Source code for nifty8.extra

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2022 Max-Planck-Society
# Author: Philipp Arras
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.

from itertools import combinations

import numpy as np

from .domain_tuple import DomainTuple
from .field import Field
from .linearization import Linearization
from .multi_domain import MultiDomain
from .multi_field import MultiField
from .operators.adder import Adder
from .operators.endomorphic_operator import EndomorphicOperator
from .operators.energy_operators import (EnergyOperator,
                                         LikelihoodEnergyOperator)
from .operators.linear_operator import LinearOperator
from .operators.operator import Operator
from .probing import StatCalculator
from .sugar import from_random, is_fieldlike, is_operator
from .utilities import issingleprec, myassert

__all__ = ["check_linear_operator", "check_operator", "assert_allclose", "minisanity"]


[docs] def check_linear_operator(op, domain_dtype=np.float64, target_dtype=np.float64, atol=1e-12, rtol=1e-12, only_r_linear=False): """Checks an operator for algebraic consistency of its capabilities. Checks whether times(), adjoint_times(), inverse_times() and adjoint_inverse_times() (if in capability list) is implemented consistently. Additionally, it checks whether the operator is linear. Parameters ---------- op : LinearOperator Operator which shall be checked. domain_dtype : dtype The data type of the random vectors in the operator's domain. Default is `np.float64`. target_dtype : dtype The data type of the random vectors in the operator's target. Default is `np.float64`. atol : float Absolute tolerance for the check. If rtol is specified, then satisfying any tolerance will let the check pass. Default: 0. rtol : float Relative tolerance for the check. If atol is specified, then satisfying any tolerance will let the check pass. Default: 0. only_r_linear: bool set to True if the operator is only R-linear, not C-linear. This will relax the adjointness test accordingly. """ if not isinstance(op, LinearOperator): raise TypeError('This test tests only linear operators.') _domain_check_linear(op, domain_dtype) _domain_check_linear(op.adjoint, target_dtype) _domain_check_linear(op.inverse, target_dtype) _domain_check_linear(op.adjoint.inverse, domain_dtype) _purity_check(op, from_random(op.domain, dtype=domain_dtype)) _purity_check(op.adjoint.inverse, from_random(op.domain, dtype=domain_dtype)) _purity_check(op.adjoint, from_random(op.target, dtype=target_dtype)) _purity_check(op.inverse, from_random(op.target, dtype=target_dtype)) _check_linearity(op, domain_dtype, atol, rtol) _check_linearity(op.adjoint, target_dtype, atol, rtol) _check_linearity(op.inverse, target_dtype, atol, rtol) _check_linearity(op.adjoint.inverse, domain_dtype, atol, rtol) _full_implementation(op, domain_dtype, target_dtype, atol, rtol, only_r_linear) _full_implementation(op.adjoint, target_dtype, domain_dtype, atol, rtol, only_r_linear) _full_implementation(op.inverse, target_dtype, domain_dtype, atol, rtol, only_r_linear) _full_implementation(op.adjoint.inverse, domain_dtype, target_dtype, atol, rtol, only_r_linear) _check_sqrt(op, domain_dtype) _check_sqrt(op.adjoint, target_dtype) _check_sqrt(op.inverse, target_dtype) _check_sqrt(op.adjoint.inverse, domain_dtype)
[docs] def check_operator(op, loc, tol=1e-12, ntries=100, perf_check=True, only_r_differentiable=True, metric_sampling=True): """Performs various checks of the implementation of linear and nonlinear operators. Computes the Jacobian with finite differences and compares it to the implemented Jacobian. Parameters ---------- op : Operator Operator which shall be checked. loc : :class:`nifty8.field.Field` or :class:`nifty8.multi_field.MultiField` An Field or MultiField instance which has the same domain as op. The location at which the gradient is checked tol : float Tolerance for the check. perf_check : Boolean Do performance check. May be disabled for very unimportant operators. only_r_differentiable : Boolean Jacobians of C-differentiable operators need to be C-linear. Default: True metric_sampling: Boolean If op is an EnergyOperator, metric_sampling determines whether the test shall try to sample from the metric or not. """ if not isinstance(op, Operator): raise TypeError('This test tests only (nonlinear) operators.') _domain_check_nonlinear(op, loc) _purity_check(op, loc) _performance_check(op, loc, bool(perf_check)) _linearization_value_consistency(op, loc) _jac_vs_finite_differences(op, loc, np.sqrt(tol), ntries, only_r_differentiable) _check_nontrivial_constant(op, loc, tol, ntries, only_r_differentiable, metric_sampling) _check_likelihood_energy(op, loc)
[docs] def assert_allclose(f1, f2, atol=0, rtol=1e-7): if isinstance(f1, Field): return np.testing.assert_allclose(f1.val, f2.val, atol=atol, rtol=rtol) if f1.domain is not f2.domain: raise AssertionError for key, val in f1.items(): assert_allclose(val, f2[key], atol=atol, rtol=rtol)
def assert_equal(f1, f2): if isinstance(f1, Field): return np.testing.assert_equal(f1.val, f2.val) if f1.domain is not f2.domain: raise AssertionError for key, val in f1.items(): assert_equal(val, f2[key]) def _adjoint_implementation(op, domain_dtype, target_dtype, atol, rtol, only_r_linear): needed_cap = op.TIMES | op.ADJOINT_TIMES if (op.capability & needed_cap) != needed_cap: return f1 = from_random(op.domain, "normal", dtype=domain_dtype) f2 = from_random(op.target, "normal", dtype=target_dtype) res1 = f1.s_vdot(op.adjoint_times(f2)) res2 = op.times(f1).s_vdot(f2) if only_r_linear: res1, res2 = res1.real, res2.real np.testing.assert_allclose(res1, res2, atol=atol, rtol=rtol) def _inverse_implementation(op, domain_dtype, target_dtype, atol, rtol): needed_cap = op.TIMES | op.INVERSE_TIMES if (op.capability & needed_cap) != needed_cap: return foo = from_random(op.target, "normal", dtype=target_dtype) res = op(op.inverse_times(foo)) assert_allclose(res, foo, atol=atol, rtol=rtol) foo = from_random(op.domain, "normal", dtype=domain_dtype) res = op.inverse_times(op(foo)) assert_allclose(res, foo, atol=atol, rtol=rtol) def _full_implementation(op, domain_dtype, target_dtype, atol, rtol, only_r_linear): _adjoint_implementation(op, domain_dtype, target_dtype, atol, rtol, only_r_linear) _inverse_implementation(op, domain_dtype, target_dtype, atol, rtol) def _check_linearity(op, domain_dtype, atol, rtol): needed_cap = op.TIMES if (op.capability & needed_cap) != needed_cap: return fld1 = from_random(op.domain, "normal", dtype=domain_dtype) fld2 = from_random(op.domain, "normal", dtype=domain_dtype) alpha = 0.42 val1 = op(alpha*fld1+fld2) val2 = alpha*op(fld1)+op(fld2) assert_allclose(val1, val2, atol=atol, rtol=rtol) def _domain_check_linear(op, domain_dtype=None, inp=None): _domain_check(op) needed_cap = op.TIMES if (op.capability & needed_cap) != needed_cap: return if domain_dtype is not None: inp = from_random(op.domain, "normal", dtype=domain_dtype) elif inp is None: raise ValueError('Need to specify either dtype or inp') myassert(inp.domain is op.domain) myassert(op(inp).domain is op.target) def _check_sqrt(op, domain_dtype): if not isinstance(op, EndomorphicOperator): try: op.get_sqrt() raise RuntimeError("Operator implements get_sqrt() although it is not an endomorphic operator.") except AttributeError: return try: sqop = op.get_sqrt() except (NotImplementedError, ValueError): return fld = from_random(op.domain, dtype=domain_dtype) a = op(fld) b = (sqop.adjoint @ sqop)(fld) return assert_allclose(a, b, rtol=1e-15) def _domain_check_nonlinear(op, loc): _domain_check(op) myassert(isinstance(loc, (Field, MultiField))) myassert(loc.domain is op.domain) for wm in [False, True]: lin = Linearization.make_var(loc, wm) reslin = op(lin) myassert(lin.domain is op.domain) myassert(lin.target is op.domain) myassert(lin.val.domain is lin.domain) myassert(reslin.domain is op.domain) myassert(reslin.target is op.target) myassert(reslin.val.domain is reslin.target) myassert(reslin.target is op.target) myassert(reslin.jac.domain is reslin.domain) myassert(reslin.jac.target is reslin.target) myassert(lin.want_metric == reslin.want_metric) _domain_check_linear(reslin.jac, inp=loc) _domain_check_linear(reslin.jac.adjoint, inp=reslin.jac(loc)) if reslin.metric is not None: myassert(reslin.metric.domain is reslin.metric.target) myassert(reslin.metric.domain is op.domain) def _domain_check(op): for dd in [op.domain, op.target]: if not isinstance(dd, (DomainTuple, MultiDomain)): raise TypeError( 'The domain and the target of an operator need to', 'be instances of either DomainTuple or MultiDomain.') def _performance_check(op, pos, raise_on_fail): class CountingOp(LinearOperator): def __init__(self, domain): from .sugar import makeDomain self._domain = self._target = makeDomain(domain) self._capability = self.TIMES | self.ADJOINT_TIMES self._count = 0 def apply(self, x, mode): self._count += 1 return x @property def count(self): return self._count for wm in [False, True]: cop = CountingOp(op.domain) myop = op @ cop myop(pos) cond = [cop.count != 1] lin = myop(Linearization.make_var(pos, wm)) cond.append(cop.count != 2) lin.jac(pos) cond.append(cop.count != 3) lin.jac.adjoint(lin.val) cond.append(cop.count != 4) if lin.metric is not None: lin.metric(pos) cond.append(cop.count != 6) if any(cond): s = 'The operator has a performance problem (want_metric={}).'.format(wm) from .logger import logger logger.error(s) logger.info(cond) if raise_on_fail: raise RuntimeError(s) def _purity_check(op, pos): if isinstance(op, LinearOperator) and (op.capability & op.TIMES) != op.TIMES: return res0 = op(pos) res1 = op(pos) assert_equal(res0, res1) def _get_acceptable_location(op, loc, lin): if not np.isfinite(lin.val.s_sum()): raise ValueError('Initial value must be finite') direction = from_random(loc.domain, dtype=loc.dtype) dirder = lin.jac(direction) fac = 1e-3 if issingleprec(loc.dtype) else 1e-6 if dirder.norm() == 0: direction = direction * (lin.val.norm() * fac) else: direction = direction * (lin.val.norm() * fac / dirder.norm()) # Find a step length that leads to a "reasonable" location for i in range(50): try: loc2 = loc + direction lin2 = op(Linearization.make_var(loc2, lin.want_metric)) if np.isfinite(lin2.val.s_sum()) and abs(lin2.val.s_sum()) < 1e20: break except FloatingPointError: pass direction = direction * 0.5 else: raise ValueError("could not find a reasonable initial step") return loc2, lin2 def _linearization_value_consistency(op, loc): for wm in [False, True]: lin = Linearization.make_var(loc, wm) fld0 = op(loc) fld1 = op(lin).val assert_allclose(fld0, fld1, 0, 1e-7) def _check_nontrivial_constant(op, loc, tol, ntries, only_r_differentiable, metric_sampling): if isinstance(op.domain, DomainTuple): return keys = op.domain.keys() combis = [] if len(keys) > 4: from .logger import logger logger.warning('Operator domain has more than 4 keys.') logger.warning('Check derivatives only with one constant key at a time.') combis = [[kk] for kk in keys] else: for ll in range(1, len(keys)): combis.extend(list(combinations(keys, ll))) for cstkeys in combis: varkeys = set(keys) - set(cstkeys) cstloc = loc.extract_by_keys(cstkeys) varloc = loc.extract_by_keys(varkeys) val0 = op(loc) _, op0 = op.simplify_for_constant_input(cstloc) myassert(op0.domain is varloc.domain) val1 = op0(varloc) assert_equal(val0, val1) lin = Linearization.make_partial_var(loc, cstkeys, want_metric=True) lin0 = Linearization.make_var(varloc, want_metric=True) oplin0 = op0(lin0) oplin = op(lin) myassert(oplin.jac.target is oplin0.jac.target) rndinp = from_random(oplin.jac.target, dtype=oplin.val.dtype) assert_allclose(oplin.jac.adjoint(rndinp).extract(varloc.domain), oplin0.jac.adjoint(rndinp), 1e-13, 1e-13) foo = oplin.jac.adjoint(rndinp).extract(cstloc.domain) assert_equal(foo, 0*foo) if isinstance(op, EnergyOperator) and metric_sampling: oplin.metric.draw_sample() # _jac_vs_finite_differences(op0, varloc, np.sqrt(tol), ntries, # only_r_differentiable) def _jac_vs_finite_differences(op, loc, tol, ntries, only_r_differentiable): for _ in range(ntries): lin = op(Linearization.make_var(loc)) loc2, lin2 = _get_acceptable_location(op, loc, lin) direction = loc2 - loc locnext = loc2 dirnorm = direction.norm() hist = [] for i in range(50): locmid = loc + 0.5 * direction linmid = op(Linearization.make_var(locmid)) dirder = linmid.jac(direction) numgrad = (lin2.val - lin.val) xtol = tol * dirder.norm() / np.sqrt(dirder.size) hist.append((numgrad - dirder).norm()) # print(len(hist),hist[-1]) if (abs(numgrad - dirder) <= xtol).s_all(): break direction = direction * 0.5 dirnorm *= 0.5 loc2, lin2 = locmid, linmid else: print(hist) raise ValueError("gradient and value seem inconsistent") loc = locnext check_linear_operator(linmid.jac, domain_dtype=loc.dtype, target_dtype=dirder.dtype, only_r_linear=only_r_differentiable, atol=tol**2, rtol=tol**2) def _check_likelihood_energy(op, loc): from .operators.energy_operators import LikelihoodEnergyOperator if not isinstance(op, LikelihoodEnergyOperator): return data_domain = op.data_domain if data_domain is None: return smet = op._sqrt_data_metric_at(loc) myassert(smet.domain == smet.target == data_domain) nres = op.normalized_residual(loc) myassert(nres.domain is data_domain) res = op.get_transformation() if res is None: raise RuntimeError("`get_transformation` is not implemented for " "this LikelihoodEnergyOperator") if len(res) != 2: raise RuntimeError("`get_transformation` has to return a dtype and the transformation")
[docs] def minisanity(likelihood_energy, samples, terminal_colors=True, return_values=False): """Log information about the current fit quality and prior compatibility. Log a table with fitting information for the likelihood and the prior. Assume that the variables in `energy.position.domain` are standard-normal distributed a priori. The table contains the reduced chi^2 value, the mean and the number of degrees of freedom for every key of a `MultiDomain`. If the domain is a `DomainTuple`, the displayed key is `<None>`. If everything is consistent the reduced chi^2 values should be close to one and the mean of the data residuals close to zero. If the reduced chi^2 value in latent space is significantly bigger than one and only one degree of freedom is present, the mean column gives an indication in which direction to change the respective hyper parameters. Ignore all NaN entries in the target of `modeldata_operator` and in `data`. Print reduced chi-square values above 2 and 5 in orange and red, respectively. Parameters ---------- likelihood_energy: LikelihoodEnergyOperator Likelihood energy of which the normalized residuals shall be computed. samples : SampleListBase List of samples. terminal_colors : bool, optional Setting this to false disables terminal colors. This may be useful if the output of minisanity is written to a file. Default: True return_values : bool, optional If true, in addition to the table in string format, `minisanity` will return the computed values as a dictionary. Default: `False`. Returns ------- printable_table : string values : dictionary Only returned if `return_values` is `True` Note ---- For computing the reduced chi^2 values and the normalized residuals, the metric at `mean` is used. """ from .minimization.sample_list import SampleListBase from .sugar import makeDomain if not isinstance(samples, SampleListBase): raise TypeError( "Minisanity takes only SampleLists as input. If you happen to have " "only one field (i.e. no samples), you may wrap it via " "`ift.SampleList([field])` and pass it to minisanity." ) if not isinstance(likelihood_energy, LikelihoodEnergyOperator): return "" data_domain = likelihood_energy.data_domain latent_domain = samples.domain xdoms = [data_domain, latent_domain] keylen = 18 for dom in xdoms: if isinstance(dom, MultiDomain): keylen = max([max(map(len, dom.keys())), keylen]) keylen = min([keylen, 42]) # compute xops xops = [] nres = likelihood_energy.normalized_residual if isinstance(data_domain, MultiDomain): lam = lambda x: nres(x) else: name = likelihood_energy.name if name is None: name = "<None>" data_domain = makeDomain({name: data_domain}) lam = lambda x: nres(x).ducktape_left(name) xops.append(lam) if isinstance(latent_domain, MultiDomain): xops.append(lambda x: x) else: latent_domain = makeDomain({"<None>": latent_domain}) xops.append(lambda x: x.ducktape_left("<None>")) # /compute xops xdoms = [data_domain, latent_domain] xredchisq, xscmean, xndof = [], [], [] for dd in xdoms: xredchisq.append({kk: StatCalculator() for kk in dd.keys()}) xscmean.append({kk: StatCalculator() for kk in dd.keys()}) xndof.append({}) for ss1, ss2 in zip(samples.iterator(xops[0]), samples.iterator(xops[1])): if isinstance(data_domain, MultiDomain): myassert(ss1.domain == data_domain) if isinstance(samples.domain, MultiDomain): myassert(ss2.domain == samples.domain) for ii, ss in enumerate((ss1, ss2)): for kk in ss.domain.keys(): lsize = ss[kk].size - np.sum(np.isnan(ss[kk].val)) xredchisq[ii][kk].add(np.nansum(abs(ss[kk].val) ** 2) / lsize) xscmean[ii][kk].add(np.nanmean(ss[kk].val)) xndof[ii][kk] = lsize for ii in range(2): for kk in xredchisq[ii].keys(): rcs_mean = xredchisq[ii][kk].mean sc_mean = xscmean[ii][kk].mean try: rcs_std = np.sqrt(xredchisq[ii][kk].var) sc_std = np.sqrt(xscmean[ii][kk].var) except RuntimeError: rcs_std = None sc_std = None xredchisq[ii][kk] = {'mean': rcs_mean, 'std': rcs_std} xscmean[ii][kk] = {'mean': sc_mean, 'std': sc_std} s0 = _tableentries(xredchisq[0], xscmean[0], xndof[0], keylen, terminal_colors) s1 = _tableentries(xredchisq[1], xscmean[1], xndof[1], keylen, terminal_colors) n = 38 + keylen s = [n * "=", (keylen + 2) * " " + "{:>11}".format("reduced χ²") + "{:>14}".format("mean") + "{:>11}".format("# dof"), n * "-", "Data residuals", s0, "Latent space", s1, n * "="] res_string = "\n".join(s) if not return_values: return res_string else: res_dict = { 'redchisq': { 'data_residuals': xredchisq[0], 'latent_variables': xredchisq[1] }, 'scmean': { 'data_residuals': xscmean[0], 'latent_variables': xscmean[1], }, 'ndof': { 'data_residuals': xndof[0], 'latent_variables': xndof[1] } } return res_string, res_dict
def _tableentries(redchisq, scmean, ndof, keylen, colors): class _bcolors: WARNING = "\033[33m" if colors else "" FAIL = "\033[31m" if colors else "" ENDC = "\033[0m" if colors else "" BOLD = "\033[1m" if colors else "" out = "" for kk in redchisq.keys(): if len(kk) > keylen: out += " " + kk[: keylen - 1] + "…" else: out += " " + kk.ljust(keylen) foo = f"{redchisq[kk]['mean']:.1f}" if redchisq[kk]['std'] is not None: foo += f" ± {redchisq[kk]['std']:.1f}" if redchisq[kk]['mean'] > 5 or redchisq[kk]['mean'] < 1/5: out += _bcolors.FAIL + _bcolors.BOLD + f"{foo:>11}" + _bcolors.ENDC elif redchisq[kk]['mean'] > 2 or redchisq[kk]['mean'] < 1/2: out += _bcolors.WARNING + _bcolors.BOLD + f"{foo:>11}" + _bcolors.ENDC else: out += f"{foo:>11}" foo = f"{scmean[kk]['mean']:.1f}" if scmean[kk]['std'] is not None: foo += f" ± {scmean[kk]['std']:.1f}" out += f"{foo:>14}" out += f"{ndof[kk]:>11}" out += "\n" return out[:-1]