nifty8.operators.jax_operator module#
- class JaxLikelihoodEnergyOperator(domain, func, transformation=None, sampling_dtype=None)[source]#
Bases:
LikelihoodEnergyOperator
Wrap a jax function as nifty likelihood energy operator.
- Parameters:
domain (DomainTuple or MultiDomain) – Domain of the operator.
func (callable) – The jax function that is evaluated by the operator. It has to be implemented in terms of jax.numpy calls. If domain is a MultiDomain, func takes a dict as argument and like-wise for the target. It needs to map to a scalar.
transformation (Operator, optional) – Coordinate transformation to Euclidean space.
sampling_dtype (dtype, optional) – The dtype that shall be used for drawing samples from the metric of the likelihood.
- apply(x)[source]#
Applies the operator to a Field, MultiField or Linearization.
- Parameters:
x (
nifty8.field.Field
,nifty8.multi_field.MultiField
,) – ornifty8.linearization.Linearization
Input on which the operator shall act. Needs to be defined ondomain
. If x`is a :class:`nifty8.linearization.Linearization, apply returns a newnifty8.linearization.Linearization
contining the result of the operator application as well as its Jacobian, evaluated at x.
- get_transformation()[source]#
The coordinate transformation that maps into a coordinate system in which the metric of a likelihood is the Euclidean metric.
- Returns:
np.dtype, or dict of np.dtype (The dtype(s) of the target space of the)
transformation.
Operator (The transformation that maps from domain into the)
Euclidean target space.
Note
This Euclidean target space is the disjoint union of the Euclidean target spaces of all summands. Therefore, the keys of MultiDomains are prefixed with an index and DomainTuples are converted to MultiDomains with the index as the key.
- class JaxLinearOperator(domain, target, func, domain_dtype=None, func_T=None)[source]#
Bases:
LinearOperator
Wrap a jax function as nifty linear operator.
- Parameters:
domain (DomainTuple or MultiDomain) – Domain of the operator.
target (DomainTuple or MultiDomain) – Target of the operator.
func (callable) – The jax function that is evaluated by the operator. It has to be implemented in terms of jax.numpy calls. If domain is a MultiDomain, func takes a dict as argument and like-wise for the target.
func_T (callable) – The jax function that implements the transposed action of the operator. If None, jax computes the adjoint. Note that this is not the adjoint action. Default: None.
domain_dtype – Needs to be set if func_transposed is None. Otherwise it does not have an effect. Dtype of the domain. If domain is a MultiDomain, domain_dtype is supposed to be a dictionary. Default: None.
Note
It is the user’s responsibility that func is actually a linear function. The user can double check this with the help of nifty8.extra.check_linear_operator.
- apply(x, mode)[source]#
Applies the Operator to a given x, in a specified mode.
- Parameters:
x (
nifty8.field.Field
) – The input Field, defined on the Operator’s domain or target, depending on mode.mode (int) –
TIMES
: normal applicationADJOINT_TIMES
: adjoint applicationINVERSE_TIMES
: inverse applicationADJOINT_INVERSE_TIMES
orINVERSE_ADJOINT_TIMES
: adjoint inverse application
- Returns:
The processed Field defined on the Operator’s target or domain, depending on mode.
- Return type:
- class JaxOperator(domain, target, func)[source]#
Bases:
Operator
Wrap a jax function as nifty operator.
- Parameters:
domain (DomainTuple or MultiDomain) – Domain of the operator.
target (DomainTuple or MultiDomain) – Target of the operator.
func (callable) – The jax function that is evaluated by the operator. It has to be implemented in terms of jax.numpy calls. If domain is a MultiDomain, func takes a dict as argument and like-wise for the target.
- apply(x)[source]#
Applies the operator to a Field, MultiField or Linearization.
- Parameters:
x (
nifty8.field.Field
,nifty8.multi_field.MultiField
,) – ornifty8.linearization.Linearization
Input on which the operator shall act. Needs to be defined ondomain
. If x`is a :class:`nifty8.linearization.Linearization, apply returns a newnifty8.linearization.Linearization
contining the result of the operator application as well as its Jacobian, evaluated at x.