# -*- coding: utf-8 -*-
"""User functions to streamline working with grids of OLS and LME
model summaries and sets of models."""
import itertools
import copy
import warnings
import re
from cycler import cycler as cy
from collections import defaultdict
import pprint as pp
import numpy as np
import pandas as pd
import matplotlib as mpl
from matplotlib import pyplot as plt
from matplotlib.lines import Line2D
import fitgrid
# enforce some common structure for summary dataframes
# scraped out of different fit objects.
# _TIME is a place holder and replaced by the grid.time value on the fly
INDEX_NAMES = ['_TIME', 'model', 'beta', 'key']
# each model, beta combination has all these values,
# some are per-beta, some are per-model
KEY_LABELS = [
    '2.5_ci',
    '97.5_ci',
    'AIC',
    'DF',
    'Estimate',
    'P-val',
    'SE',
    'SSresid',
    'T-stat',
    'has_warning',
    'logLike',
    'sigma2',
    'warnings',
]
# special treatment for per-model values ... broadcast to all params
PER_MODEL_KEY_LABELS = [
    'AIC',
    'SSresid',
    'has_warning',
    'warnings',
    'logLike',
    'sigma2',
]
[docs]def summarize(
    epochs_fg,
    modeler,
    LHS,
    RHS,
    parallel=False,
    n_cores=2,
    quiet=False,
    **kwargs,
):
    """Fit the data with one or more model formulas and return summary information.
    Convenience wrapper, useful for keeping memory use manageable when
    gathering betas and fit measures for a stack of models.
    Parameters
    ----------
    epochs_fg : fitgrid.epochs.Epochs
       as returned by `fitgrid.epochs_from_dataframe()` or
       `fitgrid.from_hdf()`, *NOT* a `pandas.DataFrame`.
    modeler : {'lm', 'lmer'}
       class of model to fit, `lm` for OLS, `lmer` for linear mixed-effects.
       Note: the RHS formula language must match the modeler.
    LHS : list of str
       the data columns to model
    RHS : model formula or list of model formulas to fit
       see the Python package `patsy` docs for `lm` formula language
       and the R library `lme4` docs for the `lmer` formula language.
    parallel : bool
       If True, model fitting is distributed to multiple cores
    n_cores : int
       number of cores to use. See what works, but golden rule if running
       on a shared machine.
    quiet : bool
       Show progress bar default=True
    **kwargs : key=value arguments passed to the modeler, optional
    Returns
    -------
    summary_df : `pandas.DataFrame`
        indexed by `timestamp`, `model_formula`, `beta`, and `key`,
        where the keys are `ll.l_ci`, `uu.u_ci`, `AIC`, `DF`, `Estimate`,
        `P-val`, `SE`, `T-stat`, `has_warning`, `logLike`.
    Examples
    --------
    >>> lm_formulas = [
        '1 + fixed_a + fixed_b + fixed_a:fixed_b',
        '1 + fixed_a + fixed_b',
        '1 + fixed_a,
        '1 + fixed_b,
        '1',
    ]
    >>> lm_summary_df = fitgrid.utils.summarize(
        epochs_fg,
        'lm',
        LHS=['MiPf', 'MiCe', 'MiPa', 'MiOc'],
        RHS=lmer_formulas,
        parallel=True,
        n_cores=4
    )
    >>> lmer_formulas = [
        '1 + fixed_a + (1 + fixed_a | random_a) + (1 | random_b)',
        '1 + fixed_a + (1 | random_a) + (1 | random_b)',
        '1 + fixed_a + (1 | random_a)',
    ]
    >>> lmer_summary_df = fitgrid.utils.summarize(
        epochs_fg,
        'lmer',
        LHS=['MiPf', 'MiCe', 'MiPa', 'MiOc'],
        RHS=lmer_formulas,
        parallel=True,
        n_cores=12,
        REML=False
    )
    """
    warnings.warn(
        'fitgrid summaries are in early days, subject to change', FutureWarning
    )
    # modicum of guarding
    msg = None
    if isinstance(epochs_fg, pd.DataFrame):
        msg = (
            "Convert dataframe to fitgrid epochs with "
            "fitgrid.epochs_from_dataframe()"
        )
    elif not isinstance(epochs_fg, fitgrid.epochs.Epochs):
        msg = f"epochs_fg must be a fitgrid.Epochs not {type(epochs_fg)}"
    if msg is not None:
        raise TypeError(msg)
    # select modler
    if modeler == 'lm':
        _modeler = fitgrid.lm
        _scraper = _lm_get_summaries_df
    elif modeler == 'lmer':
        _modeler = fitgrid.lmer
        _scraper = _lmer_get_summaries_df
    else:
        raise ValueError("modeler must be 'lm' or 'lmer'")
    # promote RHS scalar str to singleton list
    RHS = np.atleast_1d(RHS).tolist()
    # loop through model formulas fitting and scraping summaries
    summaries = []
    for _rhs in RHS:
        summaries.append(
            _scraper(
                _modeler(
                    epochs_fg,
                    LHS=LHS,
                    RHS=_rhs,
                    parallel=parallel,
                    n_cores=n_cores,
                    quiet=quiet,
                    **kwargs,
                )
            )
        )
    summary_df = pd.concat(summaries)
    _check_summary_df(summary_df, epochs_fg)
    return summary_df 
