generated from daniil-berg/boilerplate-py
Compare commits
9 Commits
934f596316
...
9f487d515d
Author | SHA1 | Date | |
---|---|---|---|
9f487d515d | |||
60c6b1c4bb | |||
8a8ea7ec09 | |||
b7e627a216 | |||
71f9db6c0d | |||
eb08e5a674 | |||
8ec4e8885a | |||
27696ad57c | |||
aa7fed7817 |
@ -12,7 +12,7 @@ Company publications API
|
||||
|
||||
## Dependencies
|
||||
|
||||
Python Version ..., OS ...
|
||||
Python Version 3.10+, OS ...
|
||||
|
||||
## Building from source
|
||||
|
||||
|
@ -0,0 +1,8 @@
|
||||
Pydantic
|
||||
FastAPI
|
||||
SQLAlchemy[asyncio]==1.4.35
|
||||
Alembic
|
||||
SQLAlchemy-Utils
|
||||
SQLModel
|
||||
Babel
|
||||
Python-Slugify
|
@ -1,2 +1,4 @@
|
||||
-r common.txt
|
||||
-r srv.txt
|
||||
coverage
|
||||
sqlalchemy-stubs
|
||||
aiosqlite
|
||||
|
2
requirements/srv.txt
Normal file
2
requirements/srv.txt
Normal file
@ -0,0 +1,2 @@
|
||||
-r common.txt
|
||||
Uvicorn[standard]
|
26
setup.cfg
26
setup.cfg
@ -1,8 +1,8 @@
|
||||
[metadata]
|
||||
name = compub
|
||||
version = 0.0.1
|
||||
author = Daniil
|
||||
author_email = mail@placeholder123.to
|
||||
author = Daniil Fajnberg
|
||||
author_email = mail@daniil.fajnberg.de
|
||||
description = Company publications API
|
||||
long_description = file: README.md
|
||||
long_description_content_type = text/markdown
|
||||
@ -10,20 +10,36 @@ url = https://git.fajnberg.de/daniil/compub
|
||||
project_urls =
|
||||
Bug Tracker = https://git.fajnberg.de/daniil/compub/issues
|
||||
classifiers =
|
||||
Programming Language :: Python :: 3
|
||||
Programming Language :: Python :: 3 :: Only
|
||||
Programming Language :: Python :: 3.10
|
||||
Operating System :: OS Independent
|
||||
Development Status :: 3 - Alpha
|
||||
Framework :: AsyncIO
|
||||
Framework :: FastAPI
|
||||
|
||||
[options]
|
||||
package_dir =
|
||||
= src
|
||||
packages = find:
|
||||
python_requires = >=3
|
||||
python_requires = >=3.10, <4
|
||||
install_requires =
|
||||
...
|
||||
Pydantic
|
||||
FastAPI
|
||||
SQLAlchemy[asyncio]==1.4.35
|
||||
Alembic
|
||||
SQLAlchemy-Utils
|
||||
SQLModel
|
||||
Babel
|
||||
Python-Slugify
|
||||
|
||||
[options.extras_require]
|
||||
srv =
|
||||
Uvicorn[standard]
|
||||
dev =
|
||||
Uvicorn[standard]
|
||||
coverage
|
||||
sqlalchemy-stubs
|
||||
aiosqlite
|
||||
|
||||
[options.packages.find]
|
||||
where = src
|
||||
|
30
src/compub/__main__.py
Normal file
30
src/compub/__main__.py
Normal file
@ -0,0 +1,30 @@
|
||||
from argparse import ArgumentParser, SUPPRESS
|
||||
from pathlib import Path
|
||||
from typing import Any, Sequence
|
||||
|
||||
import uvicorn
|
||||
|
||||
from .settings import PROGRAM_NAME, CONFIG_FILE_PATH_PARAM, DEFAULT_CONFIG_FILE_PATHS, init, settings
|
||||
|
||||
|
||||
def parse_cli(args: Sequence[str] = None) -> dict[str, Any]:
|
||||
parser = ArgumentParser(PROGRAM_NAME)
|
||||
parser.add_argument(
|
||||
'-c', f'--{CONFIG_FILE_PATH_PARAM}',
|
||||
type=Path,
|
||||
metavar='PATH',
|
||||
default=SUPPRESS,
|
||||
help=f"Paths to config file that will take precedence over all others; "
|
||||
f"the following {len(DEFAULT_CONFIG_FILE_PATHS)} paths are always checked first (in that order):"
|
||||
f" {','.join(str(p) for p in DEFAULT_CONFIG_FILE_PATHS)}"
|
||||
)
|
||||
return vars(parser.parse_args(args))
|
||||
|
||||
|
||||
def main():
|
||||
init(**parse_cli())
|
||||
uvicorn.run(f'{PROGRAM_NAME}.routes:app', **settings.server.dict())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
1
src/compub/crud/__init__.py
Normal file
1
src/compub/crud/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .geography import *
|
33
src/compub/crud/geography.py
Normal file
33
src/compub/crud/geography.py
Normal file
@ -0,0 +1,33 @@
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.sql.expression import select
|
||||
|
||||
from compub.models.geography import *
|
||||
|
||||
|
||||
__all__ = [
|
||||
# 'get_state_by_name_and_country',
|
||||
'get_states',
|
||||
'create_state'
|
||||
]
|
||||
|
||||
|
||||
# async def get_state_by_name_and_country(session: AsyncSession, name: str, country: str) -> db.StateProvince:
|
||||
# statement = select(db.StateProvince).filter(
|
||||
# db.StateProvince.name == name,
|
||||
# db.StateProvince.country == country
|
||||
# ).limit(1)
|
||||
# result = await session.execute(statement)
|
||||
# return result.scalars().first()
|
||||
|
||||
|
||||
async def get_states(session: AsyncSession, skip: int = 0, limit: int = 100) -> list[StateProvince]:
|
||||
statement = select(StateProvince).offset(skip).limit(limit)
|
||||
result = await session.execute(statement)
|
||||
return result.scalars().all()
|
||||
|
||||
|
||||
async def create_state(session: AsyncSession, state: StateProvince) -> StateProvince:
|
||||
session.add(state)
|
||||
await session.commit()
|
||||
await session.refresh(state)
|
||||
return state
|
6
src/compub/exceptions.py
Normal file
6
src/compub/exceptions.py
Normal file
@ -0,0 +1,6 @@
|
||||
class ConfigError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class NoDatabaseConfigured(ConfigError):
|
||||
pass
|
3
src/compub/models/__init__.py
Normal file
3
src/compub/models/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from .base import DB
|
||||
from .geography import *
|
||||
from .companies import *
|
80
src/compub/models/base.py
Normal file
80
src/compub/models/base.py
Normal file
@ -0,0 +1,80 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Iterator, Optional
|
||||
|
||||
from sqlalchemy.ext.asyncio.engine import AsyncEngine, create_async_engine
|
||||
from sqlalchemy.ext.asyncio.session import AsyncSession
|
||||
from sqlalchemy.future.engine import Engine
|
||||
from sqlalchemy.orm.session import sessionmaker
|
||||
from sqlalchemy.sql.functions import now as db_now
|
||||
from sqlalchemy.sql.schema import Column
|
||||
from sqlalchemy.sql.sqltypes import TIMESTAMP
|
||||
from sqlalchemy_utils.functions.orm import get_columns
|
||||
from sqlalchemy_utils.listeners import force_auto_coercion
|
||||
from sqlmodel.main import Field, SQLModel
|
||||
|
||||
from compub.exceptions import NoDatabaseConfigured
|
||||
from compub.settings import settings
|
||||
|
||||
from sqlmodel import create_engine, Session
|
||||
|
||||
|
||||
class DB:
|
||||
def __init__(self):
|
||||
self._engine: AsyncEngine | Engine | None = None
|
||||
self._session_maker: sessionmaker | None = None
|
||||
|
||||
def start_engine(self) -> None:
|
||||
if settings.db_uri is None:
|
||||
raise NoDatabaseConfigured
|
||||
if settings.db_uri.scheme == 'sqlite':
|
||||
self._engine = create_engine(settings.db_uri)
|
||||
self._session_maker = sessionmaker(self.engine, expire_on_commit=False, class_=Session)
|
||||
else:
|
||||
self._engine = create_async_engine(settings.db_uri, future=True)
|
||||
self._session_maker = sessionmaker(self.engine, expire_on_commit=False, class_=AsyncSession)
|
||||
force_auto_coercion()
|
||||
|
||||
@property
|
||||
def engine(self):
|
||||
if self._engine is None:
|
||||
self.start_engine()
|
||||
assert isinstance(self._engine, (AsyncEngine, Engine))
|
||||
return self._engine
|
||||
|
||||
async def get_session(self) -> Session | AsyncSession:
|
||||
if self._session_maker is None:
|
||||
self.start_engine()
|
||||
assert isinstance(self._session_maker, sessionmaker)
|
||||
session = self._session_maker()
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
class AbstractBase(SQLModel):
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
date_created: Optional[datetime] = Field(
|
||||
default=None, sa_column=Column(TIMESTAMP(timezone=False), server_default=db_now())
|
||||
)
|
||||
date_updated: Optional[datetime] = Field(
|
||||
default=None, sa_column=Column(TIMESTAMP(timezone=False), server_default=db_now(), onupdate=db_now())
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
fields = self.iter_fields(excl_non_repr=True, excl_null_value=True)
|
||||
attrs = ', '.join(f"{name}={repr(value)}" for name, value in fields)
|
||||
return f"{self.__class__.__name__}({attrs})"
|
||||
|
||||
def iter_fields(self, excl_non_repr: bool = False, excl_null_value: bool = False) -> Iterator[tuple[str, Any]]:
|
||||
for name in get_columns(self).keys():
|
||||
if excl_non_repr and name in self.get_non_repr_fields():
|
||||
continue
|
||||
value = getattr(self, name)
|
||||
if excl_null_value and value is None:
|
||||
continue
|
||||
yield name, value
|
||||
|
||||
@staticmethod
|
||||
def get_non_repr_fields() -> list[str]:
|
||||
return ['id', 'date_created', 'date_updated']
|
170
src/compub/models/companies.py
Normal file
170
src/compub/models/companies.py
Normal file
@ -0,0 +1,170 @@
|
||||
from datetime import date
|
||||
from typing import Optional
|
||||
|
||||
from slugify import slugify
|
||||
from sqlalchemy.engine.base import Connection
|
||||
from sqlalchemy.event.api import listens_for
|
||||
from sqlalchemy.orm.mapper import Mapper
|
||||
from sqlalchemy.sql.expression import select
|
||||
from sqlalchemy.sql.functions import count
|
||||
from sqlalchemy.sql.schema import Column
|
||||
from sqlalchemy.sql.sqltypes import Unicode
|
||||
from sqlalchemy_utils.types import CountryType
|
||||
from sqlmodel.main import Field, Relationship
|
||||
|
||||
from compub.utils import multi_max
|
||||
from .base import AbstractBase
|
||||
|
||||
|
||||
__all__ = [
|
||||
'LegalForm',
|
||||
'LegalFormSubcategory',
|
||||
'Industry',
|
||||
'Company',
|
||||
'CompanyName',
|
||||
]
|
||||
|
||||
|
||||
class LegalForm(AbstractBase, table=True):
|
||||
__tablename__ = 'legal_form'
|
||||
|
||||
# Fields
|
||||
short: str = Field(max_length=32, nullable=False, index=True)
|
||||
name: Optional[str] = Field(default=None, max_length=255, sa_column=Column(Unicode(255)))
|
||||
country: str = Field(sa_column=Column(CountryType))
|
||||
|
||||
# Relationships
|
||||
subcategories: list['LegalFormSubcategory'] = Relationship(
|
||||
back_populates='legal_form', sa_relationship_kwargs={'lazy': 'selectin'}
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return str(self.short)
|
||||
|
||||
|
||||
class LegalFormSubcategory(AbstractBase, table=True):
|
||||
__tablename__ = 'legal_form_subcategory'
|
||||
|
||||
# Fields
|
||||
short: str = Field(max_length=32, nullable=False, index=True)
|
||||
name: Optional[str] = Field(default=None, max_length=255, sa_column=Column(Unicode(255)))
|
||||
|
||||
# Relationships
|
||||
legal_form_id: Optional[int] = Field(
|
||||
foreign_key='legal_form.id', default=None, nullable=False, index=True
|
||||
)
|
||||
legal_form: Optional[LegalForm] = Relationship(
|
||||
back_populates='subcategories', sa_relationship_kwargs={'lazy': 'selectin'}
|
||||
)
|
||||
|
||||
companies: list['Company'] = Relationship(
|
||||
back_populates='legal_form', sa_relationship_kwargs={'lazy': 'selectin'}
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return str(self.short)
|
||||
|
||||
|
||||
class CompanyIndustryLink(AbstractBase, table=True):
|
||||
__tablename__ = 'company_industries'
|
||||
|
||||
# Relationships
|
||||
company_id: Optional[int] = Field(foreign_key='company.id', default=None, nullable=False, primary_key=True)
|
||||
industry_id: Optional[int] = Field(foreign_key='industry.id', default=None, nullable=False, primary_key=True)
|
||||
|
||||
|
||||
class Industry(AbstractBase, table=True):
|
||||
__tablename__ = 'industry'
|
||||
|
||||
# Fields
|
||||
name: str = Field(max_length=255, nullable=False, index=True)
|
||||
|
||||
# Relationships
|
||||
companies: list['Company'] = Relationship(
|
||||
back_populates='industries', link_model=CompanyIndustryLink, sa_relationship_kwargs={'lazy': 'selectin'}
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return str(self.name)
|
||||
|
||||
|
||||
class Company(AbstractBase, table=True):
|
||||
__tablename__ = 'company'
|
||||
|
||||
# Fields
|
||||
visible: bool = Field(default=True, nullable=False, index=True)
|
||||
insolvent: bool = Field(default=False, nullable=False, index=True)
|
||||
founding_data: Optional[date]
|
||||
liquidation_date: Optional[date]
|
||||
city: str = Field(max_length=255, index=True) # TODO: Get rid of city; implement address properly
|
||||
|
||||
# Relationships
|
||||
legal_form_id: Optional[int] = Field(
|
||||
foreign_key='legal_form_subcategory.id', default=None, index=True
|
||||
)
|
||||
legal_form: Optional[LegalFormSubcategory] = Relationship(
|
||||
back_populates='companies', sa_relationship_kwargs={'lazy': 'selectin'}
|
||||
)
|
||||
|
||||
address_id: Optional[int] = Field(
|
||||
foreign_key='address.id', default=None, index=True
|
||||
)
|
||||
|
||||
industries: list[Industry] = Relationship(
|
||||
back_populates='companies', link_model=CompanyIndustryLink, sa_relationship_kwargs={'lazy': 'selectin'}
|
||||
)
|
||||
|
||||
names: list['CompanyName'] = Relationship(
|
||||
back_populates='company', sa_relationship_kwargs={'lazy': 'selectin'}
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return str(self.current_name or f"<Company {self.id}>")
|
||||
|
||||
@property
|
||||
def current_name(self) -> Optional['CompanyName']:
|
||||
return multi_max(list(self.names), CompanyName.get_reg_date, 'date_updated', default=None)
|
||||
|
||||
|
||||
class CompanyName(AbstractBase, table=True):
|
||||
__tablename__ = 'company_name'
|
||||
__MAX_NAME_LENGTH__: int = 768
|
||||
__MAX_SLUG_LENGTH__: int = 255
|
||||
|
||||
# Fields
|
||||
name: str = Field(
|
||||
max_length=__MAX_NAME_LENGTH__, sa_column=Column(Unicode(__MAX_NAME_LENGTH__), nullable=False, index=True)
|
||||
)
|
||||
date_registered: Optional[date]
|
||||
slug: Optional[str] = Field(default=None, max_length=__MAX_SLUG_LENGTH__, index=True)
|
||||
|
||||
# Relationships
|
||||
company_id: Optional[int] = Field(
|
||||
foreign_key='company.id', default=None, nullable=False, index=True
|
||||
)
|
||||
company: Optional[Company] = Relationship(
|
||||
back_populates='names', sa_relationship_kwargs={'lazy': 'selectin'}
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return str(self.name)
|
||||
|
||||
def get_reg_date(self) -> date:
|
||||
return date(1, 1, 1) if self.date_registered is None else self.date_registered
|
||||
|
||||
@property
|
||||
def max_slug_length(self) -> int:
|
||||
return self.__MAX_SLUG_LENGTH__
|
||||
|
||||
|
||||
@listens_for(CompanyName, 'before_insert')
|
||||
def generate_company_name_slug(_mapper: Mapper, connection: Connection, target: CompanyName) -> None:
|
||||
if target.slug:
|
||||
return
|
||||
slug = slugify(target.name)[:(target.max_slug_length - 2)]
|
||||
statement = select(count()).select_from(CompanyName).where(CompanyName.slug.startswith(slug))
|
||||
num = connection.execute(statement).scalar()
|
||||
if num == 0:
|
||||
target.slug = slug
|
||||
else:
|
||||
target.slug = f'{slug}-{str(num + 1)}'
|
94
src/compub/models/geography.py
Normal file
94
src/compub/models/geography.py
Normal file
@ -0,0 +1,94 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import validator
|
||||
from sqlalchemy.sql.schema import Column
|
||||
from sqlalchemy.sql.sqltypes import Unicode
|
||||
from sqlalchemy_utils.primitives.country import Country
|
||||
from sqlalchemy_utils.types import CountryType
|
||||
from sqlmodel.main import Field, Relationship
|
||||
|
||||
from .base import AbstractBase
|
||||
|
||||
|
||||
__all__ = [
|
||||
'StateProvince',
|
||||
'City',
|
||||
'Street',
|
||||
'Address',
|
||||
]
|
||||
|
||||
|
||||
class StateProvince(AbstractBase, table=True):
|
||||
__tablename__ = 'state_province'
|
||||
|
||||
# Fields
|
||||
country: str = Field(sa_column=Column(CountryType, nullable=False, index=True))
|
||||
name: str = Field(max_length=255, sa_column=Column(Unicode(255), nullable=False, index=True))
|
||||
|
||||
# Relationships
|
||||
cities: list['City'] = Relationship(
|
||||
back_populates='state_province', sa_relationship_kwargs={'lazy': 'selectin'}
|
||||
)
|
||||
|
||||
@validator('country', pre=True)
|
||||
def country_as_uppercase_string(cls, v: Any) -> str:
|
||||
if isinstance(v, Country):
|
||||
return v.code
|
||||
if isinstance(v, str):
|
||||
return v.upper()
|
||||
raise TypeError
|
||||
|
||||
|
||||
class City(AbstractBase, table=True):
|
||||
__tablename__ = 'city'
|
||||
|
||||
# Fields
|
||||
zip_code: str = Field(max_length=5, nullable=False, index=True)
|
||||
name: str = Field(max_length=255, sa_column=Column(Unicode(255), nullable=False, index=True))
|
||||
|
||||
# Relationships
|
||||
state_province_id: Optional[int] = Field(
|
||||
foreign_key='state_province.id', default=None, nullable=False, index=True
|
||||
)
|
||||
state_province: Optional[StateProvince] = Relationship(
|
||||
back_populates='cities', sa_relationship_kwargs={'lazy': 'selectin'}
|
||||
)
|
||||
|
||||
streets: list['Street'] = Relationship(
|
||||
back_populates='city', sa_relationship_kwargs={'lazy': 'selectin'}
|
||||
)
|
||||
|
||||
|
||||
class Street(AbstractBase, table=True):
|
||||
__tablename__ = 'street'
|
||||
|
||||
# Fields
|
||||
name: str = Field(max_length=255, sa_column=Column(Unicode(255), nullable=False, index=True))
|
||||
|
||||
# Relationships
|
||||
city_id: Optional[int] = Field(
|
||||
foreign_key='city.id', default=None, nullable=False, index=True
|
||||
)
|
||||
city: Optional[City] = Relationship(
|
||||
back_populates='streets', sa_relationship_kwargs={'lazy': 'selectin'}
|
||||
)
|
||||
|
||||
addresses: list['Address'] = Relationship(
|
||||
back_populates='street', sa_relationship_kwargs={'lazy': 'selectin'}
|
||||
)
|
||||
|
||||
|
||||
class Address(AbstractBase, table=True):
|
||||
__tablename__ = 'address'
|
||||
|
||||
# Fields
|
||||
house_number: str = Field(max_length=8, nullable=False)
|
||||
supplement: str = Field(max_length=255)
|
||||
|
||||
# Relationships
|
||||
street_id: Optional[int] = Field(
|
||||
foreign_key='street.id', default=None, nullable=False, index=True
|
||||
)
|
||||
street: Optional[Street] = Relationship(
|
||||
back_populates='addresses', sa_relationship_kwargs={'lazy': 'selectin'}
|
||||
)
|
43
src/compub/routes.py
Normal file
43
src/compub/routes.py
Normal file
@ -0,0 +1,43 @@
|
||||
from fastapi import FastAPI, HTTPException, Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
from .models import *
|
||||
from .models.base import DB
|
||||
from . import crud
|
||||
|
||||
|
||||
api, db = FastAPI(), DB()
|
||||
|
||||
|
||||
@api.on_event('startup')
|
||||
async def initialize_db():
|
||||
async with db.engine.begin() as conn:
|
||||
await conn.run_sync(SQLModel.metadata.drop_all)
|
||||
await conn.run_sync(SQLModel.metadata.create_all)
|
||||
|
||||
|
||||
@api.on_event("shutdown")
|
||||
async def close_connection_pool():
|
||||
await db.engine.dispose()
|
||||
|
||||
|
||||
@api.post("/states/", response_model=StateProvince)
|
||||
async def create_state(state: StateProvince, session: AsyncSession = Depends(db.get_session)):
|
||||
# db_obj = await crud.get_state_by_name_and_country(session, name=state.name, country=state.country)
|
||||
# if db_obj:
|
||||
# raise HTTPException(status_code=400, detail="Name already registered")
|
||||
return await crud.create_state(session, state=state)
|
||||
|
||||
|
||||
@api.get("/states/", response_model=list[StateProvince])
|
||||
async def list_states(skip: int = 0, limit: int = 100, session: AsyncSession = Depends(db.get_session)):
|
||||
return await crud.get_states(session, skip=skip, limit=limit)
|
||||
|
||||
|
||||
# @app.post("/cities/", response_model=City)
|
||||
# async def create_city(city: City, session: AsyncSession = Depends(get_session)):
|
||||
# db_obj = await crud.get_state_by_name_and_country(session, name=state.name, country=state.country)
|
||||
# if db_obj:
|
||||
# raise HTTPException(status_code=400, detail="Name already registered")
|
||||
# return await crud.create_state(session, state=state)
|
113
src/compub/settings.py
Normal file
113
src/compub/settings.py
Normal file
@ -0,0 +1,113 @@
|
||||
import logging
|
||||
import logging.config
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, ClassVar
|
||||
|
||||
from pydantic import BaseModel, BaseSettings, AnyUrl, validator
|
||||
from pydantic.env_settings import SettingsSourceCallable
|
||||
from yaml import safe_load
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
PROGRAM_NAME = 'compub'
|
||||
|
||||
THIS_DIR = Path(__file__).parent
|
||||
PROJECT_DIR = THIS_DIR.parent.parent
|
||||
|
||||
DEFAULT_CONFIG_FILE_NAME = 'config.yaml'
|
||||
DEFAULT_CONFIG_FILE_PATHS = [
|
||||
Path('/etc', PROGRAM_NAME, DEFAULT_CONFIG_FILE_NAME), # system directory
|
||||
Path(PROJECT_DIR, DEFAULT_CONFIG_FILE_NAME), # project directory
|
||||
Path('.', DEFAULT_CONFIG_FILE_NAME), # working directory
|
||||
]
|
||||
|
||||
CONFIG_FILE_PATH_PARAM = 'config_file'
|
||||
|
||||
|
||||
class AbstractBaseSettings(BaseSettings):
|
||||
_config_file_paths: ClassVar[list[Path]] = DEFAULT_CONFIG_FILE_PATHS
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
config_file_path = kwargs.pop(CONFIG_FILE_PATH_PARAM, None)
|
||||
if config_file_path is not None:
|
||||
self._config_file_paths.append(Path(config_file_path))
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def get_config_file_paths(self) -> list[Path]:
|
||||
return self._config_file_paths
|
||||
|
||||
class Config:
|
||||
allow_mutation = False
|
||||
env_file_encoding = 'utf-8'
|
||||
underscore_attrs_are_private = True
|
||||
|
||||
@classmethod
|
||||
def customise_sources(
|
||||
cls,
|
||||
init_settings: SettingsSourceCallable,
|
||||
env_settings: SettingsSourceCallable,
|
||||
file_secret_settings: SettingsSourceCallable
|
||||
) -> tuple[Callable, ...]:
|
||||
return init_settings, env_settings, _yaml_config_settings_source
|
||||
|
||||
|
||||
def _yaml_config_settings_source(settings_obj: AbstractBaseSettings) -> dict[str, Any]:
|
||||
"""
|
||||
Incrementally loads (and updates) settings from all config files that can be found as returned by the
|
||||
`Settings.get_config_file_paths` method and returns the result in a dictionary.
|
||||
This function is intended to be used as a settings source in the `Config.customise_sources` method.
|
||||
"""
|
||||
config = {}
|
||||
for path in settings_obj.get_config_file_paths():
|
||||
if not path.is_file():
|
||||
log.debug(f"No config file found at '{path}'")
|
||||
continue
|
||||
log.info(f"Reading config file '{path}'")
|
||||
with open(path, 'r') as f:
|
||||
config.update(safe_load(f))
|
||||
return config
|
||||
|
||||
|
||||
class ServerSettings(BaseModel):
|
||||
host: str = '127.0.0.1'
|
||||
port: int = 9009
|
||||
uds: str | None = None
|
||||
|
||||
|
||||
class DBURI(AnyUrl):
|
||||
host_required = False
|
||||
|
||||
|
||||
class Settings(AbstractBaseSettings):
|
||||
db_uri: DBURI | None = None
|
||||
server: ServerSettings = ServerSettings()
|
||||
log_config: dict | Path | None = None
|
||||
|
||||
@validator('log_config')
|
||||
def configure_logging(cls, v: dict | Path | None) -> dict | None:
|
||||
if v is None:
|
||||
return None
|
||||
if isinstance(v, Path):
|
||||
with open(v, 'r') as f:
|
||||
logging_conf = safe_load(f)
|
||||
logging.config.dictConfig(logging_conf)
|
||||
return logging_conf
|
||||
if isinstance(v, dict):
|
||||
logging.config.dictConfig(v)
|
||||
return v
|
||||
raise TypeError
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
||||
|
||||
def init(**kwargs) -> None:
|
||||
settings.__init__(**kwargs)
|
||||
|
||||
|
||||
def update(**kwargs) -> None:
|
||||
settings_dict = settings.dict()
|
||||
settings_dict.update(kwargs)
|
||||
settings.__init__(**settings_dict)
|
66
src/compub/utils.py
Normal file
66
src/compub/utils.py
Normal file
@ -0,0 +1,66 @@
|
||||
from operator import attrgetter
|
||||
from typing import Any, Callable, TypeVar
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
KeyFuncT = Callable[[T], Any]
|
||||
_sentinel = object()
|
||||
|
||||
|
||||
def multi_sort(obj_list: list[T], *parameters: str | KeyFuncT | tuple[str | KeyFuncT, bool]) -> None:
|
||||
for param in reversed(parameters):
|
||||
if isinstance(param, str):
|
||||
obj_list.sort(key=attrgetter(param))
|
||||
elif callable(param):
|
||||
obj_list.sort(key=param)
|
||||
else:
|
||||
try:
|
||||
param, reverse = param
|
||||
assert isinstance(reverse, bool)
|
||||
except (ValueError, TypeError):
|
||||
raise ValueError(f"Sorting parameter {param} is neither a key nor a key-boolean-tuple.")
|
||||
if isinstance(param, str):
|
||||
obj_list.sort(key=attrgetter(param), reverse=reverse)
|
||||
elif callable(param):
|
||||
obj_list.sort(key=param, reverse=reverse)
|
||||
else:
|
||||
raise ValueError(f"Sorting key {param} is neither a string nor a callable.")
|
||||
|
||||
|
||||
def multi_gt(left: T, right: T, *parameters: str | KeyFuncT | tuple[str | KeyFuncT, bool]) -> bool:
|
||||
for param in parameters:
|
||||
invert = False
|
||||
if isinstance(param, str):
|
||||
left_val, right_val = getattr(left, param), getattr(right, param)
|
||||
elif callable(param):
|
||||
left_val, right_val = param(left), param(right)
|
||||
else:
|
||||
try:
|
||||
param, invert = param
|
||||
assert isinstance(invert, bool)
|
||||
except (ValueError, TypeError, AssertionError):
|
||||
raise ValueError(f"Ordering parameter {param} is neither a key nor a key-boolean-tuple.")
|
||||
if isinstance(param, str):
|
||||
left_val, right_val = getattr(left, param), getattr(right, param)
|
||||
elif callable(param):
|
||||
left_val, right_val = param(left), param(right)
|
||||
else:
|
||||
raise ValueError(f"Ordering key {param} is neither a string nor a callable.")
|
||||
if left_val == right_val:
|
||||
continue
|
||||
return left_val < right_val if invert else left_val > right_val
|
||||
return False
|
||||
|
||||
|
||||
def multi_max(obj_list: list[T], *parameters: str | KeyFuncT | tuple[str | KeyFuncT, bool],
|
||||
default: Any = _sentinel) -> T:
|
||||
try:
|
||||
largest = obj_list[0]
|
||||
except IndexError:
|
||||
if default is not _sentinel:
|
||||
return default
|
||||
raise ValueError("Cannot get largest item from an empty list.")
|
||||
for obj in obj_list[1:]:
|
||||
if multi_gt(obj, largest, *parameters):
|
||||
largest = obj
|
||||
return largest
|
Loading…
x
Reference in New Issue
Block a user