Compare commits

...

9 Commits

Author SHA1 Message Date
Daniil Fajnberg 9f487d515d Write a few SQLModels;
set up two basic CRUD-like functions;
create two test routes;
make `db_uri` setting more flexible
2022-08-10 21:50:00 +02:00
Daniil Fajnberg 60c6b1c4bb Improve imports 2022-07-01 20:06:55 +02:00
Daniil Fajnberg 8a8ea7ec09 Make use of more SQLAlchemy Utils 2022-07-01 17:17:52 +02:00
Daniil Fajnberg b7e627a216 Add first couple of ORM models; implement some helper functions 2022-07-01 16:01:07 +02:00
Daniil Fajnberg 71f9db6c0d Add ORM base class; make Postgresql mandatory; add dependencies 2022-07-01 10:46:48 +02:00
Daniil Fajnberg eb08e5a674 Add database setup 2022-06-29 13:06:20 +02:00
Daniil Fajnberg 8ec4e8885a Add uvicorn server; make in-memory DB default 2022-06-29 13:04:36 +02:00
Daniil Fajnberg 27696ad57c Implement settings singleton 2022-06-29 12:43:48 +02:00
Daniil Fajnberg aa7fed7817 Add meta data and dependencies 2022-06-29 11:53:42 +02:00
17 changed files with 674 additions and 7 deletions

View File

@ -12,7 +12,7 @@ Company publications API
## Dependencies
Python Version ..., OS ...
Python Version 3.10+, OS ...
## Building from source

View File

@ -0,0 +1,8 @@
Pydantic
FastAPI
SQLAlchemy[asyncio]==1.4.35
Alembic
SQLAlchemy-Utils
SQLModel
Babel
Python-Slugify

View File

@ -1,2 +1,4 @@
-r common.txt
-r srv.txt
coverage
sqlalchemy-stubs
aiosqlite

2
requirements/srv.txt Normal file
View File

@ -0,0 +1,2 @@
-r common.txt
Uvicorn[standard]

View File

@ -1,8 +1,8 @@
[metadata]
name = compub
version = 0.0.1
author = Daniil
author_email = mail@placeholder123.to
author = Daniil Fajnberg
author_email = mail@daniil.fajnberg.de
description = Company publications API
long_description = file: README.md
long_description_content_type = text/markdown
@ -10,20 +10,36 @@ url = https://git.fajnberg.de/daniil/compub
project_urls =
Bug Tracker = https://git.fajnberg.de/daniil/compub/issues
classifiers =
Programming Language :: Python :: 3
Programming Language :: Python :: 3 :: Only
Programming Language :: Python :: 3.10
Operating System :: OS Independent
Development Status :: 3 - Alpha
Framework :: AsyncIO
Framework :: FastAPI
[options]
package_dir =
= src
packages = find:
python_requires = >=3
python_requires = >=3.10, <4
install_requires =
...
Pydantic
FastAPI
SQLAlchemy[asyncio]==1.4.35
Alembic
SQLAlchemy-Utils
SQLModel
Babel
Python-Slugify
[options.extras_require]
srv =
Uvicorn[standard]
dev =
Uvicorn[standard]
coverage
sqlalchemy-stubs
aiosqlite
[options.packages.find]
where = src

30
src/compub/__main__.py Normal file
View File

@ -0,0 +1,30 @@
from argparse import ArgumentParser, SUPPRESS
from pathlib import Path
from typing import Any, Sequence
import uvicorn
from .settings import PROGRAM_NAME, CONFIG_FILE_PATH_PARAM, DEFAULT_CONFIG_FILE_PATHS, init, settings
def parse_cli(args: Sequence[str] = None) -> dict[str, Any]:
parser = ArgumentParser(PROGRAM_NAME)
parser.add_argument(
'-c', f'--{CONFIG_FILE_PATH_PARAM}',
type=Path,
metavar='PATH',
default=SUPPRESS,
help=f"Paths to config file that will take precedence over all others; "
f"the following {len(DEFAULT_CONFIG_FILE_PATHS)} paths are always checked first (in that order):"
f" {','.join(str(p) for p in DEFAULT_CONFIG_FILE_PATHS)}"
)
return vars(parser.parse_args(args))
def main():
init(**parse_cli())
uvicorn.run(f'{PROGRAM_NAME}.routes:app', **settings.server.dict())
if __name__ == '__main__':
main()

