Source code for nifty8.re.hmc_oo

# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause

from functools import partial
from typing import Any, Callable, NamedTuple, Optional, Tuple, Union

import numpy as np
from jax import grad
from jax import numpy as jnp
from jax import random, tree_util

from .lax import fori_loop
from .hmc import AcceptedAndRejected, Q, QP, Tree
from .hmc import (
    generate_hmc_acc_rej,
    generate_nuts_tree,
    leapfrog_step,
    sample_momentum_from_diagonal,
    tree_index_update,
)


def _parse_diag_mass_matrix(mass_matrix, position_proto: Q) -> Q:
    if isinstance(mass_matrix,
                  (float, jnp.ndarray)) and jnp.size(mass_matrix) == 1:
        mass_matrix = tree_util.tree_map(
            partial(jnp.full_like, fill_value=mass_matrix), position_proto
        )
    elif tree_util.tree_structure(mass_matrix
                                 ) == tree_util.tree_structure(position_proto):
        shape_match_tree = tree_util.tree_map(
            lambda a1, a2: jnp.shape(a1) == jnp.shape(a2), mass_matrix,
            position_proto
        )
        shape_and_structure_match = all(
            tree_util.tree_flatten(shape_match_tree)
        )
        if not shape_and_structure_match:
            ve = "matrix has same tree_structe as the position but shapes do not match up"
            raise ValueError(ve)
    else:
        te = "matrix must either be float or have same tree structure as the position"
        raise TypeError(te)

    return mass_matrix


