Source code for nifty8.minimization.kl_energies

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2022 Max-Planck-Society
# Authors: Philipp Arras, Philipp Frank
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.

from functools import reduce

import numpy as np

from .. import random, utilities
from ..linearization import Linearization
from ..multi_domain import MultiDomain
from ..multi_field import MultiField
from ..operators.endomorphic_operator import EndomorphicOperator
from ..operators.energy_operators import GaussianEnergy, StandardHamiltonian
from ..operators.sampling_enabler import SamplingEnabler
from ..operators.sandwich_operator import SandwichOperator
from ..operators.scaling_operator import ScalingOperator
from ..probing import approximation2endo
from ..sugar import makeOp
from ..utilities import get_MPI_params_from_comm, myassert, shareRange
from .descent_minimizers import DescentMinimizer
from .energy import Energy
from .energy_adapter import EnergyAdapter
from .sample_list import ResidualSampleList, SampleListBase


def _reduce_field(field, keys):
    if isinstance(field, MultiField) and len(keys)>0:
        return field.extract_by_keys(set(field.keys()) - set(keys))
    return field


def _reduce_by_keys(field, operator, keys):
    """Partially insert a field into an operator

    If the domain of the operator is an instance of `DomainTuple`

    Parameters
    ----------
    field : :class:`nifty8.field.Field` or :class:`nifty8.multi_field.MultiField`
        Potentially partially constant input field.
    operator : Operator
        Operator into which `field` is partially inserted.
    keys : list
        List of constant `MultiDomain` entries.

    Returns
    -------
    list
        The variable part of the field and the contracted operator.
    """
    from ..sugar import is_fieldlike, is_operator
    myassert(is_fieldlike(field))
    myassert(is_operator(operator))
    if isinstance(field, MultiField):
        cst_field = field.extract_by_keys(keys)
        var_field = field.extract_by_keys(set(field.keys()) - set(keys))
        _, operator = operator.simplify_for_constant_input(cst_field)
        return var_field, operator
    myassert(len(keys) == 0)
    return field, operator


class _SelfAdjointOperatorWrapper(EndomorphicOperator):
    def __init__(self, domain, func):
        from ..sugar import makeDomain
        self._func = func
        self._capability = self.TIMES | self.ADJOINT_TIMES
        self._domain = makeDomain(domain)

    def apply(self, x, mode):
        self._check_input(x, mode)
        return self._func(x)


