Source code for mkpy.mkio

"""Module mkio.py minor tweak of NJS dig format reader and (avg) writer

Code has four main sections/functions

  1. Setup
  2. Helper functions
  3. Readers
  4. Writers

"""

# --------
# 1. Setup
# --------
import struct
import numpy as np
import gzip
import math
import os
from mkpy._mkh5 import _decompress_crw_chunk
from mkpy import get_ver

# ----------
# 2. Helpers
# ----------

# Derived from erp/include/header.h:
# '<' denotes little-endianness
_header_dtype = np.dtype(
    [
        ("magic", "<u2"),
        ("epoch_len", "<i2"),  # epoch length in msec
        ("nchans", "<i2"),  # number of channels
        ("sums", "<i2"),  # 0 = ERP, 1 = single trial
        # ^^ 8 bytes
        ("tpfuncs", "<i2"),  # number of processing funcs
        ("pp10uv", "<i2"),  # points / 10 uV
        ("verpos", "<i2"),  # positive point positive voltage, -1 => opposite
        ("odelay", "<i2"),  # ms from trigger to stim (8 video, 4 audio)
        # ^^ 16 bytes
        ("totevnt", "<i2"),  # "total log events"
        ("10usec_per_tick", "<i2"),
        ("time", "<i4"),  # "time in sample clock ticks"
        # ^^ 24 bytes
        ("cond_code", "<i2"),
        ("presam", "<i2"),  # pre-event time in epoch in msec
        ("trfuncs", "<i2"),  # number of rejection functions
        ("totrr", "<i2"),  # total raw records including rejects
        # ^^ 32 bytes
        ("totrej", "<i2"),  # total raw rejects
        ("sbcode", "<i2"),  # "subcondition number (bin number)"
        ("cprecis", "<i2"),  # channel precision in # of 256 points blocks
        ("dummy1", "<i2"),  # placeholder for ovf_errors (see header)
        # ^^ 40 bytes
        ("decfact", "<i2"),  # decimation factor used in processing
        ("dh_flag", "<i2"),  # sets time resolution (see header defines)
        ("dh_item", "<i4"),  # sequential item #
        # ^^ 48 bytes
        ("rfcnts", "<i2", (8,)),  # ndividual rejection counts 8 poss. rfs
        ("rftypes", "S64"),  # 8 char. descs for 8 poss. rfs
        ("chndes", "S128"),
        ("subdes", "S40"),
        ("sbcdes", "S40"),
        ("condes", "S40"),
        ("expdes", "S40"),
        ("pftypes", "S24"),
        ("chndes2", "S40"),
        ("flags", "<u2"),  # see flag values in header
        ("nrawrecs", "<u2"),  # raw records if this is a raw file header
        ("idxofflow", "<u2"),
        ("idxoffhi", "<u2"),
        ("chndes3", "S24"),  # channel description size
    ]
)

# If, say, chndes has trailing null bytes, then rec["chndes"] will give us a
# less-than-128-byte string back. But this function always gives us the full
# 128 byte string, trailing nuls and all.
def _get_full_string(record, key):
    val = record[key]
    desired_len = record.dtype.fields[key][0].itemsize
    return val + (desired_len - len(val)) * b"\x00"  # TPU forced to byte


def _gzipped(stream):
    """Return True if stream is a gzip file."""

    initial_pos = stream.tell()
    gzip_magic = b"\x1f\x8b"
    file_magic = stream.read(2)
    stream.seek(initial_pos)  # rewind back 2 bytes

    return file_magic == gzip_magic


def _get_reader_for_magic(magic):
    """Return appropriate reader function based on the magic."""

    if magic == 0x17A5:
        return _read_raw_chunk
    elif magic == 0x97A5:
        return _read_compressed_chunk
    else:
        return None


def _is_valid_samplerate(hz):
    """Return True if sample rate is close to an integer, False otherwise."""

    closest_integer = round(hz, 0)

    if not math.isclose(hz, closest_integer, abs_tol=1e-6):
        return False
    else:
        return True


