Source code for pyvbmc.decorators.handle_0D_1D_input

from functools import wraps

import numpy as np


[docs] def handle_0D_1D_input( patched_kwargs: [], patched_argpos: [], return_scalar=False ): """ A decorator that handles 0D, 1D inputs and transforms them to 2D. Parameters ---------- kwarg : list of str The names of the keyword arguments that should be handeled. argpos : list of int The positions of the arguments that should be handeled. return_scalar : bool, optional If the input is 1D the function should return a scalar, by default False. """ def decorator(function): @wraps(function) def wrapper(self, *args, **kwargs): for idx, patched_kwarg in enumerate(patched_kwargs): if patched_kwarg in kwargs: # for keyword arguments input_dims = np.ndim(kwargs.get(patched_kwarg)) kwargs[patched_kwarg] = np.atleast_2d( kwargs.get(patched_kwarg) ) elif len(args) > patched_argpos[idx]: # for positional arguments arg_list = list(args) input_dims = np.ndim(args[patched_argpos[idx]]) arg_list[patched_argpos[idx]] = np.atleast_2d( args[patched_argpos[idx]] ) args = tuple(arg_list) res = function(self, *args, **kwargs) # return value 1D or scalar when boolean set if input_dims == 1: # handle functions with multiple return values if type(res) is tuple: returnvalues = list(res) returnvalues = [o.ravel() for o in returnvalues] if return_scalar: returnvalues = [o[0] for o in returnvalues] return tuple(returnvalues) elif return_scalar and np.ndim(res) != 0: return res.ravel()[0] elif np.ndim(res) != 0: return res.ravel() return res return wrapper return decorator