nifty8.re.tree_math.vector_math module#
- class ShapeWithDtype(shape: Tuple | Tuple[int] | List[int] | int, dtype=None)[source]#
Bases:
object
Minimal helper class storing the shape and dtype of an object.
Notes
This class may not be transparent to JAX as it shall not be flattened itself. If used in a tree-like structure. It should only be used as leave.
- __init__(shape: Tuple | Tuple[int] | List[int] | int, dtype=None)[source]#
Instantiates a storage unit for shape and dtype.
- Parameters:
shape (tuple or list of int) – One-dimensional sequence of integers denoting the length of the object along each of the object’s axis.
dtype (dtype) – Data-type of the to-be-described object.
- property dtype#
Retrieves the data-type.
- classmethod from_leave(element)[source]#
Convenience method for creating an instance of ShapeWithDtype from an object.
To map a whole tree-like structure to a its shape and dtype use JAX’s tree_map method like so:
tree_map(ShapeWithDtype.from_leave, tree)
- Parameters:
element (tree-like structure) – Object from which to take the shape and data-type.
- Returns:
swd – Instance storing the shape and data-type of element.
- Return type:
instance of ShapeWithDtype
- property ndim: int#
- property shape: Tuple[int]#
Retrieves the shape.
- property size: int#
Total number of elements.
- all(a)#
- any(a)#
- conj(a)#
Returns the complex conjugate, component- and element-wise.
- Parameters:
a (object) – Arbitrary, flatten-able objects.
- Returns:
out – The complex conjugate of a, with same shape and dtype as a.
- Return type:
object
- conjugate(a)[source]#
Returns the complex conjugate, component- and element-wise.
- Parameters:
a (object) – Arbitrary, flatten-able objects.
- Returns:
out – The complex conjugate of a, with same shape and dtype as a.
- Return type:
object
- dot(a, b, *, precision=None)[source]#
Returns the dot product of the two vectors.
- Parameters:
a (object) – Arbitrary, flatten-able objects.
b (object) – Arbitrary, flatten-able objects.
- Returns:
out – Dot product of vectors.
- Return type:
float
- matmul(a, b, *, precision=None)#
Returns the dot product of the two vectors.
- Parameters:
a (object) – Arbitrary, flatten-able objects.
b (object) – Arbitrary, flatten-able objects.
- Returns:
out – Dot product of vectors.
- Return type:
float
- max(a)#
- min(a)#
- norm(tree, ord=2)[source]#
Vector norm.
Notes
This function assumes the input to be a vector, i.e. the default order ord is 2.
- sum(a)#