Source code for nifty8.re.tree_math.vector

# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause

import operator
from copy import deepcopy
from pprint import pformat

from jax import numpy as jnp
from jax.tree_util import (
    register_pytree_node_class, tree_leaves, tree_map, tree_structure
)

from .vector_math import matmul, max, min, size, sum


def _copy(obj):
    return obj.copy() if hasattr(obj, "copy") else deepcopy(obj)


def _value_op(op, name=None):
    def value_call(lhs, *args, **kwargs):
        return op(lhs.tree, *args, **kwargs)

    name = op.__name__ if name is None else name
    value_call.__name__ = f"__{name}__"
    return value_call


def _unary_op(op, name=None):
    def unary_call(lhs):
        return tree_map(op, lhs)

    name = op.__name__ if name is None else name
    unary_call.__name__ = f"__{name}__"
    return unary_call


def _broadcast_binary_op(op, lhs, rhs):
    from itertools import repeat

    ts_lhs = tree_structure(lhs)
    ts_rhs = tree_structure(rhs)
    # Catch non-objects scalars and 0d array-likes with a `ndim` property
    if jnp.isscalar(lhs) or getattr(lhs, "ndim", -1) == 0:
        lhs = ts_rhs.unflatten(repeat(lhs, ts_rhs.num_leaves))
    elif jnp.isscalar(rhs) or getattr(rhs, "ndim", -1) == 0:
        rhs = ts_lhs.unflatten(repeat(rhs, ts_lhs.num_leaves))
    elif ts_lhs.num_nodes != ts_rhs.num_nodes:
        ve = f"invalid binary operation {op} for {ts_lhs!r} and {ts_rhs!r}"
        raise ValueError(ve)
    return tree_map(op, lhs, rhs)


def _binary_op(op, name=None):
    def binary_call(lhs, rhs):
        return _broadcast_binary_op(op, lhs, rhs)

    name = op.__name__ if name is None else name
    binary_call.__name__ = f"__{name}__"
    return binary_call


def _rev_binary_op(op, name=None):
    def binary_call(lhs, rhs):
        return _broadcast_binary_op(op, rhs, lhs)

    name = op.__name__ if name is None else name
    binary_call.__name__ = f"__r{name}__"
    return binary_call


def _fwd_rev_binary_op(op, name=None):
    return (_binary_op(op, name=name), _rev_binary_op(op, name=name))


[docs] @register_pytree_node_class class Vector(): """Value storage for arbitrary objects with added numerics."""
[docs] def __init__(self, tree): """Instantiates a vector. Parameters ---------- tree : object Arbitrary, flatten-able objects. """ self._tree = tree
[docs] def tree_flatten(self): return ((self._tree, ), None)
[docs] @classmethod def tree_unflatten(cls, _, children): return cls(*children)
@property def tree(self): """Retrieves a **view** of the vector's values.""" return self._tree def __len__(self): return size(self) @property def size(self): return len(self) @property def shape(self): return (len(self), )
[docs] def copy(self): return tree_map(_copy, self)
[docs] def ravel(self): return self
def __repr__(self): rep = pformat(self.tree).replace("\n", "\n\t").strip() s = f"{self.__class__.__name__}(\n\t{rep}\n)" s = s.replace("\n", "").replace("\t", "") if s.count("\n") <= 2 else s return s def __str__(self): return repr(self) __bool__ = _value_op(bool) def __hash__(self): return hash(tuple(tree_leaves(self))) # NOTE, this partly redundant code could be abstracted away using # `setattr`. However, static code analyzers will not be able to infer the # properties then. __add__, __radd__ = _fwd_rev_binary_op(operator.add) __sub__, __rsub__ = _fwd_rev_binary_op(operator.sub) __mul__, __rmul__ = _fwd_rev_binary_op(operator.mul) __truediv__, __rtruediv__ = _fwd_rev_binary_op(operator.truediv) __floordiv__, __rfloordiv__ = _fwd_rev_binary_op(operator.floordiv) __pow__, __rpow__ = _fwd_rev_binary_op(operator.pow) __mod__, __rmod__ = _fwd_rev_binary_op(operator.mod) __matmul__ = __rmatmul__ = matmul # arguments of matmul commute def __divmod__(self, other): return self // other, self % other def __rdivmod__(self, other): return other // self, other % self __or__, __ror__ = _fwd_rev_binary_op(operator.or_, "or") __xor__, __rxor__ = _fwd_rev_binary_op(operator.xor) __and__, __rand__ = _fwd_rev_binary_op(operator.and_, "and") __lshift__, __rlshift__ = _fwd_rev_binary_op(operator.lshift) __rshift__, __rrshift__ = _fwd_rev_binary_op(operator.rshift) __lt__ = _binary_op(operator.lt) __le__ = _binary_op(operator.le) __eq__ = _binary_op(operator.eq) __ne__ = _binary_op(operator.ne) __ge__ = _binary_op(operator.ge) __gt__ = _binary_op(operator.gt) __neg__ = _unary_op(operator.neg) __pos__ = _unary_op(operator.pos) __abs__ = _unary_op(operator.abs) __invert__ = _unary_op(operator.invert) conj = conjugate = _unary_op(jnp.conj) real = property(_unary_op(jnp.real)) imag = property(_unary_op(jnp.imag)) dot = matmul
[docs] def max(self): return max(self)
[docs] def min(self): return min(self)
[docs] def sum(self): return sum(self)
__getitem__ = _value_op(operator.getitem) __contains__ = _value_op(operator.contains) __iter__ = _value_op(iter)