def _get_channel_names(header):
    """Extract list of channel names from header."""

    if header["nchans"] <= 16:
        dtype = "S8"
    elif header["nchans"] <= 32:
        dtype = "S4"
    else:
        raise NotImplementedError(
            "Channel name extraction for large " "montages not yet supported"
        )

    # return np.fromstring(_get_full_string(header, 'chndes'), dtype=dtype)
    return np.frombuffer(_get_full_string(header, "chndes"), dtype=dtype)


# -------
# Readers
# -------
def _read_header(stream):
    """Read header (the first 512 bytes) from file, return a subset of it.

    Parameters
    ----------
    stream : filestream
        .raw or .crw filestream

    Returns
    -------
    (reader, header["nchans"], hz, channel_names, info) : tuple

    where
        reader : function
            _read_raw_chunk or _read_compressed_chunk
        header["nchans"] : int
            number of data channels
        hz : float
            sampling frequency in samples per second
        channel_names : NumPy array of binary strings
            channel name codes, e.g. MiPf, LLPf, etc.
        info : dict
            dictionary with keys:
                name, magic, subdesc, expdesc,
                odelay, samplerate, recordduration,
                recordsize, nrawrecs, nchans
    """

    # read header from file and build NumPy data structure
    header_str = stream.read(512)

    # fromstring deprecated b.c. strange behavior on unicode
    # header = np.fromstring(header_str, dtype=_header_dtype)[0]
    header = np.frombuffer(header_str, dtype=_header_dtype)[0]

    # determine appropriate reader function
    reader = _get_reader_for_magic(header["magic"])
    if reader is None:
        raise ValueError(f'Bad magic number: {hex(header["magic"])}.')

    # calculate and validate sample rate
    hz = 1 / (header["10usec_per_tick"] / 100_000)
    if not _is_valid_samplerate(hz):
        raise ValueError(f"File claims weird non integer sample rate: {hz}.")

    # extract channel name codes from header
    channel_names = _get_channel_names(header)

    # TPU all 16 or 32 4-byte names come back, including trailing "" names
    # when nchan != 16 or 32. Drop the empty names so length of channel_names
    # agrees with header nchans
    assert all([len(chn) > 0 for chn in channel_names[: header["nchans"]]])
    assert all([len(chn) == 0 for chn in channel_names[header["nchans"] :]])
    channel_names = channel_names[: header["nchans"]]

    # capture complete and jsonifiable. new in 0.2.4
    raw_dig_header = dict()
    for key in header.dtype.names:
        val = header[key]
        if np.isscalar(val):
            val = val.item().decode("utf-8") if isinstance(val, bytes) else val.item()
        else:
            val = val.tolist()
        raw_dig_header[key] = val

    info = dict(
        {
            "name": "dig",
            "magic": header["magic"],
            "subdesc": header["subdes"],
            "expdesc": header["expdes"],
            "odelay": header["odelay"],
            "samplerate": hz,
            "recordduration": 256 / hz,
            "recordsize": 256,
            "nrawrecs": header["nrawrecs"],
            "nchans": header["nchans"],
            "mkh5_version": get_ver(),  # new in 0.2.4
            "raw_dig_header": raw_dig_header,
        }
    )

    return reader, header["nchans"], channel_names, info


