Source code for nifty8.re.refine.healpix_field

#!/usr/bin/env python3

# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause

from functools import partial
from typing import Callable, Optional, Union

import jax
import numpy as np
from jax import numpy as jnp

from ..tree_math import ShapeWithDtype
from ..logger import logger
from ..model import LazyModel
from ..num import amend_unique_
from .chart import HEALPixChart
from .healpix_refine import cov_sqrt as cov_sqrt_hp
from .healpix_refine import refine as refine_hp
from .util import (
    RefinementMatrices, get_cov_from_loc, get_refinement_shapewithdtype,
    refinement_matrices
)


def _matrices_naive(
    chart,
    kernel: Callable,
    depth: int,
    *,
    coerce_fine_kernel: bool = False,
):
    cov_from_loc = get_cov_from_loc(kernel)

    def cov_mat(lvl, idx_hp, idx_r):
        # `idx_r` is the left-most radial pixel of the to-be-refined slice
        # Extend `gc` and `gf` radially
        gc, gf = chart.get_coarse_fine_pair((idx_hp, idx_r), lvl)
        assert gf.shape[0] == chart.fine_size**(chart.ndim + 1)

        coord = jnp.concatenate((gc, gf), axis=0)
        return cov_from_loc(coord, coord)

    def ref_mat_from_cov(cov):
        olf, ks = refinement_matrices(
            cov,
            chart.fine_size**(chart.ndim + 1),
            coerce_fine_kernel=coerce_fine_kernel
        )
        if chart.ndim > 1:
            olf = olf.reshape(
                chart.fine_size**2, chart.fine_size, chart.coarse_size**2,
                chart.coarse_size
            )
        return olf, ks

    def ref_mat(lvl, idx_hp, idx_r):
        return ref_mat_from_cov(cov_mat(lvl, idx_hp, idx_r))

    opt_lin_filter, kernel_sqrt = [], []
    for lvl in range(depth):
        pix_hp_idx = jnp.arange(chart.shape_at(lvl)[0])
        if chart.ndim == 1:
            pix_r_off = None
            vdist = jax.vmap(partial(ref_mat, lvl), in_axes=(0, 0))
        elif chart.ndim == 2:
            pix_r_off = jnp.arange(
                chart.shape_at(lvl)[1] - chart.coarse_size + 1
            )
            vdist = jax.vmap(partial(ref_mat, lvl), in_axes=(None, 0))
            vdist = jax.vmap(vdist, in_axes=(0, None))
        else:
            raise AssertionError()
        olf, ks = vdist(pix_hp_idx, pix_r_off)
        opt_lin_filter.append(olf)
        kernel_sqrt.append(ks)

    return RefinementMatrices(
        opt_lin_filter, kernel_sqrt, None, (None, ) * len(opt_lin_filter)
    )


