Source code for nifty8.re.evi

# Copyright(C) 2023 Gordian Edenhofer
# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
# Authors: Gordian Edenhofer, Philipp Frank

from functools import partial
from operator import getitem
from typing import Callable, Optional, Tuple, TypeVar, Union

import jax
from jax import numpy as jnp
from jax import random
from jax.tree_util import (
    Partial, register_pytree_node_class, tree_leaves, tree_map
)

from . import conjugate_gradient, optimize
from .likelihood import (
    Likelihood, _parse_point_estimates, partial_insert_and_remove
)
from .tree_math import (
    Vector, assert_arithmetics, dot, hide_strings, random_like, stack, vdot
)

P = TypeVar("P")


def _no_jit(x, **kwargs):
    return x


def _parse_jit(jit):
    if callable(jit):
        return jit
    if isinstance(jit, bool):
        return jax.jit if jit else _no_jit
    raise TypeError(f"expected `jit` to be callable or bolean; got {jit!r}")


def _hcb_maybe_raise(condition_exception):
    condition, exception = condition_exception
    if condition:
        raise exception()


def _cond_raise(condition, exception):
    from jax.experimental.host_callback import call

    # Register as few host-callbacks as possible by implicitly hashing the
    # exception type and the strings within
    call(
        _hcb_maybe_raise, (
            condition,
            Partial(exception.__class__, *hide_strings(exception.args))
        ),
        result_shape=None
    )


def _process_point_estimate(x, primals, point_estimates, insert):
    if not point_estimates:
        return x

    point_estimates, _, p_frozen = _parse_point_estimates(
        point_estimates, primals
    )
    assert p_frozen is not None
    fill = tree_map(lambda x: jnp.zeros((1, ) * jnp.ndim(x)), p_frozen)
    in_out = partial_insert_and_remove(
        lambda *x: x[0],
        insert_axes=(point_estimates, ) if insert else None,
        flat_fill=(fill, ) if insert else None,
        remove_axes=None if insert else (point_estimates, ),
        unflatten=None if insert else Vector
    )
    return in_out(x)


