compub/src/compub/db/companies.py

139 lines
4.5 KiB
Python

from datetime import date
from typing import Optional
from uuid import uuid4
from slugify import slugify
from sqlalchemy.engine import Connection
from sqlalchemy.event.api import listens_for
from sqlalchemy.orm import relationship
from sqlalchemy.orm.mapper import Mapper
from sqlalchemy.sql.expression import select
from sqlalchemy.sql.functions import count
from sqlalchemy.sql.schema import Column, ForeignKey as FKey, Table
from sqlalchemy.sql.sqltypes import Boolean, Date, Integer, String, Unicode
from sqlalchemy_utils.types import CountryType, UUIDType
from compub.utils import multi_max
from .base import AbstractBase, ORMBase
__all__ = [
'LegalForm',
'LegalFormSubcategory',
'Integer',
'Company',
'CompanyName',
]
class LegalForm(AbstractBase):
__tablename__ = 'legal_form'
id = Column(Integer, primary_key=True)
short = Column(String(32), nullable=False, index=True)
name = Column(Unicode(255))
country = Column(CountryType)
subcategories = relationship('LegalFormSubcategory', backref='legal_form', lazy='selectin')
def __str__(self) -> str:
return str(self.short)
class LegalFormSubcategory(AbstractBase):
__tablename__ = 'legal_form_subcategory'
id = Column(Integer, primary_key=True)
legal_form_id = Column(Integer, FKey('legal_form.id', ondelete='RESTRICT'), nullable=False, index=True)
short = Column(String(32), nullable=False, index=True)
name = Column(Unicode(255))
companies = relationship('Company', backref='legal_form', lazy='selectin')
def __str__(self) -> str:
return str(self.short)
class Industry(AbstractBase):
__tablename__ = 'industry'
id = Column(Integer, primary_key=True)
name = Column(String(255), nullable=False, index=True)
companies = relationship('Company', secondary='company_industries', back_populates='industries')
def __str__(self) -> str:
return str(self.name)
class Company(AbstractBase):
__tablename__ = 'company'
NAME_SORTING_PARAMS = ( # passed to `multi_max`
lambda obj: date(1, 1, 1) if obj.date_registered is None else obj.date_registered,
'date_updated',
)
# If date_registered is None, the name is considered to be older than one with a date_registered
id = Column(UUIDType, primary_key=True, default=uuid4)
visible = Column(Boolean, default=True, nullable=False, index=True)
legal_form_id = Column(Integer, FKey('legal_form_subcategory.id', ondelete='RESTRICT'), index=True)
insolvent = Column(Boolean, default=False, nullable=False, index=True)
founding_date = Column(Date)
liquidation_date = Column(Date)
# TODO: Get rid of city; implement address properly
city = Column(String(255), index=True)
address_id = Column(UUIDType, FKey('address.id', ondelete='RESTRICT'), index=True)
industries = relationship('Industry', secondary='company_industries', back_populates='companies')
names = relationship('CompanyName', backref='company', lazy='selectin')
def __str__(self) -> str:
return str(self.current_name or f"<Company {self.id}>")
@property
def current_name(self) -> Optional['CompanyName']:
return multi_max(list(self.names), *self.NAME_SORTING_PARAMS, default=None)
company_industries = Table(
'company_industries',
ORMBase.metadata,
Column('company_id', FKey('company.id'), primary_key=True),
Column('industry_id', FKey('industry.id'), primary_key=True),
)
class CompanyName(AbstractBase):
__tablename__ = 'company_name'
MAX_SLUG_LENGTH = 255
id = Column(UUIDType, primary_key=True, default=uuid4)
name = Column(Unicode(768), nullable=False, index=True)
company_id = Column(UUIDType, FKey('company.id', ondelete='RESTRICT'), index=True)
date_registered = Column(Date)
slug = Column(String(MAX_SLUG_LENGTH), index=True)
def __str__(self) -> str:
return str(self.name)
@listens_for(CompanyName, 'before_insert')
def generate_company_name_slug(_mapper: Mapper, connection: Connection, target: CompanyName) -> None:
if target.slug:
return
slug = slugify(target.name)[:(target.MAX_SLUG_LENGTH - 2)]
statement = select(count()).select_from(CompanyName).where(CompanyName.slug.startswith(slug))
num = connection.execute(statement).scalar()
if num == 0:
target.slug = slug
else:
target.slug = f'{slug}-{str(num + 1)}'
def get_reg_date(obj: CompanyName) -> date:
if obj.date_registered is None:
return date(1, 1, 1)
return obj.date_registered