nifty8.re.model module#

class Initializer(call_or_struct)[source]#

Bases: object

domain = ShapeWithDtype(shape=(2,), dtype=<class 'jax.numpy.uint32'>)#
property stupid#
property target#
class LazyModel(domain=<class 'nifty8.re.model.NoValue'>, target=<class 'nifty8.re.model.NoValue'>, init=<class 'nifty8.re.model.NoValue'>)[source]#

Bases: object

__init__(domain=<class 'nifty8.re.model.NoValue'>, target=<class 'nifty8.re.model.NoValue'>, init=<class 'nifty8.re.model.NoValue'>)[source]#
property domain#
property init: Initializer#
property target#
class Model(call: ~typing.Callable | None = None, *, domain=<class 'nifty8.re.model.NoValue'>, target=<class 'nifty8.re.model.NoValue'>, init=<class 'nifty8.re.model.NoValue'>, white_init=False)[source]#

Bases: LazyModel

Join a callable with a domain, target, and an init method.

From a domain and a callable, this class automatically derives the target as well as instantiate a default initializer if not set explicitly. More importantly though, it registers the class as PyTree in JAX using metaprogramming. By default all properties are hidden from JAX except those marked via dataclasses.field(metadata=dict(static=False)) as non-static.

__init__(call: ~typing.Callable | None = None, *, domain=<class 'nifty8.re.model.NoValue'>, target=<class 'nifty8.re.model.NoValue'>, init=<class 'nifty8.re.model.NoValue'>, white_init=False)[source]#

Wrap a callable and associate it with a domain.

Parameters:
  • call (callable) – Method acting on objects of type domain.

  • domain (tree-like structure of ShapeWithDtype, optional) – PyTree of objects with a shape and dtype attribute. Inferred from init if not specified.

  • target (tree-like structure of ShapeWithDtype, optional) – PyTree of objects with a shape and dtype attribute akin to the output of call. Inferrred from call and domain if not set.

  • init (callable, optional) – Initialization method taking a PRNG key as first argument and creating an object of type domain. Inferred from domain assuming a white normal prior if not set.

  • white_init (bool, optional) – If True, the domain is set to a white normal prior. Defaults to False.

class ModelMeta(name, bases, dict_, /, **kwargs)[source]#

Bases: ABCMeta

Register all derived classes as PyTrees in JAX using metaprogramming.

For any dataclasses.Field property with a metadata-entry named “static”, we will either hide or expose the property to JAX depending on the value.

class NoValue[source]#

Bases: object

class WrappedCall(call: Callable, *, name=None, shape=(), dtype=None, white_init=False, target=<class 'nifty8.re.model.NoValue'>)[source]#

Bases: Model

__init__(call: ~typing.Callable, *, name=None, shape=(), dtype=None, white_init=False, target=<class 'nifty8.re.model.NoValue'>)[source]#

Transforms call such that instead of it acting on input it selects name from input using input[name].

Parameters:
  • call (callable) – Callable to wrap.

  • name (hashable, optional) – New name of the input on which call acts.

  • shape (tuple or tree-like structure of ShapeWithDtype) – Shape of old input on which call acts. This can also be an arbitrary shape-dtype structure in which case dtype is ignored. Defaults to a scalar.

  • dtype (dtype or tree-like structure of ShapeWithDtype) – If shape is a tuple, this is the dtype of the old input on which call acts. This is redundant if shape already encodes the dtype.

See Model for details on the remaining arguments.