Source code for csky.plotting

# plotting.py

"""Plotting of skymaps, test statistic distributions, PDFs, etc.

TODO: This module is a complete mess.  Starting point for things to clean up:

* abandon colormaps.py; everyone should be using modern matplotlib by now.
* get rid of Plot and SkyPlot
* make SkyPlotter more flexible and generally tidy its interface
* probably get rid of plot_energy_pdf and plot_gauss_2d_angres_param
* add docstrings to everything that remains

"""


from __future__ import print_function

import copy
from cycler import cycler
import healpy
hp = healpy

try:
    from itertools import izip
    zip = izip
except ImportError:
    pass

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
pi = np.pi
import os
import sys

try:
    from .colormaps import viridis, viridis_r, plasma, plasma_r, magma, magma_r, inferno, inferno_r
except:
    pass
from . import quiet_healpy, utils

try:
    import histlite as hl
except:
    from icecube import histlite as hl

def _ensure_dir(dirname):
    """Make sure ``dirname`` exists and is a directory."""
    if not os.path.isdir(dirname):
        try:
            os.makedirs(dirname)   # throws if exists as file
        except OSError as e:
            if e.errno != os.errno.EEXIST:
                raise
    return dirname

skymap_cmap = {
    'blue': ((0.0, 0.0, 1.0),
             (0.05, 1.0, 1.0),
             (0.4, 1.0, 1.0),
             (0.6, 1.0, 1.0),
             (0.7, 0.2, 0.2),
             (1.0, 0.0, 0.0)),
    'green': ((0.0, 0.0, 1.0),
              (0.05, 1.0, 1.0),
              (0.5, 0.0416, 0.0416),
              (0.6, 0.0, 0.0),
              (0.8, 0.5, 0.5),
              (1.0, 1.0, 1.0)),
    'red':   ((0.0, 0.0, 1.0),
              (0.05, 1.0, 1.0),
              (0.5, 0.0416, 0.0416),
              (0.6, 0.0416, 0.0416),
              (0.7, 1.0, 1.0),
              (1.0, 1.0, 1.0))}

skymap_cmap = mpl.colors.LinearSegmentedColormap('icecube', skymap_cmap, 256)

def mpl_tex_rc(sans=False):
    import matplotlib.pyplot as plt
    import matplotlib as mpl
    if sans:
        plt.rc('text', usetex=True)
        plt.rc('font', family='sans-serif')
        #plt.rc('font', **{'sans-serif': 'Computer Modern Sans Serif'})
        mpl.rcParams['text.latex.preamble'] = [
            r'\usepackage{amsmath}',
            r'\usepackage{amssymb}',
            r'\usepackage{amsthm}',
            r'\usepackage{bm}',
            r'\usepackage{sansmath}',
            r'\SetSymbolFont{operators}   {sans}{OT1}{cmss} {m}{n}'
            r'\SetSymbolFont{letters}     {sans}{OML}{cmbrm}{m}{it}'
            r'\SetSymbolFont{symbols}     {sans}{OMS}{cmbrs}{m}{n}'
            r'\SetSymbolFont{largesymbols}{sans}{OMX}{iwona}{m}{n}'
            r'\sansmath'
        ]
    else:
        plt.rc('text', usetex=True)
        plt.rc('font', family='serif')
        #plt.rc('font', serif='Computer Modern Roman')
        mpl.rcParams['text.latex.preamble'] = [
            r'\usepackage{amsmath}',
            r'\usepackage{amssymb}',
            r'\usepackage{amsthm}',
            r'\usepackage{bm}',
        ]

def saving(plot_dir, basename, fig=None, exts='png pdf', **kw):
    utils.ensure_dir(plot_dir)
    print('-> {}/{}'.format(plot_dir, basename))
    sys.stdout.flush()
    if fig is None:
        fig = plt.gcf()
    for ext in exts.split():
        fig.savefig('{}/{}.{}'.format(plot_dir, basename, ext), **kw)


