Source code for nifty8.operators.transpose_operator
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2022 Max-Planck-Society
# Author: Philipp Arras
import numpy as np
from ..domain_tuple import DomainTuple
from ..sugar import makeField
from ..utilities import myassert
from .linear_operator import LinearOperator
[docs]
class TransposeOperator(LinearOperator):
[docs]
def __init__(self, domain, indices):
self._domain = DomainTuple.make(domain)
indices = tuple(indices)
if len(indices) != len(self._domain):
raise IndexError("Either too many or too few indices given.")
self._target = DomainTuple.make(self._domain[ind] for ind in indices)
if self._domain.size != self._target.size:
raise ValueError("List of indices not complete")
self._capability = self._all_ops
self._np_indices = _niftyspace_to_np_indices(self._domain, indices)
self._indices = indices
[docs]
def apply(self, x, mode):
self._check_input(x, mode)
x = x.val
if mode in (self.TIMES, self.ADJOINT_INVERSE_TIMES):
x = np.transpose(x, self._np_indices)
else:
x = np.transpose(x, np.argsort(self._np_indices))
return makeField(self._tgt(mode), x)
[docs]
def __repr__(self):
return f'Transpose (indices={self._indices})'
def _niftyspace_to_np_indices(domain, indices):
np_indices = []
dimensions = np.cumsum((0,) + tuple(len(dd.shape) for dd in domain))
for ind in indices:
np_indices.extend(range(dimensions[ind], dimensions[ind+1]))
res = tuple(np_indices)
myassert(len(res) == len(domain.shape))
return res