Source code for nifty8.re.evidence_lower_bound

#!/usr/bin/env python3

# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause

import jax.flatten_util
import numpy as np
import scipy.linalg as slg
import scipy.sparse.linalg as ssl

from .evi import Samples
from .likelihood import Likelihood
from .logger import logger
from .optimize_kl import _StandardHamiltonian as StandardHamiltonian
from .tree_math.vector_math import size


class _Projector(ssl.LinearOperator):
    """Computes the projector of a Matrix or LinearOperator as a LinearOperator
    given the eigenvectors of the complementary space.

    Parameters
    ----------
    eigenvectors : ndarray
        The eigenvectors representing the directions to project out.

    Returns
    -------
    Projector : LinearOperator
        Operator representing the projection.
    """

    def __init__(self, eigenvectors):
        super().__init__(eigenvectors.dtype, 2 * (eigenvectors.shape[0],))
        self.eigenvectors = eigenvectors

    def _matvec(self, x):
        res = x.copy()
        for eigenvector in self.eigenvectors.T:
            res -= eigenvector * np.vdot(eigenvector, x)
        return res

    def _rmatvec(self, x):
        return self._matvec(x)


def _explicify(M):
    n = M.shape[0]
    m = []
    for i in range(n):
        basis_vector = np.zeros(n)
        basis_vector[i] = 1
        m.append(M @ basis_vector)
    return np.stack(m, axis=1)


def _ravel_metric(metric, position, dtype):
    shape = 2 * (size(metric(position, position)),)

    ravel = lambda x: jax.flatten_util.ravel_pytree(x)[0]
    unravel = lambda x: jax.linear_transpose(ravel, position)(x)[0]
    met = lambda x: ravel(metric(position, unravel(x)))

    return ssl.LinearOperator(shape=shape, dtype=dtype, matvec=met)


def _eigsh(
    metric,
    n_eigenvalues,
    tot_dofs,
    min_lh_eval=1e-4,
    batch_size=10,
    tol=0.0,
    verbose=True,
):
    metric_size = metric.shape[0]
    eigenvectors = None
    if n_eigenvalues > tot_dofs:
        raise ValueError(
            "Number of requested eigenvalues "
            "exceeds the number of relevant degrees of freedom!"
        )

    if tot_dofs == n_eigenvalues:
        # Compute exact eigensystem
        if verbose:
            logger.info(f"Computing all {tot_dofs} relevant metric eigenvalues.")
        eigenvalues = slg.eigh(
            _explicify(metric),
            eigvals_only=True,
            subset_by_index=[metric_size - tot_dofs, metric_size - 1],
        )
        eigenvalues = np.flip(eigenvalues)
    else:
        # Set up batches
        batch_size = n_eigenvalues // batch_size
        batches = [
            batch_size,
        ] * (batch_size - 1)
        batches += [
            n_eigenvalues - batch_size * (batch_size - 1),
        ]
        eigenvalues, projected_metric = None, metric
        for batch in batches:
            if verbose:
                logger.info(f"\nNumber of eigenvalues being computed: {batch}")
            # Get eigensystem for current batch
            eigvals, eigvecs = ssl.eigsh(
                projected_metric, k=batch, tol=tol, return_eigenvectors=True, which="LM"
            )
            i = np.argsort(-eigvals)
            eigvals, eigvecs = eigvals[i], eigvecs[:, i]
            eigenvalues = (
                eigvals
                if eigenvalues is None
                else np.concatenate((eigenvalues, eigvals))
            )
            eigenvectors = (
                eigvecs if eigenvectors is None else np.hstack((eigenvectors, eigvecs))
            )

            if abs(1.0 - np.min(eigenvalues)) < min_lh_eval:
                break
            # Project out subspace of already computed eigenvalues
            projector = _Projector(eigenvectors)
            projected_metric = projector @ metric @ projector.T
    return eigenvalues, eigenvectors


