nifty8.re.optimize_kl module#

class OptimizeVI(likelihood: ~nifty8.re.likelihood.Likelihood, n_total_iterations: int, *, kl_jit=True, residual_jit=True, kl_map=<function vmap>, residual_map='lmap', kl_reduce=functools.partial(<function tree_map>, functools.partial(<function mean>, axis=0)), mirror_samples=True, _kl_value_and_grad: ~typing.Callable | None = None, _kl_metric: ~typing.Callable | None = None, _draw_linear_residual: ~typing.Callable | None = None, _nonlinearly_update_residual: ~typing.Callable | None = None, _get_status_message: ~typing.Callable | None = None)[source]#

Bases: object

State-less assembly of all methods needed for an MGVI/geoVI style VI approximation.

Builds functions for a VI approximation via variants of the Geometric Variational Inference and/or Metric Gaussian Variational Inference algorithms. They produce approximate posterior samples that are used for KL estimation internally and the final set of samples are the approximation of the posterior. The samples can be linear, i.e. following a standard normal distribution in model space, or nonlinear, i.e. following a standard normal distribution in the canonical coordinate system of the Riemannian manifold associated with the metric of the approximate posterior distribution. The coordinate transformation for the nonlinear sample is approximated by an expansion.

Both linear and nonlinear sample start by drawing a sample from the inverse metric. To do so, we draw a sample which has the metric as covariance structure and apply the inverse metric to it. The sample transformed in this way has the inverse metric as covariance. The first part is trivial since we can use the left square root of the metric L associated with every likelihood:

\tilde{d} \leftarrow \mathcal{G}(0,\mathbb{1}) \
t = L \tilde{d}

with t now having a covariance structure of

<t t^\dagger> = L <\tilde{d} \tilde{d}^\dagger> L^\dagger = M .

To transform the sample to an inverse sample, we apply the inverse metric. We can do so using the conjugate gradient algorithm (CG). The CG algorithm yields the solution to M s = t, i.e. applies the inverse of M to t:

M &s =  t \\
&s = M^{-1} t = cg(M, t) .

The linear sample is s.

The nonlinear sampling uses s as a starting value and curves it in a nonlinear way as to better resemble the posterior locally. See the below reference literature for more details on the nonlinear sampling.

See also

Geometric Variational Inference, Philipp Frank, Reimar Leike, Torsten A. Enßlin, https://arxiv.org/abs/2105.10470 https://doi.org/10.3390/e23070853 Metric Gaussian Variational Inference, Jakob Knollmüller, Torsten A. Enßlin, https://arxiv.org/abs/1901.11033

__init__(likelihood: ~nifty8.re.likelihood.Likelihood, n_total_iterations: int, *, kl_jit=True, residual_jit=True, kl_map=<function vmap>, residual_map='lmap', kl_reduce=functools.partial(<function tree_map>, functools.partial(<function mean>, axis=0)), mirror_samples=True, _kl_value_and_grad: ~typing.Callable | None = None, _kl_metric: ~typing.Callable | None = None, _draw_linear_residual: ~typing.Callable | None = None, _nonlinearly_update_residual: ~typing.Callable | None = None, _get_status_message: ~typing.Callable | None = None)[source]#

JaxOpt style minimizer for a VI approximation of a distribution with samples.

Parameters:
  • likelihood (Likelihood) – Likelihood to be used for inference.

  • n_total_iterations (int) – Total number of iterations. One iteration consists of the steps 1) - 3).

  • kl_jit (bool or callable) – Whether to jit the KL minimization.

  • residual_jit (bool or callable) – Whether to jit the residual sampling functions.

  • kl_map (callable or str) – Map function used for the KL minimization.

  • residual_map (callable or str) – Map function used for the residual sampling functions.

  • kl_reduce (callable) – Reduce function used for the KL minimization.

  • mirror_samples (bool) – Whether to mirror the samples or not.

Notes

Implements the base logic present in conditional VI approximations such as MGVI and geoVI. First samples are generated (and/or updated) and then their collective mean is optimized for using the sample estimated variational KL between the true distribution and the sampled approximation. This is split into three steps: 1) Sample generation 2) Sample update 3) KL minimization. Step 1) and 2) may be skipped depending on the minimizers state, but step 3) is always performed at the end of one iteration. A full loop consists of repeatedly iterating over the steps 1) - 3).

draw_linear_samples(primals, keys, **kwargs)[source]#
draw_samples(samples: Samples, *, key, sample_mode: Literal['linear_sample', 'linear_resample', 'nonlinear_sample', 'nonlinear_resample', 'nonlinear_update'], n_samples: int, point_estimates, draw_linear_kwargs={}, nonlinearly_update_kwargs={}, **kwargs)[source]#
init_state(key, *, nit=0, n_samples: int | Callable[[int], int], draw_linear_kwargs: Callable[[int], dict] | dict = {'cg_kwargs': {}, 'cg_name': 'SL'}, nonlinearly_update_kwargs: Callable[[int], dict] | dict = {'minimize_kwargs': {'cg_kwargs': {'name': 'SNCG'}, 'name': 'SN'}}, kl_kwargs: Callable[[int], dict] | dict = {'minimize_kwargs': {'cg_kwargs': {'name': 'MCG'}, 'name': 'M'}}, sample_mode: Literal['linear_sample', 'linear_resample', 'nonlinear_sample', 'nonlinear_resample', 'nonlinear_update'] | Callable[[int], Literal['linear_sample', 'linear_resample', 'nonlinear_sample', 'nonlinear_resample', 'nonlinear_update']] = 'nonlinear_resample', point_estimates=(), constants=()) OptimizeVIState[source]#