View File

@ -0,0 +1 @@
from .geography import *

View File

@ -0,0 +1,33 @@
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.sql.expression import select
from compub.models.geography import *
__all__ = [
# 'get_state_by_name_and_country',
'get_states',
'create_state'
]
# async def get_state_by_name_and_country(session: AsyncSession, name: str, country: str) -> db.StateProvince:
# statement = select(db.StateProvince).filter(
# db.StateProvince.name == name,
# db.StateProvince.country == country
# ).limit(1)
# result = await session.execute(statement)
# return result.scalars().first()
async def get_states(session: AsyncSession, skip: int = 0, limit: int = 100) -> list[StateProvince]:
statement = select(StateProvince).offset(skip).limit(limit)
result = await session.execute(statement)
return result.scalars().all()
async def create_state(session: AsyncSession, state: StateProvince) -> StateProvince:
session.add(state)
await session.commit()
await session.refresh(state)
return state

6
src/compub/exceptions.py Normal file
View File

@ -0,0 +1,6 @@
class ConfigError(Exception):
pass
class NoDatabaseConfigured(ConfigError):
pass

View File

@ -0,0 +1,3 @@
from .base import DB
from .geography import *
from .companies import *

80
src/compub/models/base.py Normal file
View File

@ -0,0 +1,80 @@
from datetime import datetime
from typing import Any, Iterator, Optional
from sqlalchemy.ext.asyncio.engine import AsyncEngine, create_async_engine
from sqlalchemy.ext.asyncio.session import AsyncSession
from sqlalchemy.future.engine import Engine
from sqlalchemy.orm.session import sessionmaker
from sqlalchemy.sql.functions import now as db_now
from sqlalchemy.sql.schema import Column
from sqlalchemy.sql.sqltypes import TIMESTAMP
from sqlalchemy_utils.functions.orm import get_columns
from sqlalchemy_utils.listeners import force_auto_coercion
from sqlmodel.main import Field, SQLModel
from compub.exceptions import NoDatabaseConfigured
from compub.settings import settings
from sqlmodel import create_engine, Session
class DB:
def __init__(self):
self._engine: AsyncEngine | Engine | None = None
self._session_maker: sessionmaker | None = None
def start_engine(self) -> None:
if settings.db_uri is None:
raise NoDatabaseConfigured
if settings.db_uri.scheme == 'sqlite':
self._engine = create_engine(settings.db_uri)
self._session_maker = sessionmaker(self.engine, expire_on_commit=False, class_=Session)
else:
self._engine = create_async_engine(settings.db_uri, future=True)
self._session_maker = sessionmaker(self.engine, expire_on_commit=False, class_=AsyncSession)
force_auto_coercion()
@property
def engine(self):
if self._engine is None:
self.start_engine()
assert isinstance(self._engine, (AsyncEngine, Engine))
return self._engine
async def get_session(self) -> Session | AsyncSession:
if self._session_maker is None:
self.start_engine()
assert isinstance(self._session_maker, sessionmaker)
session = self._session_maker()
try:
yield session
finally:
await session.close()
class AbstractBase(SQLModel):
id: Optional[int] = Field(default=None, primary_key=True)
date_created: Optional[datetime] = Field(
default=None, sa_column=Column(TIMESTAMP(timezone=False), server_default=db_now())
)
date_updated: Optional[datetime] = Field(
default=None, sa_column=Column(TIMESTAMP(timezone=False), server_default=db_now(), onupdate=db_now())
)
def __repr__(self) -> str:
fields = self.iter_fields(excl_non_repr=True, excl_null_value=True)
attrs = ', '.join(f"{name}={repr(value)}" for name, value in fields)
return f"{self.__class__.__name__}({attrs})"
def iter_fields(self, excl_non_repr: bool = False, excl_null_value: bool = False) -> Iterator[tuple[str, Any]]:
for name in get_columns(self).keys():
if excl_non_repr and name in self.get_non_repr_fields():
continue
value = getattr(self, name)
if excl_null_value and value is None:
continue
yield name, value
@staticmethod
def get_non_repr_fields() -> list[str]:
return ['id', 'date_created', 'date_updated']

