Source code for nifty8.operators.operator_adapter
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2020 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from .linear_operator import LinearOperator
[docs]
class OperatorAdapter(LinearOperator):
"""Class representing the inverse and/or adjoint of another operator.
Objects of this class are created internally by `LinearOperator` whenever
the inverse and/or adjoint of an already existing operator object is
requested via the `LinearOperator` attributes `inverse`, `adjoint` or
`_flip_modes()`.
Users should never have to create instances of this class directly.
Parameters
----------
op : LinearOperator
The operator on which the adapter will act
op_transform : int
1) adjoint
2) inverse
3) adjoint inverse
"""
[docs]
def __init__(self, op, op_transform, domain_dtype=float):
self._op = op
self._trafo = int(op_transform)
if self._trafo < 1 or self._trafo > 3:
raise ValueError("invalid operator transformation")
self._domain = self._op._dom(1 << self._trafo)
self._target = self._op._tgt(1 << self._trafo)
self._capability = self._capTable[self._trafo][self._op.capability]
try:
import jax.numpy as jnp
from jax import eval_shape, linear_transpose
from jax.tree_util import tree_all, tree_map
from ..nifty2jax import shapewithdtype_from_domain
from ..re import Vector
if callable(op.jax_expr) and self._trafo == self.ADJOINT_BIT:
def jax_expr(y):
op_domain = shapewithdtype_from_domain(op.domain, domain_dtype)
op_domain = Vector(op_domain) if isinstance(y, Vector) else op_domain
tentative_yshape = eval_shape(op.jax_expr, op_domain)
if not tree_all(tree_map(lambda a,b : jnp.can_cast(a.dtype, b.dtype), y, tentative_yshape)):
raise ValueError(f"wrong dtype during transposition:/got {tentative_yshape} and expected {y!r}")
y = tree_map(lambda c, d: c.astype(d.dtype), y, tentative_yshape)
y_conj = tree_map(jnp.conj, y)
jax_expr_T = linear_transpose(op.jax_expr, op_domain)
return tree_map(jnp.conj, jax_expr_T(y_conj)[0])
self._jax_expr = jax_expr
elif hasattr(op, "_jax_expr_inv") and callable(op._jax_expr_inv) and self._trafo == self.INVERSE_BIT:
self._jax_expr = op._jax_expr_inv
self._jax_expr_inv = op._jax_expr
else:
self._jax_expr = None
except ImportError:
self._jax_expr = None
def _flip_modes(self, trafo):
newtrafo = trafo ^ self._trafo
return self._op if newtrafo == 0 \
else OperatorAdapter(self._op, newtrafo)
[docs]
def apply(self, x, mode):
return self._op.apply(x,
self._modeTable[self._trafo][self._ilog[mode]])
[docs]
def draw_sample(self, from_inverse=False):
if self._trafo & self.INVERSE_BIT:
return self._op.draw_sample(not from_inverse)
return self._op.draw_sample(from_inverse)
[docs]
def __repr__(self):
from ..utilities import indent
mode = ["adjoint", "inverse", "adjoint inverse"][self._trafo-1]
res = "OperatorAdapter: {}\n".format(mode)
return res + indent(self._op.__repr__())