Source code for nifty8.plot

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.

import os
from datetime import datetime as dt
from itertools import product
from warnings import warn

import numpy as np

from .domain_tuple import DomainTuple
from .domains.gl_space import GLSpace
from .domains.hp_space import HPSpace
from .domains.power_space import PowerSpace
from .domains.rg_space import RGSpace
from .field import Field
from .minimization.iteration_controllers import EnergyHistory
from .multi_field import MultiField
from .utilities import check_object_identity, myassert

# relevant properties:
# - x/y size
# - x/y/z log
# - x/y/z min/max
# - colorbar/colormap
# - axis on/off
# - title
# - axis labels
# - labels


def _mollweide_helper(xsize):
    xsize = int(xsize)
    ysize = xsize//2
    res = np.full(shape=(ysize, xsize), fill_value=np.nan, dtype=np.float64)
    xc, yc = (xsize-1)*0.5, (ysize-1)*0.5
    u, v = np.meshgrid(np.arange(xsize), np.arange(ysize))
    u, v = 2*(u-xc)/(xc/1.02), (v-yc)/(yc/1.02)

    mask = np.where((u*u*0.25 + v*v) <= 1.)
    t1 = v[mask]
    theta = 0.5*np.pi-(
        np.arcsin(2/np.pi*(np.arcsin(t1) + t1*np.sqrt((1.-t1)*(1+t1)))))
    phi = -0.5*np.pi*u[mask]/np.maximum(np.sqrt((1-t1)*(1+t1)), 1e-6)
    phi = np.where(phi < 0, phi+2*np.pi, phi)
    return res, mask, theta, phi


