nifty8.operators.jax_operator module#

class JaxLikelihoodEnergyOperator(domain, func, transformation=None, sampling_dtype=None)[source]#

Bases: LikelihoodEnergyOperator

Wrap a jax function as nifty likelihood energy operator.

Parameters:
  • domain (DomainTuple or MultiDomain) – Domain of the operator.

  • func (callable) – The jax function that is evaluated by the operator. It has to be implemented in terms of jax.numpy calls. If domain is a MultiDomain, func takes a dict as argument and like-wise for the target. It needs to map to a scalar.

  • transformation (Operator, optional) – Coordinate transformation to Euclidean space.

  • sampling_dtype (dtype, optional) – The dtype that shall be used for drawing samples from the metric of the likelihood.

__init__(domain, func, transformation=None, sampling_dtype=None)[source]#
apply(x)[source]#

Applies the operator to a Field, MultiField or Linearization.

Parameters:

x (nifty8.field.Field, nifty8.multi_field.MultiField,) – or nifty8.linearization.Linearization Input on which the operator shall act. Needs to be defined on domain. If x`is a :class:`nifty8.linearization.Linearization, apply returns a new nifty8.linearization.Linearization contining the result of the operator application as well as its Jacobian, evaluated at x.

get_transformation()[source]#

The coordinate transformation that maps into a coordinate system in which the metric of a likelihood is the Euclidean metric.

Returns:

  • np.dtype, or dict of np.dtype (The dtype(s) of the target space of the)

  • transformation.

  • Operator (The transformation that maps from domain into the)

  • Euclidean target space.

Note

This Euclidean target space is the disjoint union of the Euclidean target spaces of all summands. Therefore, the keys of MultiDomains are prefixed with an index and DomainTuples are converted to MultiDomains with the index as the key.

class JaxLinearOperator(domain, target, func, domain_dtype=None, func_T=None)[source]#

Bases: LinearOperator

Wrap a jax function as nifty linear operator.

Parameters:
  • domain (DomainTuple or MultiDomain) – Domain of the operator.

  • target (DomainTuple or MultiDomain) – Target of the operator.

  • func (callable) – The jax function that is evaluated by the operator. It has to be implemented in terms of jax.numpy calls. If domain is a MultiDomain, func takes a dict as argument and like-wise for the target.

  • func_T (callable) – The jax function that implements the transposed action of the operator. If None, jax computes the adjoint. Note that this is not the adjoint action. Default: None.

  • domain_dtype – Needs to be set if func_transposed is None. Otherwise it does not have an effect. Dtype of the domain. If domain is a MultiDomain, domain_dtype is supposed to be a dictionary. Default: None.

Note

It is the user’s responsibility that func is actually a linear function. The user can double check this with the help of nifty8.extra.check_linear_operator.

__init__(domain, target, func, domain_dtype=None, func_T=None)[source]#
apply(x, mode)[source]#

Applies the Operator to a given x, in a specified mode.

Parameters:
  • x (nifty8.field.Field) – The input Field, defined on the Operator’s domain or target, depending on mode.

  • mode (int) –

    • TIMES: normal application

    • ADJOINT_TIMES: adjoint application

    • INVERSE_TIMES: inverse application

    • ADJOINT_INVERSE_TIMES or INVERSE_ADJOINT_TIMES: adjoint inverse application

Returns:

The processed Field defined on the Operator’s target or domain, depending on mode.

Return type:

nifty8.field.Field

class JaxOperator(domain, target, func)[source]#

Bases: Operator

Wrap a jax function as nifty operator.

Parameters:
  • domain (DomainTuple or MultiDomain) – Domain of the operator.

  • target (DomainTuple or MultiDomain) – Target of the operator.

  • func (callable) – The jax function that is evaluated by the operator. It has to be implemented in terms of jax.numpy calls. If domain is a MultiDomain, func takes a dict as argument and like-wise for the target.

__init__(domain, target, func)[source]#
apply(x)[source]#

Applies the operator to a Field, MultiField or Linearization.

Parameters:

x (nifty8.field.Field, nifty8.multi_field.MultiField,) – or nifty8.linearization.Linearization Input on which the operator shall act. Needs to be defined on domain. If x`is a :class:`nifty8.linearization.Linearization, apply returns a new nifty8.linearization.Linearization contining the result of the operator application as well as its Jacobian, evaluated at x.