# ------------------------------------------------------------
# private-ish summary helpers for scraping summary info from fits
# ------------------------------------------------------------
def _check_summary_df(summary_df, fg_obj):
    """check summary df structure, and against the fitgrid object if any"""
    # fg_obj can be fitgrid.Epochs, LMGrid or LMERGrid, they all have a time attribute
    # check for fatal error conditions
    error_msg = None  # set on error
    # check summary
    if not isinstance(summary_df, pd.DataFrame):
        error_msg = "summary data is not a pandas.DataFrame"
    elif not len(summary_df):
        error_msg = "summary data frame is empty"
    elif not summary_df.index.names[1:] == INDEX_NAMES[1:]:
        # first name is _TIME, set from user epochs data
        error_msg = (
            f"summary index names do not match INDEX_NAMES: {INDEX_NAMES}"
        )
    elif not all(summary_df.index.levels[-1] == KEY_LABELS):
        error_msg = (
            f"summary index key levels dot match KEY_LABELS: {KEY_LABELS}"
        )
    else:
        # TBD
        pass
    # does summary of an object agree with its object?
    if fg_obj:
        assert any(
            [
                isinstance(fg_obj, fgtype)
                for fgtype in [
                    fitgrid.epochs.Epochs,
                    fitgrid.fitgrid.LMFitGrid,
                    fitgrid.fitgrid.LMERFitGrid,
                ]
            ]
        )
        if not summary_df.index.names == [fg_obj.time] + INDEX_NAMES[1:]:
            error_msg = (
                f"summary fitgrid object index mismatch: "
                f"summary_df.index.names: {summary_df.index.names} "
                f"fitgrd object: {[fg_obj.time] + INDEX_NAMES[1:]}"
            )
    if error_msg:
        raise ValueError(error_msg)
    # check for non-fatal issues
    if "warnings" not in summary_df.index.unique("key"):
        msg = (
            "Summaries are from fitgrid version < 0.5.0, use that version or re-fit the"
            f" models with this one fitgrid.utils.summarize() v{fitgrid.__version__}"
        )
        raise RuntimeError(msg)
def _update_INDEX_NAMES(lxgrid, index_names):
    """use the grid time column name for the summary index"""
    assert index_names[0] == '_TIME'
    _index_names = copy.copy(index_names)
    _index_names[0] = lxgrid.time
    return _index_names
def _stringify_lmer_warnings(fg_lmer):
    """create grid w/ _ separated string of lme4::lmer warning list items, else "" """
    warning_grids = fitgrid.utils.lmer.get_lmer_warnings(
        fg_lmer
    )  # dict of indicator dataframes
    warning_string_grid = pd.DataFrame(
        np.full(fg_lmer._grid.shape, ""),
        index=fg_lmer._grid.index.copy(),
        columns=fg_lmer._grid.columns.copy(),
    )
    # collect multiple warnings into single sorted "_" separated strings
    # on a tidy time x channel grid
    for warning, warning_grid in warning_grids.items():
        for idx, row_vals in warning_grid.iterrows():
            for jdx, col_val in row_vals.iteritems():
                if col_val:
                    if len(warning_string_grid.loc[idx, jdx]) == 0:
                        warning_string_grid.loc[idx, jdx] = warning
                    else:
                        # split, sort, reassemble
                        wrns = "_".join(
                            sorted(
                                warning_string_grid.loc[idx, jdx].split("_")
                                + [warning]
                            )
                        )
                        warning_string_grid.loc[idx, jdx] = wrns
    return warning_string_grid
