Source code for nifty8.re.refine.util

#!/usr/bin/env python3

# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause

from collections import namedtuple
from functools import partial
from math import ceil
from typing import Callable, Iterable, Literal, Optional, Tuple, Union
from warnings import warn

import jax
import numpy as np
from jax import numpy as jnp
from scipy.spatial import distance_matrix

from ..tree_math import ShapeWithDtype, zeros_like
from ..logger import logger
from ..model import LazyModel

NDARRAY = Union[jnp.ndarray, np.ndarray]

RefinementMatrices = namedtuple(
    "RefinementMatrices",
    ("filter", "propagator_sqrt", "cov_sqrt0", "index_map")
)


[docs] def refinement_matrices(cov, n_fsz: int, coerce_fine_kernel: bool): cov_ff = cov[-n_fsz:, -n_fsz:] cov_fc = cov[-n_fsz:, :-n_fsz] cov_cc = cov[:-n_fsz, :-n_fsz] del cov cov_cc_inv = jnp.linalg.inv(cov_cc) del cov_cc olf = cov_fc @ cov_cc_inv # Also see Schur-Complement fine_kernel = cov_ff - cov_fc @ cov_cc_inv @ cov_fc.T del cov_cc_inv, cov_fc, cov_ff if coerce_fine_kernel: # TODO: Try to work with NaN to avoid the expensive eigendecomposition; # work with nan_to_num? # Implicitly assume a white power spectrum beyond the numerics limit. # Use the diagonal as estimate for the magnitude of the variance. fine_kernel_fallback = jnp.diag(jnp.abs(jnp.diag(fine_kernel))) # Never produce NaNs (https://github.com/google/jax/issues/1052) # This is expensive but necessary (worse but cheaper: # `jnp.all(jnp.diag(fine_kernel) > 0.)`) is_pos_def = jnp.all(jnp.linalg.eigvalsh(fine_kernel) > 0) fine_kernel = jnp.where(is_pos_def, fine_kernel, fine_kernel_fallback) # NOTE, subsequently use the Cholesky decomposition, even though # already having computed the eigenvalues, as to get consistent results # across platforms # Matrices are symmetrized by JAX, i.e. gradients are projected to the # subspace of symmetric matrices (see # https://github.com/google/jax/issues/10815) fine_kernel_sqrt = jnp.linalg.cholesky(fine_kernel) return olf, fine_kernel_sqrt
[docs] def get_cov_from_loc( kernel=None, cov_from_loc=None ) -> Callable[[NDARRAY, NDARRAY], NDARRAY]: if cov_from_loc is None and callable(kernel): # TODO: extend to non-stationary kernels def cov_from_loc_sngl(x, y): return kernel(jnp.linalg.norm(x - y)) cov_from_loc = jax.vmap( jax.vmap(cov_from_loc_sngl, in_axes=(None, 0)), in_axes=(0, None) ) else: if not callable(cov_from_loc): ve = "exactly one of `cov_from_loc` or `kernel` must be set and callable" raise ValueError(ve) # TODO: benchmark whether using `triu_indices(n, k=1)` and # `diag_indices(n)` is advantageous return cov_from_loc
[docs] def get_refinement_shapewithdtype( shape0: Union[int, tuple], depth: int, dtype=None, *, _coarse_size: int = 3, _fine_size: int = 2, _fine_strategy: Literal["jump", "extend"] = "jump", skip0: bool = False, ): if depth < 0: raise ValueError(f"invalid `depth`; got {depth!r}") csz = int(_coarse_size) # coarse size fsz = int(_fine_size) # fine size swd = partial(ShapeWithDtype, dtype=dtype) shape0 = (shape0, ) if isinstance(shape0, int) else shape0 ndim = len(shape0) exc_shp = [swd(shape0)] if not skip0 else [None] if depth > 0: if _fine_strategy == "jump": exc_lvl = tuple(el - (csz - 1) for el in shape0) + (fsz**ndim, ) elif _fine_strategy == "extend": exc_lvl = tuple( ceil((el - (csz - 1)) / (fsz // 2)) for el in shape0 ) + (fsz**ndim, ) else: raise ValueError(f"invalid `_fine_strategy`; got {_fine_strategy}") exc_shp += [swd(exc_lvl)] for lvl in range(1, depth): if _fine_strategy == "jump": exc_lvl = tuple( fsz * el - (csz - 1) for el in exc_shp[-1].shape[:-1] ) + (fsz**ndim, ) elif _fine_strategy == "extend": exc_lvl = tuple( ceil((fsz * el - (csz - 1)) / (fsz // 2)) for el in exc_shp[-1].shape[:-1] ) + (fsz**ndim, ) else: raise AssertionError() if any(el <= 0 for el in exc_lvl): ve = ( f"`shape0` ({shape0}) with `depth` ({depth}) yield an" f" invalid shape ({exc_lvl}) at level {lvl}" ) raise ValueError(ve) exc_shp += [swd(exc_lvl)] return exc_shp
[docs] def coarse2fine_shape( shape0: Union[int, Iterable[int]], depth: int, *, _coarse_size: int = 3, _fine_size: int = 2, _fine_strategy: Literal["jump", "extend"] = "jump", ): """Translates a coarse shape to its corresponding fine shape.""" shape0 = (shape0, ) if isinstance(shape0, int) else shape0 csz = int(_coarse_size) # coarse size fsz = int(_fine_size) # fine size if _fine_size % 2 != 0: raise ValueError("only even numbers allowed for `_fine_size`") shape = [] for shp in shape0: sz_at = shp for lvl in range(depth): if _fine_strategy == "jump": sz_at = fsz * (sz_at - (csz - 1)) elif _fine_strategy == "extend": sz_at = fsz * ceil((sz_at - (csz - 1)) / (fsz // 2)) else: ve = f"invalid `_fine_strategy`; got {_fine_strategy}" raise ValueError(ve) if sz_at <= 0: ve = ( f"`shape0` ({shape0}) with `depth` ({depth}) yield an" f" invalid shape ({sz_at}) at level {lvl}" ) raise ValueError(ve) shape.append(int(sz_at)) return tuple(shape)
[docs] def fine2coarse_shape( shape: Union[int, Iterable[int]], depth: int, *, _coarse_size: int = 3, _fine_size: int = 2, _fine_strategy: Literal["jump", "extend"] = "jump", ceil_sizes: bool = False, ): """Translates a fine shape to its corresponding coarse shape.""" shape = (shape, ) if isinstance(shape, int) else shape csz = int(_coarse_size) # coarse size fsz = int(_fine_size) # fine size if _fine_size % 2 != 0: raise ValueError("only even numbers allowed for `_fine_size`") shape0 = [] for shp in shape: sz_at = shp for lvl in range(depth, 0, -1): if _fine_strategy == "jump": # solve for n: `fsz * (n - (csz - 1))` sz_at = sz_at / fsz + (csz - 1) elif _fine_strategy == "extend": # solve for n: `fsz * ceil((n - (csz - 1)) / (fsz // 2))` # NOTE, not unique because of `ceil`; use lower limit sz_at_max = (sz_at / fsz) * (fsz // 2) + (csz - 1) sz_at_min = ceil(sz_at_max - (fsz // 2 - 1)) for sz_at_cand in range(sz_at_min, ceil(sz_at_max) + 1): try: shp_cand = coarse2fine_shape( (sz_at_cand, ), depth=depth - lvl + 1, _coarse_size=csz, _fine_size=fsz, _fine_strategy=_fine_strategy )[0] except ValueError as e: if "invalid shape" not in "".join(e.args): ve = "unexpected behavior of `coarse2fine_shape`" raise ValueError(ve) from e shp_cand = -1 if shp_cand >= shp: sz_at = sz_at_cand break else: ve = f"interval search within [{sz_at_min}, {ceil(sz_at_max)}] failed" raise ValueError(ve) else: ve = f"invalid `_fine_strategy`; got {_fine_strategy}" raise ValueError(ve) sz_at = ceil(sz_at) if ceil_sizes else sz_at if sz_at != int(sz_at): raise ValueError(f"invalid shape at level {lvl}") shape0.append(int(sz_at)) return tuple(shape0)
[docs] def coarse2fine_distances( distances0: Union[float, Iterable[float]], depth: int, *, _fine_size: int = 2, _fine_strategy: Literal["jump", "extend"] = "jump", ): """Translates coarse distances to its corresponding fine distances.""" fsz = int(_fine_size) # fine size if _fine_strategy == "jump": fpx_in_cpx = fsz**depth elif _fine_strategy == "extend": fpx_in_cpx = 2**depth else: ve = f"invalid `_fine_strategy`; got {_fine_strategy}" raise ValueError(ve) return jnp.atleast_1d(distances0) / fpx_in_cpx
[docs] def fine2coarse_distances( distances: Union[float, Iterable[float]], depth: int, *, _fine_size: int = 2, _fine_strategy: Literal["jump", "extend"] = "jump", ): """Translates fine distances to its corresponding coarse distances.""" fsz = int(_fine_size) # fine size if _fine_strategy == "jump": fpx_in_cpx = fsz**depth elif _fine_strategy == "extend": fpx_in_cpx = 2**depth else: ve = f"invalid `_fine_strategy`; got {_fine_strategy}" raise ValueError(ve) return jnp.atleast_1d(distances) * fpx_in_cpx
def _clipping_posdef_logdet(mat, msg_prefix=""): sign, logdet = jnp.linalg.slogdet(mat) if sign <= 0: ve = "not positive definite; clipping eigenvalues" warn(msg_prefix + ve) eps = jnp.finfo(mat.dtype.type).eps evs = jnp.linalg.eigvalsh(mat) logdet = jnp.sum(jnp.log(jnp.clip(evs, a_min=eps * evs.max()))) return logdet
[docs] def gauss_kl(cov_desired, cov_approx, *, m_desired=None, m_approx=None): cov_t_dl = _clipping_posdef_logdet(cov_desired, msg_prefix="`cov_desired` ") cov_a_dl = _clipping_posdef_logdet(cov_approx, msg_prefix="`cov_approx` ") cov_a_inv = jnp.linalg.inv(cov_approx) kl = -cov_desired.shape[0] # number of dimensions kl += cov_a_dl - cov_t_dl + jnp.trace(cov_a_inv @ cov_desired) if m_approx is not None and m_desired is not None: m_diff = m_approx - m_desired kl += m_diff @ cov_a_inv @ m_diff elif not (m_approx is None and m_approx is None): ve = "either both or neither of `m_approx` and `m_desired` must be `None`" raise ValueError(ve) return 0.5 * kl
[docs] def refinement_covariance(chart_or_model, kernel=None, jit=True): """Computes the implied covariance as modeled by the refinement scheme.""" from .chart import CoordinateChart, HEALPixChart from .charted_field import RefinementField if isinstance(chart_or_model, CoordinateChart): cf = RefinementField(chart_or_model, kernel=kernel) shape = chart_or_model.shape elif isinstance(chart_or_model, HEALPixChart): cf = RefinementField(chart_or_model, kernel=kernel) shape = chart_or_model.shape elif isinstance(chart_or_model, LazyModel): cf = chart_or_model shape = chart_or_model.target.shape else: te = f"expected a model or a chart; got {type(chart_or_model)!r}" raise TypeError(te) ndim = len(shape) try: cf_T = jax.linear_transpose(cf, cf.domain) cov_implicit = lambda x: cf(*cf_T(x)) cov_implicit = jax.jit(cov_implicit) if jit else cov_implicit _ = cov_implicit(jnp.zeros(shape)) # Test transpose except (NotImplementedError, AssertionError): # Workaround JAX not yet implementing the transpose of the scanned # refinement _, cf_T = jax.vjp(cf, zeros_like(cf.domain)) cov_implicit = lambda x: cf(*cf_T(x)) cov_implicit = jax.jit(cov_implicit) if jit else cov_implicit probe = jnp.zeros(shape) indices = np.indices(shape).reshape(ndim, -1) cov_empirical = jax.lax.map( lambda idx: cov_implicit(probe.at[tuple(idx)].set(1.)).ravel(), indices.T ).T # vmap over `indices` w/ `in_axes=1, out_axes=-1` return cov_empirical
[docs] def true_covariance(chart, kernel, depth=None): """Computes the true covariance at the final grid.""" depth = chart.depth if depth is None else depth c0_slc = tuple(slice(sz) for sz in chart.shape_at(depth)) pos = jnp.stack(chart.ind2cart(jnp.mgrid[c0_slc], depth), axis=-1).reshape(-1, chart.ndim) dist_mat = distance_matrix(pos, pos) return kernel(dist_mat)
[docs] def refinement_approximation_error( chart, kernel: Callable, cutout: Optional[Union[slice, int, Tuple[slice], Tuple[int]]] = None, ): """Computes the Kullback-Leibler (KL) divergence of the true covariance versus the approximative one for a given kernel and shape of the fine grid. If the desired shape can not be matched, the next larger one is used and the field is subsequently cropped to the desired shape. """ suggested_min_shape = 2 * 4**chart.depth if any(s <= suggested_min_shape for s in chart.shape): msg = ( f"shape {chart.shape} potentially too small" f" (desired {(suggested_min_shape, ) * chart.ndim} (=`2*4^depth`))" ) warn(msg) cov_empirical = refinement_covariance(chart, kernel) cov_truth = true_covariance(chart, kernel) if cutout is None and all(s > suggested_min_shape for s in chart.shape): cutout = (suggested_min_shape, ) * chart.ndim logger.info(f"cropping field (w/ shape {chart.shape}) to {cutout}") if cutout is not None: if isinstance(cutout, slice): cutout = (cutout, ) * chart.ndim elif isinstance(cutout, int): cutout = (slice(cutout), ) * chart.ndim elif isinstance(cutout, tuple): if all(isinstance(el, slice) for el in cutout): pass elif all(isinstance(el, int) for el in cutout): cutout = tuple(slice(el) for el in cutout) else: raise TypeError("elements of `cutout` of invalid type") else: raise TypeError("`cutout` of invalid type") cov_empirical = cov_empirical.reshape(chart.shape * 2)[cutout * 2] cov_truth = cov_truth.reshape(chart.shape * 2)[cutout * 2] sz = np.prod(cov_empirical.shape[:chart.ndim]) if np.prod(cov_truth.shape[:chart.ndim]) != sz or not sz.dtype == int: raise AssertionError() cov_empirical = cov_empirical.reshape(sz, sz) cov_truth = cov_truth.reshape(sz, sz) aux = { "cov_empirical": cov_empirical, "cov_truth": cov_truth, } return gauss_kl(cov_truth, cov_empirical), aux
REFINEMENT_STRATEGIES = [ { "_coarse_size": 3, "_fine_size": 2 }, { "_coarse_size": 3, "_fine_size": 4 }, { "_coarse_size": 5, "_fine_size": 2 }, { "_coarse_size": 5, "_fine_size": 4 }, { "_coarse_size": 5, "_fine_size": 6 }, ]
[docs] def get_optimal_refinement_chart( kernel, *, shape0, refinement_strategies=REFINEMENT_STRATEGIES, **common_kwargs ): """Compute the Kullback-Leibler divergence for a given `kernel`, initial `shape0`, and parameters of a coordinate chart (`common_kwargs`); returning the set within `refinement_strategies` that has the lowest Kullback-Leibler divergence. """ from .chart import CoordinateChart charts = [] min_shape = (1 << 31, ) * len(shape0) # Absurdly large placeholder value for kwargs in refinement_strategies: cc = CoordinateChart(shape0=shape0, **(common_kwargs | kwargs)) charts.append(cc) min_shape = tuple(min(ms, s) for ms, s in zip(min_shape, cc.shape)) errors = [] for cc in charts: err, _ = refinement_approximation_error(cc, kernel, cutout=min_shape) errors.append(err) chosen = np.argmin(errors) return charts[chosen], tuple(zip(errors, charts))