Source code for pyvbmc.priors.user_function

from textwrap import indent

import numpy as np
from scipy.stats import multivariate_normal

from pyvbmc.formatting import full_repr
from pyvbmc.priors import Prior


[docs] class UserFunction(Prior): """Lightweight wrapper for user-defined priors. Attributes ---------- log_pdf : callable The user-provided function representing the log-density of the prior. sample : callable or None A function for sampling from the prior, if provided. D : int or None The dimension of the prior, if provided. """
[docs] def __init__(self, log_prior, sample_prior=None, D=None): """Initialize a user-specified prior from a function. Parameters ---------- log_prior : callable, optional The user-provided function. Should take a one-dimensional array as a single argument, and return the log-density of the prior evaluated at that point. sample_prior : callable, optional An optional user-provided function for sampling form the prior. Should take an integer `n` as a single argument, and return `n` samples from the prior distribution as an `(n, D)` array. D : int, optional Specified dimension of the prior (optional). """ self.D = D if (log_prior is not None) and (not callable(log_prior)): raise TypeError("`log_prior` must be callable.") self.log_pdf = log_prior if (sample_prior is not None) and (not callable(sample_prior)): raise TypeError( f"Optional keyword `sample_prior` must be callable." ) self.sample = sample_prior
def pdf(self, *args, **kwargs): """Compute the pdf of the distribution.""" return np.exp(self.log_pdf(*args, **kwargs)) @classmethod def _generic(cls, D=1): """Return a generic instance of the class (used for tests).""" log_prior = lambda x: multivariate_normal(np.zeros(D)).logpdf(x) return cls(log_prior, D=D) def sample(self, n): """Unused""" pass def _log_pdf(self): """Unused""" pass def __str__(self): """Print a string summary.""" return "UserFunction prior:" + indent( f""" dimension = {self.D}, log pdf = {self.log_pdf}, sample function = {self.sample}""", " ", ) def __repr__(self, arr_size_thresh=10, expand=False): """Construct a detailed string summary. Parameters ---------- arr_size_thresh : float, optional If ``obj`` is an array whose product of dimensions is less than ``arr_size_thresh``, print the full array. Otherwise print only the shape. Default `10`. expand : bool, optional If ``expand`` is `False`, then describe any complex child attributes of the object by their name and memory location. Otherwise, recursively expand the child attributes into their own representations. Default `False`. Returns ------- string : str The string representation of ``self``. """ return full_repr( self, "UserFunction", order=[ "D", "log_pdf", "sample", ], expand=expand, arr_size_thresh=arr_size_thresh, )