#!/usr/bin/env python3
import numpy as np
import nifty7 as ift
from collections import namedtuple
Model = namedtuple(
"Model",
(
"density",
"correlated_field",
"select_subset",
"amplitude_total_offset",
"normalized_amplitudes",
),
)
[docs]def build_density_estimator(
shape,
pixel_distances=None,
pad=1.0,
cf_fluctuations=None,
cf_azm_uniform=None
):
"""Construct the model for a density estimator
Parameters
----------
shape : int or tuple of int or tuple of tuple of int
Shape of the to-be-modeled data. Note, the level of nesting matters!
While `((128, 128), )` denotes a shared kernel for both axis,
`(128, 128)` denotes separate kernels for both axis.
pixel_distances : float or tuple of float, optional
Distances between individual pixels.
pad : float or tuple of float, optional
Padding factor.
cf_fluctuations : dict or tuple of dict, optional
Parameters for the Matern kernel of the density field.
cf_azm_uniform : tuple of float, optional
Parameters for the a-priori uniform amplitude total offset.
Returns
-------
model : Model
Collection of operators that represent the model. Most importantly
`model` features an attribute `.density` which corresponds to the
NIFTy representation of the density estimator.
"""
# Cast shape to type tuple[tuple[int]]
if isinstance(shape, int):
shape = ((shape, ), )
elif all(isinstance(el, int) for el in shape):
shape = tuple((el, ) for el in shape)
# Likewise for pixel_distances but with a lower nesting level
if pixel_distances is None or isinstance(pixel_distances, float):
pixel_distances = (pixel_distances, ) * len(shape)
dt = []
for shp, dist in zip(shape, pixel_distances):
dt.append(ift.RGSpace(shape=shp, distances=dist))
position_space = ift.DomainTuple.make(dt)
density, model_operators = ift.density_estimator(
position_space,
pad=pad,
cf_fluctuations=cf_fluctuations,
cf_azm_uniform=cf_azm_uniform,
)
return Model(density=density, **model_operators)
[docs]def draw_synthetic_sample(model, seed=60, pretty=False):
"""Draw synthetic data (i.e. produce a prior sample in data space)
Parameters
----------
model : Model
Collection of operators representing the model.
seed : int or np.random.SeedSequence, optional
Random state with which to draw the synthetic truth and the noise.
pretty : bool or set
Set of latent parameters which are artificially dampened as to produce
untrue prior samples that look prettier. If a boolean and true then a
sane set of latent parameters is chosen.
Returns
-------
synth_data : np.ndarray
Array of integers for the synthetic Poissonian counts.
synth_truth : np.ndarray
Array of the true expectation value underlying the Poissonian counts.
"""
if isinstance(pretty, bool):
dampened_keys = {
"zeromode", "scale", "cutoff", "loglogslope"
} if pretty else {}
else:
dampened_keys = set(pretty)
# Use a dedicated seed to draw synthetic samples to ease reproducibility
if not isinstance(seed, np.random.SeedSequence):
sseq = np.random.SeedSequence(seed)
else:
sseq = seed
rng = np.random.default_rng(sseq.spawn(1)[0])
synth_position = {}
with ift.random.Context(sseq.spawn(1)[0]):
sp = ift.from_random(model.density.domain).val
for k, v in sp.items():
if any(k.endswith(dk) for dk in dampened_keys):
synth_position[k] = 1e-3 * v
else:
synth_position[k] = v
synth_position = ift.MultiField.from_raw(
model.density.domain, synth_position
)
synth_truth = model.density(synth_position).val
synth_data = rng.poisson(synth_truth)
return synth_data, synth_truth
[docs]def fit(
model,
data,
n_samples=5,
n_max_iterations=20,
init_position=None,
init_seed=31617,
kl_kw=None,
sampling_kw=None,
minimizer_kw=None
):
"""Fit the model to data and return approximate posterior samples
Parameters
----------
model : Model
Collection of operators representing the model.
data : np.ndarray
Poissonian count data to which to fit the model to.
init_position : ift.MultiField, optional
Optional starting position in the form of a MultiField from NIFTy.
init_seed : int or np.random.SeedSequence, optional
Random state with which to draw the synthetic truth and the noise.
kl_kw : dict, optional
Parameters for the Kullback-Leibler computation.
sampling_kw : dict, optional
Parameters for the sampling iteration controller.
minimizer_kw : dict, optional
Parameters for the minimizer iteration controller.
Returns
-------
density_samples : tuple of np.ndarray
Samples from the approximate posterior distribution.
lat_samples : tuple of ift.MultiField
Latent samples compatible with the collection of operators
representation in the model. These can be used to access intermediate
posteriors within the model itself.
"""
if not isinstance(init_seed, np.random.SeedSequence):
sseq = np.random.SeedSequence(init_seed)
else:
sseq = init_seed
kl_kw = {"mirror_samples": True} if kl_kw is None else kl_kw
if sampling_kw is None:
sampling_kw = {
"name": "Sampling",
"deltaE": 0.0,
"iteration_limit": 200
}
if minimizer_kw is None:
minimizer_kw = {
"name": "Minimizer",
"deltaE": 0.0,
"iteration_limit": 35
}
if not np.issubdtype(data.dtype, np.integer):
te = f"data of invalid type {data.dtype!r}; expected dtype `int`"
raise TypeError(te)
if not hasattr(model, "density"):
raise ValueError("`model` has no `.density` attribute")
if not isinstance(model.density, ift.Operator):
raise TypeError(f"`model.density` of invalid type {type(model)!r}")
# Minimization parameters
ic_sampling = ift.AbsDeltaEnergyController(**sampling_kw)
ic_newton = ift.AbsDeltaEnergyController(**minimizer_kw)
ic_sampling.enable_logging()
ic_newton.enable_logging()
minimizer = ift.NewtonCG(ic_newton, enable_logging=True)
# Set up likelihood and information Hamiltonian
data = ift.Field.from_raw(model.density.target, data)
likelihood = ift.PoissonianEnergy(data) @ model.density
ham = ift.StandardHamiltonian(likelihood, ic_sampling)
position = init_position
if position is None:
with ift.random.Context(sseq.spawn(1)[0]):
position = 1e-2 * ift.from_random(model.density.domain)
for i in range(n_max_iterations):
# TODO: introduce some crude stopping criteria
# Draw new samples and minimize KL
kl = ift.MetricGaussianKL(position, ham, n_samples, **kl_kw)
kl, _ = minimizer(kl)
position = kl.position
lat_samples = tuple(position + smpl for smpl in kl.samples)
density_samples = tuple(model.density(ls).val for ls in lat_samples)
return density_samples, lat_samples