[docs] class Chain(NamedTuple): """Object carrying chain metadata; think: transposed Tree with new axis. """ # Q but with one more dimension on the first axes of the leave tensors samples: Q divergences: jnp.ndarray acceptance: Union[jnp.ndarray, float] depths: Optional[jnp.ndarray] = None trees: Optional[Union[Tree, AcceptedAndRejected]] = None
class _Sampler: def __init__( self, potential_energy: Callable[[Q], Union[jnp.ndarray, float]], inverse_mass_matrix, position_proto: Q, step_size: Union[jnp.ndarray, float] = 1.0, max_energy_difference: Union[jnp.ndarray, float] = jnp.inf ): if not callable(potential_energy): raise TypeError() if not isinstance(step_size, (jnp.ndarray, float)): raise TypeError() self.potential_energy = potential_energy self.inverse_mass_matrix = _parse_diag_mass_matrix( inverse_mass_matrix, position_proto=position_proto ) self.mass_matrix_sqrt = self.inverse_mass_matrix**(-0.5) self.step_size = step_size def kinetic_energy(inverse_mass_matrix, momentum): # NOTE, assume a diagonal mass-matrix return inverse_mass_matrix.dot(momentum**2) / 2. self.kinetic_energy = kinetic_energy kinetic_energy_gradient = lambda inv_m, mom: inv_m * mom potential_energy_gradient = grad(self.potential_energy) self.stepper = partial( leapfrog_step, potential_energy_gradient, kinetic_energy_gradient ) self.max_energy_difference = max_energy_difference def sample_next_state(key, prev_position: Q) -> Tuple[Any, Tuple[Any, Q]]: raise NotImplementedError() self.sample_next_state = sample_next_state @staticmethod def init_chain( num_samples: int, position_proto, save_intermediates: bool ) -> Chain: raise NotImplementedError() @staticmethod def update_chain( chain: Chain, idx: Union[jnp.ndarray, int], tree: Tree ) -> Chain: raise NotImplementedError() def generate_n_samples( self, key: Any, initial_position: Q, num_samples, *, save_intermediates: bool = False ) -> Tuple[Chain, Tuple[Any, Q]]: if not isinstance(key, (jnp.ndarray, np.ndarray)): if isinstance(key, int): key = random.PRNGKey(key) else: raise TypeError() chain = self.init_chain( num_samples, initial_position, save_intermediates ) def amend_chain(idx, state): chain, core_state = state tree, core_state = self.sample_next_state(*core_state) chain = self.update_chain(chain, idx, tree) return chain, core_state chain, core_state = fori_loop( lower=0, upper=num_samples, body_fun=amend_chain, init_val=(chain, (key, initial_position)) ) return chain, core_state
[docs] class NUTSChain(_Sampler):
[docs] def __init__( self, potential_energy: Callable[[Q], Union[float, jnp.ndarray]], inverse_mass_matrix, position_proto: Q, step_size: float = 1.0, max_tree_depth: int = 10, bias_transition: bool = True, max_energy_difference: float = jnp.inf ): super().__init__( potential_energy=potential_energy, inverse_mass_matrix=inverse_mass_matrix, position_proto=position_proto, step_size=step_size, max_energy_difference=max_energy_difference ) if not isinstance(max_tree_depth, int): raise TypeError() self.bias_transition = bias_transition self.max_tree_depth = max_tree_depth def sample_next_state(key, prev_position: Q) -> Tuple[Tree, Tuple[Any, Q]]: key, key_momentum, key_nuts = random.split(key, 3) resampled_momentum = sample_momentum_from_diagonal( key=key_momentum, mass_matrix_sqrt=self.mass_matrix_sqrt ) qp = QP(position=prev_position, momentum=resampled_momentum) tree = generate_nuts_tree( initial_qp=qp, key=key_nuts, step_size=self.step_size, max_tree_depth=self.max_tree_depth, stepper=self.stepper, potential_energy=self.potential_energy, kinetic_energy=self.kinetic_energy, inverse_mass_matrix=self.inverse_mass_matrix, bias_transition=self.bias_transition, max_energy_difference=self.max_energy_difference ) return tree, (key, tree.proposal_candidate.position) self.sample_next_state = sample_next_state
[docs] @staticmethod def init_chain( num_samples: int, position_proto, save_intermediates: bool ) -> Chain: samples = tree_util.tree_map( lambda arr: jnp. zeros_like(arr, shape=(num_samples, ) + jnp.shape(arr)), position_proto ) depths = jnp.zeros(num_samples, dtype=jnp.uint64) divergences = jnp.zeros(num_samples, dtype=bool) chain = Chain( samples=samples, divergences=divergences, acceptance=0., depths=depths ) if save_intermediates: _qp_proto = QP(position_proto, position_proto) _tree_proto = Tree( _qp_proto, _qp_proto, 0., _qp_proto, turning=True, diverging=True, depth=0, cumulative_acceptance=0. ) trees = tree_util.tree_map( lambda leaf: jnp. zeros_like(leaf, shape=(num_samples, ) + jnp.shape(leaf)), _tree_proto ) chain = chain._replace(trees=trees) return chain
[docs] @staticmethod def update_chain( chain: Chain, idx: Union[jnp.ndarray, int], tree: Tree ) -> Chain: num_proposals = 2**jnp.array(tree.depth, dtype=jnp.uint64) - 1 tree_acceptance = jnp.where( num_proposals > 0, tree.cumulative_acceptance / num_proposals, 0. ) samples = tree_index_update( chain.samples, idx, tree.proposal_candidate.position ) divergences = chain.divergences.at[idx].set(tree.diverging) depths = chain.depths.at[idx].set(tree.depth) acceptance = ( chain.acceptance + (tree_acceptance - chain.acceptance) / (idx + 1) ) chain = chain._replace( samples=samples, divergences=divergences, acceptance=acceptance, depths=depths ) if chain.trees is not None: trees = tree_index_update(chain.trees, idx, tree) chain = chain._replace(trees=trees) return chain
[docs] class HMCChain(_Sampler):
[docs] def __init__( self, potential_energy: Callable, inverse_mass_matrix, position_proto, num_steps, step_size: float = 1.0, max_energy_difference: float = jnp.inf ): super().__init__( potential_energy=potential_energy, inverse_mass_matrix=inverse_mass_matrix, position_proto=position_proto, step_size=step_size, max_energy_difference=max_energy_difference ) if not isinstance(num_steps, (jnp.ndarray, int)): raise TypeError() self.num_steps = num_steps def sample_next_state(key, prev_position: Q) -> Tuple[Tree, Tuple[Any, Q]]: key, key_choose, key_momentum_resample = random.split(key, 3) resampled_momentum = sample_momentum_from_diagonal( key=key_momentum_resample, mass_matrix_sqrt=self.mass_matrix_sqrt ) qp = QP(position=prev_position, momentum=resampled_momentum) acc_rej = generate_hmc_acc_rej( key=key_choose, initial_qp=qp, potential_energy=self.potential_energy, kinetic_energy=self.kinetic_energy, inverse_mass_matrix=self.inverse_mass_matrix, stepper=self.stepper, num_steps=self.num_steps, step_size=self.step_size, max_energy_difference=self.max_energy_difference ) return acc_rej, (key, acc_rej.accepted_qp.position) self.sample_next_state = sample_next_state
[docs] @staticmethod def init_chain( num_samples: int, position_proto, save_intermediates: bool ) -> Chain: samples = tree_util.tree_map( lambda arr: jnp. zeros_like(arr, shape=(num_samples, ) + jnp.shape(arr)), position_proto ) divergences = jnp.zeros(num_samples, dtype=bool) chain = Chain(samples=samples, divergences=divergences, acceptance=0.) if save_intermediates: _qp_proto = QP(position_proto, position_proto) _acc_rej_proto = AcceptedAndRejected( _qp_proto, _qp_proto, True, True ) trees = tree_util.tree_map( lambda leaf: jnp. zeros_like(leaf, shape=(num_samples, ) + jnp.shape(leaf)), _acc_rej_proto ) chain = chain._replace(trees=trees) return chain
[docs] @staticmethod def update_chain( chain: Chain, idx: Union[jnp.ndarray, int], acc_rej: AcceptedAndRejected ) -> Chain: samples = tree_index_update( chain.samples, idx, acc_rej.accepted_qp.position ) divergences = chain.divergences.at[idx].set(acc_rej.diverging) acceptance = ( chain.acceptance + (acc_rej.accepted - chain.acceptance) / (idx + 1) ) chain = chain._replace( samples=samples, divergences=divergences, acceptance=acceptance ) if chain.trees is not None: trees = tree_index_update(chain.trees, idx, acc_rej) chain = chain._replace(trees=trees) return chain