[docs] def draw_samples(position, H, minimizer, n_samples, mirror_samples, napprox=0, want_error=False, comm=None): if not isinstance(n_samples, int): raise TypeError if not isinstance(mirror_samples, bool): raise TypeError if not isinstance(H, StandardHamiltonian): raise TypeError if isinstance(position, MultiField): sam_position = position.extract(H.domain) else: sam_position = position # Construct transformation geometric = minimizer is not None if geometric: tr = H.likelihood_energy.get_transformation() if tr is None: raise ValueError("Geometric sampling only works for likelihoods") dtype, f_lh = tr if isinstance(dtype, dict): myassert(all([dtype[k] is not None for k in dtype.keys()])) else: myassert(dtype is not None) scale = ScalingOperator(f_lh.target, 1., dtype) fl = f_lh(Linearization.make_var(sam_position)) transformation = ScalingOperator(f_lh.domain, 1.) + fl.jac.adjoint @ f_lh transformation_mean = sam_position + fl.jac.adjoint(fl.val) # Note: This metric is equivalent to H.metric, except for the case of a # `VariableCovarianceGaussianEnergy` with `use_full_fisher = True`. met = SamplingEnabler(SandwichOperator.make(fl.jac, scale), ScalingOperator(fl.domain, 1., float), H.iteration_controller) else: met = H(Linearization.make_var(sam_position, want_metric=True)).metric if napprox >= 1: met._approximation = makeOp(approximation2endo(met, napprox)) # /Construct transformation # Draw samples sseq = random.spawn_sseq(n_samples) if mirror_samples: sseq = reduce(lambda a, b: a+b, [[ss]*2 for ss in sseq]) local_samples = [] local_neg = [] utilities.check_MPI_synced_random_state(comm) utilities.check_MPI_equality(sseq, comm) y = None ntask, rank, _ = get_MPI_params_from_comm(comm) for i in range(*shareRange(len(sseq), ntask, rank)): with random.Context(sseq[i]): neg = mirror_samples and (i % 2 != 0) if not neg or y is None: # we really need to draw a sample y, yi = met.special_draw_sample(True) if geometric: m = transformation_mean - y if neg else transformation_mean + y pos = sam_position - yi if neg else sam_position + yi en = GaussianEnergy(m) @ transformation en = EnergyAdapter(pos, en, nanisinf=True, want_metric=True) en, _ = minimizer(en) local_samples.append(en.position - sam_position) local_neg.append(False) else: local_samples.append(yi) local_neg.append(neg) return ResidualSampleList(position, local_samples, local_neg, comm)
[docs] def SampledKLEnergy(position, hamiltonian, n_samples, minimizer_sampling, mirror_samples=True, constants=[], point_estimates=[], napprox=0, comm=None, nanisinf=True): """Provides the sampled Kullback-Leibler used for Variational Inference, specifically for geometric Variational Inference (geoVI) and Metric Gaussian VI (MGVI). In geoVI a probability distribution is approximated with a standard normal distribution in the canonical coordinate system of the Riemannian manifold associated with the metric of the other distribution. The coordinate transformation is approximated by expanding around a point. The MGVI simplification occurs in case this transformation can be approximated using a linear expansion. In order to infer the optimal expansion point, a stochastic estimate of the Kullback-Leibler divergence is minimized. This estimate is obtained by sampling from the approximation using the current expansion point. During minimization these samples are kept constant; only the expansion point is updated. Due to the typically nonlinear structure of the true distribution these samples have to be updated eventually by instantiating a `SampledKLEnergy` again. For the true probability distribution the standard parametrization is assumed. The samples of this class can be distributed among MPI tasks. Parameters ---------- position : :class:`nifty8.field.Field` Expansion point of the coordinate transformation. hamiltonian : :class:`nifty8.operators.energy_operators.StandardHamiltonian` Hamiltonian of the approximated probability distribution. n_samples : integer Number of samples used to stochastically estimate the KL. minimizer_samp : DescentMinimizer or None Minimizer used to perform the non-linear part of geoVI sampling. If it is None, only the linear (MGVI) approximation for sampling is used and no further non-linear steps are performed. mirror_samples : boolean Whether the mirrored version of the drawn samples are also used. If true, the number of used samples doubles. Mirroring samples stabilizes the KL estimate as extreme sample variation is counterbalanced. Default is True. constants : list List of parameter keys that are kept constant during optimization. Default is no constants. point_estimates : list List of parameter keys for which no samples are drawn, but that are (possibly) optimized for, corresponding to point estimates of these. Default is to draw samples for the complete domain. napprox : int Number of samples for computing preconditioner for linear sampling. No preconditioning is done by default. comm : MPI communicator or None If not None, samples will be distributed as evenly as possible across this communicator. If `mirror_samples` is set, then a sample and its mirror image will preferably reside on the same task if necessary. nanisinf : bool If true, nan energies which can happen due to overflows in the forward model are interpreted as inf. Thereby, the code does not crash on these occasions but rather the minimizer is told that the position it has tried is not sensible. Note ---- The two lists `constants` and `point_estimates` are independent from each other. It is possible to sample along domains which are kept constant during minimization and vice versa. If a key is in both lists, it will be inserted into the Hamiltonian and removed from the KL. Note ---- Mirroring samples can help to stabilize the latent mean as it reduces sampling noise. But a mirrored sample involves an additional solve of the non-linear part of the transformation. When using MPI, the samples get distributed as evenly as possible over all tasks. If the number of tasks is smaller then the total number of samples (including mirrored ones), the mirrored pairs try to reside on the same task as their non mirrored partners. This ensures that at least the linear part of the sampling is re-used. See also -------- `Geometric Variational Inference`, Philipp Frank, Reimar Leike, Torsten A. Enßlin, `<https://arxiv.org/abs/2105.10470>`_ `<https://doi.org/10.3390/e23070853>`_ `Metric Gaussian Variational Inference`, Jakob Knollmüller, Torsten A. Enßlin, `<https://arxiv.org/abs/1901.11033>`_ Consider citing these papers, if you use MGVI or geoVI. """ if not isinstance(hamiltonian, StandardHamiltonian): raise TypeError if hamiltonian.domain is not position.domain: raise ValueError if not isinstance(n_samples, int): raise TypeError if not isinstance(mirror_samples, bool): raise TypeError if not (minimizer_sampling is None or isinstance(minimizer_sampling, DescentMinimizer)): raise TypeError if isinstance(position, MultiField): if not set(constants).issubset(set(position.keys())): raise ValueError("Constants are not a subset of the keys of the latent space\n" f"Latent space keys: {position.keys()}\n" f"Constants keys: {constants}") if not set(point_estimates).issubset(set(position.keys())): raise ValueError("Point estimates are not a subset of the keys of the latent space\n" f"Latent space keys: {position.keys()}\n" f"Point estimate keys: {point_estimates}") if set(point_estimates) == set(position.keys()): raise RuntimeError('Point estimates for whole domain. Use EnergyAdapter instead.') # if comm is not None: # eff_n_samples = (2 if mirror_samples else 1)*n_samples # n_tasks = comm.Get_size() # if n_tasks > eff_n_samples: # raise RuntimeError("More MPI tasks available than number of samples " # f"({n_tasks} > {eff_n_samples})") # If a key is in both lists `constants` and `point_estimates` remove it. invariant = list(set(constants).intersection(point_estimates)) if isinstance(position, MultiField) and len(invariant) > 0: inv_pos = position.extract_by_keys(invariant) else: inv_pos = None position, hamiltonian = _reduce_by_keys(position, hamiltonian, invariant) _, ham_sampling = _reduce_by_keys(position, hamiltonian, point_estimates) sample_list = draw_samples(position, ham_sampling, minimizer_sampling, n_samples, mirror_samples, napprox=napprox, comm=comm) return SampledKLEnergyClass(sample_list, hamiltonian, constants, inv_pos, nanisinf)
[docs] class SampledKLEnergyClass(Energy): """Base class for Energies representing a sampled Kullback-Leibler divergence for the variational approximation of a distribution with another distribution. Supports the samples to be distributed across MPI tasks. """
[docs] def __init__(self, sample_list, hamiltonian, constants, invariants, nanisinf): myassert(isinstance(sample_list, ResidualSampleList)) myassert(sample_list.domain is hamiltonian.domain) if isinstance(sample_list.domain, MultiDomain): if not (invariants is None or isinstance(invariants, MultiField)): raise TypeError super(SampledKLEnergyClass, self).__init__(_reduce_field(sample_list._m, constants)) self._sample_list = sample_list self._hamiltonian = hamiltonian self._nanisinf = bool(nanisinf) self._constants = constants self._invariants = invariants def _func(inp): inp, tmp = _reduce_by_keys(inp, hamiltonian, constants) tmp = tmp(Linearization.make_var(inp)) return tmp.val.val[()], tmp.gradient self._val, self._grad = sample_list._average_tuple(_func) if np.isnan(self._val) and self._nanisinf: self._val = np.inf
@property def value(self): return self._val @property def gradient(self): return self._grad
[docs] def at(self, position): return SampledKLEnergyClass(self._sample_list.at(position), self._hamiltonian, self._constants, self._invariants, self._nanisinf)
[docs] def apply_metric(self, x): def _func(inp): inp, tmp = _reduce_by_keys(inp, self._hamiltonian, self._constants) tmp = tmp(Linearization.make_var(inp, want_metric=True)) return tmp.metric(x) return self._sample_list.average(_func)
@property def metric(self): return _SelfAdjointOperatorWrapper(self.position.domain, self.apply_metric) @property def samples(self): if self._invariants is None: return self._sample_list return self._sample_list.at(self._invariants)