Source code for pyvbmc.stats.kl_div_mvn
import numpy as np
from pyvbmc.decorators import handle_0D_1D_input
[docs]
@handle_0D_1D_input(
patched_kwargs=["mu1", "sigma1", "mu2", "sigma2"],
patched_argpos=[0, 1, 2, 3],
)
def kl_div_mvn(mu1, sigma1, mu2, sigma2):
"""
Compute the analytical Kullback-Leibler divergence between two multivariate
normal pdfs.
Parameters
----------
mu1 : np.ndarray
The k-dimensional mean vector of the first multivariate normal pdf.
sigma1 : np.ndarray
The covariance matrix of the first multivariate normal pdf.
mu2 : np.ndarray
The k-dimensional mean vector of the second multivariate normal pdf.
sigma2 : np.ndarray
The covariance matrix of the second multivariate normal pdf.
Returns
-------
kl_div : np.array
The computed Kullback-Leibler divergence.
"""
D = mu1.size
mu1 = mu1.reshape(-1, 1)
mu2 = mu2.reshape(-1, 1)
dmu = mu2 - mu1
detq1 = np.linalg.det(sigma1)
detq2 = np.linalg.det(sigma2)
if (detq1 == 0 or detq2 == 0):
# KL divergence is infinite
return np.concatenate((np.inf, np.inf), axis=None)
lndet = np.log(detq2 / detq1)
a, _, _, _ = np.linalg.lstsq(sigma2, sigma1, rcond=None)
b, _, _, _ = np.linalg.lstsq(sigma2, dmu, rcond=None)
kl1 = 0.5 * (np.trace(a) + dmu.T @ b - D + lndet)
a, _, _, _ = np.linalg.lstsq(sigma1, sigma2, rcond=None)
b, _, _, _ = np.linalg.lstsq(sigma1, dmu, rcond=None)
kl2 = 0.5 * (np.trace(a) + dmu.T @ b - D - lndet)
return np.concatenate((kl1, kl2), axis=None)