From 7618f6650e9262c837f877ad751f4f5728cba7c7 Mon Sep 17 00:00:00 2001 From: Daniil Fajnberg Date: Thu, 30 Dec 2021 19:15:35 +0100 Subject: [PATCH] factored out the YAML utilities to a separate file; added the YAML `!include` functionality --- src/yamlhttpforms/form.py | 35 ++------------------ src/yamlhttpforms/utils.py | 67 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+), 33 deletions(-) create mode 100644 src/yamlhttpforms/utils.py diff --git a/src/yamlhttpforms/form.py b/src/yamlhttpforms/form.py index 1512d87..eb34500 100644 --- a/src/yamlhttpforms/form.py +++ b/src/yamlhttpforms/form.py @@ -1,14 +1,13 @@ -from pathlib import Path from importlib import import_module from typing import Dict, Callable, Union, Optional, Any, TYPE_CHECKING -from yaml import safe_load if TYPE_CHECKING: from aiohttp import ClientSession as AioSession, ClientResponse as AioResponse from requests import Session as ReqSession, Response as ReqResponse +from .utils import PathT, yaml_overload + -PathT = Union[Path, str] CallableDefaultT = Callable[[], Optional[str]] DefaultInitT = Union[str, Dict[str, str]] OptionsT = Dict[str, str] @@ -176,36 +175,6 @@ class Form: return post(self.url, data=self.get_payload(**kwargs)) -def yaml_overload(*paths: PathT, dir_sort_key: Callable[[PathT], Any] = None, dir_sort_reverse: bool = False) -> dict: - """ - Loads YAML files from any number of paths, recursively going through any directories. - - Args: - paths: - Each argument should be a path to either a YAML file or a directory containing YAML files; - only files with the extension `.yaml` are loaded. - dir_sort_key (optional): - If one of the paths is a directory, its contents are sorted, before recursively passing them into this - function; to apply a specific comparison key for each sub-path, a callable can be used here, which is - passed into the builtin `sorted` function. - dir_sort_reverse (optional): - Same as with the parameter above, this argument is also passed into the `sorted` function. - - Returns: - Dictionary comprised of the contents of iteratively loaded YAML files. - NOTE: Since it is updated each time a file is loaded, their load order matters! - """ - output_dict = {} - for path in paths: - path = Path(path) - if path.is_dir(): - output_dict.update(yaml_overload(*sorted(path.iterdir(), key=dir_sort_key, reverse=dir_sort_reverse))) - elif path.suffix == '.yaml': - with open(path, 'r') as f: - output_dict.update(safe_load(f)) - return output_dict - - def load_form(*def_paths: PathT, dir_sort_key: Callable[[PathT], Any] = None, dir_sort_reverse: bool = False, full_payload: bool = True, url: str = None) -> Form: """ diff --git a/src/yamlhttpforms/utils.py b/src/yamlhttpforms/utils.py new file mode 100644 index 0000000..61dafe7 --- /dev/null +++ b/src/yamlhttpforms/utils.py @@ -0,0 +1,67 @@ +from pathlib import Path +from typing import Callable, IO, Union, Any + +from yaml import SafeLoader, load, ScalarNode + + +PathT = Union[Path, str] + +INCLUDE_TAG = '!include' + + +class RecursiveSafeLoader(SafeLoader): + """ + Custom `SafeLoader` for YAML streams that allows recursively referencing other files. + This is done by honoring the special `INCLUDE_TAG` followed by a path to a readable YAML file. + """ + def __init__(self, stream: IO) -> None: + try: + self._root = Path(stream.name).parent + except AttributeError: + self._root = Path() + super().__init__(stream) + + def include(self, node: ScalarNode) -> Any: + """ + To be used in the constructor of a node that has the `INCLUDE_TAG` in it. + The path following that tag can be either relative (to the file's own parent directory) or absolute. + Since the constructor uses this class to load the node, arbitrary nesting of YAML file inclusions is possible. + """ + path = Path(self.construct_scalar(node)) + if not path.is_absolute(): + path = Path(self._root, path) + with open(path, 'r') as f: + return load(f, RecursiveSafeLoader) + + +RecursiveSafeLoader.add_constructor(INCLUDE_TAG, RecursiveSafeLoader.include) + + +def yaml_overload(*paths: PathT, dir_sort_key: Callable[[PathT], Any] = None, dir_sort_reverse: bool = False) -> dict: + """ + Loads YAML files from any number of paths, recursively going through any directories. + + Args: + paths: + Each argument should be a path to either a YAML file or a directory containing YAML files; + only files with the extension `.yaml` are loaded. + dir_sort_key (optional): + If one of the paths is a directory, its contents are sorted, before recursively passing them into this + function; to apply a specific comparison key for each sub-path, a callable can be used here, which is + passed into the builtin `sorted` function. + dir_sort_reverse (optional): + Same as with the parameter above, this argument is also passed into the `sorted` function. + + Returns: + Dictionary comprised of the contents of iteratively loaded YAML files. + NOTE: Since it is updated each time a file is loaded, their load order matters! + """ + output_dict = {} + for path in paths: + path = Path(path) + if path.is_dir(): + output_dict.update(yaml_overload(*sorted(path.iterdir(), key=dir_sort_key, reverse=dir_sort_reverse))) + elif path.suffix == '.yaml': + with open(path, 'r') as f: + output_dict.update(load(f, RecursiveSafeLoader)) + return output_dict