nifty8.re.custom_map module#

lmap(fun, in_axes=0, out_axes=0)[source]#
smap(fun, in_axes=0, out_axes=0, *, unroll=1)[source]#

Stupid/sequential map.

Many of JAX’s control flow logic reduces to a simple jax.lax.scan. This function is one of these. In contrast to jax.lax.map or jax.lax.fori_loop, it behaves much like jax.vmap. In fact, it re-implements in_axes and out_axes and can be used in much the same way as jax.vmap. However, instead of batching the input, it works through it sequentially.

This implementation makes no claim on being efficient. It explicitly swaps around axis in the input and output, potentially allocating more memory than strictly necessary and worsening the memory layout.

For the semantics of in_axes and out_axes see jax.vmap. For the semantics of unroll see jax.lax.scan.