# def _unstringify_lmer_warnings(lmer_summaries):
#     """convert stringfied lmer warning grid back into dict of indicator grids as in get_lmer_warnings()"""
#     string_warning_grid = lmer_summaries.query("key=='warnings'")
#     warnings = []
#     for warning in np.unique(string_warning_grid):
#         if len(warning) > 0:
#             warnings += warning.split("_")
#     warning_grids = {}
#     for warning in sorted(warnings):
#         warning_grids[warning] = string_warning_grid.applymap(
#             lambda x: 1 if warning in x else 0
#         )
#     return warning_grids
def _lm_get_summaries_df(fg_ols, ci_alpha=0.05):
    """scrape fitgrid.LMFitgrid OLS info into a tidy dataframe
    Parameters
    ----------
    fg_ols : fitgrid.LMFitGrid
    ci_alpha : float {.05}
       alpha for confidence interval
    Returns
    -------
    summaries_df : pd.DataFrame
       index.names = [`_TIME`, `model`, `beta`, `key`] where
       `_TIME` is the `fg_ols.time` and columns are the `fg_ols` columns
    Notes
    -----
    The `summaries_df` row and column indexes are munged to match
    fitgrid.lmer._get_summaries_df()
    """
    # set time column from the grid, always index.names[0]
    _index_names = _update_INDEX_NAMES(fg_ols, INDEX_NAMES)
    _time = _index_names[0]
    # grab and tidy the formula RHS
    rhs = fg_ols.tester.model.formula.split('~')[1].strip()
    rhs = re.sub(r"\s+", " ", rhs)
    # fitgrid returns them in the last column of the index
    param_names = fg_ols.params.index.get_level_values(-1).unique()
    # fetch a master copy of the model info
    model_vals = []
    model_key_attrs = [
        ("DF", "df_resid"),
        ("AIC", "aic"),
        ("logLike", 'llf'),
        ("SSresid", 'ssr'),
        ("sigma2", 'mse_resid'),
    ]
    for (key, attr) in model_key_attrs:
        vals = None
        vals = getattr(fg_ols, attr).copy()
        if vals is None:
            raise AttributeError(f"model: {rhs} attribute: {attr}")
        vals['key'] = key
        model_vals.append(vals)
    # statsmodels result wrappers have different versions of llf!
    aics = (-2 * fg_ols.llf) + 2 * (fg_ols.df_model + fg_ols.k_constant)
    if not np.allclose(fg_ols.aic, aics):
        msg = (
            "uh oh ...statsmodels OLS aic and llf calculations have changed."
            " please report an issue to fitgrid"
        )
        raise ValueError(msg)
    # handle warnings
    # build model has_warnings with False for ols
    has_warnings = pd.DataFrame(
        np.zeros(model_vals[0].shape).astype('bool'),
        columns=model_vals[0].columns,
        index=model_vals[0].index,
    )
    has_warnings['key'] = 'has_warning'
    model_vals.append(has_warnings)
    # build empty warning string to match has_warnings == False
    warnings = has_warnings.applymap(lambda x: "")
    warnings["key"] = "warnings"
    model_vals.append(warnings)
    model_vals = pd.concat(model_vals)
    # constants across the model
    model_vals['model'] = rhs
    # replicate the model info for each beta
    # ... horribly redundant but mighty handy when slicing later
    pmvs = []
    for p in param_names:
        pmv = model_vals.copy()
        # pmv['param'] = p
        pmv['beta'] = p
        pmvs.append(pmv)
    pmvs = (
        pd.concat(pmvs).reset_index().set_index(_index_names)
    )  # INDEX_NAMES)
    # lookup the param_name specific info for this bundle
    summaries = []
    # select model point estimates mapped like so (key, OLS_attribute)
    sv_attrs = [
        ('Estimate', 'params'),  # coefficient value
        ('SE', 'bse'),
        ('P-val', 'pvalues'),
        ('T-stat', 'tvalues'),
    ]
    for idx, (key, attr) in enumerate(sv_attrs):
        attr_vals = getattr(fg_ols, attr).copy()  # ! don't mod the _grid
        if attr_vals is None:
            raise AttributeError(f"not found: {attr}")
        attr_vals.index.set_names('beta', level=-1, inplace=True)
        attr_vals['model'] = rhs
        attr_vals['key'] = key
        # update list of beta bundles
        summaries.append(
            attr_vals.reset_index().set_index(_index_names)
        )  # INDEX_NAMES))
    # special handling for confidence interval
    ci_bounds = [
        f"{bound:.1f}_ci"
        for bound in [100 * (1 + (b * (1 - ci_alpha))) / 2.0 for b in [-1, 1]]
    ]
    cis = fg_ols.conf_int(alpha=ci_alpha)
    cis.index = cis.index.rename([_time, 'beta', 'key'])
    cis.index = cis.index.set_levels(ci_bounds, 'key')
    cis['model'] = rhs
    summaries.append(cis.reset_index().set_index(_index_names))
    summaries_df = pd.concat(summaries)
    # add the parameter model info
    # summaries_df = pd.concat([summaries_df, pmvs]).sort_index().astype(float)
    summaries_df = pd.concat([summaries_df, pmvs]).sort_index()
    _check_summary_df(summaries_df, fg_ols)
    return summaries_df