def _rgb_data(spectral_cube):
    _xyz = np.array(
          [[0.000160, 0.000662, 0.002362, 0.007242, 0.019110,
            0.043400, 0.084736, 0.140638, 0.204492, 0.264737,
            0.314679, 0.357719, 0.383734, 0.386726, 0.370702,
            0.342957, 0.302273, 0.254085, 0.195618, 0.132349,
            0.080507, 0.041072, 0.016172, 0.005132, 0.003816,
            0.015444, 0.037465, 0.071358, 0.117749, 0.172953,
            0.236491, 0.304213, 0.376772, 0.451584, 0.529826,
            0.616053, 0.705224, 0.793832, 0.878655, 0.951162,
            1.014160, 1.074300, 1.118520, 1.134300, 1.123990,
            1.089100, 1.030480, 0.950740, 0.856297, 0.754930,
            0.647467, 0.535110, 0.431567, 0.343690, 0.268329,
            0.204300, 0.152568, 0.112210, 0.081261, 0.057930,
            0.040851, 0.028623, 0.019941, 0.013842, 0.009577,
            0.006605, 0.004553, 0.003145, 0.002175, 0.001506,
            0.001045, 0.000727, 0.000508, 0.000356, 0.000251,
            0.000178, 0.000126, 0.000090, 0.000065, 0.000046,
            0.000033],
           [0.000017, 0.000072, 0.000253, 0.000769, 0.002004,
            0.004509, 0.008756, 0.014456, 0.021391, 0.029497,
            0.038676, 0.049602, 0.062077, 0.074704, 0.089456,
            0.106256, 0.128201, 0.152761, 0.185190, 0.219940,
            0.253589, 0.297665, 0.339133, 0.395379, 0.460777,
            0.531360, 0.606741, 0.685660, 0.761757, 0.823330,
            0.875211, 0.923810, 0.961988, 0.982200, 0.991761,
            0.999110, 0.997340, 0.982380, 0.955552, 0.915175,
            0.868934, 0.825623, 0.777405, 0.720353, 0.658341,
            0.593878, 0.527963, 0.461834, 0.398057, 0.339554,
            0.283493, 0.228254, 0.179828, 0.140211, 0.107633,
            0.081187, 0.060281, 0.044096, 0.031800, 0.022602,
            0.015905, 0.011130, 0.007749, 0.005375, 0.003718,
            0.002565, 0.001768, 0.001222, 0.000846, 0.000586,
            0.000407, 0.000284, 0.000199, 0.000140, 0.000098,
            0.000070, 0.000050, 0.000036, 0.000025, 0.000018,
            0.000013],
           [0.000705, 0.002928, 0.010482, 0.032344, 0.086011,
            0.197120, 0.389366, 0.656760, 0.972542, 1.282500,
            1.553480, 1.798500, 1.967280, 2.027300, 1.994800,
            1.900700, 1.745370, 1.554900, 1.317560, 1.030200,
            0.772125, 0.570060, 0.415254, 0.302356, 0.218502,
            0.159249, 0.112044, 0.082248, 0.060709, 0.043050,
            0.030451, 0.020584, 0.013676, 0.007918, 0.003988,
            0.001091, 0.000000, 0.000000, 0.000000, 0.000000,
            0.000000, 0.000000, 0.000000, 0.000000, 0.000000,
            0.000000, 0.000000, 0.000000, 0.000000, 0.000000,
            0.000000, 0.000000, 0.000000, 0.000000, 0.000000,
            0.000000, 0.000000, 0.000000, 0.000000, 0.000000,
            0.000000, 0.000000, 0.000000, 0.000000, 0.000000,
            0.000000, 0.000000, 0.000000, 0.000000, 0.000000,
            0.000000, 0.000000, 0.000000, 0.000000, 0.000000,
            0.000000, 0.000000, 0.000000, 0.000000, 0.000000,
            0.000000]])

    MATRIX_SRGB_D65 = np.array(
            [[3.2404542, -1.5371385, -0.4985314],
             [-0.9692660,  1.8760108,  0.0415560],
             [0.0556434, -0.2040259,  1.0572252]])

    def _gammacorr(inp):
        mask = np.zeros(inp.shape, dtype=np.float64)
        mask[inp <= 0.0031308] = 1.
        r1 = 12.92*inp
        a = 0.055
        r2 = (1 + a) * (np.maximum(inp, 0.0031308) ** (1/2.4)) - a
        return r1*mask + r2*(1.-mask)

    def lambda2xyz(lam):
        lammin = 380.
        lammax = 780.
        lam = np.asarray(lam, dtype=np.float64)
        lam = np.clip(lam, lammin, lammax)

        idx = (lam-lammin)/(lammax-lammin)*(_xyz.shape[1]-1)
        ii = np.maximum(0, np.minimum(79, int(idx)))
        w1 = 1.-(idx-ii)
        w2 = 1.-w1
        c = w1*_xyz[:, ii] + w2*_xyz[:, ii+1]
        return c

    def getxyz(n):
        E0, E1 = 1./700., 1./400.
        E = E0 + np.arange(n)*(E1-E0)/(n-1)
        res = np.zeros((3, n), dtype=np.float64)
        for i in range(n):
            res[:, i] = lambda2xyz(1./E[i])
        return res

    def to_logscale(arr, lo, hi):
        res = arr.clip(lo, hi)
        res = np.log(res/hi)
        tmp = np.log(hi/lo)
        res += tmp
        res /= tmp
        return res

    shp = spectral_cube.shape[:-1]+(3,)
    spectral_cube = spectral_cube.reshape((-1, spectral_cube.shape[-1]))
    xyz = getxyz(spectral_cube.shape[-1])
    xyz_data = np.tensordot(spectral_cube, xyz, axes=[-1, -1])
    xyz_data /= xyz_data.max()
    xyz_data = to_logscale(xyz_data, max(1e-3, xyz_data.min()), 1.)
    rgb_data = xyz_data.copy()
    for x in range(xyz_data.shape[0]):
        rgb_data[x] = _gammacorr(np.matmul(MATRIX_SRGB_D65, xyz_data[x]))
    rgb_data = rgb_data.clip(0., 1.)
    return rgb_data.reshape(shp)


def _find_closest(A, target):
    # A must be sorted
    idx = np.clip(A.searchsorted(target), 1, len(A)-1)
    idx -= target - A[idx-1] < A[idx] - target
    return idx


