Source code for nifty8.field

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2020 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.

from functools import reduce

import numpy as np

from . import utilities
from .domain_tuple import DomainTuple
from .ducc_dispatch import vdot
from .operators.operator import Operator


[docs] class Field(Operator): """The discrete representation of a continuous field over multiple spaces. Stores data arrays and carries all the needed meta-information (i.e. the domain) for operators to be able to operate on them. Parameters ---------- domain : DomainTuple The domain of the new Field. val : numpy.ndarray This object's shape must match the domain shape After construction, the object will no longer be writeable! Notes ----- If possible, do not invoke the constructor directly, but use one of the many convenience functions for instantiation! """ _scalar_dom = DomainTuple.scalar_domain()
[docs] def __init__(self, domain, val): if not isinstance(domain, DomainTuple): raise TypeError("domain must be of type DomainTuple") if not isinstance(val, np.ndarray): if np.isscalar(val): val = np.broadcast_to(val, domain.shape) elif np.shape(val) == domain.shape: # If NumPy thinks the shapes are equal, attempt to convert to # NumPy. This is especially helpful for JAX DeviceArrays. val = np.asarray(val) else: raise TypeError("val must be of type numpy.ndarray") if domain.shape != val.shape: raise ValueError(f"shape mismatch between val and domain\n{domain.shape}\n{val.shape}") self._domain = domain self._val = val self._val.flags.writeable = False
[docs] @staticmethod def scalar(val): return Field(Field._scalar_dom, val)
# prevent implicit conversion to bool def __nonzero__(self): raise TypeError("Field does not support implicit conversion to bool") def __bool__(self): raise TypeError("Field does not support implicit conversion to bool")
[docs] @staticmethod def full(domain, val): """Creates a Field with a given domain, filled with a constant value. Parameters ---------- domain : Domain, tuple of Domain, or DomainTuple Domain of the new Field. val : float/complex/int scalar Fill value. Data type of the field is inferred from val. Returns ------- Field The newly created Field. """ if not np.isscalar(val): raise TypeError("val must be a scalar") if not (np.isreal(val) or np.iscomplex(val)): raise TypeError("need arithmetic scalar") domain = DomainTuple.make(domain) return Field(domain, val)
[docs] @staticmethod def from_raw(domain, arr): """Returns a Field constructed from `domain` and `arr`. Parameters ---------- domain : DomainTuple, tuple of Domain, or Domain The domain of the new Field. arr : numpy.ndarray The data content to be used for the new Field. Its shape must match the shape of `domain`. """ return Field(DomainTuple.make(domain), arr)
[docs] def cast_domain(self, new_domain): """Returns a field with the same data, but a different domain Parameters ---------- new_domain : Domain, tuple of Domain, or DomainTuple The domain for the returned field. Must be shape-compatible to `self`. Returns ------- Field Field defined on `new_domain`, but with the same data as `self`. """ return Field(DomainTuple.make(new_domain), self._val)
[docs] @staticmethod def from_random(domain, random_type='normal', dtype=np.float64, **kwargs): """Draws a random field with the given parameters. Parameters ---------- random_type : 'pm1', 'normal', or 'uniform' The random distribution to use. domain : DomainTuple The domain of the output random Field. dtype : type The datatype of the output random Field. If the datatype is complex, each real and imaginary part have variance 1 Returns ------- Field The newly created Field. """ from .random import Random domain = DomainTuple.make(domain) generator_function = getattr(Random, random_type) arr = generator_function(dtype=dtype, shape=domain.shape, **kwargs) return Field(domain, arr)
@property def val(self): """numpy.ndarray : the array storing the field's entries. Notes ----- The returned array is read-only. """ return self._val
[docs] def val_rw(self): """numpy.ndarray : a copy of the array storing the field's entries. """ return self._val.copy()
@property def dtype(self): """type : the data type of the field's entries""" return self._val.dtype @property def domain(self): """DomainTuple : the field's domain""" return self._domain @property def shape(self): """tuple of int : the concatenated shapes of all sub-domains""" return self._domain.shape @property def size(self): """int : total number of pixels in the field""" return self._domain.size @property def real(self): """Field : The real part of the field""" if utilities.iscomplextype(self.dtype): return Field(self._domain, self._val.real) return self @property def imag(self): """Field : The imaginary part of the field""" if not utilities.iscomplextype(self.dtype): raise ValueError(".imag called on a non-complex Field") return Field(self._domain, self._val.imag)
[docs] def scalar_weight(self, spaces=None): """Returns the uniform volume element for a sub-domain of `self`. Parameters ---------- spaces : int, tuple of int or None Indices of the sub-domains of the field's domain to be considered. If `None`, the entire domain is used. Returns ------- float or None If the requested sub-domain has a uniform volume element, it is returned. Otherwise, `None` is returned. """ return self._domain.scalar_weight(spaces)
[docs] def total_volume(self, spaces=None): """Returns the total volume of the field's domain or of a subspace of it. Parameters ---------- spaces : int, tuple of int or None Indices of the sub-domains of the field's domain to be considered. If `None`, the total volume of the whole domain is returned. Returns ------- float the total volume of the requested (sub-)domain. """ return self._domain.total_volume(spaces)
[docs] def weight(self, power=1, spaces=None): """Weights the pixels of `self` with their invidual pixel volumes. Parameters ---------- power : number The pixel values get multiplied with their volume-factor**power. spaces : None, int or tuple of int Determines on which sub-domain the operation takes place. If None, the entire domain is used. Returns ------- Field The weighted field. """ aout = self.val_rw() spaces = utilities.parse_spaces(spaces, len(self._domain)) fct = 1. for ind in spaces: wgt = self._domain[ind].dvol if np.isscalar(wgt): fct *= wgt else: new_shape = np.ones(len(self.shape), dtype=np.int64) new_shape[self._domain.axes[ind][0]: self._domain.axes[ind][-1]+1] = wgt.shape wgt = wgt.reshape(new_shape) aout *= wgt**power fct = fct**power if fct != 1.: aout *= fct return Field(self._domain, aout)
[docs] def outer(self, x): """Computes the outer product of 'self' with x. Parameters ---------- x : :class:`nifty8.field.Field` Returns ------- Field Defined on the product space of self.domain and x.domain. """ if not isinstance(x, Field): raise TypeError("The multiplier must be an instance of " + "the Field class") from .operators.outer_product_operator import OuterProduct return OuterProduct(x.domain, self)(x)
[docs] def vdot(self, x, spaces=None): """Computes the dot product of 'self' with x. Parameters ---------- x : :class:`nifty8.field.Field` x must be defined on the same domain as `self`. spaces : None, int or tuple of int The dot product is only carried out over the sub-domains in this tuple. If None, it is carried out over all sub-domains. Default: None. Returns ------- float, complex, either scalar (for full dot products) or Field (for partial dot products). """ if not isinstance(x, Field): raise TypeError("The dot-partner must be an instance of " + "the Field class") utilities.check_object_identity(x._domain, self._domain) ndom = len(self._domain) spaces = utilities.parse_spaces(spaces, ndom) if len(spaces) == ndom: return Field.scalar(np.array(vdot(self._val, x._val))) # If we arrive here, we have to do a partial dot product. # For the moment, do this the explicit, non-optimized way return (self.conjugate()*x).sum(spaces=spaces)
[docs] def s_vdot(self, x): """Computes the dot product of 'self' with x. Parameters ---------- x : :class:`nifty8.field.Field` x must be defined on the same domain as `self`. Returns ------- float or complex The dot product """ if not isinstance(x, Field): raise TypeError("The dot-partner must be an instance of " + "the Field class") utilities.check_object_identity(x._domain, self._domain) return vdot(self._val, x._val)
[docs] def norm(self, ord=2): """Computes the L2-norm of the field values. Parameters ---------- ord : int Accepted values: 1, 2, ..., np.inf. Default: 2. Returns ------- float The L2-norm of the field values. """ return np.linalg.norm(self._val.reshape(-1), ord=ord)
[docs] def conjugate(self): """Returns the complex conjugate of the field. Returns ------- Field The complex conjugated field. """ if utilities.iscomplextype(self._val.dtype): return Field(self._domain, self._val.conjugate()) return self
# ---General unary/contraction methods--- def __pos__(self): return self def __neg__(self): return Field(self._domain, -self._val) def __abs__(self): return Field(self._domain, abs(self._val)) def _contraction_helper(self, op, spaces): if spaces is None: return Field.scalar(getattr(self._val, op)()) spaces = utilities.parse_spaces(spaces, len(self._domain)) axes_list = tuple(self._domain.axes[sp_index] for sp_index in spaces) if len(axes_list) > 0: axes_list = reduce(lambda x, y: x+y, axes_list) # perform the contraction on the data data = getattr(self._val, op)(axis=axes_list) # check if the result is scalar or if a result_field must be constr. if np.isscalar(data): return Field.scalar(data) else: return_domain = tuple(dom for i, dom in enumerate(self._domain) if i not in spaces) return Field(DomainTuple.make(return_domain), data)
[docs] def scale(self, factor): if factor == 1: return self return factor*self
[docs] def sum(self, spaces=None): """Sums up over the sub-domains given by `spaces`. Parameters ---------- spaces : None, int or tuple of int The summation is only carried out over the sub-domains in this tuple. If None, it is carried out over all sub-domains. Returns ------- Field The result of the summation. """ return self._contraction_helper('sum', spaces)
[docs] def s_sum(self): """Returns the sum over all entries Returns ------- scalar The result of the summation. """ return self._val.sum()
[docs] def integrate(self, spaces=None): """Integrates over the sub-domains given by `spaces`. Integration is performed by summing over `self` multiplied by its volume factors. Parameters ---------- spaces : None, int or tuple of int The summation is only carried out over the sub-domains in this tuple. If None, it is carried out over all sub-domains. Returns ------- Field The result of the integration. """ swgt = self.scalar_weight(spaces) if swgt is not None: res = self.sum(spaces) res = res*swgt return res tmp = self.weight(1, spaces=spaces) return tmp.sum(spaces)
[docs] def s_integrate(self): """Integrates over the Field. Integration is performed by summing over `self` multiplied by its volume factors. Returns ------- Scalar The result of the integration. """ swgt = self.scalar_weight() if swgt is not None: return self.s_sum()*swgt tmp = self.weight(1) return tmp.s_sum()
[docs] def prod(self, spaces=None): """Computes the product over the sub-domains given by `spaces`. Parameters ---------- spaces : None, int or tuple of int The operation is only carried out over the sub-domains in this tuple. If None, it is carried out over all sub-domains. Default: None. Returns ------- Field The result of the product. """ return self._contraction_helper('prod', spaces)
[docs] def s_prod(self): return self._val.prod()
[docs] def all(self, spaces=None): return self._contraction_helper('all', spaces)
[docs] def s_all(self): return self._val.all()
[docs] def any(self, spaces=None): return self._contraction_helper('any', spaces)
[docs] def s_any(self): return self._val.any()
# def min(self, spaces=None): # """Determines the minimum over the sub-domains given by `spaces`. # # Parameters # ---------- # spaces : None, int or tuple of int (default: None) # The operation is only carried out over the sub-domains in this # tuple. If None, it is carried out over all sub-domains. # # Returns # ------- # Field # The result of the operation. # """ # return self._contraction_helper('min', spaces) # # def max(self, spaces=None): # """Determines the maximum over the sub-domains given by `spaces`. # # Parameters # ---------- # spaces : None, int or tuple of int (default: None) # The operation is only carried out over the sub-domains in this # tuple. If None, it is carried out over all sub-domains. # # Returns # ------- # Field # The result of the operation. # """ # return self._contraction_helper('max', spaces)
[docs] def mean(self, spaces=None): """Determines the mean over the sub-domains given by `spaces`. ``x.mean(spaces)`` is equivalent to ``x.integrate(spaces)/x.total_volume(spaces)``. Parameters ---------- spaces : None, int or tuple of int The operation is only carried out over the sub-domains in this tuple. If None, it is carried out over all sub-domains. Returns ------- Field The result of the operation. """ if self.scalar_weight(spaces) is not None: return self._contraction_helper('mean', spaces) # MR FIXME: not very efficient # MR FIXME: do we need "spaces" here? tmp = self.weight(1, spaces) return tmp.sum(spaces)*(1./tmp.total_volume(spaces))
[docs] def s_mean(self): """Determines the field mean ``x.s_mean()`` is equivalent to ``x.s_integrate()/x.total_volume()``. Returns ------- scalar The result of the operation. """ return self.s_integrate()/self.total_volume()
[docs] def var(self, spaces=None): """Determines the variance over the sub-domains given by `spaces`. Parameters ---------- spaces : None, int or tuple of int The operation is only carried out over the sub-domains in this tuple. If None, it is carried out over all sub-domains. Default: None. Returns ------- Field The result of the operation. """ if self.scalar_weight(spaces) is not None: return self._contraction_helper('var', spaces) # MR FIXME: not very efficient or accurate m1 = self.mean(spaces) from .operators.contraction_operator import ContractionOperator op = ContractionOperator(self._domain, spaces) m1 = op.adjoint_times(m1) if utilities.iscomplextype(self.dtype): sq = abs(self-m1)**2 else: sq = (self-m1)**2 return sq.mean(spaces)
[docs] def s_var(self): """Determines the field variance Returns ------- scalar The result of the operation. """ if self.scalar_weight() is not None: return self._val.var() # MR FIXME: not very efficient or accurate m1 = self.s_mean() if utilities.iscomplextype(self.dtype): sq = abs(self-m1)**2 else: sq = (self-m1)**2 return sq.s_mean()
[docs] def std(self, spaces=None): """Determines the standard deviation over the sub-domains given by `spaces`. ``x.std(spaces)`` is equivalent to ``sqrt(x.var(spaces))``. Parameters ---------- spaces : None, int or tuple of int The operation is only carried out over the sub-domains in this tuple. If None, it is carried out over all sub-domains. Default: None. Returns ------- Field The result of the operation. """ if self.scalar_weight(spaces) is not None: return self._contraction_helper('std', spaces) return self.var(spaces).ptw("sqrt")
[docs] def s_std(self): """Determines the standard deviation of the Field. ``x.s_std()`` is equivalent to ``sqrt(x.s_var())``. Returns ------- scalar The result of the operation. """ if self.scalar_weight() is not None: return self._val.std() return np.sqrt(self.s_var())
[docs] def __repr__(self): return "<nifty8.Field>"
[docs] def __str__(self): return "nifty8.Field instance\n- domain = " + \ self._domain.__str__() + \ "\n- val = " + repr(self._val)
[docs] def extract(self, dom): utilities.check_object_identity(dom, self._domain) return self
[docs] def extract_part(self, dom): utilities.check_object_identity(dom, self._domain) return self
[docs] def unite(self, other): return self+other
[docs] def flexible_addsub(self, other, neg): return self-other if neg else self+other
def _binary_op(self, other, op): # if other is a field, make sure that the domains match f = getattr(self._val, op) if isinstance(other, Field): utilities.check_object_identity(other._domain, self._domain) return Field(self._domain, f(other._val)) if np.isscalar(other): return Field(self._domain, f(other)) return NotImplemented def _prep_args(self, args, kwargs): for arg in args + tuple(kwargs.values()): if not (arg is None or np.isscalar(arg) or arg.jac is None): raise TypeError("bad argument") argstmp = tuple(arg if arg is None or np.isscalar(arg) else arg._val for arg in args) kwargstmp = {key: val if val is None or np.isscalar(val) else val._val for key, val in kwargs.items()} return argstmp, kwargstmp
[docs] def ptw(self, op, *args, **kwargs): from .pointwise import ptw_dict argstmp, kwargstmp = self._prep_args(args, kwargs) return Field(self._domain, ptw_dict[op][0](self._val, *argstmp, **kwargstmp))
[docs] def ptw_with_deriv(self, op, *args, **kwargs): from .pointwise import ptw_dict argstmp, kwargstmp = self._prep_args(args, kwargs) tmp = ptw_dict[op][1](self._val, *argstmp, **kwargstmp) return (Field(self._domain, tmp[0]), Field(self._domain, tmp[1]))
for op in ["__add__", "__radd__", "__sub__", "__rsub__", "__mul__", "__rmul__", "__truediv__", "__rtruediv__", "__floordiv__", "__rfloordiv__", "__pow__", "__rpow__", "__lt__", "__le__", "__gt__", "__ge__", "__eq__", "__ne__"]: def func(op): def func2(self, other): return self._binary_op(other, op) return func2 setattr(Field, op, func(op)) for op in ["__iadd__", "__isub__", "__imul__", "__idiv__", "__itruediv__", "__ifloordiv__", "__ipow__"]:
[docs] def func(op): def func2(self, other): raise TypeError( "In-place operations are deliberately not supported") return func2
setattr(Field, op, func(op))