Source code for nifty8.re.refine.charted_field

#!/usr/bin/env python3

# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause

from functools import partial
from typing import Callable, Iterable, Optional, Tuple, Union

from jax import numpy as jnp
from jax import vmap

from ..model import LazyModel
from ..tree_math import Vector
from .chart import CoordinateChart
from .charted_refine import refine
from .util import (
    RefinementMatrices, get_cov_from_loc, get_refinement_shapewithdtype,
    refinement_matrices
)


def _coordinate_pixel_refinement_matrices(
    chart: CoordinateChart,
    level: int,
    pixel_index: Optional[Iterable[int]] = None,
    kernel: Optional[Callable] = None,
    *,
    coerce_fine_kernel: bool = False,
    _cov_from_loc: Optional[Callable] = None,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
    cov_from_loc = get_cov_from_loc(kernel, _cov_from_loc)
    csz = int(chart.coarse_size)  # coarse size
    if csz % 2 != 1:
        raise ValueError("only odd numbers allowed for `_coarse_size`")
    fsz = int(chart.fine_size)  # fine size
    if fsz % 2 != 0:
        raise ValueError("only even numbers allowed for `_fine_size`")
    ndim = chart.ndim
    if pixel_index is None:
        pixel_index = (0, ) * ndim
    pixel_index = jnp.asarray(pixel_index)
    if pixel_index.size != ndim:
        ve = f"`pixel_index` has {pixel_index.size} dimensions but `chart` has {ndim}"
        raise ValueError(ve)

    csz_half = int((csz - 1) / 2)
    gc = jnp.arange(-csz_half, csz_half + 1, dtype=float)
    gc = jnp.ones((ndim, 1)) * gc
    gc = jnp.stack(jnp.meshgrid(*gc, indexing="ij"), axis=-1)
    if chart.fine_strategy == "jump":
        gf = jnp.arange(fsz, dtype=float) / fsz - 0.5 + 0.5 / fsz
    elif chart.fine_strategy == "extend":
        gf = jnp.arange(fsz, dtype=float) / 2 - 0.25 * (fsz - 1)
    else:
        raise ValueError(f"invalid `_fine_strategy`; got {chart.fine_strategy}")
    gf = jnp.ones((ndim, 1)) * gf
    gf = jnp.stack(jnp.meshgrid(*gf, indexing="ij"), axis=-1)
    # On the GPU a single `cov_from_loc` call is about twice as fast as three
    # separate calls for coarse-coarse, fine-fine and coarse-fine.
    coord = jnp.concatenate(
        (gc.reshape(-1, ndim), gf.reshape(-1, ndim)), axis=0
    )
    del gc, gf
    coord = chart.ind2cart((coord + pixel_index.reshape((1, ndim))).T, level)
    coord = jnp.stack(coord, axis=-1)

    return refinement_matrices(
        cov_from_loc(coord, coord),
        fsz**ndim,
        coerce_fine_kernel=coerce_fine_kernel
    )


def _coordinate_refinement_matrices(
    chart: CoordinateChart,
    kernel: Callable,
    *,
    depth: Optional[int] = None,
    skip0=False,
    coerce_fine_kernel: bool = False,
    _cov_from_loc=None
) -> RefinementMatrices:
    cov_from_loc = get_cov_from_loc(kernel, _cov_from_loc)
    depth = chart.depth if depth is None else depth

    if not skip0:
        rg0 = jnp.mgrid[tuple(slice(s) for s in chart.shape0)]
        c0 = jnp.stack(chart.ind2cart(rg0, 0), axis=-1).reshape(-1, chart.ndim)
        # Matrices are symmetrized by JAX, i.e. gradients are projected to the
        # subspace of symmetric matrices (see
        # https://github.com/google/jax/issues/10815)
        cov_sqrt0 = jnp.linalg.cholesky(cov_from_loc(c0, c0))
    else:
        cov_sqrt0 = None

    opt_lin_filter, kernel_sqrt = [], []
    olf_at = vmap(
        partial(
            _coordinate_pixel_refinement_matrices,
            chart,
            coerce_fine_kernel=coerce_fine_kernel,
            _cov_from_loc=cov_from_loc,
        ),
        in_axes=(None, 0),
        out_axes=(0, 0)
    )

    for lvl in range(depth):
        shape_lvl = chart.shape_at(lvl)
        pixel_indices = []
        for ax in range(chart.ndim):
            pad = (chart.coarse_size - 1) / 2
            if int(pad) != pad:
                raise ValueError("`coarse_size` must be odd")
            pad = int(pad)
            if chart.fine_strategy == "jump":
                stride = 1
            elif chart.fine_strategy == "extend":
                stride = chart.fine_size / 2
                if int(stride) != stride:
                    raise ValueError("`fine_size` must be even")
                stride = int(stride)
            else:
                raise AssertionError()
            if ax in chart.irregular_axes:
                pixel_indices.append(
                    jnp.arange(pad, shape_lvl[ax] - pad, stride)
                )
            else:
                pixel_indices.append(jnp.array([pad]))
        pixel_indices = jnp.stack(
            jnp.meshgrid(*pixel_indices, indexing="ij"), axis=-1
        )
        shape_filtered_lvl = pixel_indices.shape[:-1]
        pixel_indices = pixel_indices.reshape(-1, chart.ndim)

        olf, ks = olf_at(lvl, pixel_indices)
        shape_bc_lvl = tuple(
            shape_filtered_lvl[i] if i in chart.irregular_axes else 1
            for i in range(chart.ndim)
        )
        opt_lin_filter.append(olf.reshape(shape_bc_lvl + olf.shape[-2:]))
        kernel_sqrt.append(ks.reshape(shape_bc_lvl + ks.shape[-2:]))

    return RefinementMatrices(
        opt_lin_filter, kernel_sqrt, cov_sqrt0, (None, ) * len(opt_lin_filter)
    )


[docs] class RefinementField(LazyModel):
[docs] def __init__( self, *args, kernel: Optional[Callable] = None, dtype=None, skip0: bool = False, **kwargs ): """Initialize an Iterative Charted Refinement (ICR) field. There are multiple ways to initialize a charted refinement field. The recommended way is to first instantiate a `CoordinateChart` and pass it as first argument to this method. Alternatively, you may pass any and all arguments of `CoordinateChart` also to this method and it will instantiate the `CoordinateChart` for you and use it in the same way as if directly specified. Parameters ---------- chart : CoordinateChart The `CoordinateChart` with which to refine. kernel : Covariance kernel of the refinement field. dtype : Data-type of the excitations which to add during refining. skip0 : Whether to skip the first refinement level. This is useful to e.g. stack multiple refinement fields on top of each other. **kwargs : Alternatively to `chart` any parameters accepted by `CoordinateChart`. """ self._kernel = kernel self._dtype = dtype self._skip0 = skip0 if len(args) > 0 and isinstance(args[0], CoordinateChart): if kwargs: raise TypeError(f"expected no keyword arguments, got {kwargs}") if len(args) == 1: self._chart, = args elif len(args) == 2 and callable(args[1]) and kernel is None: self._chart, self._kernel = args elif len(args) == 3 and callable( args[1] ) and kernel is None and dtype is None: self._chart, self._kernel, self._dtype = args elif len(args) == 4 and callable( args[1] ) and kernel is None and dtype is None and skip0 == False: self._chart, self._kernel, self._dtype, self._skip0 = args else: te = "got unexpected arguments in addition to CoordinateChart" raise TypeError(te) else: self._chart = CoordinateChart(*args, **kwargs)
@property def kernel(self): """Yields the kernel specified during initialization or throw a `TypeError`. """ if self._kernel is None: te = ( "either specify a fixed kernel during initialization of the" f" {self.__class__.__name__} class or provide one here" ) raise TypeError(te) return self._kernel @property def dtype(self): """Yields the data-type of the excitations.""" return jnp.float64 if self._dtype is None else self._dtype @property def skip0(self): """Whether to skip the zeroth refinement""" return self._skip0 @property def chart(self): """Associated `CoordinateChart` with which to iterative refine.""" return self._chart
[docs] def matrices( self, kernel: Optional[Callable] = None, depth: Optional[int] = None, skip0: Optional[bool] = None, **kwargs ) -> RefinementMatrices: """Computes the refinement matrices namely the optimal linear filter and the square root of the information propagator (a.k.a. the square root of the fine covariance matrix for the excitations) for all refinement levels and all pixel indices in the coordinate chart. Parameters ---------- kernel : Covariance kernel of the refinement field if not specified during initialization. depth : Maximum refinement depth if different to the one of the `CoordinateChart`. skip0 : Whether to skip the first refinement level. """ if kernel is None and "_cov_from_loc" not in kwargs: kernel = self.kernel depth = self.chart.depth if depth is None else depth skip0 = self.skip0 if skip0 is None else skip0 return _coordinate_refinement_matrices( self.chart, kernel=kernel, depth=depth, skip0=skip0, **kwargs )
[docs] def matrices_at( self, level: int, pixel_index: Optional[Iterable[int]] = None, kernel: Optional[Callable] = None, **kwargs ) -> Tuple[jnp.ndarray, jnp.ndarray]: """Computes the refinement matrices namely the optimal linear filter and the square root of the information propagator (a.k.a. the square root of the fine covariance matrix for the excitations) at the specified level and pixel index. Parameters ---------- level : Refinement level. pixel_index : Index of the NDArray at the refinement level `level` which to refine, i.e. use as center coarse pixel. kernel : Covariance kernel of the refinement field if not specified during initialization. """ if kernel is None and "_cov_from_loc" not in kwargs: kernel = self.kernel return _coordinate_pixel_refinement_matrices( self.chart, level=level, pixel_index=pixel_index, kernel=kernel, **kwargs )
@property def domain(self): """Yields the `ShapeWithDtype` of the primals.""" return get_refinement_shapewithdtype( shape0=self.chart.shape0, depth=self.chart.depth, dtype=self.dtype, skip0=self.skip0, _coarse_size=self.chart.coarse_size, _fine_size=self.chart.fine_size, _fine_strategy=self.chart.fine_strategy, )
[docs] @staticmethod def apply( xi, chart, kernel: Union[Callable, RefinementMatrices], *, skip0: bool = False, depth: Optional[int] = None, coerce_fine_kernel: bool = False, _refine: Optional[Callable] = None, _cov_from_loc: Optional[Callable] = None, precision=None, ): """Static method to apply a refinement field given some excitations, a chart and a kernel. Parameters ---------- xi : Latent parameters which to use for refining. chart : Chart with which to refine. kernel : Covariance kernel with which to build the refinement matrices. skip0 : Whether to skip the first refinement level. depth : Refinement depth if different to the depth of the coordinate chart. coerce_fine_kernel : Whether to coerce the refinement matrices at scales at which the kernel matrix becomes singular or numerically highly unstable. precision : See JAX's precision. """ depth = chart.depth if depth is None else depth if depth != len(xi.tree if isinstance(xi, Vector) else xi) - 1: ve = ( f"incompatible refinement depths of `xi` ({len(xi) - 1})" f" and `depth` (of chart) {depth}" ) raise ValueError(ve) if isinstance(kernel, RefinementMatrices): refinement = kernel else: refinement = _coordinate_refinement_matrices( chart, kernel=kernel, depth=depth, skip0=skip0, coerce_fine_kernel=coerce_fine_kernel, _cov_from_loc=_cov_from_loc, ) refine_w_chart = partial( refine if _refine is None else _refine, _coarse_size=chart.coarse_size, _fine_size=chart.fine_size, _fine_strategy=chart.fine_strategy, precision=precision ) if not skip0: fine = (refinement.cov_sqrt0 @ xi[0].ravel()).reshape(xi[0].shape) else: if refinement.cov_sqrt0 is not None: raise AssertionError() fine = xi[0] for x, olf, k in zip( xi[1:], refinement.filter, refinement.propagator_sqrt ): fine = refine_w_chart(fine, x, olf, k) return fine
[docs] def __call__(self, xi, kernel=None, *, skip0=None, **kwargs): """See `RefinementField.apply`.""" if kernel is None and "_cov_from_loc" not in kwargs: kernel = self.kernel skip0 = self.skip0 if skip0 is None else skip0 return self.apply(xi, self.chart, kernel=kernel, skip0=skip0, **kwargs)
def __repr__(self): descr = f"{self.__class__.__name__}({self.chart!r}" descr += f", kernel={self._kernel!r}" if self._kernel is not None else "" descr += f", dtype={self._dtype!r}" if self._dtype is not None else "" descr += f", skip0={self.skip0!r}" if self.skip0 is not False else "" descr += ")" return descr def __eq__(self, other): return repr(self) == repr(other) def __hash__(self): return hash(repr(self))