Source code for nifty8.minimization.config.optimize_kl_config

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2022 Philipp Arras
# Author: Philipp Arras
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.

import importlib
import os
from configparser import ConfigParser
from warnings import warn

from ...utilities import myassert

# FIXME point_estimates, constants?
# FIXME "2**lh0" looks weird. Change syntax? Use "$" for references?
# FIXME Cache operators (especially likelihoods)


[docs] class OptimizeKLConfig: """Plug config files into ift.optimize_kl. If you use this class for the first time, look at an example config file (e.g. `demos/getting_started_7_config_file.cfg`) before continuing to read the documentation. After parsing the config file the following steps are taken: - All `base` entries are interpreted. The options that are written out explicitly in a given section override what might be imported via a `base` entry. `base` inheritance does not work recursively (yet?). - Repetitions (e.g. "2*5,3*2") in `optimization.*` are expanded and replaced by an explicit list (e.g. "5,5,2,2,2"). The length of the resulting list is not allowed to be longer than `total iterations`. If it is shorter, the last value is repeated. - All optimization stages (`optimization.*`) are joined into a single stage (`optimization.0`). For this, the quantity after the point `.` in `optimization.*` is interpreted as `int` and sorted afterwards. This means that for example `optimization.02` comes after `optimization.-1`. E.g. `optimization.1.1` is not allowed. For referring to e.g. likelihood energy operators, the star `*` is used as dereferencing operator. The value after the `*` refers to a section in the config file that must be instantiable via `OptimizeKLConfig.instantiate_section(key)`. Generally, all spaces (` `) in keys are internally replaced by underscores (`_`) such that they can be used in function calls. Parameters ---------- config_parser : ConfigParser ConfigParser that contains all configuration. builders : dict Dictionary of functions that are used to instantiate e.g. operators. Example ------- For an example of a typical config file, look at `demos/getting_started_7_config_file.cfg` in the nifty repository. Note ---- Not treated by this class: export_operator_outputs, initial_position, initial_index, comm, inspect_callback, terminate_callback, return_final_position, resume Note ---- Make sure that ConfigParser is case-sensitive by setting its attribute `optionxform` to `str` before parsing. """
[docs] def __init__(self, config_parser, builders): if not isinstance(config_parser, ConfigParser): raise TypeError if config_parser.optionxform != str: warn("Consider setting `config_parser.optionxform = str`") self._cfg = config_parser self._builders = dict(builders) self._interpret_base() self._interpret_repetitions() self._join_optimization_stages()
[docs] @classmethod def from_file(cls, file_name, builders): """ Parameters ---------- file_name : str File name of the config file that is imported. builders : dict Dictionary of functions that are used to instantiate e.g. operators. """ cfg = ConfigParser() cfg.optionxform = str # make keys case-sensitive if not os.path.isfile(file_name): raise RuntimeError(f"`{file_name}` not found") cfg.read(file_name) return cls(cfg, builders)
[docs] def to_file(self, name): """Write configuration in standardized form to file. Parameters ---------- name : str Path to which the config shall be written. """ with open(name, "w") as f: self._cfg.write(f)
[docs] def optimize_kl(self, **kwargs): """Do the inference and write the config file to output directory. All additional parameters to `ift.optimize_kl` can be passed via `kwargs`. """ from ..optimize_kl import optimize_kl dct = dict(self) os.makedirs(dct["output_directory"], exist_ok=True) self.to_file(os.path.join(dct["output_directory"], "optimization.cfg")) return optimize_kl(**dct, **kwargs)
def _interpret_base(self): """Replace `base` entry in every section by the content of the section it points to.""" c = self._cfg for section in c: if "base" in c[section]: base_name = c[section]["base"] if base_name not in c: raise RuntimeError( f"the referred section `{base_name}` does not exist" ) if "base" in c[base_name]: raise RuntimeError("recursive bases not allowed for now") # Replace base entry in section by the respective values c[section] = {**c[base_name], **c[section]} del c[section]["base"] def _interpret_repetitions(self): """Expand repretitions in sections of the form `optimization.*`. For example `2*NewtonCG` expands to `NewtonCG,NewtonCG`. If fewer entries than `total iterations` are present, fill up with the last value. """ c = self._cfg # Only look at sections starting with "optimization." for optkey in filter(lambda x: x[:13] == "optimization.", c.keys()): sec = c[optkey] total_iterations = sec.getint("total iterations") for key in filter(lambda x: x != "total iterations", sec): if key == "base": raise AssertionError( "`base` must already be interpreted. This is a bug." ) # Expand multiply "*" if "," in sec[key]: # Handle spaces around "," entry_list_pre = map(lambda x: x.strip(), sec[key].split(",")) entry_list_post = [] for val in entry_list_pre: # Nothing to expand because * not present or dereferencing operator if "*" not in val or val[0] == "*": entry_list_post.append(val) continue # Multiply "*" and dereferencing "*" mixed splt = val.split("**") if len(splt) == 2 and splt[0] != "" and splt[1] != "": fac, val = splt val = "*" + val entry_list_post.extend(int(fac) * [val]) continue # actual expansion splt = val.split("*") if len(splt) != 2: raise RuntimeError( f"the expression `{val}` cannot have more than one `*`" ) fac, val = splt entry_list_post.extend(int(fac) * [val]) sec[key] = ",".join(entry_list_post) # Fill up entry_list_pre = sec[key].split(",") diff = total_iterations - len(entry_list_pre) if diff < 0: raise RuntimeError( f"The number of total iterations ({total_iterations}) is at least {-diff} too small." ) entry_list_post = entry_list_pre + diff * [entry_list_pre[-1]] myassert(len(entry_list_post) == total_iterations) sec[key] = ",".join(entry_list_post) def _join_optimization_stages(self): """Join all optimization stages into one stage. All sections of the form `optimization.*` are combined into a single section called `optimization.0`. """ c = self._cfg # Only look at sections starting with "optimization." But this time in ascending order lookup = {} for optkey in filter(lambda x: x[:13] == "optimization.", c.keys()): _, myid = optkey.split(".") lookup[int(myid)] = optkey optimization_keys = [lookup[kk] for kk in sorted(lookup)] # Sorting done. # Merge optimization sections together into "optimization.0" # Start with first one fst_key = optimization_keys[0] sec0 = c[fst_key] # Add the rest for optkey in optimization_keys[1:]: sec = c[optkey] for key in sec: if key == "total iterations": sec0["total iterations"] = str( sec0.getint("total iterations") + sec.getint("total iterations") ) continue sec0[key] = ",".join([sec0[key], sec[key]]) # has been merged into sec0 and can be deleted del c[optkey] # If user has chosen something different than optimization.0 as first stage, normalize it if fst_key != "optimization.0": tmp = c[fst_key] c["optimization.0"] = tmp del c[fst_key] def _to_callable(self, s, dtype=None): """Turn list separated by `,` into function that takes the index and returns the respective entry. Additionally all references indicated by `*` are instantiated. """ def f(iteration): val = s.split(",")[iteration].strip() if val[0] == "*": # is reference val = val[1:] val = self.instantiate_section(val) if val == "None": return None if dtype is not None: val = dtype(val) return val return f
[docs] def instantiate_section(self, sec): """Instantiate object that is described by a section in the config file. There are two mechanisms how this instantiation works: - Look up the section key in the `self._builders` dictionary and call the respective function. - If `custom function` is specified in the section, pass all other entries of the section as arguments to the referred function. Before the instantiation is performed the inputs are transformed according to the type information that is passed in the config file. By default all values have type `str`. If `bool`, `float` or `int` shall be passed, the syntax `type :: value`, e.g. `float :: 1.2`, needs to be used in the config file. """ dct = dict(self._cfg[sec]) # Instantiate all references for kk in dct: val = dct[kk] if len(val) > 1 and val[0] == "*": # is reference dct[kk] = self.instantiate_section(val[1:]) # Replace all whitespaces with _ # FIXME Is here the best place to do this? newdct = {} for kk, vv in dct.items(): newdct[kk.replace(" ", "_")] = vv dct = newdct # Parse dtype for kk, vv in dct.items(): if not isinstance(vv, str): continue tmp = tuple(map(lambda x: x.strip(), vv.split("::"))) if len(tmp) == 2: # type information available if tmp[0] == "bool": if tmp[1].lower() == "true": vv = True elif tmp[1].lower() == "false": vv = False else: ValueError(f"{tmp[1]} is not boolean") elif tmp[0] == "float": vv = float(tmp[1]) elif tmp[0] == "int": vv = int(tmp[1]) elif tmp[0] == "None": vv = None dct[kk] = vv # Plug into builder or something else if sec in self._builders: return self._builders[sec](**dct) if "custom_function" in dct: mod_name, func_name = dct.pop("custom_function").rsplit(".", 1) mod = importlib.import_module(mod_name) func = getattr(mod, func_name) return func(**dct) raise RuntimeError( f"Provide build routine for `{sec}` in builders dictionary or " "reference a `custom_function` in the config file." )
[docs] def __iter__(self): """Enable conversion to `dict` such that the result of this class can easily be passed into `ift.optimize_kl`.""" cdyn = self._cfg["optimization.0"] # static copt = self._cfg["optimization"] yield "output_directory", copt["output directory"] yield "save_strategy", copt["save strategy"] yield "plot_energy_history", True yield "plot_minisanity_history", True # dynamic yield "total_iterations", int(cdyn["total iterations"]) for key in [ "likelihood_energy", "n_samples", "transitions", "kl_minimizer", "sampling_iteration_controller", "nonlinear_sampling_minimizer", ]: key1 = key.replace("_", " ") if key == "n_samples": yield key, self._to_callable(cdyn[key1], int) else: yield key, self._to_callable(cdyn[key1])
def __str__(self): s = [] for key, val in self._cfg.items(): s += [key] s += [f" {kk}: {vv}" for kk, vv in val.items()] s += [""] return "\n".join(s) def __eq__(self, other): for a in "_cfg", "_builders": if getattr(self, a) != getattr(other, a): return False return True