[docs] def sample_likelihood(likelihood: Likelihood, primals, key): white_sample = random_like(key, likelihood.left_sqrt_metric_tangents_shape) return likelihood.left_sqrt_metric(primals, white_sample)
[docs] def draw_linear_residual( likelihood: Likelihood, pos: P, key, *, from_inverse: bool = True, point_estimates: Union[P, Tuple[str]] = (), cg: Callable = conjugate_gradient.static_cg, cg_name: Optional[str] = None, cg_kwargs: Optional[dict] = None, _raise_nonposdef: bool = False, ) -> tuple[P, int]: assert_arithmetics(pos) if not isinstance(likelihood, Likelihood): te = f"`likelihood` of invalid type; got '{type(likelihood)}'" raise TypeError(te) if point_estimates: lh, p_liquid = likelihood.freeze( point_estimates=point_estimates, primals=pos ) else: lh = likelihood p_liquid = pos def ham_metric(primals, tangents, **primals_kw): return lh.metric(primals, tangents, **primals_kw) + tangents cg_kwargs = cg_kwargs if cg_kwargs is not None else {} subkey_nll, subkey_prr = random.split(key, 2) nll_smpl = sample_likelihood(lh, p_liquid, key=subkey_nll) prr_inv_metric_smpl = random_like(key=subkey_prr, primals=p_liquid) # One may transform any metric sample to a sample of the inverse # metric by simply applying the inverse metric to it prr_smpl = prr_inv_metric_smpl # Note, we can sample antithetically by swapping the global sign of # the metric sample below (which corresponds to mirroring the final # sample) and additionally by swapping the relative sign between # the prior and the likelihood sample. The first technique is # computationally cheap and empirically known to improve stability. # The latter technique requires an additional inversion and its # impact on stability is still unknown. # TODO: investigate the impact of sampling the prior and likelihood # antithetically. smpl = nll_smpl + prr_smpl info = 0 if from_inverse: inv_metric_at_p = partial( cg, Partial(ham_metric, p_liquid), **{ "name": cg_name, "_raise_nonposdef": _raise_nonposdef, **cg_kwargs } ) smpl, info = inv_metric_at_p(smpl, x0=prr_inv_metric_smpl) _cond_raise( (info < 0) if info is not None else False, ValueError("conjugate gradient failed") ) smpl = _process_point_estimate(smpl, pos, point_estimates, insert=True) return smpl, info
def _nonlinearly_update_residual_functions( likelihood, jit: Union[Callable, bool] = False ): def _draw_linear_non_inverse(primals, key, *, point_estimates): # `draw_linear_residual` already handles `point_estimates` no need to # partially insert anything here return draw_linear_residual( likelihood, primals, key, point_estimates=point_estimates, from_inverse=False ) def _residual_vg(e, lh_trafo_at_p, ms_at_p, x, *, point_estimates): lh, e_liquid = likelihood.freeze( point_estimates=point_estimates, primals=e ) # t = likelihood.transformation(x) - lh_trafo_at_p t = tree_map(jnp.subtract, lh.transformation(x), lh_trafo_at_p) g = x - e_liquid + lh.left_sqrt_metric(e_liquid, t) r = ms_at_p - g res = 0.5 * dot(r, r) ngrad = tree_map(jnp.conj, r) ngrad += lh.left_sqrt_metric(x, lh.right_sqrt_metric(e_liquid, ngrad)) return (res, -ngrad) def _metric(e, primals, tangents, *, point_estimates): lh, e_liquid = likelihood.freeze( point_estimates=point_estimates, primals=e ) lsm = lh.left_sqrt_metric rsm = lh.right_sqrt_metric tm = lsm(e_liquid, rsm(primals, tangents)) + tangents return lsm(primals, rsm(e_liquid, tm)) + tm def _sampnorm(e, natgrad, *, point_estimates): lh, e_liquid = likelihood.freeze( point_estimates=point_estimates, primals=e ) fpp = lh.right_sqrt_metric(e_liquid, natgrad) return jnp.sqrt(vdot(natgrad, natgrad) + vdot(fpp, fpp)) jit = _parse_jit(jit) jit = partial(jit, static_argnames=("point_estimates", )) draw_linear_non_inverse = jit(_draw_linear_non_inverse) rag = jit(_residual_vg) metric = jit(_metric) sampnorm = jit(_sampnorm) return draw_linear_non_inverse, rag, metric, sampnorm
[docs] def nonlinearly_update_residual( likelihood=None, pos: P = None, residual_sample=None, metric_sample_key=None, metric_sample_sign=1.0, *, point_estimates=(), minimize: Callable[..., optimize.OptimizeResults] = optimize._newton_cg, minimize_kwargs={}, jit: Union[Callable, bool] = False, _nonlinear_update_funcs=None, _raise_notconverged=False, ) -> tuple[P, optimize.OptimizeResults]: assert_arithmetics(pos) assert_arithmetics(residual_sample) if _nonlinear_update_funcs is None: _nonlinear_update_funcs = _nonlinearly_update_residual_functions( likelihood, jit=jit ) draw_lni, rag, metric, sampnorm = _nonlinear_update_funcs sample = pos + residual_sample del residual_sample sample = _process_point_estimate(sample, pos, point_estimates, insert=False) metric_sample, _ = draw_lni( pos, metric_sample_key, point_estimates=point_estimates ) metric_sample *= metric_sample_sign metric_sample = _process_point_estimate( metric_sample, pos, point_estimates, insert=False ) # HACK for skipping the nonlinear update steps and not calling trafo skip = isinstance(minimize_kwargs.get("maxiter", None), int) and minimize_kwargs["maxiter"] == 0 if not skip: trafo_at_p = likelihood.transformation(pos) options = { "fun_and_grad": partial( rag, pos, trafo_at_p, metric_sample, point_estimates=point_estimates ), "hessp": partial(metric, pos, point_estimates=point_estimates), "custom_gradnorm": partial(sampnorm, pos, point_estimates=point_estimates), } opt_state = minimize(None, x0=sample, **(minimize_kwargs | options)) else: opt_state = optimize.OptimizeResults(sample, True, 0, None, None) if _raise_notconverged and (opt_state.status < 0): ValueError("S: failed to invert map") # Subtract position in the reduced space (i.e. space w/o point-estimates) to # not pollute the point-estimated parameters with the mean sample = opt_state.x - _process_point_estimate( pos, pos, point_estimates, insert=False ) # Remove x from state to avoid copy of the samples opt_state = opt_state._replace(x=None, jac=None) sample = _process_point_estimate(sample, pos, point_estimates, insert=True) return sample, opt_state
[docs] def draw_residual( likelihood: Likelihood, pos: P, key, *, point_estimates: Union[P, Tuple[str]] = (), cg: Callable = conjugate_gradient.static_cg, cg_name: Optional[str] = None, cg_kwargs: Optional[dict] = None, minimize: Callable[..., optimize.OptimizeResults] = optimize._newton_cg, minimize_kwargs={}, _nonlinear_update_funcs=None, _raise_nonposdef: bool = False, _raise_notconverged: bool = False, ) -> tuple[P, optimize.OptimizeResults]: residual_sample, _ = draw_linear_residual( likelihood, pos, key, point_estimates=point_estimates, cg=cg, cg_name=cg_name, cg_kwargs=cg_kwargs, _raise_nonposdef=_raise_nonposdef, ) curve = partial( nonlinearly_update_residual, likelihood, pos, metric_sample_key=key, point_estimates=point_estimates, minimize=minimize, minimize_kwargs=minimize_kwargs, jit=False, _raise_notconverged=_raise_notconverged, _nonlinear_update_funcs=_nonlinear_update_funcs, ) return stack( ( curve(residual_sample, metric_sample_sign=1.0), curve(-residual_sample, metric_sample_sign=-1.0) ) )
[docs] @register_pytree_node_class class Samples(): """Storage class for samples (relative to some expansion point) that is fully compatible with JAX transformations like vmap, pmap, etc. This class is used to store samples for the Variational Inference schemes MGVI and geoVI where samples are defined relative to some expansion point (a.k.a. latent mean or offset). See also -------- `Geometric Variational Inference`, Philipp Frank, Reimar Leike, Torsten A. Enßlin, `<https://arxiv.org/abs/2105.10470>`_ `<https://doi.org/10.3390/e23070853>`_ `Metric Gaussian Variational Inference`, Jakob Knollmüller, Torsten A. Enßlin, `<https://arxiv.org/abs/1901.11033>`_ """
[docs] def __init__(self, *, pos: P = None, samples: P, keys=None): self._pos, self._samples, self._keys = pos, samples, keys
@property def pos(self): return self._pos @property def samples(self): if self._samples is None: raise ValueError(f"{self.__class__.__name__} has no samples") smpls = self._samples if self.pos is not None: smpls = tree_map(lambda p, s: p[jnp.newaxis] + s, self.pos, smpls) return smpls @property def keys(self): return self._keys def __len__(self): if self._samples is None: return 0 return jnp.shape(tree_leaves(self._samples)[0])[0] def __getitem__(self, index): if self._samples is None: raise ValueError(f"{self.__class__.__name__} has no samples") def get(b): return getitem(b, index) if self.pos is None: return tree_map(get, self._samples) return tree_map(lambda p, s: p + get(s), self.pos, self._samples) def __iter__(self): for i in range(len(self)): yield self[i] def __eq__(self, other) -> bool: if not isinstance(other, self.__class__): return False return self.samples == other.samples
[docs] def at(self, pos, old_pos=None): """Update the offset (usually the latent mean) of all samples and optionally subtracts `old_pos` from all samples before. """ if self.pos is not None and old_pos is None: smpls = self._samples elif old_pos is not None: smpls = self.samples smpls = tree_map(lambda p, s: s - p[jnp.newaxis], old_pos, smpls) else: raise ValueError("invalid combination of `pos` and `old_pos`") return Samples(pos=pos, samples=smpls, keys=self.keys)
[docs] def squeeze(self): """Convenience method to merge the two leading axis of stacked samples (e.g. from batching). """ smpls = tree_map( lambda s: s.reshape((-1, ) + s.shape[2:]), self._samples ) return Samples(pos=self.pos, samples=smpls, keys=self.keys)
[docs] def tree_flatten(self): # Include mean in samples when passing to JAX (for e.g. vmap, pmap, ...) # return ((self.samples, ), (self.pos, )) # confuses JAX return ((self.pos, self._samples, self.keys), ())
[docs] @classmethod def tree_unflatten(cls, aux, children): # pos, = aux pos, smpls, keys = children # if pos is not None: # confuses JAX # smpls = tree_map(lambda p, s: s - p[jnp.newaxis], pos, smpls) return cls(pos=pos, samples=smpls, keys=keys)