Source code for csky.conf

# conf.py

"""Configuration tools for csky."""

import sys
_PY2 = sys.version_info[0] == 2

if _PY2:
    import funcsigs
else:
    import inspect

import copy
import functools
import numpy as np

from . import pdf, inj, hyp, llh, trial, selections, analysis, utils

def get_signature(f, is_class=True):
    if _PY2:
        if hasattr(f, '__init__') and is_class: #and (f is not get_utf_prior):
            return funcsigs.signature(f.__init__)
        else:
            return funcsigs.signature(f)
    else:
        return inspect.signature(f)

[docs] def do_bind(f, conf, is_class=True): """Partially bind ``f`` by applying configuration ``conf``.""" empty = funcsigs._empty if _PY2 else inspect._empty sig = get_signature(f, is_class=is_class) accepted = { p.name: p.default for p in sig.parameters.values() if p.default is not empty} final_kw = {key: conf[key] for key in conf if key in sig.parameters} if f in conf: kw2 = {key: conf[f][key] for key in conf[f] if key in sig.parameters} final_kw.update(kw2) binding = sig.bind_partial(**final_kw) return functools.partial(f, **binding.arguments), final_kw, accepted
[docs] def get_obj(cls, conf, subana=None, is_class=True, **kw): """Get object of type ``cls``.""" constr, final_kw, accepted = do_bind(cls, conf, is_class=is_class) if subana is not None: signature = get_signature(cls, is_class=is_class) if 'ana' in signature.parameters: kw['ana'] = kw.get('ana', subana) if 'bg_ev' in signature.parameters: kw['bg_ev'] = kw.get('bg_ev', subana.bg_data) if 'sig_ev' in signature.parameters: kw['sig_ev'] = kw.get('sig_ev', subana.sig) try: out = constr(**kw) final_kw.update(kw) if 'conf' in final_kw: final_kw['conf'] out._csky_conf = final_kw out._csky_defaults = accepted return out except: print('exception raised trying to construct {}'.format(cls)) raise
[docs] def overlay(conf, new_conf={}): """Override elements of ``conf`` with other arguments.""" out = {} out.update(conf) for key in new_conf: try: if key in conf and isinstance(new_conf[key], dict) and isinstance(conf[key], dict): out[key] = overlay(conf[key], new_conf[key]) else: out[key] = new_conf[key] except: print('error handling overlay of key "{}"'.format(key)) raise return out
def get_utf_prior(ana, n_src=1, dt_max=None, use_prior=True, use_seasons=False, box=False): if not use_prior: return season_boundaries = [] season_livetimes = [] for (i, a) in enumerate(ana): mi = a.data.mjd.min() if i == 0: mi -= 0.1 season_boundaries.append(mi) season_livetimes.append(a.livetime / 86400) # sec->day season_boundaries.append(a.data.mjd.max() + 0.1) total_livetime = np.sum(season_livetimes) root_2pi = np.sqrt(2*np.pi) def prior(**params): t0, dt = pdf.UntriggeredTimePDFRatioModel.get_t0_dt(n_src=n_src, **params) frac_livetime = utils.get_frac_during_livetime_single(ana, t0, dt, 5) dt = dt * frac_livetime t0 = np.clip(t0, min(season_boundaries), max(season_boundaries)) if dt_max is None: if use_seasons: i_season = np.searchsorted(season_boundaries, t0[0]) - 1 livetime = season_livetimes[i_season] else: livetime = total_livetime else: livetime = dt_max F = 1. / livetime if not box: dt = 2 * dt ontime_livetime = np.sum(dt) return min(1, F * ontime_livetime) return prior
[docs] def get_llh(a, conf, **kw): """Get a :class:`csky.llh.LLHModel`.""" # need to check energy first: if 'customflux', use custom acceptance to match energy = conf.get('energy', 'fit') flux = conf.get('flux', None) keep_pdfs = conf.get('keep_pdfs', False) pdfs_args = {'keep_pdfs': keep_pdfs} if energy == 'customflux' and flux is not None: flux = np.atleast_1d(flux) acc_param = [] for f in flux: acc_param.append(a.get_custom_flux_acc_parameterization(f)) if len(acc_param) == 1: acc_param = acc_param[0] else: flux = np.atleast_1d(flux) if len(flux) >1: acc_param = [] for f in flux: acc_param.append(a.acc_param) else: acc_param = a.acc_param conf = overlay(conf, dict(acc_param=acc_param)) # SPACE space_names = dict( ps=pdf.PointSourceSpacePDFRatioModel, fitps=pdf.FitPointSourceSpacePDFRatioModel, template=pdf.TemplateSpacePDFRatioModel, prior=pdf.PriorSpacePDFRatioModel, generic=pdf.AccWeightedGenericPDFRatioModel) space = conf.get('space', 'ps') space = space_names.get(space, space) space_model = get_obj(space, conf, a) # TIME time_names = dict( utf=pdf.UntriggeredTimePDFRatioModel, lc=pdf.BinnedTimePDFRatioModel, transient=pdf.TransientTimePDFRatioModel) time_name = conf.get('time', None) time = time_names.get(time_name, time_name) if time is None: models = [space_model] else: time_model = get_obj(time, conf, a) spacetime_model = pdf.SpaceTimePDFRatioModel(space_model, time_model) models = [spacetime_model] # ENERGY if energy: if energy == 'fit': energy_model = a.energy_pdf_ratio_model energy = type(energy_model) elif energy == 'customflux': flux = conf['flux'] flux = np.atleast_1d(flux) # keep pdfs for stacking energy_model = [] if len(flux) >= 2: print("loading energy pdfs of ", len(flux)," fluxes") for f in flux: energy_pdf = a.get_custom_flux_energy_pdf_ratio_model(f) energy_pdf.keep_pdfs = keep_pdfs energy_model.append(energy_pdf) energy = type(energy_pdf) else: energy_model = a.get_custom_flux_energy_pdf_ratio_model(flux) energy = type(energy_model) else: print("WARNING: unable to find energy conf") try: if len(energy_model) > 1 and keep_pdfs: print("keeping energy models for all fluxes, make sure intended") models += energy_model except: models.append(energy_model) # ANY OTHERS!? for key in conf: if isinstance(key, type): if pdf.PDFRatioModel in key.mro(): if key not in (space, time, energy): model = get_obj(key, conf, a) models.append(model) if len(models) >= 2: prm = pdf.MultiPDFRatioModel(*models, **pdfs_args) else: prm = models[0] if( time_name=="transient" ): extended = conf.get('extended', True) lm = get_obj(llh.LLHModel, conf, a, pdf_ratio_model=prm, N_bg=time_model.nb, extended=extended) else: lm = get_obj(llh.LLHModel, conf, a, pdf_ratio_model=prm, keep_pdfs=conf.get('keep_pdfs', False)) return lm
[docs] def get_injs(a, llh_model, conf, do_inj=True, llh_conf={}, **kw): ## to be modified """Get :class:`csky.inj.Injector` instances for TRUTH, background, and signal""" prm = llh_model.pdf_ratio_model keep = list(set(prm.keep + conf.get('extra_keep', []))) data = a.data use_time = conf.get('time', None)=="transient" full_sky = conf.get('full_sky', False) if 'template' in conf or 'gpbg_conf' in conf: selected_data = data elif( use_time ): time_model = prm.get_time_model() selector = get_obj(inj.DecBandSelector, llh_conf) t_selector = inj.TimeWindowSelector(time_model.src_mjd_min, time_model.src_mjd_max) selected_data = t_selector(selector(a.data)) else: selector = get_obj(inj.DecBandSelector, llh_conf) selected_data = selector(data) truth = get_obj(inj.DataInjector, conf, a, data=selected_data, keep=keep, randomizers=[]) randomize = conf.get('randomize', ['grl' if 'time' in conf else 'ra']) randomize = np.atleast_1d(randomize) randomizers = [] for r in randomize: randomizer_names = dict( ra=inj.RARandomizer, dec=inj.DecRandomizer, energy=inj.EnergyRandomizer, grl=inj.MJDGRLRandomizer) r = randomizer_names.get(r, r) rkw = {} if r is inj.MJDGRLRandomizer: rkw['grl'] = a.grl keep.append('azimuth') randomizers.append(get_obj(r, conf, **rkw)) # safety check to make sure at least one of the base randomizers is given base_randomizers = ( inj.RARandomizer, inj.MJDGRLRandomizer, inj.MJDShuffleRandomizer ) if not any(isinstance(r, base_randomizers) for r in randomizers): raise ValueError( 'Must specify at least one of {}, but given are {}'.format( base_randomizers, randomizers)) if use_time: bg = get_obj(inj.DataOnOffInjector, conf, a, time_model=time_model, keep=keep, randomizers=[inj.MJDShuffleRandomizer()], full_sky=full_sky) elif 'gpbg_conf' in conf: print("WARNING: using gp template injecting gp events as bkg") gp_conf = conf['gpbg_conf'] gp_conf['llh_kw'] = dict(conf=gp_conf) gp_llh_model=get_llh(a, gp_conf) template_model=gp_llh_model.pdf_ratio_model #template_model = gp_prm.acc_weighted_model gp_conf['template_model'] = template_model bg = get_obj(inj.TemplateBGInjector, gp_conf, a, keep=keep, randomizers=randomizers, mcbg = gp_conf['mcbg']) elif 'bg_weight_names' in conf: bg = get_obj(inj.MCBackgroundInjector, conf, a, keep=keep, randomizers=randomizers) else: bg = get_obj(inj.DataInjector, conf, a, data=selected_data, keep=keep, randomizers=randomizers) if do_inj: sig_names = dict( ps=inj.PointSourceInjector, tw=inj.PointSourceTimeWindowInjector, lc=inj.PointSourceBinnedTimeInjector, transient=inj.TransientInjector, template=inj.TemplateInjector, prior=inj.SpatialPriorInjector, rednoise=inj.PointSourceRedNoiseInjector) sig = conf.get('sig', 'ps') sig = sig_names.get(sig, sig) sig_kw = dict(keep=keep) sig_kw.update(conf.get('sig_kw', {})) if sig is inj.TemplateInjector: template_model = prm.acc_weighted_model if isinstance(template_model, pdf.SpaceTimePDFRatioModel): template_model = template_model.space_model sig_kw['template_model'] = template_model if sig is inj.TransientInjector: use_prior = conf.get('space', 'ps')=="prior" sig = get_obj(sig, conf, a, time_model=time_model, spatial_prior=use_prior, **sig_kw) else: sig = get_obj(sig, conf, a, **sig_kw) else: sig = None return truth, bg, sig
[docs] def get_trial_runner(conf={}, inj_conf={}, **kw): """Get a :class:`csky.trial.TrialRunner`.""" conf = overlay(overlay(CONF, conf), kw) ana = conf.get('ana') if ana is None: raise ValueError( 'must provide \'ana\' in cy.CONF, conf argument, or as keyword argument ana=___') if 'time' in conf: conf.setdefault('t0_min', ana.mjd_min) conf.setdefault('t0_max', ana.mjd_max) if conf['time'] in ('utf', pdf.UntriggeredTimePDFRatioModel): conf.setdefault('prior', get_obj(get_utf_prior, conf, ana, is_class=False)) conf.setdefault('t0_max', ana.mjd_max) if 'bins_energy' not in conf and 'corona_flux' not in conf: conf.setdefault('flux', hyp.PowerLawFlux(2)) inj_conf = overlay(conf, inj_conf) if 'template' in conf and 'gpbg_conf' not in conf: conf.setdefault('space', 'template') inj_conf.setdefault('sig', 'template') conf['llh_kw'] = dict(conf=conf) if 'src' in conf: conf['llh_kw']['src'] = conf['src'] # can have different weights here than injected truth src conf['inj_kw'] = dict(conf=inj_conf, do_inj=kw.get('inj', ana.has_sig), llh_conf=conf) if 'gpbg_conf' in conf: conf['gpbg_conf'].setdefault('space', 'template') conf['gpbg_conf'].setdefault('sig', 'template') conf['gpbg_conf']['inj_conf'].setdefault('sig', 'template') conf['gpbg_conf']['inj_kw'] = dict(conf=conf['gpbg_conf']['inj_conf'], do_inj=kw.get('inj', ana.has_sig), llh_conf=conf['gpbg_conf']) fa = conf.get('fitter_args', {}) seeder = conf.get('seeder', None) if seeder is not None: fa['seeder'] = seeder prior = conf.get('prior', None) if prior is not None: fa['prior'] = prior conf['fitter_args'] = fa return get_obj( trial.TrialRunner, conf, get_llh=get_llh, get_injs=get_injs)
[docs] def get_multiflare_trial_runner(conf={}, **kw): """Get a :class:`csky.trial.MultiflareTrialRunner`.""" box_conf = { 'time': 'utf', 'box': True, 'box_mode': 'post', 'sig': 'tw', 'sig_kw': dict(t0=0, dt=0), # dummy values 'muonflag': False, 'dt_max': 400, 'fitter_args': dict(gamma=np.r_[1, 1:4.01:.5, 4], _seed_with_prior=False,), 'threshold': 1000, } conf = overlay(overlay(overlay(CONF, box_conf), conf), kw) conf['prior'] = get_utf_prior(conf.get('ana'), use_seasons=True, box=True) srcs = conf['src'] tr_all = get_trial_runner(conf, src=srcs[0], cut_n_sigma=np.inf) conf_doc = conf return get_obj(trial.MultiflareTrialRunner, conf, srclist=srcs, tr_all=tr_all, conf_doc = conf_doc)
[docs] def get_sky_scan_trial_runner(conf={}, inj_conf={}, multiflare=False, src_tr=False, **kw): """Get a :class:`csky.trial.SkyScanTrialRunner`.""" conf = overlay(overlay(CONF, conf), kw) inj_conf = overlay(conf, inj_conf) def get_tr(src, ana, **kw): if multiflare: return get_multiflare_trial_runner(conf=conf, inj_conf=inj_conf, src=src, ana=ana, **kw) elif( src_tr is not False ): src = utils.Sources(ra=src.ra, dec=src.dec) for key in src_tr.keys(): if(key not in ["ra", "ra_deg", "dec", "dec_deg"]): src[key] = src_tr[key] return get_trial_runner(conf=conf, inj_conf=inj_conf, src=src, ana=ana, **kw) def get_selector(src, cut_n_sigma=None): cut_n_sigma = conf.get('cut_n_sigma', 5) if cut_n_sigma is None else cut_n_sigma if not isinstance(src, utils.Arrays): src = utils.Sources(dec=src) return inj.DecBandSelector(src, cut_n_sigma=cut_n_sigma) return get_obj(trial.SkyScanTrialRunner, conf, get_tr=get_tr, get_selector=get_selector)
[docs] def get_spatial_prior_trial_runner(conf={}, inj_conf={}, multiflare=False, src_tr=False, **kw): """Get a :class:`csky.trial.SpatialPriorTrialRunner`.""" conf = overlay(overlay(CONF, conf), kw) inj_conf = overlay(conf, inj_conf) # validation: if src_tr given, must have length 1 if src_tr: assert len(src_tr) == 1, 'src_tr must have length 1' def merge_src_tr(src): src = copy.deepcopy(src) N = len(src) for key in src_tr.keys(): if key not in 'ra ra_deg dec dec_deg'.split(): # align lengths for use in get_selector src[key] = np.repeat(src_tr[key][0], N) return src def get_tr(src, ana, **kw): if multiflare: return get_multiflare_trial_runner(conf=conf, inj_conf=inj_conf, src=src, ana=ana, **kw) elif src_tr is not False: src = merge_src_tr(utils.Sources(ra=src.ra, dec=src.dec)) return get_trial_runner(conf=conf, inj_conf=inj_conf, src=src, ana=ana, **kw) def get_selector(src, cut_n_sigma=None): cut_n_sigma = conf.get('cut_n_sigma', 5) if cut_n_sigma is None else cut_n_sigma if not isinstance(src, utils.Arrays): src = utils.Sources(dec=src) src = merge_src_tr(src) return inj.DecBandSelector(src, cut_n_sigma=cut_n_sigma) return get_obj(trial.SpatialPriorTrialRunner, conf, get_tr=get_tr, get_selector=get_selector)
[docs] def get_analysis(repo, *args, **kw): """Get an :class:`csky.analysis.Analysis` instance. Args: repo (:class:`csky.selections.Repository`): the repository for loading the data *args: one or more data specifications, with each one optionally preceeded by a string indicating the desired dataset version (see :mod:`csky.selections`) **kw: other keyword arguments passed to :class:`csky.analysis.Analysis` constructor """ versions, specs = [], [] args = list(args) next_version = None while args: arg = args.pop(0) if isinstance(arg, str): next_version = arg continue elif isinstance(arg, selections.DataSpec): versions.append(next_version) specs.append(arg) elif isinstance(arg, type): versions.append(next_version) specs.append(arg()) else: for spec in arg: if isinstance(spec, type): spec = spec() versions.append(next_version) specs.append(spec) next_version = None for (version, spec) in zip(versions, specs): if version is not None: spec.version = version return get_obj(analysis.Analysis, {}, repo=repo, specs=specs, **kw)
def _isin(element, sequence): for item in sequence: if element is item: return True return False def _remove_cycles(d, visited=None): if not isinstance(d, dict): return d if visited is None: visited = [] out = {} for (k,v) in d.items(): if not _isin(v, visited): if not any(isinstance(v, t) for t in (bool, int, float, np.number)): visited.append(v) out[k] = _remove_cycles(v, visited) return out def _prettydict(x, d=0): if not x: return '{}' out = ['{'] for (k, v) in x.items(): if isinstance(v, dict): out.append(' {k} = {entry}'.format(k=k, entry=_prettydict(v, d+1))) else: if isinstance(v, utils.Sources): v = ', '.join ([ '({:.2f}, {:.2f}, {:.2f})'.format(r,d,e) for (r,d,e) in zip(v.ra_deg, v.dec_deg, v.extension_deg) ]) v = '[{}] (deg)'.format(v) elif isinstance(v, utils.Events): v = 'Events({} items)'.format(len(v)) out.append(' {k} = {v}'.format(k=k, v=v)) out += ['}'] out = [out[0]] + [d * ' ' + line for line in out[1:]] return '\n'.join(out) def _pushindent(s, d): return '\n'.join(d * ' ' + l for l in s.split('\n'))
[docs] def describe(o, visited=None, d=0, path=''): """Describe a csky object and any more csky objects inside it.""" from collections.abc import Iterable if visited is None: visited = [o] else: if _isin(o, visited): return visited.append(o) if hasattr(o, '_csky_conf'): t = type(o).__name__ conf = _remove_cycles(o._csky_conf) accepted = _remove_cycles(o._csky_defaults) str_conf = 'configured = ' + _prettydict(conf) str_defaults = 'defaults = ' + _prettydict(accepted) descr = ['{t}('.format(t=t), _pushindent(str_conf, 2), _pushindent(str_defaults, 2), ')'] if path: descr[0] = '{} = {}'.format(path, descr[0]) print('\n'.join(descr), '\n') elif isinstance(o, list) or isinstance(o, tuple): for (ii, element) in enumerate(o): p = '{}[{}]'.format(path, ii) try: describe(element, visited, d+1, p) except AttributeError: continue elif isinstance(o, np.ndarray) and o.dtype == np.dtype(object): for (ii, element) in enumerate(o): p = '{}[{}]'.format(path, ii) try: describe(element, visited, d+1, p) except: continue return elif isinstance(o, dict): for (kk, vv) in o.items(): p = '{}[{}]'.format(path, kk) try: describe(vv, visited, d+1, p) except AttributeError: continue for k in dir(o): if k[0] == '_': continue try: v = getattr(o, k) except: continue t = type(v) module, typename = t.__module__, t.__name__ if ('csky' in module and 'utils' not in module) or t in (list, tuple, dict): #if typename not in 'list tuple dict': # print(path, k, typename) describe(v, visited, d+1, '{}.{}'.format(path, k))
CONF = { 'mp_cpus': 1, }