Approximate Inference#

In Variational Inference (VI), the posterior \mathcal{P}(\xi|d) is approximated by a simpler, parametrized distribution, often a Gaussian \mathcal{Q}(\xi)=\mathcal{G}(\xi-m,D). The parameters of \mathcal{Q}, the mean m and its covariance D are obtained by minimization of an appropriate information distance measure between \mathcal{Q} and \mathcal{P}. As a compromise between being optimal and being computationally affordable, the variational Kullback-Leibler (KL) divergence is used:

\mathrm{KL}(m,D|d)= \mathcal{D}_\mathrm{KL}(\mathcal{Q}||\mathcal{P})=
\int \mathcal{D}\xi \,\mathcal{Q}(\xi) \log \left( \frac{\mathcal{Q}(\xi)}{\mathcal{P}(\xi)} \right)

NIFTy features two main alternatives for variational inference: Metric Gaussian Variational Inference (MGVI) and geometric Variational Inference (geoVI). A visual comparison of the MGVI and GeoVI algorithm can be found in variational_inference_visualized.py.

Metric Gaussian Variational Inference (MGVI)#

Minimizing the KL divergence with respect to all entries of the covariance D is unfeasible for fields. Therefore, Metric Gaussian Variational Inference (MGVI, [1]) approximates the posterior precision matrix D^{-1} at the location of the current mean m by the Bayesian Fisher information metric,

M \approx \left\langle \frac{\partial \mathcal{H}(d,\xi)}{\partial \xi} \, \frac{\partial \mathcal{H}(d,\xi)}{\partial \xi}^\dagger \right\rangle_{(d,\xi)}.

In practice the average is performed over \mathcal{P}(d,\xi)\approx \mathcal{P}(d|\xi)\,\delta(\xi-m) by evaluating the expression at the current mean m. This results in a Fisher information metric of the likelihood evaluated at the mean plus the prior information metric. Therefore we will only have to infer the mean of the approximate distribution. The only term within the KL-divergence that explicitly depends on it is the Hamiltonian of the true problem averaged over the approximation:

\mathrm{KL}(m|d) \;\widehat{=}\;
\left\langle  \mathcal{H}(\xi,d)    \right\rangle_{\mathcal{Q}(\xi)},

where \widehat{=} expresses equality up to irrelevant (here not m-dependent) terms.

Thus, only the gradient of the KL is needed with respect to this, which can be expressed as

\frac{\partial \mathrm{KL}(m|d)}{\partial m} = \left\langle \frac{\partial \mathcal{H}(d,\xi)}{\partial \xi}  \right\rangle_{\mathcal{G}(\xi-m,D)}.

We stochastically estimate the KL-divergence and gradients with a set of samples drawn from the approximate posterior distribution. The particular structure of the covariance allows us to draw independent samples solving a certain system of equations. This KL-divergence for MGVI is implemented by SampledKLEnergy() within NIFTy8.

Note that MGVI typically provides only a lower bound on the variance.

Geometric Variational Inference (geoVI)#

For non-linear posterior distributions \mathcal{P}(\xi|d) an approximation with a Gaussian \mathcal{Q}(\xi) in the coordinates \xi is sub-optimal, as higher order interactions are ignored. A better approximation can be achieved by constructing a coordinate system y = g\left(\xi\right) in which the posterior is close to a Gaussian, and perform VI with a Gaussian \mathcal{Q}(y) in these coordinates. This approach is called Geometric Variational Inference (geoVI). It is discussed in detail in [2].

One useful coordinate system is obtained in case the metric M of the posterior can be expressed as the pullback of the Euclidean metric by g:

M = \left(\frac{\partial g}{\partial \xi}\right)^T \frac{\partial g}{\partial \xi} \ .

In general, such a transformation exists only locally, i.e. in a neighbourhood of some expansion point \bar{\xi}, denoted as g_{\bar{\xi}}\left(\xi\right). Using g_{\bar{\xi}}, the GeoVI scheme uses a zero mean, unit Gaussian \mathcal{Q}(y) = \mathcal{G}(y, 1) approximation. It can be expressed in \xi coordinates via the pushforward by the inverse transformation \xi = g_{\bar{\xi}}^{-1}(y):

\mathcal{Q}_{\bar{\xi}}(\xi) = \left(g_{\bar{\xi}}^{-1} * \mathcal{Q}\right)(\xi) = \int \delta\left(\xi - g_{\bar{\xi}}^{-1}(y)\right) \ \mathcal{G}(y, 1) \ \mathcal{D}y \ ,

where \delta denotes the Kronecker-delta.

GeoVI obtains the optimal expansion point \bar{\xi} such that \mathcal{Q}_{\bar{\xi}} matches the posterior as good as possible. Analogous to the MGVI algorithm, \bar{\xi} is obtained by minimization of the KL-divergence between \mathcal{P} and \mathcal{Q}_{\bar{\xi}} w.r.t. \bar{\xi}. Furthermore the KL is represented as a stochastic estimate using a set of samples drawn from \mathcal{Q}_{\bar{\xi}} which is implemented in NIFTy8 via SampledKLEnergy() with minimizer_sampling != None.

Publications#

If you use MGVI or geoVI, the authors of the respective papers would greatly appreciate a citation.