Source code for nifty8.nifty2jax

# Copyright(C) 2013-2021 Max-Planck-Society
# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause

import operator
from functools import partial, reduce
from typing import Any, Callable, Tuple, Union, Dict
from warnings import warn

from . import re as jft
from .domain_tuple import DomainTuple
from .field import Field
from .multi_domain import MultiDomain
from .multi_field import MultiField
from .operators.operator import Operator
from .sugar import makeField


[docs] def spaces_to_axes(domain, spaces) -> Union[Tuple, int, None]: """Converts spaces in a domain to axes of the underlying NumPy array.""" if spaces is None: return None domain = DomainTuple.make(domain) axes = tuple(domain.axes[sp_index] for sp_index in spaces) axes = reduce(operator.add, axes) if len(axes) > 0 else axes return axes
[docs] def shapewithdtype_from_domain( domain, dtype ) -> Union[jft.ShapeWithDtype, Dict[str, jft.ShapeWithDtype]]: if isinstance(dtype, dict): dtp_fallback = float # Fallback to `float` for unspecified keys k2dtp = dtype else: dtp_fallback = dtype k2dtp = {} if isinstance(domain, MultiDomain): parameter_tree = {} for k, dom in domain.items(): parameter_tree[k] = jft.ShapeWithDtype( dom.shape, k2dtp.get(k, dtp_fallback) ) elif isinstance(domain, DomainTuple): parameter_tree = jft.ShapeWithDtype(domain.shape, dtype) else: raise TypeError(f"incompatible domain {domain!r}") return parameter_tree
[docs] def wrap_nifty_call(op, target_dtype=float) -> Callable[[Any], jft.Vector]: from jax.experimental.host_callback import call if callable(op.jax_expr): warn("wrapping operator that has a callable `.jax_expr`") def nifty_call(x): # Minimal parts that must run outside of JAX x = makeField(op.domain, x) return op(x).val # TODO: define custom JVP and VJP rules pt = shapewithdtype_from_domain(op.target, target_dtype) hcb_call = partial(call, nifty_call, result_shape=pt) def wrapped_call(x) -> jft.Vector: x = x.tree if isinstance(x, jft.Vector) else x return jft.Vector(hcb_call(x)) return wrapped_call
[docs] def convert( nifty_obj: Union[Operator, DomainTuple, MultiDomain], dtype=float ) -> Union[jft.Model, jft.Vector, jft.ShapeWithDtype, Dict[str, jft.ShapeWithDtype]]: if not isinstance(nifty_obj, (Operator, DomainTuple, MultiDomain)): raise TypeError(f"invalid input type {type(nifty_obj)!r}") if isinstance(nifty_obj, (Field, MultiField)): return jft.Vector(nifty_obj.val) elif isinstance(nifty_obj, (DomainTuple, MultiDomain)): return shapewithdtype_from_domain(nifty_obj, dtype) else: expr = nifty_obj.jax_expr parameter_tree = shapewithdtype_from_domain(nifty_obj.domain, dtype) if not callable(expr): # TODO: implement conversion via host_callback and custom_vjp raise NotImplementedError("Sorry, not yet done :(") return jft.Model(expr, domain=parameter_tree)