factored out the YAML utilities to a separate file; added the YAML `!include` functionality

This commit is contained in:
Daniil Fajnberg 2021-12-30 19:15:35 +01:00
parent 6d80cf957f
commit 7618f6650e
2 changed files with 69 additions and 33 deletions

View File

@ -1,14 +1,13 @@
from pathlib import Path
from importlib import import_module
from typing import Dict, Callable, Union, Optional, Any, TYPE_CHECKING
from yaml import safe_load
if TYPE_CHECKING:
from aiohttp import ClientSession as AioSession, ClientResponse as AioResponse
from requests import Session as ReqSession, Response as ReqResponse
from .utils import PathT, yaml_overload
PathT = Union[Path, str]
CallableDefaultT = Callable[[], Optional[str]]
DefaultInitT = Union[str, Dict[str, str]]
OptionsT = Dict[str, str]
@ -176,36 +175,6 @@ class Form:
return post(self.url, data=self.get_payload(**kwargs))
def yaml_overload(*paths: PathT, dir_sort_key: Callable[[PathT], Any] = None, dir_sort_reverse: bool = False) -> dict:
"""
Loads YAML files from any number of paths, recursively going through any directories.
Args:
paths:
Each argument should be a path to either a YAML file or a directory containing YAML files;
only files with the extension `.yaml` are loaded.
dir_sort_key (optional):
If one of the paths is a directory, its contents are sorted, before recursively passing them into this
function; to apply a specific comparison key for each sub-path, a callable can be used here, which is
passed into the builtin `sorted` function.
dir_sort_reverse (optional):
Same as with the parameter above, this argument is also passed into the `sorted` function.
Returns:
Dictionary comprised of the contents of iteratively loaded YAML files.
NOTE: Since it is updated each time a file is loaded, their load order matters!
"""
output_dict = {}
for path in paths:
path = Path(path)
if path.is_dir():
output_dict.update(yaml_overload(*sorted(path.iterdir(), key=dir_sort_key, reverse=dir_sort_reverse)))
elif path.suffix == '.yaml':
with open(path, 'r') as f:
output_dict.update(safe_load(f))
return output_dict
def load_form(*def_paths: PathT, dir_sort_key: Callable[[PathT], Any] = None, dir_sort_reverse: bool = False,
full_payload: bool = True, url: str = None) -> Form:
"""

View File

@ -0,0 +1,67 @@
from pathlib import Path
from typing import Callable, IO, Union, Any
from yaml import SafeLoader, load, ScalarNode
PathT = Union[Path, str]
INCLUDE_TAG = '!include'
class RecursiveSafeLoader(SafeLoader):
"""
Custom `SafeLoader` for YAML streams that allows recursively referencing other files.
This is done by honoring the special `INCLUDE_TAG` followed by a path to a readable YAML file.
"""
def __init__(self, stream: IO) -> None:
try:
self._root = Path(stream.name).parent
except AttributeError:
self._root = Path()
super().__init__(stream)
def include(self, node: ScalarNode) -> Any:
"""
To be used in the constructor of a node that has the `INCLUDE_TAG` in it.
The path following that tag can be either relative (to the file's own parent directory) or absolute.
Since the constructor uses this class to load the node, arbitrary nesting of YAML file inclusions is possible.
"""
path = Path(self.construct_scalar(node))
if not path.is_absolute():
path = Path(self._root, path)
with open(path, 'r') as f:
return load(f, RecursiveSafeLoader)
RecursiveSafeLoader.add_constructor(INCLUDE_TAG, RecursiveSafeLoader.include)
def yaml_overload(*paths: PathT, dir_sort_key: Callable[[PathT], Any] = None, dir_sort_reverse: bool = False) -> dict:
"""
Loads YAML files from any number of paths, recursively going through any directories.
Args:
paths:
Each argument should be a path to either a YAML file or a directory containing YAML files;
only files with the extension `.yaml` are loaded.
dir_sort_key (optional):
If one of the paths is a directory, its contents are sorted, before recursively passing them into this
function; to apply a specific comparison key for each sub-path, a callable can be used here, which is
passed into the builtin `sorted` function.
dir_sort_reverse (optional):
Same as with the parameter above, this argument is also passed into the `sorted` function.
Returns:
Dictionary comprised of the contents of iteratively loaded YAML files.
NOTE: Since it is updated each time a file is loaded, their load order matters!
"""
output_dict = {}
for path in paths:
path = Path(path)
if path.is_dir():
output_dict.update(yaml_overload(*sorted(path.iterdir(), key=dir_sort_key, reverse=dir_sort_reverse)))
elif path.suffix == '.yaml':
with open(path, 'r') as f:
output_dict.update(load(f, RecursiveSafeLoader))
return output_dict