Source code for nifty8.re.minisanity

# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause

import pprint
from typing import Any, NamedTuple, TypeVar

import jax
from jax import numpy as jnp
from jax.tree_util import tree_map

from .evi import Samples
from .tree_math import Vector, get_map

O = TypeVar("O")
I = TypeVar("I")


def _residual_params(inp):
    ndof = inp.size if jnp.isrealobj(inp) else 2 * inp.size
    mean = jnp.sum(inp) / inp.size
    rchisq = jnp.vdot(inp, inp).real / ndof
    return mean, rchisq, ndof


[docs] class ChiSqStats(NamedTuple): mean: Any reduced_chisq: Any ndof: Any
[docs] def reduced_residual_stats(position_or_samples, func=None, *, map="lmap"): """Computes the average, reduced chi-squared, and number of parameters as a summary statistics for a given input. Parameters: ----------- position_or_samples: tree-like or Samples Input values to compute reduces chi-sq statistics. The statistics is computed for each leaf of the pytree, i.E. only array-like leafs are square averaged. If `positin_or_samples` is a `Sample` object, the chi-sq statistics is computed for each sample, and the sample mean and standard deviation of the statistics is returned. func: Callable (optional) Function to apply to `position_or_samples` before computing the chi-sq statistics for. If provided, the statistics is computed for `func(x)` instead of `x` where x is either primals or a sample. Returns: -------- stats: tree-like Pytree of tuple containing the mean, reduced chi-squared, and number of parameters for each leaf of the input tree. For the mean and reduched chi-sq, a numpy array with the sample mean and sample std is returned. If samples is None, the second entry of this array is always zero. """ map = get_map(map) if not isinstance(position_or_samples, Samples) or len(position_or_samples) == 0: if isinstance(position_or_samples, Samples): assert len(position_or_samples) == 0 position_or_samples = position_or_samples.pos samples = tree_map(lambda x: x[jnp.newaxis, ...], position_or_samples) else: assert isinstance(position_or_samples, Samples) samples = position_or_samples.samples samples = map(func)(samples) if func is not None else samples get_stats = map(_residual_params) def red_chisq_stat(s): m, rx, nd = get_stats(s) m = jnp.array([jnp.mean(m), jnp.std(m)]) rx = jnp.array([jnp.mean(rx), jnp.std(rx)]) return ChiSqStats(m, rx, nd[0]) return tree_map(red_chisq_stat, samples)
[docs] def minisanity(position_or_samples, func=None, *, map="lmap"): """Wrapper for `reduced_residual_stats` to retrieve the reduced chi-squared and a pretty-printable string of the statistics.""" stat_tree = reduced_residual_stats(position_or_samples, func=func, map=map) def make_pretty_string(x): rsq = x.reduced_chisq s = ( f"reduced χ²:{rsq[0]:8.2}±{rsq[1]:8.2}" f", avg:{x.mean[0]:+9.2}±{x.mean[1]:8.2}" f", #dof:{int(x.ndof):7d}" ) return s def is_leaf(l): return isinstance(l, ChiSqStats) ps = tree_map(make_pretty_string, stat_tree, is_leaf=is_leaf) # HACK to make the most common primal types look pretty ps = ps.tree if isinstance(ps, Vector) else ps pp = pprint.PrettyPrinter() if isinstance(ps, dict): msg = "" for k in sorted(ps.keys()): v = ps[k] if isinstance(v, str): msg += f"{str(k):22s}:: {v}\n" else: msg += f"{str(k):22s}::\n{pp.pformat(v)}\n" elif not isinstance(ps, str): msg = pp.pformat(ps) else: msg = ps return stat_tree, msg