def _matrices_tol(
    chart,
    kernel: Callable,
    depth: int,
    *,
    coerce_fine_kernel: bool = False,
    atol: Optional[float] = None,
    rtol: Optional[float] = None,
    which: Optional[str] = "dist",
    mat_buffer_size: Optional[int] = None,
    _verbosity: int = 0,
):
    dist_mat_from_loc = get_cov_from_loc(lambda x: x, None)
    if which not in (None, "dist", "cov"):
        ve = f"expected `which` to be one of 'dist' or 'cov'; got {which!r}"
        raise ValueError(ve)
    if atol is not None and rtol is not None and mat_buffer_size is None:
        raise TypeError("must specify `mat_buffer_size`")

    def dist_mat(lvl, idx_hp, idx_r):
        # `idx_r` is the left-most radial pixel of the to-be-refined slice
        # Extend `gc` and `gf` radially
        gc, gf = chart.get_coarse_fine_pair((idx_hp, idx_r), lvl)
        assert gf.shape[0] == chart.fine_size**(chart.ndim + 1)

        coord = jnp.concatenate((gc, gf), axis=0)
        return dist_mat_from_loc(coord, coord)

    def ref_mat(cov):
        olf, ks = refinement_matrices(
            cov,
            chart.fine_size**(chart.ndim + 1),
            coerce_fine_kernel=coerce_fine_kernel
        )
        if chart.ndim > 1:
            olf = olf.reshape(
                chart.fine_size**2, chart.fine_size, chart.coarse_size**2,
                chart.coarse_size
            )
        return olf, ks

    opt_lin_filter, kernel_sqrt, idx_map = [], [], []
    for lvl in range(depth):
        pix_hp_idx = jnp.arange(chart.shape_at(lvl)[0])
        assert chart.ndim == 2
        pix_r_off = jnp.arange(chart.shape_at(lvl)[1] - chart.coarse_size + 1)
        # Map only over the radial axis b/c that is irregular anyways simply
        # due to the HEALPix geometry and manually scan over the HEALPix axis
        # to only save unique values
        vdist = jax.vmap(partial(dist_mat, lvl), in_axes=(None, 0))

        # Successively amend the duplicate-free distance/covariance matrices
        d = jax.eval_shape(vdist, pix_hp_idx[0], pix_r_off)
        u = jnp.full((mat_buffer_size, ) + d.shape, jnp.nan, dtype=d.dtype)

        def scanned_amend_unique(u, pix):
            d = vdist(pix, pix_r_off)
            d = kernel(d) if which == "cov" else d
            if _verbosity > 1:
                # Magic code to move up curser by one line and delete whole line
                msg = "\x1b[1A\x1b[2K{pix}/{n}"
                jax.debug.print(msg, pix=pix, n=pix_hp_idx[-1])
            return amend_unique_(u, d, axis=0, atol=atol, rtol=rtol)

        if _verbosity > 1:
            jax.debug.print("")
        u, inv = jax.lax.scan(scanned_amend_unique, u, pix_hp_idx)
        # Cut away the placeholder for preserving static shapes
        n = np.unique(inv).size
        if n >= u.shape[0] or not np.all(np.isnan(u[n:])):
            raise ValueError("`mat_buffer_size` too small")
        u = u[:n]
        u = kernel(u) if which == "dist" else u
        inv = np.array(inv)
        if _verbosity > 0:
            logger.info(f"Post uniquifying: {u.shape}")

        # Finally, all distance/covariance matrices are assembled and we
        # can map over them to construct the refinement matrices as usual
        vmat = jax.vmap(jax.vmap(ref_mat, in_axes=(0, )), in_axes=(0, ))
        olf, ks = vmat(u)

        opt_lin_filter.append(olf)
        kernel_sqrt.append(ks)
        idx_map.append(inv)

    return RefinementMatrices(opt_lin_filter, kernel_sqrt, None, idx_map)


