Source code for

#!/usr/bin/env python3

# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause

import sys
from functools import partial
from typing import Tuple

import jax
import numpy as np
from jax import numpy as jnp
from numpy.typing import NDArray

from ..logger import logger

[docs] def unique( ar, *, return_inverse=False, axis=-1, atol=1e-10, rtol=1e-5, _verbosity=0 ): """Find unique sub-arrays in `ar` along `axis` within a given tolerance. The algorithm is efficient if the number of approximately unique values is small compared to the overall size of `ar` along `axis`. """ if not isinstance(axis, int): raise TypeError(f"`axis` needs to be of type `int`; got {type(axis)!r}") uniqs = None inverse = None if return_inverse: inverse = np.full(ar.shape[axis], -1, dtype=int) # Ensure positive axis required for identify the axis of reductions axis = np.arange(ar.ndim)[axis] ra = tuple(set(range(ar.ndim)) - { axis, }) to_sort = np.ones(ar.shape[axis], dtype=bool) while np.sum(to_sort) != 0: i = np.nonzero(to_sort)[0][0] u = np.take(ar, (i, ), axis=axis) if uniqs is None: uniqs = u else: uniqs = np.concatenate((uniqs, u), axis=axis) isclose = np.zeros(to_sort.shape, dtype=bool) # Set the `mode` to work around `ar` potentially being a JAX array and # not supporting NumPy's default `mode='raise'` a = np.take(ar, np.nonzero(to_sort)[0], axis=axis, mode=None) isclose[to_sort] = np.all( np.abs(u - a) <= (atol + rtol * np.abs(a)), axis=ra ) to_sort &= ~isclose if return_inverse: assert inverse is not None inverse[isclose] = uniqs.shape[axis] - 1 if _verbosity > 0:"to-sort: {np.sum(to_sort)}", file=sys.stderr) if return_inverse: assert np.all(inverse != -1) return uniqs, inverse return uniqs
[docs] def amend_unique(ar, el, *, axis=-1, atol=1e-10, rtol=1e-5) -> Tuple[NDArray, int]: """Amend the element `el` if it is unique up to the specified tolerance otherwise do nothing. """ if not isinstance(axis, int): raise TypeError(f"`axis` needs to be of type `int`; got {type(axis)!r}") # Ensure positive axis required for identify the axis of reductions axis = np.arange(ar.ndim)[axis] ra = tuple(set(range(ar.ndim)) - { axis, }) el = np.expand_dims(el, axis=axis) isclose = np.all(np.abs(ar - el) <= (atol + rtol * np.abs(el)), axis=ra) assert isclose.size == ar.shape[axis] if np.any(isclose): return ar, np.nonzero(isclose)[0][0] else: return np.concatenate((ar, el), axis=axis), ar.shape[axis]
[docs] @partial(jax.jit, static_argnames=("axis", )) def amend_unique_(ar, el, *, axis=-1, atol=1e-10, rtol=1e-5): if not isinstance(axis, int): raise TypeError(f"`axis` needs to be of type `int`; got {type(axis)!r}") PLC = -1 << 63 if jnp.array(0).dtype == jnp.int64 else -1 << 31 # Ensure positive axis required for identify the axis of reductions axis = np.arange(ar.ndim)[axis] ra = tuple(set(range(ar.ndim)) - { axis, }) el = jnp.expand_dims(el, axis=axis) isclose = jnp.all(jnp.abs(ar - el) <= (atol + rtol * jnp.abs(el)), axis=ra) # Find the first not-NaN location in the array at which to potentially # insert a new value n = jnp.nonzero(jnp.all(jnp.isnan(ar), axis=ra), size=1, fill_value=PLC)[0][0] # Replace NaN with NaN if the new element is close to any existing element, # else insert it at the first not-NaN location any_isclose = jnp.any(isclose) e = jnp.where(any_isclose, jnp.full_like(el, jnp.nan), el) ar =[(slice(None), ) * axis + (n, )].set(jnp.squeeze(e, axis=axis)) idx = jnp.nonzero(isclose, size=1, fill_value=PLC)[0][0] idx = jnp.where(any_isclose, idx, n) return ar, idx