Source code for nifty8.operators.sandwich_operator
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2021 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from .endomorphic_operator import EndomorphicOperator
from .linear_operator import LinearOperator
from .scaling_operator import ScalingOperator
[docs]
class SandwichOperator(EndomorphicOperator):
"""Operator which is equivalent to the expression
`bun.adjoint(cheese(bun))`.
Note
----
This operator should always be called using the `make` method.
"""
[docs]
def __init__(self, bun, cheese, op, _callingfrommake=False):
if not _callingfrommake:
raise NotImplementedError
self._bun = bun
self._cheese = cheese
self._op = op
self._domain = op.domain
self._capability = op._capability
[docs]
@staticmethod
def make(bun, cheese=None, sampling_dtype=None):
"""Build a SandwichOperator (or something simpler if possible)
Parameters
----------
bun: LinearOperator
the bun part
cheese: EndomorphicOperator
the cheese part
sampling_dtype :
If this operator represents the covariance of a Gaussian probabilty
distribution and cheese is `None`, `sampling_dtype` specifies if it
is real or complex Gaussian. If `sampling_dtype` and `cheese` are
`None`, the operator cannot be used as a covariance, i.e. no samples
can be drawn. Default: None.
"""
if isinstance(cheese, SandwichOperator):
old_cheese = cheese
cheese = old_cheese._cheese
bun = old_cheese._bun @ bun
if not isinstance(bun, LinearOperator):
raise TypeError("bun must be a linear operator")
if cheese is not None and not isinstance(cheese, LinearOperator):
raise TypeError("cheese must be a linear operator or None")
if cheese is None:
cheese = ScalingOperator(bun.target, 1., sampling_dtype)
if isinstance(bun, ScalingOperator):
fct = abs(bun._factor)**2
if fct == 1.:
return cheese
op = cheese.scale(fct)
else:
op = bun.adjoint @ cheese @ bun
return SandwichOperator(bun, cheese, op, _callingfrommake=True)
[docs]
def apply(self, x, mode):
return self._op.apply(x, mode)
[docs]
def draw_sample(self, from_inverse=False):
# Inverse samples from general sandwiches are not possible
if from_inverse:
if self._bun.capability & self._bun.INVERSE_TIMES:
try:
s = self._cheese.draw_sample(from_inverse)
return self._bun.inverse_times(s)
except NotImplementedError:
pass
raise NotImplementedError(
"cannot draw from inverse of this operator")
# Samples from general sandwiches
return self._bun.adjoint_times(
self._cheese.draw_sample(from_inverse))
[docs]
def get_sqrt(self):
if self._cheese is None:
return self._bun
return self._cheese.get_sqrt() @ self._bun
[docs]
def __repr__(self):
from ..utilities import indent
return "\n".join((
"SandwichOperator:",
indent("\n".join((
"Cheese:", self._cheese.__repr__(),
"Bun:", self._bun.__repr__())))))