# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
import dataclasses
import operator
from functools import partial
from typing import Any, Callable, Optional, Tuple, Union
from jax import numpy as jnp
from jax.tree_util import Partial, tree_map
from .likelihood import Likelihood
from .logger import logger
from .model import LazyModel
from .tree_math import ShapeWithDtype, result_type, sum, vdot
def _standard_t(nwr, dof):
res = (nwr.conj() * nwr).real / dof
return sum(tree_map(jnp.log1p, res) * (dof + 1)) / 2
def _shape_w_fixed_dtype(dtype):
def shp_w_dtp(e):
return ShapeWithDtype(jnp.shape(e), dtype)
return shp_w_dtp
@Partial
def _identity(x):
return x
def _get_cov_inv_and_std_inv(
cov_inv: Optional[Callable],
std_inv: Optional[Callable],
primals=None
) -> Tuple[Union[Partial, LazyModel], Union[Partial, LazyModel]]:
if cov_inv is None and std_inv is None:
_cov_inv, _std_inv = _identity, _identity
if not callable(cov_inv) and cov_inv is not None:
msg = "assuming the specified inverse covariance is diagonal"
logger.warning(msg)
_cov_inv = Partial(operator.mul, cov_inv)
elif cov_inv is None:
wm = (
"assuming a diagonal covariance matrix"
";\nsetting `cov_inv` to `std_inv(ones_like(data))**2`"
)
logger.warning(wm)
# Note, `_std_inv` is not properly initialized yet
si = std_inv if std_inv is not None else _std_inv
noise_std_inv_sq = si(
tree_map(jnp.real, tree_map(jnp.ones_like, primals))
)**2
_cov_inv = Partial(operator.mul, noise_std_inv_sq)
else:
_cov_inv = cov_inv if isinstance(cov_inv,
(Partial,
LazyModel)) else Partial(cov_inv)
if not callable(std_inv) and std_inv is not None:
msg = "assuming the specified sqrt of the inverse covariance is diagonal"
logger.warning(msg)
_std_inv = Partial(operator.mul, std_inv)
elif std_inv is None:
wm = (
"assuming a diagonal covariance matrix"
";\nsetting `std_inv` to `cov_inv(ones_like(data))**0.5`"
)
logger.warning(wm)
noise_cov_inv_sqrt = tree_map(
jnp.sqrt,
_cov_inv(tree_map(jnp.real, tree_map(jnp.ones_like, primals)))
)
_std_inv = Partial(operator.mul, noise_cov_inv_sqrt)
else:
_std_inv = std_inv if isinstance(std_inv,
(Partial,
LazyModel)) else Partial(std_inv)
assert all(
isinstance(c, (Partial, LazyModel)) for c in (_cov_inv, _std_inv)
)
return _cov_inv, _std_inv
[docs]
class Gaussian(Likelihood):
"""Gaussian likelihood of the data
Parameters
----------
data : tree-like structure of jnp.ndarray and float
Data with additive noise following a Gaussian distribution.
noise_cov_inv : callable acting on type of data
Function applying the inverse noise covariance of the Gaussian.
noise_std_inv : callable acting on type of data
Function applying the square root of the inverse noise covariance.
Notes
-----
If `noise_std_inv` is `None` it is inferred by assuming a diagonal noise
covariance, i.e. by applying it to a vector of ones and taking the square
root. If both `noise_cov_inv` and `noise_std_inv` are `None`, a unit
covariance is assumed.
See :class:`Likelihood` for details on the properties.
"""
data: Any = dataclasses.field(metadata=dict(static=False))
noise_cov_inv: Callable = dataclasses.field(metadata=dict(static=False))
noise_std_inv: Callable = dataclasses.field(metadata=dict(static=False))
[docs]
def __init__(
self,
data,
noise_cov_inv: Optional[Callable] = None,
noise_std_inv: Optional[Callable] = None
):
self.data = data
noise_cov_inv, noise_std_inv = _get_cov_inv_and_std_inv(
noise_cov_inv, noise_std_inv, data
)
self.noise_cov_inv = noise_cov_inv
self.noise_std_inv = noise_std_inv
shp = tree_map(ShapeWithDtype.from_leave, data)
super().__init__(domain=shp, lsm_tangents_shape=shp)
[docs]
def energy(self, primals):
p_res = self.data - primals
return 0.5 * vdot(p_res, self.noise_cov_inv(p_res)).real
[docs]
def normalized_residual(self, primals):
return self.noise_std_inv(self.data - primals)
[docs]
def metric(self, primals, tangents):
return self.noise_cov_inv(tangents)
[docs]
def left_sqrt_metric(self, primals, tangents):
return self.noise_std_inv(tangents)
[docs]
class StudentT(Likelihood):
"""Student's t likelihood of the data
Parameters
----------
data : tree-like structure of jnp.ndarray and float
Data with additive noise following a Gaussian distribution.
dof : tree-like structure of jnp.ndarray and float
Degree-of-freedom parameter of Student's t distribution.
noise_cov_inv : callable acting on type of data
Function applying the inverse noise covariance of the Gaussian.
noise_std_inv : callable acting on type of data
Function applying the square root of the inverse noise covariance.
Notes
-----
If `noise_std_inv` is `None` it is inferred by assuming a diagonal noise
covariance, i.e. by applying it to a vector of ones and taking the square
root. If both `noise_cov_inv` and `noise_std_inv` are `None`, a unit
covariance is assumed.
See :class:`Likelihood` for details on the properties.
"""
data: Any = dataclasses.field(metadata=dict(static=False))
dof: Any = dataclasses.field(metadata=dict(static=False))
noise_cov_inv: Callable = dataclasses.field(metadata=dict(static=False))
noise_std_inv: Callable = dataclasses.field(metadata=dict(static=False))
[docs]
def __init__(
self,
data,
dof,
noise_cov_inv: Optional[Callable] = None,
noise_std_inv: Optional[Callable] = None,
):
self.data = data
self.dof = dof
noise_cov_inv, noise_std_inv = _get_cov_inv_and_std_inv(
noise_cov_inv, noise_std_inv, data
)
self.noise_cov_inv = noise_cov_inv
self.noise_std_inv = noise_std_inv
shp = tree_map(ShapeWithDtype.from_leave, data)
super().__init__(domain=shp, lsm_tangents_shape=shp)
[docs]
def energy(self, primals):
return _standard_t(self.noise_std_inv(self.data - primals), self.dof)
[docs]
def metric(self, primals, tangents):
return self.noise_cov_inv((self.dof + 1) / (self.dof + 3) * tangents)
[docs]
def left_sqrt_metric(self, primals, tangents):
return self.noise_std_inv(
((self.dof + 1) / (self.dof + 3))**0.5 * tangents
)
[docs]
def normalized_residual(self, primals):
return self.left_sqrt_metric(None, self.data - primals)
[docs]
class Poissonian(Likelihood):
"""Computes the negative log-likelihood, i.e. the Hamiltonians of an
expected count Vector constrained by Poissonian count data.
Represents up to an f-independent term :math:`log(d!)`:
.. math ::
E(f) = -\\log \\text{Poisson}(d|f) = \\sum f - d^\\dagger \\log(f),
where f is a Vector in data space of the expectation values for the counts.
Parameters
----------
data : jnp.ndarray or tree-like structure of jnp.ndarray and float
Data Vector with counts. Needs to have integer dtype and all values need
to be non-negative.
sampling_dtype : dtype, optional
Data-type for sampling.
See :class:`Likelihood` for details on the properties.
"""
data: Any = dataclasses.field(metadata=dict(static=False))
[docs]
def __init__(self, data, sampling_dtype=float):
dtp = result_type(data)
if not jnp.issubdtype(dtp, jnp.integer):
raise TypeError("`data` of invalid type")
if sum(tree_map(lambda x: jnp.any(x < 0), data)):
raise ValueError("`data` must not be negative")
self.data = data
shp = tree_map(_shape_w_fixed_dtype(sampling_dtype), data)
super().__init__(domain=shp, lsm_tangents_shape=shp)
[docs]
def energy(self, primals):
return sum(primals) - vdot(tree_map(jnp.log, primals), self.data)
[docs]
def metric(self, primals, tangents):
return tangents / primals
[docs]
def left_sqrt_metric(self, primals, tangents):
return tangents / primals**0.5
[docs]
def normalized_residual(self, primals):
return self.left_sqrt_metric(primals, self.data - primals)
[docs]
class VariableCovarianceGaussian(Likelihood):
"""Gaussian likelihood of the data with a variable covariance
Parameters
----------
data : tree-like structure of jnp.ndarray and float
Data with additive noise following a Gaussian distribution.
iscomplex: Boolean, optional
Whether the parameters are complex-valued.
Notes
-----
**The likelihood acts on a tuple of (mean, std_inv)**.
See :class:`Likelihood` for details on the properties.
"""
data: Any = dataclasses.field(metadata=dict(static=False))
iscomplex: bool = False
[docs]
def __init__(self, data, iscomplex=False):
# TODO: make configurable whether `std_inv` or `std` is passed
self.data = data
self.iscomplex = iscomplex
shp = tree_map(ShapeWithDtype.from_leave, (data, data.real))
super().__init__(domain=shp, lsm_tangents_shape=shp)
[docs]
def energy(self, primals):
res = (self.data - primals[0]) * primals[1]
fct = 1 + self.iscomplex
return 0.5 * vdot(res,
res).real - fct * sum(tree_map(jnp.log, primals[1]))
[docs]
def metric(self, primals, tangents):
fct = 2 * (1 + self.iscomplex)
prim_std_inv_sq = primals[1]**2
res = (
prim_std_inv_sq * tangents[0], fct * tangents[1] / prim_std_inv_sq
)
return type(primals)(res)
[docs]
def left_sqrt_metric(self, primals, tangents):
fct = jnp.sqrt(2)**(1 + self.iscomplex)
res = (primals[1] * tangents[0], fct * tangents[1] / primals[1])
return type(primals)(res)
[docs]
def normalized_residual(self, primals):
return (self.data - primals[0]) * primals[1]
[docs]
class VariableCovarianceStudentT(Likelihood):
"""Student's t likelihood of the data with a variable covariance
Parameters
----------
data : tree-like structure of jnp.ndarray and float
Data with additive noise following a Gaussian distribution.
dof : tree-like structure of jnp.ndarray and float
Degree-of-freedom parameter of Student's t distribution.
Notes
-----
**The likelihood acts on a tuple of (mean, std)**.
See :class:`Likelihood` for details on the properties.
"""
data: Any = dataclasses.field(metadata=dict(static=False))
dof: Any = dataclasses.field(metadata=dict(static=False))
[docs]
def __init__(self, data, dof):
# TODO: make configurable whether `std_inv` or `std` is passed
self.data = data
self.dof = dof
shp = tree_map(ShapeWithDtype.from_leave, (data, data))
super().__init__(domain=shp, lsm_tangents_shape=shp)
[docs]
def energy(self, primals):
t = _standard_t((self.data - primals[0]) / primals[1], self.dof)
t += sum(tree_map(jnp.log, primals[1]))
return t
[docs]
def metric(self, primals, tangent):
res = (
tangent[0] * (self.dof + 1) / (self.dof + 3) / primals[1]**2,
tangent[1] * 2 * self.dof / (self.dof + 3) / primals[1]**2
)
return type(primals)(res)
[docs]
def left_sqrt_metric(self, primals, tangents):
cov = (
(self.dof + 1) / (self.dof + 3) / primals[1]**2,
2 * self.dof / (self.dof + 3) / primals[1]**2
)
res = (cov[0]**0.5 * tangents[0], cov[1]**0.5 * tangents[1])
return type(primals)(res)
[docs]
def normalized_residual(self, primals):
return (self.data - primals[0]
) / primals[1] * ((self.dof + 1) / (self.dof + 3))**0.5
[docs]
class Categorical(Likelihood):
"""Categorical likelihood of the data, equivalent to cross entropy
Parameters
----------
data : tree-like structure of jnp.ndarray and int
Which of the categories is the realized in the data. Must agree with the
input shape except for the shape[axis] of the leafs
axis : int
Leaf-axis over which the categories are formed
sampling_dtype : dtype, optional
Data-type for sampling.
See :class:`Likelihood` for details on the properties.
"""
data: Any = dataclasses.field(metadata=dict(static=False))
axis: int = -1
[docs]
def __init__(self, data, axis=-1, sampling_dtype=float):
self.data = data
self.axis = axis
shp = tree_map(_shape_w_fixed_dtype(sampling_dtype), data)
super().__init__(domain=shp, lsm_tangents_shape=shp)
[docs]
def energy(self, primals):
from jax.nn import log_softmax
def eval(p, d):
logits = log_softmax(p, axis=self.axis)
return -jnp.sum(jnp.take_along_axis(logits, d, self.axis))
return sum(tree_map(eval, primals, self.data))
[docs]
def metric(self, primals, tangents):
from jax.nn import softmax
preds = tree_map(partial(softmax, axis=self.axis), primals)
norm_term = tree_map(
partial(jnp.sum, axis=self.axis, keepdims=True), preds * tangents
)
return preds * tangents - preds * sum(norm_term)
[docs]
def left_sqrt_metric(self, primals, tangents):
from jax.nn import softmax
# FIXME: not sure if this is really the square root
sqrtp = tree_map(partial(softmax, axis=self.axis), primals)**0.5
norm_term = tree_map(
partial(jnp.sum, axis=self.axis, keepdims=True), sqrtp * tangents
)
norm_term = sum(norm_term)
return sqrtp * (tangents - sqrtp * norm_term)