Source code for

# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause

from typing import Callable, Optional, TypeVar, Union

import jax
from jax import numpy as jnp
from jax import random

from ..lax import fori_loop
from ..tree_math import ShapeWithDtype

V = TypeVar("V")

[docs] def lanczos_tridiag(mat: Callable[[V], V], v: V, order: int): """Compute the Lanczos decomposition into a tri-diagonal matrix and its corresponding orthonormal projection matrix. """ swd = ShapeWithDtype.from_leave(v) tridiag = jnp.zeros((order, order), dtype=swd.dtype) vecs = jnp.zeros((order, ) + swd.shape, dtype=swd.dtype) v = v / jnp.linalg.norm(v) vecs =[0].set(v) # TODO # * use `` and `tree_math.norm` in favor of plain `` # * remove all reshapes as they are unnecessary # Zeroth iteration w = mat(v) if w.shape != swd.shape: ve = f"shape of `mat(v)` {w.shape!r} incompatible with {swd}" raise ValueError(ve) alpha =, v) tridiag =[(0, 0)].set(alpha) w -= alpha * v beta = jnp.linalg.norm(w) tridiag =[(0, 1)].set(beta) tridiag =[(1, 0)].set(beta) vecs =[1].set(w / beta) def reortho_step(j, state): vecs, w = state tau = vecs[j, :].reshape(swd.shape) coeff =, tau) w -= coeff * tau return vecs, w def lanczos_step(i, state): tridiag, vecs, beta = state # TODO: only save current and last vector and do not # reorthogonalize??????; check theory beforehand!!! v = vecs[i, :].reshape(swd.shape) v_old = vecs[i - 1, :].reshape(swd.shape) w = mat(v) - beta * v_old alpha =, v) tridiag =[(i, i)].set(alpha) w -= alpha * v # Full reorthogonalization # NOTE, in theory the loop could terminate at `i` but this would make # JAX's default backwards pass not work vecs, w = fori_loop(0, order, reortho_step, (vecs, w)) # TODO: Raise if lanczos vectors are independent i.e. `beta` small? beta = jnp.linalg.norm(w) tridiag =[(i, i + 1)].set(beta) tridiag =[(i + 1, i)].set(beta) vecs =[i + 1].set(w / beta) return tridiag, vecs, beta tridiag, vecs, beta = fori_loop( 1, order - 1, lanczos_step, (tridiag, vecs, beta) ) # Final tridiag value and reorthogonalization v = vecs[order - 1, :].reshape(swd.shape) v_old = vecs[order - 2, :].reshape(swd.shape) w = mat(v) - beta * v_old alpha =, v) tridiag =[(order - 1, order - 1)].set(alpha) w -= alpha * v vecs, w = fori_loop(0, order - 1, reortho_step, (vecs, w)) return (tridiag, vecs)
[docs] def stochastic_logdet_from_lanczos( tridiag_stack: jnp.ndarray, matrix_shape0: int, func: Callable = jnp.log ): """Computes a stochastic estimate of the log-determinate of a matrix using its Lanczos decomposition. Implemented via the stoachstic Lanczos quadrature. """ eig_vals, eig_vecs = jnp.linalg.eigh(tridiag_stack) # TODO: Mask Eigenvalues <= 0? num_random_probes = tridiag_stack.shape[0] eig_ves_first_component = eig_vecs[..., 0, :] func_of_eig_vals = func(eig_vals) dot_products = jnp.sum(eig_ves_first_component**2 * func_of_eig_vals) return matrix_shape0 / float(num_random_probes) * dot_products
[docs] def stochastic_lq_logdet( mat: Union[jnp.ndarray, Callable], order: int, n_samples: int, key: Union[int, jnp.ndarray], *, shape0: Optional[int] = None, dtype=None, cmap=jax.vmap, ): """Computes a stochastic estimate of the log-determinate of a matrix using the stochastic Lanczos quadrature algorithm. """ shape0 = shape0 if shape0 is not None else mat.shape[0] mat = mat.__matmul__ if not hasattr(mat, "__call__") else mat if not isinstance(key, jnp.ndarray): key = random.PRNGKey(key) key_smpls = random.split(key, n_samples) def random_lanczos(k): v = random.rademacher(k, (shape0, ), dtype=dtype) tri, _ = lanczos_tridiag(mat, v, order=order) return tri tridiags = cmap(random_lanczos)(key_smpls) return stochastic_logdet_from_lanczos(tridiags, shape0)