def _makeplot(name, block=True, dpi=None):
    import matplotlib.pyplot as plt

    if name is None:
        plt.show(block=block)
        if block:
            plt.close()
        return
    extension = os.path.splitext(name)[1]
    if extension in (".pdf", ".png", ".svg"):
        args = {}
        if dpi is not None:
            args['dpi'] = float(dpi)
        plt.savefig(name, **args)
        plt.close()
    else:
        raise ValueError("file format not understood")


def _limit_xy(**kwargs):
    import matplotlib.pyplot as plt

    x1, x2, y1, y2 = plt.axis()
    x1 = kwargs.pop("xmin", x1)
    x2 = kwargs.pop("xmax", x2)
    y1 = kwargs.pop("ymin", y1)
    y2 = kwargs.pop("ymax", y2)
    plt.axis((x1, x2, y1, y2))


def _register_cmaps():
    import matplotlib as mpl
    from matplotlib.colors import LinearSegmentedColormap

    try:
        if _register_cmaps._cmaps_registered:
            return
    except AttributeError:
        _register_cmaps._cmaps_registered = True

    planckcmap = {'red':   ((0., 0., 0.), (.4, 0., 0.), (.5, 1., 1.),
                            (.7, 1., 1.), (.8, .83, .83), (.9, .67, .67),
                            (1., .5, .5)),
                  'green': ((0., 0., 0.), (.2, 0., 0.), (.3, .3, .3),
                            (.4, .7, .7), (.5, 1., 1.), (.6, .7, .7),
                            (.7, .3, .3), (.8, 0., 0.), (1., 0., 0.)),
                  'blue':  ((0., .5, .5), (.1, .67, .67), (.2, .83, .83),
                            (.3, 1., 1.), (.5, 1., 1.), (.6, 0., 0.),
                            (1., 0., 0.))}
    he_cmap = {'red':   ((0., 0., 0.), (.167, 0., 0.), (.333, .5, .5),
                         (.5, 1., 1.), (1., 1., 1.)),
               'green': ((0., 0., 0.), (.5, 0., 0.), (.667, .5, .5),
                         (.833, 1., 1.), (1., 1., 1.)),
               'blue':  ((0., 0., 0.), (.167, 1., 1.), (.333, .5, .5),
                         (.5, 0., 0.), (1., 1., 1.))}
    fd_cmap = {'red':   ((0., .35, .35), (.1, .4, .4), (.2, .25, .25),
                         (.41, .47, .47), (.5, .8, .8), (.56, .96, .96),
                         (.59, 1., 1.), (.74, .8, .8), (.8, .8, .8),
                         (.9, .5, .5), (1., .4, .4)),
               'green': ((0., 0., 0.), (.2, 0., 0.), (.362, .88, .88),
                         (.5, 1., 1.), (.638, .88, .88), (.8, .25, .25),
                         (.9, .3, .3), (1., .2, .2)),
               'blue':  ((0., .35, .35), (.1, .4, .4), (.2, .8, .8),
                         (.26, .8, .8), (.41, 1., 1.), (.44, .96, .96),
                         (.5, .8, .8), (.59, .47, .47), (.8, 0., 0.),
                         (1., 0., 0.))}
    fdu_cmap = {'red':   ((0., 1., 1.), (0.1, .8, .8), (.2, .65, .65),
                          (.41, .6, .6), (.5, .7, .7), (.56, .96, .96),
                          (.59, 1., 1.), (.74, .8, .8), (.8, .8, .8),
                          (.9, .5, .5), (1., .4, .4)),
                'green': ((0., .9, .9), (.362, .95, .95), (.5, 1., 1.),
                          (.638, .88, .88), (.8, .25, .25), (.9, .3, .3),
                          (1., .2, .2)),
                'blue':  ((0., 1., 1.), (.1, .8, .8), (.2, 1., 1.),
                          (.41, 1., 1.), (.44, .96, .96), (.5, .7, .7),
                          (.59, .42, .42), (.8, 0., 0.), (1., 0., 0.))}
    pm_cmap = {'red':   ((0., 1., 1.), (.1, .96, .96), (.2, .84, .84),
                         (.3, .64, .64), (.4, .36, .36), (.5, 0., 0.),
                         (1., 0., 0.)),
               'green': ((0., .5, .5), (.1, .32, .32), (.2, .18, .18),
                         (.3, .8, .8),  (.4, .2, .2), (.5, 0., 0.),
                         (.6, .2, .2), (.7, .8, .8), (.8, .18, .18),
                         (.9, .32, .32), (1., .5, .5)),
               'blue':  ((0., 0., 0.), (.5, 0., 0.), (.6, .36, .36),
                         (.7, .64, .64), (.8, .84, .84), (.9, .96, .96),
                         (1., 1., 1.))}

    mpl.colormaps.register(cmap=LinearSegmentedColormap("Planck-like",
                                                        planckcmap))
    mpl.colormaps.register(cmap=LinearSegmentedColormap("High Energy",
                                                        he_cmap))
    mpl.colormaps.register(cmap=LinearSegmentedColormap("Faraday Map",
                                                        fd_cmap))
    mpl.colormaps.register(cmap=LinearSegmentedColormap("Faraday Uncertainty",
                                                        fdu_cmap))
    mpl.colormaps.register(cmap=LinearSegmentedColormap("Plus Minus", pm_cmap))


