Source code for nifty8.operators.sum_operator

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2021 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.

import operator
from collections import defaultdict

from ..sugar import domain_union
from ..utilities import indent
from .block_diagonal_operator import BlockDiagonalOperator
from .linear_operator import LinearOperator


[docs] class SumOperator(LinearOperator): """Class representing sums of operators. Notes ----- This operator has to be called using the `make` method. """
[docs] def __init__(self, ops, neg, dom, tgt, _callingfrommake=False): if not _callingfrommake: raise NotImplementedError self._ops = ops self._neg = neg self._domain = dom self._target = tgt self._capability = self.TIMES | self.ADJOINT_TIMES for op in ops: self._capability &= op.capability try: from ..re import unite def joined_jax_expr(x): res = None for op, n in zip(ops, neg): tmp = op.jax_expr(x) if res is None: res = -tmp if n is True else tmp else: o = operator.sub if n is True else operator.add res = unite(res, tmp, op=o) return res self._jax_expr = joined_jax_expr except ImportError: self._jax_expr = None
[docs] @staticmethod def simplify(ops, neg): from .diagonal_operator import DiagonalOperator from .scaling_operator import ScalingOperator # unpack SumOperators opsnew = [] negnew = [] for op, ng in zip(ops, neg): if isinstance(op, SumOperator): opsnew += op._ops if ng: negnew += [not n for n in op._neg] else: negnew += list(op._neg) # FIXME: this needs some more work to keep the domain and target unchanged! # elif isinstance(op, NullOperator): # pass else: opsnew.append(op) negnew.append(ng) ops = opsnew neg = negnew # sort operators according to domains sorted = defaultdict(list) for op, ng in zip(ops, neg): sorted[(op.domain, op.target)].append((op, ng)) xxops = [] xxneg = [] for opset in sorted.values(): # collect ScalingOperators sum = 0. opsnew = [] negnew = [] dtype = [] for op, ng in opset: if isinstance(op, ScalingOperator): sum += op._factor * (-1 if ng else 1) dtype.append(op._dtype) else: opsnew.append(op) negnew.append(ng) lastdom = opset[0][0].domain del(opset) if len(dtype) > 0: # Propagate sampling dtypes only if they are all the same if all(dtype[0] == ss for ss in dtype): dtype = dtype[0] else: dtype = None if sum != 0.: # try to absorb the factor into a DiagonalOperator for i in range(len(opsnew)): if isinstance(opsnew[i], DiagonalOperator): if opsnew[i]._dtype != dtype: continue sum *= (-1 if negnew[i] else 1) opsnew[i] = opsnew[i]._add(sum) sum = 0. break if sum != 0 or len(opsnew) == 0: # have to add the scaling operator at the end opsnew.append(ScalingOperator(lastdom, sum, dtype)) negnew.append(False) del(dtype, sum, lastdom) ops = opsnew neg = negnew # Step 4: combine DiagonalOperators where possible processed = [False] * len(ops) opsnew = [] negnew = [] for i in range(len(ops)): if not processed[i]: if isinstance(ops[i], DiagonalOperator): op = ops[i] opneg = neg[i] for j in range(i+1, len(ops)): if isinstance(ops[j], DiagonalOperator) and ops[i]._dtype == ops[j]._dtype: op = op._combine_sum(ops[j], opneg, neg[j]) opneg = False processed[j] = True opsnew.append(op) negnew.append(opneg) else: opsnew.append(ops[i]) negnew.append(neg[i]) ops = opsnew neg = negnew # combine BlockDiagonalOperators where possible processed = [False] * len(ops) opsnew = [] negnew = [] for i in range(len(ops)): if not processed[i]: if isinstance(ops[i], BlockDiagonalOperator): op = ops[i] opneg = neg[i] for j in range(i+1, len(ops)): if isinstance(ops[j], BlockDiagonalOperator): op = op._combine_sum(ops[j], opneg, neg[j]) opneg = False processed[j] = True opsnew.append(op) negnew.append(opneg) else: opsnew.append(ops[i]) negnew.append(neg[i]) xxops += opsnew xxneg += negnew dom = domain_union([op.domain for op in xxops]) tgt = domain_union([op.target for op in xxops]) return xxops, xxneg, dom, tgt
[docs] @staticmethod def make(ops, neg): """Build a SumOperator (or something simpler if possible) Parameters ---------- ops: list of LinearOperator Individual operators of the sum. neg: list of bool Same length as ops. If True then the corresponding operator gets a minus in the sum. """ ops = tuple(ops) neg = tuple(neg) if len(ops) == 0: raise ValueError("ops is empty") if len(ops) != len(neg): raise ValueError("length mismatch between ops and neg") ops, neg, dom, tgt = SumOperator.simplify(ops, neg) if len(ops) == 1: return -ops[0] if neg[0] else ops[0] return SumOperator(ops, neg, dom, tgt, _callingfrommake=True)
@property def adjoint(self): return self.make([op.adjoint for op in self._ops], self._neg)
[docs] def apply(self, x, mode): self._check_mode(mode) res = None for op, neg in zip(self._ops, self._neg): tmp = op.apply(x.extract(op._dom(mode)), mode) if res is None: res = -tmp if neg else tmp else: res = res.flexible_addsub(tmp, neg) return res
[docs] def draw_sample(self, from_inverse=False): if from_inverse: raise NotImplementedError( "cannot draw from inverse of this operator") res = None for op in self._ops: from .simple_linear_operators import NullOperator if isinstance(op, NullOperator): continue tmp = op.draw_sample(from_inverse) res = tmp if res is None else res.unite(tmp) return res
[docs] def __repr__(self): subs = "\n".join(sub.__repr__() for sub in self._ops) return "SumOperator:\n"+indent(subs)
def _simplify_for_constant_input_nontrivial(self, c_inp): f = [] o = [] for op in self._ops: tf, to = op.simplify_for_constant_input( c_inp.extract_part(op.domain)) f.append(tf) o.append(to) from ..multi_domain import MultiDomain if not isinstance(self._target, MultiDomain): fullop = None for to, n in zip(o, self._neg): op = to if not n else -to fullop = op if fullop is None else fullop + op return None, fullop from .simplify_for_const import ConstCollector cc = ConstCollector() fullop = None for tf, to, n in zip(f, o, self._neg): cc.add(tf, to.target) op = to if not n else -to fullop = op if fullop is None else fullop + op return cc.constfield, fullop