def _lmer_get_summaries_df(fg_lmer):
    """scrape a single model fitgrid.LMERFitGrid into a standard summary format
    Note: some values are fitgrid attributes (via pymer), others are derived
    Parameters
    ----------
    fg_lmer : fitgrid.LMERFitGrid
    """
    def scrape_sigma2(fg_lmer):
        # sigma2 is extracted from fg_lmer.ranef_var ...
        # residuals should be in the last row of ranef_var at each Time
        ranef_var = fg_lmer.ranef_var
        # set the None index names
        assert ranef_var.index.names == [fg_lmer.time, None, None]
        ranef_var.index.set_names([fg_lmer.time, 'key', 'value'], inplace=True)
        assert 'Residual' == ranef_var.index.get_level_values(1).unique()[-1]
        assert all(
            ['Name', 'Var', 'Std']
            == ranef_var.index.get_level_values(2).unique()
        )
        # slice out the Residual Variance at each time point
        # and drop all but the Time indexes to make Time x Chan
        sigma2 = ranef_var.query(
            'key=="Residual" and value=="Var"'
        ).reset_index(['key', 'value'], drop=True)
        return sigma2
    # set time column from the grid, always index.names[0]
    _index_names = _update_INDEX_NAMES(fg_lmer, INDEX_NAMES)
    _time = _index_names[0]
    # look these up directly
    pymer_attribs = ['AIC', 'has_warning', 'logLike']
    #  x=lmer_fg caclulate or extract from other attributes
    derived_attribs = {
        # since pymer4 0.7.1 the Lmer model.resid are renamed
        # model.residuals and come back as a well-behaved
        # dataframe of floats rather than rpy2 objects
        "SSresid": lambda lmer: lmer.residuals.apply(lambda x: x ** 2)
        .groupby([fg_lmer.time])
        .sum(),
        'sigma2': lambda x: scrape_sigma2(x),
        "warnings": lambda x: _stringify_lmer_warnings(x),
    }
    # grab and tidy the formulat RHS from the first grid cell
    rhs = fg_lmer.tester.formula.split('~')[1].strip()
    rhs = re.sub(r"\s+", "", rhs)
    # coef estimates and stats ... these are 2-D
    summaries_df = fg_lmer.coefs.copy()  # don't mod the original
    summaries_df.index.names = [_time, 'beta', 'key']
    summaries_df = summaries_df.query("key != 'Sig'")  # drop the stars
    summaries_df.index = summaries_df.index.remove_unused_levels()
    summaries_df.insert(0, 'model', rhs)
    summaries_df.set_index('model', append=True, inplace=True)
    summaries_df.reset_index(['key', 'beta'], inplace=True)
    # scrape AIC and other useful 1-D fit attributes into summaries_df
    for attrib in pymer_attribs + list(derived_attribs.keys()):
        # LOGGER.info(attrib)
        # lookup or calculate model measures
        if attrib in pymer_attribs:
            attrib_df = getattr(fg_lmer, attrib).copy()
        else:
            attrib_df = derived_attribs[attrib](fg_lmer)
        attrib_df.insert(0, 'model', rhs)
        attrib_df.insert(1, 'key', attrib)
        # propagate attributes to each beta ... wasteful but tidy
        # when grouping by beta
        for beta in summaries_df['beta'].unique():
            beta_attrib = attrib_df.copy().set_index('model', append=True)
            beta_attrib.insert(0, 'beta', beta)
            summaries_df = summaries_df.append(beta_attrib)
    summaries_df = (
        summaries_df.reset_index()
        .set_index(_index_names)  # INDEX_NAMES)
        .sort_index()
        #        .astype(float)
    )
    _check_summary_df(summaries_df, fg_lmer)
    return summaries_df
def _get_AICs(summary_df):
    """collect AICs, AIC_min deltas, and lmer warnings from summary_df
    Parameters
    ----------
    summary_df : multi-indexed pandas.DataFrame
       as returned by `fitgrid.summary.summarize()`
    Returns
    -------
    aics : multi-indexed pandas pd.DataFrame
    """
    # AIC and lmer warnings are 1 per model, pull from the first
    # model coefficient only, e.g., (Intercept)
    aic_cols = ["AIC", "has_warning", "warnings"]
    aics = []
    # for model, model_data in summary_df.groupby('model'):
    # groupby processes models in alphabetical sort order
    for model in summary_df.index.unique('model'):
        model_data = summary_df.query("model==@model")
        first_param = model_data.index.get_level_values('beta').unique()[0]
        aic = pd.DataFrame(
            summary_df.loc[pd.IndexSlice[:, model, first_param, aic_cols], :]
            .stack(-1)
            .unstack("key")
            .reset_index(["beta"], drop=True),
            columns=aic_cols,
        )
        aic.index.names = aic.index.names[:-1] + ["channel"]
        aics += [aic]
    AICs = pd.concat(aics)
    assert set(summary_df.index.unique('model')) == set(
        AICs.index.unique('model')
    )
    # sort except model, channel
    AICs.sort_index(
        axis=0,
        level=[l for l in AICs.index.names if not l in ['model', 'channel']],
        sort_remaining=False,
        inplace=True,
    )
    # calculate AIC_min for the fitted models at each time, channel
    AICs['min_delta'] = np.inf  # init to float
    # time label is the first index level, may not be fitgrid.defaults.TIME
    assert AICs.index.names == summary_df.index.names[:2] + ["channel"]
    for time in AICs.index.get_level_values(0).unique():
        for chan in AICs.index.get_level_values('channel').unique():
            slicer = pd.IndexSlice[time, :, chan]
            AICs.loc[slicer, 'min_delta'] = AICs.loc[slicer, 'AIC'] - min(
                AICs.loc[slicer, 'AIC']
            )
    FutureWarning('fitgrid AICs are in early days, subject to change')
    assert set(summary_df.index.unique('model')) == set(
        AICs.index.unique('model')
    )
    return AICs