[docs] class RefinementHPField(LazyModel):
[docs] def __init__( self, chart: HEALPixChart, kernel: Optional[Callable] = None, dtype=None, skip0: bool = False, ): """Initialize an Iterative Charted Refinement (ICR) field for a HEALPix map with a radial extent. Parameters ---------- chart : HEALPix coordinate chart with which to iteratively refine. kernel : Covariance kernel of the refinement field. dtype : Data-type of the excitations which to add during refining. skip0 : Whether to skip the first refinement level. This is useful to e.g. """ self._chart = chart self._kernel = kernel self._dtype = dtype self._skip0 = skip0
@property def kernel(self): """Yields the kernel specified during initialization or throw a `TypeError`. """ if self._kernel is None: te = ( "either specify a fixed kernel during initialization of the" f" {self.__class__.__name__} class or provide one here" ) raise TypeError(te) return self._kernel @property def dtype(self): """Yields the data-type of the excitations.""" return jnp.float64 if self._dtype is None else self._dtype @property def skip0(self): """Whether to skip the zeroth refinement""" return self._skip0 @property def chart(self): """Associated `HEALPixChart` with which to iteratively refine.""" return self._chart
[docs] def matrices( self, kernel: Optional[Callable] = None, depth: Optional[int] = None, skip0: Optional[bool] = None, *, coerce_fine_kernel: bool = False, atol: Optional[float] = None, rtol: Optional[float] = None, which: Optional[str] = "dist", mat_buffer_size: Optional[int] = None, _verbosity: int = 0, ) -> RefinementMatrices: """Computes the refinement matrices namely the optimal linear filter and the square root of the information propagator (a.k.a. the square root of the fine covariance matrix for the excitations) for all refinement levels and all pixel indices in the coordinate chart. Parameters ---------- kernel : Covariance kernel of the refinement field if not specified during initialization. depth : Maximum refinement depth if different to the one of the `HEALPixChart`. skip0 : Whether to skip the first refinement level. coerce_fine_kernel : Whether to coerce the refinement matrices at scales at which the kernel matrix becomes singular or numerically highly unstable. atol : Absolute tolerance of the element-wise difference between distance matrices up to which refinement matrices are merged. rtol : Relative tolerance of the element-wise difference between distance matrices up to which refinement matrices are merged. which : 'cov' or 'dist' Type of the matrix for which the unique instances shall be found. If the covariance matrices (`cov`) are "uniquified", then the procedure depends on the kernel used. If the distance matrices (`dist`) are "uniquified", then the search for redundancies becomes independent of the kernel. For a fixed charting, the latter can always be done ahead of time while it might be easier to set sensible `atol` and `rtol` parameters for the latter. Note ---- Finding duplicates in the distance matrices is computationally expensive if there are only few of them. However, if the chart is kept fixed, then this one time investment might still be worth it though. """ kernel = self.kernel if kernel is None else kernel depth = self.chart.depth if depth is None else depth skip0 = self.skip0 if skip0 is None else skip0 if atol is None and rtol is None: rfm = _matrices_naive( self.chart, kernel, depth, coerce_fine_kernel=coerce_fine_kernel ) else: rfm = _matrices_tol( self.chart, kernel, depth, coerce_fine_kernel=coerce_fine_kernel, atol=atol, rtol=rtol, which=which, mat_buffer_size=mat_buffer_size, _verbosity=_verbosity ) cov_sqrt0 = cov_sqrt_hp(self.chart, kernel) if not skip0 else None return rfm._replace(cov_sqrt0=cov_sqrt0)
@property def domain(self): """Yields the `ShapeWithDtype` of the primals.""" nonhp_domain = get_refinement_shapewithdtype( shape0=self.chart.shape0[1:], depth=self.chart.depth, dtype=self.dtype, skip0=self.skip0, _coarse_size=self.chart.coarse_size, _fine_size=self.chart.fine_size, _fine_strategy=self.chart.fine_strategy, ) if not self.skip0: domain = [ ShapeWithDtype( (12 * self.chart.nside0**2, ) + nonhp_domain[0].shape, nonhp_domain[0].dtype ) ] else: domain = [None] domain += [ ShapeWithDtype( (12 * self.chart.nside_at(lvl)**2, ) + swd.shape[:-1] + (4 * swd.shape[-1], ), swd.dtype ) for lvl, swd in zip(range(self.chart.depth + 1), nonhp_domain[1:]) ] return domain
[docs] @staticmethod def apply( xi, chart: HEALPixChart, kernel: Union[Callable, RefinementMatrices], *, skip0: bool = False, coerce_fine_kernel: bool = False, _refine: Optional[Callable] = None, precision=None, ): """Static method to apply a refinement field given some excitations, a chart and a kernel. Parameters ---------- xi : Latent parameters which to use for refining. chart : Chart with which to refine the non-HEALPix axis. kernel : Covariance kernel with which to build the refinement matrices. skip0 : Whether to skip the first refinement level. coerce_fine_kernel : Whether to coerce the refinement matrices at scales at which the kernel matrix becomes singular or numerically highly unstable. precision : See JAX's precision. """ if xi[0].shape != chart.shape0: ve = "zeroth excitations do not fit to chart" raise ValueError(ve) if isinstance(kernel, RefinementMatrices): refinement = kernel else: refinement = RefinementHPField(chart, None, xi[0].dtype).matrices( kernel=kernel, skip0=skip0, coerce_fine_kernel=coerce_fine_kernel ) refine_w_chart = partial( refine_hp if _refine is None else _refine, chart=chart, precision=precision ) if not skip0: fine = (refinement.cov_sqrt0 @ xi[0].ravel()).reshape(xi[0].shape) else: if refinement.cov_sqrt0 is not None: raise AssertionError() fine = xi[0] for x, olf, k, im in zip( xi[1:], refinement.filter, refinement.propagator_sqrt, refinement.index_map ): fine = refine_w_chart(fine, x, olf, k, im) return fine
[docs] def __call__(self, xi, kernel=None, *, skip0=None, **kwargs): """See `RefinementField.apply`.""" kernel = self.kernel if kernel is None else kernel skip0 = self.skip0 if skip0 is None else skip0 return self.apply(xi, self.chart, kernel=kernel, skip0=skip0, **kwargs)
def __repr__(self): descr = f"{self.__class__.__name__}(chart={self.chart!r}" descr += f", kernel={self._kernel!r}" if self._kernel is not None else "" descr += f", dtype={self._dtype!r}" if self._dtype is not None else "" descr += f", skip0={self.skip0!r}" if self.skip0 is not False else "" descr += ")" return descr def __eq__(self, other): return repr(self) == repr(other) def __hash__(self): return hash(repr(self))