PageRenderTime 59ms CodeModel.GetById 27ms RepoModel.GetById 0ms app.codeStats 0ms

/helper_functions/helper_functions.py

https://gitlab.com/BudzikFUW/budzik_analiza
Python | 1226 lines | 1131 code | 68 blank | 27 comment | 62 complexity | 0e7e2ee7b7f26e8a83263b47a82828fc MD5 | raw file
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # based on obci.analysis.p300.analysis_offline
  4. # Marian Dovgialo
  5. import os
  6. import sys
  7. import glob
  8. from multiprocessing import cpu_count
  9. from itertools import groupby
  10. from operator import itemgetter
  11. from copy import deepcopy
  12. from scipy import signal, stats
  13. import numpy as np
  14. import matplotlib
  15. from p300_classifier_MMP.clusterisation_selector import EEGClusterisationSelector
  16. from .helper_data.ourcap_neighb import get_our_connectivity_sparse
  17. havedisplay = "DISPLAY" in os.environ
  18. if not havedisplay and "linux" in sys.platform:
  19. matplotlib.use('Agg')
  20. import pylab as pb
  21. from copy import deepcopy
  22. from obci.analysis.obci_signal_processing import read_manager
  23. from obci.analysis.obci_signal_processing.signal import read_info_source, read_data_source
  24. from obci.analysis.obci_signal_processing.tags import read_tags_source
  25. from obci.analysis.obci_signal_processing.tags.smart_tag_definition import SmartTagDurationDefinition
  26. from obci.analysis.obci_signal_processing.tags.tags_file_writer import TagsFileWriter
  27. from obci.analysis.obci_signal_processing.smart_tags_manager import SmartTagsManager
  28. from mne_conversions import read_manager_continious_to_mne, chtype
  29. import mne
  30. from config import main_outdir, figure_scale, fontsize
  31. from collections import namedtuple
  32. Pos = namedtuple('Pos', ['x', 'y'])
  33. map1020 = {'eog': Pos(0, 0), 'Fp1': Pos(1, 0), 'Fpz': Pos(2, 0), 'Fp2': Pos(3, 0), 'Null': Pos(4, 0),
  34. 'F7': Pos(0, 1), 'F3': Pos(1, 1), 'Fz': Pos(2, 1), 'F4': Pos(3, 1), 'F8': Pos(4, 1),
  35. 'T3': Pos(0, 2), 'C3': Pos(1, 2), 'Cz': Pos(2, 2), 'C4': Pos(3, 2), 'T4': Pos(4, 2),
  36. 'T5': Pos(0, 3), 'P3': Pos(1, 3), 'Pz': Pos(2, 3), 'P4': Pos(3, 3), 'T6': Pos(4, 3),
  37. 'M1': Pos(0, 4), 'O1': Pos(1, 4), 'Oz': Pos(2, 4), 'O2': Pos(3, 4), 'M2': Pos(4, 4)}
  38. def get_filelist(filelist):
  39. if len(filelist) == 1:
  40. if os.path.exists(filelist[0]):
  41. pass
  42. # print 'pracuję nad pojedyńczym plikiem', filelist[0]
  43. else:
  44. print 'pracuję ze wzorem', filelist[0]
  45. filelist = glob.glob(filelist[0])
  46. else:
  47. print 'pracuję nad listą plików', filelist
  48. return filelist
  49. def get_tagfilters(blok_type):
  50. epoch_labels = None
  51. if not blok_type:
  52. return (None,), epoch_labels
  53. if blok_type == 1:
  54. epoch_labels = ('target', 'nontarget')
  55. def target_func(tag):
  56. try:
  57. return tag['desc']['blok_type'] == '1' and tag['desc']['type'] == 'target'
  58. except:
  59. return False
  60. def nontarget_func(tag):
  61. try:
  62. return tag['desc']['blok_type'] == '1' and tag['desc']['type'] == 'nontarget'
  63. except:
  64. return False
  65. return (target_func, nontarget_func), epoch_labels
  66. elif blok_type == 2:
  67. epoch_labels = ('target', 'nontarget')
  68. def target_func(tag):
  69. try:
  70. return tag['desc']['blok_type'] == '2' and tag['desc']['type'] == 'target'
  71. except:
  72. return False
  73. def nontarget_func(tag):
  74. try:
  75. return tag['desc']['blok_type'] == '2' and tag['desc']['type'] == 'nontarget'
  76. except:
  77. return False
  78. return (target_func, nontarget_func), epoch_labels
  79. elif blok_type == 'local':
  80. epoch_labels = ('dewiant', 'standard')
  81. def target_func(tag):
  82. try:
  83. return tag['desc']['type_local'] == 'dewiant'
  84. except:
  85. return False
  86. def nontarget_func(tag):
  87. try:
  88. return tag['desc']['type_local'] == 'standard'
  89. except:
  90. return False
  91. return (target_func, nontarget_func), epoch_labels
  92. elif blok_type == 'global':
  93. epoch_labels = ('target', 'nontarget')
  94. def target_func(tag):
  95. try:
  96. return tag['desc']['type_global'] == 'target'
  97. except:
  98. return False
  99. def nontarget_func(tag):
  100. try:
  101. return tag['desc']['type_global'] == 'nontarget'
  102. except:
  103. return False
  104. return (target_func, nontarget_func), epoch_labels
  105. elif blok_type == "erds":
  106. epoch_labels = ('hand_mvt', 'leg_mvt')
  107. def reka_func(tag):
  108. try:
  109. return tag['name'] == 'ERDS_instr1.wav'
  110. except:
  111. return False
  112. def noga_func(tag):
  113. try:
  114. return tag['name'] == 'ERDS_instr2.wav'
  115. except:
  116. return False
  117. return (reka_func, noga_func), epoch_labels
  118. elif blok_type == "wzrokowe_kot_movie":
  119. epoch_labels = ('kot_target', 'kot_nontarget')
  120. def target_func(tag):
  121. try:
  122. return tag['name'] == 'kot_target'
  123. except:
  124. return False
  125. def nontarget_func(tag):
  126. try:
  127. return tag['name'] == 'kot_nontarget'
  128. except:
  129. return False
  130. return (target_func, nontarget_func), epoch_labels
  131. elif blok_type == "wzrokowe_ptak_movie":
  132. epoch_labels = ('ptak_target', 'ptak_nontarget')
  133. def target_func(tag):
  134. try:
  135. return tag['name'] == 'ptak_target'
  136. except:
  137. return False
  138. def nontarget_func(tag):
  139. try:
  140. return tag['name'] == 'ptak_nontarget'
  141. except:
  142. return False
  143. return (target_func, nontarget_func), epoch_labels
  144. elif blok_type == "wzrokowe_both_movie":
  145. epoch_labels = ('target', 'nontarget')
  146. def target_func(tag):
  147. try:
  148. return tag['name'] in ('kot_target', 'ptak_target')
  149. except:
  150. return False
  151. def nontarget_func(tag):
  152. try:
  153. return tag['name'] in ('kot_nontarget', 'ptak_nontarget')
  154. except:
  155. return False
  156. return (target_func, nontarget_func), epoch_labels
  157. elif blok_type == "multimodal_both_movie":
  158. epoch_labels = ('target', 'nontarget')
  159. def target_func(tag):
  160. try:
  161. return tag['name'] in ('kaczka_target', 'pies_target')
  162. except:
  163. return False
  164. def nontarget_func(tag):
  165. try:
  166. return tag['name'] in ('kaczka_nontarget', 'pies_nontarget', 'sowa_nontarget', 'zaba_nontarget')
  167. except:
  168. return False
  169. return (target_func, nontarget_func), epoch_labels
  170. elif blok_type == "multimodal_both":
  171. epoch_labels = ('target', 'nontarget')
  172. def target_func(tag):
  173. try:
  174. return tag['desc']['trial_type'] == "target"
  175. except:
  176. return False
  177. def nontarget_func(tag):
  178. try:
  179. return tag['desc']['trial_type'] == "nontarget"
  180. except:
  181. return False
  182. return (target_func, nontarget_func), epoch_labels
  183. else:
  184. raise Exception("Nie ma bloków tego typu!")
  185. def savetags(stags, filename, start_offset = 0, duration = 0.1):
  186. """Create tags XML from smart tag list"""
  187. writer = TagsFileWriter(filename)
  188. for stag in stags:
  189. tag = stag.get_tags()[0]
  190. tag['start_timestamp'] += start_offset
  191. tag['end_timestamp'] += duration + start_offset
  192. writer.tag_received(tag)
  193. writer.finish_saving(0.0)
  194. def get_microvolt_samples(stag, channels = None):
  195. """Does get_samples on smart tag (read manager), but multiplied by channel gain"""
  196. if not channels: # returns for all channels
  197. gains = np.array([float(i) for i in stag.get_param('channels_gains')], ndmin = 2).T
  198. return stag.get_samples() * gains
  199. elif isinstance(channels, (str, unicode)): # returns for specific channel
  200. channel = channels
  201. ch_n = stag.get_param('channels_names').index(channel)
  202. gain = stag.get_param('channels_gains')[ch_n]
  203. return stag.get_channel_samples(channel) * float(gain)
  204. else: # returns for specified list of channels
  205. gains = stag.get_param('channels_gains')
  206. gains = np.array([float(gains[channels.index(ch)]) for ch in channels], ndmin = 2).T
  207. return stag.get_channels_samples(channels) * gains
  208. def shift_tags_relatively_to_signal_beginnig(rm, shift):
  209. """przesuwa wszystkie tagi w rm o pewną względną odległość
  210. zatem każdy tag jest przesuwany o inną wartość w zależności od jego pozycji
  211. im dalej od początku sygnału, tym bardziej jest przesuwany.
  212. """
  213. tags = rm.get_tags()
  214. for tag in tags:
  215. tag['start_timestamp'] *= 1+shift
  216. tag['end_timestamp'] *= 1+shift
  217. rm.set_tags(tags)
  218. def align_tags(rm, tag_correction_chnls, start_offset = -0.1, duration = 0.3, thr = None, reverse = False, offset = 0):
  219. """aligns tags in read manager to start of sudden change std => 3 in either tag_correction_chnls list
  220. searches for that in window [start_offset+tag_time; tag_time+duration]
  221. if no such change occures - does nothing to the tag - reverse - searches for end of stimulation
  222. offset - offset in seconds to add forcibly
  223. """
  224. tags = rm.get_tags()
  225. Fs = float(rm.get_param('sampling_frequency'))
  226. trigger_chnl = np.zeros(int(rm.get_param('number_of_samples')))
  227. for tag_correction_chnl in tag_correction_chnls:
  228. trigger_chnl += np.abs(rm.get_channel_samples(tag_correction_chnl))
  229. if not thr:
  230. thr = 3 * np.std(trigger_chnl) + np.mean(trigger_chnl)
  231. maksimum = trigger_chnl.max()
  232. if thr > 0.5 * maksimum:
  233. thr = 0.5 * maksimum
  234. for tag in tags:
  235. start = int((tag['start_timestamp'] + start_offset) * Fs)
  236. end = int((tag['start_timestamp'] + start_offset + duration) * Fs)
  237. try:
  238. if reverse:
  239. trig_pos_s_r = np.argmax(np.flipud(trigger_chnl[start:end] > thr))
  240. trig_pos_s = (end - start - 1) - trig_pos_s_r
  241. else:
  242. trig_pos_s = np.argmax(trigger_chnl[start:end] > thr) # will find first True, or first False if no Trues
  243. except ValueError:
  244. tag['start_timestamp'] += offset
  245. tag['end_timestamp'] += offset
  246. continue
  247. # Debuging code:
  248. # print trig_pos_s, Fs, reverse,
  249. # print 'thr', thr, 'value at pos', trigger_chnl[start+trig_pos_s], trigger_chnl[start+trig_pos_s]>thr
  250. # pb.plot(np.linspace(0, (end-start)/Fs, len(trigger_chnl[start:end])), trigger_chnl[start:end])
  251. # pb.axvline(trig_pos_s/Fs, color='k')
  252. # pb.title(str(tag))
  253. # pb.show()
  254. # Debug code end
  255. if trigger_chnl[start + trig_pos_s] > thr:
  256. trig_pos_t = trig_pos_s * 1.0 / Fs
  257. tag_change = trig_pos_t + start_offset
  258. tag['start_timestamp'] += tag_change
  259. tag['end_timestamp'] += tag_change
  260. tag['start_timestamp'] += offset
  261. tag['end_timestamp'] += offset
  262. rm.set_tags(tags)
  263. def show_eog_ica(rm, ica,
  264. eog_chnl = 'eog',
  265. blink_rejection_dict = dict(eeg = 0.000250, eog = 0.000500), # V
  266. correlation_treshhold = 0.25,
  267. results_path =""):
  268. """
  269. rm - read manager with training data for ICA
  270. eog_chnl - channel to use as EOG source
  271. montage - montage of the read manager (for logging and filenames of generated images)
  272. use_eog_events True - split to EOG epochs - do ICA
  273. use eog events False - use whole file
  274. use_eog_events None - only create eog events
  275. manual True/False - shows ICA components, ICA components map
  276. prints correlations with EOG and then lets user write space seperated
  277. indexes of components to remove
  278. Returns
  279. - fitted mne.ICA object to be used in remove_eog_ica to correct read manager
  280. - list of bad components
  281. - detected eog events
  282. """
  283. print('removing eog artifact')
  284. raw = read_manager_continious_to_mne(rm)
  285. n = len(raw.ch_names)
  286. print('n chnls {}'.format(n))
  287. raw.plot(block = True, show = False, scalings = 'auto', title = 'Simple preview of signal', n_channels = n)
  288. data = raw.get_data()
  289. # data zawiera wszystkie dane z raw
  290. # data2 zawierać będzie dane bez outlierów:
  291. data2, drop_inds = mne.preprocessing.ica._reject_data_segments(data, blink_rejection_dict, flat = None, decim = None, info = raw.info, tstep = 2.)
  292. # bardzo ważne, żeby powyżej podać raw.info, a nie ica.info, bo ica.info nie zawiera informacji o eog, a chcemy usunąć outliery również stamtąd
  293. raw2 = mne.io.RawArray(data2, raw.info)
  294. scores = ica.score_sources(raw2, target = eog_chnl, l_freq = 1, h_freq = 9)
  295. if max(np.abs(scores)) > correlation_treshhold:
  296. bads = [np.argmax(np.abs(scores))]
  297. winnig_correlation = max(np.abs(scores))
  298. else:
  299. bads = []
  300. winnig_correlation = 0
  301. print(ica)
  302. print('CORRELATION SCORES:')
  303. for nr, score in enumerate(scores):
  304. msg = '{} {} {}'.format('*' if np.abs(score) > correlation_treshhold else ' ', nr, score)
  305. print msg
  306. if winnig_correlation:
  307. title = 'Components with artifacts: {}, Corr. with {} = {} > thr = {}'.format(bads, eog_chnl, winnig_correlation, correlation_treshhold)
  308. print 'Wybrane złe komponenty:', bads
  309. else:
  310. title = 'Components with artifacts: None, (corr. with {}), thr = {}'.format(eog_chnl, correlation_treshhold)
  311. print 'Nie udało się wybrać złych komponent:', bads
  312. ica.plot_components(res = 128, show = False, title = title, colorbar = True)
  313. ica.plot_sources(raw, show = True)
  314. print "\nDane będą zapisane do katalogu:"
  315. print results_path
  316. print('\nWpisz indeksy komponent rozdzielone spacjami jeśli chcesz nadpisać (po zamknięciu okienek)\n'
  317. 'jeśli nic nie wpiszesz użyją się wybrane automatycznie [potwierdź ENTERem]\n'
  318. 'jeśli nie chcesz usuwać żadnej wpisz -1.')
  319. good = False
  320. while not good:
  321. try:
  322. inp = raw_input()
  323. if inp.split():
  324. man_bads = [int(i) for i in inp.split() if 0 <= int(i) < n]
  325. if not man_bads and not int(inp) == -1:
  326. raise Exception
  327. bads = man_bads
  328. good = True
  329. except Exception:
  330. print('Błąd, wpisz jeszcze raz\n')
  331. print 'Wybrane złe komponenty:', bads
  332. return bads
  333. def fit_eog_ica(rm,
  334. eog_chnl = 'eog',
  335. montage = None,
  336. ds = '',
  337. use_eog_events = False,
  338. manual = False,
  339. rejection_dict = dict(eeg = 0.000150,
  340. eog = 0.000250), # V
  341. blink_rejection_dict = dict(eeg = 0.000250,
  342. eog = 0.000500), # V
  343. correlation_treshhold = 0.5,
  344. outdir = os.path.expanduser(os.path.join(main_outdir, "unknown/noname/ica_and_artifacts/ica_maps"))):
  345. """
  346. rm - read manager with training data for ICA
  347. eog_chnl - channel to use as EOG source
  348. montage - montage of the read manager (for logging and filenames of generated images)
  349. use_eog_events True - split to EOG epochs - do ICA
  350. use eog events False - use whole file
  351. use_eog_events None - only create eog events
  352. manual True/False - shows ICA components, ICA components map
  353. prints correlations with EOG and then lets user write space seperated
  354. indexes of components to remove
  355. Returns
  356. - fitted mne.ICA object to be used in remove_eog_ica to correct read manager
  357. - list of bad components
  358. - detected eog events
  359. """
  360. print('removing eog artifact')
  361. raw = read_manager_continious_to_mne(rm)
  362. n = len(raw.ch_names)
  363. print('n chnls {}'.format(n))
  364. if manual:
  365. raw.plot(block = True, show = True, scalings = 'auto', title = 'Simple preview of signal', n_channels = n)
  366. print('HHHHHHHHHHHHHHHHh', raw.ch_names, eog_chnl)
  367. raw.plot(block=True, show=True, scalings='auto', title='Simple preview of signal', n_channels=n)
  368. events = mne.preprocessing.find_eog_events(raw, ch_name = eog_chnl)
  369. print('EOG EVENTS\n', events)
  370. if use_eog_events is None:
  371. return rm, events
  372. no_eeg_channels = len([ch for ch in raw.ch_names if chtype(ch) == "eeg"])
  373. max_pca_components = no_eeg_channels - 1 - ("car" in montage)
  374. ica = mne.preprocessing.ICA(method = 'extended-infomax', max_pca_components = max_pca_components)
  375. if use_eog_events:
  376. eog = mne.preprocessing.create_eog_epochs(raw, ch_name = eog_chnl)
  377. ica.fit(eog, reject = rejection_dict, picks = mne.pick_types(raw.info, eeg = True, eog = True))
  378. else:
  379. ica.fit(raw, reject = rejection_dict, tstep = 0.3)
  380. data = raw.get_data()
  381. # data zawiera wszystkie dane z raw
  382. # data2 zawierać będzie dane bez outlierów:
  383. data2, drop_inds = mne.preprocessing.ica._reject_data_segments(data, blink_rejection_dict, flat = None, decim = None, info = raw.info, tstep = 2.)
  384. # bardzo ważne, żeby powyżej podać raw.info, a nie ica.info, bo ica.info nie zawiera informacji o eog, a chcemy usunąć outliery również stamtąd
  385. raw2 = mne.io.RawArray(data2, raw.info)
  386. scores = ica.score_sources(raw2, target = eog_chnl, l_freq = 1, h_freq = 9)
  387. bads = list(np.arange(scores.size)[np.abs(scores) > correlation_treshhold])
  388. if len(bads) == 0 and max(np.abs(scores)) > correlation_treshhold/2.:
  389. bads = [np.argmax(np.abs(scores))]
  390. if bads:
  391. winnig_correlations = np.abs(scores[bads])
  392. else:
  393. winnig_correlations = None
  394. print(ica)
  395. filename = 'ICA_eog_' + os.path.basename(ds) + '_{}'.format(montage)
  396. log = open(os.path.join(outdir, filename + '.txt'), 'w')
  397. print('CORRELATION SCORES:')
  398. log.write('CORRELATION SCORES:\n')
  399. for nr, score in enumerate(scores):
  400. msg = '{} {} {}'.format('*' if np.abs(score) > correlation_treshhold else ' ', nr, score)
  401. print msg
  402. log.write(msg + '\n')
  403. log.close()
  404. if np.any(winnig_correlations):
  405. title = 'Components with artifacts: {}, Corr. with {} = {} > thr = {}'.format(bads, eog_chnl, winnig_correlations, correlation_treshhold)
  406. print 'Wybrane złe komponenty:', bads
  407. else:
  408. title = 'Components with artifacts: None, (corr. with {}), thr = {}'.format(eog_chnl, correlation_treshhold)
  409. print 'Nie udało się wybrać złych komponent:', bads
  410. fig = ica.plot_components(res = 128, show = False, title = title, colorbar = True)
  411. # przykładowe 40 sekund sygnału począwczy od 10 sekundy
  412. # oraz odpowiadające mu komponenty ICA
  413. fig_sig = raw.plot(start = 10, duration = 40, block = False, show = False, scalings = 'auto', title = 'Preview of dirty signal', n_channels = n)
  414. fig_com = ica.plot_sources(raw, start = 10, stop = 50, show = False)
  415. fig_com.savefig(os.path.join(outdir, filename + 'example1_components.png'))
  416. fig_sig.savefig(os.path.join(outdir, filename + 'example2_signal.png'))
  417. pb.close(fig_sig)
  418. pb.close(fig_com)
  419. if manual:
  420. ica.plot_sources(raw, show = True)
  421. print('Wpisz indeksy komponent rozdzielone spacjami jeśli chcesz nadpisać (po zamknięciu okienek)\n'
  422. 'jeśli nic nie wpiszesz użyją się wybrane automatycznie [potwierdź ENTERem]\n'
  423. 'jeśli nie chcesz usuwać żadnej wpisz -1.')
  424. # trzeba narysować jeszcze raz, bo jak użytkownik zamknie, to nie ma potem czego zapisać do pliku
  425. fig = ica.plot_components(res = 128, show = False, title = title, colorbar = True)
  426. good = False
  427. while not good:
  428. try:
  429. inp = raw_input()
  430. if inp.split():
  431. man_bads = [int(i) for i in inp.split() if 0 <= int(i) < n]
  432. if not man_bads and not int(inp) == -1:
  433. raise Exception
  434. bads = man_bads
  435. good = True
  436. except Exception:
  437. print('Błąd, wpisz jeszcze raz\n')
  438. print 'Wybrane złe komponenty:', bads
  439. if isinstance(fig, list):
  440. for nr, figura in enumerate(fig):
  441. figura.savefig(os.path.join(outdir, filename + '_{}'.format(nr) + '.png'))
  442. if not manual:
  443. pb.close(figura)
  444. else:
  445. fig.savefig(os.path.join(outdir, filename + '.png'))
  446. if not manual:
  447. pb.close(fig)
  448. return ica, bads, events
  449. def remove_ica_components(rm, ica, bads,
  450. events = [],
  451. scalings = {'eeg': 4e-5, 'eog': 4e-5},
  452. silent = False,
  453. ds = '',
  454. montage = [],
  455. outdir = os.path.expanduser(os.path.join(main_outdir, "unknown/noname/ica_and_artifacts/ica_maps"))):
  456. """
  457. rm - read manager with data to clean
  458. ica - fitted mne.ica object to be used (e.g. returned by fit_eog_ica)
  459. bads - list of bad components of ICA
  460. events - detected eog events (nd.array)
  461. Returns ICA-corrected read manager
  462. """
  463. rm = deepcopy(rm)
  464. if bads:
  465. # read_manager to mne conversion
  466. raw = read_manager_continious_to_mne(rm)
  467. n = len(raw.ch_names)
  468. if not silent:
  469. raw.copy().plot(scalings = scalings, events = events, block = True, show = False, title = 'PRZED ICA', n_channels = n)
  470. raw_clean = ica.apply(raw, exclude = bads)
  471. if not silent:
  472. raw_clean.copy().plot(scalings = scalings, events = events, block = True, show = True, title = 'PO ICA', n_channels = n)
  473. if ds:
  474. # przykładowe 40 sekund sygnału począwczy od 10 sekundy po usunięciu mrugnięć
  475. filename = 'ICA_eog_' + os.path.basename(ds) + '_{}'.format(montage)
  476. fig = raw_clean.plot(start = 10, duration = 40, block = False, show = False, scalings = 'auto', title = 'Preview of cleaned signal', n_channels = n)
  477. fig.savefig(os.path.join(outdir, filename + 'example3_clean_signal.png'))
  478. pb.close(fig)
  479. data = np.array(raw_clean.to_data_frame())
  480. print "CONTROL INFO"
  481. print data.shape
  482. print np.median(np.abs(data), axis = 1)
  483. print np.std(data, axis = 1)
  484. # mne to read_manager conversion
  485. rm.set_samples(data.T, rm.get_param('channels_names'))
  486. return rm
  487. def remove_eog_ica(rm,
  488. eog_chnl = 'eog',
  489. montage = None,
  490. ds = '',
  491. use_eog_events = False,
  492. manual = False,
  493. rejection_dict = dict(eeg = 0.000150,
  494. eog = 0.000250), # V
  495. correlation_treshhold = 0.5):
  496. """
  497. Exists for compatibility reasons.
  498. rm - read manager with training data for ICA
  499. eog_chnl - channel to use as EOG source
  500. montage - montage of the read manager (for logging and filenames of generated images)
  501. use_eog_events True - split to EOG epochs - do ICA
  502. use eog events False - use whole file
  503. use_eog_events None - only create eog events
  504. manual True/False - shows ICA components, ICA components map
  505. prints correlations with EOG and then lets user write space seperated
  506. indexes of components to remove
  507. Returns
  508. - ICA-corrected read manager
  509. - detected eog events
  510. """
  511. ica, bads, eog_events = fit_eog_ica(rm, eog_chnl, montage, ds, use_eog_events, manual, rejection_dict, correlation_treshhold)
  512. if bads:
  513. clean_rm = remove_ica_components(ica, bads, eog_events)
  514. else:
  515. clean_rm = rm
  516. return clean_rm, eog_events
  517. def interp_bads(rm, bads):
  518. ds = read_manager_continious_to_mne(rm)
  519. ds.info['bads'] = bads
  520. ds.interpolate_bads()
  521. data = np.array(ds.to_data_frame())
  522. rm.set_samples(data.T * 1e-6, rm.get_param('channels_names'))
  523. return rm
  524. def undrop_channels(mgr_dropped, mgr_full):
  525. """Uzupełnia mgr_dropped o brakujące kanały pobierając je z mgr_full,
  526. a właściwie to zastępuje kanały mgr_full kanałami z mgr_dropped, bo cała reszta informacji
  527. jest wzięta z mgr_full (zakładam bowiem, że jest to ta sama informacja)"""
  528. available_ch = mgr_dropped.get_param('channels_names')
  529. all_ch = mgr_full.get_param('channels_names')
  530. new_params = deepcopy(mgr_full.get_params())
  531. samples_full = mgr_full.get_samples()
  532. samples_dropped = mgr_dropped.get_samples()
  533. new_tags = deepcopy(mgr_full.get_tags())
  534. new_samples = np.zeros((int(new_params['number_of_channels']), len(samples_full[0])))
  535. # Define new samples and params list values
  536. keys = ['channels_names', 'channels_numbers', 'channels_gains', 'channels_offsets']
  537. keys_to_remove = []
  538. for k in keys:
  539. try:
  540. # Exclude from keys those keys that are missing in mgr
  541. mgr_full.get_params()[k]
  542. except KeyError:
  543. keys_to_remove.append(k)
  544. continue
  545. new_params[k] = []
  546. for k in keys_to_remove:
  547. keys.remove(k)
  548. new_ind = 0
  549. for ch_ind, ch in enumerate(all_ch):
  550. if ch not in available_ch:
  551. new_samples[ch_ind, :] = samples_full[ch_ind, :]
  552. else:
  553. new_samples[ch_ind, :] = samples_dropped[available_ch.index(ch), :]
  554. for k in keys:
  555. new_params[k].append(mgr_full.get_params()[k][ch_ind])
  556. new_ind += 1
  557. info_source = read_info_source.MemoryInfoSource(new_params)
  558. tags_source = read_tags_source.MemoryTagsSource(new_tags)
  559. samples_source = read_data_source.MemoryDataSource(new_samples)
  560. return read_manager.ReadManager(info_source, samples_source, tags_source)
  561. def exclude_channels(mgr, channels):
  562. """exclude all channels in channels list"""
  563. available = set(mgr.get_param('channels_names'))
  564. exclude = set(channels)
  565. channels = list(available.intersection(exclude))
  566. new_params = deepcopy(mgr.get_params())
  567. samples = mgr.get_samples()
  568. new_tags = deepcopy(mgr.get_tags())
  569. ex_channels_inds = [new_params['channels_names'].index(ch) for ch in channels]
  570. assert (-1 not in ex_channels_inds)
  571. new_samples = np.zeros((int(new_params['number_of_channels']) - len(channels),
  572. len(samples[0])))
  573. # Define new samples and params list values
  574. keys = ['channels_names', 'channels_numbers', 'channels_gains', 'channels_offsets']
  575. keys_to_remove = []
  576. for k in keys:
  577. try:
  578. # Exclude from keys those keys that are missing in mgr
  579. mgr.get_params()[k]
  580. except KeyError:
  581. keys_to_remove.append(k)
  582. continue
  583. new_params[k] = []
  584. for k in keys_to_remove:
  585. keys.remove(k)
  586. new_ind = 0
  587. for ch_ind, ch in enumerate(samples):
  588. if ch_ind in ex_channels_inds:
  589. continue
  590. else:
  591. new_samples[new_ind, :] = ch
  592. for k in keys:
  593. new_params[k].append(mgr.get_params()[k][ch_ind])
  594. new_ind += 1
  595. # Define other new new_params
  596. new_params['number_of_channels'] = str(int(new_params['number_of_channels']) - len(channels))
  597. info_source = read_info_source.MemoryInfoSource(new_params)
  598. tags_source = read_tags_source.MemoryTagsSource(new_tags)
  599. samples_source = read_data_source.MemoryDataSource(new_samples)
  600. return read_manager.ReadManager(info_source, samples_source, tags_source)
  601. def leave_channels(mgr, channels):
  602. """exclude all channels except those in channels list"""
  603. chans = deepcopy(mgr.get_param('channels_names'))
  604. for leave in channels:
  605. chans.remove(leave)
  606. return exclude_channels(mgr, chans)
  607. def GetEpochsFromRM(rm, tags_function_list,
  608. start_offset = -0.1, duration = 2.0,
  609. tag_name = None,
  610. get_last_tags = False):
  611. """Extracts stimulus epochs from ReadManager to list of SmartTags
  612. Args:
  613. rm: ReadManager with dataset
  614. start_offset: baseline in negative seconds,
  615. duration: duration of the epoch (excluding baseline),
  616. tags_function_list: list of tag filtering functions to get epochs for
  617. tag_name: tag name to be considered, if you want to use all tags use None
  618. get_last_tags: takes only las 99 tags
  619. Return:
  620. list of smarttags corresponding to tags_function_list"""
  621. # usuwamy tagi, które nie mają dość sygnału przed swoim początkiem, żeby zaaplikować offset
  622. # niestety tracimy w ten sposób początkowy tag lub dwa, ale bez tego wypadną wszystkie, bo SmartTagManager ma buga :(
  623. new_tags = [tag for tag in rm.get_tags() if float(tag['start_timestamp']) > -start_offset]
  624. rm.set_tags(new_tags)
  625. if get_last_tags:
  626. tags = rm.get_tags()
  627. rm.set_tags(tags[-1 - 99:])
  628. tag_def = SmartTagDurationDefinition(start_tag_name = tag_name,
  629. start_offset = start_offset,
  630. end_offset = 0.0,
  631. duration = duration)
  632. stags = SmartTagsManager(tag_def, '', '', '', p_read_manager = rm)
  633. returntags = []
  634. for tagfunction in tags_function_list:
  635. returntags.append(stags.get_smart_tags(p_func = tagfunction, ))
  636. print 'Found epochs in defined groups:', [len(i) for i in returntags]
  637. return returntags
  638. def evoked_from_smart_tags(tags, chnames, bas = -0.1):
  639. """
  640. Args:
  641. tags: smart tag list, to average
  642. chnames: list of channels to use for averaging,
  643. bas: baseline (in negative seconds)"""
  644. min_length = min(i.get_samples().shape[1] for i in tags)
  645. # really don't like this, but epochs generated by smart tags can vary in length by 1 sample
  646. channels_data = []
  647. Fs = float(tags[0].get_param('sampling_frequency'))
  648. for i in tags:
  649. try:
  650. data = i.get_channels_samples(chnames)[:, :min_length]
  651. except IndexError: # in case of len(chnames)==1
  652. data = i.get_channels_samples(chnames)[None, :][:, :min_length]
  653. if bas:
  654. for nr, chnl in enumerate(data):
  655. data[nr] = chnl - np.mean(chnl[0:int(-Fs * bas)]) # baseline correction
  656. if np.max(np.abs(data)) < np.inf:
  657. channels_data.append(data)
  658. return np.mean(channels_data, axis = 0), stats.sem(channels_data, axis = 0)
  659. def do_permutation_test(taglist, chnames):
  660. """ between 2 conditions in taglist
  661. returns: list of clusters (per channel) with tuples (clusters, clusters_p_values)
  662. """
  663. print 'LEN TAGLIST =', len(taglist), "|", [len(t) for t in taglist]
  664. min_length = min([min(i.get_samples().shape[1] for i in tags) for tags in taglist])
  665. clusters = [] # per channel
  666. for channel in chnames:
  667. data_test = []
  668. print 'clustering for channel {}'.format(channel)
  669. for tags in taglist:
  670. data_tag = []
  671. for tag in tags:
  672. chnls_data = tag.get_channel_samples(channel)[:min_length]
  673. data_tag.append(chnls_data.T)
  674. data_test.append(np.array(data_tag))
  675. if len(data_test) > 1:
  676. T_obs, clusters_, cluster_p_values, H0 = mne.stats.permutation_cluster_test(data_test, step_down_p = 0.05, n_jobs = cpu_count()/2, seed = 42)
  677. else:
  678. T_obs, clusters_, cluster_p_values, H0 = mne.stats.permutation_cluster_1samp_test(data_test[0], step_down_p = 0.05, n_jobs = cpu_count()/2, seed = 42)
  679. clusters.append((clusters_, cluster_p_values))
  680. return clusters
  681. def do_permutation_test_spatiotemporal(taglist, chnames):
  682. """ Does permutation cluster test between 2 conditions in taglist
  683. Includes spatiotemporal calculation
  684. returns: spatiotemporal_cluster
  685. """
  686. connectivity = get_our_connectivity_sparse(chnames)
  687. # show_connectivity(connectivity, chnames)
  688. print 'LEN TAGLIST =', len(taglist), "|", [len(t) for t in taglist]
  689. min_length = min([min(i.get_samples().shape[1] for i in tags) for tags in taglist])
  690. data_tags = []
  691. for tags in taglist:
  692. data_tag = []
  693. for tag in tags:
  694. chnls_data = tag.get_channels_samples(chnames)[:, :min_length]
  695. data_tag.append(chnls_data.T)
  696. data_tag = np.array(data_tag)
  697. data_tags.append(data_tag)
  698. if len(data_tags) > 1:
  699. # import IPython
  700. # IPython.embed()
  701. T_obs, clusters_, cluster_p_values, H0 = mne.stats.spatio_temporal_cluster_test(data_tags, step_down_p=0.05, n_jobs = cpu_count()/2, seed=42, n_permutations=1024, connectivity=connectivity)
  702. else:
  703. T_obs, clusters_, cluster_p_values, H0 = mne.stats.spatio_temporal_cluster_1samp_test(data_tags[0], step_down_p=0.05, n_jobs = cpu_count()/2, seed=42, n_permutations=1024, connectivity=connectivity)
  704. print("P_Values:", cluster_p_values)
  705. print "FOND {} CLUSTERS, in THOSE with P<0.05: {}, with P<0.15: {}".format(len(cluster_p_values), np.sum(cluster_p_values<0.05), np.sum(cluster_p_values<0.15))
  706. return clusters_, cluster_p_values
  707. def find_consecutive_ranges(data):
  708. ranges = []
  709. for k, g in groupby(enumerate(data), lambda (i, x): i - x):
  710. group = map(itemgetter(1), g)
  711. ranges.append((group[0], group[-1]))
  712. return ranges
  713. def evoked_list_plot_smart_tags(taglist, chnames = ('Fz', 'Pz','Cz'), chnls_to_clusterize = ('Fz', 'Pz','Cz'),
  714. start_offset = -0.2, roi = (-1e10, 1e10), labels = ('target', 'nontarget'), show = True, size = (5, 5),
  715. addline = [], one_scale = True, anatomical = True, std_corridor = True, permutation_test = True,
  716. spatiotemporal=True, classification_pipeline_clusters=False):
  717. """debug evoked potential plot,
  718. plot list of smarttags,
  719. blocks thread
  720. Args:
  721. taglist: list of smarttags
  722. labels: list of labels
  723. taglist, labels: lists of equal lengths,
  724. chnames: channels to plot
  725. start_offset: baseline in seconds
  726. addline: list of floats - seconds to add vertical barons.py", line 927, in do_autoreject
  727. segment_shape = ts.bad_segments.shape
  728. one_scale: binary - to force the same scale
  729. anatomical: plot all 10-20 electrodes with positions
  730. permutation_test: do a permutation test between target/nontarget
  731. spatiotemporal: if true permutation clustering will be spatiotemporal not per channel
  732. """
  733. for tag in taglist[0] + taglist[1]:
  734. available_chnls = tag.get_param('channels_names')
  735. chnames = [chname for chname in chnames if chname in available_chnls]
  736. chnls_to_clusterize = [chname for chname in chnls_to_clusterize if chname in available_chnls]
  737. evs, stds = [], []
  738. for tags in taglist:
  739. ev, std = evoked_from_smart_tags(tags, chnames, start_offset)
  740. evs.append(ev)
  741. stds.append(std)
  742. Fs = float(taglist[0][0].get_param('sampling_frequency'))
  743. times = np.linspace(0 + start_offset, ev.shape[1] / Fs + start_offset, ev.shape[1])
  744. # baseline correction
  745. for tag in taglist[0] + taglist[1]:
  746. samples = tag.get_channels_samples(tag.get_param('channels_names'))
  747. for s in samples:
  748. s -= np.mean(s[:-int(Fs*start_offset)])
  749. tag.set_samples(samples, tag.get_param('channels_names'))
  750. # truncation according to expected timing of relevant potentials
  751. truncated_taglist = deepcopy(taglist)
  752. for tag in truncated_taglist[0] + truncated_taglist[1]:
  753. samples = tag.get_channels_samples(tag.get_param('channels_names'))
  754. roi_start = np.argmin(np.abs(times-roi[0]*1e-3))
  755. roi_end = np.argmin(np.abs(times-roi[1]*1e-3))
  756. truncated_samples = samples[:, roi_start:roi_end]
  757. tag.set_samples(truncated_samples, tag.get_param('channels_names'))
  758. truncated_times = times[roi_start:roi_end]
  759. if permutation_test and not spatiotemporal and not classification_pipeline_clusters:
  760. clusters_per_chnl = do_permutation_test(truncated_taglist, chnames)
  761. p_values = []
  762. elif permutation_test and spatiotemporal:
  763. clusters, clusters_p_values = do_permutation_test_spatiotemporal(truncated_taglist, chnls_to_clusterize)
  764. p_values = [p for p in clusters_p_values]
  765. clusters_significant = [clusters[i] for i in range(len(clusters_p_values)) if clusters_p_values[i] < 0.1]
  766. clusters_significant_p_values = [clusters_p_values[i] for i in range(len(clusters_p_values)) if clusters_p_values[i] < 0.1]
  767. print clusters_p_values
  768. # todo - to poniżej prawie na pewno nie działa, bo było dużo zmian, do których nie było to dostosowywane, jako że nieużywane
  769. elif permutation_test and classification_pipeline_clusters:
  770. clusteriser = EEGClusterisationSelector(Fs=Fs, bas=start_offset, chnames=chnames, cluster_time_channel_mask=True)
  771. clusteriser.fit_smarttags(taglist)
  772. p_values = [p for p in clusteriser.cluster_p_values]
  773. if one_scale:
  774. vmax = np.max(np.array(evs) + np.array(stds))
  775. vmin = np.min(np.array(evs) - np.array(stds))
  776. if anatomical:
  777. fig, axs = pb.subplots(6, 5, figsize = (1. * 5 * figure_scale, .5625 * 5 * figure_scale))
  778. for ch in channels_not2draw(chnames):
  779. pos = map1020[ch]
  780. axs[pos.y, pos.x].axis("off")
  781. fig.subplots_adjust(left = 0.03, bottom = 0.03, right = 0.98, top = 0.93, wspace = 0.15, hspace = 0.27)
  782. else:
  783. fig = pb.figure(figsize = size)
  784. for nr, chname in enumerate(chnames):
  785. if anatomical:
  786. pos = map1020[chname]
  787. ax = axs[pos.y, pos.x]
  788. else:
  789. ax = fig.add_subplot((len(chnames) + 1) / 2, 2, nr + 1)
  790. # zaznaczenie ROI na wykresie:
  791. x1 = truncated_times[0]
  792. x2 = truncated_times[-1]
  793. ax.plot([x1, x1], [vmin, vmax], "k--", linewidth = 2)
  794. ax.plot([x2, x2], [vmin, vmax], "k--", linewidth = 2)
  795. if permutation_test and not spatiotemporal and not classification_pipeline_clusters:
  796. cl, p_val = clusters_per_chnl[nr]
  797. for cc, pp in zip(cl, p_val):
  798. if pp < 0.05:
  799. color = "blue"
  800. alpha = 0.3
  801. paint_patch = True
  802. elif pp < 0.1:
  803. color = "blue"
  804. alpha = 0.1
  805. paint_patch = True
  806. else:
  807. color = "gray"
  808. alpha = 1 - pp
  809. paint_patch = False
  810. if paint_patch:
  811. ax.axvspan(truncated_times[cc[0].start], truncated_times[cc[0].stop - 1],
  812. color = color, alpha = alpha, zorder = 1)
  813. p_values.append(pp)
  814. elif permutation_test and spatiotemporal:
  815. color_cycle_p05 = ['tab:blue', 'tab:purple', 'tab:green', 'xkcd:navy']
  816. color_cycle_p10 = ['tab:orange', 'tab:brown', 'xkcd:yellow', 'xkcd:burgundy']
  817. current_color_p05_id = 0
  818. current_color_p10_id = 0
  819. for clust_id in range(len(clusters_significant)):
  820. cluster = clusters_significant[clust_id]
  821. try:
  822. cluster_number = chnls_to_clusterize.index(chname)
  823. except ValueError:
  824. continue
  825. cluster_channel_mask = cluster[1] == cluster_number
  826. if clusters_significant_p_values[clust_id] < 0.05:
  827. color = color_cycle_p05[current_color_p05_id]
  828. current_color_p05_id += 1
  829. if current_color_p05_id == len(color_cycle_p05):
  830. current_color_p05_id = 0
  831. else:
  832. color = color_cycle_p10[current_color_p10_id]
  833. current_color_p10_id += 1
  834. if current_color_p10_id == len(color_cycle_p10):
  835. current_color_p10_id = 0
  836. print"SIGNIFICANT CLUSTER: ", clust_id, "P value", clusters_significant_p_values[clust_id], [chnls_to_clusterize[kkk] for kkk in np.unique(cluster[1])]
  837. if np.any(cluster_channel_mask): # czy jest w tym klastrze ten kanał
  838. cluster_time_idx = cluster[0][cluster_channel_mask] # indexy czasowe należące do klastra
  839. for span in find_consecutive_ranges(list(cluster_time_idx)):
  840. ax.axvspan(truncated_times[span[0]], truncated_times[span[1]], alpha = 0.4 if color is not "xkcd:yellow" else 0.6, zorder = 1, color = color)
  841. elif permutation_test and classification_pipeline_clusters:
  842. in_thr = min(clusteriser.cluster_p_values) < 0.05
  843. color = 'tab:blue' if in_thr else 'tab:red'
  844. for cluster in clusteriser.clusters:
  845. cluster_for_channel = cluster[:, nr]
  846. ranges = clusteriser.find_cluster_ranges(cluster_for_channel)
  847. for cluster_range in ranges:
  848. ax.axvspan(cluster_range[0], cluster_range[1], color=color, zorder=1, alpha = 0.4)
  849. # import IPython
  850. # IPython.embed()
  851. for tagsnr in xrange(len(taglist)):
  852. color = None # standard colors
  853. if 'nontarget' in labels[tagsnr]:
  854. color = 'green'
  855. elif 'target' in labels[tagsnr]:
  856. color = 'red'
  857. if 'standard' in labels[tagsnr]:
  858. color = 'green'
  859. elif 'dewiant' in labels[tagsnr]:
  860. color = 'red'
  861. lines, = ax.plot(times,
  862. evs[tagsnr][nr],
  863. label = labels[tagsnr] + ' N:{}'.format(len(taglist[tagsnr])),
  864. color = color,
  865. zorder = 3)
  866. ax.axvline(0, color = 'k')
  867. ax.axhline(0, color = 'k')
  868. for l in addline:
  869. ax.axvline(l, color = 'k')
  870. if std_corridor:
  871. ax.fill_between(times,
  872. evs[tagsnr][nr] - stds[tagsnr][nr],
  873. evs[tagsnr][nr] + stds[tagsnr][nr],
  874. color = lines.get_color(),
  875. alpha = 0.3,
  876. zorder = 2)
  877. if one_scale:
  878. ax.set_ylim(vmin, vmax)
  879. elif type(one_scale) == list:
  880. ax.set_ylim(one_scale[0], one_scale[1])
  881. ax.set_xlim(round(times[0], 2), round(times[-1], 2))
  882. ax.set_title(chname)
  883. set_axis_fontsize(ax, fontsize * figure_scale / 8)
  884. ax.ticklabel_format(style='plain')
  885. ax.yaxis.set_major_formatter(pb.FormatStrFormatter('%.0f'))
  886. ax.legend(fontsize = fontsize * figure_scale / 8)
  887. if show:
  888. pb.show()
  889. return fig, p_values
  890. def mgr_decimate(mgr, factor):
  891. steps = int(factor / 2)
  892. x = mgr.get_samples()
  893. for step in xrange(steps):
  894. if step == 0:
  895. new_samples = x
  896. y = signal.decimate(new_samples, 2, zero_phase = True)
  897. new_samples = y
  898. info_source = deepcopy(mgr.info_source)
  899. info_source.get_params()['number_of_samples'] = new_samples.shape[1]
  900. info_source.get_params()['sampling_frequency'] = float(mgr.get_param('sampling_frequency')) / factor
  901. tags_source = deepcopy(mgr.tags_source)
  902. samples_source = read_data_source.MemoryDataSource(new_samples)
  903. return read_manager.ReadManager(info_source, samples_source, tags_source)
  904. def mgr_order_filter(mgr, order = 0, Wn = [49, 51], rp = None, rs = None, ftype = 'cheby2', btype = 'bandstop',
  905. output = 'ba', use_filtfilt = True, meancorr = 1.0):
  906. nyquist = float(mgr.get_param('sampling_frequency')) / 2.0
  907. if ftype in ['ellip', 'cheby2']:
  908. b, a = signal.iirfilter(order, np.array(Wn) / nyquist, rp, rs, btype = btype, ftype = ftype, output = output)
  909. else:
  910. b, a = signal.iirfilter(order, np.array(Wn) / nyquist, btype = btype, ftype = ftype, output = output)
  911. if use_filtfilt:
  912. for i in range(int(mgr.get_param('number_of_channels'))):
  913. mgr.get_samples()[i, :] = signal.filtfilt(b, a, mgr.get_samples()[i] - np.mean(mgr.get_samples()[i]) * meancorr)
  914. samples_source = read_data_source.MemoryDataSource(mgr.get_samples(), False)
  915. else:
  916. print("FILTER CHANNELs")
  917. filtered = signal.lfilter(b, a, mgr.get_samples())
  918. print("FILTER CHANNELs finished")
  919. samples_source = read_data_source.MemoryDataSource(filtered, True)
  920. info_source = deepcopy(mgr.info_source)
  921. tags_source = deepcopy(mgr.tags_source)
  922. new_mgr = read_manager.ReadManager(info_source, samples_source, tags_source)
  923. return new_mgr
  924. def mgr_filter(mgr, wp, ws, gpass, gstop, analog = 0, ftype = 'ellip', output = 'ba', unit = 'hz', use_filtfilt = True, meancorr = 1.0):
  925. if unit == 'radians':
  926. b, a = signal.iirdesign(wp, ws, gpass, gstop, analog, ftype, output)
  927. w, h = signal.freqz(b, a, 1000)
  928. fff = pb.figure()
  929. ax = fff.add_subplot()
  930. ax.plot(w, 20 * np.log10(np.abs(h)))
  931. pb.show()
  932. elif unit == 'hz':
  933. nyquist = float(mgr.get_param('sampling_frequency')) / 2.0
  934. try:
  935. wp = wp / nyquist
  936. ws = ws / nyquist
  937. except TypeError:
  938. wp = [i / nyquist for i in wp]
  939. ws = [i / nyquist for i in ws]
  940. b, a = signal.iirdesign(wp, ws, gpass, gstop, analog, ftype, output)
  941. if use_filtfilt:
  942. # samples_source = read_data_source.MemoryDataSource(mgr.get_samples(), False)
  943. for i in range(int(mgr.get_param('number_of_channels'))):
  944. # ~ print("FILT FILT CHANNEL "+str(i))
  945. mgr.get_samples()[i, :] = signal.filtfilt(b, a, mgr.get_samples()[i] - np.mean(mgr.get_samples()[i]) * meancorr)
  946. samples_source = read_data_source.MemoryDataSource(mgr.get_samples(), False)
  947. else:
  948. print("FILTER CHANNELs")
  949. filtered = signal.lfilter(b, a, mgr.get_samples())
  950. print("FILTER CHANNELs finished")
  951. samples_source = read_data_source.MemoryDataSource(filtered, True)
  952. info_source = deepcopy(mgr.info_source)
  953. tags_source = deepcopy(mgr.tags_source)
  954. new_mgr = read_manager.ReadManager(info_source, samples_source, tags_source)
  955. return new_mgr
  956. def _exclude_from_montage_indexes(mgr, chnames):
  957. exclude_from_montage_indexes = []
  958. for i in chnames:
  959. try:
  960. exclude_from_montage_indexes.append(mgr.get_param('channels_names').index(i))
  961. except ValueError:
  962. pass
  963. return exclude_from_montage_indexes
  964. def montage_csa(mgr, exclude_from_montage = []):
  965. exclude_from_montage_indexes = _exclude_from_montage_indexes(mgr, exclude_from_montage)
  966. new_samples = get_montage(mgr.get_samples(),
  967. get_montage_matrix_csa(int(mgr.get_param('number_of_channels')),
  968. exclude_from_montage = exclude_from_montage_indexes))
  969. info_source = deepcopy(mgr.info_source)
  970. tags_source = deepcopy(mgr.tags_source)
  971. samples_source = read_data_source.MemoryDataSource(new_samples)
  972. return read_manager.ReadManager(info_source, samples_source, tags_source)
  973. def montage_ears(mgr, l_ear_channel, r_ear_channel, exclude_from_montage = []):
  974. try:
  975. left_index = mgr.get_param('channels_names').index(l_ear_channel)
  976. except ValueError:
  977. print "Brakuje kanału usznego {}. Wykonuję montaż tylko do jegnego ucha.".format(l_ear_channel)
  978. return montage_custom(mgr, [r_ear_channel], exclude_from_montage)
  979. try:
  980. right_index = mgr.get_param('channels_names').index(r_ear_channel)
  981. except ValueError:
  982. print "Brakuje kanału usznego {}. Wykonuję montaż tylko do jegnego ucha.".format(r_ear_channel)
  983. return montage_custom(mgr, [l_ear_channel], exclude_from_montage)
  984. exclude_from_montage_indexes = _exclude_from_montage_indexes(mgr, exclude_from_montage)
  985. if left_index < 0 or right_index < 0:
  986. raise Exception("Montage - couldn`t find ears channels: " + str(l_ear_channel) + ", " + str(r_ear_channel))
  987. new_samples = get_montage(mgr.get_samples(),
  988. get_montage_matrix_ears(int(mgr.get_param('number_of_channels')),
  989. left_index,
  990. right_index,
  991. exclude_from_montage_indexes)
  992. )
  993. info_source = deepcopy(mgr.info_source)
  994. tags_source = deepcopy(mgr.tags_source)
  995. samples_source = read_data_source.MemoryDataSource(new_samples)
  996. return read_manager.ReadManager(info_source, samples_source, tags_source)
  997. def get_channel_indexes(channels, toindex):
  998. """get list of indexes of channels in toindex list as found in
  999. channels list"""
  1000. indexes = []
  1001. for chnl in toindex:
  1002. index = channels.index(chnl)
  1003. if index < 0:
  1004. raise Exception("Montage