[docs]def summaries_fdr_control(
    model_summary_df,
    method="BY",
    rate=0.05,
    plot_pvalues=True,
):
    r"""False discovery rate control for non-zero betas in model summary dataframes
    The family of tests for FDR control is assumed to be **all and
    only** the channels, models, and :math:`\hat{\beta}_i` in the
    summary dataframe.  The user must select the appropriate family of
    tests by slicing or stacking summary dataframes before running the
    FDR calculator.
    Parameters
    ----------
    model_summary_df : pandas.DataFrame
        As returned by `fitgrid.utils.summary.summarize`.
    method : str {"BY", "BH"}
        BY (default) is from Benjamini and Yekatuli [1]_, BH is Benjamini and
        Hochberg [2]_.
    rate : float {0.05}
        The target rate for controlling false discoveries.
    plot_pvalues : bool {True, False}
        Display a plot of the family of $p$-values and critical value for FDR control.
    References
    ----------
    .. [1] Benjamini, Y., & Yekutieli, D. (2001). The control of
           the false discovery rate in multiple testing under
           dependency.The Annals of Statistics, 29, 1165-1188.
    .. [2] Benjamini, Y., & Hochberg, Y. (1995). Controlling the
           false discovery rate: A practical and powerful approach to
           multiple testing. Journal of the Royal Statistical
           Society. Series B (Methodological), 57, 289-300.
    """
    _check_summary_df(model_summary_df, None)
    pvals_df = model_summary_df.query("key == 'P-val'")  # fetch pvals
    pvals = np.sort(pvals_df.to_numpy().flatten())
    m = len(pvals)
    ks = list()
    if method == 'BH':
        # Benjamini & Hochberg ... restricted
        c_m = 1
    elif method == 'BY':
        # Benjamini & Yekatuli general case
        c_m = np.sum([1 / i for i in range(1, m + 1)])
    else:
        raise ValueError("method must be 'BH' or 'BY'")
    for k, p in enumerate(pvals):
        kmcm = k / (m * c_m)
        if p <= kmcm * rate:
            ks.append(k)
    if len(ks) > 0:
        crit_p = pvals[max(ks)]
        crit_p_idx = np.where(pvals < crit_p)[0].max()
    else:
        crit_p = 0.0
        crit_p_idx = 0
    n_pvals = len(pvals)
    fdr_specs = {
        "method": method,
        "rate": rate,
        "crit_p": crit_p,
        "n_pvals": n_pvals,
        "models": list(pvals_df.index.unique('model')),
        "betas": list(pvals_df.index.unique('beta')),
        "channels": list(pvals_df.columns),
    }
    fig, ax = None, None
    if plot_pvalues:
        fig, ax = plt.subplots()
        ax.set_title("Distribution of $p$-values")
        ax.plot(np.arange(m), pvals, color="k")
        ax.axhline(crit_p, xmax=crit_p_idx, ls="--", color="k")
        ax.axvline(crit_p_idx, ymax=0.5, ls="--", color="k")
        ax.annotate(
            xy=(crit_p_idx, 0.525),
            text=f"critcal $p$={crit_p:0.5f} for {method} FDR {rate}",
            ha="left",
        )
        ax.text(
            x=0.0,
            y=-0.15,
            s=pp.pformat(fdr_specs, compact=True),
            va="top",
            ha="left",
            transform=ax.transAxes,
            wrap=True,
        )
    else:
        fig, ax = None, None
    return fdr_specs, fig, ax 
