Add first couple of ORM models; implement some helper functions

This commit is contained in:
Daniil Fajnberg 2022-07-01 16:01:07 +02:00
parent 71f9db6c0d
commit b7e627a216
3 changed files with 243 additions and 0 deletions

129
src/compub/db/companies.py Normal file
View File

@ -0,0 +1,129 @@
from datetime import date
from typing import Optional
from uuid import uuid4
from slugify import slugify
from sqlalchemy.engine import Connection
from sqlalchemy.event.api import listens_for
from sqlalchemy.orm import relationship
from sqlalchemy.orm.mapper import Mapper
from sqlalchemy.sql.expression import select
from sqlalchemy.sql.functions import count
from sqlalchemy.sql.schema import Column, ForeignKey as FKey, Table
from sqlalchemy.sql.sqltypes import Boolean, Date, Integer, String, Unicode
from sqlalchemy_utils.types import CountryType, UUIDType
from compub.utils import multi_max
from .base import AbstractBase, ORMBase
class LegalForm(AbstractBase):
__tablename__ = 'legal_form'
id = Column(Integer, primary_key=True)
short = Column(String(32), nullable=False, index=True)
name = Column(Unicode(255))
country = Column(CountryType)
subcategories = relationship('LegalFormSubcategory', backref='legal_form', lazy='selectin')
def __str__(self) -> str:
return str(self.short)
class LegalFormSubcategory(AbstractBase):
__tablename__ = 'legal_form_subcategory'
id = Column(Integer, primary_key=True)
legal_form_id = Column(Integer, FKey('legal_form.id', ondelete='RESTRICT'), nullable=False, index=True)
short = Column(String(32), nullable=False, index=True)
name = Column(Unicode(255))
companies = relationship('Company', backref='legal_form', lazy='selectin')
def __str__(self) -> str:
return str(self.short)
class Industry(AbstractBase):
__tablename__ = 'industry'
id = Column(Integer, primary_key=True)
name = Column(String(255), nullable=False, index=True)
companies = relationship('Company', secondary='company_industries', back_populates='industries')
def __str__(self) -> str:
return str(self.name)
class Company(AbstractBase):
__tablename__ = 'company'
NAME_SORTING_PARAMS = ( # passed to `multi_max`
lambda obj: date(1, 1, 1) if obj.date_registered is None else obj.date_registered,
'date_updated',
)
# If date_registered is None, the name is considered to be older than one with a date_registered
id = Column(UUIDType, primary_key=True, default=uuid4)
visible = Column(Boolean, default=True, nullable=False, index=True)
legal_form_id = Column(Integer, FKey('legal_form_subcategory.id', ondelete='RESTRICT'), index=True)
insolvent = Column(Boolean, default=False, nullable=False, index=True)
founding_date = Column(Date)
liquidation_date = Column(Date)
# TODO: Get rid of city; implement address properly
city = Column(String(255), index=True)
address_id = Column(UUIDType, FKey('address.id', ondelete='RESTRICT'), index=True)
industries = relationship('Industry', secondary='company_industries', back_populates='companies')
names = relationship('CompanyName', backref='company', lazy='selectin')
def __str__(self) -> str:
return str(self.current_name or f"<Company {self.id}>")
@property
def current_name(self) -> Optional['CompanyName']:
return multi_max(list(self.names), *self.NAME_SORTING_PARAMS, default=None)
company_industries = Table(
'company_industries',
ORMBase.metadata,
Column('company_id', FKey('company.id'), primary_key=True),
Column('industry_id', FKey('industry.id'), primary_key=True),
)
class CompanyName(AbstractBase):
__tablename__ = 'company_name'
MAX_SLUG_LENGTH = 255
id = Column(UUIDType, primary_key=True, default=uuid4)
name = Column(Unicode(768), nullable=False, index=True)
company_id = Column(UUIDType, FKey('company.id', ondelete='RESTRICT'), index=True)
date_registered = Column(Date)
slug = Column(String(MAX_SLUG_LENGTH), index=True)
def __str__(self) -> str:
return str(self.name)
@listens_for(CompanyName, 'before_insert')
def generate_company_name_slug(_mapper: Mapper, connection: Connection, target: CompanyName) -> None:
if target.slug:
return
slug = slugify(target.name)[:(target.MAX_SLUG_LENGTH - 2)]
statement = select(count()).select_from(CompanyName).where(CompanyName.slug.startswith(slug))
num = connection.execute(statement).scalar()
if num == 0:
target.slug = slug
else:
target.slug = f'{slug}-{str(num + 1)}'
def get_reg_date(obj: CompanyName) -> date:
if obj.date_registered is None:
return date(1, 1, 1)
return obj.date_registered