View File

@ -0,0 +1,170 @@
from datetime import date
from typing import Optional
from slugify import slugify
from sqlalchemy.engine.base import Connection
from sqlalchemy.event.api import listens_for
from sqlalchemy.orm.mapper import Mapper
from sqlalchemy.sql.expression import select
from sqlalchemy.sql.functions import count
from sqlalchemy.sql.schema import Column
from sqlalchemy.sql.sqltypes import Unicode
from sqlalchemy_utils.types import CountryType
from sqlmodel.main import Field, Relationship
from compub.utils import multi_max
from .base import AbstractBase
__all__ = [
'LegalForm',
'LegalFormSubcategory',
'Industry',
'Company',
'CompanyName',
]
class LegalForm(AbstractBase, table=True):
__tablename__ = 'legal_form'
# Fields
short: str = Field(max_length=32, nullable=False, index=True)
name: Optional[str] = Field(default=None, max_length=255, sa_column=Column(Unicode(255)))
country: str = Field(sa_column=Column(CountryType))
# Relationships
subcategories: list['LegalFormSubcategory'] = Relationship(
back_populates='legal_form', sa_relationship_kwargs={'lazy': 'selectin'}
)
def __str__(self) -> str:
return str(self.short)
class LegalFormSubcategory(AbstractBase, table=True):
__tablename__ = 'legal_form_subcategory'
# Fields
short: str = Field(max_length=32, nullable=False, index=True)
name: Optional[str] = Field(default=None, max_length=255, sa_column=Column(Unicode(255)))
# Relationships
legal_form_id: Optional[int] = Field(
foreign_key='legal_form.id', default=None, nullable=False, index=True
)
legal_form: Optional[LegalForm] = Relationship(
back_populates='subcategories', sa_relationship_kwargs={'lazy': 'selectin'}
)
companies: list['Company'] = Relationship(
back_populates='legal_form', sa_relationship_kwargs={'lazy': 'selectin'}
)
def __str__(self) -> str:
return str(self.short)
class CompanyIndustryLink(AbstractBase, table=True):
__tablename__ = 'company_industries'
# Relationships
company_id: Optional[int] = Field(foreign_key='company.id', default=None, nullable=False, primary_key=True)
industry_id: Optional[int] = Field(foreign_key='industry.id', default=None, nullable=False, primary_key=True)
class Industry(AbstractBase, table=True):
__tablename__ = 'industry'
# Fields
name: str = Field(max_length=255, nullable=False, index=True)
# Relationships
companies: list['Company'] = Relationship(
back_populates='industries', link_model=CompanyIndustryLink, sa_relationship_kwargs={'lazy': 'selectin'}
)
def __str__(self) -> str:
return str(self.name)
class Company(AbstractBase, table=True):
__tablename__ = 'company'
# Fields
visible: bool = Field(default=True, nullable=False, index=True)
insolvent: bool = Field(default=False, nullable=False, index=True)
founding_data: Optional[date]
liquidation_date: Optional[date]
city: str = Field(max_length=255, index=True) # TODO: Get rid of city; implement address properly
# Relationships
legal_form_id: Optional[int] = Field(
foreign_key='legal_form_subcategory.id', default=None, index=True
)
legal_form: Optional[LegalFormSubcategory] = Relationship(
back_populates='companies', sa_relationship_kwargs={'lazy': 'selectin'}
)
address_id: Optional[int] = Field(
foreign_key='address.id', default=None, index=True
)
industries: list[Industry] = Relationship(
back_populates='companies', link_model=CompanyIndustryLink, sa_relationship_kwargs={'lazy': 'selectin'}
)
names: list['CompanyName'] = Relationship(
back_populates='company', sa_relationship_kwargs={'lazy': 'selectin'}
)
def __str__(self) -> str:
return str(self.current_name or f"<Company {self.id}>")
@property
def current_name(self) -> Optional['CompanyName']:
return multi_max(list(self.names), CompanyName.get_reg_date, 'date_updated', default=None)
class CompanyName(AbstractBase, table=True):
__tablename__ = 'company_name'
__MAX_NAME_LENGTH__: int = 768
__MAX_SLUG_LENGTH__: int = 255
# Fields
name: str = Field(
max_length=__MAX_NAME_LENGTH__, sa_column=Column(Unicode(__MAX_NAME_LENGTH__), nullable=False, index=True)
)
date_registered: Optional[date]
slug: Optional[str] = Field(default=None, max_length=__MAX_SLUG_LENGTH__, index=True)
# Relationships
company_id: Optional[int] = Field(
foreign_key='company.id', default=None, nullable=False, index=True
)
company: Optional[Company] = Relationship(
back_populates='names', sa_relationship_kwargs={'lazy': 'selectin'}
)
def __str__(self) -> str:
return str(self.name)
def get_reg_date(self) -> date:
return date(1, 1, 1) if self.date_registered is None else self.date_registered
@property
def max_slug_length(self) -> int:
return self.__MAX_SLUG_LENGTH__
@listens_for(CompanyName, 'before_insert')
def generate_company_name_slug(_mapper: Mapper, connection: Connection, target: CompanyName) -> None:
if target.slug:
return
slug = slugify(target.name)[:(target.max_slug_length - 2)]
statement = select(count()).select_from(CompanyName).where(CompanyName.slug.startswith(slug))
num = connection.execute(statement).scalar()
if num == 0:
target.slug = slug
else:
target.slug = f'{slug}-{str(num + 1)}'