[docs]def plot_betas(
    summary_df,
    LHS=[],
    models=[],
    betas=[],
    interval=[],
    beta_plot_kw={},
    show_se=True,
    show_warnings=True,
    fdr_kw={},
    fig_kw={},
    df_func=None,
    scatter_size=75,
    **kwargs,
):
    """Plot  model parameter estimates for model, beta, and channel LHS
    The time course of estimated betas and standard errors is plotted
    by channel for the models, betas, and channels in the data
    frame. Channels, models, betas and time intervals may be selected
    from the summary dataframe. Plots are marked with model fit
    warnings by default and may be tagged to indicate differences from
    0 controlled for false discovery rate (FDR).
    Parameters
    ----------
    summary_df : pd.DataFrame
       as returned by fitgrid.utils.summary.summarize
    LHS : list of str or []
       column names of the data, [] default = all channels
    models : list of str or []
       select model or model betas to display, [] default = all models
    betas : list of str [] or []
       select beta or betas to plot,  [] default = all betas
    interval : [start, stop] list of two ints
       time interval to plot
    beta_plot_kw : dict
       keyword arguments passed to matplotlib.axes.plot()
    show_se : bool
       toggle display of standard error shading (default = True)
    show_warnings : bool
       toggle display of model warnings (default = True)
    fdr_kw : dict (default empty)
        One or more keyword arguments passed to ``summaries_fdr_control()`` to trigger
        to tag plots for FDR controlled differences from 0.
    fig_kw : dict
       keyword args passed to pyplot.subplots()
    df_func : {None, function}
        toggle degrees of freedom line plot via function, e.g.,
        ``np.log10``, ``lambda x: x``
    scatter_size : float
       scatterplot marker size for FDR (default = 75) and warnings (= 1.5 scatter_size)
    Returns
    -------
    figs : list of matplotlib.Figure
    Note
    ----
    The FDR family of tests is given by all channels, models, betas,
    and times in the summary data frame regardless of which of these
    are selected for plotting. To specify a different family of tests,
    construct a summary dataframe with all and only the tests for that
    family before plotting the betas.
    """
    _check_summary_df(summary_df, None)
    # fitgrid < 0.5.0
    for kwarg in ["figsize", "fdr", "alpha", "s"]:
        if kwarg in kwargs.keys():
            msg = (
                "keyword {kwarg} is deprecated in fitgrid 0.5.0, has no effect and "
                "will be removed. See figrid.utils.summary.plot_betas() documentation."
            )
            warnings.warn(msg, FutureWarning)
    # ------------------------------------------------------------
    # validate kwargs
    error_msg = None
    # LHS defaults to all channels
    if isinstance(LHS, str):
        LHS = [LHS]
    if LHS == []:
        LHS = list(summary_df.columns)
    if not all([isinstance(col, str) for col in LHS]):
        error_msg = "LHS must be a list of channel name strings"
    for channel in LHS:
        if channel not in summary_df.columns:
            error_msg = f"channel {channel} not found in the summary columns"
    # model, beta
    for key, vals in {"model": models, "beta": betas}.items():
        if vals and not (
            isinstance(vals, list)
            and all([isinstance(itm, str) for itm in vals])
        ):
            error_msg = f"{key} must be a list of strings"
        unique_vals = list(summary_df.index.unique(key))
        for val in vals:
            if val not in unique_vals:
                error_msg = (
                    f"{val} not found, check the summary index: "
                    f"name={key}, labels={unique_vals}"
                )
    # validate interval
    if interval:
        t_min = summary_df.index[0][0]
        t_max = summary_df.index[-1][0]
        if not (
            isinstance(interval, list)
            and all([isinstance(t, int) for t in interval])
            and interval[0] < interval[1]
            and interval[0] >= t_min
            and interval[1] <= t_max
        ):
            error_msg = (
                "interval must be a list of increasing integers "
                f"in the summary time range between {t_min} and {t_max}."
            )
    # fail on any error
    if error_msg:
        raise ValueError(error_msg)
    # ------------------------------------------------------------
    # filter summary for selections, if any
    # summary_df.sort_index(inplace=True)
    model_summary_df = summary_df  # a reference may be all we need
    if not LHS == list(summary_df.columns):
        model_summary_df = summary_df[LHS].copy()
    if models:
        model_summary_df = model_summary_df.query("model in @models").copy()
    if betas:
        model_summary_df = model_summary_df.query("beta in @betas").copy()
    if interval:
        model_summary_df.sort_index(inplace=True)
        model_summary_df = model_summary_df.loc[
            # Index = time, model, beta, key
            pd.IndexSlice[
                interval[0] : interval[1],
                :,
                :,
                :,
            ],
            :,
        ]
    models = list(model_summary_df.index.unique("model"))
    _time = model_summary_df.index.names[0]
    # ------------------------------------------------------------
    # optional FDR calc
    if fdr_kw:
        # the family of tests for FDR is given by the summary data, *not*
        # which slices happen to be selected for plotting.
        fdr_specs, fdr_fig, fdr_ax = summaries_fdr_control(
            summary_df, **fdr_kw
        )
        if not summary_df.equals(model_summary_df):
            fdr_msg = (
                "FDR test family is for **ALL** models, betas, and channels in "
                "the summary dataframe not just those selected for plotting."
            )
            warnings.warn(fdr_msg)
        print(pp.pformat(fdr_specs, compact=True))
    # ------------------------------------------------------------
    # set up to plot various warnings consistently
    warning_kinds = np.unique(
        np.hstack(
            [
                w.split("_") if len(w) else []
                for w in np.unique(summary_df.query("key=='warnings'"))
            ]
        )
    )
    warning_colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
    warning_cycler = cy(color=warning_colors) + cy(
        marker=Line2D.filled_markers[: len(warning_colors)]
    )  # [1:len(warning_colors) + 1])
    # build the dict as warning keys are encountered, then use them as styled
    cy_iter = iter(warning_cycler)
    warning_styles = defaultdict(lambda: next(cy_iter))
    for warning_kind in warning_kinds:
        print(warning_kind)
    # ------------------------------------------------------------
    # set up figures
    figs = list()
    for model, col in itertools.product(models, LHS):
        # select beta for this model
        for beta in model_summary_df.query("model == @model").index.unique(
            "beta"
        ):
            # start the fig, ax
            if "figsize" not in fig_kw.keys():
                fig_kw["figsize"] = (8, 3)  # default
            f, ax_beta = plt.subplots(nrows=1, ncols=1, **fig_kw)
            # unstack this beta as a column for plotting
            fg_beta = (
                model_summary_df.loc[pd.IndexSlice[:, model, beta], col]
                .unstack(level='key')
                .reset_index(_time)  # time label for this model_summary_df
            )
            fg_beta.plot(
                x=_time,
                y='Estimate',
                ax=ax_beta,
                color='black',
                alpha=0.5,
                label=beta,
                **beta_plot_kw,
            )
            # optional +/- SE band
            if show_se:
                beta_hat = fg_beta['Estimate']
                ax_beta.fill_between(
                    x=fg_beta[_time],
                    y1=(beta_hat + fg_beta["SE"]).astype(float),
                    y2=(beta_hat - fg_beta["SE"]).astype(float),
                    alpha=0.2,
                    color='black',
                )
            # optional (transformed) degrees of freedom
            if df_func is not None:
                try:
                    func_name = getattr(df_func, "__name__")
                except AttributeError:
                    func_name = str(df_func)
                fg_beta['DF_'] = fg_beta['DF'].apply(lambda x: df_func(x))
                fg_beta.plot(
                    x=_time, y='DF_', ax=ax_beta, label=f"{func_name}(df)"
                )
            # FDR controlled differences from 0
            if fdr_kw:
                fdr_mask = fg_beta["P-val"] < fdr_specs["crit_p"]
                ax_beta.scatter(
                    fg_beta[_time][fdr_mask],
                    fg_beta['Estimate'][fdr_mask],
                    color='black',
                    zorder=3,
                    label=f"{fdr_specs['method']} FDR p < crit {fdr_specs['crit_p']:0.2}",
                    s=scatter_size,
                )
            # warnings
            if show_warnings and any(
                [len(warning) for warning in fg_beta["warnings"]]
            ):
                warn_strs = np.hstack(
                    [
                        np.array(wrn.split("_"))  # lengths vary
                        for wrn in fg_beta["warnings"].unique()
                        if len(wrn) > 0
                    ]
                )
                warn_strs = sorted(np.unique(warn_strs))
                for warn_str in warn_strs:
                    # separate warnings by 1/4 major tick interval
                    sep = np.abs(
                        (ax_beta.get_yticks()[:2] * [0.25, -0.25]).sum()
                    )
                    # warn_offset = (warning_kinds.index(warn_str) + 1) * sep
                    warn_offset = (
                        np.where(warning_kinds == warn_str)[0] + 1
                    ) * sep
                    warn_mask = fg_beta["warnings"].apply(
                        lambda x: warn_str in x
                    )
                    ax_beta.scatter(
                        fg_beta[_time][warn_mask],
                        fg_beta['Estimate'][warn_mask] + warn_offset,
                        zorder=4,
                        label=warn_str,
                        **warning_styles[
                            warn_str
                        ],  # cycler + defaultdict voodoo
                        alpha=0.75,
                        s=scatter_size * 1.5,
                    )
            ax_beta.axhline(y=0, linestyle='--', color='black')
            ax_beta.legend(loc='upper left', bbox_to_anchor=(0.0, -0.25))
            formula = fg_beta.index.get_level_values('model').unique()[0]
            ax_beta.set_title(f'{col} {beta}: {formula}', loc='left')
            figs.append(f)
    return figs 
