Source code for nifty8.operators.counting_operator
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2022 Max-Planck-Society, Philipp Arras
from .endomorphic_operator import EndomorphicOperator
from .operator import Operator
[docs]
class CountingOperator(Operator):
[docs]
def __init__(self, domain):
from ..sugar import makeDomain
self._domain = self._target = makeDomain(domain)
self._count_apply = 0
self._count_apply_lin = 0
self._derivative = _JacCountingOperator(self._domain)
[docs]
def apply(self, x):
from ..sugar import is_linearization
self._check_input(x)
if is_linearization(x):
self._count_apply_lin += 1
x = x.new(x.val, self._derivative)
else:
self._count_apply += 1
return x
@property
def count_apply(self):
return self._count_apply
@property
def count_apply_lin(self):
return self._count_apply_lin
@property
def count_jac(self):
return self._derivative._count_times
@property
def count_jac_adj(self):
return self._derivative._count_adjoint_times
[docs]
def __repr__(self):
return f"CountingOperator({self._domain.__repr__()})"
[docs]
def report(self):
s = [f"* apply: \t\t{self.count_apply:>7}",
f"* apply Linearization: \t{self.count_apply_lin:>7}",
f"* Jacobian: \t\t{self.count_jac:>7}",
f"* Adjoint Jacobian: \t{self.count_jac_adj:>7}"]
return "\n".join(s)
def _simplify_for_constant_input_nontrivial(self, c_inp):
from .simplify_for_const import InsertionOperator
return None, self @ InsertionOperator(self.domain, c_inp)
class _JacCountingOperator(EndomorphicOperator):
def __init__(self, domain):
from ..sugar import makeDomain
self._domain = makeDomain(domain)
self._capability = self.TIMES | self.ADJOINT_TIMES
self._count_times = 0
self._count_adjoint_times = 0
def apply(self, x, mode):
self._check_input(x, mode)
if mode == self.TIMES:
self._count_times += 1
else:
self._count_adjoint_times += 1
return x