[docs]def read_raw(stream, dtype): """parses bytestream of from kutaslab eeg file into usable data Returns ------- (channel_names, np.array(all_codes, dtype=np.int16), np.array(record_counts, dtype=np.int16), final_data, info) all_codes -- a vector of event codes and record indices from the mark track final_data -- a np.array: samples (rows) x eeg channels (columns) """ if _gzipped(stream): stream = gzip.GzipFile(mode="r", fileobj=stream) reader, nchans, channel_names, info = _read_header(stream) # NJS. Data is stored in a series of "chunks" -- each chunk # contains 256 s16 samples from each channel (the 32/64/whatever # analog channels, plus 1 channel for codes -- that channel being # first.). The code channel contains a "record number" as its # first entry in each chunk, which simply increments by 1 each # time. all_codes = [] data_chunks = [] chunk_bytes = (nchans + 1) * 512 chunkno = 0 record_counts = [] while True: read = reader(stream, nchans) if read is None: break (codes_chunk, data_chunk) = read assert len(codes_chunk) == 256 assert data_chunk.shape == (256 * nchans,) assert codes_chunk[0] == chunkno # codes_chunk[0] = 65535 ## NJS overwrote record counter record_counts.append( codes_chunk[0] ) # track for sanity checks and later processing codes_chunk[ 0 ] = 0 # clear the record count so marktrack has all and only event codes TPU all_codes += codes_chunk data_chunk.resize((256, nchans)) data_chunks.append(np.array(data_chunk, dtype=dtype)) chunkno += 1 final_data = np.vstack(data_chunks) # TPU ... changed all_codes, dtype=np.uint16 -> np.int16 all_codes = np.array(all_codes, dtype=np.int16) return channel_names, all_codes, record_counts, final_data, info
def _read_raw_chunk(stream, nchans): """reads a kutaslab .raw eeg data record bytestream, returns (mark track event codes, vector of eeg data) """ chunk_bytes = (nchans + 1) * 512 buf = stream.read(chunk_bytes) # Check for EOF: if not buf: return None codes_list = list(struct.unpack("<256H", buf[:512])) # data_chunk = np.fromstring(buf[512:], dtype="<i2") data_chunk = np.frombuffer(buf[512:], dtype="<i2") return (codes_list, data_chunk) def _read_compressed_chunk(stream, nchans): """decompresses record of kutaslab .crw eeg data bytestream""" # Check for EOF: ncode_records_minus_one_buf = stream.read(1) if not ncode_records_minus_one_buf: return None # Code track (run length encoded): (ncode_records_minus_one,) = struct.unpack("<B", ncode_records_minus_one_buf) ncode_records = ncode_records_minus_one + 1 code_records = [] for i in range(ncode_records): code_records.append(struct.unpack("<BH", stream.read(3))) codes_list = [] for (repeat_minus_one, code) in code_records: codes_list += [code] * (repeat_minus_one + 1) assert len(codes_list) == 256 # Data bytes (delta encoded and packed into variable-length integers): (ncompressed_words,) = struct.unpack("<H", stream.read(2)) compressed_data = stream.read(ncompressed_words * 2) data_chunk = _decompress_crw_chunk(compressed_data, ncompressed_words, nchans) return (codes_list, data_chunk)
[docs]def read_log(fo): """generator reads kutaslab binary log, returns (code, tick, condition, flag) Parameters ---------- fo : file object flags values # avg -x sets 0 = OK, 20 = artifact, 40 = polinv, 60 = polinv + artifact # cdbl -op also sets flags according to the bdf # 100 = data error (rare) """ while True: event = fo.read(8) if not event: return # NJS # (code, tick_hi, tick_lo, condition, flag) \ # = struct.unpack("<HHHBB", event) # TPU ... 2-byte event codes can be negative, i.e. short (code, tick_hi, tick_lo, condition, flag) = struct.unpack("<hHHBB", event) yield (code, (tick_hi << 16 | tick_lo), condition, flag) # NJS
[docs]def load(f_raw, f_log, dtype=np.float64, delete_channels=[], calibrate=True, **kwargs): # read the raw and sanity check the records ... channel_names, raw_codes, record_counts, data, info = read_raw(f_raw, dtype) assert all(record_counts == np.arange(len(record_counts))) # read the log codes_from_log = np.zeros(raw_codes.shape, dtype=raw_codes.dtype) for (code, tick, condition, flag) in read_log(f_log): codes_from_log[tick] = code discrepancies = codes_from_log != raw_codes assert (codes_from_log[discrepancies] == 0).all() assert (raw_codes[discrepancies] == 65535).all() if delete_channels: # fast-path: no need to do a copy if nothing to delete keep_channels = [] for i in range(len(channel_names)): if channel_names[i] not in delete_channels: keep_channels.append(i) assert len(keep_channels) + len(delete_channels) == len(channel_names) data = data[:, keep_channels] channel_names = channel_names[keep_channels] if calibrate: calibrate_in_place(data, raw_codes, **kwargs) return channel_names, raw_codes, data, info
# ------- # Writers # ------- # To write multiple "bins" to the same file, just call this function # repeatedly on the same stream. # NJS
[docs]def write_erp_as_avg(erp, stream): magic = "\xa5\x17" header = np.zeros(1, dtype=_header_dtype)[0] header["magic"] = 0x17A5 header["verpos"] = 1 # One avg record is always exactly 256 * cprecis samples long, with # cprecis = 1, 2, 3 (limitation of the data format). So we pick the # smallest cprecis that is <= our actual number of samples (maximum 3), # and then we resample to have that many samples exactly. (I.e., we try # to resample up when possible.) The kutaslab tools only do # integer-factor downsampling (decimation), and they write the decimation # factor to the file. I don't see how it matters for the .avg file to # retain the decimation information, and the file won't let us write down # upsampling (especially non-integer upsampling!), so we just set our # decimation factor to 1 and be done with it. if erp.data.shape[0] <= 1 * 256: cprecis = 1 elif erp.data.shape[0] <= 2 * 256: cprecis = 2 elif erp.data.shape[0] <= 3 * 256: cprecis = 3 else: raise ValueError("cprecis > 3") ## TPU samples = cprecis * 256 if erp.data.shape[0] != samples: import scipy.signal resampled_data = scipy.signal.resample(erp.data, samples) else: resampled_data = erp.data assert resampled_data.shape == (samples, erp.data.shape[1]) resampled_sp_10us = int( round((erp.times.max() - erp.times.min()) * 100.0 / samples) ) epoch_len_ms = int(round(samples * resampled_sp_10us / 100.0)) # Need to convert to s16's. To preserve as much resolution as possible, # we use the full s16 range, minus a bit to make sure we don't run into # any overflow issues. s16_max = 2**15 - 10 # Same as np.abs(data).max(), but without copying the whole array: data_max = max(resampled_data.max(), np.abs(resampled_data.min())) # We have to write the conversion factor as an integer, so we round it # down here, and then use the *rounded* version to actually convert the # data. s16_per_10uV = int(s16_max / (data_max / 10)) # Except that if our conversion factor itself overflows, then we have to # truncate it back down (and lose a bit of resolution in the process, oh # well): if s16_per_10uV > s16_max: s16_per_10uV = s16_max integer_data = np.array(np.round(s16_per_10uV * resampled_data / 10.0), dtype="<i2") header["epoch_len"] = epoch_len_ms header["nchans"] = integer_data.shape[1] header["sums"] = 0 # ERP header["tpfuncs"] = 1 # processing function of "averaging" header["pftypes"] = "average" header["pp10uv"] = s16_per_10uV header["10usec_per_tick"] = resampled_sp_10us header["presam"] = 0 - erp.times.min() header["cprecis"] = cprecis header["decfact"] = 1 if "num_combined_trials" in erp.metadata: header["totrr"] = erp.metadata["num_combined_trials"] if len(erp.channel_names) <= 16: header["chndes"] = np.asarray(erp.channel_names, dtype="S8").tostring() elif len(erp.channel_names) <= 32: header["chndes"] = np.asarray(erp.channel_names, dtype="S4").tostring() else: assert False, "Channel name writing for large montages not yet supported" if "experiment" in erp.metadata: header["expdes"] = erp.metadata["experiment"] if "subject" in erp.metadata: header["subdes"] = erp.metadata["subject"] if erp.name is not None: header["condes"] = erp.name header.tofile(stream) # avg files omit the mark track. And, all the data for a single channel # goes together in a single chunk, rather than interleaving all channels. # THIS IS DIFFERENT FROM RAW FILES! for i in range(integer_data.shape[1]): integer_data[:, i].tofile(stream)
if __name__ == "__main__": # stub pass