# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2021 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from copy import deepcopy
from numpy import allclose
from .multi_field import MultiField
from .operators.operator import _OpChain, _OpProd, _OpSum
from .operators.simple_linear_operators import FieldAdapter
from .sugar import domain_union, from_random
from .utilities import myassert
def _optimise_operator(op):
"""
optimises operator trees, so that same operator subtrees are not computed twice.
Recognizes same subtrees and replaces them at nodes.
Recognizes same leaves and structures them.
Works partly inplace, rendering the old operator unusable"""
# Format: List of tuple [op, parent_index, left=True right=False]
nodes = []
# Format: ID: index in nodes[]
id_dic = {}
# Format: [parent_index, left]
leaves = set()
# helper functions
def readable_id():
# Gives out letters to prepend field_adapter ids for cosmetics
# Start at 'A'
current_letter = 65
repeats = 1
while True:
yield chr(current_letter)*repeats
current_letter += 1
if current_letter == 91:
# skip specials
current_letter += 6
elif current_letter == 123:
# End at z and start at AA
current_letter = 65
repeats += 1
prepend_id = readable_id()
def isnode(op):
return isinstance(op, (_OpSum, _OpProd))
def left_parser(left_bool):
return '_op1' if left_bool else '_op2'
def get_duplicate_keys(k_list, dic):
for item in list(dic.items()):
if len(item[1]) > 1:
k_list.append(item[0])
# Main algorithm functions
def rebuild_domains(index):
"""Goes bottom up to fix domains which were destroyed by plugging in field adapters"""
cond = True
while cond:
op = nodes[index][0]
for attr in ('_op1', '_op2'):
if isinstance(getattr(op, attr), _OpChain):
getattr(op, attr)._domain = getattr(op, attr)._ops[-1].domain
if isnode(op):
# Some problems doing this on non-multidomains, because one side becomes a multidomain and the other not
try:
op._domain = domain_union((op._op1.domain, op._op2.domain))
except AttributeError:
import warnings
warnings.warn('Operator should be defined on a MultiDomain')
pass
index = nodes[index][1]
cond = type(index) is int
def recognize_nodes(op, active_node, left):
# If nothing added - is a leaf!
isleaf = True
if isinstance(op, _OpChain):
for i in range(len(op._ops)):
if isnode(op._ops[i]):
nodes.append((op._ops[i], active_node, left))
isleaf = False
elif isnode(op):
nodes.append((op, active_node, left))
isleaf = False
if isleaf:
leaves.add((active_node, left))
def equal_nodes(op):
# BFS-Algorithm which fills the nodes list and id_dic dictionary
# Does not scan equal subtrees multiple times
list_index_traversed = 0
recognize_nodes(op, None, None)
while list_index_traversed < len(nodes):
# Visit node
active_node = nodes[list_index_traversed][0]
# Check whether exists already
try:
id_dic[id(active_node)] = id_dic[id(active_node)] + [list_index_traversed]
match = True
except KeyError:
id_dic[id(active_node)] = [list_index_traversed]
match = False
# Check vertices for nodes
if not match:
recognize_nodes(active_node._op1, list_index_traversed, True)
recognize_nodes(active_node._op2, list_index_traversed, False)
list_index_traversed += 1
edited = set()
def equal_leaves(leaves):
id_leaf = {}
# Find matching leaves
def write_to_dic(leaf, leaf_op_id):
try:
id_leaf[leaf_op_id] = id_leaf[leaf_op_id] + (leaf,)
except KeyError:
id_leaf[leaf_op_id] = (leaf,)
for leaf in leaves:
parent = nodes[leaf[0]][0]
attr = left_parser(leaf[1])
leaf_op = getattr(parent, attr)
if isinstance(leaf_op, _OpChain):
leaf_op_id = ''
for i in reversed(leaf_op._ops):
leaf_op_id += str(id(i))
if not isinstance(i, FieldAdapter):
# Do not optimise leaves which only have equal FieldAdapters
write_to_dic(leaf, leaf_op_id)
break
else:
if not isinstance(leaf_op, FieldAdapter):
write_to_dic(leaf, str(id(leaf_op)))
# Unroll their OpChain and see how far they are equal
key_list_leaf = []
same_leaf = {}
get_duplicate_keys(key_list_leaf, id_leaf)
for key in key_list_leaf:
to_compare = []
for leaf in id_leaf[key]:
parent = nodes[leaf[0]][0]
attr = left_parser(leaf[1])
leaf_op = getattr(parent, attr)
if isinstance(leaf_op, _OpChain):
to_compare.append(tuple(reversed(leaf_op._ops)))
else:
to_compare.append((leaf_op,))
first_difference = 1
max_diff = min(len(i) for i in to_compare)
if not max_diff == 1:
compare_iterator = iter(to_compare)
first = next(compare_iterator)
while all(first[first_difference] == rest[first_difference] for rest in compare_iterator):
first_difference += 1
if first_difference >= max_diff:
break
compare_iterator = iter(to_compare)
first = next(compare_iterator)
common_op = to_compare[0][:first_difference]
res_op = common_op[0]
for ops in common_op[1:]:
res_op = ops @ res_op
same_leaf[key] = [res_op, FieldAdapter(res_op.target, next(prepend_id) + str(id(res_op)))]
for leaf in id_leaf[key]:
parent = nodes[leaf[0]][0]
edited.add(id_dic[id(parent)][0])
attr = left_parser(leaf[1])
leaf_op = getattr(parent, attr)
if isinstance(leaf_op, _OpChain):
if first_difference == len(leaf_op._ops):
setattr(parent, attr, same_leaf[key][1])
else:
leaf_op._ops = leaf_op._ops[:-first_difference] + (same_leaf[key][1],)
else:
setattr(parent, attr, same_leaf[key][1])
return key_list_leaf, same_leaf
equal_nodes(op)
key_temp = []
key_list_op, same_op = equal_leaves(leaves)
cond = True
while cond:
key_temp, same_op_temp = equal_leaves(leaves)
key_list_op += key_temp
same_op.update(same_op_temp)
cond = len(same_op_temp) > 0
key_temp.clear()
# Cut subtrees
key_list_node = []
key_list_subtrees = []
same_node = {}
same_subtrees = {}
subtree_leaves = set()
get_duplicate_keys(key_list_node, id_dic)
for key in key_list_node:
same_node[key] = [nodes[id_dic[key][0]][0],
FieldAdapter(nodes[id_dic[key][0]][0].target, next(prepend_id) + str(key))]
for node_indices in id_dic[key]:
edited.add(node_indices)
parent = nodes[nodes[node_indices][1]][0]
attr = left_parser(nodes[node_indices][2])
if isinstance(getattr(parent, attr), _OpChain):
getattr(parent, attr)._ops = getattr(parent, attr)._ops[:-1] + (same_node[key][1],)
else:
setattr(parent, attr, same_node[key][1])
# Nodes have been replaced - treat replacements now as leaves
subtree_leaves.add((nodes[node_indices][1], nodes[node_indices][2]))
cond = True
while cond:
key_temp1, same_temp = equal_leaves(subtree_leaves)
key_temp = key_temp1 + key_temp
same_subtrees.update(same_temp)
cond = len(same_temp) > 0
key_list_subtrees += key_temp + [key, ]
key_temp.clear()
subtree_leaves.clear()
same_subtrees.update(same_node)
for index in edited:
rebuild_domains(index)
if isinstance(op, _OpChain):
op._domain = op._ops[-1].domain
# Insert trees before leaves
for key in key_list_subtrees:
op = op.partial_insert(same_subtrees[key][1].adjoint(same_subtrees[key][0]))
for key in reversed(key_list_op):
op = op.partial_insert(same_op[key][1].adjoint(same_op[key][0]))
return op
[docs]
def optimise_operator(op):
"""
Merges redundant operations in the tree structure of an operator.
For example it is ensured that for ``f@x + x`` the operator ``x`` is only computed once.
It is supposed to be used on the whole operator chain before doing minimisation.
Currently optimises only ``_OpChain``, ``_OpSum`` and ``_OpProd`` and not their linear pendants
``ChainOp`` and ``SumOperator``.
Parameters
----------
op : Operator
Operator with a tree structure.
Returns
-------
op_optimised : Operator
Operator with same input/output, but optimised tree structure.
Notes
-----
Operators are compared only by id, so best results are achieved when the following code
>>> from nifty8 import UniformOperator, DomainTuple
>>> uni1 = UniformOperator(DomainTuple.scalar_domain()
>>> uni2 = UniformOperator(DomainTuple.scalar_domain()
>>> op = (uni1 + uni2)*(uni1 + uni2)
is replaced by something comparable to
>>> uni = UniformOperator(DomainTuple.scalar_domain())
>>> uni_add = uni + uni
>>> op = uni_add * uni_add
After optimisation the operator is as fast as
>>> op = (2*uni)**2
"""
op_optimised = deepcopy(op)
op_optimised = _optimise_operator(op_optimised)
test_field = from_random(op.domain)
if isinstance(op(test_field), MultiField):
for key in op(test_field).keys():
myassert(allclose(op(test_field).val[key], op_optimised(test_field).val[key], 1e-10))
else:
myassert(allclose(op(test_field).val, op_optimised(test_field).val, 1e-10))
return op_optimised