View File

@ -0,0 +1,48 @@
from uuid import uuid4
from sqlalchemy.orm import relationship
from sqlalchemy.sql.schema import Column, ForeignKey as FKey
from sqlalchemy.sql.sqltypes import Integer, String, Unicode
from sqlalchemy_utils.types import CountryType, UUIDType
from .base import AbstractBase
class StateProvince(AbstractBase):
__tablename__ = 'state_province'
id = Column(Integer, primary_key=True)
country = Column(CountryType, nullable=False, index=True)
name = Column(Unicode(255), nullable=False, index=True)
cities = relationship('City', backref='state_province', lazy='selectin')
class City(AbstractBase):
__tablename__ = 'city'
id = Column(Integer, primary_key=True)
state_province_id = Column(Integer, FKey('state_province.id', ondelete='RESTRICT'), nullable=False, index=True)
zip_code = Column(String(5), nullable=False, index=True)
name = Column(Unicode(255), nullable=False, index=True)
streets = relationship('Street', backref='city', lazy='selectin')
class Street(AbstractBase):
__tablename__ = 'street'
id = Column(Integer, primary_key=True)
city_id = Column(Integer, FKey('city.id', ondelete='RESTRICT'), nullable=False, index=True)
name = Column(Unicode(255), nullable=False, index=True)
addresses = relationship('Address', backref='street', lazy='selectin')
class Address(AbstractBase):
__tablename__ = 'address'
id = Column(UUIDType, primary_key=True, default=uuid4)
street_id = Column(Integer, FKey('street.id', ondelete='RESTRICT'), nullable=False, index=True)
house_number = Column(String(8), nullable=False)
supplement = Column(String(255))

66
src/compub/utils.py Normal file
View File

@ -0,0 +1,66 @@
from operator import attrgetter
from typing import Any, Callable, TypeVar
T = TypeVar('T')
KeyFuncT = Callable[[T], Any]
_sentinel = object()
def multi_sort(obj_list: list[T], *parameters: str | KeyFuncT | tuple[str | KeyFuncT, bool]) -> None:
for param in reversed(parameters):
if isinstance(param, str):
obj_list.sort(key=attrgetter(param))
elif callable(param):
obj_list.sort(key=param)
else:
try:
param, reverse = param
assert isinstance(reverse, bool)
except (ValueError, TypeError):
raise ValueError(f"Sorting parameter {param} is neither a key nor a key-boolean-tuple.")
if isinstance(param, str):
obj_list.sort(key=attrgetter(param), reverse=reverse)
elif callable(param):
obj_list.sort(key=param, reverse=reverse)
else:
raise ValueError(f"Sorting key {param} is neither a string nor a callable.")
def multi_gt(left: T, right: T, *parameters: str | KeyFuncT | tuple[str | KeyFuncT, bool]) -> bool:
for param in parameters:
invert = False
if isinstance(param, str):
left_val, right_val = getattr(left, param), getattr(right, param)
elif callable(param):
left_val, right_val = param(left), param(right)
else:
try:
param, invert = param
assert isinstance(invert, bool)
except (ValueError, TypeError, AssertionError):
raise ValueError(f"Ordering parameter {param} is neither a key nor a key-boolean-tuple.")
if isinstance(param, str):
left_val, right_val = getattr(left, param), getattr(right, param)
elif callable(param):
left_val, right_val = param(left), param(right)
else:
raise ValueError(f"Ordering key {param} is neither a string nor a callable.")
if left_val == right_val:
continue
return left_val < right_val if invert else left_val > right_val
return False
def multi_max(obj_list: list[T], *parameters: str | KeyFuncT | tuple[str | KeyFuncT, bool],
default: Any = _sentinel) -> T:
try:
largest = obj_list[0]
except IndexError:
if default is not _sentinel:
return default
raise ValueError("Cannot get largest item from an empty list.")
for obj in obj_list[1:]:
if multi_gt(obj, largest, *parameters):
largest = obj
return largest