Source code for nifty8.re.lax
# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
from jax import lax
_DISABLE_CONTROL_FLOW_PRIM = False
[docs]
def cond(pred, true_fun, false_fun, operand):
if _DISABLE_CONTROL_FLOW_PRIM:
if pred:
return true_fun(operand)
else:
return false_fun(operand)
else:
return lax.cond(pred, true_fun, false_fun, operand)
[docs]
def while_loop(cond_fun, body_fun, init_val):
if _DISABLE_CONTROL_FLOW_PRIM:
val = init_val
while cond_fun(val):
val = body_fun(val)
return val
else:
return lax.while_loop(cond_fun, body_fun, init_val)
[docs]
def fori_loop(lower, upper, body_fun, init_val):
if _DISABLE_CONTROL_FLOW_PRIM:
val = init_val
for i in range(int(lower), int(upper)):
val = body_fun(i, val)
return val
else:
return lax.fori_loop(lower, upper, body_fun, init_val)