# conf.py
"""Configuration tools for csky."""
import sys
_PY2 = sys.version_info[0] == 2
if _PY2:
import funcsigs
else:
import inspect
import copy
import functools
import numpy as np
from . import pdf, inj, hyp, llh, trial, selections, analysis, utils
def get_signature(f, is_class=True):
if _PY2:
if hasattr(f, '__init__') and is_class: #and (f is not get_utf_prior):
return funcsigs.signature(f.__init__)
else:
return funcsigs.signature(f)
else:
return inspect.signature(f)
[docs]
def do_bind(f, conf, is_class=True):
"""Partially bind ``f`` by applying configuration ``conf``."""
empty = funcsigs._empty if _PY2 else inspect._empty
sig = get_signature(f, is_class=is_class)
accepted = {
p.name: p.default
for p in sig.parameters.values()
if p.default is not empty}
final_kw = {key: conf[key] for key in conf if key in sig.parameters}
if f in conf:
kw2 = {key: conf[f][key] for key in conf[f] if key in sig.parameters}
final_kw.update(kw2)
binding = sig.bind_partial(**final_kw)
return functools.partial(f, **binding.arguments), final_kw, accepted
[docs]
def get_obj(cls, conf, subana=None, is_class=True, **kw):
"""Get object of type ``cls``."""
constr, final_kw, accepted = do_bind(cls, conf, is_class=is_class)
if subana is not None:
signature = get_signature(cls, is_class=is_class)
if 'ana' in signature.parameters:
kw['ana'] = kw.get('ana', subana)
if 'bg_ev' in signature.parameters:
kw['bg_ev'] = kw.get('bg_ev', subana.bg_data)
if 'sig_ev' in signature.parameters:
kw['sig_ev'] = kw.get('sig_ev', subana.sig)
try:
out = constr(**kw)
final_kw.update(kw)
if 'conf' in final_kw:
final_kw['conf']
out._csky_conf = final_kw
out._csky_defaults = accepted
return out
except:
print('exception raised trying to construct {}'.format(cls))
raise
[docs]
def overlay(conf, new_conf={}):
"""Override elements of ``conf`` with other arguments."""
out = {}
out.update(conf)
for key in new_conf:
try:
if key in conf and isinstance(new_conf[key], dict) and isinstance(conf[key], dict):
out[key] = overlay(conf[key], new_conf[key])
else:
out[key] = new_conf[key]
except:
print('error handling overlay of key "{}"'.format(key))
raise
return out
def get_utf_prior(ana, n_src=1, dt_max=None, use_prior=True, use_seasons=False, box=False):
if not use_prior:
return
season_boundaries = []
season_livetimes = []
for (i, a) in enumerate(ana):
mi = a.data.mjd.min()
if i == 0:
mi -= 0.1
season_boundaries.append(mi)
season_livetimes.append(a.livetime / 86400) # sec->day
season_boundaries.append(a.data.mjd.max() + 0.1)
total_livetime = np.sum(season_livetimes)
root_2pi = np.sqrt(2*np.pi)
def prior(**params):
t0, dt = pdf.UntriggeredTimePDFRatioModel.get_t0_dt(n_src=n_src, **params)
frac_livetime = utils.get_frac_during_livetime_single(ana, t0, dt, 5)
dt = dt * frac_livetime
t0 = np.clip(t0, min(season_boundaries), max(season_boundaries))
if dt_max is None:
if use_seasons:
i_season = np.searchsorted(season_boundaries, t0[0]) - 1
livetime = season_livetimes[i_season]
else:
livetime = total_livetime
else:
livetime = dt_max
F = 1. / livetime
if not box:
dt = 2 * dt
ontime_livetime = np.sum(dt)
return min(1, F * ontime_livetime)
return prior
[docs]
def get_llh(a, conf, **kw):
"""Get a :class:`csky.llh.LLHModel`."""
# need to check energy first: if 'customflux', use custom acceptance to match
energy = conf.get('energy', 'fit')
flux = conf.get('flux', None)
keep_pdfs = conf.get('keep_pdfs', False)
pdfs_args = {'keep_pdfs': keep_pdfs}
if energy == 'customflux' and flux is not None:
flux = np.atleast_1d(flux)
acc_param = []
for f in flux:
acc_param.append(a.get_custom_flux_acc_parameterization(f))
if len(acc_param) == 1: acc_param = acc_param[0]
else:
flux = np.atleast_1d(flux)
if len(flux) >1:
acc_param = []
for f in flux:
acc_param.append(a.acc_param)
else:
acc_param = a.acc_param
conf = overlay(conf, dict(acc_param=acc_param))
# SPACE
space_names = dict(
ps=pdf.PointSourceSpacePDFRatioModel,
fitps=pdf.FitPointSourceSpacePDFRatioModel,
template=pdf.TemplateSpacePDFRatioModel,
prior=pdf.PriorSpacePDFRatioModel,
generic=pdf.AccWeightedGenericPDFRatioModel)
space = conf.get('space', 'ps')
space = space_names.get(space, space)
space_model = get_obj(space, conf, a)
# TIME
time_names = dict(
utf=pdf.UntriggeredTimePDFRatioModel,
lc=pdf.BinnedTimePDFRatioModel,
transient=pdf.TransientTimePDFRatioModel)
time_name = conf.get('time', None)
time = time_names.get(time_name, time_name)
if time is None:
models = [space_model]
else:
time_model = get_obj(time, conf, a)
spacetime_model = pdf.SpaceTimePDFRatioModel(space_model, time_model)
models = [spacetime_model]
# ENERGY
if energy:
if energy == 'fit':
energy_model = a.energy_pdf_ratio_model
energy = type(energy_model)
elif energy == 'customflux':
flux = conf['flux']
flux = np.atleast_1d(flux)
# keep pdfs for stacking
energy_model = []
if len(flux) >= 2:
print("loading energy pdfs of ", len(flux)," fluxes")
for f in flux:
energy_pdf = a.get_custom_flux_energy_pdf_ratio_model(f)
energy_pdf.keep_pdfs = keep_pdfs
energy_model.append(energy_pdf)
energy = type(energy_pdf)
else:
energy_model = a.get_custom_flux_energy_pdf_ratio_model(flux)
energy = type(energy_model)
else:
print("WARNING: unable to find energy conf")
try:
if len(energy_model) > 1 and keep_pdfs:
print("keeping energy models for all fluxes, make sure intended")
models += energy_model
except:
models.append(energy_model)
# ANY OTHERS!?
for key in conf:
if isinstance(key, type):
if pdf.PDFRatioModel in key.mro():
if key not in (space, time, energy):
model = get_obj(key, conf, a)
models.append(model)
if len(models) >= 2:
prm = pdf.MultiPDFRatioModel(*models, **pdfs_args)
else:
prm = models[0]
if( time_name=="transient" ):
extended = conf.get('extended', True)
lm = get_obj(llh.LLHModel, conf, a, pdf_ratio_model=prm, N_bg=time_model.nb,
extended=extended)
else:
lm = get_obj(llh.LLHModel, conf, a, pdf_ratio_model=prm, keep_pdfs=conf.get('keep_pdfs', False))
return lm
[docs]
def get_injs(a, llh_model, conf, do_inj=True, llh_conf={}, **kw): ## to be modified
"""Get :class:`csky.inj.Injector` instances for TRUTH, background, and signal"""
prm = llh_model.pdf_ratio_model
keep = list(set(prm.keep + conf.get('extra_keep', [])))
data = a.data
use_time = conf.get('time', None)=="transient"
full_sky = conf.get('full_sky', False)
if 'template' in conf or 'gpbg_conf' in conf:
selected_data = data
elif( use_time ):
time_model = prm.get_time_model()
selector = get_obj(inj.DecBandSelector, llh_conf)
t_selector = inj.TimeWindowSelector(time_model.src_mjd_min, time_model.src_mjd_max)
selected_data = t_selector(selector(a.data))
else:
selector = get_obj(inj.DecBandSelector, llh_conf)
selected_data = selector(data)
truth = get_obj(inj.DataInjector, conf, a, data=selected_data, keep=keep, randomizers=[])
randomize = conf.get('randomize', ['grl' if 'time' in conf else 'ra'])
randomize = np.atleast_1d(randomize)
randomizers = []
for r in randomize:
randomizer_names = dict(
ra=inj.RARandomizer,
dec=inj.DecRandomizer,
energy=inj.EnergyRandomizer,
grl=inj.MJDGRLRandomizer)
r = randomizer_names.get(r, r)
rkw = {}
if r is inj.MJDGRLRandomizer:
rkw['grl'] = a.grl
keep.append('azimuth')
randomizers.append(get_obj(r, conf, **rkw))
# safety check to make sure at least one of the base randomizers is given
base_randomizers = (
inj.RARandomizer, inj.MJDGRLRandomizer, inj.MJDShuffleRandomizer
)
if not any(isinstance(r, base_randomizers) for r in randomizers):
raise ValueError(
'Must specify at least one of {}, but given are {}'.format(
base_randomizers, randomizers))
if use_time:
bg = get_obj(inj.DataOnOffInjector, conf, a, time_model=time_model, keep=keep,
randomizers=[inj.MJDShuffleRandomizer()], full_sky=full_sky)
elif 'gpbg_conf' in conf:
print("WARNING: using gp template injecting gp events as bkg")
gp_conf = conf['gpbg_conf']
gp_conf['llh_kw'] = dict(conf=gp_conf)
gp_llh_model=get_llh(a, gp_conf)
template_model=gp_llh_model.pdf_ratio_model
#template_model = gp_prm.acc_weighted_model
gp_conf['template_model'] = template_model
bg = get_obj(inj.TemplateBGInjector, gp_conf, a,
keep=keep, randomizers=randomizers, mcbg = gp_conf['mcbg'])
elif 'bg_weight_names' in conf:
bg = get_obj(inj.MCBackgroundInjector, conf, a,
keep=keep, randomizers=randomizers)
else:
bg = get_obj(inj.DataInjector, conf, a, data=selected_data, keep=keep,
randomizers=randomizers)
if do_inj:
sig_names = dict(
ps=inj.PointSourceInjector,
tw=inj.PointSourceTimeWindowInjector,
lc=inj.PointSourceBinnedTimeInjector,
transient=inj.TransientInjector,
template=inj.TemplateInjector,
prior=inj.SpatialPriorInjector,
rednoise=inj.PointSourceRedNoiseInjector)
sig = conf.get('sig', 'ps')
sig = sig_names.get(sig, sig)
sig_kw = dict(keep=keep)
sig_kw.update(conf.get('sig_kw', {}))
if sig is inj.TemplateInjector:
template_model = prm.acc_weighted_model
if isinstance(template_model, pdf.SpaceTimePDFRatioModel):
template_model = template_model.space_model
sig_kw['template_model'] = template_model
if sig is inj.TransientInjector:
use_prior = conf.get('space', 'ps')=="prior"
sig = get_obj(sig, conf, a, time_model=time_model,
spatial_prior=use_prior, **sig_kw)
else:
sig = get_obj(sig, conf, a, **sig_kw)
else:
sig = None
return truth, bg, sig
[docs]
def get_trial_runner(conf={}, inj_conf={}, **kw):
"""Get a :class:`csky.trial.TrialRunner`."""
conf = overlay(overlay(CONF, conf), kw)
ana = conf.get('ana')
if ana is None:
raise ValueError(
'must provide \'ana\' in cy.CONF, conf argument, or as keyword argument ana=___')
if 'time' in conf:
conf.setdefault('t0_min', ana.mjd_min)
conf.setdefault('t0_max', ana.mjd_max)
if conf['time'] in ('utf', pdf.UntriggeredTimePDFRatioModel):
conf.setdefault('prior', get_obj(get_utf_prior, conf, ana, is_class=False))
conf.setdefault('t0_max', ana.mjd_max)
if 'bins_energy' not in conf and 'corona_flux' not in conf:
conf.setdefault('flux', hyp.PowerLawFlux(2))
inj_conf = overlay(conf, inj_conf)
if 'template' in conf and 'gpbg_conf' not in conf:
conf.setdefault('space', 'template')
inj_conf.setdefault('sig', 'template')
conf['llh_kw'] = dict(conf=conf)
if 'src' in conf:
conf['llh_kw']['src'] = conf['src'] # can have different weights here than injected truth src
conf['inj_kw'] = dict(conf=inj_conf, do_inj=kw.get('inj', ana.has_sig), llh_conf=conf)
if 'gpbg_conf' in conf:
conf['gpbg_conf'].setdefault('space', 'template')
conf['gpbg_conf'].setdefault('sig', 'template')
conf['gpbg_conf']['inj_conf'].setdefault('sig', 'template')
conf['gpbg_conf']['inj_kw'] = dict(conf=conf['gpbg_conf']['inj_conf'], do_inj=kw.get('inj', ana.has_sig), llh_conf=conf['gpbg_conf'])
fa = conf.get('fitter_args', {})
seeder = conf.get('seeder', None)
if seeder is not None:
fa['seeder'] = seeder
prior = conf.get('prior', None)
if prior is not None:
fa['prior'] = prior
conf['fitter_args'] = fa
return get_obj(
trial.TrialRunner, conf,
get_llh=get_llh, get_injs=get_injs)
[docs]
def get_multiflare_trial_runner(conf={}, **kw):
"""Get a :class:`csky.trial.MultiflareTrialRunner`."""
box_conf = {
'time': 'utf',
'box': True,
'box_mode': 'post',
'sig': 'tw',
'sig_kw': dict(t0=0, dt=0), # dummy values
'muonflag': False,
'dt_max': 400,
'fitter_args': dict(gamma=np.r_[1, 1:4.01:.5, 4], _seed_with_prior=False,),
'threshold': 1000,
}
conf = overlay(overlay(overlay(CONF, box_conf), conf), kw)
conf['prior'] = get_utf_prior(conf.get('ana'), use_seasons=True, box=True)
srcs = conf['src']
tr_all = get_trial_runner(conf, src=srcs[0], cut_n_sigma=np.inf)
conf_doc = conf
return get_obj(trial.MultiflareTrialRunner, conf, srclist=srcs, tr_all=tr_all, conf_doc = conf_doc)
[docs]
def get_sky_scan_trial_runner(conf={}, inj_conf={}, multiflare=False, src_tr=False, **kw):
"""Get a :class:`csky.trial.SkyScanTrialRunner`."""
conf = overlay(overlay(CONF, conf), kw)
inj_conf = overlay(conf, inj_conf)
def get_tr(src, ana, **kw):
if multiflare:
return get_multiflare_trial_runner(conf=conf, inj_conf=inj_conf, src=src, ana=ana, **kw)
elif( src_tr is not False ):
src = utils.Sources(ra=src.ra, dec=src.dec)
for key in src_tr.keys():
if(key not in ["ra", "ra_deg", "dec", "dec_deg"]):
src[key] = src_tr[key]
return get_trial_runner(conf=conf, inj_conf=inj_conf, src=src, ana=ana, **kw)
def get_selector(src, cut_n_sigma=None):
cut_n_sigma = conf.get('cut_n_sigma', 5) if cut_n_sigma is None else cut_n_sigma
if not isinstance(src, utils.Arrays):
src = utils.Sources(dec=src)
return inj.DecBandSelector(src, cut_n_sigma=cut_n_sigma)
return get_obj(trial.SkyScanTrialRunner, conf, get_tr=get_tr, get_selector=get_selector)
[docs]
def get_spatial_prior_trial_runner(conf={}, inj_conf={}, multiflare=False, src_tr=False, **kw):
"""Get a :class:`csky.trial.SpatialPriorTrialRunner`."""
conf = overlay(overlay(CONF, conf), kw)
inj_conf = overlay(conf, inj_conf)
# validation: if src_tr given, must have length 1
if src_tr:
assert len(src_tr) == 1, 'src_tr must have length 1'
def merge_src_tr(src):
src = copy.deepcopy(src)
N = len(src)
for key in src_tr.keys():
if key not in 'ra ra_deg dec dec_deg'.split():
# align lengths for use in get_selector
src[key] = np.repeat(src_tr[key][0], N)
return src
def get_tr(src, ana, **kw):
if multiflare:
return get_multiflare_trial_runner(conf=conf, inj_conf=inj_conf, src=src, ana=ana, **kw)
elif src_tr is not False:
src = merge_src_tr(utils.Sources(ra=src.ra, dec=src.dec))
return get_trial_runner(conf=conf, inj_conf=inj_conf, src=src, ana=ana, **kw)
def get_selector(src, cut_n_sigma=None):
cut_n_sigma = conf.get('cut_n_sigma', 5) if cut_n_sigma is None else cut_n_sigma
if not isinstance(src, utils.Arrays):
src = utils.Sources(dec=src)
src = merge_src_tr(src)
return inj.DecBandSelector(src, cut_n_sigma=cut_n_sigma)
return get_obj(trial.SpatialPriorTrialRunner, conf, get_tr=get_tr,
get_selector=get_selector)
[docs]
def get_analysis(repo, *args, **kw):
"""Get an :class:`csky.analysis.Analysis` instance.
Args:
repo (:class:`csky.selections.Repository`): the repository for loading the data
*args: one or more data specifications, with each one optionally preceeded by a
string indicating the desired dataset version (see :mod:`csky.selections`)
**kw: other keyword arguments passed to :class:`csky.analysis.Analysis` constructor
"""
versions, specs = [], []
args = list(args)
next_version = None
while args:
arg = args.pop(0)
if isinstance(arg, str):
next_version = arg
continue
elif isinstance(arg, selections.DataSpec):
versions.append(next_version)
specs.append(arg)
elif isinstance(arg, type):
versions.append(next_version)
specs.append(arg())
else:
for spec in arg:
if isinstance(spec, type):
spec = spec()
versions.append(next_version)
specs.append(spec)
next_version = None
for (version, spec) in zip(versions, specs):
if version is not None:
spec.version = version
return get_obj(analysis.Analysis, {}, repo=repo, specs=specs, **kw)
def _isin(element, sequence):
for item in sequence:
if element is item:
return True
return False
def _remove_cycles(d, visited=None):
if not isinstance(d, dict):
return d
if visited is None:
visited = []
out = {}
for (k,v) in d.items():
if not _isin(v, visited):
if not any(isinstance(v, t) for t in (bool, int, float, np.number)):
visited.append(v)
out[k] = _remove_cycles(v, visited)
return out
def _prettydict(x, d=0):
if not x:
return '{}'
out = ['{']
for (k, v) in x.items():
if isinstance(v, dict):
out.append(' {k} = {entry}'.format(k=k, entry=_prettydict(v, d+1)))
else:
if isinstance(v, utils.Sources):
v = ', '.join ([
'({:.2f}, {:.2f}, {:.2f})'.format(r,d,e)
for (r,d,e) in zip(v.ra_deg, v.dec_deg, v.extension_deg)
])
v = '[{}] (deg)'.format(v)
elif isinstance(v, utils.Events):
v = 'Events({} items)'.format(len(v))
out.append(' {k} = {v}'.format(k=k, v=v))
out += ['}']
out = [out[0]] + [d * ' ' + line for line in out[1:]]
return '\n'.join(out)
def _pushindent(s, d):
return '\n'.join(d * ' ' + l for l in s.split('\n'))
[docs]
def describe(o, visited=None, d=0, path=''):
"""Describe a csky object and any more csky objects inside it."""
from collections.abc import Iterable
if visited is None:
visited = [o]
else:
if _isin(o, visited):
return
visited.append(o)
if hasattr(o, '_csky_conf'):
t = type(o).__name__
conf = _remove_cycles(o._csky_conf)
accepted = _remove_cycles(o._csky_defaults)
str_conf = 'configured = ' + _prettydict(conf)
str_defaults = 'defaults = ' + _prettydict(accepted)
descr = ['{t}('.format(t=t), _pushindent(str_conf, 2), _pushindent(str_defaults, 2), ')']
if path:
descr[0] = '{} = {}'.format(path, descr[0])
print('\n'.join(descr), '\n')
elif isinstance(o, list) or isinstance(o, tuple):
for (ii, element) in enumerate(o):
p = '{}[{}]'.format(path, ii)
try:
describe(element, visited, d+1, p)
except AttributeError:
continue
elif isinstance(o, np.ndarray) and o.dtype == np.dtype(object):
for (ii, element) in enumerate(o):
p = '{}[{}]'.format(path, ii)
try:
describe(element, visited, d+1, p)
except:
continue
return
elif isinstance(o, dict):
for (kk, vv) in o.items():
p = '{}[{}]'.format(path, kk)
try:
describe(vv, visited, d+1, p)
except AttributeError:
continue
for k in dir(o):
if k[0] == '_':
continue
try:
v = getattr(o, k)
except:
continue
t = type(v)
module, typename = t.__module__, t.__name__
if ('csky' in module and 'utils' not in module) or t in (list, tuple, dict):
#if typename not in 'list tuple dict':
# print(path, k, typename)
describe(v, visited, d+1, '{}.{}'.format(path, k))
CONF = {
'mp_cpus': 1,
}