import logging import logging.config from pathlib import Path from typing import Any, Callable, ClassVar from pydantic import BaseModel, BaseSettings, AnyUrl, validator from pydantic.env_settings import SettingsSourceCallable from yaml import safe_load log = logging.getLogger(__name__) PROGRAM_NAME = 'compub' THIS_DIR = Path(__file__).parent PROJECT_DIR = THIS_DIR.parent.parent DEFAULT_CONFIG_FILE_NAME = 'config.yaml' DEFAULT_CONFIG_FILE_PATHS = [ Path('/etc', PROGRAM_NAME, DEFAULT_CONFIG_FILE_NAME), # system directory Path(PROJECT_DIR, DEFAULT_CONFIG_FILE_NAME), # project directory Path('.', DEFAULT_CONFIG_FILE_NAME), # working directory ] CONFIG_FILE_PATH_PARAM = 'config_file' class AbstractBaseSettings(BaseSettings): _config_file_paths: ClassVar[list[Path]] = DEFAULT_CONFIG_FILE_PATHS def __init__(self, *args, **kwargs): config_file_path = kwargs.pop(CONFIG_FILE_PATH_PARAM, None) if config_file_path is not None: self._config_file_paths.append(Path(config_file_path)) super().__init__(*args, **kwargs) def get_config_file_paths(self) -> list[Path]: return self._config_file_paths class Config: allow_mutation = False env_file_encoding = 'utf-8' underscore_attrs_are_private = True @classmethod def customise_sources( cls, init_settings: SettingsSourceCallable, env_settings: SettingsSourceCallable, file_secret_settings: SettingsSourceCallable ) -> tuple[Callable, ...]: return init_settings, env_settings, _yaml_config_settings_source def _yaml_config_settings_source(settings_obj: AbstractBaseSettings) -> dict[str, Any]: """ Incrementally loads (and updates) settings from all config files that can be found as returned by the `Settings.get_config_file_paths` method and returns the result in a dictionary. This function is intended to be used as a settings source in the `Config.customise_sources` method. """ config = {} for path in settings_obj.get_config_file_paths(): if not path.is_file(): log.debug(f"No config file found at '{path}'") continue log.info(f"Reading config file '{path}'") with open(path, 'r') as f: config.update(safe_load(f)) return config class ServerSettings(BaseModel): host: str = '127.0.0.1' port: int = 9009 uds: str | None = None class DBURI(AnyUrl): host_required = False class Settings(AbstractBaseSettings): db_uri: DBURI | None = None server: ServerSettings = ServerSettings() log_config: dict | Path | None = None @validator('log_config') def configure_logging(cls, v: dict | Path | None) -> dict | None: if v is None: return None if isinstance(v, Path): with open(v, 'r') as f: logging_conf = safe_load(f) logging.config.dictConfig(logging_conf) return logging_conf if isinstance(v, dict): logging.config.dictConfig(v) return v raise TypeError settings = Settings() def init(**kwargs) -> None: settings.__init__(**kwargs) def update(**kwargs) -> None: settings_dict = settings.dict() settings_dict.update(kwargs) settings.__init__(**settings_dict)