nifty8.re.tree_math.forest_math module#

assert_arithmetics(obj, *args, **kwargs)[source]#
get_map(map) Callable[source]#
has_arithmetics(obj, additional_attributes=())[source]#
map_forest(f: Callable, in_axes: int | Tuple = 0, out_axes: int | Tuple = 0, tree_transpose_output: bool = True, map: str | Callable = 'vmap', **kwargs) Callable[source]#
map_forest_mean(method, map='vmap', *args, **kwargs) Callable[source]#
mean(forest)[source]#
mean_and_std(forest, correct_bias=True)[source]#
random_like(key: ~collections.abc.Iterable, primals, rng: ~typing.Callable = <function normal>)[source]#
stack(arrays, axis=0)[source]#
tree_shape(tree: T) T[source]#
unite(x, y, op=<built-in function add>)[source]#

Unites two Vector-like objects.

If a key is contained in both objects, then the fields at that key are combined.

unstack(stack, axis=0)[source]#