nifty8.re.hmc_oo module#

class Chain(samples: Q, divergences: Array, acceptance: Array | float, depths: Array | None = None, trees: Tree | AcceptedAndRejected | None = None)[source]#

Bases: NamedTuple

Object carrying chain metadata; think: transposed Tree with new axis.

acceptance: Array | float#

Alias for field number 2

depths: Array | None#

Alias for field number 3

divergences: Array#

Alias for field number 1

samples: Q#

Alias for field number 0

trees: Tree | AcceptedAndRejected | None#

Alias for field number 4

class HMCChain(potential_energy: Callable, inverse_mass_matrix, position_proto, num_steps, step_size: float = 1.0, max_energy_difference: float = inf)[source]#

Bases: _Sampler

__init__(potential_energy: Callable, inverse_mass_matrix, position_proto, num_steps, step_size: float = 1.0, max_energy_difference: float = inf)[source]#
static init_chain(num_samples: int, position_proto, save_intermediates: bool) Chain[source]#
static update_chain(chain: Chain, idx: Array | int, acc_rej: AcceptedAndRejected) Chain[source]#
class NUTSChain(potential_energy: Callable[[Q], Array | float], inverse_mass_matrix, position_proto: Q, step_size: float = 1.0, max_tree_depth: int = 10, bias_transition: bool = True, max_energy_difference: float = inf)[source]#

Bases: _Sampler

__init__(potential_energy: Callable[[Q], Array | float], inverse_mass_matrix, position_proto: Q, step_size: float = 1.0, max_tree_depth: int = 10, bias_transition: bool = True, max_energy_difference: float = inf)[source]#
static init_chain(num_samples: int, position_proto, save_intermediates: bool) Chain[source]#
static update_chain(chain: Chain, idx: Array | int, tree: Tree) Chain[source]#