Source code for nifty8.operators.field_zero_padder
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import numpy as np
from .. import utilities
from ..domain_tuple import DomainTuple
from ..domains.rg_space import RGSpace
from ..field import Field
from .linear_operator import LinearOperator
[docs]
class FieldZeroPadder(LinearOperator):
"""Operator which applies zero-padding to one of the subdomains of its
input field
Parameters
----------
domain : Domain, DomainTuple or tuple of Domain
The operator's input domain.
new_shape : list or tuple of int
The new dimensions of the subdomain which is zero-padded.
No entry must be smaller than the corresponding dimension in the
operator's domain.
space : int
The index of the subdomain to be zero-padded. If None, it is set to 0
if domain contains exactly one space. domain[space] must be an RGSpace.
central : bool
If `False`, padding is performed at the end of the domain axes,
otherwise in the middle.
Notes
-----
When doing central padding on an axis with an even length, the "central"
entry should in principle be split up; this is currently not done.
"""
[docs]
def __init__(self, domain, new_shape, space=0, central=False):
self._domain = DomainTuple.make(domain)
self._space = utilities.infer_space(self._domain, space)
self._central = central
dom = self._domain[self._space]
if not isinstance(dom, RGSpace):
raise TypeError("RGSpace required")
if len(new_shape) != len(dom.shape):
raise ValueError("Shape mismatch")
if any([a < b for a, b in zip(new_shape, dom.shape)]):
raise ValueError("New shape must not be smaller than old shape")
self._target = list(self._domain)
self._target[self._space] = RGSpace(new_shape, dom.distances,
dom.harmonic)
self._target = DomainTuple.make(self._target)
self._capability = self.TIMES | self.ADJOINT_TIMES
[docs]
def apply(self, x, mode):
self._check_input(x, mode)
v = x.val
curshp = list(self._dom(mode).shape)
tgtshp = self._tgt(mode).shape
for d in self._target.axes[self._space]:
if v.shape[d] == tgtshp[d]: # nothing to do
continue
idx = (slice(None),) * d
if mode == self.TIMES:
shp = list(v.shape)
shp[d] = tgtshp[d]
xnew = np.zeros(shp, dtype=v.dtype)
if self._central:
Nyquist = v.shape[d]//2
i1 = idx + (slice(0, Nyquist+1),)
xnew[i1] = v[i1]
i1 = idx + (slice(None, -(Nyquist+1), -1),)
xnew[i1] = v[i1]
# if (v.shape[d] & 1) == 0: # even number of pixels
# i1 = idx+(Nyquist,)
# xnew[i1] *= 0.5
# i1 = idx+(-Nyquist,)
# xnew[i1] *= 0.5
else:
xnew[idx + (slice(0, v.shape[d]),)] = v
else: # ADJOINT_TIMES
if self._central:
shp = list(v.shape)
shp[d] = tgtshp[d]
xnew = np.zeros(shp, dtype=v.dtype)
Nyquist = xnew.shape[d]//2
i1 = idx + (slice(0, Nyquist+1),)
xnew[i1] = v[i1]
i1 = idx + (slice(None, -(Nyquist+1), -1),)
xnew[i1] += v[i1]
# if (xnew.shape[d] & 1) == 0: # even number of pixels
# i1 = idx+(Nyquist,)
# xnew[i1] *= 0.5
else:
xnew = v[idx + (slice(0, tgtshp[d]),)]
curshp[d] = xnew.shape[d]
v = xnew
return Field(self._tgt(mode), v)