nifty8.re.evi module#

class Samples(*, pos: P = None, samples: P, keys=None)[source]#

Bases: object

Storage class for samples (relative to some expansion point) that is fully compatible with JAX transformations like vmap, pmap, etc.

This class is used to store samples for the Variational Inference schemes MGVI and geoVI where samples are defined relative to some expansion point (a.k.a. latent mean or offset).

See also

Geometric Variational Inference, Philipp Frank, Reimar Leike, Torsten A. Enßlin, https://arxiv.org/abs/2105.10470 https://doi.org/10.3390/e23070853

Metric Gaussian Variational Inference, Jakob Knollmüller, Torsten A. Enßlin, https://arxiv.org/abs/1901.11033

__init__(*, pos: P = None, samples: P, keys=None)[source]#
at(pos, old_pos=None)[source]#

Update the offset (usually the latent mean) of all samples and optionally subtracts old_pos from all samples before.

property keys#
property pos#
property samples#
squeeze()[source]#

Convenience method to merge the two leading axis of stacked samples (e.g. from batching).

tree_flatten()[source]#
classmethod tree_unflatten(aux, children)[source]#
draw_linear_residual(likelihood: ~nifty8.re.likelihood.Likelihood, pos: ~nifty8.re.evi.P, key, *, from_inverse: bool = True, point_estimates: ~nifty8.re.evi.P | ~typing.Tuple[str] = (), cg: ~typing.Callable = <function static_cg>, cg_name: str | None = None, cg_kwargs: dict | None = None, _raise_nonposdef: bool = False) tuple[P, int][source]#
draw_residual(likelihood: ~nifty8.re.likelihood.Likelihood, pos: ~nifty8.re.evi.P, key, *, point_estimates: ~nifty8.re.evi.P | ~typing.Tuple[str] = (), cg: ~typing.Callable = <function static_cg>, cg_name: str | None = None, cg_kwargs: dict | None = None, minimize: ~typing.Callable[[...], ~nifty8.re.optimize.OptimizeResults] = <function _newton_cg>, minimize_kwargs={}, _nonlinear_update_funcs=None, _raise_nonposdef: bool = False, _raise_notconverged: bool = False) tuple[P, OptimizeResults][source]#
nonlinearly_update_residual(likelihood=None, pos: ~nifty8.re.evi.P = None, residual_sample=None, metric_sample_key=None, metric_sample_sign=1.0, *, point_estimates=(), minimize: ~typing.Callable[[...], ~nifty8.re.optimize.OptimizeResults] = <function _newton_cg>, minimize_kwargs={}, jit: ~typing.Callable | bool = False, _nonlinear_update_funcs=None, _raise_notconverged=False) tuple[P, OptimizeResults][source]#
sample_likelihood(likelihood: Likelihood, primals, key)[source]#