def _extract_list_kwargs(kwargs, keys, n):
    tmp = {}
    for kk in keys:
        val = kwargs.pop(kk, None)
        tmp[kk] = val if isinstance(val, list) else n*[val]
    with_legend = "label" in keys and any(ll is not None for ll in tmp["label"])
    return [{kk: vv[i] for kk, vv in tmp.items()} for i in range(n)], with_legend


def _plot_history(f, ax, **kwargs):
    import matplotlib.pyplot as plt
    from matplotlib.dates import DateFormatter, date2num

    for i, fld in enumerate(f):
        if not isinstance(fld, EnergyHistory):
            raise TypeError
    label = kwargs.pop("label", None)
    if not isinstance(label, list):
        label = [label] * len(f)
    alpha = kwargs.pop("alpha", None)
    if not isinstance(alpha, list):
        alpha = [alpha] * len(f)
    color = kwargs.pop("color", None)
    if not isinstance(color, list):
        color = [color] * len(f)
    size = kwargs.pop("s", None)
    if not isinstance(size, list):
        size = [size] * len(f)
    ax.set_title(kwargs.pop("title", ""))
    ax.set_xlabel(kwargs.pop("xlabel", ""))
    ax.set_ylabel(kwargs.pop("ylabel", ""))

    skip_timestamp_conversion = kwargs.pop("skip_timestamp_conversion", False)
    energy_differences = kwargs.pop("plot_energy_differences", False)

    plt.xscale(kwargs.pop("xscale", "linear"))
    default_yscale = 'linear' if not energy_differences else 'log'
    plt.yscale(kwargs.pop("yscale", default_yscale))

    mi, ma = np.inf, -np.inf

    for i, fld in enumerate(f):
        kwargs = {'alpha': alpha[i], 's': size[i], 'color': color[i]}

        if skip_timestamp_conversion:
            xcoord = fld.time_stamps
        else:
            xcoord = date2num([dt.fromtimestamp(ts) for ts in fld.time_stamps])

        if not energy_differences:
            ycoord = fld.energy_values
            ax.scatter(xcoord, ycoord, label=label[i], **kwargs)
        else:
            E = np.array(fld.energy_values)
            dE = E[1:] - E[:-1]
            xcoord = np.array(xcoord[1:])
            idx_pos = (dE > 0)
            idx_neg = (dE < 0)
            label_pos = label[i] + ' (pos)' if label[i] is not None else None
            label_neg = label[i] + ' (neg)' if label[i] is not None else None
            ax.scatter(xcoord[idx_pos], dE[idx_pos], marker='^',
                       label=label_pos, **kwargs)
            ax.scatter(xcoord[idx_neg], dE[idx_neg], marker='v',
                       label=label_neg, **kwargs)

        mi, ma = min([min(xcoord), mi]), max([max(xcoord), ma])

    delta = (ma-mi)*0.05
    if delta == 0.:
        delta = 1.
    ax.set_xlim((mi-delta, ma+delta))
    if not skip_timestamp_conversion:
        xfmt = DateFormatter('%H:%M')
        ax.xaxis.set_major_formatter(xfmt)
    _limit_xy(**kwargs)
    if label != ([None]*len(f)):
        plt.legend(loc="upper right")