[docs] class Plot(object): """ Base class for plots. """
[docs] def __init__(self, fig=None): if fig is None: fig = plt.figure() elif isinstance(fig, int): fig = plt.figure(fig) self.fig = fig self.fig.clf() try: self.fignum = self.fig.number except: self.fignum = None
def save(self, dir, basename, exts='png pdf'): _ensure_dir(dir) for ext in exts.split(): filename = '{}/{}.{}'.format(dir, basename, ext) print('-> {} ...'.format(filename)) self.fig.savefig(filename) def close(self): plt.close(self.fig)
[docs] class SkyPlot(Plot): """ Skymap plotter. """
[docs] def __init__(self, fig=None, m=None, rot=None, coord='C', *a, **kw): Plot.__init__(self, fig) self.m = m self.coord = coord if rot is None: rot = 180 if coord[-1] == 'C' else 0 self.rot = rot self.a = a self.kw = kw self.kw.setdefault('unit', '') self.kw.setdefault('title', '') self.kw.setdefault('format', '%.1f') self.kw.setdefault('cmap', 'afmhot') if isinstance(self.kw['cmap'], str): self.kw['cmap'] = plt.get_cmap(self.kw['cmap']) self.kw['cmap'].set_under('w') healpy.mollview( m, fig=self.fignum, rot=self.rot, coord=self.coord, *a, **kw ) self.mollax = self.fig.get_axes()[0]
def graticule(self, wpad=None, *a, **kw): fig = plt.figure(self.fignum) self.kw.setdefault('alpha', .5) healpy.graticule(*a, **kw) usetex = mpl.rcParams['text.usetex'] if self.coord[-1] == 'C': locs = 0, 359.9 if usetex: labels = r'\textbf{0h} \textbf{24h}' else: labels = '0h 24h' elif self.coord[-1] == 'G': locs = -180, 179.9 if usetex: labels = r'\textbf{--180}$^\circ$ \textbf{+180}$^\circ$' else: labels = '$-180^\circ$ $+180^\circ$' labels = labels.split() lons = self.rot - 180, self.rot + 179.9 healpy.projtext( lons[0], 0, labels[0], lonlat=True, ha='left', va='center', withdash=True, dashpad=2, dashlength=.01, dashdirection=1 ) healpy.projtext( lons[1], 0, labels[1], lonlat=True, ha='right', va='center', withdash=True, dashpad=8, dashlength=.01, ) for lat in [-60, -30, 30, 60]: lon = 179.9 + self.rot healpy.projtext( lon, 1.1 * lat, format(lat, '+.0f'), lonlat=True, ha='right', va='center', withdash=True, dashpad=abs(lat), dashlength=.01 ) if wpad is not None: bounds = list(self.mollax.get_position().bounds) if bounds[0] < wpad: bounds[0] = wpad bounds[2] = 1 - 2*wpad self.mollax.set_position(bounds) plt.draw()
[docs] def colorbar(self, unit=''): """ Draw a colorbar. """ log = self.kw.get('norm', '') == 'log' vmin, vmax = np.nanmin(self.m), np.nanmax(self.m) if log: vmin = 10**np.ceil(np.log10 (vmin)) vmax = 10**np.floor(np.log10 (vmax)) kw = dict( orientation='horizontal', fraction=.1, shrink=.5, pad=.05, ) if log: kw.update(dict( format=mpl.ticker.LogFormatterMathtext(), ticks=10**np.arange(np.log10 (vmin), np.log10 (vmax) + 1) )) cb = self.fig.colorbar( self.mollax.get_images()[0], ax=self.mollax, **kw ) if unit: cb.set_label(unit) return cb
def show_gp(self, **kw): lon = np.linspace(0, 360, 1000) lat = np.zeros_like(lon) kw['lonlat'] = True kw['coord'] = 'G' + self.coord[-1] healpy.projplot(lon, lat, **kw) def show_gc(self, **kw): coord = 'G' + self.coord[-1] kw['lonlat'] = True kw['coord'] = 'G' + self.coord[-1] healpy.projscatter(0, 0, **kw)
[docs] class SkyPlotter(object): """ Skymap plotter using matplotlib directly for projections. """
[docs] def __init__(self, coord='C', projection='aitoff', pc_kw={}, cb_kw={}): self.coord = coord self.projection = projection self.pc_kw = pc_kw self.cb_kw = cb_kw if self.coord not in ['C', 'G']: raise NotImplementedError('coord "{}" not yet supported'.format(self.coord)) self.cb_kw.setdefault('orientation', 'horizontal') self.cb_kw.setdefault('shrink', .5) self.cb_kw.setdefault('pad', .08)
def thetaphi_to_mpl(self, theta, phi): theta, phi = np.atleast_1d(theta), np.atleast_1d(phi) x = pi - phi x[x > pi] -= 2*pi y = pi/2 - theta return x, y def plot_gp(self, ax, color='.5', s=.3, strip=0., **kw): l = np.linspace(-pi, pi, 3000) theta_b = pi/2 * np.ones_like(l) if self.coord == 'C': r = healpy.Rotator(coord='GC') theta, phi = r(theta_b, l) if strip>0: theta_up = theta_b +strip theta_down = theta_b -strip theta_up, phi_up = r(theta_up, l) theta_down, phi_down = r(theta_down, l) elif self.coord == 'G': theta, phi = theta_b, l if strip>0: theta_up = theta +strip theta_down = theta-strip else: raise ValueError('bad coord {}'.format(self.coord)) x, y = self.thetaphi_to_mpl(theta, phi) ax.scatter(x, y, color=color, marker='.', s=s, **kw) if strip> 0: x_up, y_up = self.thetaphi_to_mpl(theta_up, phi_up) x_down, y_down = self.thetaphi_to_mpl(theta_down, phi_down) ax.scatter(x_up, y_up, color=color, marker='.', s=s, **kw) ax.scatter(x_down, y_down, color=color, marker='.', s=s, **kw) #ax.fill_between(x, y_up, y_down, alpha=0.2) def plot_sgp(self, ax, color='.5', s=.3, **kw): from icecube import astro ras = np.linspace(0.,361., 1500)*np.pi/180. decls_0=np.zeros(len(ras)) if self.coord == 'C': l, b=astro.supergal_to_equa(ras,decls_0) theta_b=np.pi/2 - b theta, phi = theta_b, l elif self.coord == 'G': l, b=astro.supergal_to_gal(ras,decls_0) theta_b=np.pi/2 - b theta, phi = theta_b, l else: raise ValueError('bad coord {}'.format(self.coord)) x, y = self.thetaphi_to_mpl(theta, phi) ax.scatter(x, y, color=color, marker='.', s=s, **kw) def plot_gc(self, ax, color='.5', s=15, **kw): l = 0 theta_b = pi/2 if self.coord == 'C': r = healpy.Rotator(coord='GC') theta, phi = r(theta_b, l) elif self.coord == 'G': theta, phi = theta_b, l else: raise ValueError('bad coord {}'.format(self.coord)) x, y = self.thetaphi_to_mpl(theta, phi) ax.scatter(x, y, color=color, s=s, **kw) def rotate_map(self, m, **kw): nside = healpy.get_nside(m) r = healpy.rotator.Rotator(**kw) theta, phi = healpy.pix2ang(nside, np.arange(len(m))) theta_rot, phi_rot = r(theta, phi) rot_map = healpy.get_interp_val(m, theta_rot, phi_rot) return rot_map def map_to_latlonz(self, m, N=1000): x = np.linspace(pi, -pi, 2*N) y = np.linspace(pi, 0, N) X, Y = np.meshgrid(x, y) r = healpy.rotator.Rotator(rot=(-180, 0, 0)) YY, XX = r(Y.ravel(), X.ravel()) pix = healpy.ang2pix(healpy.get_nside(m), YY, XX) Z = np.reshape(m[pix], X.shape) lon = x[::-1] lat = pi/2 - y return lat, lon, Z def plot_map(self, ax, m, unit='', n_ticks=5, pc_kw={}, cb_kw={}, ticks=None, log=False, titleticks=False, nohr=False): if m is None: return lat, lon, Z = self.map_to_latlonz(m) kw = copy.deepcopy(self.pc_kw) kw.update(pc_kw) if log: vmin = kw.pop('vmin', None) vmax = kw.pop('vmax', None) kw['norm'] = mpl.colors.LogNorm(vmin, vmax) pc = ax.pcolormesh(lon, lat, Z, **kw) usetex = mpl.rcParams['text.usetex'] def yfmt(n, *a): n = np.degrees(n) if titleticks and n > 70: return '' fmt = r'${:+.0f}^\circ$' return fmt.format(n) if n else '' ax.xaxis.set_ticks(np.radians(np.arange(-180, 180, 30))) ax.xaxis.set_major_formatter(mpl.ticker.FuncFormatter(lambda *a: '')) ax.yaxis.set_major_formatter(mpl.ticker.FuncFormatter(yfmt)) if not nohr: kw = dict(xycoords='axes fraction', textcoords='offset pixels', verticalalignment='center') ax.annotate(r'0h', xy=(1, .5), xytext=(10, 0), horizontalalignment='left', **kw) ax.annotate(r'24h', xy=(0, .5), xytext=(-10, 0), horizontalalignment='right', **kw) kw = copy.deepcopy(self.cb_kw) kw.update(cb_kw) cb = ax.figure.colorbar(pc, ax=ax, **kw) if unit: cb.set_label(unit) vmin, vmax = pc.get_clim() if ticks is None: if not log: ticks = np.linspace(vmin, vmax, n_ticks) #else: # a, b = np.round(np.log10 ([vmin, vmax])) # ticks = np.logspace(a, b, n_ticks) if not log: cb.set_ticks(ticks) return pc, cb
def ud_grade_interp(m, nside): from .trial import SkyScanner old_nside = hp.get_nside(m) new_ra, new_dec = SkyScanner.get_healpix_grid(nside) new_m = hp.get_interp_val(m, pi/2 - new_dec, new_ra) return new_m def plot_energy_pdf(ax, ana, gamma, bins=400, range=None, **kw): pdf = ana.energy_pdf_ratio_model def f(sd, lE): return pdf(utils.Events(sindec=sd, log10energy=lE)) (gamma=gamma)[0] if range is None: range = pdf.range h = hl.hist_from_eval(f, vectorize=False, bins=(bins, bins), range=range) return hl.plot2d(ax, h, **kw) def plot_gauss_2d_angres_param(sigma_param, bins=400, range=None, figscale=3, **kw): sp = sigma_param smoothed_bins = int(sp.hdec_base is not sp.hdec) fitted = int(sp.sdec is not None) ncol = 1 + smoothed_bins + fitted nrow = 3 fig, axs = plt.subplots(nrow, ncol, figsize=(figscale * ncol, figscale * nrow)) axs = np.array(axs) out = np.empty_like(axs) # dec i = 0 out[i,0] = hl.plot2d(axs[i,0], sp.hdec_base * 180/pi, **kw) if smoothed_bins: j = smoothed_bins out[i,j] = hl.plot2d(axs[i,j], sp.hdec * 180/pi, **kw) if fitted: def fdec(sd, lE): return sp.sdec(sd, lE) hdec = hl.hist_from_eval(fdec, vectorize=False, bins=400, range=sp.range) j = smoothed_bins + fitted out[i,j] = hl.plot2d(axs[i,j], hdec * 180/pi, **kw) # ra i = 1 out[i,0] = hl.plot2d(axs[i,0], sp.hra_base * 180/pi, **kw) if smoothed_bins: j = smoothed_bins out[i,j] = hl.plot2d(axs[i,j], sp.hra * 180/pi, **kw) if fitted: def fra(sd, lE): return sp.sra(sd, lE) hra = hl.hist_from_eval(fra, vectorize=False, bins=400, range=sp.range) j = smoothed_bins + fitted out[i,j] = hl.plot2d(axs[i,j], hra * 180/pi, **kw) # norm nkw = copy.deepcopy(kw) nkw.pop('vmin', 0) nkw['vmax'] = 1 i, j = 2, smoothed_bins out[i,j] = hl.plot2d(axs[i,j], sp.hnorm, **nkw) if fitted: def fnorm(sd, lE): return sp.snorm(sd, lE) hnorm = hl.hist_from_eval(fnorm, vectorize=False, bins=400, range=sp.range) j = smoothed_bins + fitted out[i,j] = hl.plot2d(axs[i,j], hnorm, **nkw) for ax in np.ravel(axs): ax.set_xlabel(r'$\sin(\delta_\mathsf{reco})$') ax.set_ylabel(r'$\log_{10}(E_\mathsf{reco})$') for o in out[0]: o['colorbar'].set_label(r'estimated $\sigma_\delta~[^\circ]$') for o in out[1]: o['colorbar'].set_label(r'estimated $\sigma_\alpha~[^\circ]$') for (o,ax) in zip(out[2], axs[2]): if o is None: ax.set_visible(False) else: o['colorbar'].set_label(r'normalization') return fig, axs, out soft_colors = ['#004466', '#d06050', '#2aca80', '#dd9388', '#caca68'] friendly_colors = ['#184b68', '#cf4d30', '#62badb', '#e797b4', '#eec9b4', '#f7dede'] mpl_colors_orig = np.array(mpl.rcParamsDefault['axes.prop_cycle'].by_key()['color']) mpl_colors = mpl_colors_orig[[0, 3, 2, 1, 4, 5, 6, 7]] def mrichman_mpl(tex=True, sans=True, colors=mpl_colors): #plt.rc('axes', color_cycle=soft_colors) if mpl.__version__ > '1.5.1': mpl.rcParams['axes.prop_cycle'] = cycler('color', mpl_colors) else: mpl.rcParams['axes.color_cycle'] = mpl_colors mpl.rcParams['grid.linestyle'] = ':' mpl.rcParams['lines.linewidth'] = 2 mpl.rcParams['figure.facecolor'] = mpl.rcParams['savefig.facecolor'] = 'w' mpl.rcParams['legend.framealpha'] = 1 # my laptop display is quad-HD mpl.rcParams['figure.dpi'] = 120 mpl.rcParams['savefig.dpi'] = 150 if tex: mpl_tex_rc(sans=sans)