Source code for nifty8.operators.regridding_operator

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.

import numpy as np

from ..domain_tuple import DomainTuple
from ..domains.rg_space import RGSpace
from ..field import Field
from ..utilities import infer_space, special_add_at
from .linear_operator import LinearOperator


[docs] class RegriddingOperator(LinearOperator): """Linearly interpolates an RGSpace to an RGSpace with coarser resolution. Parameters ---------- domain : Domain, DomainTuple or tuple of Domain domain[space] needs to be an :class:`RGSpace`. new_shape : tuple of int Shape of the space which domain[space] is replaced by. Each entry must be smaller or equal to the respective entry in `domain[space].shape`. space : int Index of space in `domain` on which the operator shall act. Default is 0. """
[docs] def __init__(self, domain, new_shape, space=0): self._domain = DomainTuple.make(domain) self._space = infer_space(self._domain, space) dom = self._domain[self._space] if not isinstance(dom, RGSpace): raise TypeError("RGSpace required") if len(new_shape) != len(dom.shape): raise ValueError("Shape mismatch") if any([a > b for a, b in zip(new_shape, dom.shape)]): raise ValueError("New shape must not be larger than old shape") if any([ii <= 0 for ii in new_shape]): raise ValueError('New shape must not be zero or negative.') newdist = tuple(dom.distances[i]*dom.shape[i]/new_shape[i] for i in range(len(dom.shape))) tgt = RGSpace(new_shape, newdist) self._target = list(self._domain) self._target[self._space] = tgt self._target = DomainTuple.make(self._target) self._capability = self.TIMES | self.ADJOINT_TIMES ndim = len(new_shape) self._bindex = [None] * ndim self._frac = [None] * ndim for d in range(ndim): tmp = np.arange(new_shape[d])*(newdist[d]/dom.distances[d]) self._bindex[d] = np.minimum(dom.shape[d]-2, tmp.astype(np.int64)) self._frac[d] = tmp-self._bindex[d]
[docs] def apply(self, x, mode): self._check_input(x, mode) v = x.val ndim = len(self.target.shape) curshp = list(self._dom(mode).shape) tgtshp = self._tgt(mode).shape d0 = self._target.axes[self._space][0] for d in self._target.axes[self._space]: idx = (slice(None),) * d wgt = self._frac[d-d0].reshape((1,)*d + (-1,) + (1,)*(ndim-d-1)) if mode == self.ADJOINT_TIMES: shp = list(v.shape) shp[d] = tgtshp[d] xnew = np.zeros(shp, dtype=v.dtype) xnew = special_add_at(xnew, d, self._bindex[d-d0], v*(1.-wgt)) xnew = special_add_at(xnew, d, self._bindex[d-d0]+1, v*wgt) else: # TIMES xnew = v[idx + (self._bindex[d-d0],)] * (1.-wgt) xnew += v[idx + (self._bindex[d-d0]+1,)] * wgt curshp[d] = xnew.shape[d] v = xnew return Field(self._tgt(mode), xnew)