/helper_functions/helper_functions.py
Python | 1226 lines | 1131 code | 68 blank | 27 comment | 62 complexity | 0e7e2ee7b7f26e8a83263b47a82828fc MD5 | raw file
- #!/usr/bin/env python
- # -*- coding: utf-8 -*-
- # based on obci.analysis.p300.analysis_offline
- # Marian Dovgialo
- import os
- import sys
- import glob
- from multiprocessing import cpu_count
- from itertools import groupby
- from operator import itemgetter
- from copy import deepcopy
- from scipy import signal, stats
- import numpy as np
- import matplotlib
- from p300_classifier_MMP.clusterisation_selector import EEGClusterisationSelector
- from .helper_data.ourcap_neighb import get_our_connectivity_sparse
- havedisplay = "DISPLAY" in os.environ
- if not havedisplay and "linux" in sys.platform:
- matplotlib.use('Agg')
- import pylab as pb
- from copy import deepcopy
- from obci.analysis.obci_signal_processing import read_manager
- from obci.analysis.obci_signal_processing.signal import read_info_source, read_data_source
- from obci.analysis.obci_signal_processing.tags import read_tags_source
- from obci.analysis.obci_signal_processing.tags.smart_tag_definition import SmartTagDurationDefinition
- from obci.analysis.obci_signal_processing.tags.tags_file_writer import TagsFileWriter
- from obci.analysis.obci_signal_processing.smart_tags_manager import SmartTagsManager
- from mne_conversions import read_manager_continious_to_mne, chtype
- import mne
- from config import main_outdir, figure_scale, fontsize
- from collections import namedtuple
- Pos = namedtuple('Pos', ['x', 'y'])
- map1020 = {'eog': Pos(0, 0), 'Fp1': Pos(1, 0), 'Fpz': Pos(2, 0), 'Fp2': Pos(3, 0), 'Null': Pos(4, 0),
- 'F7': Pos(0, 1), 'F3': Pos(1, 1), 'Fz': Pos(2, 1), 'F4': Pos(3, 1), 'F8': Pos(4, 1),
- 'T3': Pos(0, 2), 'C3': Pos(1, 2), 'Cz': Pos(2, 2), 'C4': Pos(3, 2), 'T4': Pos(4, 2),
- 'T5': Pos(0, 3), 'P3': Pos(1, 3), 'Pz': Pos(2, 3), 'P4': Pos(3, 3), 'T6': Pos(4, 3),
- 'M1': Pos(0, 4), 'O1': Pos(1, 4), 'Oz': Pos(2, 4), 'O2': Pos(3, 4), 'M2': Pos(4, 4)}
- def get_filelist(filelist):
- if len(filelist) == 1:
- if os.path.exists(filelist[0]):
- pass
- # print 'pracuję nad pojedyńczym plikiem', filelist[0]
- else:
- print 'pracuję ze wzorem', filelist[0]
- filelist = glob.glob(filelist[0])
- else:
- print 'pracuję nad listą plików', filelist
-
- return filelist
- def get_tagfilters(blok_type):
- epoch_labels = None
- if not blok_type:
- return (None,), epoch_labels
-
- if blok_type == 1:
- epoch_labels = ('target', 'nontarget')
-
- def target_func(tag):
- try:
- return tag['desc']['blok_type'] == '1' and tag['desc']['type'] == 'target'
- except:
- return False
-
- def nontarget_func(tag):
- try:
- return tag['desc']['blok_type'] == '1' and tag['desc']['type'] == 'nontarget'
- except:
- return False
- return (target_func, nontarget_func), epoch_labels
- elif blok_type == 2:
- epoch_labels = ('target', 'nontarget')
-
- def target_func(tag):
- try:
- return tag['desc']['blok_type'] == '2' and tag['desc']['type'] == 'target'
- except:
- return False
-
- def nontarget_func(tag):
- try:
- return tag['desc']['blok_type'] == '2' and tag['desc']['type'] == 'nontarget'
- except:
- return False
- return (target_func, nontarget_func), epoch_labels
- elif blok_type == 'local':
- epoch_labels = ('dewiant', 'standard')
-
- def target_func(tag):
- try:
- return tag['desc']['type_local'] == 'dewiant'
- except:
- return False
-
- def nontarget_func(tag):
- try:
- return tag['desc']['type_local'] == 'standard'
- except:
- return False
-
- return (target_func, nontarget_func), epoch_labels
-
- elif blok_type == 'global':
- epoch_labels = ('target', 'nontarget')
-
- def target_func(tag):
- try:
- return tag['desc']['type_global'] == 'target'
- except:
- return False
-
- def nontarget_func(tag):
- try:
- return tag['desc']['type_global'] == 'nontarget'
- except:
- return False
-
- return (target_func, nontarget_func), epoch_labels
-
- elif blok_type == "erds":
- epoch_labels = ('hand_mvt', 'leg_mvt')
-
- def reka_func(tag):
- try:
- return tag['name'] == 'ERDS_instr1.wav'
- except:
- return False
-
- def noga_func(tag):
- try:
- return tag['name'] == 'ERDS_instr2.wav'
- except:
- return False
-
- return (reka_func, noga_func), epoch_labels
- elif blok_type == "wzrokowe_kot_movie":
- epoch_labels = ('kot_target', 'kot_nontarget')
- def target_func(tag):
- try:
- return tag['name'] == 'kot_target'
- except:
- return False
- def nontarget_func(tag):
- try:
- return tag['name'] == 'kot_nontarget'
- except:
- return False
- return (target_func, nontarget_func), epoch_labels
- elif blok_type == "wzrokowe_ptak_movie":
- epoch_labels = ('ptak_target', 'ptak_nontarget')
- def target_func(tag):
- try:
- return tag['name'] == 'ptak_target'
- except:
- return False
- def nontarget_func(tag):
- try:
- return tag['name'] == 'ptak_nontarget'
- except:
- return False
- return (target_func, nontarget_func), epoch_labels
- elif blok_type == "wzrokowe_both_movie":
- epoch_labels = ('target', 'nontarget')
-
- def target_func(tag):
- try:
- return tag['name'] in ('kot_target', 'ptak_target')
- except:
- return False
- def nontarget_func(tag):
- try:
- return tag['name'] in ('kot_nontarget', 'ptak_nontarget')
- except:
- return False
-
- return (target_func, nontarget_func), epoch_labels
- elif blok_type == "multimodal_both_movie":
- epoch_labels = ('target', 'nontarget')
-
- def target_func(tag):
- try:
- return tag['name'] in ('kaczka_target', 'pies_target')
- except:
- return False
-
- def nontarget_func(tag):
- try:
- return tag['name'] in ('kaczka_nontarget', 'pies_nontarget', 'sowa_nontarget', 'zaba_nontarget')
- except:
- return False
-
- return (target_func, nontarget_func), epoch_labels
- elif blok_type == "multimodal_both":
- epoch_labels = ('target', 'nontarget')
-
- def target_func(tag):
- try:
- return tag['desc']['trial_type'] == "target"
- except:
- return False
-
- def nontarget_func(tag):
- try:
- return tag['desc']['trial_type'] == "nontarget"
- except:
- return False
-
- return (target_func, nontarget_func), epoch_labels
- else:
- raise Exception("Nie ma bloków tego typu!")
- def savetags(stags, filename, start_offset = 0, duration = 0.1):
- """Create tags XML from smart tag list"""
- writer = TagsFileWriter(filename)
- for stag in stags:
- tag = stag.get_tags()[0]
- tag['start_timestamp'] += start_offset
- tag['end_timestamp'] += duration + start_offset
- writer.tag_received(tag)
- writer.finish_saving(0.0)
- def get_microvolt_samples(stag, channels = None):
- """Does get_samples on smart tag (read manager), but multiplied by channel gain"""
- if not channels: # returns for all channels
- gains = np.array([float(i) for i in stag.get_param('channels_gains')], ndmin = 2).T
- return stag.get_samples() * gains
- elif isinstance(channels, (str, unicode)): # returns for specific channel
- channel = channels
- ch_n = stag.get_param('channels_names').index(channel)
- gain = stag.get_param('channels_gains')[ch_n]
- return stag.get_channel_samples(channel) * float(gain)
- else: # returns for specified list of channels
- gains = stag.get_param('channels_gains')
- gains = np.array([float(gains[channels.index(ch)]) for ch in channels], ndmin = 2).T
- return stag.get_channels_samples(channels) * gains
- def shift_tags_relatively_to_signal_beginnig(rm, shift):
- """przesuwa wszystkie tagi w rm o pewną względną odległość
- zatem każdy tag jest przesuwany o inną wartość w zależności od jego pozycji
- im dalej od początku sygnału, tym bardziej jest przesuwany.
- """
- tags = rm.get_tags()
- for tag in tags:
- tag['start_timestamp'] *= 1+shift
- tag['end_timestamp'] *= 1+shift
- rm.set_tags(tags)
- def align_tags(rm, tag_correction_chnls, start_offset = -0.1, duration = 0.3, thr = None, reverse = False, offset = 0):
- """aligns tags in read manager to start of sudden change std => 3 in either tag_correction_chnls list
- searches for that in window [start_offset+tag_time; tag_time+duration]
- if no such change occures - does nothing to the tag - reverse - searches for end of stimulation
- offset - offset in seconds to add forcibly
- """
- tags = rm.get_tags()
- Fs = float(rm.get_param('sampling_frequency'))
- trigger_chnl = np.zeros(int(rm.get_param('number_of_samples')))
- for tag_correction_chnl in tag_correction_chnls:
- trigger_chnl += np.abs(rm.get_channel_samples(tag_correction_chnl))
- if not thr:
- thr = 3 * np.std(trigger_chnl) + np.mean(trigger_chnl)
- maksimum = trigger_chnl.max()
- if thr > 0.5 * maksimum:
- thr = 0.5 * maksimum
-
- for tag in tags:
- start = int((tag['start_timestamp'] + start_offset) * Fs)
- end = int((tag['start_timestamp'] + start_offset + duration) * Fs)
- try:
- if reverse:
- trig_pos_s_r = np.argmax(np.flipud(trigger_chnl[start:end] > thr))
- trig_pos_s = (end - start - 1) - trig_pos_s_r
- else:
- trig_pos_s = np.argmax(trigger_chnl[start:end] > thr) # will find first True, or first False if no Trues
- except ValueError:
- tag['start_timestamp'] += offset
- tag['end_timestamp'] += offset
- continue
- # Debuging code:
- # print trig_pos_s, Fs, reverse,
- # print 'thr', thr, 'value at pos', trigger_chnl[start+trig_pos_s], trigger_chnl[start+trig_pos_s]>thr
- # pb.plot(np.linspace(0, (end-start)/Fs, len(trigger_chnl[start:end])), trigger_chnl[start:end])
- # pb.axvline(trig_pos_s/Fs, color='k')
- # pb.title(str(tag))
- # pb.show()
- # Debug code end
- if trigger_chnl[start + trig_pos_s] > thr:
- trig_pos_t = trig_pos_s * 1.0 / Fs
- tag_change = trig_pos_t + start_offset
- tag['start_timestamp'] += tag_change
- tag['end_timestamp'] += tag_change
- tag['start_timestamp'] += offset
- tag['end_timestamp'] += offset
- rm.set_tags(tags)
- def show_eog_ica(rm, ica,
- eog_chnl = 'eog',
- blink_rejection_dict = dict(eeg = 0.000250, eog = 0.000500), # V
- correlation_treshhold = 0.25,
- results_path =""):
- """
- rm - read manager with training data for ICA
- eog_chnl - channel to use as EOG source
- montage - montage of the read manager (for logging and filenames of generated images)
- use_eog_events True - split to EOG epochs - do ICA
- use eog events False - use whole file
- use_eog_events None - only create eog events
- manual True/False - shows ICA components, ICA components map
- prints correlations with EOG and then lets user write space seperated
- indexes of components to remove
- Returns
- - fitted mne.ICA object to be used in remove_eog_ica to correct read manager
- - list of bad components
- - detected eog events
- """
- print('removing eog artifact')
- raw = read_manager_continious_to_mne(rm)
- n = len(raw.ch_names)
- print('n chnls {}'.format(n))
- raw.plot(block = True, show = False, scalings = 'auto', title = 'Simple preview of signal', n_channels = n)
- data = raw.get_data()
- # data zawiera wszystkie dane z raw
- # data2 zawierać będzie dane bez outlierów:
- data2, drop_inds = mne.preprocessing.ica._reject_data_segments(data, blink_rejection_dict, flat = None, decim = None, info = raw.info, tstep = 2.)
- # bardzo ważne, żeby powyżej podać raw.info, a nie ica.info, bo ica.info nie zawiera informacji o eog, a chcemy usunąć outliery również stamtąd
- raw2 = mne.io.RawArray(data2, raw.info)
- scores = ica.score_sources(raw2, target = eog_chnl, l_freq = 1, h_freq = 9)
-
- if max(np.abs(scores)) > correlation_treshhold:
- bads = [np.argmax(np.abs(scores))]
- winnig_correlation = max(np.abs(scores))
- else:
- bads = []
- winnig_correlation = 0
- print(ica)
-
- print('CORRELATION SCORES:')
- for nr, score in enumerate(scores):
- msg = '{} {} {}'.format('*' if np.abs(score) > correlation_treshhold else ' ', nr, score)
- print msg
- if winnig_correlation:
- title = 'Components with artifacts: {}, Corr. with {} = {} > thr = {}'.format(bads, eog_chnl, winnig_correlation, correlation_treshhold)
- print 'Wybrane złe komponenty:', bads
- else:
- title = 'Components with artifacts: None, (corr. with {}), thr = {}'.format(eog_chnl, correlation_treshhold)
- print 'Nie udało się wybrać złych komponent:', bads
-
- ica.plot_components(res = 128, show = False, title = title, colorbar = True)
- ica.plot_sources(raw, show = True)
-
- print "\nDane będą zapisane do katalogu:"
- print results_path
-
- print('\nWpisz indeksy komponent rozdzielone spacjami jeśli chcesz nadpisać (po zamknięciu okienek)\n'
- 'jeśli nic nie wpiszesz użyją się wybrane automatycznie [potwierdź ENTERem]\n'
- 'jeśli nie chcesz usuwać żadnej wpisz -1.')
-
- good = False
- while not good:
- try:
- inp = raw_input()
- if inp.split():
- man_bads = [int(i) for i in inp.split() if 0 <= int(i) < n]
- if not man_bads and not int(inp) == -1:
- raise Exception
- bads = man_bads
- good = True
- except Exception:
- print('Błąd, wpisz jeszcze raz\n')
- print 'Wybrane złe komponenty:', bads
-
- return bads
- def fit_eog_ica(rm,
- eog_chnl = 'eog',
- montage = None,
- ds = '',
- use_eog_events = False,
- manual = False,
- rejection_dict = dict(eeg = 0.000150,
- eog = 0.000250), # V
- blink_rejection_dict = dict(eeg = 0.000250,
- eog = 0.000500), # V
- correlation_treshhold = 0.5,
- outdir = os.path.expanduser(os.path.join(main_outdir, "unknown/noname/ica_and_artifacts/ica_maps"))):
- """
- rm - read manager with training data for ICA
- eog_chnl - channel to use as EOG source
- montage - montage of the read manager (for logging and filenames of generated images)
- use_eog_events True - split to EOG epochs - do ICA
- use eog events False - use whole file
- use_eog_events None - only create eog events
- manual True/False - shows ICA components, ICA components map
- prints correlations with EOG and then lets user write space seperated
- indexes of components to remove
- Returns
- - fitted mne.ICA object to be used in remove_eog_ica to correct read manager
- - list of bad components
- - detected eog events
- """
- print('removing eog artifact')
- raw = read_manager_continious_to_mne(rm)
- n = len(raw.ch_names)
- print('n chnls {}'.format(n))
- if manual:
- raw.plot(block = True, show = True, scalings = 'auto', title = 'Simple preview of signal', n_channels = n)
- print('HHHHHHHHHHHHHHHHh', raw.ch_names, eog_chnl)
- raw.plot(block=True, show=True, scalings='auto', title='Simple preview of signal', n_channels=n)
- events = mne.preprocessing.find_eog_events(raw, ch_name = eog_chnl)
- print('EOG EVENTS\n', events)
- if use_eog_events is None:
- return rm, events
-
- no_eeg_channels = len([ch for ch in raw.ch_names if chtype(ch) == "eeg"])
- max_pca_components = no_eeg_channels - 1 - ("car" in montage)
- ica = mne.preprocessing.ICA(method = 'extended-infomax', max_pca_components = max_pca_components)
- if use_eog_events:
- eog = mne.preprocessing.create_eog_epochs(raw, ch_name = eog_chnl)
- ica.fit(eog, reject = rejection_dict, picks = mne.pick_types(raw.info, eeg = True, eog = True))
- else:
- ica.fit(raw, reject = rejection_dict, tstep = 0.3)
- data = raw.get_data()
- # data zawiera wszystkie dane z raw
- # data2 zawierać będzie dane bez outlierów:
-
- data2, drop_inds = mne.preprocessing.ica._reject_data_segments(data, blink_rejection_dict, flat = None, decim = None, info = raw.info, tstep = 2.)
- # bardzo ważne, żeby powyżej podać raw.info, a nie ica.info, bo ica.info nie zawiera informacji o eog, a chcemy usunąć outliery również stamtąd
- raw2 = mne.io.RawArray(data2, raw.info)
- scores = ica.score_sources(raw2, target = eog_chnl, l_freq = 1, h_freq = 9)
- bads = list(np.arange(scores.size)[np.abs(scores) > correlation_treshhold])
- if len(bads) == 0 and max(np.abs(scores)) > correlation_treshhold/2.:
- bads = [np.argmax(np.abs(scores))]
- if bads:
- winnig_correlations = np.abs(scores[bads])
- else:
- winnig_correlations = None
-
- print(ica)
-
- filename = 'ICA_eog_' + os.path.basename(ds) + '_{}'.format(montage)
-
- log = open(os.path.join(outdir, filename + '.txt'), 'w')
- print('CORRELATION SCORES:')
- log.write('CORRELATION SCORES:\n')
- for nr, score in enumerate(scores):
- msg = '{} {} {}'.format('*' if np.abs(score) > correlation_treshhold else ' ', nr, score)
- print msg
- log.write(msg + '\n')
- log.close()
- if np.any(winnig_correlations):
- title = 'Components with artifacts: {}, Corr. with {} = {} > thr = {}'.format(bads, eog_chnl, winnig_correlations, correlation_treshhold)
- print 'Wybrane złe komponenty:', bads
- else:
- title = 'Components with artifacts: None, (corr. with {}), thr = {}'.format(eog_chnl, correlation_treshhold)
- print 'Nie udało się wybrać złych komponent:', bads
- fig = ica.plot_components(res = 128, show = False, title = title, colorbar = True)
- # przykładowe 40 sekund sygnału począwczy od 10 sekundy
- # oraz odpowiadające mu komponenty ICA
- fig_sig = raw.plot(start = 10, duration = 40, block = False, show = False, scalings = 'auto', title = 'Preview of dirty signal', n_channels = n)
- fig_com = ica.plot_sources(raw, start = 10, stop = 50, show = False)
- fig_com.savefig(os.path.join(outdir, filename + 'example1_components.png'))
- fig_sig.savefig(os.path.join(outdir, filename + 'example2_signal.png'))
- pb.close(fig_sig)
- pb.close(fig_com)
- if manual:
- ica.plot_sources(raw, show = True)
-
- print('Wpisz indeksy komponent rozdzielone spacjami jeśli chcesz nadpisać (po zamknięciu okienek)\n'
- 'jeśli nic nie wpiszesz użyją się wybrane automatycznie [potwierdź ENTERem]\n'
- 'jeśli nie chcesz usuwać żadnej wpisz -1.')
-
- # trzeba narysować jeszcze raz, bo jak użytkownik zamknie, to nie ma potem czego zapisać do pliku
- fig = ica.plot_components(res = 128, show = False, title = title, colorbar = True)
- good = False
- while not good:
- try:
- inp = raw_input()
- if inp.split():
- man_bads = [int(i) for i in inp.split() if 0 <= int(i) < n]
- if not man_bads and not int(inp) == -1:
- raise Exception
- bads = man_bads
- good = True
- except Exception:
- print('Błąd, wpisz jeszcze raz\n')
- print 'Wybrane złe komponenty:', bads
-
- if isinstance(fig, list):
- for nr, figura in enumerate(fig):
- figura.savefig(os.path.join(outdir, filename + '_{}'.format(nr) + '.png'))
- if not manual:
- pb.close(figura)
- else:
- fig.savefig(os.path.join(outdir, filename + '.png'))
- if not manual:
- pb.close(fig)
-
- return ica, bads, events
- def remove_ica_components(rm, ica, bads,
- events = [],
- scalings = {'eeg': 4e-5, 'eog': 4e-5},
- silent = False,
- ds = '',
- montage = [],
- outdir = os.path.expanduser(os.path.join(main_outdir, "unknown/noname/ica_and_artifacts/ica_maps"))):
- """
-
- rm - read manager with data to clean
- ica - fitted mne.ica object to be used (e.g. returned by fit_eog_ica)
- bads - list of bad components of ICA
- events - detected eog events (nd.array)
- Returns ICA-corrected read manager
- """
- rm = deepcopy(rm)
- if bads:
- # read_manager to mne conversion
- raw = read_manager_continious_to_mne(rm)
- n = len(raw.ch_names)
- if not silent:
- raw.copy().plot(scalings = scalings, events = events, block = True, show = False, title = 'PRZED ICA', n_channels = n)
-
- raw_clean = ica.apply(raw, exclude = bads)
-
- if not silent:
- raw_clean.copy().plot(scalings = scalings, events = events, block = True, show = True, title = 'PO ICA', n_channels = n)
- if ds:
- # przykładowe 40 sekund sygnału począwczy od 10 sekundy po usunięciu mrugnięć
- filename = 'ICA_eog_' + os.path.basename(ds) + '_{}'.format(montage)
- fig = raw_clean.plot(start = 10, duration = 40, block = False, show = False, scalings = 'auto', title = 'Preview of cleaned signal', n_channels = n)
- fig.savefig(os.path.join(outdir, filename + 'example3_clean_signal.png'))
- pb.close(fig)
- data = np.array(raw_clean.to_data_frame())
- print "CONTROL INFO"
- print data.shape
- print np.median(np.abs(data), axis = 1)
- print np.std(data, axis = 1)
-
- # mne to read_manager conversion
- rm.set_samples(data.T, rm.get_param('channels_names'))
-
- return rm
- def remove_eog_ica(rm,
- eog_chnl = 'eog',
- montage = None,
- ds = '',
- use_eog_events = False,
- manual = False,
- rejection_dict = dict(eeg = 0.000150,
- eog = 0.000250), # V
- correlation_treshhold = 0.5):
- """
- Exists for compatibility reasons.
- rm - read manager with training data for ICA
- eog_chnl - channel to use as EOG source
- montage - montage of the read manager (for logging and filenames of generated images)
- use_eog_events True - split to EOG epochs - do ICA
- use eog events False - use whole file
- use_eog_events None - only create eog events
- manual True/False - shows ICA components, ICA components map
- prints correlations with EOG and then lets user write space seperated
- indexes of components to remove
- Returns
- - ICA-corrected read manager
- - detected eog events
- """
-
- ica, bads, eog_events = fit_eog_ica(rm, eog_chnl, montage, ds, use_eog_events, manual, rejection_dict, correlation_treshhold)
- if bads:
- clean_rm = remove_ica_components(ica, bads, eog_events)
- else:
- clean_rm = rm
-
- return clean_rm, eog_events
- def interp_bads(rm, bads):
- ds = read_manager_continious_to_mne(rm)
- ds.info['bads'] = bads
- ds.interpolate_bads()
- data = np.array(ds.to_data_frame())
- rm.set_samples(data.T * 1e-6, rm.get_param('channels_names'))
- return rm
- def undrop_channels(mgr_dropped, mgr_full):
- """Uzupełnia mgr_dropped o brakujące kanały pobierając je z mgr_full,
- a właściwie to zastępuje kanały mgr_full kanałami z mgr_dropped, bo cała reszta informacji
- jest wzięta z mgr_full (zakładam bowiem, że jest to ta sama informacja)"""
- available_ch = mgr_dropped.get_param('channels_names')
- all_ch = mgr_full.get_param('channels_names')
-
- new_params = deepcopy(mgr_full.get_params())
-
- samples_full = mgr_full.get_samples()
- samples_dropped = mgr_dropped.get_samples()
-
- new_tags = deepcopy(mgr_full.get_tags())
- new_samples = np.zeros((int(new_params['number_of_channels']), len(samples_full[0])))
-
-
- # Define new samples and params list values
- keys = ['channels_names', 'channels_numbers', 'channels_gains', 'channels_offsets']
- keys_to_remove = []
- for k in keys:
- try:
- # Exclude from keys those keys that are missing in mgr
- mgr_full.get_params()[k]
- except KeyError:
- keys_to_remove.append(k)
- continue
- new_params[k] = []
- for k in keys_to_remove:
- keys.remove(k)
- new_ind = 0
-
- for ch_ind, ch in enumerate(all_ch):
- if ch not in available_ch:
- new_samples[ch_ind, :] = samples_full[ch_ind, :]
- else:
- new_samples[ch_ind, :] = samples_dropped[available_ch.index(ch), :]
-
- for k in keys:
- new_params[k].append(mgr_full.get_params()[k][ch_ind])
- new_ind += 1
- info_source = read_info_source.MemoryInfoSource(new_params)
- tags_source = read_tags_source.MemoryTagsSource(new_tags)
- samples_source = read_data_source.MemoryDataSource(new_samples)
- return read_manager.ReadManager(info_source, samples_source, tags_source)
- def exclude_channels(mgr, channels):
- """exclude all channels in channels list"""
- available = set(mgr.get_param('channels_names'))
- exclude = set(channels)
- channels = list(available.intersection(exclude))
-
- new_params = deepcopy(mgr.get_params())
- samples = mgr.get_samples()
- new_tags = deepcopy(mgr.get_tags())
-
- ex_channels_inds = [new_params['channels_names'].index(ch) for ch in channels]
- assert (-1 not in ex_channels_inds)
-
- new_samples = np.zeros((int(new_params['number_of_channels']) - len(channels),
- len(samples[0])))
- # Define new samples and params list values
- keys = ['channels_names', 'channels_numbers', 'channels_gains', 'channels_offsets']
- keys_to_remove = []
- for k in keys:
- try:
- # Exclude from keys those keys that are missing in mgr
- mgr.get_params()[k]
- except KeyError:
- keys_to_remove.append(k)
- continue
- new_params[k] = []
-
- for k in keys_to_remove:
- keys.remove(k)
- new_ind = 0
- for ch_ind, ch in enumerate(samples):
- if ch_ind in ex_channels_inds:
- continue
- else:
- new_samples[new_ind, :] = ch
- for k in keys:
- new_params[k].append(mgr.get_params()[k][ch_ind])
-
- new_ind += 1
-
- # Define other new new_params
- new_params['number_of_channels'] = str(int(new_params['number_of_channels']) - len(channels))
-
- info_source = read_info_source.MemoryInfoSource(new_params)
- tags_source = read_tags_source.MemoryTagsSource(new_tags)
- samples_source = read_data_source.MemoryDataSource(new_samples)
- return read_manager.ReadManager(info_source, samples_source, tags_source)
- def leave_channels(mgr, channels):
- """exclude all channels except those in channels list"""
- chans = deepcopy(mgr.get_param('channels_names'))
- for leave in channels:
- chans.remove(leave)
- return exclude_channels(mgr, chans)
- def GetEpochsFromRM(rm, tags_function_list,
- start_offset = -0.1, duration = 2.0,
- tag_name = None,
- get_last_tags = False):
- """Extracts stimulus epochs from ReadManager to list of SmartTags
- Args:
- rm: ReadManager with dataset
- start_offset: baseline in negative seconds,
- duration: duration of the epoch (excluding baseline),
- tags_function_list: list of tag filtering functions to get epochs for
- tag_name: tag name to be considered, if you want to use all tags use None
- get_last_tags: takes only las 99 tags
- Return:
- list of smarttags corresponding to tags_function_list"""
-
- # usuwamy tagi, które nie mają dość sygnału przed swoim początkiem, żeby zaaplikować offset
- # niestety tracimy w ten sposób początkowy tag lub dwa, ale bez tego wypadną wszystkie, bo SmartTagManager ma buga :(
- new_tags = [tag for tag in rm.get_tags() if float(tag['start_timestamp']) > -start_offset]
- rm.set_tags(new_tags)
-
- if get_last_tags:
- tags = rm.get_tags()
- rm.set_tags(tags[-1 - 99:])
- tag_def = SmartTagDurationDefinition(start_tag_name = tag_name,
- start_offset = start_offset,
- end_offset = 0.0,
- duration = duration)
- stags = SmartTagsManager(tag_def, '', '', '', p_read_manager = rm)
- returntags = []
- for tagfunction in tags_function_list:
- returntags.append(stags.get_smart_tags(p_func = tagfunction, ))
- print 'Found epochs in defined groups:', [len(i) for i in returntags]
- return returntags
- def evoked_from_smart_tags(tags, chnames, bas = -0.1):
- """
- Args:
- tags: smart tag list, to average
- chnames: list of channels to use for averaging,
- bas: baseline (in negative seconds)"""
- min_length = min(i.get_samples().shape[1] for i in tags)
- # really don't like this, but epochs generated by smart tags can vary in length by 1 sample
- channels_data = []
- Fs = float(tags[0].get_param('sampling_frequency'))
- for i in tags:
- try:
- data = i.get_channels_samples(chnames)[:, :min_length]
- except IndexError: # in case of len(chnames)==1
- data = i.get_channels_samples(chnames)[None, :][:, :min_length]
-
- if bas:
- for nr, chnl in enumerate(data):
- data[nr] = chnl - np.mean(chnl[0:int(-Fs * bas)]) # baseline correction
- if np.max(np.abs(data)) < np.inf:
- channels_data.append(data)
-
- return np.mean(channels_data, axis = 0), stats.sem(channels_data, axis = 0)
- def do_permutation_test(taglist, chnames):
- """ between 2 conditions in taglist
- returns: list of clusters (per channel) with tuples (clusters, clusters_p_values)
- """
-
- print 'LEN TAGLIST =', len(taglist), "|", [len(t) for t in taglist]
- min_length = min([min(i.get_samples().shape[1] for i in tags) for tags in taglist])
-
- clusters = [] # per channel
- for channel in chnames:
- data_test = []
- print 'clustering for channel {}'.format(channel)
- for tags in taglist:
- data_tag = []
- for tag in tags:
- chnls_data = tag.get_channel_samples(channel)[:min_length]
- data_tag.append(chnls_data.T)
- data_test.append(np.array(data_tag))
- if len(data_test) > 1:
- T_obs, clusters_, cluster_p_values, H0 = mne.stats.permutation_cluster_test(data_test, step_down_p = 0.05, n_jobs = cpu_count()/2, seed = 42)
- else:
- T_obs, clusters_, cluster_p_values, H0 = mne.stats.permutation_cluster_1samp_test(data_test[0], step_down_p = 0.05, n_jobs = cpu_count()/2, seed = 42)
- clusters.append((clusters_, cluster_p_values))
- return clusters
- def do_permutation_test_spatiotemporal(taglist, chnames):
- """ Does permutation cluster test between 2 conditions in taglist
- Includes spatiotemporal calculation
- returns: spatiotemporal_cluster
- """
- connectivity = get_our_connectivity_sparse(chnames)
- # show_connectivity(connectivity, chnames)
- print 'LEN TAGLIST =', len(taglist), "|", [len(t) for t in taglist]
- min_length = min([min(i.get_samples().shape[1] for i in tags) for tags in taglist])
- data_tags = []
- for tags in taglist:
- data_tag = []
- for tag in tags:
- chnls_data = tag.get_channels_samples(chnames)[:, :min_length]
- data_tag.append(chnls_data.T)
- data_tag = np.array(data_tag)
- data_tags.append(data_tag)
- if len(data_tags) > 1:
- # import IPython
- # IPython.embed()
- T_obs, clusters_, cluster_p_values, H0 = mne.stats.spatio_temporal_cluster_test(data_tags, step_down_p=0.05, n_jobs = cpu_count()/2, seed=42, n_permutations=1024, connectivity=connectivity)
- else:
- T_obs, clusters_, cluster_p_values, H0 = mne.stats.spatio_temporal_cluster_1samp_test(data_tags[0], step_down_p=0.05, n_jobs = cpu_count()/2, seed=42, n_permutations=1024, connectivity=connectivity)
- print("P_Values:", cluster_p_values)
- print "FOND {} CLUSTERS, in THOSE with P<0.05: {}, with P<0.15: {}".format(len(cluster_p_values), np.sum(cluster_p_values<0.05), np.sum(cluster_p_values<0.15))
- return clusters_, cluster_p_values
- def find_consecutive_ranges(data):
- ranges = []
- for k, g in groupby(enumerate(data), lambda (i, x): i - x):
- group = map(itemgetter(1), g)
- ranges.append((group[0], group[-1]))
- return ranges
- def evoked_list_plot_smart_tags(taglist, chnames = ('Fz', 'Pz','Cz'), chnls_to_clusterize = ('Fz', 'Pz','Cz'),
- start_offset = -0.2, roi = (-1e10, 1e10), labels = ('target', 'nontarget'), show = True, size = (5, 5),
- addline = [], one_scale = True, anatomical = True, std_corridor = True, permutation_test = True,
- spatiotemporal=True, classification_pipeline_clusters=False):
- """debug evoked potential plot,
- plot list of smarttags,
- blocks thread
- Args:
- taglist: list of smarttags
- labels: list of labels
- taglist, labels: lists of equal lengths,
- chnames: channels to plot
- start_offset: baseline in seconds
- addline: list of floats - seconds to add vertical barons.py", line 927, in do_autoreject
- segment_shape = ts.bad_segments.shape
- one_scale: binary - to force the same scale
- anatomical: plot all 10-20 electrodes with positions
- permutation_test: do a permutation test between target/nontarget
- spatiotemporal: if true permutation clustering will be spatiotemporal not per channel
- """
-
- for tag in taglist[0] + taglist[1]:
- available_chnls = tag.get_param('channels_names')
- chnames = [chname for chname in chnames if chname in available_chnls]
- chnls_to_clusterize = [chname for chname in chnls_to_clusterize if chname in available_chnls]
-
- evs, stds = [], []
- for tags in taglist:
- ev, std = evoked_from_smart_tags(tags, chnames, start_offset)
- evs.append(ev)
- stds.append(std)
-
- Fs = float(taglist[0][0].get_param('sampling_frequency'))
- times = np.linspace(0 + start_offset, ev.shape[1] / Fs + start_offset, ev.shape[1])
- # baseline correction
- for tag in taglist[0] + taglist[1]:
- samples = tag.get_channels_samples(tag.get_param('channels_names'))
- for s in samples:
- s -= np.mean(s[:-int(Fs*start_offset)])
- tag.set_samples(samples, tag.get_param('channels_names'))
- # truncation according to expected timing of relevant potentials
- truncated_taglist = deepcopy(taglist)
- for tag in truncated_taglist[0] + truncated_taglist[1]:
- samples = tag.get_channels_samples(tag.get_param('channels_names'))
- roi_start = np.argmin(np.abs(times-roi[0]*1e-3))
- roi_end = np.argmin(np.abs(times-roi[1]*1e-3))
- truncated_samples = samples[:, roi_start:roi_end]
- tag.set_samples(truncated_samples, tag.get_param('channels_names'))
-
- truncated_times = times[roi_start:roi_end]
-
- if permutation_test and not spatiotemporal and not classification_pipeline_clusters:
- clusters_per_chnl = do_permutation_test(truncated_taglist, chnames)
- p_values = []
- elif permutation_test and spatiotemporal:
- clusters, clusters_p_values = do_permutation_test_spatiotemporal(truncated_taglist, chnls_to_clusterize)
- p_values = [p for p in clusters_p_values]
- clusters_significant = [clusters[i] for i in range(len(clusters_p_values)) if clusters_p_values[i] < 0.1]
- clusters_significant_p_values = [clusters_p_values[i] for i in range(len(clusters_p_values)) if clusters_p_values[i] < 0.1]
- print clusters_p_values
-
- # todo - to poniżej prawie na pewno nie działa, bo było dużo zmian, do których nie było to dostosowywane, jako że nieużywane
- elif permutation_test and classification_pipeline_clusters:
- clusteriser = EEGClusterisationSelector(Fs=Fs, bas=start_offset, chnames=chnames, cluster_time_channel_mask=True)
- clusteriser.fit_smarttags(taglist)
- p_values = [p for p in clusteriser.cluster_p_values]
-
- if one_scale:
- vmax = np.max(np.array(evs) + np.array(stds))
- vmin = np.min(np.array(evs) - np.array(stds))
-
- if anatomical:
- fig, axs = pb.subplots(6, 5, figsize = (1. * 5 * figure_scale, .5625 * 5 * figure_scale))
- for ch in channels_not2draw(chnames):
- pos = map1020[ch]
- axs[pos.y, pos.x].axis("off")
- fig.subplots_adjust(left = 0.03, bottom = 0.03, right = 0.98, top = 0.93, wspace = 0.15, hspace = 0.27)
- else:
- fig = pb.figure(figsize = size)
- for nr, chname in enumerate(chnames):
- if anatomical:
- pos = map1020[chname]
- ax = axs[pos.y, pos.x]
- else:
- ax = fig.add_subplot((len(chnames) + 1) / 2, 2, nr + 1)
- # zaznaczenie ROI na wykresie:
- x1 = truncated_times[0]
- x2 = truncated_times[-1]
- ax.plot([x1, x1], [vmin, vmax], "k--", linewidth = 2)
- ax.plot([x2, x2], [vmin, vmax], "k--", linewidth = 2)
-
- if permutation_test and not spatiotemporal and not classification_pipeline_clusters:
- cl, p_val = clusters_per_chnl[nr]
-
- for cc, pp in zip(cl, p_val):
- if pp < 0.05:
- color = "blue"
- alpha = 0.3
- paint_patch = True
- elif pp < 0.1:
- color = "blue"
- alpha = 0.1
- paint_patch = True
- else:
- color = "gray"
- alpha = 1 - pp
- paint_patch = False
- if paint_patch:
- ax.axvspan(truncated_times[cc[0].start], truncated_times[cc[0].stop - 1],
- color = color, alpha = alpha, zorder = 1)
- p_values.append(pp)
- elif permutation_test and spatiotemporal:
- color_cycle_p05 = ['tab:blue', 'tab:purple', 'tab:green', 'xkcd:navy']
- color_cycle_p10 = ['tab:orange', 'tab:brown', 'xkcd:yellow', 'xkcd:burgundy']
- current_color_p05_id = 0
- current_color_p10_id = 0
- for clust_id in range(len(clusters_significant)):
- cluster = clusters_significant[clust_id]
- try:
- cluster_number = chnls_to_clusterize.index(chname)
- except ValueError:
- continue
- cluster_channel_mask = cluster[1] == cluster_number
- if clusters_significant_p_values[clust_id] < 0.05:
- color = color_cycle_p05[current_color_p05_id]
- current_color_p05_id += 1
- if current_color_p05_id == len(color_cycle_p05):
- current_color_p05_id = 0
- else:
- color = color_cycle_p10[current_color_p10_id]
- current_color_p10_id += 1
- if current_color_p10_id == len(color_cycle_p10):
- current_color_p10_id = 0
- print"SIGNIFICANT CLUSTER: ", clust_id, "P value", clusters_significant_p_values[clust_id], [chnls_to_clusterize[kkk] for kkk in np.unique(cluster[1])]
- if np.any(cluster_channel_mask): # czy jest w tym klastrze ten kanał
- cluster_time_idx = cluster[0][cluster_channel_mask] # indexy czasowe należące do klastra
- for span in find_consecutive_ranges(list(cluster_time_idx)):
- ax.axvspan(truncated_times[span[0]], truncated_times[span[1]], alpha = 0.4 if color is not "xkcd:yellow" else 0.6, zorder = 1, color = color)
- elif permutation_test and classification_pipeline_clusters:
- in_thr = min(clusteriser.cluster_p_values) < 0.05
- color = 'tab:blue' if in_thr else 'tab:red'
- for cluster in clusteriser.clusters:
- cluster_for_channel = cluster[:, nr]
- ranges = clusteriser.find_cluster_ranges(cluster_for_channel)
- for cluster_range in ranges:
- ax.axvspan(cluster_range[0], cluster_range[1], color=color, zorder=1, alpha = 0.4)
- # import IPython
- # IPython.embed()
- for tagsnr in xrange(len(taglist)):
- color = None # standard colors
- if 'nontarget' in labels[tagsnr]:
- color = 'green'
- elif 'target' in labels[tagsnr]:
- color = 'red'
-
- if 'standard' in labels[tagsnr]:
- color = 'green'
- elif 'dewiant' in labels[tagsnr]:
- color = 'red'
-
- lines, = ax.plot(times,
- evs[tagsnr][nr],
- label = labels[tagsnr] + ' N:{}'.format(len(taglist[tagsnr])),
- color = color,
- zorder = 3)
-
- ax.axvline(0, color = 'k')
- ax.axhline(0, color = 'k')
- for l in addline:
- ax.axvline(l, color = 'k')
- if std_corridor:
- ax.fill_between(times,
- evs[tagsnr][nr] - stds[tagsnr][nr],
- evs[tagsnr][nr] + stds[tagsnr][nr],
- color = lines.get_color(),
- alpha = 0.3,
- zorder = 2)
-
- if one_scale:
- ax.set_ylim(vmin, vmax)
- elif type(one_scale) == list:
- ax.set_ylim(one_scale[0], one_scale[1])
- ax.set_xlim(round(times[0], 2), round(times[-1], 2))
- ax.set_title(chname)
- set_axis_fontsize(ax, fontsize * figure_scale / 8)
- ax.ticklabel_format(style='plain')
- ax.yaxis.set_major_formatter(pb.FormatStrFormatter('%.0f'))
- ax.legend(fontsize = fontsize * figure_scale / 8)
- if show:
- pb.show()
- return fig, p_values
- def mgr_decimate(mgr, factor):
- steps = int(factor / 2)
- x = mgr.get_samples()
- for step in xrange(steps):
- if step == 0:
- new_samples = x
- y = signal.decimate(new_samples, 2, zero_phase = True)
- new_samples = y
- info_source = deepcopy(mgr.info_source)
- info_source.get_params()['number_of_samples'] = new_samples.shape[1]
- info_source.get_params()['sampling_frequency'] = float(mgr.get_param('sampling_frequency')) / factor
- tags_source = deepcopy(mgr.tags_source)
- samples_source = read_data_source.MemoryDataSource(new_samples)
- return read_manager.ReadManager(info_source, samples_source, tags_source)
- def mgr_order_filter(mgr, order = 0, Wn = [49, 51], rp = None, rs = None, ftype = 'cheby2', btype = 'bandstop',
- output = 'ba', use_filtfilt = True, meancorr = 1.0):
- nyquist = float(mgr.get_param('sampling_frequency')) / 2.0
- if ftype in ['ellip', 'cheby2']:
- b, a = signal.iirfilter(order, np.array(Wn) / nyquist, rp, rs, btype = btype, ftype = ftype, output = output)
- else:
- b, a = signal.iirfilter(order, np.array(Wn) / nyquist, btype = btype, ftype = ftype, output = output)
- if use_filtfilt:
- for i in range(int(mgr.get_param('number_of_channels'))):
- mgr.get_samples()[i, :] = signal.filtfilt(b, a, mgr.get_samples()[i] - np.mean(mgr.get_samples()[i]) * meancorr)
- samples_source = read_data_source.MemoryDataSource(mgr.get_samples(), False)
- else:
- print("FILTER CHANNELs")
- filtered = signal.lfilter(b, a, mgr.get_samples())
- print("FILTER CHANNELs finished")
- samples_source = read_data_source.MemoryDataSource(filtered, True)
- info_source = deepcopy(mgr.info_source)
- tags_source = deepcopy(mgr.tags_source)
- new_mgr = read_manager.ReadManager(info_source, samples_source, tags_source)
- return new_mgr
- def mgr_filter(mgr, wp, ws, gpass, gstop, analog = 0, ftype = 'ellip', output = 'ba', unit = 'hz', use_filtfilt = True, meancorr = 1.0):
- if unit == 'radians':
- b, a = signal.iirdesign(wp, ws, gpass, gstop, analog, ftype, output)
- w, h = signal.freqz(b, a, 1000)
-
- fff = pb.figure()
- ax = fff.add_subplot()
- ax.plot(w, 20 * np.log10(np.abs(h)))
- pb.show()
- elif unit == 'hz':
- nyquist = float(mgr.get_param('sampling_frequency')) / 2.0
- try:
- wp = wp / nyquist
- ws = ws / nyquist
- except TypeError:
- wp = [i / nyquist for i in wp]
- ws = [i / nyquist for i in ws]
- b, a = signal.iirdesign(wp, ws, gpass, gstop, analog, ftype, output)
- if use_filtfilt:
- # samples_source = read_data_source.MemoryDataSource(mgr.get_samples(), False)
- for i in range(int(mgr.get_param('number_of_channels'))):
- # ~ print("FILT FILT CHANNEL "+str(i))
- mgr.get_samples()[i, :] = signal.filtfilt(b, a, mgr.get_samples()[i] - np.mean(mgr.get_samples()[i]) * meancorr)
- samples_source = read_data_source.MemoryDataSource(mgr.get_samples(), False)
- else:
- print("FILTER CHANNELs")
- filtered = signal.lfilter(b, a, mgr.get_samples())
- print("FILTER CHANNELs finished")
- samples_source = read_data_source.MemoryDataSource(filtered, True)
- info_source = deepcopy(mgr.info_source)
- tags_source = deepcopy(mgr.tags_source)
- new_mgr = read_manager.ReadManager(info_source, samples_source, tags_source)
- return new_mgr
- def _exclude_from_montage_indexes(mgr, chnames):
- exclude_from_montage_indexes = []
-
- for i in chnames:
- try:
- exclude_from_montage_indexes.append(mgr.get_param('channels_names').index(i))
- except ValueError:
- pass
- return exclude_from_montage_indexes
- def montage_csa(mgr, exclude_from_montage = []):
- exclude_from_montage_indexes = _exclude_from_montage_indexes(mgr, exclude_from_montage)
- new_samples = get_montage(mgr.get_samples(),
- get_montage_matrix_csa(int(mgr.get_param('number_of_channels')),
- exclude_from_montage = exclude_from_montage_indexes))
- info_source = deepcopy(mgr.info_source)
- tags_source = deepcopy(mgr.tags_source)
- samples_source = read_data_source.MemoryDataSource(new_samples)
- return read_manager.ReadManager(info_source, samples_source, tags_source)
- def montage_ears(mgr, l_ear_channel, r_ear_channel, exclude_from_montage = []):
- try:
- left_index = mgr.get_param('channels_names').index(l_ear_channel)
- except ValueError:
- print "Brakuje kanału usznego {}. Wykonuję montaż tylko do jegnego ucha.".format(l_ear_channel)
- return montage_custom(mgr, [r_ear_channel], exclude_from_montage)
- try:
- right_index = mgr.get_param('channels_names').index(r_ear_channel)
- except ValueError:
- print "Brakuje kanału usznego {}. Wykonuję montaż tylko do jegnego ucha.".format(r_ear_channel)
- return montage_custom(mgr, [l_ear_channel], exclude_from_montage)
-
- exclude_from_montage_indexes = _exclude_from_montage_indexes(mgr, exclude_from_montage)
-
- if left_index < 0 or right_index < 0:
- raise Exception("Montage - couldn`t find ears channels: " + str(l_ear_channel) + ", " + str(r_ear_channel))
-
- new_samples = get_montage(mgr.get_samples(),
- get_montage_matrix_ears(int(mgr.get_param('number_of_channels')),
- left_index,
- right_index,
- exclude_from_montage_indexes)
- )
- info_source = deepcopy(mgr.info_source)
- tags_source = deepcopy(mgr.tags_source)
- samples_source = read_data_source.MemoryDataSource(new_samples)
- return read_manager.ReadManager(info_source, samples_source, tags_source)
- def get_channel_indexes(channels, toindex):
- """get list of indexes of channels in toindex list as found in
- channels list"""
- indexes = []
- for chnl in toindex:
- index = channels.index(chnl)
- if index < 0:
- raise Exception("Montage