Source code for pyvbmc.priors.scipy

from textwrap import indent

import numpy as np
from scipy.stats import multivariate_normal
from scipy.stats._distn_infrastructure import rv_continuous_frozen
from scipy.stats._multivariate import (
    multivariate_normal_frozen,
    multivariate_t_frozen,
)

from pyvbmc.formatting import full_repr
from pyvbmc.priors import Prior


[docs] class SciPy(Prior): """Wrapper class for `scipy.stats` distributions. Attributes ---------- D : int The dimension of the prior distribution. distribution : scipy.stats._multivariate.multivariate_normal_frozen\ or scipy.stats._multivariate.multivariate_t_frozen\ or scipy.stats._distn_infrastructure.rv_continuous_frozen The underlying `scipy.stats` distributions. """
[docs] def __init__(self, distribution): """Initialize `pyvbmc.priors.Prior` from `scipy.stats` distribution(s). Parameters ---------- distribution : scipy.stats.multivariate_normal\ or scipy.stats.multivariate_t\ or scipy.stats.rv_continuous\ The underlying `scipy.stats` distributions. Should be a multivariate normal distribution, a multivariate t distribution, or a univariate continuous distribution (for 1-D models). Raises ------ TypeError If the provided distribution is not of the appropriate type. """ if isinstance(distribution, multivariate_normal_frozen) or isinstance( distribution, multivariate_t_frozen ): x = np.atleast_1d(distribution.rvs(1)) self.D = len(x) self.a = np.full(self.D, -np.inf) self.b = np.full(self.D, np.inf) elif isinstance(distribution, rv_continuous_frozen): self.D = 1 self.a = np.atleast_1d(distribution.a) self.b = np.atleast_1d(distribution.b) else: raise TypeError( f'A SciPy prior should be initialized from a "frozen" multivariate normal, multivariate t, or univariate SciPy distribution, but got `distribution` of type {type(distribution)}.' ) self.distribution = distribution
def _log_pdf(self, x): """Compute the log-pdf of the multivariate uniform-box prior. Parameters ---------- x : np.ndarray The array of input point(s), of dimension `(D,)` or `(n,D)`, where `D` is the distribution dimension. Returns ------- log_pdf : np.ndarray The log-density of the prior at the input point(s), of dimension `(n, 1)`. """ n, D = x.shape log_pdf = self.distribution.logpdf(x).reshape((n, 1)) return log_pdf def sample(self, n): """Sample random variables from the uniform-box distribution. Parameters ---------- n : int The number of points to sample. Returns ------- rvs : np.ndarray The samples points, of shape `(n, D)`, where `D` is the dimension. """ rvs = self.distribution.rvs(n).reshape((n, self.D)) return rvs @classmethod def _generic(cls, D=1): """Return a generic instance of the class (used for tests).""" return SciPy(multivariate_normal(np.zeros(D))) def _support(self): """Returns the support of the distribution. Used to test that the distribution integrates to one, so it is also acceptable to return a box which bounds the support of the distribution. Returns ------- a, b : tuple(np.ndarray, np.ndarray) A tuple of lower and upper bounds of the support, such that [``a[i]``, ``b[i]``] bounds the support of the `i`th marginal. """ return self.a, self.b def __str__(self): """Print a string summary.""" return "SciPy prior:" + indent( f""" dimension = {self.D}, lower bounds = {self.a}, upper bounds = {self.b}, distribution(s) = {self.distribution}""", " ", ) def __repr__(self, arr_size_thresh=10, expand=False): """Construct a detailed string summary. Parameters ---------- arr_size_thresh : float, optional If ``obj`` is an array whose product of dimensions is less than ``arr_size_thresh``, print the full array. Otherwise print only the shape. Default `10`. expand : bool, optional If ``expand`` is `False`, then describe any complex child attributes of the object by their name and memory location. Otherwise, recursively expand the child attributes into their own representations. Default `False`. Returns ------- string : str The string representation of ``self``. """ return full_repr( self, "SciPy", order=[ "D", "a", "b", "distribution", ], expand=expand, arr_size_thresh=arr_size_thresh, )
[docs] def is_valid_scipy_dist(obj): """Assess whether ``obj`` is a valid SciPy distribution for a PyVBMC prior. A valid SciPy distribution is a frozen multivariate normal, multivariate t, or continuous univariate distribution. Parameters ---------- obj : any The object to type-check. Returns ------- is_valid : bool """ return ( isinstance(obj, multivariate_normal_frozen) or isinstance(obj, multivariate_t_frozen) or isinstance(obj, rv_continuous_frozen) )