Source code for magmap.io.yaml_io

# YAML Input/Output
# Author: David Young, 2020
"""YAML file format input/output."""

from enum import Enum
from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING, Union

import yaml

from magmap.io import libmag

if TYPE_CHECKING:
    import pathlib


def _filter_dict(d: Dict, fn_parse_val: Callable[[Any], Any]) -> Dict:
    """Recursively filter keys and values within nested dictionaries
    
    Args:
        d: Dictionary to filter.
        fn_parse_val: Function to apply to each value. Should call
            this parent function if deep recursion is desired.

    Returns:
        Filtered dictionary.

    """
    out = {}
    for key, val in d.items():
        if isinstance(val, dict):
            # recursively filter nested dictionaries
            val = fn_parse_val(val)
        elif libmag.is_seq(val):
            # filter each val within list
            val = [fn_parse_val(v) for v in val]
        else:
            # filter a single val
            val = fn_parse_val(val)
        # filter key
        key = fn_parse_val(key)
        out[key] = val
    return out


[docs] def load_yaml( path: Union[str, "pathlib.Path"], enums: Optional[Dict[str, Enum]] = None) -> List[Dict]: """Load a YAML file with support for multiple documents and Enums. Args: path: Path to YAML file. enums: Dictionary mapping Enum names to Enum classes; defaults to None. If a key or value in the YAML file matches an Enum name followed by a period, the corresponding Enum will be used. Returns: Sequence of parsed dictionaries for each document within a YAML file. Raises: FileNotFoundError: if ``path`` could not be found or loaded. """ def parse_enum_val(val): # recursively parse Enum values if isinstance(val, dict): val = _filter_dict(val, parse_enum_val) elif libmag.is_seq(val): val = [parse_enum_val(v) for v in val] elif isinstance(val, str): val_split = val.split(".") if len(val_split) > 1 and val_split[0] in enums: # replace with the corresponding Enum class val = enums[val_split[0]][val_split[1]] return val try: with open(path) as yaml_file: # load all documents into a generator docs = yaml.load_all(yaml_file, Loader=yaml.FullLoader) data = [] for doc in docs: if not doc: # skip empty document continue if enums: doc = _filter_dict(doc, parse_enum_val) data.append(doc) except (FileNotFoundError, UnicodeDecodeError) as e: raise FileNotFoundError(e) return data
[docs] def save_yaml( path: Union[str, "pathlib.Path"], data: Dict, use_primitives: bool = False, convert_enums: bool = False) -> Dict: """Save a dictionary to YAML file format. Args: path: Output path. data: Dictionary to output. use_primitives: True to replace Numpy data types with Python primitives; defaults to False. convert_enums: True to convert keys and vals that are Enums to strings; defaults to False. Returns: ``data`` with any conversions. """ def convert_numpy_val(val): # recursively convert Numpy data types to primitives if isinstance(val, dict): val = _filter_dict(val, convert_numpy_val) elif libmag.is_seq(val): # also replaces any tuples with lists, avoiding tuple flags in # the output file for simplicity val = [convert_numpy_val(v) for v in val] else: try: val = val.item() except AttributeError: pass return val def convert_enum(val): # convert Enums to class.name strings if isinstance(val, Enum): return f"{val.__class__.__name__}.{val.name}" return val if use_primitives: # replace Numpy arrays and types with Python primitives data = _filter_dict(data, convert_numpy_val) if convert_enums: data = _filter_dict(data, convert_enum) with open(path, "w") as yaml_file: # save to YAML format yaml.dump(data, yaml_file) print("Saved data to:", path) return data