# -*- coding: utf-8 -*-
"""User functions to streamline working with grids of OLS and LME
model summaries and sets of models."""

import itertools
import copy
import warnings
import re
from cycler import cycler as cy
from collections import defaultdict
import pprint as pp
import numpy as np
import pandas as pd
import matplotlib as mpl
from matplotlib import pyplot as plt
from matplotlib.lines import Line2D
import fitgrid

# enforce some common structure for summary dataframes
# scraped out of different fit objects.
# _TIME is a place holder and replaced by the grid.time value on the fly

INDEX_NAMES = ['_TIME', 'model', 'beta', 'key']

# each model, beta combination has all these values,
# some are per-beta, some are per-model

# special treatment for per-model values ... broadcast to all params

[docs]def summarize( epochs_fg, modeler, LHS, RHS, parallel=False, n_cores=2, quiet=False, **kwargs, ): """Fit the data with one or more model formulas and return summary information. Convenience wrapper, useful for keeping memory use manageable when gathering betas and fit measures for a stack of models. Parameters ---------- epochs_fg : fitgrid.epochs.Epochs as returned by `fitgrid.epochs_from_dataframe()` or `fitgrid.from_hdf()`, *NOT* a `pandas.DataFrame`. modeler : {'lm', 'lmer'} class of model to fit, `lm` for OLS, `lmer` for linear mixed-effects. Note: the RHS formula language must match the modeler. LHS : list of str the data columns to model RHS : model formula or list of model formulas to fit see the Python package `patsy` docs for `lm` formula language and the R library `lme4` docs for the `lmer` formula language. parallel : bool If True, model fitting is distributed to multiple cores n_cores : int number of cores to use. See what works, but golden rule if running on a shared machine. quiet : bool Show progress bar default=True **kwargs : key=value arguments passed to the modeler, optional Returns ------- summary_df : `pandas.DataFrame` indexed by `timestamp`, `model_formula`, `beta`, and `key`, where the keys are `ll.l_ci`, `uu.u_ci`, `AIC`, `DF`, `Estimate`, `P-val`, `SE`, `T-stat`, `has_warning`, `logLike`. Examples -------- >>> lm_formulas = [ '1 + fixed_a + fixed_b + fixed_a:fixed_b', '1 + fixed_a + fixed_b', '1 + fixed_a, '1 + fixed_b, '1', ] >>> lm_summary_df = fitgrid.utils.summarize( epochs_fg, 'lm', LHS=['MiPf', 'MiCe', 'MiPa', 'MiOc'], RHS=lmer_formulas, parallel=True, n_cores=4 ) >>> lmer_formulas = [ '1 + fixed_a + (1 + fixed_a | random_a) + (1 | random_b)', '1 + fixed_a + (1 | random_a) + (1 | random_b)', '1 + fixed_a + (1 | random_a)', ] >>> lmer_summary_df = fitgrid.utils.summarize( epochs_fg, 'lmer', LHS=['MiPf', 'MiCe', 'MiPa', 'MiOc'], RHS=lmer_formulas, parallel=True, n_cores=12, REML=False ) """ warnings.warn( 'fitgrid summaries are in early days, subject to change', FutureWarning ) # modicum of guarding msg = None if isinstance(epochs_fg, pd.DataFrame): msg = ( "Convert dataframe to fitgrid epochs with " "fitgrid.epochs_from_dataframe()" ) elif not isinstance(epochs_fg, fitgrid.epochs.Epochs): msg = f"epochs_fg must be a fitgrid.Epochs not {type(epochs_fg)}" if msg is not None: raise TypeError(msg) # select modler if modeler == 'lm': _modeler = fitgrid.lm _scraper = _lm_get_summaries_df elif modeler == 'lmer': _modeler = fitgrid.lmer _scraper = _lmer_get_summaries_df else: raise ValueError("modeler must be 'lm' or 'lmer'") # promote RHS scalar str to singleton list RHS = np.atleast_1d(RHS).tolist() # loop through model formulas fitting and scraping summaries summaries = [] for _rhs in RHS: summaries.append( _scraper( _modeler( epochs_fg, LHS=LHS, RHS=_rhs, parallel=parallel, n_cores=n_cores, quiet=quiet, **kwargs, ) ) ) summary_df = pd.concat(summaries) _check_summary_df(summary_df, epochs_fg) return summary_df
# ------------------------------------------------------------ # private-ish summary helpers for scraping summary info from fits # ------------------------------------------------------------ def _check_summary_df(summary_df, fg_obj): """check summary df structure, and against the fitgrid object if any""" # fg_obj can be fitgrid.Epochs, LMGrid or LMERGrid, they all have a time attribute # check for fatal error conditions error_msg = None # set on error # check summary if not isinstance(summary_df, pd.DataFrame): error_msg = "summary data is not a pandas.DataFrame" elif not len(summary_df): error_msg = "summary data frame is empty" elif not summary_df.index.names[1:] == INDEX_NAMES[1:]: # first name is _TIME, set from user epochs data error_msg = ( f"summary index names do not match INDEX_NAMES: {INDEX_NAMES}" ) elif not all(summary_df.index.levels[-1] == KEY_LABELS): error_msg = ( f"summary index key levels dot match KEY_LABELS: {KEY_LABELS}" ) else: # TBD pass # does summary of an object agree with its object? if fg_obj: assert any( [ isinstance(fg_obj, fgtype) for fgtype in [ fitgrid.epochs.Epochs, fitgrid.fitgrid.LMFitGrid, fitgrid.fitgrid.LMERFitGrid, ] ] ) if not summary_df.index.names == [fg_obj.time] + INDEX_NAMES[1:]: error_msg = ( f"summary fitgrid object index mismatch: " f"summary_df.index.names: {summary_df.index.names} " f"fitgrd object: {[fg_obj.time] + INDEX_NAMES[1:]}" ) if error_msg: raise ValueError(error_msg) # check for non-fatal issues if "warnings" not in summary_df.index.unique("key"): msg = ( "Summaries are from fitgrid version < 0.5.0, use that version or re-fit the" f" models with this one fitgrid.utils.summarize() v{fitgrid.__version__}" ) raise RuntimeError(msg) def _update_INDEX_NAMES(lxgrid, index_names): """use the grid time column name for the summary index""" assert index_names[0] == '_TIME' _index_names = copy.copy(index_names) _index_names[0] = lxgrid.time return _index_names def _stringify_lmer_warnings(fg_lmer): """create grid w/ _ separated string of lme4::lmer warning list items, else "" """ warning_grids = fitgrid.utils.lmer.get_lmer_warnings( fg_lmer ) # dict of indicator dataframes warning_string_grid = pd.DataFrame( np.full(fg_lmer._grid.shape, ""), index=fg_lmer._grid.index.copy(), columns=fg_lmer._grid.columns.copy(), ) # collect multiple warnings into single sorted "_" separated strings # on a tidy time x channel grid for warning, warning_grid in warning_grids.items(): for idx, row_vals in warning_grid.iterrows(): for jdx, col_val in row_vals.iteritems(): if col_val: if len(warning_string_grid.loc[idx, jdx]) == 0: warning_string_grid.loc[idx, jdx] = warning else: # split, sort, reassemble wrns = "_".join( sorted( warning_string_grid.loc[idx, jdx].split("_") + [warning] ) ) warning_string_grid.loc[idx, jdx] = wrns return warning_string_grid # def _unstringify_lmer_warnings(lmer_summaries): # """convert stringfied lmer warning grid back into dict of indicator grids as in get_lmer_warnings()""" # string_warning_grid = lmer_summaries.query("key=='warnings'") # warnings = [] # for warning in np.unique(string_warning_grid): # if len(warning) > 0: # warnings += warning.split("_") # warning_grids = {} # for warning in sorted(warnings): # warning_grids[warning] = string_warning_grid.applymap( # lambda x: 1 if warning in x else 0 # ) # return warning_grids def _lm_get_summaries_df(fg_ols, ci_alpha=0.05): """scrape fitgrid.LMFitgrid OLS info into a tidy dataframe Parameters ---------- fg_ols : fitgrid.LMFitGrid ci_alpha : float {.05} alpha for confidence interval Returns ------- summaries_df : pd.DataFrame index.names = [`_TIME`, `model`, `beta`, `key`] where `_TIME` is the `fg_ols.time` and columns are the `fg_ols` columns Notes ----- The `summaries_df` row and column indexes are munged to match fitgrid.lmer._get_summaries_df() """ # set time column from the grid, always index.names[0] _index_names = _update_INDEX_NAMES(fg_ols, INDEX_NAMES) _time = _index_names[0] # grab and tidy the formula RHS rhs = fg_ols.tester.model.formula.split('~')[1].strip() rhs = re.sub(r"\s+", " ", rhs) # fitgrid returns them in the last column of the index param_names = fg_ols.params.index.get_level_values(-1).unique() # fetch a master copy of the model info model_vals = [] model_key_attrs = [ ("DF", "df_resid"), ("AIC", "aic"), ("logLike", 'llf'), ("SSresid", 'ssr'), ("sigma2", 'mse_resid'), ] for (key, attr) in model_key_attrs: vals = None vals = getattr(fg_ols, attr).copy() if vals is None: raise AttributeError(f"model: {rhs} attribute: {attr}") vals['key'] = key model_vals.append(vals) # statsmodels result wrappers have different versions of llf! aics = (-2 * fg_ols.llf) + 2 * (fg_ols.df_model + fg_ols.k_constant) if not np.allclose(fg_ols.aic, aics): msg = ( "uh oh ...statsmodels OLS aic and llf calculations have changed." " please report an issue to fitgrid" ) raise ValueError(msg) # handle warnings # build model has_warnings with False for ols has_warnings = pd.DataFrame( np.zeros(model_vals[0].shape).astype('bool'), columns=model_vals[0].columns, index=model_vals[0].index, ) has_warnings['key'] = 'has_warning' model_vals.append(has_warnings) # build empty warning string to match has_warnings == False warnings = has_warnings.applymap(lambda x: "") warnings["key"] = "warnings" model_vals.append(warnings) model_vals = pd.concat(model_vals) # constants across the model model_vals['model'] = rhs # replicate the model info for each beta # ... horribly redundant but mighty handy when slicing later pmvs = [] for p in param_names: pmv = model_vals.copy() # pmv['param'] = p pmv['beta'] = p pmvs.append(pmv) pmvs = ( pd.concat(pmvs).reset_index().set_index(_index_names) ) # INDEX_NAMES) # lookup the param_name specific info for this bundle summaries = [] # select model point estimates mapped like so (key, OLS_attribute) sv_attrs = [ ('Estimate', 'params'), # coefficient value ('SE', 'bse'), ('P-val', 'pvalues'), ('T-stat', 'tvalues'), ] for idx, (key, attr) in enumerate(sv_attrs): attr_vals = getattr(fg_ols, attr).copy() # ! don't mod the _grid if attr_vals is None: raise AttributeError(f"not found: {attr}") attr_vals.index.set_names('beta', level=-1, inplace=True) attr_vals['model'] = rhs attr_vals['key'] = key # update list of beta bundles summaries.append( attr_vals.reset_index().set_index(_index_names) ) # INDEX_NAMES)) # special handling for confidence interval ci_bounds = [ f"{bound:.1f}_ci" for bound in [100 * (1 + (b * (1 - ci_alpha))) / 2.0 for b in [-1, 1]] ] cis = fg_ols.conf_int(alpha=ci_alpha) cis.index = cis.index.rename([_time, 'beta', 'key']) cis.index = cis.index.set_levels(ci_bounds, 'key') cis['model'] = rhs summaries.append(cis.reset_index().set_index(_index_names)) summaries_df = pd.concat(summaries) # add the parameter model info # summaries_df = pd.concat([summaries_df, pmvs]).sort_index().astype(float) summaries_df = pd.concat([summaries_df, pmvs]).sort_index() _check_summary_df(summaries_df, fg_ols) return summaries_df def _lmer_get_summaries_df(fg_lmer): """scrape a single model fitgrid.LMERFitGrid into a standard summary format Note: some values are fitgrid attributes (via pymer), others are derived Parameters ---------- fg_lmer : fitgrid.LMERFitGrid """ def scrape_sigma2(fg_lmer): # sigma2 is extracted from fg_lmer.ranef_var ... # residuals should be in the last row of ranef_var at each Time ranef_var = fg_lmer.ranef_var # set the None index names assert ranef_var.index.names == [fg_lmer.time, None, None] ranef_var.index.set_names([fg_lmer.time, 'key', 'value'], inplace=True) assert 'Residual' == ranef_var.index.get_level_values(1).unique()[-1] assert all( ['Name', 'Var', 'Std'] == ranef_var.index.get_level_values(2).unique() ) # slice out the Residual Variance at each time point # and drop all but the Time indexes to make Time x Chan sigma2 = ranef_var.query( 'key=="Residual" and value=="Var"' ).reset_index(['key', 'value'], drop=True) return sigma2 # set time column from the grid, always index.names[0] _index_names = _update_INDEX_NAMES(fg_lmer, INDEX_NAMES) _time = _index_names[0] # look these up directly pymer_attribs = ['AIC', 'has_warning', 'logLike'] # x=lmer_fg caclulate or extract from other attributes derived_attribs = { # since pymer4 0.7.1 the Lmer model.resid are renamed # model.residuals and come back as a well-behaved # dataframe of floats rather than rpy2 objects "SSresid": lambda lmer: lmer.residuals.apply(lambda x: x ** 2) .groupby([fg_lmer.time]) .sum(), 'sigma2': lambda x: scrape_sigma2(x), "warnings": lambda x: _stringify_lmer_warnings(x), } # grab and tidy the formulat RHS from the first grid cell rhs = fg_lmer.tester.formula.split('~')[1].strip() rhs = re.sub(r"\s+", "", rhs) # coef estimates and stats ... these are 2-D summaries_df = fg_lmer.coefs.copy() # don't mod the original summaries_df.index.names = [_time, 'beta', 'key'] summaries_df = summaries_df.query("key != 'Sig'") # drop the stars summaries_df.index = summaries_df.index.remove_unused_levels() summaries_df.insert(0, 'model', rhs) summaries_df.set_index('model', append=True, inplace=True) summaries_df.reset_index(['key', 'beta'], inplace=True) # scrape AIC and other useful 1-D fit attributes into summaries_df for attrib in pymer_attribs + list(derived_attribs.keys()): # # lookup or calculate model measures if attrib in pymer_attribs: attrib_df = getattr(fg_lmer, attrib).copy() else: attrib_df = derived_attribs[attrib](fg_lmer) attrib_df.insert(0, 'model', rhs) attrib_df.insert(1, 'key', attrib) # propagate attributes to each beta ... wasteful but tidy # when grouping by beta for beta in summaries_df['beta'].unique(): beta_attrib = attrib_df.copy().set_index('model', append=True) beta_attrib.insert(0, 'beta', beta) summaries_df = summaries_df.append(beta_attrib) summaries_df = ( summaries_df.reset_index() .set_index(_index_names) # INDEX_NAMES) .sort_index() # .astype(float) ) _check_summary_df(summaries_df, fg_lmer) return summaries_df def _get_AICs(summary_df): """collect AICs, AIC_min deltas, and lmer warnings from summary_df Parameters ---------- summary_df : multi-indexed pandas.DataFrame as returned by `fitgrid.summary.summarize()` Returns ------- aics : multi-indexed pandas pd.DataFrame """ # AIC and lmer warnings are 1 per model, pull from the first # model coefficient only, e.g., (Intercept) aic_cols = ["AIC", "has_warning", "warnings"] aics = [] # for model, model_data in summary_df.groupby('model'): # groupby processes models in alphabetical sort order for model in summary_df.index.unique('model'): model_data = summary_df.query("model==@model") first_param = model_data.index.get_level_values('beta').unique()[0] aic = pd.DataFrame( summary_df.loc[pd.IndexSlice[:, model, first_param, aic_cols], :] .stack(-1) .unstack("key") .reset_index(["beta"], drop=True), columns=aic_cols, ) aic.index.names = aic.index.names[:-1] + ["channel"] aics += [aic] AICs = pd.concat(aics) assert set(summary_df.index.unique('model')) == set( AICs.index.unique('model') ) # sort except model, channel AICs.sort_index( axis=0, level=[l for l in AICs.index.names if not l in ['model', 'channel']], sort_remaining=False, inplace=True, ) # calculate AIC_min for the fitted models at each time, channel AICs['min_delta'] = np.inf # init to float # time label is the first index level, may not be fitgrid.defaults.TIME assert AICs.index.names == summary_df.index.names[:2] + ["channel"] for time in AICs.index.get_level_values(0).unique(): for chan in AICs.index.get_level_values('channel').unique(): slicer = pd.IndexSlice[time, :, chan] AICs.loc[slicer, 'min_delta'] = AICs.loc[slicer, 'AIC'] - min( AICs.loc[slicer, 'AIC'] ) FutureWarning('fitgrid AICs are in early days, subject to change') assert set(summary_df.index.unique('model')) == set( AICs.index.unique('model') ) return AICs
[docs]def summaries_fdr_control( model_summary_df, method="BY", rate=0.05, plot_pvalues=True, ): r"""False discovery rate control for non-zero betas in model summary dataframes The family of tests for FDR control is assumed to be **all and only** the channels, models, and :math:`\hat{\beta}_i` in the summary dataframe. The user must select the appropriate family of tests by slicing or stacking summary dataframes before running the FDR calculator. Parameters ---------- model_summary_df : pandas.DataFrame As returned by `fitgrid.utils.summary.summarize`. method : str {"BY", "BH"} BY (default) is from Benjamini and Yekatuli [1]_, BH is Benjamini and Hochberg [2]_. rate : float {0.05} The target rate for controlling false discoveries. plot_pvalues : bool {True, False} Display a plot of the family of $p$-values and critical value for FDR control. References ---------- .. [1] Benjamini, Y., & Yekutieli, D. (2001). The control of the false discovery rate in multiple testing under dependency.The Annals of Statistics, 29, 1165-1188. .. [2] Benjamini, Y., & Hochberg, Y. (1995). Controlling the false discovery rate: A practical and powerful approach to multiple testing. Journal of the Royal Statistical Society. Series B (Methodological), 57, 289-300. """ _check_summary_df(model_summary_df, None) pvals_df = model_summary_df.query("key == 'P-val'") # fetch pvals pvals = np.sort(pvals_df.to_numpy().flatten()) m = len(pvals) ks = list() if method == 'BH': # Benjamini & Hochberg ... restricted c_m = 1 elif method == 'BY': # Benjamini & Yekatuli general case c_m = np.sum([1 / i for i in range(1, m + 1)]) else: raise ValueError("method must be 'BH' or 'BY'") for k, p in enumerate(pvals): kmcm = k / (m * c_m) if p <= kmcm * rate: ks.append(k) if len(ks) > 0: crit_p = pvals[max(ks)] crit_p_idx = np.where(pvals < crit_p)[0].max() else: crit_p = 0.0 crit_p_idx = 0 n_pvals = len(pvals) fdr_specs = { "method": method, "rate": rate, "crit_p": crit_p, "n_pvals": n_pvals, "models": list(pvals_df.index.unique('model')), "betas": list(pvals_df.index.unique('beta')), "channels": list(pvals_df.columns), } fig, ax = None, None if plot_pvalues: fig, ax = plt.subplots() ax.set_title("Distribution of $p$-values") ax.plot(np.arange(m), pvals, color="k") ax.axhline(crit_p, xmax=crit_p_idx, ls="--", color="k") ax.axvline(crit_p_idx, ymax=0.5, ls="--", color="k") ax.annotate( xy=(crit_p_idx, 0.525), text=f"critcal $p$={crit_p:0.5f} for {method} FDR {rate}", ha="left", ) ax.text( x=0.0, y=-0.15, s=pp.pformat(fdr_specs, compact=True), va="top", ha="left", transform=ax.transAxes, wrap=True, ) else: fig, ax = None, None return fdr_specs, fig, ax
[docs]def plot_betas( summary_df, LHS=[], models=[], betas=[], interval=[], beta_plot_kw={}, show_se=True, show_warnings=True, fdr_kw={}, fig_kw={}, df_func=None, scatter_size=75, **kwargs, ): """Plot model parameter estimates for model, beta, and channel LHS The time course of estimated betas and standard errors is plotted by channel for the models, betas, and channels in the data frame. Channels, models, betas and time intervals may be selected from the summary dataframe. Plots are marked with model fit warnings by default and may be tagged to indicate differences from 0 controlled for false discovery rate (FDR). Parameters ---------- summary_df : pd.DataFrame as returned by fitgrid.utils.summary.summarize LHS : list of str or [] column names of the data, [] default = all channels models : list of str or [] select model or model betas to display, [] default = all models betas : list of str [] or [] select beta or betas to plot, [] default = all betas interval : [start, stop] list of two ints time interval to plot beta_plot_kw : dict keyword arguments passed to matplotlib.axes.plot() show_se : bool toggle display of standard error shading (default = True) show_warnings : bool toggle display of model warnings (default = True) fdr_kw : dict (default empty) One or more keyword arguments passed to ``summaries_fdr_control()`` to trigger to tag plots for FDR controlled differences from 0. fig_kw : dict keyword args passed to pyplot.subplots() df_func : {None, function} toggle degrees of freedom line plot via function, e.g., ``np.log10``, ``lambda x: x`` scatter_size : float scatterplot marker size for FDR (default = 75) and warnings (= 1.5 scatter_size) Returns ------- figs : list of matplotlib.Figure Note ---- The FDR family of tests is given by all channels, models, betas, and times in the summary data frame regardless of which of these are selected for plotting. To specify a different family of tests, construct a summary dataframe with all and only the tests for that family before plotting the betas. """ _check_summary_df(summary_df, None) # fitgrid < 0.5.0 for kwarg in ["figsize", "fdr", "alpha", "s"]: if kwarg in kwargs.keys(): msg = ( "keyword {kwarg} is deprecated in fitgrid 0.5.0, has no effect and " "will be removed. See figrid.utils.summary.plot_betas() documentation." ) warnings.warn(msg, FutureWarning) # ------------------------------------------------------------ # validate kwargs error_msg = None # LHS defaults to all channels if isinstance(LHS, str): LHS = [LHS] if LHS == []: LHS = list(summary_df.columns) if not all([isinstance(col, str) for col in LHS]): error_msg = "LHS must be a list of channel name strings" for channel in LHS: if channel not in summary_df.columns: error_msg = f"channel {channel} not found in the summary columns" # model, beta for key, vals in {"model": models, "beta": betas}.items(): if vals and not ( isinstance(vals, list) and all([isinstance(itm, str) for itm in vals]) ): error_msg = f"{key} must be a list of strings" unique_vals = list(summary_df.index.unique(key)) for val in vals: if val not in unique_vals: error_msg = ( f"{val} not found, check the summary index: " f"name={key}, labels={unique_vals}" ) # validate interval if interval: t_min = summary_df.index[0][0] t_max = summary_df.index[-1][0] if not ( isinstance(interval, list) and all([isinstance(t, int) for t in interval]) and interval[0] < interval[1] and interval[0] >= t_min and interval[1] <= t_max ): error_msg = ( "interval must be a list of increasing integers " f"in the summary time range between {t_min} and {t_max}." ) # fail on any error if error_msg: raise ValueError(error_msg) # ------------------------------------------------------------ # filter summary for selections, if any # summary_df.sort_index(inplace=True) model_summary_df = summary_df # a reference may be all we need if not LHS == list(summary_df.columns): model_summary_df = summary_df[LHS].copy() if models: model_summary_df = model_summary_df.query("model in @models").copy() if betas: model_summary_df = model_summary_df.query("beta in @betas").copy() if interval: model_summary_df.sort_index(inplace=True) model_summary_df = model_summary_df.loc[ # Index = time, model, beta, key pd.IndexSlice[ interval[0] : interval[1], :, :, :, ], :, ] models = list(model_summary_df.index.unique("model")) _time = model_summary_df.index.names[0] # ------------------------------------------------------------ # optional FDR calc if fdr_kw: # the family of tests for FDR is given by the summary data, *not* # which slices happen to be selected for plotting. fdr_specs, fdr_fig, fdr_ax = summaries_fdr_control( summary_df, **fdr_kw ) if not summary_df.equals(model_summary_df): fdr_msg = ( "FDR test family is for **ALL** models, betas, and channels in " "the summary dataframe not just those selected for plotting." ) warnings.warn(fdr_msg) print(pp.pformat(fdr_specs, compact=True)) # ------------------------------------------------------------ # set up to plot various warnings consistently warning_kinds = np.unique( np.hstack( [ w.split("_") if len(w) else [] for w in np.unique(summary_df.query("key=='warnings'")) ] ) ) warning_colors = plt.rcParams["axes.prop_cycle"].by_key()["color"] warning_cycler = cy(color=warning_colors) + cy( marker=Line2D.filled_markers[: len(warning_colors)] ) # [1:len(warning_colors) + 1]) # build the dict as warning keys are encountered, then use them as styled cy_iter = iter(warning_cycler) warning_styles = defaultdict(lambda: next(cy_iter)) for warning_kind in warning_kinds: print(warning_kind) # ------------------------------------------------------------ # set up figures figs = list() for model, col in itertools.product(models, LHS): # select beta for this model for beta in model_summary_df.query("model == @model").index.unique( "beta" ): # start the fig, ax if "figsize" not in fig_kw.keys(): fig_kw["figsize"] = (8, 3) # default f, ax_beta = plt.subplots(nrows=1, ncols=1, **fig_kw) # unstack this beta as a column for plotting fg_beta = ( model_summary_df.loc[pd.IndexSlice[:, model, beta], col] .unstack(level='key') .reset_index(_time) # time label for this model_summary_df ) fg_beta.plot( x=_time, y='Estimate', ax=ax_beta, color='black', alpha=0.5, label=beta, **beta_plot_kw, ) # optional +/- SE band if show_se: beta_hat = fg_beta['Estimate'] ax_beta.fill_between( x=fg_beta[_time], y1=(beta_hat + fg_beta["SE"]).astype(float), y2=(beta_hat - fg_beta["SE"]).astype(float), alpha=0.2, color='black', ) # optional (transformed) degrees of freedom if df_func is not None: try: func_name = getattr(df_func, "__name__") except AttributeError: func_name = str(df_func) fg_beta['DF_'] = fg_beta['DF'].apply(lambda x: df_func(x)) fg_beta.plot( x=_time, y='DF_', ax=ax_beta, label=f"{func_name}(df)" ) # FDR controlled differences from 0 if fdr_kw: fdr_mask = fg_beta["P-val"] < fdr_specs["crit_p"] ax_beta.scatter( fg_beta[_time][fdr_mask], fg_beta['Estimate'][fdr_mask], color='black', zorder=3, label=f"{fdr_specs['method']} FDR p < crit {fdr_specs['crit_p']:0.2}", s=scatter_size, ) # warnings if show_warnings and any( [len(warning) for warning in fg_beta["warnings"]] ): warn_strs = np.hstack( [ np.array(wrn.split("_")) # lengths vary for wrn in fg_beta["warnings"].unique() if len(wrn) > 0 ] ) warn_strs = sorted(np.unique(warn_strs)) for warn_str in warn_strs: # separate warnings by 1/4 major tick interval sep = np.abs( (ax_beta.get_yticks()[:2] * [0.25, -0.25]).sum() ) # warn_offset = (warning_kinds.index(warn_str) + 1) * sep warn_offset = ( np.where(warning_kinds == warn_str)[0] + 1 ) * sep warn_mask = fg_beta["warnings"].apply( lambda x: warn_str in x ) ax_beta.scatter( fg_beta[_time][warn_mask], fg_beta['Estimate'][warn_mask] + warn_offset, zorder=4, label=warn_str, **warning_styles[ warn_str ], # cycler + defaultdict voodoo alpha=0.75, s=scatter_size * 1.5, ) ax_beta.axhline(y=0, linestyle='--', color='black') ax_beta.legend(loc='upper left', bbox_to_anchor=(0.0, -0.25)) formula = fg_beta.index.get_level_values('model').unique()[0] ax_beta.set_title(f'{col} {beta}: {formula}', loc='left') figs.append(f) return figs
[docs]def plot_AICmin_deltas( summary_df, show_warnings="no_labels", figsize=None, gridspec_kw=None, subplot_kw=None, ): r"""plot FitGrid min delta AICs and fitter warnings Thresholds of AIC_min delta at 2, 4, 7, 10 are from Burnham & Anderson 2004, see Notes. Parameters ---------- summary_df : pd.DataFrame as returned by fitgrid.utils.summary.summarize show_warnings : {"no_labels", "labels", str, list of str} "no_labels" (default) highlights everywhere there is any warning in red, the default behavior in fitgrid < v0.5.0. "labels" display all warning strings the axes titles. A `str` or list of `str` selects and display only warnings that (partial) match a model warning string. figsize : 2-ple pyplot.figure figure size parameter gridspec_kw : dict matplotlib.gridspec keyword args passed to ``pyplot.subplots(..., gridspec_kw=gridspec_kw})`` subplot_kw : dict keyword args passed to ``pyplot.subplots(..., subplot_kw=subplot_kw))`` Returns ------- f, axs : matplotlib.pyplot.Figure Notes ----- [BurAnd2004]_ p. 270-271. Where :math:`AIC_{min}` is the lowest AIC value for "a set of a priori candidate models well-supported by the underlying science :math:`g_{i}, i = 1, 2, ..., R)`", .. math:: \Delta_{i} = AIC_{i} - AIC_{min} "is the information loss experienced if we are using fitted model :math:`g_{i}` rather than the best model, :math:`g_{min}` for inference." ... "Some simple rules of thumb are often useful in assessing the relative merits of models in the set: Models having :math:`\Delta_{i} <= 2` have substantial support (evidence), those in which :math:`\Delta_{i} 4 <= 7` have considerably less support, and models having :math:`\Delta_{i} > 10` have essentially no support." """ def _get_warnings_grid(model_warnings, show_warnings): """look up warnings according to aic and user kwarg value""" # split the "_" separated multiple warning strings into unique types warning_kinds = np.unique( np.hstack( [ w.split("_") if len(w) else [] for w in np.unique(model_warnings) ] ) ) warning_kinds = list(warning_kinds) # optionally filter by user keyword matching if show_warnings not in ["no_labels", "labels"]: user_kinds = [] for kw_warning in show_warnings: found_kinds = [ warning_kind for warning_kind in warning_kinds if kw_warning in warning_kind # string matches ] # collect the matching kinds or warn if found_kinds: user_kinds += found_kinds else: msg = ( f"show_warnings '{kw_warning}' not found in model " f"{m} warnings: [{', '.join(warning_kinds)}]" ) warnings.warn(msg) # update filtered kinds warning_kinds = user_kinds # build indicator grid for matching warning kinds warnings_grid = model_warnings.applymap( lambda x: 1 if any([warning_kind in x for warning_kind in warning_kinds]) else 0 ) return warning_kinds, warnings_grid # ------------------------------------------------------------ # validate kwarg if show_warnings not in ["no_labels", "labels"]: # promote string to list show_warnings = list(np.atleast_1d(show_warnings)) if not all([isinstance(wrn, str) for wrn in show_warnings]): msg = ( "show_warnings must be 'all', 'kinds', or a string or list of strings " "that partial match warnings" ) raise ValueError(msg) # validate summary dataframe _check_summary_df(summary_df, None) _time = summary_df.index.names[0] # fetch the AIC min delta data aics = _get_AICs(summary_df) # long format models = aics.index.unique('model') channels = aics.index.unique('channel') # ------------------------------------------------------------ # figure setup if figsize is None: figsize = (12, 8) # reasonable default, update w/ user kwargs if any gspec_kw = {'width_ratios': [0.46, 0.46, 0.015]} if gridspec_kw: gspec_kw.update(gridspec_kw) # main figure, axes: number of models rows x 3 columns: traces, raster, colorbar f, axs = plt.subplots( len(models), # 1 axis row per model 3, squeeze=False, # keep axes shape (1, 3), though singleton model is pointless figsize=figsize, gridspec_kw=gspec_kw, subplot_kw=subplot_kw, ) # plot each model on an axes row for i, m in enumerate(models): traces = axs[i, 0] heatmap = axs[i, 1] colorbar = axs[i, 2] # ------------------------------------------------------------ # slice this model min delta values and warnings _min_deltas = ( aics.loc[pd.IndexSlice[:, m, :], 'min_delta'] .reset_index('model', drop=True) .unstack('channel') .astype(float) ) # unstack() alphanum sorts the channel index ... ugh _min_deltas = _min_deltas.reindex(columns=channels) model_warnings = ( aics.loc[pd.IndexSlice[:, m, :], 'warnings'] .reset_index('model', drop=True) .unstack('channel') ) model_warnings = model_warnings.reindex(columns=channels) # fetch warnings for heatmapping warning_kinds, warnings_grid = _get_warnings_grid( model_warnings, show_warnings ) # ------------------------------------------------------------ # plot traces and warnings in left column # left column title is model with optional list of warnings title_str = f"{m}" if warning_kinds and not show_warnings == "no_labels": title_str += "\n" + "\n".join(warning_kinds) traces.set_title(title_str, loc="left") for chan in channels: traces.plot( _min_deltas.reset_index()[_time], _min_deltas[chan], label=chan ) # warning mask chan_mask = ( _min_deltas[chan].where(warnings_grid[chan] == 1).dropna() ) traces.scatter( chan_mask.index, chan_mask, c="crimson", label=None, ) if i == 0: # first channel legend left of the main plot traces.legend() traces.legend( loc='upper right', bbox_to_anchor=(-0.2, 1.0), handles=traces.get_legend().legendHandles[::-1], ) aic_min_delta_bounds = [0, 2, 4, 7, 10] # for y in aic_min_delta_bounds: for y in aic_min_delta_bounds[1:]: traces.axhline(y=y, color='black', linestyle='dotted') # ------------------------------------------------------------ # heatmap # colorbrewer 2.0 Blues color blind safe n=5 # pal = ['#eff3ff', '#bdd7e7', '#6baed6', '#3182bd', '#08519c'] cmap = mpl.colors.ListedColormap(pal) # cmap.set_over(color='#fcae91') cmap.set_over(color='#08306b') # darkest from Blues n=7 cmap.set_under(color='lightgray') bounds = aic_min_delta_bounds norm = mpl.colors.BoundaryNorm(bounds, cmap.N) ylabels = _min_deltas.columns heatmap.yaxis.set_major_locator( mpl.ticker.FixedLocator(np.arange(len(ylabels))) ) heatmap.yaxis.set_major_formatter(mpl.ticker.FixedFormatter(ylabels)) im = heatmap.pcolormesh( _min_deltas.index, np.arange(len(ylabels)), _min_deltas.T, cmap=cmap, norm=norm, shading='nearest', ) # any non-zero warnings are red if warnings_grid.to_numpy().max(): assert (warnings_grid.index == _min_deltas.index).all() assert (warnings_grid.columns == _min_deltas.columns).all() heatmap.pcolormesh( warnings_grid.index, np.arange(len(ylabels)),, 0), shading="nearest", cmap=mpl.colors.ListedColormap(['red']), ) colorbar = mpl.colorbar.Colorbar(colorbar, im, extend='max') return f, axs