Initialize the state of the (otherwise state-less) VI approximation.

Parameters:
  • key (jax random number generataion key)

  • nit (int) – Current iteration number.

  • n_samples (int or callable) – Number of samples to draw.

  • draw_linear_kwargs (dict or callable) – Configuration for drawing linear samples, see draw_linear_residual().

  • nonlinearly_update_kwargs (dict or callable) – Configuration for nonlinearly updating samples, see nonlinearly_update_residual().

  • kl_kwargs (dict or callable) – Keyword arguments for the KL minimizer.

  • sample_mode (str or callable) – One in {“linear_sample”, “linear_resample”, “nonlinear_sample”, “nonlinear_resample”, “nonlinear_update”}. The mode denotes the way samples are drawn and/or updates, “linear” draws MGVI samples, “nonlinear” draws MGVI samples which are then nonlinearly updated with geoVI, the “_sample” versus “_resample” suffix denotes whether the same stochasticity or new stochasticity is used for the drawing of the samples, and “nonlinear_update” nonlinearly updates existing samples using geoVI.

  • point_estimates (tree-like structure or tuple of str) – Pytree of same structure as likelihood input but with boolean leaves indicating whether to sample the value in the input or use it as a point estimate. As a convenience method, for dict-like inputs, a tuple of strings is also valid. From these the boolean indicator pytree is automatically constructed.

  • constants (tree-like structure or tuple of str) – Not implemented yet, sorry :( Do bug me (Gordian) at edh@mpa-garching.mpg.de if you wanted to run with this option.

  • callable (Most of the parameters can be)

  • called (in which case they are)

  • the (with the current iteration number as argument and should return)

  • iteration. (value to use for the current)

kl_minimize(samples: ~nifty8.re.evi.Samples, minimize: ~typing.Callable[[...], ~nifty8.re.optimize.OptimizeResults] = <function _newton_cg>, minimize_kwargs={}, **kwargs) OptimizeResults[source]#
nonlinearly_update_samples(samples: Samples, **kwargs)[source]#
run(samples, *args, **kwargs) tuple[Samples, OptimizeVIState][source]#
update(samples: Samples, state: OptimizeVIState, /, **kwargs) tuple[Samples, OptimizeVIState][source]#

Moves the VI approximation one sample update and minimization forward.

Parameters:
  • samples (Samples) – Current samples.

  • state (OptimizeVIState) – Current state of the VI approximation.

  • kwargs (dict) – Keyword arguments passed to the residual sampling functions.

class OptimizeVIState(nit, key, sample_state, minimization_state, config)[source]#

Bases: NamedTuple

config: dict[str, dict | Callable[[int], Any] | Any]#

Alias for field number 4

key: Any#

Alias for field number 1

minimization_state: OptimizeResults | None#

Alias for field number 3

nit: int#

Alias for field number 0

sample_state: OptimizeResults | None#

Alias for field number 2

concatenate_zip(*arrays)[source]#
get_status_message(samples, state, residual=None, *, name='', map='lmap') str[source]#
optimize_kl(likelihood: ~nifty8.re.likelihood.Likelihood, position_or_samples, *, key, n_total_iterations: int, n_samples, point_estimates=(), constants=(), kl_jit=True, residual_jit=True, kl_map=<function vmap>, residual_map='lmap', kl_reduce=functools.partial(<function tree_map>, functools.partial(<function mean>, axis=0)), mirror_samples=True, draw_linear_kwargs={'cg_kwargs': {}, 'cg_name': 'SL'}, nonlinearly_update_kwargs={'minimize_kwargs': {'cg_kwargs': {'name': 'SNCG'}, 'name': 'SN'}}, kl_kwargs={'minimize_kwargs': {'cg_kwargs': {'name': 'MCG'}, 'name': 'M'}}, sample_mode: ~typing.Literal['linear_sample', 'linear_resample', 'nonlinear_sample', 'nonlinear_resample', 'nonlinear_update'] | ~typing.Callable[[int], ~typing.Literal['linear_sample', 'linear_resample', 'nonlinear_sample', 'nonlinear_resample', 'nonlinear_update']] = 'nonlinear_resample', resume: str | bool = False, callback: ~typing.Callable[[~nifty8.re.evi.Samples, ~nifty8.re.optimize_kl.OptimizeVIState], None] | None = None, odir: str | None = None, _optimize_vi=None, _optimize_vi_state=None) tuple[Samples, OptimizeVIState][source]#

One-stop-shop for MGVI/geoVI style VI approximation.

Parameters:
  • position_or_samples (Samples or tree-like) – Initial position for minimization.

  • resume (str or bool) – Resume partially run optimization. If True, the optimization is resumed from the previos state in odir otherwise it is resumed from the location toward which resume points.

  • callback (callable or None) – Function called after every global iteration taking the samples and the optimization state.

  • odir (str or None) – Path at which all output files are saved.

See OptimizeVI and OptimizeVI.init_state() for the remaining parameters and further details on the optimization.