Source code for nifty8.operators.linear_interpolation

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.

from functools import reduce
from operator import add

import numpy as np
from scipy.sparse import coo_matrix
from scipy.sparse.linalg import aslinearoperator

from ..domains.rg_space import RGSpace
from ..domains.unstructured_domain import UnstructuredDomain
from ..field import Field
from ..sugar import makeDomain
from .linear_operator import LinearOperator


[docs] class LinearInterpolator(LinearOperator): """Multilinear interpolation for points in an RGSpace Parameters ---------- domain : RGSpace sampling_points : numpy.ndarray Positions at which to interpolate, shape (dim, ndata), Notes ----- Positions that are not within the RGSpace are wrapped according to periodic boundary conditions. This reflects the general property of RGSpaces to be tori topologically. """
[docs] def __init__(self, domain, sampling_points): self._domain = makeDomain(domain) for dom in self.domain: if not isinstance(dom, RGSpace): raise TypeError dims = [len(dom.shape) for dom in self.domain] # FIXME This needs to be removed as soon as the bug below is fixed. if dims.count(dims[0]) != len(dims): raise TypeError("This is a bug. Please extend" "LinearInterpolator's functionality!") shp = sampling_points.shape if not (isinstance(sampling_points, np.ndarray) and len(shp) == 2): raise TypeError n_dim, n_points = shp if n_dim != reduce(add, dims): raise TypeError self._target = makeDomain(UnstructuredDomain(n_points)) self._capability = self.TIMES | self.ADJOINT_TIMES self._build_mat(sampling_points, n_points)
def _build_mat(self, sampling_points, N_points): ndim = sampling_points.shape[0] mg = np.mgrid[(slice(0, 2),)*ndim] mg = np.array(list(map(np.ravel, mg))) dist = [list(dom.distances) for dom in self.domain] # FIXME This breaks as soon as not all domains have the same number of # dimensions. dist = np.array(dist).reshape(-1, 1) pos = sampling_points/dist excess = pos - np.floor(pos) pos = np.floor(pos).astype(np.int64) max_index = np.array(self.domain.shape).reshape(-1, 1) data = np.zeros((len(mg[0]), N_points)) ii = np.zeros((len(mg[0]), N_points), dtype=np.int64) jj = np.zeros((len(mg[0]), N_points), dtype=np.int64) for i in range(len(mg[0])): factor = np.prod( np.abs(1 - mg[:, i].reshape(-1, 1) - excess), axis=0) data[i, :] = factor fromi = (pos + mg[:, i].reshape(-1, 1)) % max_index ii[i, :] = np.arange(N_points) jj[i, :] = np.ravel_multi_index(fromi, self.domain.shape) self._mat = coo_matrix((data.reshape(-1), (ii.reshape(-1), jj.reshape(-1))), (N_points, np.prod(self.domain.shape))) self._mat = aslinearoperator(self._mat)
[docs] def apply(self, x, mode): self._check_input(x, mode) x_val = x.val if mode == self.TIMES: res = self._mat.matvec(x_val.reshape(-1)) else: res = self._mat.rmatvec(x_val).reshape(self.domain.shape) return Field(self._tgt(mode), res)