Source code for my_code_base.core.utils

[docs] # %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% # Author: Markus Ritschel # eMail: git@markusritschel.de # Date: 2024-03-03 # %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% # import functools import logging import matplotlib.pyplot as plt import numpy as np import pandas as pd import xarray as xr logging.basicConfig(level="INFO") log = logging.getLogger(__name__)
[docs] def add_metadata(func): """ A decorator that adds metadata to the function's output. The metadata includes the relative path of the file, line number, and git commit hash. Parameters ---------- func : callable The function to be decorated. Returns ------- callable The decorated function. """ import os import subprocess import sys from pathlib import Path @functools.wraps(func) def wrapper(*args, **kwargs): kwargs.setdefault('add_hash', False) meta = collect_metadata() kwargs['metadata'] = meta args = list(args) obj = args[0] path = Path(args[1]) suffix = '' if kwargs.pop('add_hash'): suffix += f"_{meta.git_commit}" output_path = f'{path.parent}/{path.stem}{suffix}{path.suffix}' args[1] = output_path obj_type = get_obj_type_str(obj) log.info(f"Saved {obj_type} to {output_path}, produced by {meta.relative_code_path}#{meta.line_number} @git-commit:{meta.git_commit}") return func(*args, **kwargs) def collect_metadata(): frame = sys._getframe(1).f_back code_filename = frame.f_code.co_filename line_number = frame.f_lineno #- 1 # TODO: <-- check! relative_code_path = os.path.relpath(code_filename) git_commit = subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD']).decode('ascii').strip() metadata = {} metadata['relative_code_path'] = relative_code_path metadata['line_number'] = str(line_number) metadata['git_commit'] = git_commit return BunchDict(metadata) return wrapper
[docs] class BunchDict(dict): """BunchDict is a subclass of the built-in dict class that allows accessing dictionary keys as attributes. This class overrides the `__getattr__` and `__setattr__` methods to provide attribute-style access to dictionary keys. When an attribute is accessed, it is treated as a dictionary key and the corresponding value is returned. When an attribute is set, it is treated as a dictionary key and the corresponding value is updated. .. note:: This is now also implemented in :class:`sklearn.utils.Bunch` Example ------- >>> bd = BunchDict() >>> bd['key'] = 'value' >>> print(bd.key) value >>> bd.key = 'new value' >>> print(bd['key']) new value """ def __getattr__(self, attr): return self[attr] def __setattr__(self, attr, value): self[attr] = value
[docs] def get_obj_type_str(obj): """Transform the output of `type` to a simplified descriptor: Turns "<class 'xarray.core.dataset.Dataset'>" into "Dataset" """ return str(type(obj)).split("'")[1].split('.')[-1]
@add_metadata @functools.singledispatch
[docs] def save(obj, path, *args, **kwargs): """Save the given object including metadata. This is a dispatchable function. That is, there are several implementations for different types of objects (e.g. :class:`matplotlib.figure.Figure`, :class:`pandas.DataFrame`, :class:`xarray.Dataset`). In case there is no implementation, the function will throw a :class:`NotImplementedError`. Parameters ---------- obj : object The object to be saved. path : str The path to which the object will be saved. Raises ------ NotImplementedError If the according function is not dispatched. Notes ----- This function raises a NotImplementedError because it is meant to be overridden by subclasses. To save objects of a specific type, please use the native method provided by that type. Examples -------- >>> ds = xr.tutorial.load_dataset('air_temperature') # doctest: +SKIP >>> save(ds, '/tmp/mynetcdf.nc', add_hash=True) # doctest: +SKIP >>> !ncdump -h /tmp/mynetcdf_500e15f.nc | grep history # doctest: +SKIP :history = "2024-06-12 16:18:16: File saved by myscript.py#3 @git-commit:500e15f;" >>> save(my_object, '/tmp/myobj') # doctest: +SKIP NotImplementedError: Cannot save object of type <class 'type'> using `save` method. Please use the native method. """ raise NotImplementedError(f"No implementation of `save` found for object of type {type(obj)}. " "Please use the native method.")
@save.register(plt.Figure) def _(fig, path, *args, **kwargs): plt.savefig(path, *args, **kwargs) @save.register(pd.DataFrame) def _(df, path, *args, **kwargs): del kwargs['metadata'] # df.to_csv cannot interpret `metadata` df.to_csv(path, *args, **kwargs) @save.register(xr.Dataset) def _(ds, path, *args, **kwargs): from .xarray_utils import HistoryAccessor metadata = kwargs.pop('metadata') msg = f"File saved by {metadata['relative_code_path']}#{metadata['line_number']} @git-commit:{metadata['git_commit']}" ds = ds.history.add(msg) ds.to_netcdf(path, *args, **kwargs)
[docs] def centered_bins(x): """Create centered bin boundaries from a given array with the values of the array as centers. Example ------- >>> x = np.arange(-3, 4) >>> x array([-3, -2, -1, 0, 1, 2, 3]) >>> centered_bins(x) array([-3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5]) """ x = np.array(x) x = np.append(x, x[-1] + np.diff(x[-2:])) differences = np.gradient(x, 2) return x - differences
[docs] def find_nearest(items: list | np.ndarray, pivot: float) -> float: """ Find the element inside `items` that is closest to the `pivot` element. Parameters ---------- items: A list of elements to search from. pivot: The pivot element to find the closest element to. Returns ------- float The element from `items` that is closest to the `pivot` element. Examples -------- >>> result = find_nearest(np.array([2,4,5,7,9,10]), 4.6) >>> int(result) # Cast to int for consistent comparison 5 """ return min(items, key=lambda x: abs(x - pivot))
[docs] def order_of_magnitude(x: int | float | np.ndarray | pd.Series) -> np.ndarray: """Determine the order of magnitude of the numeric input. Examples -------- >>> order_of_magnitude(11) array([1.]) >>> order_of_magnitude(234) array([2.]) >>> order_of_magnitude(1) array([0.]) >>> order_of_magnitude(.15) array([-1.]) >>> order_of_magnitude(np.array([24.13, 254.2])) array([1., 2.]) >>> order_of_magnitude(pd.Series([24.13, 254.2])) array([1., 2.]) """ x = np.asarray(x) if np.all(x == 0): return None x = x[x != 0] oom = np.floor(np.log10(x)) # oom = (np.int32(np.log10(np.abs(x))) + 1) return np.array(oom)