View File

@ -0,0 +1,94 @@
from typing import Any, Optional
from pydantic import validator
from sqlalchemy.sql.schema import Column
from sqlalchemy.sql.sqltypes import Unicode
from sqlalchemy_utils.primitives.country import Country
from sqlalchemy_utils.types import CountryType
from sqlmodel.main import Field, Relationship
from .base import AbstractBase
__all__ = [
'StateProvince',
'City',
'Street',
'Address',
]
class StateProvince(AbstractBase, table=True):
__tablename__ = 'state_province'
# Fields
country: str = Field(sa_column=Column(CountryType, nullable=False, index=True))
name: str = Field(max_length=255, sa_column=Column(Unicode(255), nullable=False, index=True))
# Relationships
cities: list['City'] = Relationship(
back_populates='state_province', sa_relationship_kwargs={'lazy': 'selectin'}
)
@validator('country', pre=True)
def country_as_uppercase_string(cls, v: Any) -> str:
if isinstance(v, Country):
return v.code
if isinstance(v, str):
return v.upper()
raise TypeError
class City(AbstractBase, table=True):
__tablename__ = 'city'
# Fields
zip_code: str = Field(max_length=5, nullable=False, index=True)
name: str = Field(max_length=255, sa_column=Column(Unicode(255), nullable=False, index=True))
# Relationships
state_province_id: Optional[int] = Field(
foreign_key='state_province.id', default=None, nullable=False, index=True
)
state_province: Optional[StateProvince] = Relationship(
back_populates='cities', sa_relationship_kwargs={'lazy': 'selectin'}
)
streets: list['Street'] = Relationship(
back_populates='city', sa_relationship_kwargs={'lazy': 'selectin'}
)
class Street(AbstractBase, table=True):
__tablename__ = 'street'
# Fields
name: str = Field(max_length=255, sa_column=Column(Unicode(255), nullable=False, index=True))
# Relationships
city_id: Optional[int] = Field(
foreign_key='city.id', default=None, nullable=False, index=True
)
city: Optional[City] = Relationship(
back_populates='streets', sa_relationship_kwargs={'lazy': 'selectin'}
)
addresses: list['Address'] = Relationship(
back_populates='street', sa_relationship_kwargs={'lazy': 'selectin'}
)
class Address(AbstractBase, table=True):
__tablename__ = 'address'
# Fields
house_number: str = Field(max_length=8, nullable=False)
supplement: str = Field(max_length=255)
# Relationships
street_id: Optional[int] = Field(
foreign_key='street.id', default=None, nullable=False, index=True
)
street: Optional[Street] = Relationship(
back_populates='addresses', sa_relationship_kwargs={'lazy': 'selectin'}
)

