Source code for nifty7.operators.simplify_for_const

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-201 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.

from ..multi_domain import MultiDomain
from ..utilities import myassert
from .block_diagonal_operator import BlockDiagonalOperator
from .energy_operators import EnergyOperator
from .operator import Operator
from .scaling_operator import ScalingOperator
from .simple_linear_operators import NullOperator


[docs]class ConstCollector(object): def __init__(self): self._const = None # MultiField on the part of the MultiDomain that could be constant self._nc = set() # NoConstant - set of keys that we know cannot be constant
[docs] def mult(self, const, fulldom): if const is None: self._nc |= set(fulldom.keys()) else: from ..multi_field import MultiField self._nc |= set(fulldom.keys()) - set(const.keys()) if self._const is None: self._const = MultiField.from_dict( {key: const[key] for key in const.keys() if key not in self._nc}) else: # we know that the domains are identical for products self._const = MultiField.from_dict( {key: self._const[key]*const[key] for key in const.keys() if key not in self._nc})
[docs] def add(self, const, fulldom): if const is None: self._nc |= set(fulldom.keys()) else: from ..multi_field import MultiField self._nc |= set(fulldom.keys()) - set(const.keys()) self._const = const if self._const is None else self._const.unite(const) self._const = MultiField.from_dict( {key: const[key] for key in const.keys() if key not in self._nc})
@property def constfield(self): return self._const
[docs]class ConstantOperator(Operator): def __init__(self, output): from ..sugar import makeDomain self._domain = makeDomain({}) self._target = makeDomain(output.domain) self._output = output
[docs] def apply(self, x): from .simple_linear_operators import NullOperator self._check_input(x) if x.jac is not None: return x.new(self._output, NullOperator(self._domain, self._target)) return self._output
[docs] def __repr__(self): tgt = self.target.keys() if isinstance(self.target, MultiDomain) else '()' return f'{tgt} <- ConstantOperator'
[docs]class ConstantEnergyOperator(EnergyOperator): def __init__(self, output): from ..field import Field from ..sugar import makeDomain self._domain = makeDomain({}) if not isinstance(output, Field): output = Field.scalar(float(output)) self._output = output
[docs] def apply(self, x): self._check_input(x) if x.jac is not None: val = self._output jac = NullOperator(self._domain, self._target) # FIXME Do we need a metric here? met = NullOperator(self._domain, self._domain) if x.want_metric else None return x.new(val, jac, met) return self._output
[docs]class InsertionOperator(Operator): def __init__(self, target, cst_field): from ..multi_field import MultiField from ..sugar import makeDomain if not isinstance(target, MultiDomain): raise TypeError if not isinstance(cst_field, MultiField): raise TypeError self._target = MultiDomain.make(target) cstdom = cst_field.domain vardom = makeDomain({kk: vv for kk, vv in self._target.items() if kk not in cst_field.keys()}) self._domain = vardom self._cst = cst_field jac = {kk: ScalingOperator(vv, 1.) for kk, vv in self._domain.items()} self._jac = BlockDiagonalOperator(self._domain, jac) + NullOperator(makeDomain({}), cstdom)
[docs] def apply(self, x): myassert(len(set(self._cst.keys()) & set(x.domain.keys())) == 0) val = x if x.jac is None else x.val val = val.unite(self._cst) if x.jac is None: return val return x.new(val, self._jac)
[docs] def __repr__(self): from ..utilities import indent subs = f'Constant: {self._cst.keys()}\nVariable: {self._domain.keys()}' return 'InsertionOperator\n'+indent(subs)