# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2021 Max-Planck-Society
# Authors: Gordian Edenhofer
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import numpy as np
from ..domain_tuple import DomainTuple
from ..domains.rg_space import RGSpace
from ..domains.unstructured_domain import UnstructuredDomain
from ..field import Field
from ..multi_domain import MultiDomain
from ..multi_field import MultiField
from .linear_operator import LinearOperator
[docs]
class SliceOperator(LinearOperator):
"""Geometry preserving mask operator
Takes a field, slices it into the desired shape and returns the values of
the field in the sliced domain all while preserving the original distances.
Parameters
----------
domain : Domain, DomainTuple or tuple of Domain
The operator's input domain.
new_shape : tuple of tuples or integers, or None
The shape of the target domain with None indicating to copy the shape
of the original domain for this axis. For example ((10, 5), 100) for a
DomainTuple with two entires, the first having shape (10, 5) and the
second having shape 100
center : bool, optional
Whether to center the slice that is selected in the input field.
preserve_dist: bool, optional
Whether to preserve the distance of the input field.
"""
[docs]
def __init__(self, domain, new_shape, center=False, preserve_dist=True):
self._domain = DomainTuple.make(domain)
if len(new_shape) != len(self._domain):
ve = (
f"shape ({new_shape}) is incompatible with the shape of the"
f" domain ({self._domain.shape})"
)
raise ValueError(ve)
for i, shape in enumerate(new_shape):
if len(np.atleast_1d(shape)) != len(self._domain[i].shape):
ve = (
f"shape of subspace ({i}) is incompatible with the domain"
)
raise ValueError(ve)
tgt = []
slc_by_ax = []
for i, d in enumerate(self._domain):
if new_shape[i] is None or np.all(
np.array(self._domain.shape[i]) == np.array(new_shape[i])
):
tgt += [d]
elif np.all(np.array(new_shape[i]) <= np.array(d.shape)):
dom_kw = dict()
if isinstance(d, RGSpace):
if preserve_dist:
dom_kw["distances"] = d.distances
dom_kw["harmonic"] = d.harmonic
elif not isinstance(d, UnstructuredDomain):
# Some domains like HPSpace or LMSPace can not be sliced
ve = f"{d.__class__.__name__} can not be sliced"
raise ValueError(ve)
tgt += [d.__class__(new_shape[i], **dom_kw)]
else:
ve = (
f"domain axes ({d}) is smaller than the target shape"
f"{new_shape[i]}"
)
raise ValueError(ve)
if center:
for j, n_pix in enumerate(np.atleast_1d(new_shape[i])):
slc_start = np.floor((d.shape[j] - n_pix) / 2.).astype(int)
slc_end = slc_start + n_pix
slc_by_ax += [slice(slc_start, slc_end)]
else:
for n_pix in np.atleast_1d(new_shape[i]):
slc_start = 0
slc_end = n_pix
slc_by_ax += [slice(slc_start, slc_end)]
self._slc_by_ax = tuple(slc_by_ax)
self._target = DomainTuple.make(tgt)
self._capability = self.TIMES | self.ADJOINT_TIMES
[docs]
def apply(self, x, mode):
self._check_input(x, mode)
x = x.val
if mode == self.TIMES:
res = x[self._slc_by_ax]
return Field.from_raw(self.target, res)
res = np.zeros(self.domain.shape, x.dtype)
res[self._slc_by_ax] = x
return Field.from_raw(self.domain, res)
[docs]
def __str__(self):
ss = (
f"{self.__class__.__name__}"
f"({self.domain.shape} -> {self.target.shape})"
)
return ss
[docs]
class SplitOperator(LinearOperator):
"""Split a single field into a multi-field
Takes a field, selects the desired entries for each multi-field key and
puts the result into a multi-field. Along sliced axis, the domain will
be replaced by an UnstructuredDomain as no distance measures are preserved.
Note, slices may intersect, i.e. slices may reference the same input
multiple times if the `intersecting_slices` option is set. However, a
single field in the output may not contain the same part of the input more
than once.
Parameters
----------
domain : Domain, DomainTuple or tuple of Domain
The operator's input domain.
slices_by_key : dict{key: tuple of integers or None}
The key-value pairs of which the values indicate the parts to be
selected. The result will be a multi-field with the given keys as
entries and the selected slices of the domain as values. `None`
indicates to select the whole input along this axis.
intersecting_slices : bool, optional
Tells the operator whether slices may contain intersections. If true,
the adjoint is constructed a little less efficiently. Set this
parameter to `False` to gain a little more efficiency.
"""
[docs]
def __init__(self, domain, slices_by_key, intersecting_slices=True):
self._domain = DomainTuple.make(domain)
self._intersec_slc = intersecting_slices
tgt = dict()
self._k_slc = dict()
for k, slc in slices_by_key.items():
if len(slc) > len(self._domain):
ve = f"slice at key {k!r} has more dimensions than the input"
raise ValueError(ve)
k_tgt = []
k_slc_by_ax = []
for i, d in enumerate(self._domain):
if i >= len(slc) or slc[i] is None or (
isinstance(slc[i], slice) and slc[i] == slice(None)
):
k_tgt += [d]
k_slc_by_ax += [slice(None)]
elif isinstance(slc[i], slice):
start = slc[i].start if slc[i].start is not None else 0
stop = slc[i].stop if slc[i].stop is not None else d.size
step = slc[i].step if slc[i].step is not None else 1
frac = np.floor((stop - start) / np.abs(step))
k_tgt += [UnstructuredDomain(frac.astype(int))]
k_slc_by_ax += [slc[i]]
elif isinstance(slc[i],
np.ndarray) and slc[i].dtype is np.dtype(bool):
if slc[i].size != d.size:
raise ValueError(
"shape mismatch between desired slice {slc[i]}"
"and the shape of the domain {d.size}"
)
k_tgt += [UnstructuredDomain(slc[i].sum())]
k_slc_by_ax += [slc[i]]
elif isinstance(slc[i], (tuple, list, np.ndarray)):
k_tgt += [UnstructuredDomain(len(slc[i]))]
k_slc_by_ax += [slc[i]]
elif isinstance(slc[i], int):
k_slc_by_ax += [slc[i]]
else:
ve = f"invalid type for specifying a slice; got {slc[i]}"
raise ValueError(ve)
tgt[k] = DomainTuple.make(k_tgt)
self._k_slc[k] = tuple(k_slc_by_ax)
self._target = MultiDomain.make(tgt)
self._capability = self.TIMES | self.ADJOINT_TIMES
[docs]
def apply(self, x, mode):
self._check_input(x, mode)
x = x.val
if mode == self.TIMES:
res = dict()
for k, slc in self._k_slc.items():
res[k] = x[slc]
return MultiField.from_raw(self.target, res)
# Note, not-selected parts must be zero. Hence, using the quicker
# `np.empty` method is unfortunately not possible
res = np.zeros(self.domain.shape, tuple(x.values())[0].dtype)
if self._intersec_slc:
for k, slc in self._k_slc.items():
# Mind the `+` here for coping with intersections
res[slc] += x[k]
return Field.from_raw(self.domain, res)
for k, slc in self._k_slc.items():
res[slc] = x[k]
return Field.from_raw(self.domain, res)
[docs]
def __str__(self):
return f"{self.__class__.__name__} {self._target.keys()!r} <-"