Source code for nifty8.re.tree_math.pytree_string
#!/usr/bin/env python3
# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
import operator
from jax.tree_util import register_pytree_node_class, tree_map
def _unary_op(op, name=None):
def unary_call(lhs):
return op(lhs._str)
name = op.__name__ if name is None else name
unary_call.__name__ = f"__{name}__"
return unary_call
def _binary_op(op, name=None):
def binary_call(lhs, rhs):
lhs = lhs._str if isinstance(lhs, PyTreeString) else lhs
rhs = rhs._str if isinstance(rhs, PyTreeString) else rhs
out = op(lhs, rhs)
return PyTreeString(out) if isinstance(out, str) else out
name = op.__name__ if name is None else name
binary_call.__name__ = f"__{name}__"
return binary_call
def _rev_binary_op(op, name=None):
def binary_call(lhs, rhs):
lhs = lhs._str if isinstance(lhs, PyTreeString) else lhs
rhs = rhs._str if isinstance(rhs, PyTreeString) else rhs
out = op(rhs, lhs)
return PyTreeString(out) if isinstance(out, str) else out
name = op.__name__ if name is None else name
binary_call.__name__ = f"__r{name}__"
return binary_call
def _fwd_rev_binary_op(op, name=None):
return (_binary_op(op, name=name), _rev_binary_op(op, name=name))
[docs]
@register_pytree_node_class
class PyTreeString():
[docs]
def __init__(self, str):
self._str = str
[docs]
def tree_flatten(self):
return ((), (self._str, ))
[docs]
@classmethod
def tree_unflatten(cls, aux, _):
return cls(*aux)
def __str__(self):
return self._str
def __repr__(self):
return f"{self.__class__.__name__}({self._str!r})"
__lt__ = _binary_op(operator.lt)
__le__ = _binary_op(operator.le)
__eq__ = _binary_op(operator.eq)
__ne__ = _binary_op(operator.ne)
__ge__ = _binary_op(operator.ge)
__gt__ = _binary_op(operator.gt)
__add__, __radd__ = _fwd_rev_binary_op(operator.add)
__mul__, __rmul__ = _fwd_rev_binary_op(operator.mul)
lower = _unary_op(str.lower)
upper = _unary_op(str.upper)
__hash__ = _unary_op(str.__hash__)
startswith = _binary_op(str.startswith)
[docs]
def hide_strings(a):
return tree_map(lambda x: PyTreeString(x) if isinstance(x, str) else x, a)