43
src/compub/routes.py Normal file
View File

@ -0,0 +1,43 @@
from fastapi import FastAPI, HTTPException, Depends
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import SQLModel
from .models import *
from .models.base import DB
from . import crud
api, db = FastAPI(), DB()
@api.on_event('startup')
async def initialize_db():
async with db.engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.drop_all)
await conn.run_sync(SQLModel.metadata.create_all)
@api.on_event("shutdown")
async def close_connection_pool():
await db.engine.dispose()
@api.post("/states/", response_model=StateProvince)
async def create_state(state: StateProvince, session: AsyncSession = Depends(db.get_session)):
# db_obj = await crud.get_state_by_name_and_country(session, name=state.name, country=state.country)
# if db_obj:
# raise HTTPException(status_code=400, detail="Name already registered")
return await crud.create_state(session, state=state)
@api.get("/states/", response_model=list[StateProvince])
async def list_states(skip: int = 0, limit: int = 100, session: AsyncSession = Depends(db.get_session)):
return await crud.get_states(session, skip=skip, limit=limit)
# @app.post("/cities/", response_model=City)
# async def create_city(city: City, session: AsyncSession = Depends(get_session)):
# db_obj = await crud.get_state_by_name_and_country(session, name=state.name, country=state.country)
# if db_obj:
# raise HTTPException(status_code=400, detail="Name already registered")
# return await crud.create_state(session, state=state)

113
src/compub/settings.py Normal file
View File

