# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
from functools import partial
from typing import Callable, NamedTuple, TypeVar, Union
from jax import lax
from jax import numpy as jnp
from jax import random, tree_util
from jax.experimental import host_callback
from jax.scipy.special import expit
from .lax import cond, fori_loop, while_loop
from .tree_math import random_like
_DEBUG_FLAG = False
_DEBUG_TREE_END_IDXS = []
_DEBUG_SUBTREE_END_IDXS = []
_DEBUG_STORE = []
def _DEBUG_ADD_QP(qp):
"""Stores **all** results of leapfrog integration"""
global _DEBUG_STORE
_DEBUG_STORE.append(qp)
def _DEBUG_FINISH_TREE(dummy_arg):
"""Signal the position of a finished tree in `_DEBUG_STORE`"""
global _DEBUG_TREE_END_IDXS
_DEBUG_TREE_END_IDXS.append(len(_DEBUG_STORE))
def _DEBUG_FINISH_SUBTREE(dummy_arg):
"""Signal the position of a finished sub-tree in `_DEBUG_STORE`"""
global _DEBUG_SUBTREE_END_IDXS
_DEBUG_SUBTREE_END_IDXS.append(len(_DEBUG_STORE))
[docs]
def select(pred, on_true, on_false):
return tree_util.tree_map(partial(lax.select, pred), on_true, on_false)
### COMMON FUNCTIONALITY
Q = TypeVar("Q")
[docs]
class QP(NamedTuple):
"""Object holding a pair of position and momentum.
Attributes
----------
position : Q
Position.
momentum : Q
Momentum.
"""
position: Q
momentum: Q
[docs]
def flip_momentum(qp: QP) -> QP:
return QP(position=qp.position, momentum=-qp.momentum)
[docs]
def sample_momentum_from_diagonal(*, key, mass_matrix_sqrt):
"""
Draw a momentum sample from the kinetic energy of the hamiltonian.
Parameters
----------
key: ndarray
PRNGKey used as the random key.
mass_matrix_sqrt: ndarray
The left square-root mass matrix (i.e. square-root of the inverse
diagonal covariance) to use for sampling. Diagonal matrix represented
as (possibly pytree of) ndarray vector containing the entries of the
diagonal.
"""
normal = random_like(key=key, primals=mass_matrix_sqrt, rng=random.normal)
return tree_util.tree_map(jnp.multiply, mass_matrix_sqrt, normal)
# TODO: how to randomize step size (neal sect. 3.2)
# @partial(jit, static_argnames=('potential_energy_gradient',))
[docs]
def leapfrog_step(
potential_energy_gradient,
kinetic_energy_gradient,
step_size,
inverse_mass_matrix,
qp: QP,
):
"""
Perform one iteration of the leapfrog integrator forwards in time.
Parameters
----------
potential_energy_gradient: Callable[[ndarray], float]
Potential energy gradient part of the hamiltonian (V). Depends on
position only.
qp: QP
Point in position and momentum space from which to start integration.
step_size: float
Step length (usually called epsilon) of the leapfrog integrator.
"""
position = qp.position
momentum = qp.momentum
momentum_halfstep = (
momentum - (step_size / 2.) * potential_energy_gradient(position)
)
position_fullstep = position + step_size * kinetic_energy_gradient(
inverse_mass_matrix, momentum_halfstep
)
momentum_fullstep = (
momentum_halfstep -
(step_size / 2.) * potential_energy_gradient(position_fullstep)
)
qp_fullstep = QP(position=position_fullstep, momentum=momentum_fullstep)
global _DEBUG_FLAG
if _DEBUG_FLAG:
# append result to global list variable
host_callback.call(_DEBUG_ADD_QP, qp_fullstep)
return qp_fullstep
### SIMPLE HMC
[docs]
class AcceptedAndRejected(NamedTuple):
accepted_qp: QP
rejected_qp: QP
accepted: Union[jnp.ndarray, bool]
diverging: Union[jnp.ndarray, bool]
# @partial(jit, static_argnames=('potential_energy', 'potential_energy_gradient'))
[docs]
def generate_hmc_acc_rej(
*, key, initial_qp, potential_energy, kinetic_energy, inverse_mass_matrix,
stepper, num_steps, step_size, max_energy_difference
) -> AcceptedAndRejected:
"""
Generate a sample given the initial position.
Parameters
----------
key: ndarray
a PRNGKey used as the random key
position: ndarray
The the starting position of this step of the markov chain.
potential_energy: Callable[[ndarray], float]
The potential energy, which is the distribution to be sampled from.
mass_matrix: ndarray
The mass matrix used in the kinetic energy
num_steps: int
The number of steps the leapfrog integrator should perform.
step_size: float
The step size (usually epsilon) for the leapfrog integrator.
"""
loop_body = partial(stepper, step_size, inverse_mass_matrix)
new_qp = fori_loop(
lower=0,
upper=num_steps,
body_fun=lambda _, args: loop_body(args),
init_val=initial_qp
)
# this flipping is needed to make the proposal distribution symmetric
# doesn't have any effect on acceptance though because kinetic energy depends on momentum^2
# might have an effect with other kinetic energies though
proposed_qp = flip_momentum(new_qp)
total_energy = partial(
total_energy_of_qp,
potential_energy=potential_energy,
kinetic_energy_w_inv_mass=partial(kinetic_energy, inverse_mass_matrix)
)
energy_diff = total_energy(initial_qp) - total_energy(proposed_qp)
energy_diff = jnp.where(jnp.isnan(energy_diff), jnp.inf, energy_diff)
transition_probability = jnp.minimum(1., jnp.exp(energy_diff))
accept = random.bernoulli(key, transition_probability)
accepted_qp, rejected_qp = select(
accept,
(proposed_qp, initial_qp),
(initial_qp, proposed_qp),
)
diverging = jnp.abs(energy_diff) > max_energy_difference
return AcceptedAndRejected(
accepted_qp, rejected_qp, accepted=accept, diverging=diverging
)
### NUTS
[docs]
class Tree(NamedTuple):
"""Object carrying tree metadata.
Attributes
----------
left, right : QP
Respective endpoints of the trees path.
logweight: Union[jnp.ndarray, float]
Sum over all -H(q, p) in the tree's path.
proposal_candidate: QP
Sample from the trees path, distributed as exp(-H(q, p)).
turning: Union[jnp.ndarray, bool]
Indicator for either the left or right endpoint are a uturn or any
subtree is a uturn.
diverging: Union[jnp.ndarray, bool]
Indicator for a large increase in energy in the next larger tree.
depth: Union[jnp.ndarray, int]
Levels of the tree.
cumulative_acceptance: Union[jnp.ndarray, float]
Sum of all acceptance probabilities relative to some initial energy
value. This value is distinct from `logweight` as its absolute value is
only well defined for the very final tree of NUTS.
"""
left: QP
right: QP
logweight: Union[jnp.ndarray, float]
proposal_candidate: QP
turning: Union[jnp.ndarray, bool]
diverging: Union[jnp.ndarray, bool]
depth: Union[jnp.ndarray, int]
cumulative_acceptance: Union[jnp.ndarray, float]
[docs]
def total_energy_of_qp(qp, potential_energy, kinetic_energy_w_inv_mass):
return potential_energy(qp.position
) + kinetic_energy_w_inv_mass(qp.momentum)
[docs]
def generate_nuts_tree(
initial_qp,
key,
step_size,
max_tree_depth,
stepper: Callable[[Union[jnp.ndarray, float], Q, QP], QP],
potential_energy,
kinetic_energy: Callable[[Q, Q], float],
inverse_mass_matrix: Q,
bias_transition: bool = True,
max_energy_difference: Union[jnp.ndarray, float] = jnp.inf
) -> Tree:
"""Generate a sample given the initial position.
This call implements a No-U-Turn-Sampler.
Parameters
----------
initial_qp: QP
Starting pair of (position, momentum). **NOTE**, the momentum must be
resampled from conditional distribution **BEFORE** passing it into this
function!
key: ndarray
PRNGKey used as the random key.
step_size: float
Step size (usually called epsilon) for the leapfrog integrator.
max_tree_depth: int
The maximum depth of the trajectory tree before the expansion is
terminated. At the maximum iteration depth, the current value is
returned even if the U-turn condition is not met. The maximum number of
points (/integration steps) per trajectory is :math:`N =
2^{\\mathrm{max\\_tree\\_depth}}`. This function requires memory linear
in max_tree_depth, i.e. logarithmic in trajectory length. It is used to
statically allocate memory in advance.
stepper: Callable[[float, Q, QP], QP]
The function that performs (Leapfrog) steps. Takes as arguments (in order)
1) step size (containing the direction): float ,
2) inverse mass matrix: Q ,
3) starting point: QP .
potential_energy: Callable[[Q], float]
The potential energy, of the distribution to be sampled from. Takes
only the position part (QP.position) as argument.
kinetic_energy: Callable[[Q, Q], float], optional
Mapping of the momentum to its corresponding kinetic energy. As
argument the function takes the inverse mass matrix and the momentum.
Returns
-------
current_tree: Tree
The final tree, carrying a sample from the target distribution.
See Also
--------
No-U-Turn Sampler original paper (2011): https://arxiv.org/abs/1111.4246
NumPyro Iterative NUTS paper: https://arxiv.org/abs/1912.11554
Combination of samples from two trees, Sampling from trajectories according to target distribution in this paper's Appendix: https://arxiv.org/abs/1701.02434
"""
# initialize depth 0 tree, containing 2**0 = 1 points
initial_neg_energy = -total_energy_of_qp(
initial_qp, potential_energy,
partial(kinetic_energy, inverse_mass_matrix)
)
current_tree = Tree(
left=initial_qp,
right=initial_qp,
logweight=initial_neg_energy,
proposal_candidate=initial_qp,
turning=False,
diverging=False,
depth=0,
cumulative_acceptance=jnp.zeros_like(initial_neg_energy)
)
def _cont_cond(loop_state):
_, current_tree, stop = loop_state
return (~stop) & (current_tree.depth <= max_tree_depth)
def cond_tree_doubling(loop_state):
key, current_tree, _ = loop_state
key, key_dir, key_subtree, key_merge = random.split(key, 4)
go_right = random.bernoulli(key_dir, 0.5)
# build tree adjacent to current_tree
new_subtree = iterative_build_tree(
key_subtree,
current_tree,
step_size,
go_right,
stepper,
potential_energy,
kinetic_energy,
inverse_mass_matrix,
max_tree_depth,
initial_neg_energy=initial_neg_energy,
max_energy_difference=max_energy_difference
)
# Mark current tree as diverging if it diverges in the next step
current_tree = current_tree._replace(diverging=new_subtree.diverging)
# combine current_tree and new_subtree into a tree which is one layer deeper only if new_subtree has no turning subtrees (including itself)
current_tree = cond(
# If new tree is turning or diverging, do not merge
pred=new_subtree.turning | new_subtree.diverging,
true_fun=lambda old_and_new: old_and_new[0],
false_fun=lambda old_and_new: merge_trees(
key_merge,
old_and_new[0],
old_and_new[1],
go_right,
bias_transition=bias_transition
),
operand=(current_tree, new_subtree),
)
# stop if new subtree was turning -> we sample from the old one and don't expand further
# stop if new total tree is turning -> we sample from the combined trajectory and don't expand further
stop = new_subtree.turning | current_tree.turning
stop |= new_subtree.diverging
return (key, current_tree, stop)
loop_state = (key, current_tree, False)
_, current_tree, _ = while_loop(_cont_cond, cond_tree_doubling, loop_state)
global _DEBUG_FLAG
if _DEBUG_FLAG:
host_callback.call(_DEBUG_FINISH_TREE, None)
return current_tree
[docs]
def tree_index_get(ptree, idx):
return tree_util.tree_map(lambda arr: arr[idx], ptree)
[docs]
def tree_index_update(x, idx, y):
from jax.tree_util import tree_map
return tree_map(lambda x_el, y_el: x_el.at[idx].set(y_el), x, y)
[docs]
def count_trailing_ones(n):
"""Count the number of trailing, consecutive ones in the binary
representation of `n`.
Warning
-------
`n` must be positive and strictly smaller than 2**64
Examples
--------
>>> print(bin(23), count_trailing_one_bits(23))
0b10111 3
"""
# taken from http://num.pyro.ai/en/stable/_modules/numpyro/infer/hmc_util.html
_, trailing_ones_count = while_loop(
lambda nc: (nc[0] & 1) != 0, lambda nc: (nc[0] >> 1, nc[1] + 1), (n, 0)
)
return trailing_ones_count
[docs]
def is_euclidean_uturn(qp_left, qp_right):
"""
See Also
--------
Betancourt - A conceptual introduction to Hamiltonian Monte Carlo
"""
return (
(qp_right.momentum.dot(qp_right.position - qp_left.position) < 0.) &
(qp_left.momentum.dot(qp_left.position - qp_right.position) < 0.)
)
# Essentially algorithm 2 from https://arxiv.org/pdf/1912.11554.pdf
[docs]
def iterative_build_tree(
key, initial_tree, step_size, go_right, stepper, potential_energy,
kinetic_energy, inverse_mass_matrix, max_tree_depth, initial_neg_energy,
max_energy_difference
):
"""
Starting from either the left or right endpoint of a given tree, builds a
new adjacent tree of the same size.
Parameters
----------
key: ndarray
PRNGKey to choose a sample when adding QPs to the tree.
initial_tree: Tree
Tree to be extended (doubled) on the left or right.
step_size: float
The step size (usually called epsilon) for the leapfrog integrator.
go_right: bool
If `go_right` start at the right end, going right else start at the
left end, going left.
stepper: Callable[[float, Q, QP], QP]
The function that performs (Leapfrog) steps. Takes as arguments (in order)
1) step size (containing the direction): float ,
2) inverse mass matrix: Q ,
3) starting point: QP .
potential_energy: Callable[[Q], float]
Potential energy, of the distribution to be sampled from. Takes
only the position part (QP.position) as argument.
kinetic_energy: Callable[[Q, Q], float], optional
Mapping of the momentum to its corresponding kinetic energy. As
argument the function takes the inverse mass matrix and the momentum.
max_tree_depth: int
An upper bound on the 'depth' argument, but has no effect on the
functions behaviour. It's only required to statically set the size of
the `S` array (Q).
"""
# 1. choose start point of integration
z = select(go_right, initial_tree.right, initial_tree.left)
depth = initial_tree.depth
max_num_proposals = 2**depth
# 2. build / collect new states
# Create a storage for left endpoints of subtrees. Size is determined
# statically by the `max_tree_depth` parameter.
# NOTE, let's hope this does not break anything but in principle we only
# need `max_tree_depth` element even though the tree can be of length `max_tree_depth +
# 1`. This is because we will never access the last element.
S = tree_util.tree_map(
lambda proto: jnp.
empty_like(proto, shape=(max_tree_depth, ) + jnp.shape(proto)), z
)
z = stepper(
jnp.where(go_right, 1., -1.) * step_size, inverse_mass_matrix, z
)
neg_energy = -total_energy_of_qp(
z, potential_energy, partial(kinetic_energy, inverse_mass_matrix)
)
diverging = jnp.abs(neg_energy - initial_neg_energy) > max_energy_difference
cum_acceptance = jnp.minimum(1., jnp.exp(initial_neg_energy - neg_energy))
incomplete_tree = Tree(
left=z,
right=z,
logweight=neg_energy,
proposal_candidate=z,
turning=False,
diverging=diverging,
depth=-1,
cumulative_acceptance=cum_acceptance
)
S = tree_index_update(S, 0, z)
def amend_incomplete_tree(state):
n, incomplete_tree, z, S, key = state
key, key_choose_candidate = random.split(key)
z = stepper(
jnp.where(go_right, 1., -1.) * step_size, inverse_mass_matrix, z
)
incomplete_tree = add_single_qp_to_tree(
key_choose_candidate,
incomplete_tree,
z,
go_right,
potential_energy,
kinetic_energy,
inverse_mass_matrix,
initial_neg_energy=initial_neg_energy,
max_energy_difference=max_energy_difference
)
def _even_fun(S):
# n is even, the current z is w.l.o.g. a left endpoint of some
# subtrees. Register the current z to be used in turning condition
# checks later, when the right endpoints of it's subtrees are
# generated.
S = tree_index_update(S, lax.population_count(n), z)
return S, False
def _odd_fun(S):
# n is odd, the current z is w.l.o.g a right endpoint of some
# subtrees. Check turning condition against all left endpoints of
# subtrees that have the current z (/n) as their right endpoint.
# l = nubmer of subtrees that have current z as their right endpoint.
l = count_trailing_ones(n)
# inclusive indices into S referring to the left endpoints of the l subtrees.
i_max_incl = lax.population_count(n - 1)
i_min_incl = i_max_incl - l + 1
# TODO: this should traverse the range in reverse
turning = fori_loop(
lower=i_min_incl,
upper=i_max_incl + 1,
# TODO: conditional for early termination
body_fun=lambda k, turning: turning |
is_euclidean_uturn(tree_index_get(S, k), z),
init_val=False
)
return S, turning
S, turning = cond(
pred=n % 2 == 0, true_fun=_even_fun, false_fun=_odd_fun, operand=S
)
incomplete_tree = incomplete_tree._replace(turning=turning)
return (n + 1, incomplete_tree, z, S, key)
def _cont_cond(state):
n, incomplete_tree, *_ = state
return (n < max_num_proposals) & (~incomplete_tree.turning
) & (~incomplete_tree.diverging)
n, incomplete_tree, *_ = while_loop(
# while n < 2**depth and not stop
cond_fun=_cont_cond,
body_fun=amend_incomplete_tree,
init_val=(1, incomplete_tree, z, S, key)
)
global _DEBUG_FLAG
if _DEBUG_FLAG:
host_callback.call(_DEBUG_FINISH_SUBTREE, None)
# The depth of a tree which was aborted early is possibly ill defined
depth = jnp.where(n == max_num_proposals, depth, -1)
return incomplete_tree._replace(depth=depth)
[docs]
def add_single_qp_to_tree(
key, tree, qp, go_right, potential_energy, kinetic_energy,
inverse_mass_matrix, initial_neg_energy, max_energy_difference
):
"""Helper function for progressive sampling. Takes a tree with a sample, and
a new endpoint, propagates sample.
"""
# This is technically just a special case of merge_trees with one of the
# trees being a singleton, depth 0 tree. However, no turning check is
# required and it is not possible to bias the transition.
left, right = select(go_right, (tree.left, qp), (qp, tree.right))
neg_energy = -total_energy_of_qp(
qp, potential_energy, partial(kinetic_energy, inverse_mass_matrix)
)
diverging = jnp.abs(neg_energy - initial_neg_energy) > max_energy_difference
# ln(e^-H_1 + e^-H_2)
total_logweight = jnp.logaddexp(tree.logweight, neg_energy)
# expit(x-y) := 1 / (1 + e^(-(x-y))) = 1 / (1 + e^(y-x)) = e^x / (e^y + e^x)
prob_of_keeping_old = expit(tree.logweight - neg_energy)
remain = random.bernoulli(key, prob_of_keeping_old)
proposal_candidate = select(remain, tree.proposal_candidate, qp)
# NOTE, set an invalid depth as to indicate that adding a single QP to a
# perfect binary tree does not yield another perfect binary tree
cum_acceptance = tree.cumulative_acceptance + jnp.minimum(
1., jnp.exp(initial_neg_energy - neg_energy)
)
return Tree(
left,
right,
total_logweight,
proposal_candidate,
turning=tree.turning,
diverging=diverging,
depth=-1,
cumulative_acceptance=cum_acceptance
)
[docs]
def merge_trees(key, current_subtree, new_subtree, go_right, bias_transition):
"""Merges two trees, propagating the proposal_candidate"""
# 5. decide which sample to take based on total weights (merge trees)
if bias_transition:
# Bias the transition towards the new subtree (see Betancourt
# conceptual intro (and Numpyro))
transition_probability = jnp.minimum(
1., jnp.exp(new_subtree.logweight - current_subtree.logweight)
)
else:
# expit(x-y) := 1 / (1 + e^(-(x-y))) = 1 / (1 + e^(y-x)) = e^x / (e^y + e^x)
transition_probability = expit(
new_subtree.logweight - current_subtree.logweight
)
# print(f"prob of choosing new sample: {transition_probability}")
new_sample = select(
random.bernoulli(key, transition_probability),
new_subtree.proposal_candidate, current_subtree.proposal_candidate
)
# 6. define new tree
left, right = select(
go_right,
(current_subtree.left, new_subtree.right),
(new_subtree.left, current_subtree.right),
)
turning = is_euclidean_uturn(left, right)
diverging = current_subtree.diverging | new_subtree.diverging
neg_energy = jnp.logaddexp(new_subtree.logweight, current_subtree.logweight)
cum_acceptance = current_subtree.cumulative_acceptance + new_subtree.cumulative_acceptance
merged_tree = Tree(
left=left,
right=right,
logweight=neg_energy,
proposal_candidate=new_sample,
turning=turning,
diverging=diverging,
depth=current_subtree.depth + 1,
cumulative_acceptance=cum_acceptance
)
return merged_tree