From 9f487d515d286c6b193617e97504f3ee378ec95f Mon Sep 17 00:00:00 2001 From: Daniil Fajnberg Date: Wed, 10 Aug 2022 21:50:00 +0200 Subject: [PATCH] Write a few SQLModels; set up two basic CRUD-like functions; create two test routes; make `db_uri` setting more flexible --- requirements/common.txt | 5 +- setup.cfg | 7 +- src/compub/crud/__init__.py | 1 + src/compub/crud/geography.py | 33 +++++ src/compub/db/base.py | 35 ------ src/compub/db/companies.py | 138 --------------------- src/compub/db/geography.py | 56 --------- src/compub/{db => models}/__init__.py | 1 + src/compub/models/base.py | 80 ++++++++++++ src/compub/models/companies.py | 170 ++++++++++++++++++++++++++ src/compub/models/geography.py | 94 ++++++++++++++ src/compub/routes.py | 43 +++++++ src/compub/settings.py | 8 +- 13 files changed, 435 insertions(+), 236 deletions(-) create mode 100644 src/compub/crud/__init__.py create mode 100644 src/compub/crud/geography.py delete mode 100644 src/compub/db/base.py delete mode 100644 src/compub/db/companies.py delete mode 100644 src/compub/db/geography.py rename src/compub/{db => models}/__init__.py (70%) create mode 100644 src/compub/models/base.py create mode 100644 src/compub/models/companies.py create mode 100644 src/compub/models/geography.py create mode 100644 src/compub/routes.py diff --git a/requirements/common.txt b/requirements/common.txt index 5a06d14..7d158af 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -1,7 +1,8 @@ Pydantic FastAPI -SQLAlchemy[asyncio] +SQLAlchemy[asyncio]==1.4.35 Alembic SQLAlchemy-Utils +SQLModel Babel -python-slugify \ No newline at end of file +Python-Slugify \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index eb46a7c..de7c8e0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -21,15 +21,16 @@ classifiers = package_dir = = src packages = find: -python_requires = >=3.10 +python_requires = >=3.10, <4 install_requires = Pydantic FastAPI - SQLAlchemy[asyncio] + SQLAlchemy[asyncio]==1.4.35 Alembic SQLAlchemy-Utils + SQLModel Babel - python-slugify + Python-Slugify [options.extras_require] srv = diff --git a/src/compub/crud/__init__.py b/src/compub/crud/__init__.py new file mode 100644 index 0000000..5acc368 --- /dev/null +++ b/src/compub/crud/__init__.py @@ -0,0 +1 @@ +from .geography import * diff --git a/src/compub/crud/geography.py b/src/compub/crud/geography.py new file mode 100644 index 0000000..106dd52 --- /dev/null +++ b/src/compub/crud/geography.py @@ -0,0 +1,33 @@ +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.sql.expression import select + +from compub.models.geography import * + + +__all__ = [ + # 'get_state_by_name_and_country', + 'get_states', + 'create_state' +] + + +# async def get_state_by_name_and_country(session: AsyncSession, name: str, country: str) -> db.StateProvince: +# statement = select(db.StateProvince).filter( +# db.StateProvince.name == name, +# db.StateProvince.country == country +# ).limit(1) +# result = await session.execute(statement) +# return result.scalars().first() + + +async def get_states(session: AsyncSession, skip: int = 0, limit: int = 100) -> list[StateProvince]: + statement = select(StateProvince).offset(skip).limit(limit) + result = await session.execute(statement) + return result.scalars().all() + + +async def create_state(session: AsyncSession, state: StateProvince) -> StateProvince: + session.add(state) + await session.commit() + await session.refresh(state) + return state diff --git a/src/compub/db/base.py b/src/compub/db/base.py deleted file mode 100644 index d2cb273..0000000 --- a/src/compub/db/base.py +++ /dev/null @@ -1,35 +0,0 @@ -from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession -from sqlalchemy.orm import declarative_base, sessionmaker -from sqlalchemy.sql.functions import now as db_now -from sqlalchemy.sql.schema import Column -from sqlalchemy.sql.sqltypes import TIMESTAMP -from sqlalchemy_utils.functions.orm import get_columns -from sqlalchemy_utils.listeners import force_auto_coercion - -from compub.exceptions import NoDatabaseConfigured -from compub.settings import settings - - -if settings.db_uri is None: - raise NoDatabaseConfigured -engine = create_async_engine(settings.db_uri, future=True) -LocalAsyncSession = sessionmaker(engine, expire_on_commit=False, class_=AsyncSession) -ORMBase = declarative_base() - -force_auto_coercion() - - -class AbstractBase(ORMBase): - __abstract__ = True - - NON_REPR_FIELDS = ['id', 'date_created', 'date_updated'] - - date_created = Column(TIMESTAMP(timezone=False), server_default=db_now()) - date_updated = Column(TIMESTAMP(timezone=False), server_default=db_now(), onupdate=db_now()) - - def __repr__(self) -> str: - # Exclude non-representative fields: - fields = (name for name in get_columns(self).keys() if name not in self.NON_REPR_FIELDS) - # Exclude NULL value fields: - attrs = ', '.join(f"{name}={repr(getattr(self, name))}" for name in fields if getattr(self, name) is not None) - return f"{self.__class__.__name__}({attrs})" diff --git a/src/compub/db/companies.py b/src/compub/db/companies.py deleted file mode 100644 index 7158a52..0000000 --- a/src/compub/db/companies.py +++ /dev/null @@ -1,138 +0,0 @@ -from datetime import date -from typing import Optional -from uuid import uuid4 - -from slugify import slugify -from sqlalchemy.engine import Connection -from sqlalchemy.event.api import listens_for -from sqlalchemy.orm import relationship -from sqlalchemy.orm.mapper import Mapper -from sqlalchemy.sql.expression import select -from sqlalchemy.sql.functions import count -from sqlalchemy.sql.schema import Column, ForeignKey as FKey, Table -from sqlalchemy.sql.sqltypes import Boolean, Date, Integer, String, Unicode -from sqlalchemy_utils.types import CountryType, UUIDType - -from compub.utils import multi_max -from .base import AbstractBase, ORMBase - - -__all__ = [ - 'LegalForm', - 'LegalFormSubcategory', - 'Integer', - 'Company', - 'CompanyName', -] - - -class LegalForm(AbstractBase): - __tablename__ = 'legal_form' - - id = Column(Integer, primary_key=True) - short = Column(String(32), nullable=False, index=True) - name = Column(Unicode(255)) - country = Column(CountryType) - - subcategories = relationship('LegalFormSubcategory', backref='legal_form', lazy='selectin') - - def __str__(self) -> str: - return str(self.short) - - -class LegalFormSubcategory(AbstractBase): - __tablename__ = 'legal_form_subcategory' - - id = Column(Integer, primary_key=True) - legal_form_id = Column(Integer, FKey('legal_form.id', ondelete='RESTRICT'), nullable=False, index=True) - short = Column(String(32), nullable=False, index=True) - name = Column(Unicode(255)) - - companies = relationship('Company', backref='legal_form', lazy='selectin') - - def __str__(self) -> str: - return str(self.short) - - -class Industry(AbstractBase): - __tablename__ = 'industry' - - id = Column(Integer, primary_key=True) - name = Column(String(255), nullable=False, index=True) - - companies = relationship('Company', secondary='company_industries', back_populates='industries') - - def __str__(self) -> str: - return str(self.name) - - -class Company(AbstractBase): - __tablename__ = 'company' - - NAME_SORTING_PARAMS = ( # passed to `multi_max` - lambda obj: date(1, 1, 1) if obj.date_registered is None else obj.date_registered, - 'date_updated', - ) - # If date_registered is None, the name is considered to be older than one with a date_registered - - id = Column(UUIDType, primary_key=True, default=uuid4) - visible = Column(Boolean, default=True, nullable=False, index=True) - legal_form_id = Column(Integer, FKey('legal_form_subcategory.id', ondelete='RESTRICT'), index=True) - insolvent = Column(Boolean, default=False, nullable=False, index=True) - founding_date = Column(Date) - liquidation_date = Column(Date) - # TODO: Get rid of city; implement address properly - city = Column(String(255), index=True) - address_id = Column(UUIDType, FKey('address.id', ondelete='RESTRICT'), index=True) - - industries = relationship('Industry', secondary='company_industries', back_populates='companies') - names = relationship('CompanyName', backref='company', lazy='selectin') - - def __str__(self) -> str: - return str(self.current_name or f"") - - @property - def current_name(self) -> Optional['CompanyName']: - return multi_max(list(self.names), *self.NAME_SORTING_PARAMS, default=None) - - -company_industries = Table( - 'company_industries', - ORMBase.metadata, - Column('company_id', FKey('company.id'), primary_key=True), - Column('industry_id', FKey('industry.id'), primary_key=True), -) - - -class CompanyName(AbstractBase): - __tablename__ = 'company_name' - - MAX_SLUG_LENGTH = 255 - - id = Column(UUIDType, primary_key=True, default=uuid4) - name = Column(Unicode(768), nullable=False, index=True) - company_id = Column(UUIDType, FKey('company.id', ondelete='RESTRICT'), index=True) - date_registered = Column(Date) - slug = Column(String(MAX_SLUG_LENGTH), index=True) - - def __str__(self) -> str: - return str(self.name) - - -@listens_for(CompanyName, 'before_insert') -def generate_company_name_slug(_mapper: Mapper, connection: Connection, target: CompanyName) -> None: - if target.slug: - return - slug = slugify(target.name)[:(target.MAX_SLUG_LENGTH - 2)] - statement = select(count()).select_from(CompanyName).where(CompanyName.slug.startswith(slug)) - num = connection.execute(statement).scalar() - if num == 0: - target.slug = slug - else: - target.slug = f'{slug}-{str(num + 1)}' - - -def get_reg_date(obj: CompanyName) -> date: - if obj.date_registered is None: - return date(1, 1, 1) - return obj.date_registered diff --git a/src/compub/db/geography.py b/src/compub/db/geography.py deleted file mode 100644 index f40af5f..0000000 --- a/src/compub/db/geography.py +++ /dev/null @@ -1,56 +0,0 @@ -from uuid import uuid4 - -from sqlalchemy.orm import relationship -from sqlalchemy.sql.schema import Column, ForeignKey as FKey -from sqlalchemy.sql.sqltypes import Integer, String, Unicode -from sqlalchemy_utils.types import CountryType, UUIDType - -from .base import AbstractBase - - -__all__ = [ - 'StateProvince', - 'City', - 'Street', - 'Address', -] - - -class StateProvince(AbstractBase): - __tablename__ = 'state_province' - - id = Column(Integer, primary_key=True) - country = Column(CountryType, nullable=False, index=True) - name = Column(Unicode(255), nullable=False, index=True) - - cities = relationship('City', backref='state_province', lazy='selectin') - - -class City(AbstractBase): - __tablename__ = 'city' - - id = Column(Integer, primary_key=True) - state_province_id = Column(Integer, FKey('state_province.id', ondelete='RESTRICT'), nullable=False, index=True) - zip_code = Column(String(5), nullable=False, index=True) - name = Column(Unicode(255), nullable=False, index=True) - - streets = relationship('Street', backref='city', lazy='selectin') - - -class Street(AbstractBase): - __tablename__ = 'street' - - id = Column(Integer, primary_key=True) - city_id = Column(Integer, FKey('city.id', ondelete='RESTRICT'), nullable=False, index=True) - name = Column(Unicode(255), nullable=False, index=True) - - addresses = relationship('Address', backref='street', lazy='selectin') - - -class Address(AbstractBase): - __tablename__ = 'address' - - id = Column(UUIDType, primary_key=True, default=uuid4) - street_id = Column(Integer, FKey('street.id', ondelete='RESTRICT'), nullable=False, index=True) - house_number = Column(String(8), nullable=False) - supplement = Column(String(255)) diff --git a/src/compub/db/__init__.py b/src/compub/models/__init__.py similarity index 70% rename from src/compub/db/__init__.py rename to src/compub/models/__init__.py index eb71b46..0098ddc 100644 --- a/src/compub/db/__init__.py +++ b/src/compub/models/__init__.py @@ -1,2 +1,3 @@ +from .base import DB from .geography import * from .companies import * diff --git a/src/compub/models/base.py b/src/compub/models/base.py new file mode 100644 index 0000000..4e605ee --- /dev/null +++ b/src/compub/models/base.py @@ -0,0 +1,80 @@ +from datetime import datetime +from typing import Any, Iterator, Optional + +from sqlalchemy.ext.asyncio.engine import AsyncEngine, create_async_engine +from sqlalchemy.ext.asyncio.session import AsyncSession +from sqlalchemy.future.engine import Engine +from sqlalchemy.orm.session import sessionmaker +from sqlalchemy.sql.functions import now as db_now +from sqlalchemy.sql.schema import Column +from sqlalchemy.sql.sqltypes import TIMESTAMP +from sqlalchemy_utils.functions.orm import get_columns +from sqlalchemy_utils.listeners import force_auto_coercion +from sqlmodel.main import Field, SQLModel + +from compub.exceptions import NoDatabaseConfigured +from compub.settings import settings + +from sqlmodel import create_engine, Session + + +class DB: + def __init__(self): + self._engine: AsyncEngine | Engine | None = None + self._session_maker: sessionmaker | None = None + + def start_engine(self) -> None: + if settings.db_uri is None: + raise NoDatabaseConfigured + if settings.db_uri.scheme == 'sqlite': + self._engine = create_engine(settings.db_uri) + self._session_maker = sessionmaker(self.engine, expire_on_commit=False, class_=Session) + else: + self._engine = create_async_engine(settings.db_uri, future=True) + self._session_maker = sessionmaker(self.engine, expire_on_commit=False, class_=AsyncSession) + force_auto_coercion() + + @property + def engine(self): + if self._engine is None: + self.start_engine() + assert isinstance(self._engine, (AsyncEngine, Engine)) + return self._engine + + async def get_session(self) -> Session | AsyncSession: + if self._session_maker is None: + self.start_engine() + assert isinstance(self._session_maker, sessionmaker) + session = self._session_maker() + try: + yield session + finally: + await session.close() + + +class AbstractBase(SQLModel): + id: Optional[int] = Field(default=None, primary_key=True) + date_created: Optional[datetime] = Field( + default=None, sa_column=Column(TIMESTAMP(timezone=False), server_default=db_now()) + ) + date_updated: Optional[datetime] = Field( + default=None, sa_column=Column(TIMESTAMP(timezone=False), server_default=db_now(), onupdate=db_now()) + ) + + def __repr__(self) -> str: + fields = self.iter_fields(excl_non_repr=True, excl_null_value=True) + attrs = ', '.join(f"{name}={repr(value)}" for name, value in fields) + return f"{self.__class__.__name__}({attrs})" + + def iter_fields(self, excl_non_repr: bool = False, excl_null_value: bool = False) -> Iterator[tuple[str, Any]]: + for name in get_columns(self).keys(): + if excl_non_repr and name in self.get_non_repr_fields(): + continue + value = getattr(self, name) + if excl_null_value and value is None: + continue + yield name, value + + @staticmethod + def get_non_repr_fields() -> list[str]: + return ['id', 'date_created', 'date_updated'] diff --git a/src/compub/models/companies.py b/src/compub/models/companies.py new file mode 100644 index 0000000..8d0ae84 --- /dev/null +++ b/src/compub/models/companies.py @@ -0,0 +1,170 @@ +from datetime import date +from typing import Optional + +from slugify import slugify +from sqlalchemy.engine.base import Connection +from sqlalchemy.event.api import listens_for +from sqlalchemy.orm.mapper import Mapper +from sqlalchemy.sql.expression import select +from sqlalchemy.sql.functions import count +from sqlalchemy.sql.schema import Column +from sqlalchemy.sql.sqltypes import Unicode +from sqlalchemy_utils.types import CountryType +from sqlmodel.main import Field, Relationship + +from compub.utils import multi_max +from .base import AbstractBase + + +__all__ = [ + 'LegalForm', + 'LegalFormSubcategory', + 'Industry', + 'Company', + 'CompanyName', +] + + +class LegalForm(AbstractBase, table=True): + __tablename__ = 'legal_form' + + # Fields + short: str = Field(max_length=32, nullable=False, index=True) + name: Optional[str] = Field(default=None, max_length=255, sa_column=Column(Unicode(255))) + country: str = Field(sa_column=Column(CountryType)) + + # Relationships + subcategories: list['LegalFormSubcategory'] = Relationship( + back_populates='legal_form', sa_relationship_kwargs={'lazy': 'selectin'} + ) + + def __str__(self) -> str: + return str(self.short) + + +class LegalFormSubcategory(AbstractBase, table=True): + __tablename__ = 'legal_form_subcategory' + + # Fields + short: str = Field(max_length=32, nullable=False, index=True) + name: Optional[str] = Field(default=None, max_length=255, sa_column=Column(Unicode(255))) + + # Relationships + legal_form_id: Optional[int] = Field( + foreign_key='legal_form.id', default=None, nullable=False, index=True + ) + legal_form: Optional[LegalForm] = Relationship( + back_populates='subcategories', sa_relationship_kwargs={'lazy': 'selectin'} + ) + + companies: list['Company'] = Relationship( + back_populates='legal_form', sa_relationship_kwargs={'lazy': 'selectin'} + ) + + def __str__(self) -> str: + return str(self.short) + + +class CompanyIndustryLink(AbstractBase, table=True): + __tablename__ = 'company_industries' + + # Relationships + company_id: Optional[int] = Field(foreign_key='company.id', default=None, nullable=False, primary_key=True) + industry_id: Optional[int] = Field(foreign_key='industry.id', default=None, nullable=False, primary_key=True) + + +class Industry(AbstractBase, table=True): + __tablename__ = 'industry' + + # Fields + name: str = Field(max_length=255, nullable=False, index=True) + + # Relationships + companies: list['Company'] = Relationship( + back_populates='industries', link_model=CompanyIndustryLink, sa_relationship_kwargs={'lazy': 'selectin'} + ) + + def __str__(self) -> str: + return str(self.name) + + +class Company(AbstractBase, table=True): + __tablename__ = 'company' + + # Fields + visible: bool = Field(default=True, nullable=False, index=True) + insolvent: bool = Field(default=False, nullable=False, index=True) + founding_data: Optional[date] + liquidation_date: Optional[date] + city: str = Field(max_length=255, index=True) # TODO: Get rid of city; implement address properly + + # Relationships + legal_form_id: Optional[int] = Field( + foreign_key='legal_form_subcategory.id', default=None, index=True + ) + legal_form: Optional[LegalFormSubcategory] = Relationship( + back_populates='companies', sa_relationship_kwargs={'lazy': 'selectin'} + ) + + address_id: Optional[int] = Field( + foreign_key='address.id', default=None, index=True + ) + + industries: list[Industry] = Relationship( + back_populates='companies', link_model=CompanyIndustryLink, sa_relationship_kwargs={'lazy': 'selectin'} + ) + + names: list['CompanyName'] = Relationship( + back_populates='company', sa_relationship_kwargs={'lazy': 'selectin'} + ) + + def __str__(self) -> str: + return str(self.current_name or f"") + + @property + def current_name(self) -> Optional['CompanyName']: + return multi_max(list(self.names), CompanyName.get_reg_date, 'date_updated', default=None) + + +class CompanyName(AbstractBase, table=True): + __tablename__ = 'company_name' + __MAX_NAME_LENGTH__: int = 768 + __MAX_SLUG_LENGTH__: int = 255 + + # Fields + name: str = Field( + max_length=__MAX_NAME_LENGTH__, sa_column=Column(Unicode(__MAX_NAME_LENGTH__), nullable=False, index=True) + ) + date_registered: Optional[date] + slug: Optional[str] = Field(default=None, max_length=__MAX_SLUG_LENGTH__, index=True) + + # Relationships + company_id: Optional[int] = Field( + foreign_key='company.id', default=None, nullable=False, index=True + ) + company: Optional[Company] = Relationship( + back_populates='names', sa_relationship_kwargs={'lazy': 'selectin'} + ) + + def __str__(self) -> str: + return str(self.name) + + def get_reg_date(self) -> date: + return date(1, 1, 1) if self.date_registered is None else self.date_registered + + @property + def max_slug_length(self) -> int: + return self.__MAX_SLUG_LENGTH__ + + +@listens_for(CompanyName, 'before_insert') +def generate_company_name_slug(_mapper: Mapper, connection: Connection, target: CompanyName) -> None: + if target.slug: + return + slug = slugify(target.name)[:(target.max_slug_length - 2)] + statement = select(count()).select_from(CompanyName).where(CompanyName.slug.startswith(slug)) + num = connection.execute(statement).scalar() + if num == 0: + target.slug = slug + else: + target.slug = f'{slug}-{str(num + 1)}' diff --git a/src/compub/models/geography.py b/src/compub/models/geography.py new file mode 100644 index 0000000..9aed86a --- /dev/null +++ b/src/compub/models/geography.py @@ -0,0 +1,94 @@ +from typing import Any, Optional + +from pydantic import validator +from sqlalchemy.sql.schema import Column +from sqlalchemy.sql.sqltypes import Unicode +from sqlalchemy_utils.primitives.country import Country +from sqlalchemy_utils.types import CountryType +from sqlmodel.main import Field, Relationship + +from .base import AbstractBase + + +__all__ = [ + 'StateProvince', + 'City', + 'Street', + 'Address', +] + + +class StateProvince(AbstractBase, table=True): + __tablename__ = 'state_province' + + # Fields + country: str = Field(sa_column=Column(CountryType, nullable=False, index=True)) + name: str = Field(max_length=255, sa_column=Column(Unicode(255), nullable=False, index=True)) + + # Relationships + cities: list['City'] = Relationship( + back_populates='state_province', sa_relationship_kwargs={'lazy': 'selectin'} + ) + + @validator('country', pre=True) + def country_as_uppercase_string(cls, v: Any) -> str: + if isinstance(v, Country): + return v.code + if isinstance(v, str): + return v.upper() + raise TypeError + + +class City(AbstractBase, table=True): + __tablename__ = 'city' + + # Fields + zip_code: str = Field(max_length=5, nullable=False, index=True) + name: str = Field(max_length=255, sa_column=Column(Unicode(255), nullable=False, index=True)) + + # Relationships + state_province_id: Optional[int] = Field( + foreign_key='state_province.id', default=None, nullable=False, index=True + ) + state_province: Optional[StateProvince] = Relationship( + back_populates='cities', sa_relationship_kwargs={'lazy': 'selectin'} + ) + + streets: list['Street'] = Relationship( + back_populates='city', sa_relationship_kwargs={'lazy': 'selectin'} + ) + + +class Street(AbstractBase, table=True): + __tablename__ = 'street' + + # Fields + name: str = Field(max_length=255, sa_column=Column(Unicode(255), nullable=False, index=True)) + + # Relationships + city_id: Optional[int] = Field( + foreign_key='city.id', default=None, nullable=False, index=True + ) + city: Optional[City] = Relationship( + back_populates='streets', sa_relationship_kwargs={'lazy': 'selectin'} + ) + + addresses: list['Address'] = Relationship( + back_populates='street', sa_relationship_kwargs={'lazy': 'selectin'} + ) + + +class Address(AbstractBase, table=True): + __tablename__ = 'address' + + # Fields + house_number: str = Field(max_length=8, nullable=False) + supplement: str = Field(max_length=255) + + # Relationships + street_id: Optional[int] = Field( + foreign_key='street.id', default=None, nullable=False, index=True + ) + street: Optional[Street] = Relationship( + back_populates='addresses', sa_relationship_kwargs={'lazy': 'selectin'} + ) diff --git a/src/compub/routes.py b/src/compub/routes.py new file mode 100644 index 0000000..bdca78a --- /dev/null +++ b/src/compub/routes.py @@ -0,0 +1,43 @@ +from fastapi import FastAPI, HTTPException, Depends +from sqlalchemy.ext.asyncio import AsyncSession +from sqlmodel import SQLModel + +from .models import * +from .models.base import DB +from . import crud + + +api, db = FastAPI(), DB() + + +@api.on_event('startup') +async def initialize_db(): + async with db.engine.begin() as conn: + await conn.run_sync(SQLModel.metadata.drop_all) + await conn.run_sync(SQLModel.metadata.create_all) + + +@api.on_event("shutdown") +async def close_connection_pool(): + await db.engine.dispose() + + +@api.post("/states/", response_model=StateProvince) +async def create_state(state: StateProvince, session: AsyncSession = Depends(db.get_session)): + # db_obj = await crud.get_state_by_name_and_country(session, name=state.name, country=state.country) + # if db_obj: + # raise HTTPException(status_code=400, detail="Name already registered") + return await crud.create_state(session, state=state) + + +@api.get("/states/", response_model=list[StateProvince]) +async def list_states(skip: int = 0, limit: int = 100, session: AsyncSession = Depends(db.get_session)): + return await crud.get_states(session, skip=skip, limit=limit) + + +# @app.post("/cities/", response_model=City) +# async def create_city(city: City, session: AsyncSession = Depends(get_session)): +# db_obj = await crud.get_state_by_name_and_country(session, name=state.name, country=state.country) +# if db_obj: +# raise HTTPException(status_code=400, detail="Name already registered") +# return await crud.create_state(session, state=state) diff --git a/src/compub/settings.py b/src/compub/settings.py index 67c04d5..86b84e8 100644 --- a/src/compub/settings.py +++ b/src/compub/settings.py @@ -3,7 +3,7 @@ import logging.config from pathlib import Path from typing import Any, Callable, ClassVar -from pydantic import BaseModel, BaseSettings, PostgresDsn, validator +from pydantic import BaseModel, BaseSettings, AnyUrl, validator from pydantic.env_settings import SettingsSourceCallable from yaml import safe_load @@ -76,8 +76,12 @@ class ServerSettings(BaseModel): uds: str | None = None +class DBURI(AnyUrl): + host_required = False + + class Settings(AbstractBaseSettings): - db_uri: PostgresDsn | None = None + db_uri: DBURI | None = None server: ServerSettings = ServerSettings() log_config: dict | Path | None = None