yamlhttpforms/src/yamlhttpforms/utils.py

68 lines
2.7 KiB
Python

from pathlib import Path
from typing import Callable, IO, Union, Any
from yaml import SafeLoader, load, ScalarNode
PathT = Union[Path, str]
INCLUDE_TAG = '!include'
class RecursiveSafeLoader(SafeLoader):
"""
Custom `SafeLoader` for YAML streams that allows recursively referencing other files.
This is done by honoring the special `INCLUDE_TAG` followed by a path to a readable YAML file.
"""
def __init__(self, stream: IO) -> None:
try:
self._root = Path(stream.name).parent
except AttributeError:
self._root = Path()
super().__init__(stream)
def include(self, node: ScalarNode) -> Any:
"""
To be used in the constructor of a node that has the `INCLUDE_TAG` in it.
The path following that tag can be either relative (to the file's own parent directory) or absolute.
Since the constructor uses this class to load the node, arbitrary nesting of YAML file inclusions is possible.
"""
path = Path(self.construct_scalar(node))
if not path.is_absolute():
path = Path(self._root, path)
with open(path, 'r') as f:
return load(f, RecursiveSafeLoader)
RecursiveSafeLoader.add_constructor(INCLUDE_TAG, RecursiveSafeLoader.include)
def yaml_overload(*paths: PathT, dir_sort_key: Callable[[PathT], Any] = None, dir_sort_reverse: bool = False) -> dict:
"""
Loads YAML files from any number of paths, recursively going through any directories.
Args:
paths:
Each argument should be a path to either a YAML file or a directory containing YAML files;
only files with the extension `.yaml` are loaded.
dir_sort_key (optional):
If one of the paths is a directory, its contents are sorted, before recursively passing them into this
function; to apply a specific comparison key for each sub-path, a callable can be used here, which is
passed into the builtin `sorted` function.
dir_sort_reverse (optional):
Same as with the parameter above, this argument is also passed into the `sorted` function.
Returns:
Dictionary comprised of the contents of iteratively loaded YAML files.
NOTE: Since it is updated each time a file is loaded, their load order matters!
"""
output_dict = {}
for path in paths:
path = Path(path)
if path.is_dir():
output_dict.update(yaml_overload(*sorted(path.iterdir(), key=dir_sort_key, reverse=dir_sort_reverse)))
elif path.suffix == '.yaml':
with open(path, 'r') as f:
output_dict.update(load(f, RecursiveSafeLoader))
return output_dict