[docs]def plot_AICmin_deltas(
    summary_df,
    show_warnings="no_labels",
    figsize=None,
    gridspec_kw=None,
    subplot_kw=None,
):
    r"""plot FitGrid min delta AICs and fitter warnings
    Thresholds of AIC_min delta at 2, 4, 7, 10 are from Burnham &
    Anderson 2004, see Notes.
    Parameters
    ----------
    summary_df : pd.DataFrame
       as returned by fitgrid.utils.summary.summarize
    show_warnings : {"no_labels", "labels", str, list of str}
       "no_labels" (default) highlights everywhere there is any warning in
       red, the default behavior in fitgrid < v0.5.0. "labels" display
       all warning strings the axes titles.  A `str` or list of `str` selects
       and display only warnings that (partial) match a model warning string.
    figsize : 2-ple
       pyplot.figure figure size parameter
    gridspec_kw : dict
       matplotlib.gridspec keyword args passed to ``pyplot.subplots(...,
       gridspec_kw=gridspec_kw})``
    subplot_kw : dict
       keyword args passed to ``pyplot.subplots(..., subplot_kw=subplot_kw))``
    Returns
    -------
    f, axs : matplotlib.pyplot.Figure
    Notes
    -----
       [BurAnd2004]_ p. 270-271. Where :math:`AIC_{min}` is the
       lowest AIC value for "a set of a priori candidate models
       well-supported by the underlying science :math:`g_{i}, i = 1,
       2, ..., R)`",
       .. math:: \Delta_{i} = AIC_{i} - AIC_{min}
       "is the information loss experienced if we are using fitted
       model :math:`g_{i}` rather than the best model, :math:`g_{min}`
       for inference." ...
       "Some simple rules of thumb are often useful in assessing the
       relative merits of models in the set: Models having
       :math:`\Delta_{i} <= 2` have substantial support (evidence), those
       in which :math:`\Delta_{i} 4 <= 7` have considerably less support,
       and models having :math:`\Delta_{i} > 10` have essentially no
       support."
    """
    def _get_warnings_grid(model_warnings, show_warnings):
        """look up warnings according to aic and user kwarg value"""
        # split the "_" separated multiple warning strings into unique types
        warning_kinds = np.unique(
            np.hstack(
                [
                    w.split("_") if len(w) else []
                    for w in np.unique(model_warnings)
                ]
            )
        )
        warning_kinds = list(warning_kinds)
        # optionally filter by user keyword matching
        if show_warnings not in ["no_labels", "labels"]:
            user_kinds = []
            for kw_warning in show_warnings:
                found_kinds = [
                    warning_kind
                    for warning_kind in warning_kinds
                    if kw_warning in warning_kind  # string matches
                ]
                # collect the matching kinds or warn
                if found_kinds:
                    user_kinds += found_kinds
                else:
                    msg = (
                        f"show_warnings '{kw_warning}' not found in model "
                        f"{m} warnings: [{', '.join(warning_kinds)}]"
                    )
                    warnings.warn(msg)
            # update filtered kinds
            warning_kinds = user_kinds
        # build indicator grid for matching warning kinds
        warnings_grid = model_warnings.applymap(
            lambda x: 1
            if any([warning_kind in x for warning_kind in warning_kinds])
            else 0
        )
        return warning_kinds, warnings_grid
    # ------------------------------------------------------------
    # validate kwarg
    if show_warnings not in ["no_labels", "labels"]:
        # promote string to list
        show_warnings = list(np.atleast_1d(show_warnings))
        if not all([isinstance(wrn, str) for wrn in show_warnings]):
            msg = (
                "show_warnings must be 'all', 'kinds', or a string or list of strings "
                "that partial match warnings"
            )
            raise ValueError(msg)
    # validate summary dataframe
    _check_summary_df(summary_df, None)
    _time = summary_df.index.names[0]
    # fetch the AIC min delta data
    aics = _get_AICs(summary_df)  # long format
    models = aics.index.unique('model')
    channels = aics.index.unique('channel')
    # ------------------------------------------------------------
    # figure setup
    if figsize is None:
        figsize = (12, 8)
    # reasonable default, update w/ user kwargs if any
    gspec_kw = {'width_ratios': [0.46, 0.46, 0.015]}
    if gridspec_kw:
        gspec_kw.update(gridspec_kw)
    # main figure, axes: number of models rows x 3 columns: traces, raster, colorbar
    f, axs = plt.subplots(
        len(models),  # 1 axis row per model
        3,
        squeeze=False,  # keep axes shape (1, 3), though singleton model is pointless
        figsize=figsize,
        gridspec_kw=gspec_kw,
        subplot_kw=subplot_kw,
    )
    # plot each model on an axes row
    for i, m in enumerate(models):
        traces = axs[i, 0]
        heatmap = axs[i, 1]
        colorbar = axs[i, 2]
        # ------------------------------------------------------------
        # slice this model min delta values and warnings
        _min_deltas = (
            aics.loc[pd.IndexSlice[:, m, :], 'min_delta']
            .reset_index('model', drop=True)
            .unstack('channel')
            .astype(float)
        )
        # unstack() alphanum sorts the channel index ... ugh
        _min_deltas = _min_deltas.reindex(columns=channels)
        model_warnings = (
            aics.loc[pd.IndexSlice[:, m, :], 'warnings']
            .reset_index('model', drop=True)
            .unstack('channel')
        )
        model_warnings = model_warnings.reindex(columns=channels)
        # fetch warnings for heatmapping
        warning_kinds, warnings_grid = _get_warnings_grid(
            model_warnings, show_warnings
        )
        # ------------------------------------------------------------
        # plot traces and warnings in left column
        # left column title is model with optional list of warnings
        title_str = f"{m}"
        if warning_kinds and not show_warnings == "no_labels":
            title_str += "\n" + "\n".join(warning_kinds)
        traces.set_title(title_str, loc="left")
        for chan in channels:
            traces.plot(
                _min_deltas.reset_index()[_time], _min_deltas[chan], label=chan
            )
            # warning mask
            chan_mask = (
                _min_deltas[chan].where(warnings_grid[chan] == 1).dropna()
            )
            traces.scatter(
                chan_mask.index,
                chan_mask,
                c="crimson",
                label=None,
            )
        if i == 0:
            # first channel legend left of the main plot
            traces.legend()
            traces.legend(
                loc='upper right',
                bbox_to_anchor=(-0.2, 1.0),
                handles=traces.get_legend().legendHandles[::-1],
            )
        aic_min_delta_bounds = [0, 2, 4, 7, 10]
        # for y in aic_min_delta_bounds:
        for y in aic_min_delta_bounds[1:]:
            traces.axhline(y=y, color='black', linestyle='dotted')
        # ------------------------------------------------------------
        # heatmap
        # colorbrewer 2.0 Blues color blind safe n=5
        # http://colorbrewer2.org/#type=sequential&scheme=Blues&n=5
        pal = ['#eff3ff', '#bdd7e7', '#6baed6', '#3182bd', '#08519c']
        cmap = mpl.colors.ListedColormap(pal)
        # cmap.set_over(color='#fcae91')
        cmap.set_over(color='#08306b')  # darkest from Blues n=7
        cmap.set_under(color='lightgray')
        bounds = aic_min_delta_bounds
        norm = mpl.colors.BoundaryNorm(bounds, cmap.N)
        ylabels = _min_deltas.columns
        heatmap.yaxis.set_major_locator(
            mpl.ticker.FixedLocator(np.arange(len(ylabels)))
        )
        heatmap.yaxis.set_major_formatter(mpl.ticker.FixedFormatter(ylabels))
        im = heatmap.pcolormesh(
            _min_deltas.index,
            np.arange(len(ylabels)),
            _min_deltas.T,
            cmap=cmap,
            norm=norm,
            shading='nearest',
        )
        # any non-zero warnings are red
        if warnings_grid.to_numpy().max():
            assert (warnings_grid.index == _min_deltas.index).all()
            assert (warnings_grid.columns == _min_deltas.columns).all()
            heatmap.pcolormesh(
                warnings_grid.index,
                np.arange(len(ylabels)),
                np.ma.masked_equal(warnings_grid.T.to_numpy(), 0),
                shading="nearest",
                cmap=mpl.colors.ListedColormap(['red']),
            )
        colorbar = mpl.colorbar.Colorbar(colorbar, im, extend='max')
    return f, axs