@ -0,0 +1,113 @@
import logging
import logging.config
from pathlib import Path
from typing import Any, Callable, ClassVar
from pydantic import BaseModel, BaseSettings, AnyUrl, validator
from pydantic.env_settings import SettingsSourceCallable
from yaml import safe_load
log = logging.getLogger(__name__)
PROGRAM_NAME = 'compub'
THIS_DIR = Path(__file__).parent
PROJECT_DIR = THIS_DIR.parent.parent
DEFAULT_CONFIG_FILE_NAME = 'config.yaml'
DEFAULT_CONFIG_FILE_PATHS = [
Path('/etc', PROGRAM_NAME, DEFAULT_CONFIG_FILE_NAME), # system directory
Path(PROJECT_DIR, DEFAULT_CONFIG_FILE_NAME), # project directory
Path('.', DEFAULT_CONFIG_FILE_NAME), # working directory
]
CONFIG_FILE_PATH_PARAM = 'config_file'
class AbstractBaseSettings(BaseSettings):
_config_file_paths: ClassVar[list[Path]] = DEFAULT_CONFIG_FILE_PATHS
def __init__(self, *args, **kwargs):
config_file_path = kwargs.pop(CONFIG_FILE_PATH_PARAM, None)
if config_file_path is not None:
self._config_file_paths.append(Path(config_file_path))
super().__init__(*args, **kwargs)
def get_config_file_paths(self) -> list[Path]:
return self._config_file_paths
class Config:
allow_mutation = False
env_file_encoding = 'utf-8'
underscore_attrs_are_private = True
@classmethod
def customise_sources(
cls,
init_settings: SettingsSourceCallable,
env_settings: SettingsSourceCallable,
file_secret_settings: SettingsSourceCallable
) -> tuple[Callable, ...]:
return init_settings, env_settings, _yaml_config_settings_source
def _yaml_config_settings_source(settings_obj: AbstractBaseSettings) -> dict[str, Any]:
"""
Incrementally loads (and updates) settings from all config files that can be found as returned by the
`Settings.get_config_file_paths` method and returns the result in a dictionary.
This function is intended to be used as a settings source in the `Config.customise_sources` method.
"""
config = {}
for path in settings_obj.get_config_file_paths():
if not path.is_file():
log.debug(f"No config file found at '{path}'")
continue
log.info(f"Reading config file '{path}'")
with open(path, 'r') as f:
config.update(safe_load(f))
return config
class ServerSettings(BaseModel):
host: str = '127.0.0.1'
port: int = 9009
uds: str | None = None
class DBURI(AnyUrl):
host_required = False
class Settings(AbstractBaseSettings):
db_uri: DBURI | None = None
server: ServerSettings = ServerSettings()
log_config: dict | Path | None = None
@validator('log_config')
def configure_logging(cls, v: dict | Path | None) -> dict | None:
if v is None:
return None
if isinstance(v, Path):
with open(v, 'r') as f:
logging_conf = safe_load(f)
logging.config.dictConfig(logging_conf)
return logging_conf
if isinstance(v, dict):
logging.config.dictConfig(v)
return v
raise TypeError
settings = Settings()
def init(**kwargs) -> None:
settings.__init__(**kwargs)
def update(**kwargs) -> None:
settings_dict = settings.dict()
settings_dict.update(kwargs)
settings.__init__(**settings_dict)

66
src/compub/utils.py Normal file
View File

@ -0,0 +1,66 @@
from operator import attrgetter
from typing import Any, Callable, TypeVar
T = TypeVar('T')
KeyFuncT = Callable[[T], Any]
_sentinel = object()
def multi_sort(obj_list: list[T], *parameters: str | KeyFuncT | tuple[str | KeyFuncT, bool]) -> None:
for param in reversed(parameters):
if isinstance(param, str):
obj_list.sort(key=attrgetter(param))
elif callable(param):
obj_list.sort(key=param)
else:
try:
param, reverse = param
assert isinstance(reverse, bool)
except (ValueError, TypeError):
raise ValueError(f"Sorting parameter {param} is neither a key nor a key-boolean-tuple.")
if isinstance(param, str):
obj_list.sort(key=attrgetter(param), reverse=reverse)
elif callable(param):
obj_list.sort(key=param, reverse=reverse)
else:
raise ValueError(f"Sorting key {param} is neither a string nor a callable.")
def multi_gt(left: T, right: T, *parameters: str | KeyFuncT | tuple[str | KeyFuncT, bool]) -> bool:
for param in parameters:
invert = False
if isinstance(param, str):
left_val, right_val = getattr(left, param), getattr(right, param)
elif callable(param):
left_val, right_val = param(left), param(right)
else:
try:
param, invert = param
assert isinstance(invert, bool)
except (ValueError, TypeError, AssertionError):
raise ValueError(f"Ordering parameter {param} is neither a key nor a key-boolean-tuple.")
if isinstance(param, str):
left_val, right_val = getattr(left, param), getattr(right, param)
elif callable(param):
left_val, right_val = param(left), param(right)
else:
raise ValueError(f"Ordering key {param} is neither a string nor a callable.")
if left_val == right_val:
continue
return left_val < right_val if invert else left_val > right_val
return False
def multi_max(obj_list: list[T], *parameters: str | KeyFuncT | tuple[str | KeyFuncT, bool],
default: Any = _sentinel) -> T:
try:
largest = obj_list[0]
except IndexError:
if default is not _sentinel:
return default
raise ValueError("Cannot get largest item from an empty list.")
for obj in obj_list[1:]:
if multi_gt(obj, largest, *parameters):
largest = obj
return largest