#!/usr/bin/env python3
# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
# Authors: Philipp Frank, Jakob Roth, Gordian Edenhofer
import inspect
import os
import pickle
from dataclasses import field
from functools import partial
from os import makedirs
from typing import Any, Callable, Literal, NamedTuple, Optional, TypeVar, Union
import jax
import numpy as np
from jax import numpy as jnp
from jax import random, tree_map
from jax.tree_util import Partial
from . import optimize
from .evi import (
Samples, _parse_jit, draw_linear_residual, nonlinearly_update_residual
)
from .likelihood import Likelihood
from .logger import logger
from .minisanity import minisanity
from .model import LazyModel
from .tree_math import get_map, hide_strings, vdot
P = TypeVar("P")
[docs]
def get_status_message(
samples, state, residual=None, *, name="", map="lmap"
) -> str:
energy = state.minimization_state.fun
msg_smpl = ""
if isinstance(state.sample_state, optimize.OptimizeResults):
nlsi = tuple(int(el) for el in state.sample_state.nit)
msg_smpl = f"\n{name}: #(Nonlinear sampling steps) {nlsi}"
elif isinstance(state.sample_state, (np.ndarray, jax.Array)):
nlsi = tuple(int(el) for el in state.sample_state)
msg_smpl = f"\n{name}: Linear sampling status {nlsi}"
mini_res = ""
if residual is not None:
_, mini_res = minisanity(samples, residual, map=map)
_, mini_pr = minisanity(samples, map=map)
msg = (
f"{name}: Iteration {state.nit:04d} ⛰:{energy:+2.4e}"
f"{msg_smpl}"
f"\n{name}: #(KL minimization steps) {state.minimization_state.nit}"
f"\n{name}: Likelihood residual(s):\n{mini_res}"
f"\n{name}: Prior residual(s):\n{mini_pr}"
f"\n"
)
return msg
_reduce = partial(tree_map, partial(jnp.mean, axis=0))
class _StandardHamiltonian(LazyModel):
"""Joined object storage composed of a user-defined likelihood and a
standard normal prior.
"""
likelihood: Likelihood = field(metadata=dict(static=False))
def __init__(self, likelihood: Likelihood, /):
self.likelihood = likelihood
def __call__(self, primals, **primals_kw):
return self.energy(primals, **primals_kw)
def energy(self, primals, **primals_kw):
return self.likelihood(primals, **
primals_kw) + 0.5 * vdot(primals, primals)
def metric(self, primals, tangents, **primals_kw):
return self.likelihood.metric(
primals, tangents, **primals_kw
) + tangents
def _kl_vg(
likelihood,
primals,
primals_samples,
*,
map=jax.vmap,
reduce=_reduce,
):
assert isinstance(primals_samples, Samples)
map = get_map(map)
ham = _StandardHamiltonian(likelihood)
if len(primals_samples) == 0:
return jax.value_and_grad(ham)(primals)
vvg = map(jax.value_and_grad(ham))
s = vvg(primals_samples.at(primals).samples)
return reduce(s)
def _kl_met(
likelihood,
primals,
tangents,
primals_samples,
*,
map=jax.vmap,
reduce=_reduce
):
assert isinstance(primals_samples, Samples)
map = get_map(map)
ham = _StandardHamiltonian(likelihood)
if len(primals_samples) == 0:
return ham.metric(primals, tangents)
vmet = map(ham.metric, in_axes=(0, None))
s = vmet(primals_samples.at(primals).samples, tangents)
return reduce(s)
[docs]
@jax.jit
def concatenate_zip(*arrays):
return tree_map(
lambda *x: jnp.stack(x, axis=1).reshape((-1, ) + x[0].shape[1:]),
*arrays
)
SMPL_MODE_TYP = Literal[
"linear_sample",
"linear_resample",
"nonlinear_sample",
"nonlinear_resample",
"nonlinear_update",
]
SMPL_MODE_GENERIC_TYP = Union[SMPL_MODE_TYP, Callable[[int], SMPL_MODE_TYP]]
DICT_OR_CALL4DICT_TYP = Union[Callable[[int], dict], dict]
[docs]
class OptimizeVIState(NamedTuple):
nit: int
key: Any
sample_state: Optional[optimize.OptimizeResults] = None
minimization_state: Optional[optimize.OptimizeResults] = None
config: dict[str, Union[dict, Callable[[int], Any], Any]] = {}
def _getitem_at_nit(config, key, nit):
c = config[key]
if callable(c) and len(inspect.getfullargspec(c).args) == 1:
return c(nit)
return c
[docs]
class OptimizeVI:
"""State-less assembly of all methods needed for an MGVI/geoVI style VI
approximation.
Builds functions for a VI approximation via variants of the `Geometric
Variational Inference` and/or `Metric Gaussian Variational Inference`
algorithms. They produce approximate posterior samples that are used for KL
estimation internally and the final set of samples are the approximation of
the posterior. The samples can be linear, i.e. following a standard normal
distribution in model space, or nonlinear, i.e. following a standard normal
distribution in the canonical coordinate system of the Riemannian manifold
associated with the metric of the approximate posterior distribution. The
coordinate transformation for the nonlinear sample is approximated by an
expansion.
Both linear and nonlinear sample start by drawing a sample from the
inverse metric. To do so, we draw a sample which has the metric as
covariance structure and apply the inverse metric to it. The sample
transformed in this way has the inverse metric as covariance. The first
part is trivial since we can use the left square root of the metric
:math:`L` associated with every likelihood:
.. math::
\\tilde{d} \\leftarrow \\mathcal{G}(0,\\mathbb{1}) \\
t = L \\tilde{d}
with :math:`t` now having a covariance structure of
.. math::
<t t^\\dagger> = L <\\tilde{d} \\tilde{d}^\\dagger> L^\\dagger = M .
To transform the sample to an inverse sample, we apply the inverse
metric. We can do so using the conjugate gradient algorithm (CG). The CG
algorithm yields the solution to :math:`M s = t`, i.e. applies the
inverse of :math:`M` to :math:`t`:
.. math::
M &s = t \\\\
&s = M^{-1} t = cg(M, t) .
The linear sample is :math:`s`.
The nonlinear sampling uses :math:`s` as a starting value and curves it in
a nonlinear way as to better resemble the posterior locally. See the below
reference literature for more details on the nonlinear sampling.
See also
--------
`Geometric Variational Inference`, Philipp Frank, Reimar Leike,
Torsten A. Enßlin, `<https://arxiv.org/abs/2105.10470>`_
`<https://doi.org/10.3390/e23070853>`_
`Metric Gaussian Variational Inference`, Jakob Knollmüller,
Torsten A. Enßlin, `<https://arxiv.org/abs/1901.11033>`_
"""
[docs]
def __init__(
self,
likelihood: Likelihood,
n_total_iterations: int,
*,
kl_jit=True,
residual_jit=True,
kl_map=jax.vmap,
residual_map="lmap",
kl_reduce=_reduce,
mirror_samples=True,
_kl_value_and_grad: Optional[Callable] = None,
_kl_metric: Optional[Callable] = None,
_draw_linear_residual: Optional[Callable] = None,
_nonlinearly_update_residual: Optional[Callable] = None,
_get_status_message: Optional[Callable] = None,
):
"""JaxOpt style minimizer for a VI approximation of a distribution with
samples.
Parameters
----------
likelihood: :class:`~nifty8.re.likelihood.Likelihood`
Likelihood to be used for inference.
n_total_iterations: int
Total number of iterations. One iteration consists of the steps
1) - 3).
kl_jit: bool or callable
Whether to jit the KL minimization.
residual_jit: bool or callable
Whether to jit the residual sampling functions.
kl_map: callable or str
Map function used for the KL minimization.
residual_map: callable or str
Map function used for the residual sampling functions.
kl_reduce: callable
Reduce function used for the KL minimization.
mirror_samples: bool
Whether to mirror the samples or not.
Notes
-----
Implements the base logic present in conditional VI approximations
such as MGVI and geoVI. First samples are generated (and/or updated)
and then their collective mean is optimized for using the sample
estimated variational KL between the true distribution and the sampled
approximation. This is split into three steps:
1) Sample generation
2) Sample update
3) KL minimization.
Step 1) and 2) may be skipped depending on the minimizers state, but
step 3) is always performed at the end of one iteration. A full loop
consists of repeatedly iterating over the steps 1) - 3).
"""
kl_jit = _parse_jit(kl_jit)
residual_jit = _parse_jit(residual_jit)
residual_map = get_map(residual_map)
if mirror_samples is False:
raise NotImplementedError()
if _kl_value_and_grad is None:
_kl_value_and_grad = partial(
kl_jit(_kl_vg, static_argnames=("map", "reduce")),
likelihood,
map=kl_map,
reduce=kl_reduce
)
if _kl_metric is None:
_kl_metric = partial(
kl_jit(_kl_met, static_argnames=("map", "reduce")),
likelihood,
map=kl_map,
reduce=kl_reduce
)
if _draw_linear_residual is None:
_draw_linear_residual = partial(
residual_jit(draw_linear_residual), likelihood
)
if _nonlinearly_update_residual is None:
# TODO: Pull out `jit` from `nonlinearly_update_residual` once NCG
# is jit-able
from .evi import _nonlinearly_update_residual_functions
_nonlin_funcs = _nonlinearly_update_residual_functions(
likelihood=likelihood,
jit=residual_jit,
)
_nonlinearly_update_residual = partial(
nonlinearly_update_residual,
likelihood,
_nonlinear_update_funcs=_nonlin_funcs,
)
if _get_status_message is None:
_get_status_message = partial(
get_status_message,
residual=likelihood.normalized_residual,
name=self.__class__.__name__,
)
self.n_total_iterations = n_total_iterations
self.kl_value_and_grad = _kl_value_and_grad
self.kl_metric = _kl_metric
self.draw_linear_residual = _draw_linear_residual
self.nonlinearly_update_residual = _nonlinearly_update_residual
self.residual_map = residual_map
self.get_status_message = _get_status_message
[docs]
def draw_linear_samples(self, primals, keys, **kwargs):
# NOTE, use `Partial` in favor of `partial` to allow the (potentially)
# re-jitting `residual_map` to trace the kwargs
kwargs = hide_strings(kwargs)
sampler = Partial(self.draw_linear_residual, **kwargs)
sampler = self.residual_map(sampler, in_axes=(None, 0))
smpls, smpls_states = sampler(primals, keys)
# zip samples such that the mirrored-counterpart always comes right
# after the original sample
smpls = Samples(
pos=primals, samples=concatenate_zip(smpls, -smpls), keys=keys
)
return smpls, smpls_states
[docs]
def nonlinearly_update_samples(self, samples: Samples, **kwargs):
# NOTE, use `Partial` in favor of `partial` to allow the (potentially)
# re-jitting `residual_map` to trace the kwargs
kwargs = hide_strings(kwargs)
curver = Partial(self.nonlinearly_update_residual, **kwargs)
curver = self.residual_map(curver, in_axes=(None, 0, 0, 0))
assert len(samples.keys) == len(samples) // 2
metric_sample_key = concatenate_zip(*((samples.keys, ) * 2))
sgn = jnp.ones(len(samples.keys))
sgn = concatenate_zip(sgn, -sgn)
smpls, smpls_states = curver(
samples.pos, samples._samples, metric_sample_key, sgn
)
smpls = Samples(pos=samples.pos, samples=smpls, keys=samples.keys)
return smpls, smpls_states
[docs]
def draw_samples(
self,
samples: Samples,
*,
key,
sample_mode: SMPL_MODE_TYP,
n_samples: int,
point_estimates,
draw_linear_kwargs={},
nonlinearly_update_kwargs={},
**kwargs
):
# Always resample if `n_samples` increased
n_keys = 0 if samples.keys is None else len(samples.keys)
if n_samples == 0:
sample_mode = ""
elif n_samples != n_keys and sample_mode.lower() == "nonlinear_update":
sample_mode = "nonlinear_resample"
elif n_samples != n_keys and sample_mode.lower().endswith("_sample"):
sample_mode = sample_mode.replace("_sample", "_resample")
if sample_mode.lower() in (
"linear_resample", "linear_sample", "nonlinear_resample",
"nonlinear_sample"
):
k_smpls = samples.keys # Re-use the keys if not re-sampling
if sample_mode.lower().endswith("_resample"):
k_smpls = random.split(key, n_samples)
assert n_samples == len(k_smpls)
samples, st_smpls = self.draw_linear_samples(
samples.pos,
k_smpls,
point_estimates=point_estimates,
**draw_linear_kwargs,
**kwargs
)
if sample_mode.lower().startswith("nonlinear"):
samples, st_smpls = self.nonlinearly_update_samples(
samples,
point_estimates=point_estimates,
**nonlinearly_update_kwargs,
**kwargs
)
elif not sample_mode.lower().startswith("linear"):
ve = f"invalid sampling mode {sample_mode!r}"
raise ValueError(ve)
elif sample_mode.lower() == "nonlinear_update":
samples, st_smpls = self.nonlinearly_update_samples(
samples,
point_estimates=point_estimates,
**nonlinearly_update_kwargs,
**kwargs
)
elif sample_mode == "":
samples, st_smpls = samples, 0 # Do nothing for MAP
else:
ve = f"invalid sampling mode {sample_mode!r}"
raise ValueError(ve)
return samples, st_smpls
[docs]
def kl_minimize(
self,
samples: Samples,
minimize: Callable[..., optimize.OptimizeResults] = optimize._newton_cg,
minimize_kwargs={},
**kwargs
) -> optimize.OptimizeResults:
fun_and_grad = Partial(
self.kl_value_and_grad, primals_samples=samples, **kwargs
)
hessp = Partial(self.kl_metric, primals_samples=samples, **kwargs)
kl_opt_state = minimize(
None,
x0=samples.pos,
fun_and_grad=fun_and_grad,
hessp=hessp,
**minimize_kwargs
)
return kl_opt_state
[docs]
def init_state(
self,
key,
*,
nit=0,
n_samples: Union[int, Callable[[int], int]],
draw_linear_kwargs: DICT_OR_CALL4DICT_TYP = dict(
cg_name="SL", cg_kwargs=dict()
),
nonlinearly_update_kwargs: DICT_OR_CALL4DICT_TYP = dict(
minimize_kwargs=dict(name="SN", cg_kwargs=dict(name="SNCG"))
),
kl_kwargs: DICT_OR_CALL4DICT_TYP = dict(
minimize_kwargs=dict(name="M", cg_kwargs=dict(name="MCG"))
),
sample_mode: SMPL_MODE_GENERIC_TYP = "nonlinear_resample",
point_estimates=(),
constants=(), # TODO
) -> OptimizeVIState:
"""Initialize the state of the (otherwise state-less) VI approximation.
Parameters
----------
key : jax random number generataion key
nit : int
Current iteration number.
n_samples : int or callable
Number of samples to draw.
draw_linear_kwargs : dict or callable
Configuration for drawing linear samples, see
:func:`draw_linear_residual`.
nonlinearly_update_kwargs : dict or callable
Configuration for nonlinearly updating samples, see
:func:`nonlinearly_update_residual`.
kl_kwargs : dict or callable
Keyword arguments for the KL minimizer.
sample_mode : str or callable
One in {"linear_sample", "linear_resample", "nonlinear_sample",
"nonlinear_resample", "nonlinear_update"}. The mode denotes the way
samples are drawn and/or updates, "linear" draws MGVI samples,
"nonlinear" draws MGVI samples which are then nonlinearly updated
with geoVI, the "_sample" versus "_resample" suffix denotes whether
the same stochasticity or new stochasticity is used for the drawing
of the samples, and "nonlinear_update" nonlinearly updates existing
samples using geoVI.
point_estimates: tree-like structure or tuple of str
Pytree of same structure as likelihood input but with boolean
leaves indicating whether to sample the value in the input or use
it as a point estimate. As a convenience method, for dict-like
inputs, a tuple of strings is also valid. From these the boolean
indicator pytree is automatically constructed.
constants: tree-like structure or tuple of str
Not implemented yet, sorry :( Do bug me (Gordian) at
edh@mpa-garching.mpg.de if you wanted to run with this option.
Most of the parameters can be callable, in which case they are called
with the current iteration number as argument and should return the
value to use for the current iteration.
"""
config = dict(
n_samples=n_samples,
sample_mode=sample_mode,
point_estimates=point_estimates,
constants=constants,
draw_linear_kwargs=draw_linear_kwargs,
nonlinearly_update_kwargs=nonlinearly_update_kwargs,
kl_kwargs=kl_kwargs,
)
return OptimizeVIState(nit, key, config=config)
[docs]
def update(
self,
samples: Samples,
state: OptimizeVIState,
/,
**kwargs,
) -> tuple[Samples, OptimizeVIState]:
"""Moves the VI approximation one sample update and minimization forward.
Parameters
----------
samples : :class:`Samples`
Current samples.
state : :class:`OptimizeVIState`
Current state of the VI approximation.
kwargs : dict
Keyword arguments passed to the residual sampling functions.
"""
assert isinstance(samples, Samples)
assert isinstance(state, OptimizeVIState)
nit, key, config = state.nit, state.key, state.config
constants = _getitem_at_nit(config, "constants", nit)
if not (constants == () or constants is None):
raise NotImplementedError()
sample_mode = _getitem_at_nit(config, "sample_mode", nit)
point_estimates = _getitem_at_nit(config, "point_estimates", nit)
n_samples = _getitem_at_nit(config, "n_samples", nit)
draw_linear_kwargs = _getitem_at_nit(config, "draw_linear_kwargs", nit)
nonlinearly_update_kwargs = _getitem_at_nit(
config, "nonlinearly_update_kwargs", nit
)
# Make the `key` tick independently of whether samples are drawn or not
key, sk = random.split(key, 2)
samples, st_smpls = self.draw_samples(
samples,
key=sk,
sample_mode=sample_mode,
point_estimates=point_estimates,
n_samples=n_samples,
draw_linear_kwargs=draw_linear_kwargs,
nonlinearly_update_kwargs=nonlinearly_update_kwargs,
**kwargs
)
kl_kwargs = _getitem_at_nit(config, "kl_kwargs", nit).copy()
kl_opt_state = self.kl_minimize(samples, **kl_kwargs, **kwargs)
samples = samples.at(kl_opt_state.x)
# Remove unnecessary references
kl_opt_state = kl_opt_state._replace(
x=None, jac=None, hess=None, hess_inv=None
)
state = state._replace(
nit=nit + 1,
key=key,
sample_state=st_smpls,
minimization_state=kl_opt_state,
)
return samples, state
[docs]
def run(self, samples, *args, **kwargs) -> tuple[Samples, OptimizeVIState]:
state = self.init_state(*args, **kwargs)
nm = self.__class__.__name__
for i in range(state.nit, self.n_total_iterations):
logger.info(f"{nm}: Starting {i+1:04d}")
samples, state = self.update(samples, state)
msg = self.get_status_message(
samples, state, map=self.residual_map, name=nm
)
logger.info(msg)
return samples, state
[docs]
def optimize_kl(
likelihood: Likelihood,
position_or_samples,
*,
key,
n_total_iterations: int,
n_samples,
point_estimates=(),
constants=(),
kl_jit=True,
residual_jit=True,
kl_map=jax.vmap,
residual_map="lmap",
kl_reduce=_reduce,
mirror_samples=True,
draw_linear_kwargs=dict(cg_name="SL", cg_kwargs=dict()),
nonlinearly_update_kwargs=dict(
minimize_kwargs=dict(name="SN", cg_kwargs=dict(name="SNCG"))
),
kl_kwargs=dict(minimize_kwargs=dict(name="M", cg_kwargs=dict(name="MCG"))),
sample_mode: SMPL_MODE_GENERIC_TYP = "nonlinear_resample",
resume: Union[str, bool] = False,
callback: Optional[Callable[[Samples, OptimizeVIState], None]] = None,
odir: Optional[str] = None,
_optimize_vi=None,
_optimize_vi_state=None,
) -> tuple[Samples, OptimizeVIState]:
"""One-stop-shop for MGVI/geoVI style VI approximation.
Parameters
----------
position_or_samples: Samples or tree-like
Initial position for minimization.
resume : str or bool
Resume partially run optimization. If `True`, the optimization is
resumed from the previos state in `odir` otherwise it is resumed from
the location toward which `resume` points.
callback : callable or None
Function called after every global iteration taking the samples and the
optimization state.
odir : str or None
Path at which all output files are saved.
See :class:`OptimizeVI` and :func:`OptimizeVI.init_state` for the remaining
parameters and further details on the optimization.
"""
LAST_FILENAME = "last.pkl"
MINISANITY_FILENAME = "minisanity.txt"
opt_vi = _optimize_vi if _optimize_vi is not None else None
if opt_vi is None:
opt_vi = OptimizeVI(
likelihood,
n_total_iterations=n_total_iterations,
kl_jit=kl_jit,
residual_jit=residual_jit,
kl_map=kl_map,
residual_map=residual_map,
kl_reduce=kl_reduce,
mirror_samples=mirror_samples
)
last_fn = os.path.join(odir, LAST_FILENAME) if odir is not None else None
resume_fn = resume if os.path.isfile(resume) else last_fn
sanity_fn = os.path.join(
odir, MINISANITY_FILENAME
) if odir is not None else None
samples = None
if isinstance(position_or_samples, Samples):
samples = position_or_samples
else:
samples = Samples(pos=position_or_samples, samples=None, keys=None)
opt_vi_st = None
if resume:
if not os.path.isfile(resume_fn):
raise ValueError(f"unable to resume from {resume_fn!r}")
if samples.pos is not None:
logger.warning("overwriting `position_or_samples` with `resume`")
with open(resume_fn, "rb") as f:
samples, opt_vi_st = pickle.load(f)
opt_vi_st_init = opt_vi.init_state(
key,
n_samples=n_samples,
draw_linear_kwargs=draw_linear_kwargs,
nonlinearly_update_kwargs=nonlinearly_update_kwargs,
kl_kwargs=kl_kwargs,
sample_mode=sample_mode,
point_estimates=point_estimates,
constants=constants,
)
opt_vi_st = _optimize_vi_state if _optimize_vi_state is not None else opt_vi_st
opt_vi_st = opt_vi_st_init if opt_vi_st is None else opt_vi_st
if len(opt_vi_st.config) == 0: # resume or _optimize_vi_state has empty config
opt_vi_st = opt_vi_st._replace(config=opt_vi_st_init.config)
if odir:
makedirs(odir, exist_ok=True)
if not resume and sanity_fn is not None:
with open(sanity_fn, "w"):
pass
if not resume and last_fn is not None:
with open(last_fn, "wb"):
pass
nm = "OPTIMIZE_KL"
for i in range(opt_vi_st.nit, opt_vi.n_total_iterations):
logger.info(f"{nm}: Starting {i+1:04d}")
samples, opt_vi_st = opt_vi.update(samples, opt_vi_st)
msg = opt_vi.get_status_message(samples, opt_vi_st, name=nm)
logger.info(msg)
if sanity_fn is not None:
with open(sanity_fn, "a") as f:
f.write("\n" + msg)
if last_fn is not None:
with open(last_fn, "wb") as f:
# TODO: Make all arrays numpy arrays as to not instantiate on
# the main device when loading
pickle.dump((samples, opt_vi_st._replace(config={})), f)
if callback is not None:
callback(samples, opt_vi_st)
return samples, opt_vi_st