Source code for pyvbmc.entropy.entlb_vbmc

import numpy as np

from pyvbmc.variational_posterior import VariationalPosterior


[docs] def entlb_vbmc( vp: VariationalPosterior, grad_flags: tuple = tuple([True] * 4), jacobian_flag: bool = True, ): r"""Entropy lower bound for variational posterior by Jensen's inequality. Parameters ---------- vp : VariationalPosterior An instance of VariationalPosterior class. grad_flags : tuple of bool, len(grad_flags)=4, optional Whether to compute the gradients for [mu, sigma, lambda, w]. jacobian_flag : bool, optional Whether variational parameters are transformed. The variational parameters and corresponding transformations are: sigma (log), lambda (log), w (softmax). Returns ------- H: float Entropy lower bound of vp [1]_. dH: np.ndarray Gradient of entropy lower bound. Raises ------ NotImplementedError Not implemented for K > BigK. References ---------- .. [1] Gershman, S. J., Hoffman, M. D., & Blei, D. M. (2012). Nonparametric variational inference. Proceedings of the 29th International Conference on Machine Learning, 235–242. """ BigK = np.inf # large number of components D = vp.D K = vp.K mu = vp.mu # [D, K] mu_t = mu.T # [K, D] sigma = vp.sigma.ravel() # [1,K] -> [K, ] lambd = vp.lambd.ravel() # [D,1] -> [D, ] w = vp.w.ravel() # [1,K] -> [K, ] eta = vp.eta.ravel() # [1,K] -> [K, ] # Check which gradients are computed mu_grad = np.zeros((D, K)) if grad_flags[0] else np.empty(0) sigma_grad = np.zeros(K) if grad_flags[1] else np.empty(0) lambd_grad = np.zeros(D) if grad_flags[2] else np.empty(0) w_grad = np.zeros(K) if grad_flags[3] else np.empty(0) if K == 1: # Entropy of single component, uses exact expression H = ( 0.5 * D * (1 + np.log(2 * np.pi)) + D * np.log(sigma).sum() + np.log(lambd).sum() ) if grad_flags[0]: mu_grad = np.zeros(D) if grad_flags[1]: sigma_grad = D / sigma if grad_flags[2]: lambd_grad = 1 / lambd if grad_flags[3]: w_grad = np.zeros(1) elif K > BigK: raise NotImplementedError("Not implemented yet for K > BigK.") else: # Multiple components sumsigma2 = sigma[:, None] ** 2 + sigma[None, :] ** 2 sumsigma = np.sqrt(sumsigma2) # [K, K] nconst = 1 / (2 * np.pi) ** (D / 2) / np.prod(lambd) d2 = ( (mu_t[:, None, :] - mu_t[None, :, :]) / (sumsigma[..., None] * lambd) ) ** 2 # [K, K, D] d2 = d2.sum(2) # [K, K] gamma = nconst / sumsigma**D * np.exp(-0.5 * d2) # [K, K] gammasum = (w * gamma).sum(1) # [K, ] H = -(w * np.log(gammasum)).sum() if any(grad_flags): gammafrac = ( gamma / gammasum ) # [K, K], gammafrac[i,j]=gamma[i,j]/gammasum[j] wgammafrac = ( w * gammafrac ) # [K, K], wgammafrac[i,j] = w[j]*gamma[i,j]/gammasum[j] if grad_flags[0]: dmu = (mu_t[:, None, :] - mu_t[None, :, :]) / ( sumsigma2[..., None] * lambd**2 ) # [K, K, D], dmu[i,j,:] = (mu_i - mu_j)/sumsigma2[i]/lambd^2 if grad_flags[1]: dsigma = -D / sumsigma2 + 1 / sumsigma2**2 * np.sum( ((mu_t[:, None, :] - mu_t[None, :, :]) / lambd) ** 2, 2 ) # [K, K] # Loop over mixture components for j in range(K): if grad_flags[0]: m1 = (wgammafrac[j, :][:, None] * dmu[:, j, :]).sum(0) m2 = ( dmu[:, j, :] * gamma[:, [j]] * w[:, None] ) / gammasum[ j ] # [K, D] m2 = m2.sum(0) mu_grad[:, j] = -w[j] * (m1 + m2) if grad_flags[1]: # Compute terms of gradient with respect to sigma_j s1 = (wgammafrac[j, :] * dsigma[:, j]).sum(0) s2 = (dsigma[:, j] * gamma[:, j] * w) / gammasum[j] s2 = s2.sum(0) sigma_grad[j] = -w[j] * sigma[j] * (s1 + s2) if grad_flags[2]: dmu2 = ( (mu_t[:, None, :] - mu_t[None, :, :]) ** 2 / sumsigma2[..., None] / lambd**2 ) # [K, K, D] lambd_grad[:] = ( -np.sum( w[:, None] * np.sum( w[:, None, None] * gamma[:, :, None] * (dmu2 - 1), 0, ) / gammasum[:, None], 0, ) / lambd ) if grad_flags[3]: w_grad[:] = -np.log(gammasum) - wgammafrac.sum(1) # Correct for standard log reparameterization of SIGMA if jacobian_flag and grad_flags[1]: sigma_grad = sigma_grad * sigma # Correct for standard log reparameterization of LAMBDA if jacobian_flag and grad_flags[2]: lambd_grad = lambd_grad * lambd # Correct for standard softmax reparameterization of W if jacobian_flag and grad_flags[3]: eta_exp = np.exp(eta) eta_sum = eta_exp.sum() J_w = ( -eta_exp[:, None] * eta_exp[None, :] / eta_sum**2 + np.diag(eta_exp) / eta_sum ) w_grad = J_w @ w_grad dH = np.concatenate([mu_grad.ravel("F"), sigma_grad, lambd_grad, w_grad]) return H, dH