nifty8.re.tree_math.vector_math module#

class ShapeWithDtype(shape: Tuple | Tuple[int] | List[int] | int, dtype=None)[source]#

Bases: object

Minimal helper class storing the shape and dtype of an object.

Notes

This class may not be transparent to JAX as it shall not be flattened itself. If used in a tree-like structure. It should only be used as leave.

__init__(shape: Tuple | Tuple[int] | List[int] | int, dtype=None)[source]#

Instantiates a storage unit for shape and dtype.

Parameters:
  • shape (tuple or list of int) – One-dimensional sequence of integers denoting the length of the object along each of the object’s axis.

  • dtype (dtype) – Data-type of the to-be-described object.

property dtype#

Retrieves the data-type.

classmethod from_leave(element)[source]#

Convenience method for creating an instance of ShapeWithDtype from an object.

To map a whole tree-like structure to a its shape and dtype use JAX’s tree_map method like so:

tree_map(ShapeWithDtype.from_leave, tree)

Parameters:

element (tree-like structure) – Object from which to take the shape and data-type.

Returns:

swd – Instance storing the shape and data-type of element.

Return type:

instance of ShapeWithDtype

property ndim: int#
property shape: Tuple[int]#

Retrieves the shape.

property size: int#

Total number of elements.

all(a)#
any(a)#
conj(a)#

Returns the complex conjugate, component- and element-wise.

Parameters:

a (object) – Arbitrary, flatten-able objects.

Returns:

out – The complex conjugate of a, with same shape and dtype as a.

Return type:

object

conjugate(a)[source]#

Returns the complex conjugate, component- and element-wise.

Parameters:

a (object) – Arbitrary, flatten-able objects.

Returns:

out – The complex conjugate of a, with same shape and dtype as a.

Return type:

object

dot(a, b, *, precision=None)[source]#

Returns the dot product of the two vectors.

Parameters:
  • a (object) – Arbitrary, flatten-able objects.

  • b (object) – Arbitrary, flatten-able objects.

Returns:

out – Dot product of vectors.

Return type:

float

matmul(a, b, *, precision=None)#

Returns the dot product of the two vectors.

Parameters:
  • a (object) – Arbitrary, flatten-able objects.

  • b (object) – Arbitrary, flatten-able objects.

Returns:

out – Dot product of vectors.

Return type:

float

max(a)#
min(a)#
norm(tree, ord=2)[source]#

Vector norm.

Notes

This function assumes the input to be a vector, i.e. the default order ord is 2.

result_type(*trees)[source]#
shape(a)[source]#
size(a, axis: int | None = None) int[source]#
sum(a)#
vdot(a, b, *, precision=None)[source]#
where(condition, x, y)[source]#

Selects a pytree based on the condition which can be a pytree itself.

Notes

If condition is not a pytree, then a partially evaluated selection is simply mapped over x and y without actually broadcasting condition.