From b7e627a216e08d0147adb5d88bf2d5babd102854 Mon Sep 17 00:00:00 2001 From: Daniil Fajnberg Date: Fri, 1 Jul 2022 16:01:07 +0200 Subject: [PATCH] Add first couple of ORM models; implement some helper functions --- src/compub/db/companies.py | 129 +++++++++++++++++++++++++++++++++++++ src/compub/db/geography.py | 48 ++++++++++++++ src/compub/utils.py | 66 +++++++++++++++++++ 3 files changed, 243 insertions(+) create mode 100644 src/compub/db/companies.py create mode 100644 src/compub/db/geography.py create mode 100644 src/compub/utils.py diff --git a/src/compub/db/companies.py b/src/compub/db/companies.py new file mode 100644 index 0000000..077be5e --- /dev/null +++ b/src/compub/db/companies.py @@ -0,0 +1,129 @@ +from datetime import date +from typing import Optional +from uuid import uuid4 + +from slugify import slugify +from sqlalchemy.engine import Connection +from sqlalchemy.event.api import listens_for +from sqlalchemy.orm import relationship +from sqlalchemy.orm.mapper import Mapper +from sqlalchemy.sql.expression import select +from sqlalchemy.sql.functions import count +from sqlalchemy.sql.schema import Column, ForeignKey as FKey, Table +from sqlalchemy.sql.sqltypes import Boolean, Date, Integer, String, Unicode +from sqlalchemy_utils.types import CountryType, UUIDType + +from compub.utils import multi_max +from .base import AbstractBase, ORMBase + + +class LegalForm(AbstractBase): + __tablename__ = 'legal_form' + + id = Column(Integer, primary_key=True) + short = Column(String(32), nullable=False, index=True) + name = Column(Unicode(255)) + country = Column(CountryType) + + subcategories = relationship('LegalFormSubcategory', backref='legal_form', lazy='selectin') + + def __str__(self) -> str: + return str(self.short) + + +class LegalFormSubcategory(AbstractBase): + __tablename__ = 'legal_form_subcategory' + + id = Column(Integer, primary_key=True) + legal_form_id = Column(Integer, FKey('legal_form.id', ondelete='RESTRICT'), nullable=False, index=True) + short = Column(String(32), nullable=False, index=True) + name = Column(Unicode(255)) + + companies = relationship('Company', backref='legal_form', lazy='selectin') + + def __str__(self) -> str: + return str(self.short) + + +class Industry(AbstractBase): + __tablename__ = 'industry' + + id = Column(Integer, primary_key=True) + name = Column(String(255), nullable=False, index=True) + + companies = relationship('Company', secondary='company_industries', back_populates='industries') + + def __str__(self) -> str: + return str(self.name) + + +class Company(AbstractBase): + __tablename__ = 'company' + + NAME_SORTING_PARAMS = ( # passed to `multi_max` + lambda obj: date(1, 1, 1) if obj.date_registered is None else obj.date_registered, + 'date_updated', + ) + # If date_registered is None, the name is considered to be older than one with a date_registered + + id = Column(UUIDType, primary_key=True, default=uuid4) + visible = Column(Boolean, default=True, nullable=False, index=True) + legal_form_id = Column(Integer, FKey('legal_form_subcategory.id', ondelete='RESTRICT'), index=True) + insolvent = Column(Boolean, default=False, nullable=False, index=True) + founding_date = Column(Date) + liquidation_date = Column(Date) + # TODO: Get rid of city; implement address properly + city = Column(String(255), index=True) + address_id = Column(UUIDType, FKey('address.id', ondelete='RESTRICT'), index=True) + + industries = relationship('Industry', secondary='company_industries', back_populates='companies') + names = relationship('CompanyName', backref='company', lazy='selectin') + + def __str__(self) -> str: + return str(self.current_name or f"") + + @property + def current_name(self) -> Optional['CompanyName']: + return multi_max(list(self.names), *self.NAME_SORTING_PARAMS, default=None) + + +company_industries = Table( + 'company_industries', + ORMBase.metadata, + Column('company_id', FKey('company.id'), primary_key=True), + Column('industry_id', FKey('industry.id'), primary_key=True), +) + + +class CompanyName(AbstractBase): + __tablename__ = 'company_name' + + MAX_SLUG_LENGTH = 255 + + id = Column(UUIDType, primary_key=True, default=uuid4) + name = Column(Unicode(768), nullable=False, index=True) + company_id = Column(UUIDType, FKey('company.id', ondelete='RESTRICT'), index=True) + date_registered = Column(Date) + slug = Column(String(MAX_SLUG_LENGTH), index=True) + + def __str__(self) -> str: + return str(self.name) + + +@listens_for(CompanyName, 'before_insert') +def generate_company_name_slug(_mapper: Mapper, connection: Connection, target: CompanyName) -> None: + if target.slug: + return + slug = slugify(target.name)[:(target.MAX_SLUG_LENGTH - 2)] + statement = select(count()).select_from(CompanyName).where(CompanyName.slug.startswith(slug)) + num = connection.execute(statement).scalar() + if num == 0: + target.slug = slug + else: + target.slug = f'{slug}-{str(num + 1)}' + + +def get_reg_date(obj: CompanyName) -> date: + if obj.date_registered is None: + return date(1, 1, 1) + return obj.date_registered diff --git a/src/compub/db/geography.py b/src/compub/db/geography.py new file mode 100644 index 0000000..35c0d29 --- /dev/null +++ b/src/compub/db/geography.py @@ -0,0 +1,48 @@ +from uuid import uuid4 + +from sqlalchemy.orm import relationship +from sqlalchemy.sql.schema import Column, ForeignKey as FKey +from sqlalchemy.sql.sqltypes import Integer, String, Unicode +from sqlalchemy_utils.types import CountryType, UUIDType + +from .base import AbstractBase + + +class StateProvince(AbstractBase): + __tablename__ = 'state_province' + + id = Column(Integer, primary_key=True) + country = Column(CountryType, nullable=False, index=True) + name = Column(Unicode(255), nullable=False, index=True) + + cities = relationship('City', backref='state_province', lazy='selectin') + + +class City(AbstractBase): + __tablename__ = 'city' + + id = Column(Integer, primary_key=True) + state_province_id = Column(Integer, FKey('state_province.id', ondelete='RESTRICT'), nullable=False, index=True) + zip_code = Column(String(5), nullable=False, index=True) + name = Column(Unicode(255), nullable=False, index=True) + + streets = relationship('Street', backref='city', lazy='selectin') + + +class Street(AbstractBase): + __tablename__ = 'street' + + id = Column(Integer, primary_key=True) + city_id = Column(Integer, FKey('city.id', ondelete='RESTRICT'), nullable=False, index=True) + name = Column(Unicode(255), nullable=False, index=True) + + addresses = relationship('Address', backref='street', lazy='selectin') + + +class Address(AbstractBase): + __tablename__ = 'address' + + id = Column(UUIDType, primary_key=True, default=uuid4) + street_id = Column(Integer, FKey('street.id', ondelete='RESTRICT'), nullable=False, index=True) + house_number = Column(String(8), nullable=False) + supplement = Column(String(255)) diff --git a/src/compub/utils.py b/src/compub/utils.py new file mode 100644 index 0000000..c1d6ab1 --- /dev/null +++ b/src/compub/utils.py @@ -0,0 +1,66 @@ +from operator import attrgetter +from typing import Any, Callable, TypeVar + + +T = TypeVar('T') +KeyFuncT = Callable[[T], Any] +_sentinel = object() + + +def multi_sort(obj_list: list[T], *parameters: str | KeyFuncT | tuple[str | KeyFuncT, bool]) -> None: + for param in reversed(parameters): + if isinstance(param, str): + obj_list.sort(key=attrgetter(param)) + elif callable(param): + obj_list.sort(key=param) + else: + try: + param, reverse = param + assert isinstance(reverse, bool) + except (ValueError, TypeError): + raise ValueError(f"Sorting parameter {param} is neither a key nor a key-boolean-tuple.") + if isinstance(param, str): + obj_list.sort(key=attrgetter(param), reverse=reverse) + elif callable(param): + obj_list.sort(key=param, reverse=reverse) + else: + raise ValueError(f"Sorting key {param} is neither a string nor a callable.") + + +def multi_gt(left: T, right: T, *parameters: str | KeyFuncT | tuple[str | KeyFuncT, bool]) -> bool: + for param in parameters: + invert = False + if isinstance(param, str): + left_val, right_val = getattr(left, param), getattr(right, param) + elif callable(param): + left_val, right_val = param(left), param(right) + else: + try: + param, invert = param + assert isinstance(invert, bool) + except (ValueError, TypeError, AssertionError): + raise ValueError(f"Ordering parameter {param} is neither a key nor a key-boolean-tuple.") + if isinstance(param, str): + left_val, right_val = getattr(left, param), getattr(right, param) + elif callable(param): + left_val, right_val = param(left), param(right) + else: + raise ValueError(f"Ordering key {param} is neither a string nor a callable.") + if left_val == right_val: + continue + return left_val < right_val if invert else left_val > right_val + return False + + +def multi_max(obj_list: list[T], *parameters: str | KeyFuncT | tuple[str | KeyFuncT, bool], + default: Any = _sentinel) -> T: + try: + largest = obj_list[0] + except IndexError: + if default is not _sentinel: + return default + raise ValueError("Cannot get largest item from an empty list.") + for obj in obj_list[1:]: + if multi_gt(obj, largest, *parameters): + largest = obj + return largest