Source code for spudtr.epf

"""utilities for epoched EEG data in a pandas.DataFrame """
from pathlib import Path
import warnings
import numpy as np
import pandas as pd
import bottleneck as bn

from spudtr.filters import _design_firwin_filter, fir_filter_dt

EPOCH_ID = "epoch_id"  # default epoch ID column
TIME = "time"  # default time column


def _validate_epochs_df(epochs_df, epoch_id=EPOCH_ID, time=TIME):
    """check form and index of the epochs_df is as expected

    Parameters
    ----------
    epochs_df : pd.DataFrame

    epoch_id : str (optional, default=epf.EPOCH_ID)
        column name for epoch indexes

    time: str (optional, default=epf.TIME)
        column name for time stamps

    """
    for key, val in {"epoch_id": epoch_id, "time": time}.items():
        if val not in epochs_df.columns:
            raise ValueError(f"{key} column not found: {val}")


def _epochs_QC(epochs_df, data_streams, epoch_id=EPOCH_ID, time=TIME):
    """Quality control for spudtr format epochs, returns epochs_df on success"""

    # epochs_df must be a Pandas DataFrame.
    if not isinstance(epochs_df, pd.DataFrame):
        raise ValueError("epochs_df must be a Pandas DataFrame.")

    # data_streams must be a list of strings
    if not isinstance(data_streams, list) or not all(
        isinstance(item, str) for item in data_streams
    ):
        raise ValueError("data_streams should be a list of strings.")

    # all channels must be present as epochs_df columns
    missing_channels = set(data_streams) - set(epochs_df.columns)
    if missing_channels:
        raise ValueError(
            "data_streams should all be present in the epochs dataframe, "
            f"the following are missing: {list(missing_channels)}"
        )

    # check no duplicate column names in index and regular columns
    names = list(epochs_df.index.names) + list(epochs_df.columns)
    if len(names) != len(set(names)):
        raise ValueError("Duplicate column names not allowed.")

    # epoch_id and time must be the columns in the epochs_df
    _validate_epochs_df(epochs_df, epoch_id=epoch_id, time=time)

    # check values of epoch_id in every time group are the same, and
    # unique in each time group. Make our own copy so we are immune to
    # modification to original table
    table = epochs_df.copy().reset_index().set_index(epoch_id).sort_index()
    assert table.index.names == [epoch_id]

    snapshots = table.groupby([time])

    # check that snapshots across epochs have equal index by transitivity
    prev_group = None
    for idx, cur_group in snapshots:
        if prev_group is not None:
            if not prev_group.index.equals(cur_group.index):
                raise ValueError(
                    f"Snapshot {idx} differs from "
                    f"previous snapshot in {epoch_id} index:\n"
                    f"Current snapshot's indices:\n"
                    f"{cur_group.index}\n"
                    f"Previous snapshot's indices:\n"
                    f"{prev_group.index}"
                )
        prev_group = cur_group

    def list_duplicates(seq):
        seen = set()
        seen_add = seen.add
        # adds all elements it doesn't know yet to seen and all other to seen_twice
        seen_twice = set(x for x in seq if x in seen or seen_add(x))
        # turn the set into a list (as requested)
        return list(seen_twice)

    if not prev_group.index.is_unique:
        dupes = list_duplicates(list(prev_group.index))
        raise ValueError(
            f"Duplicate values of epoch_id in each" f"time group not allowed:\n{dupes}"
        )
    return epochs_df


def _hdf_read_epochs(epochs_f, h5_group, epoch_id=EPOCH_ID, time=TIME):
    """read tabular hdf5 epochs file, return as pd.DataFrame

    .. deprecated:: 0.0.9
       Use native pandas HDF functions instead. Will be removed in 0.0.11


    Parameters
    ----------
    epochs_f : str
        name of the recorded epochs file to load

    h5_group : str
        name of h5 group key

    Return
    ------
    df : pd.DataFrame
        columns in INDEX_NAMES are pd.MultiIndex axis 0
    """
    warnings.warn(
        "_hdf_read_epochs() is unused, untested, and deprecated in spudtr.epf v0.0.9 and will be removed in v0.0.11",
        DeprecationWarning,
    )

    if h5_group is None:
        raise ValueError("You have to give h5_group key")
    else:
        epochs_df = pd.read_hdf(epochs_f, h5_group)

    _validate_epochs_df(epochs_df, epoch_id=epoch_id, time=time)
    return epochs_df


def _find_subscript(times, start, stop):
    """start stop interval includes end both end time stamps

    This makes the timestamp interval open left and right, 
    [start, stop] when slicing with pandas and open left, 
    closed right, [start, stop) when slicing with numpy.
    """
    istart = np.where(times >= start)[0]
    if len(istart) == 0:
        raise ValueError(
            "start is too large (%s), it exceeds the largest " "time value" % (start,)
        )
    istart = int(istart[0])

    istop = np.where(times <= stop)[0]
    if len(istop) == 0:
        raise ValueError(
            "stop is too small (%s), it is smaller than the "
            "smallest time value" % (stop,)
        )
    istop = int(istop[-1])
    if istart >= istop:
        raise ValueError(
            "Bad rescaling slice (%s:%s) from time values %s, %s"
            % (istart, istop, start, stop)
        )
    return istart, istop


