# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
import operator
from collections import namedtuple
from collections.abc import Mapping
from functools import partial
from typing import Callable, Optional, Tuple, Union
import numpy as np
from jax import numpy as jnp
from ..config import _config
from .gauss_markov import IntegratedWienerProcess
from .logger import logger
from .misc import wrap
from .model import Model, WrappedCall
from .num import lognormal_prior, normal_prior
from .tree_math import ShapeWithDtype, random_like
[docs]
def hartley(p, axes=None):
from jax.numpy import fft
tmp = fft.fftn(p, axes=axes)
c = _config.get("hartley_convention")
add_or_sub = operator.add if c == "non_canonical_hartley" else operator.sub
return add_or_sub(tmp.real, tmp.imag)
[docs]
def get_fourier_mode_distributor(
shape: Union[tuple, int], distances: Union[tuple, float]
):
"""Get the unique lengths of the Fourier modes, a mapping from a mode to
its length index and the multiplicity of each unique Fourier mode length.
Parameters
----------
shape : tuple of int or int
Position-space shape.
distances : tuple of float or float
Position-space distances.
Returns
-------
mode_length_idx : jnp.ndarray
Index in power-space for every mode in harmonic-space. Can be used to
distribute power from a power-space to the full harmonic grid.
unique_mode_length : jnp.ndarray
Unique length of Fourier modes.
mode_multiplicity : jnp.ndarray
Multiplicity for each unique Fourier mode length.
"""
shape = (shape, ) if isinstance(shape, int) else tuple(shape)
# Compute length of modes
mspc_distances = 1. / (jnp.array(shape) * jnp.array(distances))
m_length = jnp.arange(shape[0], dtype=jnp.float64)
m_length = jnp.minimum(m_length, shape[0] - m_length) * mspc_distances[0]
if len(shape) != 1:
m_length *= m_length
for i in range(1, len(shape)):
tmp = jnp.arange(shape[i], dtype=jnp.float64)
tmp = jnp.minimum(tmp, shape[i] - tmp) * mspc_distances[i]
tmp *= tmp
m_length = jnp.expand_dims(m_length, axis=-1) + tmp
m_length = jnp.sqrt(m_length)
# Construct an array of unique mode lengths
uniqueness_rtol = 1e-12
um = jnp.unique(m_length)
tol = uniqueness_rtol * um[-1]
um = um[jnp.diff(jnp.append(um, 2 * um[-1])) > tol]
# Group modes based on their length and store the result as power
# distributor
binbounds = 0.5 * (um[:-1] + um[1:])
m_length_idx = jnp.searchsorted(binbounds, m_length)
m_count = jnp.bincount(m_length_idx.ravel(), minlength=um.size)
if jnp.any(m_count == 0) or um.shape != m_count.shape:
raise RuntimeError("invalid harmonic mode(s) encountered")
return m_length_idx, um, m_count
RegularCartesianGrid = namedtuple(
"RegularCartesianGrid", (
"shape",
"total_volume",
"distances",
"harmonic_grid",
),
defaults=(None, )
)
RegularFourierGrid = namedtuple(
"RegularFourierGrid", (
"shape",
"power_distributor",
"mode_multiplicity",
"mode_lengths",
"relative_log_mode_lengths",
"log_volume",
)
)
def _make_grid(shape, distances, harmonic_type) -> RegularCartesianGrid:
"""Creates the grid for the amplitude model"""
shape = (shape, ) if isinstance(shape, int) else tuple(shape)
distances = tuple(np.broadcast_to(distances, jnp.shape(shape)))
totvol = jnp.prod(jnp.array(shape) * jnp.array(distances))
# TODO: cache results such that only references are used afterwards
grid = RegularCartesianGrid(
shape=shape,
total_volume=totvol,
distances=distances,
)
# Pre-compute lengths of modes and indices for distributing power
if harmonic_type.lower() == "fourier":
m_length_idx, m_length, m_count = get_fourier_mode_distributor(
shape, distances
)
um = m_length.copy()
um = um.at[1:].set(jnp.log(um[1:]))
um = um.at[1:].add(-um[1])
assert um[0] == 0.
log_vol = um[2:] - um[1:-1]
assert um.shape[0] - 2 == log_vol.shape[0]
harmonic_grid = RegularFourierGrid(
shape=shape,
power_distributor=m_length_idx,
mode_multiplicity=m_count,
mode_lengths=m_length,
relative_log_mode_lengths=um,
log_volume=log_vol,
)
else:
ve = f"invalid `harmonic_type` {harmonic_type!r}"
raise ValueError(ve)
grid = grid._replace(harmonic_grid=harmonic_grid)
return grid
def _remove_slope(rel_log_mode_dist, x):
sc = rel_log_mode_dist / rel_log_mode_dist[-1]
return x - x[-1] * sc
[docs]
def matern_amplitude(
grid,
scale: Callable,
cutoff: Callable,
loglogslope: Callable,
renormalize_amplitude: bool,
prefix: str = "",
kind: str = "amplitude",
) -> Model:
"""Constructs a function computing the amplitude of a Matérn-kernel
power spectrum.
See
:class:`nifty8.re.correlated_field.CorrelatedFieldMaker.add_fluctuations
_matern`
for more details on the parameters.
See also
--------
`Causal, Bayesian, & non-parametric modeling of the SARS-CoV-2 viral
load vs. patient's age`, Guardiani, Matteo and Frank, Kostić Andrija
and Edenhofer, Gordian and Roth, Jakob and Uhlmann, Berit and
Enßlin, Torsten, `<https://arxiv.org/abs/2105.13483>`_
`<https://doi.org/10.1371/journal.pone.0275011>`_
"""
totvol = grid.total_volume
mode_lengths = grid.harmonic_grid.mode_lengths
mode_multiplicity = grid.harmonic_grid.mode_multiplicity
scale = WrappedCall(scale, name=prefix + "scale")
ptree = scale.domain.copy()
cutoff = WrappedCall(cutoff, name=prefix + "cutoff")
ptree.update(cutoff.domain)
loglogslope = WrappedCall(loglogslope, name=prefix + "loglogslope")
ptree.update(loglogslope.domain)
def correlate(primals: Mapping) -> jnp.ndarray:
scl = scale(primals)
ctf = cutoff(primals)
slp = loglogslope(primals)
ln_spectrum = 0.25 * slp * jnp.log1p((mode_lengths / ctf)**2)
spectrum = jnp.exp(ln_spectrum)
norm = 1.
if renormalize_amplitude:
logger.warning("Renormalize amplidude is not yet tested!")
if kind.lower() == "amplitude":
norm = jnp.sqrt(
jnp.sum(mode_multiplicity[1:] * spectrum[1:]**4)
)
elif kind.lower() == "power":
norm = jnp.sqrt(
jnp.sum(mode_multiplicity[1:] * spectrum[1:]**2)
)
norm /= jnp.sqrt(totvol) # Due to integral in harmonic space
spectrum = scl * (jnp.sqrt(totvol) / norm) * spectrum
spectrum = spectrum.at[0].set(totvol)
if kind.lower() == "power":
spectrum = jnp.sqrt(spectrum)
elif kind.lower() != "amplitude":
raise ValueError(f"invalid kind specified {kind!r}")
return spectrum
return Model(
correlate, domain=ptree, init=partial(random_like, primals=ptree)
)
[docs]
def non_parametric_amplitude(
grid,
fluctuations: Callable,
loglogavgslope: Callable,
flexibility: Optional[Callable] = None,
asperity: Optional[Callable] = None,
prefix: str = "",
kind: str = "amplitude",
) -> Model:
"""Constructs a function computing the amplitude of a non-parametric power
spectrum
See
:class:`nifty8.re.correlated_field.CorrelatedFieldMaker.add_fluctuations`
for more details on the parameters.
See also
--------
`Variable structures in M87* from space, time and frequency resolved
interferometry`, Arras, Philipp and Frank, Philipp and Haim, Philipp
and Knollmüller, Jakob and Leike, Reimar and Reinecke, Martin and
Enßlin, Torsten, `<https://arxiv.org/abs/2002.05218>`_
`<http://dx.doi.org/10.1038/s41550-021-01548-0>`_
"""
totvol = grid.total_volume
rel_log_mode_len = grid.harmonic_grid.relative_log_mode_lengths
mode_multiplicity = grid.harmonic_grid.mode_multiplicity
log_vol = grid.harmonic_grid.log_volume
fluctuations = WrappedCall(
fluctuations, name=prefix + "fluctuations", white_init=True
)
ptree = fluctuations.domain.copy()
loglogavgslope = WrappedCall(
loglogavgslope, name=prefix + "loglogavgslope", white_init=True
)
ptree.update(loglogavgslope.domain)
if flexibility is not None and (log_vol.size > 0):
flexibility = WrappedCall(
flexibility, name=prefix + "flexibility", white_init=True
)
assert log_vol is not None
assert rel_log_mode_len.ndim == log_vol.ndim == 1
if asperity is not None:
asperity = WrappedCall(
asperity, name=prefix + "asperity", white_init=True
)
deviations = IntegratedWienerProcess(
jnp.zeros((2, )),
flexibility,
log_vol,
name=prefix + "spectrum",
asperity=asperity
)
ptree.update(deviations.domain)
else:
deviations = None
def correlate(primals: Mapping) -> jnp.ndarray:
flu = fluctuations(primals)
slope = loglogavgslope(primals)
slope *= rel_log_mode_len
ln_spectrum = slope
if deviations is not None:
twolog = deviations(primals)
# Prepend zeromode
twolog = jnp.concatenate((jnp.zeros((1, )), twolog[:, 0]))
ln_spectrum += _remove_slope(rel_log_mode_len, twolog)
# Exponentiate and norm the power spectrum
spectrum = jnp.exp(ln_spectrum)
# Take the sqrt of the integral of the slope w/o fluctuations and
# zero-mode while taking into account the multiplicity of each mode
if kind.lower() == "amplitude":
norm = jnp.sqrt(jnp.sum(mode_multiplicity[1:] * spectrum[1:]**2))
norm /= jnp.sqrt(totvol) # Due to integral in harmonic space
amplitude = flu * (jnp.sqrt(totvol) / norm) * spectrum
elif kind.lower() == "power":
norm = jnp.sqrt(jnp.sum(mode_multiplicity[1:] * spectrum[1:]))
norm /= jnp.sqrt(totvol) # Due to integral in harmonic space
amplitude = flu * (jnp.sqrt(totvol) / norm) * jnp.sqrt(spectrum)
else:
raise ValueError(f"invalid kind specified {kind!r}")
amplitude = amplitude.at[0].set(totvol)
return amplitude
return Model(
correlate, domain=ptree, init=partial(random_like, primals=ptree)
)