Source code for nifty8.re.extra.sampling_los

#!/usr/bin/env python3

# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause

import dataclasses
from functools import partial

import jax
from jax import numpy as jnp

from ..tree_math import ShapeWithDtype
from ..model import Model


def _los(x, /, start, end, *, distances, shape, n_sampling_points, order=1):
    from jax.scipy.ndimage import map_coordinates

    l2i = ((shape - 1) / shape) / distances
    start_iloc = start * l2i
    end_iloc = end * l2i
    ddi = (end_iloc - start_iloc) / n_sampling_points
    adi = jnp.arange(0, n_sampling_points) + 0.5
    dist = jnp.linalg.norm(end - start)
    pp = start_iloc[:, jnp.newaxis] + ddi[:, jnp.newaxis] * adi[jnp.newaxis]
    return map_coordinates(x, pp, order=order, cval=jnp.nan).sum() * (
        dist / n_sampling_points
    )


[docs] class SamplingCartesianGridLOS(Model): start: jax.Array = dataclasses.field(metadata=dict(static=False)) end: jax.Array = dataclasses.field(metadata=dict(static=False)) distances: jax.Array = dataclasses.field(metadata=dict(static=False))
[docs] def __init__( self, start, end, *, shape, distances, n_sampling_points=500, interpolation_order=1, dtype=None, ): """Sampling Line-Of-Sight (LOS) intergrator. Samples the LOS at a number of points and sum up the result to estimate the integral from a starting point to an end point in n-dimensional space. Parameters ---------- start : Location of the start point(s) in Cartesian space of shape `(n_points, n_dim)` or `(n_dim,)`. end : Location of the end point(s) in Cartesian space of shape `(n_points, n_dim)` or `(n_dim,)`. shape : Shape of the input. distances : Tuple of distances for each axis of the shape of the input n_sampling_points : int, optional Number of sampling points per LOS for the integration. interpolation_order : int, optional Order of the interpolation for reading out the sampling points. dtype : data-type, optional Hint specifying the dtype for the construction of the domain. """ # We assume that `start` and `end` are of shape (n_points, n_dimensions) self.start = jnp.array(start) self.end = jnp.array(end) self.distances = jnp.array(distances) self._los = partial( _los, n_sampling_points=n_sampling_points, order=interpolation_order, distances=self.distances, shape=jnp.array(shape), ) super().__init__( domain=ShapeWithDtype(shape, dtype), target=ShapeWithDtype(end.shape, dtype) )
def __call__(self, x): in_axes = (None, 0, 0) if self.start.ndim < self.end.ndim: in_axes = (None, None, 0) elif self.start.ndim > self.end.ndim: in_axes = (None, 0, None) return jax.vmap(self._los, in_axes=in_axes)(x, self.start, self.end)