# ------------------------------------------------------------
# user API
# ------------------------------------------------------------


[docs]def check_epochs(epochs_df, data_streams, epoch_id=EPOCH_ID, time=TIME): """check epochs data are in spudtr format Parameters ---------- epochs_df : pd.DataFrame data_streams: list of str the columns containing data epoch_id : str, optional column name for the epoch index time: str, optional column name for the time stamps Raises ------ Exception diagnostic for what went wrong """ _ = _epochs_QC(epochs_df, data_streams, epoch_id=epoch_id, time=time)
[docs]def center_eeg(epochs_df, eeg_streams, start, stop, epoch_id=EPOCH_ID, time=TIME): """center (a.k.a. "baseline") EEG amplitude on mean amplitude in [start, stop) Parameters ---------- epochs_df : pd.DataFrame must have epoch_id and time columns eeg_streams: list of str column names to apply the transform start,stop : int basline interval time values, `start <= t <= stop` epoch_id : str, optional column to use for the epoch index time : str, optional column to use for the time stamp index Returns ------- centered_epochs_df : pd.DataFrame each epoch and channel time series centered on the [start, stop) interval mean amplitude Notes ----- The `start`, `stop` values pick the smallest and largest timestamps in the interval, i.e., [start_stamp, stop_stamp], but since the data are sliced with np.arange, the upper bound is not included, i.e., start_stamp <= timestamps < stop_stamp. So, for instance, start=-200, stop=0, would include timestamps at -200, -199, ... -1, but not 0. """ _epochs_QC(epochs_df, eeg_streams, epoch_id=epoch_id, time=time) # calculate the row-index vector to slice the centering intervals n_times = len(epochs_df[time].unique()) n_epochs = len(epochs_df[epoch_id].unique()) times = epochs_df[time].unique() istart, istop = _find_subscript(times, start, stop) center_idxs = np.array( [ np.arange(istart + (i * n_times), istop + (i * n_times)) for i in range(n_epochs) ] ).flatten() # use pandas iloc index slicing then groupby epoch_id to compute means mns = epochs_df.iloc[center_idxs, :].groupby(epoch_id)[eeg_streams].mean() # inflate the means to the shape of the data and subtract in place, not sure if view() saves memory centered_epochs_df = epochs_df.copy() centered_epochs_df[eeg_streams] -= np.repeat(mns.to_numpy().view(), n_times, axis=0) return centered_epochs_df
[docs]def drop_bad_epochs(epochs_df, bads_column, epoch_id=EPOCH_ID, time=EPOCH_ID): """Quality control data slicer, excludes previously tagged artifact epochs All epochs tagged with a non-zero quality code on the specified `bads_column` at the time stamp == 0 are excluded. Parameters ---------- epochs_df : pd.DataFrame must have epoch_id and time row index names bads_column : str column name with QC codes: non-zero == drop epoch_id : str, optional column name for epoch indexes time: str, optional column name for time stamps Returns ------- good_epochs_df : pd.DataFrame subset of the epochs with code 0 on `bads_column` at timestamp == 0 """ _epochs_QC(epochs_df, [bads_column], epoch_id=epoch_id, time=time) # get the group of time == 0 group = epochs_df.groupby([time]).get_group(0) good_idx = list(group[epoch_id][group[bads_column] == 0]) good_epochs_df = epochs_df[epochs_df[epoch_id].isin(good_idx)].copy() return good_epochs_df
[docs]def re_reference(epochs_df, eeg_streams, ref, ref_type, epoch_id=EPOCH_ID, time=TIME): """Convert EEG data recorded with a common reference to a different reference .. warning:: These transforms are intended for use with common reference EEG data. Use with other types of data are at your own risk. Parameters ---------- epochs_df : pd.DataFrame must have epoch_id and time row index names eeg_streams : list-like of str the names of colums to transform ref : str or list-like of str name of the 2nd stream for a linked pair, the new common reference, or the complete list of streams to use for a common average reference type : str = {'linked_pair', 'new_common', 'common_average'} epoch_id : str, optional time : str, optional Returns ------- pd.DataFrame a copy of epochs_df with `eeg_streams` re-referenced Note ---- `linked_pair` Transforms the EEG data to a "linked" pair reference: .. math:: EEG_{\\text{re-referenced}} = EEG - (0.5 \\times EEG_{ref}) May be used to switch from an A1 left mastoid common reference to a common linked A1, A2 mastoid reference ("bimastoid"). `new_common` Transforms EEG to a different common reference location: .. math:: EEG_{\\text{re-referenced}} = EEG - EEG_{ref} May be used switch from an A1 common reference to a vertex or nose-tip reference. `common_average` Transforms EEG to a common average reference of :math:`N` EEG reference streams .. math:: EEG_{\\text{re-referenced}} = EEG - \\frac{\\sum_{i=0}^{i=N}{EEG_{ref[i]}}}{N} Examples -------- Switch from A1 reference to linked-mastoids >>> eeg_streams = ['MiPf', 'MiCe', 'MiPa', 'MiOc'] >>> re_reference(epochs_df, eeg_streams, 'A2', 'linked_pair') Switch to a vertex reference, MiCe >>> eeg_streams = ['MiPf', 'MiCe', 'MiPa', 'MiOc'] >>> br_epochs_df = epf.re_reference(epochs_df, eeg_streams, 'MiCe', "new_common") Switch to a common average reference (typically all available EEG data streams) >>> eeg_streams = ['MiPf', 'MiCe', 'MiPa', 'MiOc'] >>> ref = eeg_streams >>> br_epochs_df = epf.re_reference(epochs_df, eeg_streams, ref, "common_average") """ _epochs_QC(epochs_df, eeg_streams, epoch_id=epoch_id, time=time) # ref must be a list of strings with len(ref)>1 for ref_type of 'common_average' if ref_type == "common_average": if not (isinstance(ref, list) and len(ref) > 1) or not all( isinstance(item, str) for item in ref ): raise ValueError( "ref should be a list of strings with length greater than 1." ) if isinstance(ref, list) and len(ref) == 1: ref = "".join(ref) if ref_type == "linked_pair": new_ref = epochs_df[ref] / 2.0 elif ref_type == "new_common": new_ref = epochs_df[ref] elif ref_type == "common_average": new_ref = epochs_df[ref].mean(axis=1) else: raise ValueError(f"unknown reference type: ref_type={ref_type}") br_epochs_df = epochs_df.copy() for col in eeg_streams: br_epochs_df[col] = br_epochs_df[col] - new_ref return br_epochs_df
[docs]def fir_filter_epochs( epochs_df, data_columns, ftype=None, cutoff_hz=None, width_hz=None, ripple_db=None, window=None, sfreq=None, trim_edges=False, epoch_id=EPOCH_ID, time=TIME, ): """apply FIRLS filtering to spudtr format epoched data Parameters ---------- epochs_df : pd.DataFrame must be a spudtr format epochs dataframe with epoch_id, time columns data_columns: list of str column names to apply the transform ftype : str {'lowpass' , 'highpass', 'bandpass', 'bandstop'} filter type cutoff_hz : float or 1D-array-like of floats, length 2 1/2 amplitude cutoff frequency in Hz width_hz : float pass-to-stop transition band width (Hz), symmetric for bandpass, bandstop ripple_db : float ripple, in dB, e.g., 53.0, 60.0 window : str {'kaiser','hamming','hann','blackman'} window type for firwin sfreq : float sampling frequency, e.g., 250.0, 500.0 trim_edges : bool True trim edges, False not trim edges epoch_id : str {"epoch_id"}, optional column name for epoch index time: str {"time"}, optional column name for timestamps Returns ------- pd.DataFrame a copy of epochs_df, with data in `data_columns` filtered Notes ----- All the filter parameters are mandatory, consider making a ``filter_params`` dictionary and expanding it like so ``fir_filter_epochs( ..., **filter_params)``. By default the filtered epochs have the same length as the original. The `trim_edges` option returns the center interval of each epoch, free from distortion at the edges but this may result in considerable data loss depending on the filter specifications. Examples -------- >>> ftype = "bandpass" >>> cutoff_hz = [18, 35] >>> sfreq = 250 >>> window = "kaiser" >>> width_hz = 5 >>> ripple_db = 60 >>> epoch_id = "epoch_id" >>> time = "time_ms" >>> filt_test_df = epochs_filters( epochs_df, data_columns, ftype=ftype, cutoff_hz=cutoff_hz, width_hz=width_hz, ripple_db=ripple_db, window=window, sfreq=sfreq, trim_edges=False epoch_id=epoch_id time=time ) """ # it is crucial to enforce the spudtr epochs format because trimming # needs to know about epoch boundaries and times _epochs_QC(epochs_df, data_columns, epoch_id=epoch_id, time=time) _fparams = dict( ftype=ftype, cutoff_hz=cutoff_hz, sfreq=sfreq, width_hz=width_hz, ripple_db=ripple_db, window=window, ) # build and apply the filter filt_epochs_df = fir_filter_dt(epochs_df, data_columns, **_fparams) # this trims edges in *each epoch*, 1/2 length of the filter if trim_edges: taps = _design_firwin_filter(**_fparams) n_edge = int(np.floor(len(taps) / 2.0)) times = filt_epochs_df[time].unique() start_good = times[n_edge] # first good sample stop_good = times[-(n_edge + 1)] # last good sample qstr = f"{time} >= @start_good and {time} <= @stop_good" filt_epochs_df = filt_epochs_df.query(qstr).copy() return filt_epochs_df