[docs] def estimate_evidence_lower_bound( likelihood, samples, n_eigenvalues, min_lh_eval=1e-3, batch_size=10, tol=0.0, verbose=True, ): """Provides an estimate for the Evidence Lower Bound (ELBO). Statistical inference deals with the problem of hypothesis testing, given some data and models that can describe it. In general, it is hard to find a good metric to discern between different models. In Bayesian Inference, the Bayes factor can serve this purpose. To compute the Bayes factor it is necessary to calculate the evidence, given the specific model :math:`p( d|\\text{model})` at hand. Then, the ratio between the evidence of a model A and the one of a model B represents how much more likely it is that model A represents the data better than model B. The evidence for an approximated-inference problem can in principle be calculated. However, this is only practically feasible in a low-dimensional setting. What often can be computed is .. math :: \\log(p(d)) - D_\\text{KL} \\left[ Q(\\theta(\\xi)|d) || p(\\theta(\\xi) | d) \\right] = -\\langle H(\\theta( \\xi), d)\\rangle + \\frac1 2 \\left( N + \\text{Tr } \\log\\Lambda\\right), where :math:`D_\\text{KL} \\left[ Q || p \\right]` is the Kullback-Leibler (KL) divergence between the approximating posterior distribution :math:`Q` and the actual posterior :math:`p`. Since the Kullback-Leibler divergence :math:`D_\\text{KL} [\\cdot, \\cdot] \\geq 0` is positive definite, it is convenient to consider the lower bound .. math :: \\log(p(d)) \\geq -\\langle H(\\theta(\\xi), d)\\rangle + \\frac1 2 \\left(N + \\text{Tr } \\log\\Lambda \\right), which takes the name of Evidence Lower Bound (ELBO). If the KL divergence is well minimized (which should always be the case when a Variational Inference approach is followed), then it is possible to utilize the ELBO (as a proxy for the actual evidences) and calculate the Bayes factors for model comparison. Parameters ---------- likelihood : :class:`nifty8.re.likelihood.Likelihood` Log-likelihood of the model. samples : :class:`nifty8.re.evi.Samples` Collection of samples from the posterior distribution. n_eigenvalues : int Maximum number of eigenvalues to be considered for the estimation of the log-determinant of the metric. Note that if `n_eigenvalues` equals the total number of relevant degrees of freedom of the problem, all relevant eigenvalues are always computed irrespective of other stopping criteria. min_lh_eval : float Smallest eigenvalue of the likelihood to be considered. If the estimated eigenvalues become smaller than 1 + `min_lh_eval`, the eigenvalue estimation terminates and uses the smallest eigenvalue as a proxy for all remaining eigenvalues in the trace-log estimation. Default is 1e-3. batch_size : int Number of batches into which the eigenvalue estimation gets subdivided into. Only after completing one batch the early stopping criterion based on `min_lh_eval` is checked for. tol : Optional[float] Tolerance on the eigenvalue calculation. Zero indicates machine precision. Default is 0. verbose : Optional[bool] Print list of eigenvalues and summary of evidence calculation. Default is True. Returns ------- `elbo_samples` : np.array List of elbo samples from the posterior distribution. The samples are returned to allow for more accurate elbo statistics. stats : dict Dictionary with a summary of the statistics of the estimated ELBO. The keys of this dictionary are: - `elbo_mean`: returns the mean value of the elbo estimate calculated over posterior samples - `elbo_up`: returns an upper bound to the elbo estimate (given by one posterior-sample standard deviation) - `elbo_lw`: returns a lower bound to the elbo estimate (one standard deviation plus a maximal error on the metric trace-log) - `lower_error`: maximal error on the metric trace-log term given by the number of relevant metric eigenvalues different from 1 neglected in the estimation of the trace-log times the log of the smallest calculated eigenvalue. Warning ------- To perform Variational Inference there is no need to take into account quantities that are not explicitly dependent on the inferred parameters. Explicitly calculating these terms can be expensive, therefore they are neglected in NIFTy. Since in most cases they are also not required for model comparison, the provided estimate may not include terms which are constant in these parameters. Only when comparing models for which the likelihood includes (possibly data-dependent) constants (or when the ELBO is needed to approximate the true evidence) these contributions have to be considered. For example, for a Gaussian distributed signal and a linear problem (Wiener Filter problem) the only term missing is :math:`-\\frac1 2 \\log \\det |2 \\pi N|`, where :math:`N` is the noise covariance matrix. See also -------- For further details we refer to: * Analytic geoVI parametrization: P. Frank et al., Geometric Variational Inference <https://arxiv.org/pdf/2105.10470.pdf> (Sec. 5.1) * Conceptualization: A. Kostić et al. (manuscript in preparation). """ if not isinstance(samples, Samples): raise TypeError("samples attribute should be of type `Samples`.") if not isinstance(likelihood, Likelihood): raise TypeError("likelhood is not an instance of `Likelihood`.") hamiltonian = StandardHamiltonian(likelihood) metric = hamiltonian.metric metric = _ravel_metric(metric, samples.pos, dtype=likelihood.target.dtype) metric_size = metric.shape[0] n_data_points = likelihood.lsm_tangents_shape.size if not None else metric_size n_relevant_dofs = min( n_data_points, metric_size ) # Number of metric eigenvalues that are not 1 eigenvalues, _ = _eigsh( metric, n_eigenvalues, tot_dofs=n_relevant_dofs, min_lh_eval=min_lh_eval, batch_size=batch_size, tol=tol, verbose=verbose, ) if verbose: logger.info( f"\nComputed {eigenvalues.size} largest eigenvalues (out of {n_relevant_dofs} " f"relevant degrees of freedom)." f"\nThe remaining {metric_size - n_relevant_dofs} metric eigenvalues " f"(out of {metric_size}) are equal to " f"1.\n\n{eigenvalues}." ) # Return a list of ELBO samples and a summary of the ELBO statistics log_eigenvalues = np.log(eigenvalues) tr_log_lat_cov = -0.5 * np.sum(log_eigenvalues) tr_log_lat_cov_lower = ( 0.5 * (n_relevant_dofs - log_eigenvalues.size) * np.min(log_eigenvalues) ) posterior_contribution = tr_log_lat_cov + 0.5 * metric_size elbo_samples = np.array( list(posterior_contribution - hamiltonian(s) for s in samples) ) stats = {"lower_error": tr_log_lat_cov_lower} elbo_mean = np.mean(elbo_samples) elbo_std = np.std(elbo_samples, ddof=1) elbo_up = elbo_mean + elbo_std elbo_lw = elbo_mean - elbo_std - stats["lower_error"] stats["elbo_mean"], stats["elbo_up"], stats["elbo_lw"] = elbo_mean, elbo_up, elbo_lw if verbose: s = ( f"\nELBO decomposition (in log units)" f"\nELBO mean : {elbo_mean:.4e} (upper: {elbo_up:.4e}, lower: {elbo_lw:.4e})" ) logger.info(s) return elbo_samples, stats