def _plot1D(f, ax, **kwargs):
    import matplotlib.pyplot as plt

    dom = f[0].domain[0]
    add_kwargs, with_legend = _extract_list_kwargs(kwargs, ("label", "alpha", "color", "linewidth"), len(f))

    if isinstance(dom, RGSpace):
        plt.yscale(kwargs.pop("yscale", "linear"))
        npoints = dom.shape[0]
        dist = dom.distances[0]
        xcoord = np.arange(npoints, dtype=np.float64)*dist
        for i, fld in enumerate(f):
            ycoord = fld.val
            plt.plot(xcoord, ycoord, **add_kwargs[i])
    elif isinstance(dom, PowerSpace):
        plt.xscale(kwargs.pop("xscale", "log"))
        plt.yscale(kwargs.pop("yscale", "log"))
        xcoord = dom.k_lengths
        for i, fld in enumerate(f):
            ycoord = fld.val_rw()
            ycoord[0] = ycoord[1]
            plt.plot(xcoord, ycoord, **add_kwargs[i])
    else:
        raise RuntimeError("This point should never be reached")

    _limit_xy(**kwargs)
    if with_legend:
        ax.legend(loc="upper right")


[docs] def plottable1D(f): dom = f[0].domain is_1d_plottable = isinstance(dom[0], (RGSpace, PowerSpace)) is_1d_plottable &= (len(dom) == 1) and (len(dom.shape) == 1) is_1d_plottable &= all(dom == el.domain for el in f) return is_1d_plottable
[docs] def plottable2D(fld, f_space=1): dom = fld.domain if not isinstance(dom, DomainTuple) or len(dom) > 2: return False if f_space not in [0, 1]: return False x_space = 0 if len(dom) == 2: x_space = 1 - f_space if (not isinstance(dom[f_space], RGSpace)) or len(dom[f_space].shape) != 1: return False elif len(dom) == 0: return False if not isinstance(dom[x_space], (RGSpace, HPSpace, GLSpace)): return False if isinstance(dom[x_space], RGSpace) and not len(dom[x_space].shape) == 2: return False return True
def _plotting_args_2D(fld, f_space=1): from .sugar import makeField # check for multifrequency plotting have_rgb, rgb = False, None x_space = 0 dom = fld.domain if len(dom) == 1: x_space = 0 elif len(dom) == 2: x_space = 1 - f_space # Only one frequency? if dom[f_space].shape[0] == 1: fld = makeField(fld.domain[x_space], fld.val.squeeze(axis=dom.axes[f_space])) else: val = fld.val if f_space == 0: val = np.moveaxis(val, 0, -1) rgb = _rgb_data(val) have_rgb = True else: # "DomainTuple can only have one or two entries. raise ValueError('check plottable2D before using this function') return fld, x_space, have_rgb, rgb def _plot2D(f, ax, **kwargs): import matplotlib.pyplot as plt from mpl_toolkits.axes_grid1 import make_axes_locatable f = f[0] dom = f.domain f, x_space, have_rgb, rgb = _plotting_args_2D(f, kwargs.pop("freq_space_idx", 1)) foo = kwargs.pop("norm", None) norm = {} if foo is None else {'norm': foo} foo = kwargs.pop("aspect", None) aspect = {} if foo is None else {'aspect': foo} dom = dom[x_space] if not have_rgb: cmap = kwargs.pop("cmap", plt.rcParams['image.cmap']) if isinstance(dom, RGSpace): nx, ny = dom.shape dx, dy = dom.distances if have_rgb: im = ax.imshow( rgb, extent=[0, nx*dx, 0, ny*dy], origin="lower", **norm, **aspect) else: im = ax.imshow( f.val.T, extent=[0, nx*dx, 0, ny*dy], vmin=kwargs.get("vmin"), vmax=kwargs.get("vmax"), cmap=cmap, origin="lower", **norm, **aspect) cax = make_axes_locatable(ax).append_axes("right", size="5%", pad=0.05) plt.colorbar(im, cax=cax) _limit_xy(**kwargs) return elif isinstance(dom, (HPSpace, GLSpace)): from ducc0.healpix import Healpix_Base from ducc0.misc import GL_thetas xsize = 800 res, mask, theta, phi = _mollweide_helper(xsize) if have_rgb: res = np.full(shape=res.shape+(3,), fill_value=1., dtype=np.float64) if isinstance(dom, HPSpace): ptg = np.empty((phi.size, 2), dtype=np.float64) ptg[:, 0] = theta ptg[:, 1] = phi base = Healpix_Base(int(np.sqrt(dom.size//12)), "RING") if have_rgb: res[mask] = rgb[base.ang2pix(ptg)] else: res[mask] = f.val[base.ang2pix(ptg)] else: ra = np.linspace(0, 2*np.pi, dom.nlon+1) dec = GL_thetas(dom.nlat) ilat = _find_closest(dec, theta) ilon = _find_closest(ra, phi) ilon = np.where(ilon == dom.nlon, 0, ilon) if have_rgb: res[mask] = rgb[ilat*dom[0].nlon + ilon] else: res[mask] = f.val[ilat*dom.nlon + ilon] plt.axis('off') if have_rgb: plt.imshow(res, origin="lower") else: plt.imshow(res, vmin=kwargs.get("vmin"), vmax=kwargs.get("vmax"), norm=norm.get('norm'), cmap=cmap, origin="lower") plt.colorbar(orientation="horizontal") return raise ValueError("Field type not(yet) supported") def _plotHist(f, ax, **kwargs): add_kwargs, with_legend = _extract_list_kwargs(kwargs, ("label", "alpha", "color", "range"), len(f)) add_kwargs2 = { "log": kwargs.pop("log", False), "density": kwargs.pop("density", False), "bins": kwargs.pop("bins", 50) } for i, fld in enumerate(f): ax.hist(fld.val.ravel(), **add_kwargs[i], **add_kwargs2) if with_legend: ax.legend(loc="upper right") def _plot(f, ax, **kwargs): _register_cmaps() if isinstance(f[0], EnergyHistory): _plot_history(f, ax, **kwargs) return ax.set_title(kwargs.pop("title", "")) ax.set_xlabel(kwargs.pop("xlabel", "")) ax.set_ylabel(kwargs.pop("ylabel", "")) if plottable1D(f): _plot1D(f, ax, **kwargs) return if plottable2D(f[0], kwargs.get("freq_space_idx", 1)): _plot2D(f, ax, **kwargs) return _plotHist(f, ax, **kwargs)
[docs] class Plot:
[docs] def __init__(self): self._plots = [] self._kwargs = []
[docs] def add(self, f, **kwargs): """Add a figure to the current list of plots. Notes ----- After doing one or more calls `add()`, one needs to call `output()` to show or save the plot. Parameters ---------- f : :class:`nifty8.field.Field` or list of :class:`nifty8.field.Field` or None If `f` is a single Field, it must be defined on a single `RGSpace`, `PowerSpace`, `HPSpace`, `GLSpace`. If it is a list, all list members must be Fields defined over the same one-dimensional `RGSpace` or `PowerSpace`. If `f` is `None`, an empty panel will be displayed. Optional Parameters ------------------- title: string Title of the plot. xlabel: string Label for the x axis. ylabel: string Label for the y axis. [xyz]min, [xyz]max: float Limits for the values to plot. cmap: string Color map to use for the plot (if it is a 2D plot). linewidth: float or list of floats Line width. label: string of list of strings Annotation string. alpha: float or list of floats Transparency value. freq_space_idx: int for multi-frequency plotting: index of frequency space in domain """ if f is None: self._plots.append(None) self._kwargs.append({}) return if isinstance(f, (MultiField, Field, EnergyHistory)): f = [f] if not isinstance(f[0], (MultiField, Field, EnergyHistory)): raise TypeError("Incorrect data type. You can only add Fields or EnergyHistories, or None") if hasattr(f, "__len__") and all(isinstance(ff, MultiField) for ff in f): for kk in f[0].domain.keys(): self._plots.append([ff[kk] for ff in f]) mykwargs = kwargs.copy() if 'title' in kwargs: mykwargs['title'] = "{} {}".format(kk, kwargs['title']) else: mykwargs['title'] = "{}".format(kk) self._kwargs.append(mykwargs) return if isinstance(f[0], EnergyHistory): dom = None else: dom = f[0].domain if isinstance(dom, DomainTuple) \ and any(isinstance(dd, RGSpace) for dd in dom) \ and not "freq_space_idx" in kwargs \ and not plottable2D(f[0], kwargs.get("freq_space_idx", 1)): from .sugar import makeField dims = [len(dd.shape) for dd in dom] # One space is 2d, the rest is 1d if np.sum(np.array(dims) == 2) == 1 and np.sum(np.array(dims) == 1) == len(dims) - 1: twod_index = dims.index(2) sizes = [dd.size for dd in dom] del(sizes[twod_index]) for multi_index in product(*[tuple(range(ii)) for ii in sizes]): multi_index = list(multi_index) multi_index.insert(twod_index, slice(None)) multi_index.insert(twod_index, slice(None)) for ifield in range(len(f)): arr = f[ifield].val[tuple(multi_index)] myassert(arr.ndim == 2) self._plots.append([makeField(dom[twod_index], arr)]) self._kwargs.append(kwargs) return self._plots.append(f) self._kwargs.append(kwargs)
[docs] def output(self, **kwargs): """Plot the accumulated list of figures. Parameters ---------- title: string Title of the full plot. nx, ny: int Number of subplots to use in x- and y-direction. Default: square root of the numer of plots, rounded up. xsize, ysize: float Dimensions of the full plot in inches. Default: 6. name: string If left empty, the plot will be shown on the screen, otherwise it will be written to a file with the given name. Supported extensions: .png and .pdf. Default: None. block: bool Override the blocking behavior of the non-interactive plotting mode. The plot will not be closed in this case but is left open! """ try: import matplotlib.pyplot as plt except ImportError: warn("Since matplotlib is not installed, NIFTy will not generate any plots.") return nplot = len(self._plots) if nplot == 0: raise ValueError("Use .add to add plots to your plotting routine.") fig = plt.figure() if "title" in kwargs: plt.suptitle(kwargs.pop("title")) nx = kwargs.pop("nx", 0) ny = kwargs.pop("ny", 0) if nx == ny == 0: ny = int(np.ceil(np.sqrt(nplot))) nx = int(np.ceil(nplot/ny)) myassert(nx*ny >= nplot) elif nx == 0: nx = int(np.ceil(nplot/ny)) elif ny == 0: ny = int(np.ceil(nplot/nx)) if nx*ny < nplot: raise ValueError( 'Figure dimensions not sufficient for number of plots. ' 'Available plot slots: {}, number of plots: {}' .format(nx*ny, nplot)) xsize = kwargs.pop("xsize", 6*nx) ysize = kwargs.pop("ysize", 6*ny) fig.set_size_inches(xsize, ysize) for i in range(nplot): if self._plots[i] is None: continue ax = fig.add_subplot(ny, nx, i+1) _plot(self._plots[i], ax, **self._kwargs[i]) fig.tight_layout() _makeplot(kwargs.pop("name", None), block=kwargs.pop("block", True), dpi=kwargs.pop("dpi", None))