Source code for nifty8.minimization.sample_list

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2021 Max-Planck-Society
# Author: Philipp Arras, Philipp Frank

import os
import pickle
import re
import time
from warnings import warn

from .. import utilities
from ..field import Field
from ..multi_domain import MultiDomain
from ..multi_field import MultiField
from ..operators.operator import Operator
from ..utilities import get_MPI_params_from_comm, shareRange


[docs] class SampleListBase: """Base class for storing lists of fields representing samples. This class suits as a base class for storing lists of in most cases posterior samples. It is intended to be used to hold the minimization state of an inference run and comes with a variety of convenience functions like computing the mean or standard deviation of the output of a given operator over the sample list. Parameters ---------- comm : MPI communicator or None If not `None`, :class:`SampleListBase` can gather samples across multiple MPI tasks. If `None`, :class:`SampleListBase` is not a distributed object. domain : Domainoid (can be DomainTuple, MultiDomain, dict, Domain or list of Domains) The domain on which the samples are defined. n_local_samples : int Number of samples present on the current MPI task. Note ---- A class inheriting from :class:`SampleListBase` needs to call the constructor of `SampleListBase` and needs to implement :attr:`n_local_samples` and :attr:`local_item()`. """
[docs] def __init__(self, comm, domain, n_local_samples): from ..sugar import makeDomain self._comm = comm self._domain = makeDomain(domain) utilities.check_MPI_equality(self._domain, comm) self.local_indices = _compute_local_indices(n_local_samples, comm) self._active_comm = comm if comm is not None: # 0 corresponds to empty tasks # 1 corresponds to tasks with samples color = 0 if len(self.local_indices) == 0 else 1 # If the same key is assigned to multiple processes, the conflict is # resolved according to the rank in the original communicator key = 0 self._active_comm = comm.Split(color, key) self._n_samples = utilities.allreduce_sum([self.n_local_samples], self.comm)
@property def n_local_samples(self): """int: Number of local samples.""" return len(self.local_indices) @property def n_samples(self): """Return number of samples across all MPI tasks.""" return self._n_samples
[docs] def local_item(self, i): """Return ith local sample.""" raise NotImplementedError
[docs] def local_iterator(self): """Return an iterator over all local samples.""" for i in range(self.n_local_samples): yield self.local_item(i)
@property def comm(self): """MPI communicator or None: The communicator used for the SampleListBase.""" return self._comm @property def MPI_master(self): return utilities.get_MPI_params_from_comm(self.comm)[2] @property def domain(self): """DomainTuple or MultiDomain: the domain on which the samples are defined.""" return self._domain
[docs] def save_to_hdf5(self, file_name, op=None, samples=False, mean=False, std=False, overwrite=False): """Write sample list to HDF5 file. This function writes sample lists to HDF5 files that contain two groups: `samples` and `stats`. `samples` contain the sublabels `0`, `1`, ... that number the labels and `stats` contains the sublabels `mean` and `standard deviation`. If `self.domain` is an instance of :class:`~nifty8.multi_domain.MultiDomain`, these sublabels refer themselves to subgroups. For :class:`~nifty8.field.Field`, the sublabel refers to an HDF5 data set. If quanitities are not requested (e.g. by setting `mean=False`), the respective sublabels are not present in the HDF5 file. If `op` is an :class:`~nifty8.operators.operator.Operator`, the operator string representation, its domain and its target are written to the attributes of the HDF5 file. Parameters ---------- file_name : str File name of output hdf5 file. op : callable or None Callable that is applied to each item in the :class:`SampleListBase` before it is returned. Can be an :class:`~nifty8.operators.operator.Operator` or any other callable that takes a :class:`~nifty8.field.Field` as an input. Default: None. samples : bool If True, samples are written into hdf5 file. mean : bool If True, mean of samples is written into hdf5 file. std : bool If True, standard deviation of samples is written into hdf5 file. overwrite : bool If True, a potentially existing file with the same file name as `file_name`, is overwritten. """ import h5py if os.path.isfile(file_name): if self.MPI_master and overwrite: os.remove(file_name) if not overwrite: raise RuntimeError(f"File {file_name} already exists. Delete it or use " "`overwrite=True`") if not (samples or mean or std): raise ValueError("Neither samples nor mean nor standard deviation shall be written.") if self.MPI_master: f = h5py.File(file_name, "w") if isinstance(op, Operator): f.attrs["nifty operator string representation"] = str(op) f.attrs["nifty operator domain"] = repr(op.domain) f.attrs["nifty operator target"] = repr(op.target) f.attrs["nifty domain"] = repr(op.target) else: f.attrs["nifty domain"] = repr(self.domain) else: f = utilities.Nop() # TODO Add some meta information (e.g. k_length for PowerSpace, distances for RGSpace) if samples: grp = f.create_group("samples") for ii, ss in enumerate(self.iterator(op)): _field2hdf5(grp, ss, str(ii)) if std: grp = f.create_group("stats") m, v = self.sample_stat(op) else: if mean: grp = f.create_group("stats") m = self.average(op) if mean: _field2hdf5(grp, m, "mean") if std: _field2hdf5(grp, v.sqrt(), "standard deviation") f.close() _barrier(self.comm)
[docs] def save_to_fits(self, file_name_base, op=None, samples=False, mean=False, std=False, overwrite=False): """Write sample list to FITS file. This function writes properties of a sample list to a FITS file. This is supported if and only if the target of `op` is a two-dimensional RGSpace. Parameters ---------- file_name_base : str File name base of output FITS file, i.e. without `.fits` extension. op : callable or None Callable that is applied to each item in the :class:`SampleListBase` before it is returned. Can be an :class:`~nifty8.operators.operator.Operator` or any other callable that takes a :class:`~nifty8.field.Field` as an input. Default: None. samples : bool If True, samples are written into hdf5 file. mean : bool If True, mean of samples is written into hdf5 file. std : bool If True, standard deviation of samples is written into hdf5 file. overwrite : bool If True, a potentially existing file with the same file name as `file_name`, is overwritten. """ # TODO Add support for 3d and 4d fields # https://fits.gsfc.nasa.gov/fits_standard.html if not (samples or mean or std): raise ValueError("Neither samples nor mean nor standard deviation shall be written.") if mean or std: m, v = self.sample_stat(op) if mean: self._save_fits_2d(m, file_name_base + "_mean.fits", overwrite) if std: self._save_fits_2d(v.sqrt(), file_name_base + "_std.fits", overwrite) if samples: for ii, ss in enumerate(self.iterator(op)): self._save_fits_2d(ss, file_name_base + f"_sample_{ii}.fits", overwrite)
def _save_fits_2d(self, fld, file_name, overwrite): import astropy.io.fits as pyfits from astropy.time import Time from ..domain_tuple import DomainTuple dom = fld.domain if not isinstance(dom, DomainTuple) or len(dom) != 1 or len(dom[0].shape) != 2: raise ValueError("FITS file export is only supported from 2d-fields. " f"Current domain:\n{dom}") h = pyfits.Header() h["DATE-MAP"] = Time(time.time(), format="unix").iso.split()[0] h["CRVAL1"] = h["CRVAL2"] = 0 h["CRPIX1"] = h["CRPIX2"] = 0 h["CUNIT1"] = h["CUNIT2"] = "deg" h["CDELT1"], h["CDELT2"] = -dom[0].distances[0], dom[0].distances[1] h["CTYPE1"] = "RA---SIN" h["CTYPE2"] = "DEC---SIN" h["EQUINOX"] = 2000 hdu = pyfits.PrimaryHDU(fld.val[:, :].T, header=h) hdulist = pyfits.HDUList([hdu]) if self.MPI_master: hdulist.writeto(file_name, overwrite=overwrite)
[docs] def iterator(self, op=None): """Return iterator over all potentially distributed samples. Parameters ---------- op : callable or None Callable that is applied to each item in the :class:`SampleListBase` before it is returned. Can be an :class:`~nifty8.operators.operator.Operator` or any other callable that takes a :class:`~nifty8.field.Field` as an input. Default: None. Note ---- Calling this function involves MPI communication if `comm != None`. """ op = _none_to_id(op) if self.comm is not None: for itask in range(self.comm.Get_size()): for i in range(_bcast(self.n_local_samples, self._comm, itask)): ss = self.local_item(i) if itask == self._comm.Get_rank() else None yield op(_bcast(ss, self._comm, itask)) else: for ss in self.local_iterator(): yield op(ss)
[docs] def average(self, op=None): """Compute average over all potentially distributed samples. Parameters ---------- op : callable or None Callable that is applied to each item in the :class:`SampleListBase` before it is averaged. Note ---- Calling this function involves MPI communication if `comm != None`. Note ---- Calling this function involves allocating arrays for all samples and averaging them afterwards. If the number of local samples is big and `op` is not None, this leads to much temporary memory usage. If the output of `op` is just a :class:`~nifty8.field.Field` or :class:`~nifty8.multi_field.MultiField`, :attr:`sample_stat()` can be used in order to compute the average memory efficiently. """ res = self._prepare_average(op) n = self.n_samples return utilities.allreduce_sum(res, self.comm) / n
def _average_tuple(self, op): """Compute simultaneous average over all potentially distributed samples. Parameters ---------- op : callable Callable that is applied to each item in the :class:`SampleListBase` before it is averaged. `op` shall return a tuple of whose individual averages are computed and returned as tuple. Note ---- Calling this function involves MPI communication if `comm != None`. """ n = self.n_samples res = None if len(self.local_indices) > 0: res = self._prepare_average(op) res = _list_transpose(res) res = tuple(utilities.allreduce_sum(rr, self._active_comm) / n for rr in res) if self._active_comm is None and self._comm is None: return res if len(self.local_indices) > 0: task_list = self._comm.allgather(self._comm.Get_rank()) else: task_list = self._comm.allgather(-1) root = next(filter((-1).__ne__, task_list)) return self._comm.bcast(res, root=root) def _prepare_average(self, op): op = _none_to_id(op) res = [op(ss) for ss in self.local_iterator()] return res
[docs] def sample_stat(self, op=None): """Compute mean and variance of samples after applying `op`. Parameters ---------- op : callable or None Callable that is applied to each item in the :class:`SampleListBase` before it is used to compute mean and variance. Returns ------- tuple A tuple with two items: the mean and the variance. """ from ..probing import StatCalculator if self.n_samples == 1: res = self.average(op) return res, 0*res sc = StatCalculator() for ss in self.iterator(op): sc.add(ss) return sc.mean, sc.var
[docs] def save(self, file_name_base, overwrite=False): """Serialize SampleList and write it to disk. Parameters ---------- file_name_base : str File name of the output file without extension. The actual file name will have the extension ".pickle". overwrite : bool Existing files are overwritten. Note ---- If the instance of :class:`SampleListBase` is distributed, each MPI task writes its own file. """ raise NotImplementedError
[docs] @classmethod def load(cls, file_name_base, comm=None): """Deserialize SampleList from files on disk. Parameters ---------- file_name_base : str File name of the input file without extension. The actual file name will have the extension ".pickle". comm : MPI communicator or None If not `None`, each MPI task reads its own input file. Note ---- `file_name_base` needs to be the same string that has been used for saving the :class:`SampleListBase`. """ raise NotImplementedError
@classmethod def _list_local_sample_files(cls, file_name_base, comm=None): """List all sample files that are relevant for the local task. All sample files that correspond to `file_name_base` are searched and selected based on which rank the local task has. Note ---- This function makes sure that the all file numbers between 0 and the maximal found number are present. If this is not the case, a `RuntimeError` is raised. Note ---- This function distributes the samples according to the standard NIFTy distribution scheme (see `ift.utilities.shareRange`). """ base_dir, base_file = os.path.split(os.path.abspath(file_name_base)) files = [ff for ff in os.listdir(base_dir) if re.match(f"{base_file}.[0-9]+.pickle", ff)] if len(files) == 0: raise RuntimeError(f"No files matching `{file_name_base}.*.pickle`") n_samples = max(list(map(lambda x: int(x.split(".")[-2]), files))) + 1 ntask, rank, _ = get_MPI_params_from_comm(comm) local_indices = range(*shareRange(n_samples, ntask, rank)) files = [f"{file_name_base}.{ii}.pickle" for ii in local_indices] for ff in files: if not os.path.isfile(ff): raise RuntimeError(f"File {ff} not found") return files
[docs] class ResidualSampleList(SampleListBase):
[docs] def __init__(self, mean, residuals, neg, comm=None): """Store samples in terms of a mean and a residual deviation thereof. Parameters ---------- mean : :class:`nifty8.field.Field` or :class:`nifty8.multi_field.MultiField` Mean of the sample list. residuals : list of :class:`nifty8.field.Field` or list of :class:`nifty8.multi_field.MultiField` List of residuals from the mean. If it is a list of `MultiField`, the domain of the residuals can be a subdomain of the domain of mean. This results in adding just a zero in respective `MultiField` entries. neg: list of bool This list has to have the same length as `residuals`. If an entry is `True`, the respective residual is subtracted and not added. comm : MPI communicator or None If not `None`, samples can be gathered across multiple MPI tasks. If `None`, :class:`ResidualSampleList` is not a distributed object. """ self._m = mean self._r = tuple(residuals) self._n = tuple(neg) super(ResidualSampleList, self).__init__(comm, mean.domain, len(self._r)) if len(self._r) != len(self._n): raise ValueError("Residuals and neg need to have the same length.") if any(elem == 0 for elem in utilities.allreduce_sum([[len(self._r)]], comm)): warn("ResidualSampleList: (Partially empty) sample list detected.") if len(self._r) > 0: r_dom = self._r[0].domain if not all(rr.domain is r_dom for rr in self._r): raise ValueError("All residuals must have the same domain.") if isinstance(r_dom, MultiDomain): try: self._m.extract(r_dom) except: raise ValueError("`residual.domain` must be a subdomain of `mean.domain`.") if not all(isinstance(nn, bool) for nn in neg): raise TypeError("All entries in neg need to be bool.")
[docs] def local_item(self, i): return self._m.flexible_addsub(self._r[i], self._n[i])
[docs] def at(self, mean): """Instantiate `ResidualSampleList` with the same residuals as `self`. The mean is updated. Note ---- If `self.domain` is a :class:`~nifty8.multi_domain.MultiDomain`, the old and new mean are combined with :attr:`~nifty8.multi_field.MultiField.union` beforehand. This means that only the multi field entries present in `mean` are updated. Returns ------- ResidualSampleList Sample list with updated mean. """ if isinstance(self._m, MultiField) and self.domain is not mean.domain: mean = MultiField.union([self._m, mean]) return ResidualSampleList(mean, self._r, self._n, self.comm)
[docs] def save(self, file_name_base, overwrite=False): for ii, isample in enumerate(self.local_indices): obj = [self._r[ii], self._n[ii]] fname = _sample_file_name(file_name_base, isample) _save_to_disk(fname, obj, overwrite) if self.MPI_master: _save_to_disk(f"{file_name_base}.mean.pickle", self._m, overwrite) _barrier(self.comm)
[docs] @classmethod def load(cls, file_name_base, comm=None): _barrier(comm) mean = _load_from_disk(f"{file_name_base}.mean.pickle") files = cls._list_local_sample_files(file_name_base, comm) tmp = [_load_from_disk(ff) for ff in files] res = [aa[0] for aa in tmp] neg = [aa[1] for aa in tmp] return cls(mean, res, neg, comm=comm)
[docs] @classmethod def load_mean(cls, file_name_base): return _load_from_disk(f"{file_name_base}.mean.pickle")
[docs] class SampleList(SampleListBase):
[docs] def __init__(self, samples, comm=None, domain=None): """Store samples as a plain list. This is a minimalist implementation of :class:`SampleListBase`. It just serves as a (potentially MPI-distributed) wrapper of a list of samples. Parameters ---------- samples : list of :class:`nifty8.field.Field` or list of :class:`nifty8.multi_field.MultiField` List of samples. comm : MPI communicator or None If not `None`, samples can be gathered across multiple MPI tasks. If `None`, :class:`ResidualSampleList` is not a distributed object. domain : DomainTuple, MultiDomain or None Sets the domain of the `SampleList`. If `samples` is non-empty and `domain` is not None, `domain` has to coincide with the domain of the samples. Default: None. """ from ..sugar import makeDomain if domain is None: if len(samples) == 0: raise ValueError("Need to pass `domain` to instantiate empty `SampleList`.") else: domain = samples[0].domain else: domain = makeDomain(domain) super(SampleList, self).__init__(comm, domain, len(samples)) self._s = samples for ss in self._s: utilities.check_object_identity(ss.domain, self.domain)
[docs] def local_item(self, i): return self._s[i]
[docs] def save(self, file_name_base, overwrite=False): nsample = self.n_samples for isample in range(nsample): if isample in self.local_indices: obj = self._s[isample-self.local_indices[0]] fname = _sample_file_name(file_name_base, isample) _save_to_disk(fname, obj, overwrite=True) _barrier(self.comm)
[docs] @classmethod def load(cls, file_name_base, comm=None): from ..logger import logger _barrier(comm) foo = "{file_name_base}.mean.pickle" if os.path.isfile(foo): logger.warning(f"{foo} is present. Most probably you intended to " "call `ift.ResidualSampleList.load()`.") files = cls._list_local_sample_files(file_name_base, comm) samples = [_load_from_disk(ff) for ff in files] dom = None if comm is not None: if comm.Get_rank() == 0: dom = samples[0].domain dom = comm.bcast(dom, 0) return cls(samples, comm=comm, domain=dom)
def _none_to_id(obj): """If the input is None, replace it with identity map. If input is `Operator`, append `force` in order to make sure that the function is evaluated if there is a chance that the domains match. Otherwise return input. """ if obj is None: return lambda x: x if isinstance(obj, Operator): return lambda x: obj.force(x) return obj def _bcast(obj, comm, root): """Broadcast python object from given root. Parameters ---------- obj : object The object to be broadcasted. comm : MPI communicator MPI communicator used for the broadcasting. root : int MPI task number from which the object shall be sent. """ data = obj if comm.Get_rank() == root else None return comm.bcast(data, root=root) def _sample_file_name(file_name_base, isample): """Return sample-unique file name. This file name that can be used to uniquely write samples potentially from all MPI tasks. Parameters ---------- file_name_base : str User-defined first part of file name. isample : int Number of sample """ if not isinstance(isample, int): raise TypeError return f"{file_name_base}.{isample}.pickle" def _load_from_disk(file_name): with open(file_name, "rb") as f: obj = pickle.load(f) return obj def _save_to_disk(file_name, obj, overwrite=False): if overwrite and os.path.isfile(file_name): os.remove(file_name) with open(file_name, "wb") as f: pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL) def _field2hdf5(file_handle, obj, name): if not isinstance(name, str): raise TypeError if isinstance(obj, MultiField): grp = file_handle.create_group(name) for kk, fld in obj.items(): _field2hdf5(grp, fld, kk) return if not isinstance(obj, Field): raise TypeError file_handle.create_dataset(name, data=obj.val) def _barrier(comm): if comm is not None: comm.Barrier() def _list_transpose(list_of_lists): ny = len(list_of_lists[0]) return [[elem[ii] for elem in list_of_lists] for ii in range(ny)] def _compute_local_indices(n_local, comm): if comm is None: return range(n_local) n_locals = comm.allgather(n_local) start = sum(n_locals[:comm.Get_rank()]) return range(start, start + n_local)