from datetime import datetime from typing import Any, Iterator, Optional from sqlalchemy.ext.asyncio.engine import AsyncEngine, create_async_engine from sqlalchemy.ext.asyncio.session import AsyncSession from sqlalchemy.future.engine import Engine from sqlalchemy.orm.session import sessionmaker from sqlalchemy.sql.functions import now as db_now from sqlalchemy.sql.schema import Column from sqlalchemy.sql.sqltypes import TIMESTAMP from sqlalchemy_utils.functions.orm import get_columns from sqlalchemy_utils.listeners import force_auto_coercion from sqlalchemy_utils.types.choice import Choice from sqlmodel.main import Field, SQLModel from compub.exceptions import NoDatabaseConfigured from compub.settings import settings from sqlmodel import create_engine, Session DEFAULT_PK_TYPE = int class DB: def __init__(self): self._engine: AsyncEngine | Engine | None = None self.session_maker: sessionmaker | None = None def start_engine(self) -> None: if settings.db_uri is None: raise NoDatabaseConfigured if settings.db_uri.scheme == 'sqlite': self._engine = create_engine(settings.db_uri, echo=True) self.session_maker = sessionmaker(self.engine, expire_on_commit=False, class_=Session) else: self._engine = create_async_engine(settings.db_uri, future=True) self.session_maker = sessionmaker(self.engine, expire_on_commit=False, class_=AsyncSession) force_auto_coercion() @property def engine(self): if self._engine is None: self.start_engine() assert isinstance(self._engine, (AsyncEngine, Engine)) return self._engine async def get_session(self) -> AsyncSession: if self.session_maker is None: self.start_engine() assert isinstance(self.session_maker.class_, AsyncSession) session = self.session_maker() try: yield session finally: await session.close() def get_session_blocking(self) -> Session: if self.session_maker is None: self.start_engine() assert isinstance(self.session_maker.class_, Session) session = self.session_maker() try: yield session finally: session.close() class AbstractBase(SQLModel): id: Optional[DEFAULT_PK_TYPE] = Field(default=None, primary_key=True) date_created: Optional[datetime] = Field( default=None, sa_column=Column(TIMESTAMP(timezone=False), server_default=db_now()) ) date_updated: Optional[datetime] = Field( default=None, sa_column=Column(TIMESTAMP(timezone=False), server_default=db_now(), onupdate=db_now()) ) def __repr__(self) -> str: fields = self.iter_fields(excl_non_repr=True, excl_null_value=True) attrs = ', '.join(f"{name}={repr(value)}" for name, value in fields) return f"{self.__class__.__name__}({attrs})" def iter_fields(self, excl_non_repr: bool = False, excl_null_value: bool = False) -> Iterator[tuple[str, Any]]: for name in get_columns(self).keys(): if excl_non_repr and name in self.get_non_repr_fields(): continue value = getattr(self, name) if excl_null_value and value is None: continue yield name, value @staticmethod def get_non_repr_fields() -> list[str]: return ['id', 'date_created', 'date_updated'] def get_choice_value(obj: Choice | str) -> str: return obj.value if isinstance(obj, Choice) else obj