/mne/epochs.py

http://github.com/mne-tools/mne-python · Python · 3770 lines · 2897 code · 227 blank · 646 comment · 406 complexity · 991ba055254a88290a273569e9c91910 MD5 · raw file

  1. # -*- coding: utf-8 -*-
  2. """Tools for working with epoched data."""
  3. # Authors: Alexandre Gramfort <alexandre.gramfort@inria.fr>
  4. # Matti Hämäläinen <msh@nmr.mgh.harvard.edu>
  5. # Daniel Strohmeier <daniel.strohmeier@tu-ilmenau.de>
  6. # Denis Engemann <denis.engemann@gmail.com>
  7. # Mainak Jas <mainak@neuro.hut.fi>
  8. # Stefan Appelhoff <stefan.appelhoff@mailbox.org>
  9. #
  10. # License: BSD-3-Clause
  11. from functools import partial
  12. from collections import Counter
  13. from copy import deepcopy
  14. import json
  15. import operator
  16. import os.path as op
  17. import numpy as np
  18. from .io.utils import _construct_bids_filename
  19. from .io.write import (start_file, start_block, end_file, end_block,
  20. write_int, write_float, write_float_matrix,
  21. write_double_matrix, write_complex_float_matrix,
  22. write_complex_double_matrix, write_id, write_string,
  23. _get_split_size, _NEXT_FILE_BUFFER, INT32_MAX)
  24. from .io.meas_info import (read_meas_info, write_meas_info, _merge_info,
  25. _ensure_infos_match)
  26. from .io.open import fiff_open, _get_next_fname
  27. from .io.tree import dir_tree_find
  28. from .io.tag import read_tag, read_tag_info
  29. from .io.constants import FIFF
  30. from .io.fiff.raw import _get_fname_rep
  31. from .io.pick import (channel_indices_by_type, channel_type,
  32. pick_channels, pick_info, _pick_data_channels,
  33. _DATA_CH_TYPES_SPLIT, _picks_to_idx)
  34. from .io.proj import setup_proj, ProjMixin
  35. from .io.base import BaseRaw, TimeMixin, _get_ch_factors
  36. from .bem import _check_origin
  37. from .evoked import EvokedArray, _check_decim
  38. from .baseline import rescale, _log_rescale, _check_baseline
  39. from .channels.channels import (ContainsMixin, UpdateChannelsMixin,
  40. SetChannelsMixin, InterpolationMixin)
  41. from .filter import detrend, FilterMixin, _check_fun
  42. from .parallel import parallel_func
  43. from .event import _read_events_fif, make_fixed_length_events
  44. from .fixes import rng_uniform
  45. from .viz import (plot_epochs, plot_epochs_psd, plot_epochs_psd_topomap,
  46. plot_epochs_image, plot_topo_image_epochs, plot_drop_log)
  47. from .utils import (_check_fname, check_fname, logger, verbose,
  48. _time_mask, check_random_state, warn, _pl,
  49. sizeof_fmt, SizeMixin, copy_function_doc_to_method_doc,
  50. _check_pandas_installed,
  51. _check_preload, GetEpochsMixin,
  52. _prepare_read_metadata, _prepare_write_metadata,
  53. _check_event_id, _gen_events, _check_option,
  54. _check_combine, ShiftTimeMixin, _build_data_frame,
  55. _check_pandas_index_arguments, _convert_times,
  56. _scale_dataframe_data, _check_time_format, object_size,
  57. _on_missing, _validate_type, _ensure_events,
  58. _path_like)
  59. from .utils.docs import fill_doc
  60. from .data.html_templates import epochs_template
  61. def _pack_reject_params(epochs):
  62. reject_params = dict()
  63. for key in ('reject', 'flat', 'reject_tmin', 'reject_tmax'):
  64. val = getattr(epochs, key, None)
  65. if val is not None:
  66. reject_params[key] = val
  67. return reject_params
  68. def _save_split(epochs, fname, part_idx, n_parts, fmt, split_naming,
  69. overwrite):
  70. """Split epochs.
  71. Anything new added to this function also needs to be added to
  72. BaseEpochs.save to account for new file sizes.
  73. """
  74. # insert index in filename
  75. base, ext = op.splitext(fname)
  76. if part_idx > 0:
  77. if split_naming == 'neuromag':
  78. fname = '%s-%d%s' % (base, part_idx, ext)
  79. else:
  80. assert split_naming == 'bids'
  81. fname = _construct_bids_filename(base, ext, part_idx,
  82. validate=False)
  83. _check_fname(fname, overwrite=overwrite)
  84. next_fname = None
  85. if part_idx < n_parts - 1:
  86. if split_naming == 'neuromag':
  87. next_fname = '%s-%d%s' % (base, part_idx + 1, ext)
  88. else:
  89. assert split_naming == 'bids'
  90. next_fname = _construct_bids_filename(base, ext, part_idx + 1,
  91. validate=False)
  92. next_idx = part_idx + 1
  93. else:
  94. next_idx = None
  95. with start_file(fname) as fid:
  96. _save_part(fid, epochs, fmt, n_parts, next_fname, next_idx)
  97. def _save_part(fid, epochs, fmt, n_parts, next_fname, next_idx):
  98. info = epochs.info
  99. meas_id = info['meas_id']
  100. start_block(fid, FIFF.FIFFB_MEAS)
  101. write_id(fid, FIFF.FIFF_BLOCK_ID)
  102. if info['meas_id'] is not None:
  103. write_id(fid, FIFF.FIFF_PARENT_BLOCK_ID, info['meas_id'])
  104. # Write measurement info
  105. write_meas_info(fid, info)
  106. # One or more evoked data sets
  107. start_block(fid, FIFF.FIFFB_PROCESSED_DATA)
  108. start_block(fid, FIFF.FIFFB_MNE_EPOCHS)
  109. # write events out after getting data to ensure bad events are dropped
  110. data = epochs.get_data()
  111. _check_option('fmt', fmt, ['single', 'double'])
  112. if np.iscomplexobj(data):
  113. if fmt == 'single':
  114. write_function = write_complex_float_matrix
  115. elif fmt == 'double':
  116. write_function = write_complex_double_matrix
  117. else:
  118. if fmt == 'single':
  119. write_function = write_float_matrix
  120. elif fmt == 'double':
  121. write_function = write_double_matrix
  122. start_block(fid, FIFF.FIFFB_MNE_EVENTS)
  123. write_int(fid, FIFF.FIFF_MNE_EVENT_LIST, epochs.events.T)
  124. write_string(fid, FIFF.FIFF_DESCRIPTION, _event_id_string(epochs.event_id))
  125. end_block(fid, FIFF.FIFFB_MNE_EVENTS)
  126. # Metadata
  127. if epochs.metadata is not None:
  128. start_block(fid, FIFF.FIFFB_MNE_METADATA)
  129. metadata = _prepare_write_metadata(epochs.metadata)
  130. write_string(fid, FIFF.FIFF_DESCRIPTION, metadata)
  131. end_block(fid, FIFF.FIFFB_MNE_METADATA)
  132. # First and last sample
  133. first = int(round(epochs.tmin * info['sfreq'])) # round just to be safe
  134. last = first + len(epochs.times) - 1
  135. write_int(fid, FIFF.FIFF_FIRST_SAMPLE, first)
  136. write_int(fid, FIFF.FIFF_LAST_SAMPLE, last)
  137. # save baseline
  138. if epochs.baseline is not None:
  139. bmin, bmax = epochs.baseline
  140. write_float(fid, FIFF.FIFF_MNE_BASELINE_MIN, bmin)
  141. write_float(fid, FIFF.FIFF_MNE_BASELINE_MAX, bmax)
  142. # The epochs itself
  143. decal = np.empty(info['nchan'])
  144. for k in range(info['nchan']):
  145. decal[k] = 1.0 / (info['chs'][k]['cal'] *
  146. info['chs'][k].get('scale', 1.0))
  147. data *= decal[np.newaxis, :, np.newaxis]
  148. write_function(fid, FIFF.FIFF_EPOCH, data)
  149. # undo modifications to data
  150. data /= decal[np.newaxis, :, np.newaxis]
  151. write_string(fid, FIFF.FIFF_MNE_EPOCHS_DROP_LOG,
  152. json.dumps(epochs.drop_log))
  153. reject_params = _pack_reject_params(epochs)
  154. if reject_params:
  155. write_string(fid, FIFF.FIFF_MNE_EPOCHS_REJECT_FLAT,
  156. json.dumps(reject_params))
  157. write_int(fid, FIFF.FIFF_MNE_EPOCHS_SELECTION,
  158. epochs.selection)
  159. # And now write the next file info in case epochs are split on disk
  160. if next_fname is not None and n_parts > 1:
  161. start_block(fid, FIFF.FIFFB_REF)
  162. write_int(fid, FIFF.FIFF_REF_ROLE, FIFF.FIFFV_ROLE_NEXT_FILE)
  163. write_string(fid, FIFF.FIFF_REF_FILE_NAME, op.basename(next_fname))
  164. if meas_id is not None:
  165. write_id(fid, FIFF.FIFF_REF_FILE_ID, meas_id)
  166. write_int(fid, FIFF.FIFF_REF_FILE_NUM, next_idx)
  167. end_block(fid, FIFF.FIFFB_REF)
  168. end_block(fid, FIFF.FIFFB_MNE_EPOCHS)
  169. end_block(fid, FIFF.FIFFB_PROCESSED_DATA)
  170. end_block(fid, FIFF.FIFFB_MEAS)
  171. end_file(fid)
  172. def _event_id_string(event_id):
  173. return ';'.join([k + ':' + str(v) for k, v in event_id.items()])
  174. def _merge_events(events, event_id, selection):
  175. """Merge repeated events."""
  176. event_id = event_id.copy()
  177. new_events = events.copy()
  178. event_idxs_to_delete = list()
  179. unique_events, counts = np.unique(events[:, 0], return_counts=True)
  180. for ev in unique_events[counts > 1]:
  181. # indices at which the non-unique events happened
  182. idxs = (events[:, 0] == ev).nonzero()[0]
  183. # Figure out new value for events[:, 1]. Set to 0, if mixed vals exist
  184. unique_priors = np.unique(events[idxs, 1])
  185. new_prior = unique_priors[0] if len(unique_priors) == 1 else 0
  186. # If duplicate time samples have same event val, "merge" == "drop"
  187. # and no new event_id key will be created
  188. ev_vals = np.unique(events[idxs, 2])
  189. if len(ev_vals) <= 1:
  190. new_event_val = ev_vals[0]
  191. # Else, make a new event_id for the merged event
  192. else:
  193. # Find all event_id keys involved in duplicated events. These
  194. # keys will be merged to become a new entry in "event_id"
  195. event_id_keys = list(event_id.keys())
  196. event_id_vals = list(event_id.values())
  197. new_key_comps = [event_id_keys[event_id_vals.index(value)]
  198. for value in ev_vals]
  199. # Check if we already have an entry for merged keys of duplicate
  200. # events ... if yes, reuse it
  201. for key in event_id:
  202. if set(key.split('/')) == set(new_key_comps):
  203. new_event_val = event_id[key]
  204. break
  205. # Else, find an unused value for the new key and make an entry into
  206. # the event_id dict
  207. else:
  208. ev_vals = np.unique(
  209. np.concatenate((list(event_id.values()),
  210. events[:, 1:].flatten()),
  211. axis=0))
  212. if ev_vals[0] > 1:
  213. new_event_val = 1
  214. else:
  215. diffs = np.diff(ev_vals)
  216. idx = np.where(diffs > 1)[0]
  217. idx = -1 if len(idx) == 0 else idx[0]
  218. new_event_val = ev_vals[idx] + 1
  219. new_event_id_key = '/'.join(sorted(new_key_comps))
  220. event_id[new_event_id_key] = int(new_event_val)
  221. # Replace duplicate event times with merged event and remember which
  222. # duplicate indices to delete later
  223. new_events[idxs[0], 1] = new_prior
  224. new_events[idxs[0], 2] = new_event_val
  225. event_idxs_to_delete.extend(idxs[1:])
  226. # Delete duplicate event idxs
  227. new_events = np.delete(new_events, event_idxs_to_delete, 0)
  228. new_selection = np.delete(selection, event_idxs_to_delete, 0)
  229. return new_events, event_id, new_selection
  230. def _handle_event_repeated(events, event_id, event_repeated, selection,
  231. drop_log):
  232. """Handle repeated events.
  233. Note that drop_log will be modified inplace
  234. """
  235. assert len(events) == len(selection)
  236. selection = np.asarray(selection)
  237. unique_events, u_ev_idxs = np.unique(events[:, 0], return_index=True)
  238. # Return early if no duplicates
  239. if len(unique_events) == len(events):
  240. return events, event_id, selection, drop_log
  241. # Else, we have duplicates. Triage ...
  242. _check_option('event_repeated', event_repeated, ['error', 'drop', 'merge'])
  243. drop_log = list(drop_log)
  244. if event_repeated == 'error':
  245. raise RuntimeError('Event time samples were not unique. Consider '
  246. 'setting the `event_repeated` parameter."')
  247. elif event_repeated == 'drop':
  248. logger.info('Multiple event values for single event times found. '
  249. 'Keeping the first occurrence and dropping all others.')
  250. new_events = events[u_ev_idxs]
  251. new_selection = selection[u_ev_idxs]
  252. drop_ev_idxs = np.setdiff1d(selection, new_selection)
  253. for idx in drop_ev_idxs:
  254. drop_log[idx] = drop_log[idx] + ('DROP DUPLICATE',)
  255. selection = new_selection
  256. elif event_repeated == 'merge':
  257. logger.info('Multiple event values for single event times found. '
  258. 'Creating new event value to reflect simultaneous events.')
  259. new_events, event_id, new_selection = \
  260. _merge_events(events, event_id, selection)
  261. drop_ev_idxs = np.setdiff1d(selection, new_selection)
  262. for idx in drop_ev_idxs:
  263. drop_log[idx] = drop_log[idx] + ('MERGE DUPLICATE',)
  264. selection = new_selection
  265. drop_log = tuple(drop_log)
  266. # Remove obsolete kv-pairs from event_id after handling
  267. keys = new_events[:, 1:].flatten()
  268. event_id = {k: v for k, v in event_id.items() if v in keys}
  269. return new_events, event_id, selection, drop_log
  270. @fill_doc
  271. class BaseEpochs(ProjMixin, ContainsMixin, UpdateChannelsMixin, ShiftTimeMixin,
  272. SetChannelsMixin, InterpolationMixin, FilterMixin,
  273. TimeMixin, SizeMixin, GetEpochsMixin):
  274. """Abstract base class for `~mne.Epochs`-type classes.
  275. .. warning:: This class provides basic functionality and should never be
  276. instantiated directly.
  277. Parameters
  278. ----------
  279. %(info_not_none)s
  280. data : ndarray | None
  281. If ``None``, data will be read from the Raw object. If ndarray, must be
  282. of shape (n_epochs, n_channels, n_times).
  283. %(epochs_events_event_id)s
  284. %(epochs_tmin_tmax)s
  285. %(baseline_epochs)s
  286. Defaults to ``(None, 0)``, i.e. beginning of the the data until
  287. time point zero.
  288. %(epochs_raw)s
  289. %(picks_all)s
  290. %(reject_epochs)s
  291. %(flat)s
  292. %(decim)s
  293. %(epochs_reject_tmin_tmax)s
  294. %(epochs_detrend)s
  295. %(proj_epochs)s
  296. %(epochs_on_missing)s
  297. preload_at_end : bool
  298. %(epochs_preload)s
  299. selection : iterable | None
  300. Iterable of indices of selected epochs. If ``None``, will be
  301. automatically generated, corresponding to all non-zero events.
  302. drop_log : tuple | None
  303. Tuple of tuple of strings indicating which epochs have been marked to
  304. be ignored.
  305. filename : str | None
  306. The filename (if the epochs are read from disk).
  307. %(epochs_metadata)s
  308. %(epochs_event_repeated)s
  309. %(verbose)s
  310. Notes
  311. -----
  312. The ``BaseEpochs`` class is public to allow for stable type-checking in
  313. user code (i.e., ``isinstance(my_epochs, BaseEpochs)``) but should not be
  314. used as a constructor for Epochs objects (use instead :class:`mne.Epochs`).
  315. """
  316. @verbose
  317. def __init__(self, info, data, events, event_id=None, tmin=-0.2, tmax=0.5,
  318. baseline=(None, 0), raw=None, picks=None, reject=None,
  319. flat=None, decim=1, reject_tmin=None, reject_tmax=None,
  320. detrend=None, proj=True, on_missing='raise',
  321. preload_at_end=False, selection=None, drop_log=None,
  322. filename=None, metadata=None, event_repeated='error',
  323. verbose=None): # noqa: D102
  324. self.verbose = verbose
  325. if events is not None: # RtEpochs can have events=None
  326. events = _ensure_events(events)
  327. events_max = events.max()
  328. if events_max > INT32_MAX:
  329. raise ValueError(
  330. f'events array values must not exceed {INT32_MAX}, '
  331. f'got {events_max}')
  332. event_id = _check_event_id(event_id, events)
  333. self.event_id = event_id
  334. del event_id
  335. if events is not None: # RtEpochs can have events=None
  336. for key, val in self.event_id.items():
  337. if val not in events[:, 2]:
  338. msg = ('No matching events found for %s '
  339. '(event id %i)' % (key, val))
  340. _on_missing(on_missing, msg)
  341. # ensure metadata matches original events size
  342. self.selection = np.arange(len(events))
  343. self.events = events
  344. self.metadata = metadata
  345. del events
  346. values = list(self.event_id.values())
  347. selected = np.where(np.in1d(self.events[:, 2], values))[0]
  348. if selection is None:
  349. selection = selected
  350. else:
  351. selection = np.array(selection, int)
  352. if selection.shape != (len(selected),):
  353. raise ValueError('selection must be shape %s got shape %s'
  354. % (selected.shape, selection.shape))
  355. self.selection = selection
  356. if drop_log is None:
  357. self.drop_log = tuple(
  358. () if k in self.selection else ('IGNORED',)
  359. for k in range(max(len(self.events),
  360. max(self.selection) + 1)))
  361. else:
  362. self.drop_log = drop_log
  363. self.events = self.events[selected]
  364. self.events, self.event_id, self.selection, self.drop_log = \
  365. _handle_event_repeated(
  366. self.events, self.event_id, event_repeated,
  367. self.selection, self.drop_log)
  368. # then subselect
  369. sub = np.where(np.in1d(selection, self.selection))[0]
  370. if isinstance(metadata, list):
  371. metadata = [metadata[s] for s in sub]
  372. elif metadata is not None:
  373. metadata = metadata.iloc[sub]
  374. self.metadata = metadata
  375. del metadata
  376. n_events = len(self.events)
  377. if n_events > 1:
  378. if np.diff(self.events.astype(np.int64)[:, 0]).min() <= 0:
  379. warn('The events passed to the Epochs constructor are not '
  380. 'chronologically ordered.', RuntimeWarning)
  381. if n_events > 0:
  382. logger.info('%d matching events found' % n_events)
  383. else:
  384. raise ValueError('No desired events found.')
  385. else:
  386. self.drop_log = tuple()
  387. self.selection = np.array([], int)
  388. self.metadata = metadata
  389. # do not set self.events here, let subclass do it
  390. if (detrend not in [None, 0, 1]) or isinstance(detrend, bool):
  391. raise ValueError('detrend must be None, 0, or 1')
  392. self.detrend = detrend
  393. self._raw = raw
  394. info._check_consistency()
  395. self.picks = _picks_to_idx(info, picks, none='all', exclude=(),
  396. allow_empty=False)
  397. self.info = pick_info(info, self.picks)
  398. del info
  399. self._current = 0
  400. if data is None:
  401. self.preload = False
  402. self._data = None
  403. self._do_baseline = True
  404. else:
  405. assert decim == 1
  406. if data.ndim != 3 or data.shape[2] != \
  407. round((tmax - tmin) * self.info['sfreq']) + 1:
  408. raise RuntimeError('bad data shape')
  409. if data.shape[0] != len(self.events):
  410. raise ValueError(
  411. 'The number of epochs and the number of events must match')
  412. self.preload = True
  413. self._data = data
  414. self._do_baseline = False
  415. self._offset = None
  416. if tmin > tmax:
  417. raise ValueError('tmin has to be less than or equal to tmax')
  418. # Handle times
  419. sfreq = float(self.info['sfreq'])
  420. start_idx = int(round(tmin * sfreq))
  421. self._raw_times = np.arange(start_idx,
  422. int(round(tmax * sfreq)) + 1) / sfreq
  423. self._set_times(self._raw_times)
  424. # check reject_tmin and reject_tmax
  425. if reject_tmin is not None:
  426. if (np.isclose(reject_tmin, tmin)):
  427. # adjust for potential small deviations due to sampling freq
  428. reject_tmin = self.tmin
  429. elif reject_tmin < tmin:
  430. raise ValueError(f'reject_tmin needs to be None or >= tmin '
  431. f'(got {reject_tmin})')
  432. if reject_tmax is not None:
  433. if (np.isclose(reject_tmax, tmax)):
  434. # adjust for potential small deviations due to sampling freq
  435. reject_tmax = self.tmax
  436. elif reject_tmax > tmax:
  437. raise ValueError(f'reject_tmax needs to be None or <= tmax '
  438. f'(got {reject_tmax})')
  439. if (reject_tmin is not None) and (reject_tmax is not None):
  440. if reject_tmin >= reject_tmax:
  441. raise ValueError(f'reject_tmin ({reject_tmin}) needs to be '
  442. f' < reject_tmax ({reject_tmax})')
  443. self.reject_tmin = reject_tmin
  444. self.reject_tmax = reject_tmax
  445. # decimation
  446. self._decim = 1
  447. self.decimate(decim)
  448. # baseline correction: replace `None` tuple elements with actual times
  449. self.baseline = _check_baseline(baseline, times=self.times,
  450. sfreq=self.info['sfreq'])
  451. if self.baseline is not None and self.baseline != baseline:
  452. logger.info(f'Setting baseline interval to '
  453. f'[{self.baseline[0]}, {self.baseline[1]}] sec')
  454. logger.info(_log_rescale(self.baseline))
  455. # setup epoch rejection
  456. self.reject = None
  457. self.flat = None
  458. self._reject_setup(reject, flat)
  459. # do the rest
  460. valid_proj = [True, 'delayed', False]
  461. if proj not in valid_proj:
  462. raise ValueError('"proj" must be one of %s, not %s'
  463. % (valid_proj, proj))
  464. if proj == 'delayed':
  465. self._do_delayed_proj = True
  466. logger.info('Entering delayed SSP mode.')
  467. else:
  468. self._do_delayed_proj = False
  469. activate = False if self._do_delayed_proj else proj
  470. self._projector, self.info = setup_proj(self.info, False,
  471. activate=activate)
  472. if preload_at_end:
  473. assert self._data is None
  474. assert self.preload is False
  475. self.load_data() # this will do the projection
  476. elif proj is True and self._projector is not None and data is not None:
  477. # let's make sure we project if data was provided and proj
  478. # requested
  479. # we could do this with np.einsum, but iteration should be
  480. # more memory safe in most instances
  481. for ii, epoch in enumerate(self._data):
  482. self._data[ii] = np.dot(self._projector, epoch)
  483. self._filename = str(filename) if filename is not None else filename
  484. self._check_consistency()
  485. def _check_consistency(self):
  486. """Check invariants of epochs object."""
  487. if hasattr(self, 'events'):
  488. assert len(self.selection) == len(self.events)
  489. assert len(self.drop_log) >= len(self.events)
  490. assert len(self.selection) == sum(
  491. (len(dl) == 0 for dl in self.drop_log))
  492. assert hasattr(self, '_times_readonly')
  493. assert not self.times.flags['WRITEABLE']
  494. assert isinstance(self.drop_log, tuple)
  495. assert all(isinstance(log, tuple) for log in self.drop_log)
  496. assert all(isinstance(s, str) for log in self.drop_log for s in log)
  497. def reset_drop_log_selection(self):
  498. """Reset the drop_log and selection entries.
  499. This method will simplify ``self.drop_log`` and ``self.selection``
  500. so that they are meaningless (tuple of empty tuples and increasing
  501. integers, respectively). This can be useful when concatenating
  502. many Epochs instances, as ``drop_log`` can accumulate many entries
  503. which can become problematic when saving.
  504. """
  505. self.selection = np.arange(len(self.events))
  506. self.drop_log = (tuple(),) * len(self.events)
  507. self._check_consistency()
  508. def load_data(self):
  509. """Load the data if not already preloaded.
  510. Returns
  511. -------
  512. epochs : instance of Epochs
  513. The epochs object.
  514. Notes
  515. -----
  516. This function operates in-place.
  517. .. versionadded:: 0.10.0
  518. """
  519. if self.preload:
  520. return self
  521. self._data = self._get_data()
  522. self.preload = True
  523. self._do_baseline = False
  524. self._decim_slice = slice(None, None, None)
  525. self._decim = 1
  526. self._raw_times = self.times
  527. assert self._data.shape[-1] == len(self.times)
  528. self._raw = None # shouldn't need it anymore
  529. return self
  530. @verbose
  531. def decimate(self, decim, offset=0, verbose=None):
  532. """Decimate the epochs.
  533. Parameters
  534. ----------
  535. %(decim)s
  536. %(decim_offset)s
  537. %(verbose_meth)s
  538. Returns
  539. -------
  540. epochs : instance of Epochs
  541. The decimated Epochs object.
  542. See Also
  543. --------
  544. mne.Evoked.decimate
  545. mne.Epochs.resample
  546. mne.io.Raw.resample
  547. Notes
  548. -----
  549. %(decim_notes)s
  550. If ``decim`` is 1, this method does not copy the underlying data.
  551. .. versionadded:: 0.10.0
  552. References
  553. ----------
  554. .. footbibliography::
  555. """
  556. decim, offset, new_sfreq = _check_decim(self.info, decim, offset)
  557. start_idx = int(round(-self._raw_times[0] * (self.info['sfreq'] *
  558. self._decim)))
  559. self._decim *= decim
  560. i_start = start_idx % self._decim + offset
  561. decim_slice = slice(i_start, None, self._decim)
  562. with self.info._unlock():
  563. self.info['sfreq'] = new_sfreq
  564. if self.preload:
  565. if decim != 1:
  566. self._data = self._data[:, :, decim_slice].copy()
  567. self._raw_times = self._raw_times[decim_slice].copy()
  568. else:
  569. self._data = np.ascontiguousarray(self._data)
  570. self._decim_slice = slice(None)
  571. self._decim = 1
  572. else:
  573. self._decim_slice = decim_slice
  574. self._set_times(self._raw_times[self._decim_slice])
  575. return self
  576. @verbose
  577. def apply_baseline(self, baseline=(None, 0), *, verbose=None):
  578. """Baseline correct epochs.
  579. Parameters
  580. ----------
  581. %(baseline_epochs)s
  582. Defaults to ``(None, 0)``, i.e. beginning of the the data until
  583. time point zero.
  584. %(verbose_meth)s
  585. Returns
  586. -------
  587. epochs : instance of Epochs
  588. The baseline-corrected Epochs object.
  589. Notes
  590. -----
  591. Baseline correction can be done multiple times, but can never be
  592. reverted once the data has been loaded.
  593. .. versionadded:: 0.10.0
  594. """
  595. baseline = _check_baseline(baseline, times=self.times,
  596. sfreq=self.info['sfreq'])
  597. if self.preload:
  598. if self.baseline is not None and baseline is None:
  599. raise RuntimeError('You cannot remove baseline correction '
  600. 'from preloaded data once it has been '
  601. 'applied.')
  602. self._do_baseline = True
  603. picks = self._detrend_picks
  604. rescale(self._data, self.times, baseline, copy=False, picks=picks)
  605. self._do_baseline = False
  606. else: # logging happens in "rescale" in "if" branch
  607. logger.info(_log_rescale(baseline))
  608. assert self._do_baseline is True
  609. self.baseline = baseline
  610. return self
  611. def _reject_setup(self, reject, flat):
  612. """Set self._reject_time and self._channel_type_idx."""
  613. idx = channel_indices_by_type(self.info)
  614. reject = deepcopy(reject) if reject is not None else dict()
  615. flat = deepcopy(flat) if flat is not None else dict()
  616. for rej, kind in zip((reject, flat), ('reject', 'flat')):
  617. if not isinstance(rej, dict):
  618. raise TypeError('reject and flat must be dict or None, not %s'
  619. % type(rej))
  620. bads = set(rej.keys()) - set(idx.keys())
  621. if len(bads) > 0:
  622. raise KeyError('Unknown channel types found in %s: %s'
  623. % (kind, bads))
  624. for key in idx.keys():
  625. # don't throw an error if rejection/flat would do nothing
  626. if len(idx[key]) == 0 and (np.isfinite(reject.get(key, np.inf)) or
  627. flat.get(key, -1) >= 0):
  628. # This is where we could eventually add e.g.
  629. # self.allow_missing_reject_keys check to allow users to
  630. # provide keys that don't exist in data
  631. raise ValueError("No %s channel found. Cannot reject based on "
  632. "%s." % (key.upper(), key.upper()))
  633. # check for invalid values
  634. for rej, kind in zip((reject, flat), ('Rejection', 'Flat')):
  635. for key, val in rej.items():
  636. if val is None or val < 0:
  637. raise ValueError('%s value must be a number >= 0, not "%s"'
  638. % (kind, val))
  639. # now check to see if our rejection and flat are getting more
  640. # restrictive
  641. old_reject = self.reject if self.reject is not None else dict()
  642. old_flat = self.flat if self.flat is not None else dict()
  643. bad_msg = ('{kind}["{key}"] == {new} {op} {old} (old value), new '
  644. '{kind} values must be at least as stringent as '
  645. 'previous ones')
  646. # copy thresholds for channel types that were used previously, but not
  647. # passed this time
  648. for key in set(old_reject) - set(reject):
  649. reject[key] = old_reject[key]
  650. # make sure new thresholds are at least as stringent as the old ones
  651. for key in reject:
  652. if key in old_reject and reject[key] > old_reject[key]:
  653. raise ValueError(
  654. bad_msg.format(kind='reject', key=key, new=reject[key],
  655. old=old_reject[key], op='>'))
  656. # same for flat thresholds
  657. for key in set(old_flat) - set(flat):
  658. flat[key] = old_flat[key]
  659. for key in flat:
  660. if key in old_flat and flat[key] < old_flat[key]:
  661. raise ValueError(
  662. bad_msg.format(kind='flat', key=key, new=flat[key],
  663. old=old_flat[key], op='<'))
  664. # after validation, set parameters
  665. self._bad_dropped = False
  666. self._channel_type_idx = idx
  667. self.reject = reject if len(reject) > 0 else None
  668. self.flat = flat if len(flat) > 0 else None
  669. if (self.reject_tmin is None) and (self.reject_tmax is None):
  670. self._reject_time = None
  671. else:
  672. if self.reject_tmin is None:
  673. reject_imin = None
  674. else:
  675. idxs = np.nonzero(self.times >= self.reject_tmin)[0]
  676. reject_imin = idxs[0]
  677. if self.reject_tmax is None:
  678. reject_imax = None
  679. else:
  680. idxs = np.nonzero(self.times <= self.reject_tmax)[0]
  681. reject_imax = idxs[-1]
  682. self._reject_time = slice(reject_imin, reject_imax)
  683. @verbose # verbose is used by mne-realtime
  684. def _is_good_epoch(self, data, verbose=None):
  685. """Determine if epoch is good."""
  686. if isinstance(data, str):
  687. return False, (data,)
  688. if data is None:
  689. return False, ('NO_DATA',)
  690. n_times = len(self.times)
  691. if data.shape[1] < n_times:
  692. # epoch is too short ie at the end of the data
  693. return False, ('TOO_SHORT',)
  694. if self.reject is None and self.flat is None:
  695. return True, None
  696. else:
  697. if self._reject_time is not None:
  698. data = data[:, self._reject_time]
  699. return _is_good(data, self.ch_names, self._channel_type_idx,
  700. self.reject, self.flat, full_report=True,
  701. ignore_chs=self.info['bads'])
  702. @verbose
  703. def _detrend_offset_decim(self, epoch, picks, verbose=None):
  704. """Aux Function: detrend, baseline correct, offset, decim.
  705. Note: operates inplace
  706. """
  707. if (epoch is None) or isinstance(epoch, str):
  708. return epoch
  709. # Detrend
  710. if self.detrend is not None:
  711. # We explicitly detrend just data channels (not EMG, ECG, EOG which
  712. # are processed by baseline correction)
  713. use_picks = _pick_data_channels(self.info, exclude=())
  714. epoch[use_picks] = detrend(epoch[use_picks], self.detrend, axis=1)
  715. # Baseline correct
  716. if self._do_baseline:
  717. rescale(
  718. epoch, self._raw_times, self.baseline, picks=picks, copy=False,
  719. verbose=False)
  720. # Decimate if necessary (i.e., epoch not preloaded)
  721. epoch = epoch[:, self._decim_slice]
  722. # handle offset
  723. if self._offset is not None:
  724. epoch += self._offset
  725. return epoch
  726. def iter_evoked(self, copy=False):
  727. """Iterate over epochs as a sequence of Evoked objects.
  728. The Evoked objects yielded will each contain a single epoch (i.e., no
  729. averaging is performed).
  730. This method resets the object iteration state to the first epoch.
  731. Parameters
  732. ----------
  733. copy : bool
  734. If False copies of data and measurement info will be omitted
  735. to save time.
  736. """
  737. self.__iter__()
  738. while True:
  739. try:
  740. out = self.__next__(True)
  741. except StopIteration:
  742. break
  743. data, event_id = out
  744. tmin = self.times[0]
  745. info = self.info
  746. if copy:
  747. info = deepcopy(self.info)
  748. data = data.copy()
  749. yield EvokedArray(data, info, tmin, comment=str(event_id))
  750. def subtract_evoked(self, evoked=None):
  751. """Subtract an evoked response from each epoch.
  752. Can be used to exclude the evoked response when analyzing induced
  753. activity, see e.g. [1]_.
  754. Parameters
  755. ----------
  756. evoked : instance of Evoked | None
  757. The evoked response to subtract. If None, the evoked response
  758. is computed from Epochs itself.
  759. Returns
  760. -------
  761. self : instance of Epochs
  762. The modified instance (instance is also modified inplace).
  763. References
  764. ----------
  765. .. [1] David et al. "Mechanisms of evoked and induced responses in
  766. MEG/EEG", NeuroImage, vol. 31, no. 4, pp. 1580-1591, July 2006.
  767. """
  768. logger.info('Subtracting Evoked from Epochs')
  769. if evoked is None:
  770. picks = _pick_data_channels(self.info, exclude=[])
  771. evoked = self.average(picks)
  772. # find the indices of the channels to use
  773. picks = pick_channels(evoked.ch_names, include=self.ch_names)
  774. # make sure the omitted channels are not data channels
  775. if len(picks) < len(self.ch_names):
  776. sel_ch = [evoked.ch_names[ii] for ii in picks]
  777. diff_ch = list(set(self.ch_names).difference(sel_ch))
  778. diff_idx = [self.ch_names.index(ch) for ch in diff_ch]
  779. diff_types = [channel_type(self.info, idx) for idx in diff_idx]
  780. bad_idx = [diff_types.index(t) for t in diff_types if t in
  781. _DATA_CH_TYPES_SPLIT]
  782. if len(bad_idx) > 0:
  783. bad_str = ', '.join([diff_ch[ii] for ii in bad_idx])
  784. raise ValueError('The following data channels are missing '
  785. 'in the evoked response: %s' % bad_str)
  786. logger.info(' The following channels are not included in the '
  787. 'subtraction: %s' % ', '.join(diff_ch))
  788. # make sure the times match
  789. if (len(self.times) != len(evoked.times) or
  790. np.max(np.abs(self.times - evoked.times)) >= 1e-7):
  791. raise ValueError('Epochs and Evoked object do not contain '
  792. 'the same time points.')
  793. # handle SSPs
  794. if not self.proj and evoked.proj:
  795. warn('Evoked has SSP applied while Epochs has not.')
  796. if self.proj and not evoked.proj:
  797. evoked = evoked.copy().apply_proj()
  798. # find the indices of the channels to use in Epochs
  799. ep_picks = [self.ch_names.index(evoked.ch_names[ii]) for ii in picks]
  800. # do the subtraction
  801. if self.preload:
  802. self._data[:, ep_picks, :] -= evoked.data[picks][None, :, :]
  803. else:
  804. if self._offset is None:
  805. self._offset = np.zeros((len(self.ch_names), len(self.times)),
  806. dtype=np.float64)
  807. self._offset[ep_picks] -= evoked.data[picks]
  808. logger.info('[done]')
  809. return self
  810. @fill_doc
  811. def average(self, picks=None, method="mean", by_event_type=False):
  812. """Compute an average over epochs.
  813. Parameters
  814. ----------
  815. %(picks_all_data)s
  816. method : str | callable
  817. How to combine the data. If "mean"/"median", the mean/median
  818. are returned.
  819. Otherwise, must be a callable which, when passed an array of shape
  820. (n_epochs, n_channels, n_time) returns an array of shape
  821. (n_channels, n_time).
  822. Note that due to file type limitations, the kind for all
  823. these will be "average".
  824. by_event_type : bool
  825. When ``False`` (the default) all epochs are averaged and a single
  826. :class:`Evoked` object is returned. When ``True``, epochs are first
  827. grouped by event type (as specified using the ``event_id``
  828. parameter) and a list is returned containing a separate
  829. :class:`Evoked` object for each event type. The ``.comment``
  830. attribute is set to the label of the event type.
  831. .. versionadded:: 0.24.0
  832. Returns
  833. -------
  834. evoked : instance of Evoked | list of Evoked
  835. The averaged epochs. When ``by_event_type=True`` was specified, a
  836. list is returned containing a separate :class:`Evoked` object
  837. for each event type. The list has the same order as the event types
  838. as specified in the ``event_id`` dictionary.
  839. Notes
  840. -----
  841. Computes an average of all epochs in the instance, even if
  842. they correspond to different conditions. To average by condition,
  843. do ``epochs[condition].average()`` for each condition separately.
  844. When picks is None and epochs contain only ICA channels, no channels
  845. are selected, resulting in an error. This is because ICA channels
  846. are not considered data channels (they are of misc type) and only data
  847. channels are selected when picks is None.
  848. The ``method`` parameter allows e.g. robust averaging.
  849. For example, one could do:
  850. >>> from scipy.stats import trim_mean # doctest:+SKIP
  851. >>> trim = lambda x: trim_mean(x, 0.1, axis=0) # doctest:+SKIP
  852. >>> epochs.average(method=trim) # doctest:+SKIP
  853. This would compute the trimmed mean.
  854. """
  855. if by_event_type:
  856. evokeds = list()
  857. for event_type in self.event_id.keys():
  858. ev = self[event_type]._compute_aggregate(picks=picks,
  859. mode=method)
  860. ev.comment = event_type
  861. evokeds.append(ev)
  862. else:
  863. evokeds = self._compute_aggregate(picks=picks, mode=method)
  864. return evokeds
  865. @fill_doc
  866. def standard_error(self, picks=None, by_event_type=False):
  867. """Compute standard error over epochs.
  868. Parameters
  869. ----------
  870. %(picks_all_data)s
  871. by_event_type : bool
  872. When ``False`` (the default) all epochs are averaged and a single
  873. :class:`Evoked` object is returned. When ``True``, epochs are first
  874. grouped by event type (as specified using the ``event_id``
  875. parameter) and a list is returned containing a separate
  876. :class:`Evoked` object for each event type. The ``.comment``
  877. attribute is set to the label of the event type.
  878. .. versionadded:: 0.24.0
  879. Returns
  880. -------
  881. std_err : instance of Evoked | list of Evoked
  882. The standard error over epochs. When ``by_event_type=True`` was
  883. specified, a list is returned containing a separate :class:`Evoked`
  884. object for each event type. The list has the same order as the
  885. event types as specified in the ``event_id`` dictionary.
  886. """
  887. return self.average(picks=picks, method="std",
  888. by_event_type=by_event_type)
  889. def _compute_aggregate(self, picks, mode='mean'):
  890. """Compute the mean, median, or std over epochs and return Evoked."""
  891. # if instance contains ICA channels they won't be included unless picks
  892. # is specified
  893. if picks is None:
  894. check_ICA = [x.startswith('ICA') for x in self.ch_names]
  895. if np.all(check_ICA):
  896. raise TypeError('picks must be specified (i.e. not None) for '
  897. 'ICA channel data')
  898. elif np.any(check_ICA):
  899. warn('ICA channels will not be included unless explicitly '
  900. 'selected in picks')
  901. n_channels = len(self.ch_names)
  902. n_times = len(self.times)
  903. if self.preload:
  904. n_events = len(self.events)
  905. fun = _check_combine(mode, valid=('mean', 'median', 'std'))
  906. data = fun(self._data)
  907. assert len(self.events) == len(self._data)
  908. if data.shape != self._data.shape[1:]:
  909. raise RuntimeError(
  910. 'You passed a function that resulted n data of shape {}, '
  911. 'but it should be {}.'.format(
  912. data.shape, self._data.shape[1:]))
  913. else:
  914. if mode not in {"mean", "std"}:
  915. raise ValueError("If data are not preloaded, can only compute "
  916. "mean or standard deviation.")
  917. data = np.zeros((n_channels, n_times))
  918. n_events = 0
  919. for e in self:
  920. if np.iscomplexobj(e):
  921. data = data.astype(np.complex128)
  922. data += e
  923. n_events += 1
  924. if n_events > 0:
  925. data /= n_events
  926. else:
  927. data.fill(np.nan)
  928. # convert to stderr if requested, could do in one pass but do in
  929. # two (slower) in case there are large numbers
  930. if mode == "std":
  931. data_mean = data.copy()
  932. data.fill(0.)
  933. for e in self:
  934. data += (e - data_mean) ** 2
  935. data = np.sqrt(data / n_events)
  936. if mode == "std":
  937. kind = 'standard_error'
  938. data /= np.sqrt(n_events)
  939. else:
  940. kind = "average"
  941. return self._evoked_from_epoch_data(data, self.info, picks, n_events,
  942. kind, self._name)
  943. @property
  944. def _name(self):
  945. """Give a nice string representation based on event ids."""
  946. if len(self.event_id) == 1:
  947. comment = next(iter(self.event_id.keys()))
  948. else:
  949. count = Counter(self.events[:, 2])
  950. comments = list()
  951. for key, value in self.event_id.items():
  952. comments.append('%.2f × %s' % (
  953. float(count[value]) / len(self.events), key))
  954. comment = ' + '.join(comments)
  955. return comment
  956. def _evoked_from_epoch_data(self, data, info, picks, n_events, kind,
  957. comment):
  958. """Create an evoked object from epoch data."""
  959. info = deepcopy(info)
  960. # don't apply baseline correction; we'll set evoked.baseline manually
  961. evoked = EvokedArray(data, info, tmin=self.times[0], comment=comment,
  962. nave=n_events, kind=kind, baseline=None,
  963. verbose=self.verbose)
  964. evoked.baseline = self.baseline
  965. # the above constructor doesn't recreate the times object precisely
  966. # due to numerical precision issues
  967. evoked.times = self.times.copy()
  968. # pick channels
  969. picks = _picks_to_idx(self.info, picks, 'data_or_ica', ())
  970. ch_names = [evoked.ch_names[p] for p in picks]
  971. evoked.pick_channels(ch_names)
  972. if len(evoked.info['ch_names']) == 0:
  973. raise ValueError('No data channel found when averaging.')
  974. if evoked.nave < 1:
  975. warn('evoked object is empty (based on less than 1 epoch)')
  976. return evoked
  977. @property
  978. def ch_names(self):
  979. """Channel names."""
  980. return self.info['ch_names']
  981. @copy_function_doc_to_method_doc(plot_epochs)
  982. def plot(self, picks=None, scalings=None, n_epochs=20, n_channels=20,
  983. title=None, events=None, event_color=None,
  984. order=None, show=True, block=False, decim='auto', noise_cov=None,
  985. butterfly=False, show_scrollbars=True, show_scalebars=True,
  986. epoch_colors=None, event_id=None, group_by='type'):
  987. return plot_epochs(self, picks=picks, scalings=scalings,
  988. n_epochs=n_epochs, n_channels=n_channels,
  989. title=title, events=events, event_color=event_color,
  990. order=order, show=show, block=block, decim=decim,
  991. noise_cov=noise_cov, butterfly=butterfly,
  992. show_scrollbars=show_scrollbars,
  993. show_scalebars=show_scalebars,
  994. epoch_colors=epoch_colors, event_id=event_id,
  995. group_by=group_by)
  996. @copy_function_doc_to_method_doc(plot_epochs_psd)
  997. def plot_psd(self, fmin=0, fmax=np.inf, tmin=None, tmax=None,
  998. proj=False, bandwidth=None, adaptive=False, low_bias=True,
  999. normalization='length', picks=None, ax=None, color='black',
  1000. xscale='linear', area_mode='std', area_alpha=0.33,
  1001. dB=True, estimate='auto', show=True, n_jobs=1,
  1002. average=False, line_alpha=None, spatial_colors=True,
  1003. sphere=None, exclude='bads', verbose=None):
  1004. return plot_epochs_psd(self, fmin=fmin, fmax=fmax, tmin=tmin,
  1005. tmax=tmax, proj=proj, bandwidth=bandwidth,
  1006. adaptive=adaptive, low_bias=low_bias,
  1007. normalization=normalization, picks=picks, ax=ax,
  1008. color=color, xscale=xscale, area_mode=area_mode,
  1009. area_alpha=area_alpha, dB=dB, estimate=estimate,
  1010. show=show, n_jobs=n_jobs, average=average,
  1011. line_alpha=line_alpha,
  1012. spatial_colors=spatial_colors, sphere=sphere,
  1013. exclude=exclude, verbose=verbose)
  1014. @copy_function_doc_to_method_doc(plot_epochs_psd_topomap)
  1015. def plot_psd_topomap(self, bands=None, tmin=None,
  1016. tmax=None, proj=False, bandwidth=None, adaptive=False,
  1017. low_bias=True, normalization='length', ch_type=None,
  1018. cmap=None, agg_fun=None, dB=True,
  1019. n_jobs=1, normalize=False, cbar_fmt='auto',
  1020. outlines='head', axes=None, show=True,
  1021. sphere=None, vlim=(None, None), verbose=None):
  1022. return plot_epochs_psd_topomap(
  1023. self, bands=bands, tmin=tmin, tmax=tmax,
  1024. proj=proj, bandwidth=bandwidth, adaptive=adaptive,
  1025. low_bias=low_bias, normalization=normalization, ch_type=ch_type,
  1026. cmap=cmap, agg_fun=agg_fun, dB=dB, n_jobs=n_jobs,
  1027. normalize=normalize, cbar_fmt=cbar_fmt, outlines=outlines,
  1028. axes=axes, show=show, sphere=sphere, vlim=vlim, verbose=verbose)
  1029. @copy_function_doc_to_method_doc(plot_topo_image_epochs)
  1030. def plot_topo_image(self, layout=None, sigma=0., vmin=None, vmax=None,
  1031. colorbar=None, order=None, cmap='RdBu_r',
  1032. layout_scale=.95, title=None, scalings=None,
  1033. border='none', fig_facecolor='k', fig_background=None,
  1034. font_color='w', show=True):
  1035. return plot_topo_image_epochs(
  1036. self, layout=layout, sigma=sigma, vmin=vmin, vmax=vmax,
  1037. colorbar=colorbar, order=order, cmap=cmap,
  1038. layout_scale=layout_scale, title=title, scalings=scalings,
  1039. border=border, fig_facecolor=fig_facecolor,
  1040. fig_background=fig_background, font_color=font_color, show=show)
  1041. @verbose
  1042. def drop_bad(self, reject='existing', flat='existing', verbose=None):
  1043. """Drop bad epochs without retaining the epochs data.
  1044. Should be used before slicing operations.
  1045. .. warning:: This operation is slow since all epochs have to be read
  1046. from disk. To avoid reading epochs from disk multiple
  1047. times, use :meth:`mne.Epochs.load_data()`.
  1048. .. note:: To constrain the time period used for estimation of signal
  1049. quality, set ``epochs.reject_tmin`` and
  1050. ``epochs.reject_tmax``, respectively.
  1051. Parameters
  1052. ----------
  1053. %(reject_drop_bad)s
  1054. %(flat_drop_bad)s
  1055. %(verbose_meth)s
  1056. Returns
  1057. -------
  1058. epochs : instance of Epochs
  1059. The epochs with bad epochs dropped. Operates in-place.
  1060. Notes
  1061. -----
  1062. Dropping bad epochs can be done multiple times with different
  1063. ``reject`` and ``flat`` parameters. However, once an epoch is
  1064. dropped, it is dropped forever, so if more lenient thresholds may
  1065. subsequently be applied, `epochs.copy <mne.Epochs.copy>` should be
  1066. used.
  1067. """
  1068. if reject == 'existing':
  1069. if flat == 'existing' and self._bad_dropped:
  1070. return
  1071. reject = self.reject
  1072. if flat == 'existing':
  1073. flat = self.flat
  1074. if any(isinstance(rej, str) and rej != 'existing' for
  1075. rej in (reject, flat)):
  1076. raise ValueError('reject and flat, if strings, must be "existing"')
  1077. self._reject_setup(reject, flat)
  1078. self._get_data(out=False, verbose=verbose)
  1079. return self
  1080. def drop_log_stats(self, ignore=('IGNORED',)):
  1081. """Compute the channel stats based on a drop_log from Epochs.
  1082. Parameters
  1083. ----------
  1084. ignore : list
  1085. The drop reasons to ignore.
  1086. Returns
  1087. -------
  1088. perc : float
  1089. Total percentage of epochs dropped.
  1090. See Also
  1091. --------
  1092. plot_drop_log
  1093. """
  1094. return _drop_log_stats(self.drop_log, ignore)
  1095. @copy_function_doc_to_method_doc(plot_drop_log)
  1096. def plot_drop_log(self, threshold=0, n_max_plot=20, subject='Unknown subj',
  1097. color=(0.9, 0.9, 0.9), width=0.8, ignore=('IGNORED',),
  1098. show=True):
  1099. if not self._bad_dropped:
  1100. raise ValueError("You cannot use plot_drop_log since bad "
  1101. "epochs have not yet been dropped. "
  1102. "Use epochs.drop_bad().")
  1103. return plot_drop_log(self.drop_log, threshold, n_max_plot, subject,
  1104. color=color, width=width, ignore=ignore,
  1105. show=show)
  1106. @copy_function_doc_to_method_doc(plot_epochs_image)
  1107. def plot_image(self, picks=None, sigma=0., vmin=None, vmax=None,
  1108. colorbar=True, order=None, show=True, units=None,
  1109. scalings=None, cmap=None, fig=None, axes=None,
  1110. overlay_times=None, combine=None, group_by=None,
  1111. evoked=True, ts_args=None, title=None, clear=False):
  1112. return plot_epochs_image(self, picks=picks, sigma=sigma, vmin=vmin,
  1113. vmax=vmax, colorbar=colorbar, order=order,
  1114. show=show, units=units, scalings=scalings,
  1115. cmap=cmap, fig=fig, axes=axes,
  1116. overlay_times=overlay_times, combine=combine,
  1117. group_by=group_by, evoked=evoked,
  1118. ts_args=ts_args, title=title, clear=clear)
  1119. @verbose
  1120. def drop(self, indices, reason='USER', verbose=None):
  1121. """Drop epochs based on indices or boolean mask.
  1122. .. note:: The indices refer to the current set of undropped epochs
  1123. rather than the complete set of dropped and undropped epochs.
  1124. They are therefore not necessarily consistent with any
  1125. external indices (e.g., behavioral logs). To drop epochs
  1126. based on external criteria, do not use the ``preload=True``
  1127. flag when constructing an Epochs object, and call this
  1128. method before calling the :meth:`mne.Epochs.drop_bad` or
  1129. :meth:`mne.Epochs.load_data` methods.
  1130. Parameters
  1131. ----------
  1132. indices : array of int or bool
  1133. Set epochs to remove by specifying indices to remove or a boolean
  1134. mask to apply (where True values get removed). Events are
  1135. correspondingly modified.
  1136. reason : str
  1137. Reason for dropping the epochs ('ECG', 'timeout', 'blink' etc).
  1138. Default: 'USER'.
  1139. %(verbose_meth)s
  1140. Returns
  1141. -------
  1142. epochs : instance of Epochs
  1143. The epochs with indices dropped. Operates in-place.
  1144. """
  1145. indices = np.atleast_1d(indices)
  1146. if indices.ndim > 1:
  1147. raise ValueError("indices must be a scalar or a 1-d array")
  1148. if indices.dtype == bool:
  1149. indices = np.where(indices)[0]
  1150. try_idx = np.where(indices < 0, indices + len(self.events), indices)
  1151. out_of_bounds = (try_idx < 0) | (try_idx >= len(self.events))
  1152. if out_of_bounds.any():
  1153. first = indices[out_of_bounds][0]
  1154. raise IndexError("Epoch index %d is out of bounds" % first)
  1155. keep = np.setdiff1d(np.arange(len(self.events)), try_idx)
  1156. self._getitem(keep, reason, copy=False, drop_event_id=False)
  1157. count = len(try_idx)
  1158. logger.info('Dropped %d epoch%s: %s' %
  1159. (count, _pl(count), ', '.join(map(str, np.sort(try_idx)))))
  1160. return self
  1161. def _get_epoch_from_raw(self, idx, verbose=None):
  1162. """Get a given epoch from disk."""
  1163. raise NotImplementedError
  1164. def _project_epoch(self, epoch):
  1165. """Process a raw epoch based on the delayed param."""
  1166. # whenever requested, the first epoch is being projected.
  1167. if (epoch is None) or isinstance(epoch, str):
  1168. # can happen if t < 0 or reject based on annotations
  1169. return epoch
  1170. proj = self._do_delayed_proj or self.proj
  1171. if self._projector is not None and proj is True:
  1172. epoch = np.dot(self._projector, epoch)
  1173. return epoch
  1174. @verbose
  1175. def _get_data(self, out=True, picks=None, item=None, *, units=None,
  1176. tmin=None, tmax=None, verbose=None):
  1177. """Load all data, dropping bad epochs along the way.
  1178. Parameters
  1179. ----------
  1180. out : bool
  1181. Return the data. Setting this to False is used to reject bad
  1182. epochs without caching all the data, which saves memory.
  1183. %(picks_all)s
  1184. item : slice | array-like | str | list | None
  1185. See docstring of get_data method.
  1186. %(units)s
  1187. tmin : int | float | None
  1188. Start time of data to get in seconds.
  1189. tmax : int | float | None
  1190. End time of data to get in seconds.
  1191. %(verbose_meth)s
  1192. """
  1193. start, stop = self._handle_tmin_tmax(tmin, tmax)
  1194. if item is None:
  1195. item = slice(None)
  1196. elif not self._bad_dropped:
  1197. raise ValueError(
  1198. 'item must be None in epochs.get_data() unless bads have been '
  1199. 'dropped. Consider using epochs.drop_bad().')
  1200. select = self._item_to_select(item) # indices or slice
  1201. use_idx = np.arange(len(self.events))[select]
  1202. n_events = len(use_idx)
  1203. # in case there are no good events
  1204. if self.preload:
  1205. # we will store our result in our existing array
  1206. data = self._data
  1207. else:
  1208. # we start out with an empty array, allocate only if necessary
  1209. data = np.empty((0, len(self.info['ch_names']), len(self.times)))
  1210. logger.info('Loading data for %s events and %s original time '
  1211. 'points ...' % (n_events, len(self._raw_times)))
  1212. orig_picks = picks
  1213. if orig_picks is None:
  1214. picks = _picks_to_idx(self.info, picks, "all", exclude=())
  1215. else:
  1216. picks = _picks_to_idx(self.info, picks)
  1217. # handle units param only if we are going to return data (out==True)
  1218. if (units is not None) and out:
  1219. ch_factors = _get_ch_factors(self, units, picks)
  1220. if self._bad_dropped:
  1221. if not out:
  1222. return
  1223. if self.preload:
  1224. data = data[select]
  1225. if orig_picks is not None:
  1226. data = data[:, picks]
  1227. if units is not None:
  1228. data *= ch_factors[:, np.newaxis]
  1229. if start != 0 or stop != self.times.size:
  1230. data = data[..., start:stop]
  1231. return data
  1232. # we need to load from disk, drop, and return data
  1233. detrend_picks = self._detrend_picks
  1234. for ii, idx in enumerate(use_idx):
  1235. # faster to pre-allocate memory here
  1236. epoch_noproj = self._get_epoch_from_raw(idx)
  1237. epoch_noproj = self._detrend_offset_decim(
  1238. epoch_noproj, detrend_picks)
  1239. if self._do_delayed_proj:
  1240. epoch_out = epoch_noproj
  1241. else:
  1242. epoch_out = self._project_epoch(epoch_noproj)
  1243. if ii == 0:
  1244. data = np.empty((n_events, len(self.ch_names),
  1245. len(self.times)), dtype=epoch_out.dtype)
  1246. data[ii] = epoch_out
  1247. else:
  1248. # bads need to be dropped, this might occur after a preload
  1249. # e.g., when calling drop_bad w/new params
  1250. good_idx = []
  1251. n_out = 0
  1252. drop_log = list(self.drop_log)
  1253. assert n_events == len(self.selection)
  1254. if not self.preload:
  1255. detrend_picks = self._detrend_picks
  1256. for idx, sel in enumerate(self.selection):
  1257. if self.preload: # from memory
  1258. if self._do_delayed_proj:
  1259. epoch_noproj = self._data[idx]
  1260. epoch = self._project_epoch(epoch_noproj)
  1261. else:
  1262. epoch_noproj = None
  1263. epoch = self._data[idx]
  1264. else: # from disk
  1265. epoch_noproj = self._get_epoch_from_raw(idx)
  1266. epoch_noproj = self._detrend_offset_decim(
  1267. epoch_noproj, detrend_picks)
  1268. epoch = self._project_epoch(epoch_noproj)
  1269. epoch_out = epoch_noproj if self._do_delayed_proj else epoch
  1270. is_good, bad_tuple = self._is_good_epoch(
  1271. epoch, verbose=verbose)
  1272. if not is_good:
  1273. assert isinstance(bad_tuple, tuple)
  1274. assert all(isinstance(x, str) for x in bad_tuple)
  1275. drop_log[sel] = drop_log[sel] + bad_tuple
  1276. continue
  1277. good_idx.append(idx)
  1278. # store the epoch if there is a reason to (output or update)
  1279. if out or self.preload:
  1280. # faster to pre-allocate, then trim as necessary
  1281. if n_out == 0 and not self.preload:
  1282. data = np.empty((n_events, epoch_out.shape[0],
  1283. epoch_out.shape[1]),
  1284. dtype=epoch_out.dtype, order='C')
  1285. data[n_out] = epoch_out
  1286. n_out += 1
  1287. self.drop_log = tuple(drop_log)
  1288. del drop_log
  1289. self._bad_dropped = True
  1290. logger.info("%d bad epochs dropped" % (n_events - len(good_idx)))
  1291. # adjust the data size if there is a reason to (output or update)
  1292. if out or self.preload:
  1293. if data.flags['OWNDATA'] and data.flags['C_CONTIGUOUS']:
  1294. data.resize((n_out,) + data.shape[1:], refcheck=False)
  1295. else:
  1296. data = data[:n_out]
  1297. if self.preload:
  1298. self._data = data
  1299. # Now update our properties (excepd data, which is already fixed)
  1300. self._getitem(good_idx, None, copy=False, drop_event_id=False,
  1301. select_data=False)
  1302. if out:
  1303. if orig_picks is not None:
  1304. data = data[:, picks]
  1305. if units is not None:
  1306. data *= ch_factors[:, np.newaxis]
  1307. if start != 0 or stop != self.times.size:
  1308. data = data[..., start:stop]
  1309. return data
  1310. else:
  1311. return None
  1312. @property
  1313. def _detrend_picks(self):
  1314. if self._do_baseline:
  1315. return _pick_data_channels(
  1316. self.info, with_ref_meg=True, with_aux=True, exclude=())
  1317. else:
  1318. return []
  1319. @fill_doc
  1320. def get_data(self, picks=None, item=None, units=None, tmin=None,
  1321. tmax=None):
  1322. """Get all epochs as a 3D array.
  1323. Parameters
  1324. ----------
  1325. %(picks_all)s
  1326. item : slice | array-like | str | list | None
  1327. The items to get. See :meth:`mne.Epochs.__getitem__` for
  1328. a description of valid options. This can be substantially faster
  1329. for obtaining an ndarray than :meth:`~mne.Epochs.__getitem__`
  1330. for repeated access on large Epochs objects.
  1331. None (default) is an alias for ``slice(None)``.
  1332. .. versionadded:: 0.20
  1333. %(units)s
  1334. .. versionadded:: 0.24
  1335. tmin : int | float | None
  1336. Start time of data to get in seconds.
  1337. .. versionadded:: 0.24.0
  1338. tmax : int | float | None
  1339. End time of data to get in seconds.
  1340. .. versionadded:: 0.24.0
  1341. Returns
  1342. -------
  1343. data : array of shape (n_epochs, n_channels, n_times)
  1344. A view on epochs data.
  1345. """
  1346. return self._get_data(picks=picks, item=item, units=units, tmin=tmin,
  1347. tmax=tmax)
  1348. @verbose
  1349. def apply_function(self, fun, picks=None, dtype=None, n_jobs=1,
  1350. channel_wise=True, verbose=None, **kwargs):
  1351. """Apply a function to a subset of channels.
  1352. %(applyfun_summary_epochs)s
  1353. Parameters
  1354. ----------
  1355. %(applyfun_fun)s
  1356. %(picks_all_data_noref)s
  1357. %(applyfun_dtype)s
  1358. %(n_jobs)s
  1359. %(applyfun_chwise_epo)s
  1360. %(verbose_meth)s
  1361. %(kwarg_fun)s
  1362. Returns
  1363. -------
  1364. self : instance of Epochs
  1365. The epochs object with transformed data.
  1366. """
  1367. _check_preload(self, 'epochs.apply_function')
  1368. picks = _picks_to_idx(self.info, picks, exclude=(), with_ref_meg=False)
  1369. if not callable(fun):
  1370. raise ValueError('fun needs to be a function')
  1371. data_in = self._data
  1372. if dtype is not None and dtype != self._data.dtype:
  1373. self._data = self._data.astype(dtype)
  1374. if channel_wise:
  1375. if n_jobs == 1:
  1376. _fun = partial(_check_fun, fun, **kwargs)
  1377. # modify data inplace to save memory
  1378. for idx in picks:
  1379. self._data[:, idx, :] = np.apply_along_axis(
  1380. _fun, -1, data_in[:, idx, :])
  1381. else:
  1382. # use parallel function
  1383. parallel, p_fun, _ = parallel_func(_check_fun, n_jobs)
  1384. data_picks_new = parallel(p_fun(
  1385. fun, data_in[:, p, :], **kwargs) for p in picks)
  1386. for pp, p in enumerate(picks):
  1387. self._data[:, p, :] = data_picks_new[pp]
  1388. else:
  1389. self._data = _check_fun(fun, data_in, **kwargs)
  1390. return self
  1391. @property
  1392. def times(self):
  1393. """Time vector in seconds."""
  1394. return self._times_readonly
  1395. def _set_times(self, times):
  1396. """Set self._times_readonly (and make it read only)."""
  1397. # naming used to indicate that it shouldn't be
  1398. # changed directly, but rather via this method
  1399. self._times_readonly = times.copy()
  1400. self._times_readonly.flags['WRITEABLE'] = False
  1401. @property
  1402. def tmin(self):
  1403. """First time point."""
  1404. return self.times[0]
  1405. @property
  1406. def filename(self):
  1407. """The filename."""
  1408. return self._filename
  1409. @property
  1410. def tmax(self):
  1411. """Last time point."""
  1412. return self.times[-1]
  1413. def __repr__(self):
  1414. """Build string representation."""
  1415. s = ' %s events ' % len(self.events)
  1416. s += '(all good)' if self._bad_dropped else '(good & bad)'
  1417. s += ', %g - %g sec' % (self.tmin, self.tmax)
  1418. s += ', baseline '
  1419. if self.baseline is None:
  1420. s += 'off'
  1421. else:
  1422. s += f'{self.baseline[0]:g} – {self.baseline[1]:g} sec'
  1423. if self.baseline != _check_baseline(
  1424. self.baseline, times=self.times, sfreq=self.info['sfreq'],
  1425. on_baseline_outside_data='adjust'):
  1426. s += ' (baseline period was cropped after baseline correction)'
  1427. s += ', ~%s' % (sizeof_fmt(self._size),)
  1428. s += ', data%s loaded' % ('' if self.preload else ' not')
  1429. s += ', with metadata' if self.metadata is not None else ''
  1430. counts = ['%r: %i' % (k, sum(self.events[:, 2] == v))
  1431. for k, v in sorted(self.event_id.items())]
  1432. if len(self.event_id) > 0:
  1433. s += ',' + '\n '.join([''] + counts)
  1434. class_name = self.__class__.__name__
  1435. class_name = 'Epochs' if class_name == 'BaseEpochs' else class_name
  1436. return '<%s | %s>' % (class_name, s)
  1437. def _repr_html_(self):
  1438. if self.baseline is None:
  1439. baseline = 'off'
  1440. else:
  1441. baseline = tuple([f'{b:.3f}' for b in self.baseline])
  1442. baseline = f'{baseline[0]} – {baseline[1]} sec'
  1443. if isinstance(self.event_id, dict):
  1444. events = ''
  1445. for k, v in sorted(self.event_id.items()):
  1446. n_events = sum(self.events[:, 2] == v)
  1447. events += f'{k}: {n_events}<br>'
  1448. elif isinstance(self.event_id, list):
  1449. events = ''
  1450. for k in self.event_id:
  1451. n_events = sum(self.events[:, 2] == k)
  1452. events += f'{k}: {n_events}<br>'
  1453. elif isinstance(self.event_id, int):
  1454. n_events = len(self.events[:, 2])
  1455. events = f'{self.event_id}: {n_events}<br>'
  1456. else:
  1457. events = None
  1458. return epochs_template.substitute(epochs=self, baseline=baseline,
  1459. events=events)
  1460. @verbose
  1461. def crop(self, tmin=None, tmax=None, include_tmax=True, verbose=None):
  1462. """Crop a time interval from the epochs.
  1463. Parameters
  1464. ----------
  1465. tmin : float | None
  1466. Start time of selection in seconds.
  1467. tmax : float | None
  1468. End time of selection in seconds.
  1469. %(include_tmax)s
  1470. %(verbose_meth)s
  1471. Returns
  1472. -------
  1473. epochs : instance of Epochs
  1474. The cropped epochs object, modified in-place.
  1475. Notes
  1476. -----
  1477. %(notes_tmax_included_by_default)s
  1478. """
  1479. # XXX this could be made to work on non-preloaded data...
  1480. _check_preload(self, 'Modifying data of epochs')
  1481. if tmin is None:
  1482. tmin = self.tmin
  1483. elif tmin < self.tmin:
  1484. warn('tmin is not in epochs time interval. tmin is set to '
  1485. 'epochs.tmin')
  1486. tmin = self.tmin
  1487. if tmax is None:
  1488. tmax = self.tmax
  1489. elif tmax > self.tmax:
  1490. warn('tmax is not in epochs time interval. tmax is set to '
  1491. 'epochs.tmax')
  1492. tmax = self.tmax
  1493. include_tmax = True
  1494. tmask = _time_mask(self.times, tmin, tmax, sfreq=self.info['sfreq'],
  1495. include_tmax=include_tmax)
  1496. self._set_times(self.times[tmask])
  1497. self._raw_times = self._raw_times[tmask]
  1498. self._data = self._data[:, :, tmask]
  1499. # Adjust rejection period
  1500. if self.reject_tmin is not None and self.reject_tmin < self.tmin:
  1501. logger.info(
  1502. f'reject_tmin is not in epochs time interval. '
  1503. f'Setting reject_tmin to epochs.tmin ({self.tmin} sec)')
  1504. self.reject_tmin = self.tmin
  1505. if self.reject_tmax is not None and self.reject_tmax > self.tmax:
  1506. logger.info(
  1507. f'reject_tmax is not in epochs time interval. '
  1508. f'Setting reject_tmax to epochs.tmax ({self.tmax} sec)')
  1509. self.reject_tmax = self.tmax
  1510. return self
  1511. def copy(self):
  1512. """Return copy of Epochs instance.
  1513. Returns
  1514. -------
  1515. epochs : instance of Epochs
  1516. A copy of the object.
  1517. """
  1518. return deepcopy(self)
  1519. def __deepcopy__(self, memodict):
  1520. """Make a deepcopy."""
  1521. cls = self.__class__
  1522. result = cls.__new__(cls)
  1523. for k, v in self.__dict__.items():
  1524. # drop_log is immutable and _raw is private (and problematic to
  1525. # deepcopy)
  1526. if k in ('drop_log', '_raw', '_times_readonly'):
  1527. memodict[id(v)] = v
  1528. else:
  1529. v = deepcopy(v, memodict)
  1530. result.__dict__[k] = v
  1531. return result
  1532. @verbose
  1533. def save(self, fname, split_size='2GB', fmt='single', overwrite=False,
  1534. split_naming='neuromag', verbose=True):
  1535. """Save epochs in a fif file.
  1536. Parameters
  1537. ----------
  1538. fname : str
  1539. The name of the file, which should end with ``-epo.fif`` or
  1540. ``-epo.fif.gz``.
  1541. split_size : str | int
  1542. Large raw files are automatically split into multiple pieces. This
  1543. parameter specifies the maximum size of each piece. If the
  1544. parameter is an integer, it specifies the size in Bytes. It is
  1545. also possible to pass a human-readable string, e.g., 100MB.
  1546. Note: Due to FIFF file limitations, the maximum split size is 2GB.
  1547. .. versionadded:: 0.10.0
  1548. fmt : str
  1549. Format to save data. Valid options are 'double' or
  1550. 'single' for 64- or 32-bit float, or for 128- or
  1551. 64-bit complex numbers respectively. Note: Data are processed with
  1552. double precision. Choosing single-precision, the saved data
  1553. will slightly differ due to the reduction in precision.
  1554. .. versionadded:: 0.17
  1555. %(overwrite)s
  1556. To overwrite original file (the same one that was loaded),
  1557. data must be preloaded upon reading. This defaults to True in 0.18
  1558. but will change to False in 0.19.
  1559. .. versionadded:: 0.18
  1560. %(split_naming)s
  1561. .. versionadded:: 0.24
  1562. %(verbose_meth)s
  1563. Notes
  1564. -----
  1565. Bad epochs will be dropped before saving the epochs to disk.
  1566. """
  1567. check_fname(fname, 'epochs', ('-epo.fif', '-epo.fif.gz',
  1568. '_epo.fif', '_epo.fif.gz'))
  1569. # check for file existence and expand `~` if present
  1570. fname = _check_fname(fname=fname, overwrite=overwrite)
  1571. split_size_bytes = _get_split_size(split_size)
  1572. _check_option('fmt', fmt, ['single', 'double'])
  1573. # to know the length accurately. The get_data() call would drop
  1574. # bad epochs anyway
  1575. self.drop_bad()
  1576. # total_size tracks sizes that get split
  1577. # over_size tracks overhead (tags, things that get written to each)
  1578. if len(self) == 0:
  1579. warn('Saving epochs with no data')
  1580. total_size = 0
  1581. else:
  1582. d = self[0].get_data()
  1583. # this should be guaranteed by subclasses
  1584. assert d.dtype in ('>f8', '<f8', '>c16', '<c16')
  1585. total_size = d.nbytes * len(self)
  1586. self._check_consistency()
  1587. over_size = 0
  1588. if fmt == "single":
  1589. total_size //= 2 # 64bit data converted to 32bit before writing.
  1590. over_size += 32 # FIF tags
  1591. # Account for all the other things we write, too
  1592. # 1. meas_id block plus main epochs block
  1593. over_size += 132
  1594. # 2. measurement info (likely slight overestimate, but okay)
  1595. over_size += object_size(self.info) + 16 * len(self.info)
  1596. # 3. events and event_id in its own block
  1597. total_size += self.events.size * 4
  1598. over_size += len(_event_id_string(self.event_id)) + 72
  1599. # 4. Metadata in a block of its own
  1600. if self.metadata is not None:
  1601. total_size += len(_prepare_write_metadata(self.metadata))
  1602. over_size += 56
  1603. # 5. first sample, last sample, baseline
  1604. over_size += 40 * (self.baseline is not None) + 40
  1605. # 6. drop log: gets written to each, with IGNORE for ones that are
  1606. # not part of it. So make a fake one with all having entries.
  1607. drop_size = len(json.dumps(self.drop_log)) + 16
  1608. drop_size += 8 * (len(self.selection) - 1) # worst case: all but one
  1609. over_size += drop_size
  1610. # 7. reject params
  1611. reject_params = _pack_reject_params(self)
  1612. if reject_params:
  1613. over_size += len(json.dumps(reject_params)) + 16
  1614. # 8. selection
  1615. total_size += self.selection.size * 4
  1616. over_size += 16
  1617. # 9. end of file tags
  1618. over_size += _NEXT_FILE_BUFFER
  1619. logger.debug(f' Overhead size: {str(over_size).rjust(15)}')
  1620. logger.debug(f' Splittable size: {str(total_size).rjust(15)}')
  1621. logger.debug(f' Split size: {str(split_size_bytes).rjust(15)}')
  1622. # need at least one per
  1623. n_epochs = len(self)
  1624. n_per = total_size // n_epochs if n_epochs else 0
  1625. min_size = n_per + over_size
  1626. if split_size_bytes < min_size:
  1627. raise ValueError(
  1628. f'The split size {split_size} is too small to safely write '
  1629. 'the epochs contents, minimum split size is '
  1630. f'{sizeof_fmt(min_size)} ({min_size} bytes)')
  1631. # This is like max(int(ceil(total_size / split_size)), 1) but cleaner
  1632. n_parts = max(
  1633. (total_size - 1) // (split_size_bytes - over_size) + 1, 1)
  1634. assert n_parts >= 1, n_parts
  1635. if n_parts > 1:
  1636. logger.info(f'Splitting into {n_parts} parts')
  1637. if n_parts > 100: # This must be an error
  1638. raise ValueError(
  1639. f'Split size {split_size} would result in writing '
  1640. f'{n_parts} files')
  1641. if len(self.drop_log) > 100000:
  1642. warn(f'epochs.drop_log contains {len(self.drop_log)} entries '
  1643. f'which will incur up to a {sizeof_fmt(drop_size)} writing '
  1644. f'overhead (per split file), consider using '
  1645. f'epochs.reset_drop_log_selection() prior to writing')
  1646. epoch_idxs = np.array_split(np.arange(n_epochs), n_parts)
  1647. for part_idx, epoch_idx in enumerate(epoch_idxs):
  1648. this_epochs = self[epoch_idx] if n_parts > 1 else self
  1649. # avoid missing event_ids in splits
  1650. this_epochs.event_id = self.event_id
  1651. _save_split(this_epochs, fname, part_idx, n_parts, fmt,
  1652. split_naming, overwrite)
  1653. @verbose
  1654. def export(self, fname, fmt='auto', *, overwrite=False, verbose=None):
  1655. """Export Epochs to external formats.
  1656. Supported formats: EEGLAB (set, uses :mod:`eeglabio`)
  1657. %(export_warning)s :meth:`save` instead.
  1658. Parameters
  1659. ----------
  1660. %(export_params_fname)s
  1661. %(export_params_fmt)s
  1662. %(overwrite)s
  1663. .. versionadded:: 0.24.1
  1664. %(verbose)s
  1665. Notes
  1666. -----
  1667. %(export_eeglab_note)s
  1668. """
  1669. from .export import export_epochs
  1670. export_epochs(fname, self, fmt, overwrite=overwrite, verbose=verbose)
  1671. def equalize_event_counts(self, event_ids=None, method='mintime'):
  1672. """Equalize the number of trials in each condition.
  1673. It tries to make the remaining epochs occurring as close as possible in
  1674. time. This method works based on the idea that if there happened to be
  1675. some time-varying (like on the scale of minutes) noise characteristics
  1676. during a recording, they could be compensated for (to some extent) in
  1677. the equalization process. This method thus seeks to reduce any of
  1678. those effects by minimizing the differences in the times of the events
  1679. within a `~mne.Epochs` instance. For example, if one event type
  1680. occurred at time points ``[1, 2, 3, 4, 120, 121]`` and the another one
  1681. at ``[3.5, 4.5, 120.5, 121.5]``, this method would remove the events at
  1682. times ``[1, 2]`` for the first event type and not the events at times
  1683. ``[120, 121]``.
  1684. Parameters
  1685. ----------
  1686. event_ids : None | list | dict
  1687. The event types to equalize.
  1688. If ``None`` (default), equalize the counts of **all** event types
  1689. present in the `~mne.Epochs` instance.
  1690. If a list, each element can either be a string (event name) or a
  1691. list of strings. In the case where one of the entries is a list of
  1692. strings, event types in that list will be grouped together before
  1693. equalizing trial counts across conditions.
  1694. If a dictionary, the keys are considered as the event names whose
  1695. counts to equalize, i.e., passing ``dict(A=1, B=2)`` will have the
  1696. same effect as passing ``['A', 'B']``. This is useful if you intend
  1697. to pass an ``event_id`` dictionary that was used when creating
  1698. `~mne.Epochs`.
  1699. In the case where partial matching is used (using ``/`` in
  1700. the event names), the event types will be matched according to the
  1701. provided tags, that is, processing works as if the ``event_ids``
  1702. matched by the provided tags had been supplied instead.
  1703. The ``event_ids`` must identify non-overlapping subsets of the
  1704. epochs.
  1705. method : str
  1706. If ``'truncate'``, events will be truncated from the end of each
  1707. type of events. If ``'mintime'``, timing differences between each
  1708. event type will be minimized.
  1709. Returns
  1710. -------
  1711. epochs : instance of Epochs
  1712. The modified instance. It is modified in-place.
  1713. indices : array of int
  1714. Indices from the original events list that were dropped.
  1715. Notes
  1716. -----
  1717. For example (if ``epochs.event_id`` was ``{'Left': 1, 'Right': 2,
  1718. 'Nonspatial':3}``:
  1719. epochs.equalize_event_counts([['Left', 'Right'], 'Nonspatial'])
  1720. would equalize the number of trials in the ``'Nonspatial'`` condition
  1721. with the total number of trials in the ``'Left'`` and ``'Right'``
  1722. conditions combined.
  1723. If multiple indices are provided (e.g. ``'Left'`` and ``'Right'`` in
  1724. the example above), it is not guaranteed that after equalization the
  1725. conditions will contribute equally. E.g., it is possible to end up
  1726. with 70 ``'Nonspatial'`` epochs, 69 ``'Left'`` and 1 ``'Right'``.
  1727. .. versionchanged:: 0.23
  1728. Default to equalizing all events in the passed instance if no
  1729. event names were specified explicitly.
  1730. """
  1731. from collections.abc import Iterable
  1732. _validate_type(event_ids, types=(Iterable, None),
  1733. item_name='event_ids', type_name='list-like or None')
  1734. if isinstance(event_ids, str):
  1735. raise TypeError(f'event_ids must be list-like or None, but '
  1736. f'received a string: {event_ids}')
  1737. if event_ids is None:
  1738. event_ids = list(self.event_id)
  1739. elif not event_ids:
  1740. raise ValueError('event_ids must have at least one element')
  1741. if not self._bad_dropped:
  1742. self.drop_bad()
  1743. # figure out how to equalize
  1744. eq_inds = list()
  1745. # deal with hierarchical tags
  1746. ids = self.event_id
  1747. orig_ids = list(event_ids)
  1748. tagging = False
  1749. if "/" in "".join(ids):
  1750. # make string inputs a list of length 1
  1751. event_ids = [[x] if isinstance(x, str) else x
  1752. for x in event_ids]
  1753. for ids_ in event_ids: # check if tagging is attempted
  1754. if any([id_ not in ids for id_ in ids_]):
  1755. tagging = True
  1756. # 1. treat everything that's not in event_id as a tag
  1757. # 2a. for tags, find all the event_ids matched by the tags
  1758. # 2b. for non-tag ids, just pass them directly
  1759. # 3. do this for every input
  1760. event_ids = [[k for k in ids
  1761. if all((tag in k.split("/")
  1762. for tag in id_))] # ids matching all tags
  1763. if all(id__ not in ids for id__ in id_)
  1764. else id_ # straight pass for non-tag inputs
  1765. for id_ in event_ids]
  1766. for ii, id_ in enumerate(event_ids):
  1767. if len(id_) == 0:
  1768. raise KeyError(f"{orig_ids[ii]} not found in the epoch "
  1769. "object's event_id.")
  1770. elif len({sub_id in ids for sub_id in id_}) != 1:
  1771. err = ("Don't mix hierarchical and regular event_ids"
  1772. " like in \'%s\'." % ", ".join(id_))
  1773. raise ValueError(err)
  1774. # raise for non-orthogonal tags
  1775. if tagging is True:
  1776. events_ = [set(self[x].events[:, 0]) for x in event_ids]
  1777. doubles = events_[0].intersection(events_[1])
  1778. if len(doubles):
  1779. raise ValueError("The two sets of epochs are "
  1780. "overlapping. Provide an "
  1781. "orthogonal selection.")
  1782. for eq in event_ids:
  1783. eq_inds.append(self._keys_to_idx(eq))
  1784. event_times = [self.events[e, 0] for e in eq_inds]
  1785. indices = _get_drop_indices(event_times, method)
  1786. # need to re-index indices
  1787. indices = np.concatenate([e[idx] for e, idx in zip(eq_inds, indices)])
  1788. self.drop(indices, reason='EQUALIZED_COUNT')
  1789. # actually remove the indices
  1790. return self, indices
  1791. @fill_doc
  1792. def to_data_frame(self, picks=None, index=None,
  1793. scalings=None, copy=True, long_format=False,
  1794. time_format='ms'):
  1795. """Export data in tabular structure as a pandas DataFrame.
  1796. Channels are converted to columns in the DataFrame. By default,
  1797. additional columns "time", "epoch" (epoch number), and "condition"
  1798. (epoch event description) are added, unless ``index`` is not ``None``
  1799. (in which case the columns specified in ``index`` will be used to form
  1800. the DataFrame's index instead).
  1801. Parameters
  1802. ----------
  1803. %(picks_all)s
  1804. %(df_index_epo)s
  1805. Valid string values are 'time', 'epoch', and 'condition'.
  1806. Defaults to ``None``.
  1807. %(df_scalings)s
  1808. %(df_copy)s
  1809. %(df_longform_epo)s
  1810. %(df_time_format)s
  1811. .. versionadded:: 0.20
  1812. Returns
  1813. -------
  1814. %(df_return)s
  1815. """
  1816. # check pandas once here, instead of in each private utils function
  1817. pd = _check_pandas_installed() # noqa
  1818. # arg checking
  1819. valid_index_args = ['time', 'epoch', 'condition']
  1820. valid_time_formats = ['ms', 'timedelta']
  1821. index = _check_pandas_index_arguments(index, valid_index_args)
  1822. time_format = _check_time_format(time_format, valid_time_formats)
  1823. # get data
  1824. picks = _picks_to_idx(self.info, picks, 'all', exclude=())
  1825. data = self.get_data()[:, picks, :]
  1826. times = self.times
  1827. n_epochs, n_picks, n_times = data.shape
  1828. data = np.hstack(data).T # (time*epochs) x signals
  1829. if copy:
  1830. data = data.copy()
  1831. data = _scale_dataframe_data(self, data, picks, scalings)
  1832. # prepare extra columns / multiindex
  1833. mindex = list()
  1834. times = np.tile(times, n_epochs)
  1835. times = _convert_times(self, times, time_format)
  1836. mindex.append(('time', times))
  1837. rev_event_id = {v: k for k, v in self.event_id.items()}
  1838. conditions = [rev_event_id[k] for k in self.events[:, 2]]
  1839. mindex.append(('condition', np.repeat(conditions, n_times)))
  1840. mindex.append(('epoch', np.repeat(self.selection, n_times)))
  1841. assert all(len(mdx) == len(mindex[0]) for mdx in mindex)
  1842. # build DataFrame
  1843. df = _build_data_frame(self, data, picks, long_format, mindex, index,
  1844. default_index=['condition', 'epoch', 'time'])
  1845. return df
  1846. def as_type(self, ch_type='grad', mode='fast'):
  1847. """Compute virtual epochs using interpolated fields.
  1848. .. Warning:: Using virtual epochs to compute inverse can yield
  1849. unexpected results. The virtual channels have ``'_v'`` appended
  1850. at the end of the names to emphasize that the data contained in
  1851. them are interpolated.
  1852. Parameters
  1853. ----------
  1854. ch_type : str
  1855. The destination channel type. It can be 'mag' or 'grad'.
  1856. mode : str
  1857. Either ``'accurate'`` or ``'fast'``, determines the quality of the
  1858. Legendre polynomial expansion used. ``'fast'`` should be sufficient
  1859. for most applications.
  1860. Returns
  1861. -------
  1862. epochs : instance of mne.EpochsArray
  1863. The transformed epochs object containing only virtual channels.
  1864. Notes
  1865. -----
  1866. This method returns a copy and does not modify the data it
  1867. operates on. It also returns an EpochsArray instance.
  1868. .. versionadded:: 0.20.0
  1869. """
  1870. from .forward import _as_meg_type_inst
  1871. return _as_meg_type_inst(self, ch_type=ch_type, mode=mode)
  1872. def _drop_log_stats(drop_log, ignore=('IGNORED',)):
  1873. """Compute drop log stats.
  1874. Parameters
  1875. ----------
  1876. drop_log : list of list
  1877. Epoch drop log from Epochs.drop_log.
  1878. ignore : list
  1879. The drop reasons to ignore.
  1880. Returns
  1881. -------
  1882. perc : float
  1883. Total percentage of epochs dropped.
  1884. """
  1885. if not isinstance(drop_log, tuple) or \
  1886. not all(isinstance(d, tuple) for d in drop_log) or \
  1887. not all(isinstance(s, str) for d in drop_log for s in d):
  1888. raise TypeError('drop_log must be a tuple of tuple of str')
  1889. perc = 100 * np.mean([len(d) > 0 for d in drop_log
  1890. if not any(r in ignore for r in d)])
  1891. return perc
  1892. def make_metadata(events, event_id, tmin, tmax, sfreq,
  1893. row_events=None, keep_first=None, keep_last=None):
  1894. """Generate metadata from events for use with `mne.Epochs`.
  1895. This function mimics the epoching process (it constructs time windows
  1896. around time-locked "events of interest") and collates information about
  1897. any other events that occurred within those time windows. The information
  1898. is returned as a :class:`pandas.DataFrame` suitable for use as
  1899. `~mne.Epochs` metadata: one row per time-locked event, and columns
  1900. indicating presence/absence and latency of each ancillary event type.
  1901. The function will also return a new ``events`` array and ``event_id``
  1902. dictionary that correspond to the generated metadata.
  1903. Parameters
  1904. ----------
  1905. events : array, shape (m, 3)
  1906. The :term:`events array <events>`. By default, the returned metadata
  1907. :class:`~pandas.DataFrame` will have as many rows as the events array.
  1908. To create rows for only a subset of events, pass the ``row_events``
  1909. parameter.
  1910. event_id : dict
  1911. A mapping from event names (keys) to event IDs (values). The event
  1912. names will be incorporated as columns of the returned metadata
  1913. :class:`~pandas.DataFrame`.
  1914. tmin, tmax : float
  1915. Start and end of the time interval for metadata generation in seconds,
  1916. relative to the time-locked event of the respective time window.
  1917. .. note::
  1918. If you are planning to attach the generated metadata to
  1919. `~mne.Epochs` and intend to include only events that fall inside
  1920. your epochs time interval, pass the same ``tmin`` and ``tmax``
  1921. values here as you use for your epochs.
  1922. sfreq : float
  1923. The sampling frequency of the data from which the events array was
  1924. extracted.
  1925. row_events : list of str | str | None
  1926. Event types around which to create the time windows / for which to
  1927. create **rows** in the returned metadata :class:`pandas.DataFrame`. If
  1928. provided, the string(s) must be keys of ``event_id``. If ``None``
  1929. (default), rows are created for **all** event types present in
  1930. ``event_id``.
  1931. keep_first : str | list of str | None
  1932. Specify subsets of :term:`hierarchical event descriptors` (HEDs,
  1933. inspired by :footcite:`BigdelyShamloEtAl2013`) matching events of which
  1934. the **first occurrence** within each time window shall be stored in
  1935. addition to the original events.
  1936. .. note::
  1937. There is currently no way to retain **all** occurrences of a
  1938. repeated event. The ``keep_first`` parameter can be used to specify
  1939. subsets of HEDs, effectively creating a new event type that is the
  1940. union of all events types described by the matching HED pattern.
  1941. Only the very first event of this set will be kept.
  1942. For example, you might have two response events types,
  1943. ``response/left`` and ``response/right``; and in trials with both
  1944. responses occurring, you want to keep only the first response. In this
  1945. case, you can pass ``keep_first='response'``. This will add two new
  1946. columns to the metadata: ``response``, indicating at what **time** the
  1947. event occurred, relative to the time-locked event; and
  1948. ``first_response``, stating which **type** (``'left'`` or ``'right'``)
  1949. of event occurred.
  1950. To match specific subsets of HEDs describing different sets of events,
  1951. pass a list of these subsets, e.g.
  1952. ``keep_first=['response', 'stimulus']``. If ``None`` (default), no
  1953. event aggregation will take place and no new columns will be created.
  1954. .. note::
  1955. By default, this function will always retain the first instance
  1956. of any event in each time window. For example, if a time window
  1957. contains two ``'response'`` events, the generated ``response``
  1958. column will automatically refer to the first of the two events. In
  1959. this specific case, it is therefore **not** necessary to make use of
  1960. the ``keep_first`` parameter unless you need to differentiate
  1961. between two types of responses, like in the example above.
  1962. keep_last : list of str | None
  1963. Same as ``keep_first``, but for keeping only the **last** occurrence
  1964. of matching events. The column indicating the **type** of an event
  1965. ``myevent`` will be named ``last_myevent``.
  1966. Returns
  1967. -------
  1968. metadata : pandas.DataFrame
  1969. Metadata for each row event, with the following columns:
  1970. - ``event_name``, with strings indicating the name of the time-locked
  1971. event ("row event") for that specific time window
  1972. - one column per event type in ``event_id``, with the same name; floats
  1973. indicating the latency of the event in seconds, relative to the
  1974. time-locked event
  1975. - if applicable, additional columns named after the ``keep_first`` and
  1976. ``keep_last`` event types; floats indicating the latency of the
  1977. event in seconds, relative to the time-locked event
  1978. - if applicable, additional columns ``first_{event_type}`` and
  1979. ``last_{event_type}`` for ``keep_first`` and ``keep_last`` event
  1980. types, respetively; the values will be strings indicating which event
  1981. types were matched by the provided HED patterns
  1982. events : array, shape (n, 3)
  1983. The events corresponding to the generated metadata, i.e. one
  1984. time-locked event per row.
  1985. event_id : dict
  1986. The event dictionary corresponding to the new events array. This will
  1987. be identical to the input dictionary unless ``row_events`` is supplied,
  1988. in which case it will only contain the events provided there.
  1989. Notes
  1990. -----
  1991. The time window used for metadata generation need not correspond to the
  1992. time window used to create the `~mne.Epochs`, to which the metadata will
  1993. be attached; it may well be much shorter or longer, or not overlap at all,
  1994. if desired. The can be useful, for example, to include events that occurred
  1995. before or after an epoch, e.g. during the inter-trial interval.
  1996. .. versionadded:: 0.23
  1997. References
  1998. ----------
  1999. .. footbibliography::
  2000. """
  2001. from .utils.mixin import _hid_match
  2002. pd = _check_pandas_installed()
  2003. _validate_type(event_id, types=(dict,), item_name='event_id')
  2004. _validate_type(row_events, types=(None, str, list, tuple),
  2005. item_name='row_events')
  2006. _validate_type(keep_first, types=(None, str, list, tuple),
  2007. item_name='keep_first')
  2008. _validate_type(keep_last, types=(None, str, list, tuple),
  2009. item_name='keep_last')
  2010. if not event_id:
  2011. raise ValueError('event_id dictionary must contain at least one entry')
  2012. def _ensure_list(x):
  2013. if x is None:
  2014. return []
  2015. elif isinstance(x, str):
  2016. return [x]
  2017. else:
  2018. return list(x)
  2019. row_events = _ensure_list(row_events)
  2020. keep_first = _ensure_list(keep_first)
  2021. keep_last = _ensure_list(keep_last)
  2022. keep_first_and_last = set(keep_first) & set(keep_last)
  2023. if keep_first_and_last:
  2024. raise ValueError(f'The event names in keep_first and keep_last must '
  2025. f'be mutually exclusive. Specified in both: '
  2026. f'{", ".join(sorted(keep_first_and_last))}')
  2027. del keep_first_and_last
  2028. for param_name, values in dict(keep_first=keep_first,
  2029. keep_last=keep_last).items():
  2030. for first_last_event_name in values:
  2031. try:
  2032. _hid_match(event_id, [first_last_event_name])
  2033. except KeyError:
  2034. raise ValueError(
  2035. f'Event "{first_last_event_name}", specified in '
  2036. f'{param_name}, cannot be found in event_id dictionary')
  2037. event_name_diff = sorted(set(row_events) - set(event_id.keys()))
  2038. if event_name_diff:
  2039. raise ValueError(
  2040. f'Present in row_events, but missing from event_id: '
  2041. f'{", ".join(event_name_diff)}')
  2042. del event_name_diff
  2043. # First and last sample of each epoch, relative to the time-locked event
  2044. # This follows the approach taken in mne.Epochs
  2045. start_sample = int(round(tmin * sfreq))
  2046. stop_sample = int(round(tmax * sfreq)) + 1
  2047. # Make indexing easier
  2048. # We create the DataFrame before subsetting the events so we end up with
  2049. # indices corresponding to the original event indices. Not used for now,
  2050. # but might come in handy sometime later
  2051. events_df = pd.DataFrame(events, columns=('sample', 'prev_id', 'id'))
  2052. id_to_name_map = {v: k for k, v in event_id.items()}
  2053. # Only keep events that are of interest
  2054. events = events[np.in1d(events[:, 2], list(event_id.values()))]
  2055. events_df = events_df.loc[events_df['id'].isin(event_id.values()), :]
  2056. # Prepare & condition the metadata DataFrame
  2057. # Avoid column name duplications if the exact same event name appears in
  2058. # event_id.keys() and keep_first / keep_last simultaneously
  2059. keep_first_cols = [col for col in keep_first if col not in event_id]
  2060. keep_last_cols = [col for col in keep_last if col not in event_id]
  2061. first_cols = [f'first_{col}' for col in keep_first_cols]
  2062. last_cols = [f'last_{col}' for col in keep_last_cols]
  2063. columns = ['event_name',
  2064. *event_id.keys(),
  2065. *keep_first_cols,
  2066. *keep_last_cols,
  2067. *first_cols,
  2068. *last_cols]
  2069. data = np.empty((len(events_df), len(columns)))
  2070. metadata = pd.DataFrame(data=data, columns=columns, index=events_df.index)
  2071. # Event names
  2072. metadata.iloc[:, 0] = ''
  2073. # Event times
  2074. start_idx = 1
  2075. stop_idx = (start_idx + len(event_id.keys()) +
  2076. len(keep_first_cols + keep_last_cols))
  2077. metadata.iloc[:, start_idx:stop_idx] = np.nan
  2078. # keep_first and keep_last names
  2079. start_idx = stop_idx
  2080. metadata.iloc[:, start_idx:] = None
  2081. # We're all set, let's iterate over all eventns and fill in in the
  2082. # respective cells in the metadata. We will subset this to include only
  2083. # `row_events` later
  2084. for row_event in events_df.itertuples(name='RowEvent'):
  2085. row_idx = row_event.Index
  2086. metadata.loc[row_idx, 'event_name'] = \
  2087. id_to_name_map[row_event.id]
  2088. # Determine which events fall into the current epoch
  2089. window_start_sample = row_event.sample + start_sample
  2090. window_stop_sample = row_event.sample + stop_sample
  2091. events_in_window = events_df.loc[
  2092. (events_df['sample'] >= window_start_sample) &
  2093. (events_df['sample'] <= window_stop_sample), :]
  2094. assert not events_in_window.empty
  2095. # Store the metadata
  2096. for event in events_in_window.itertuples(name='Event'):
  2097. event_sample = event.sample - row_event.sample
  2098. event_time = event_sample / sfreq
  2099. event_time = 0 if np.isclose(event_time, 0) else event_time
  2100. event_name = id_to_name_map[event.id]
  2101. if not np.isnan(metadata.loc[row_idx, event_name]):
  2102. # Event already exists in current time window!
  2103. assert metadata.loc[row_idx, event_name] <= event_time
  2104. if event_name not in keep_last:
  2105. continue
  2106. metadata.loc[row_idx, event_name] = event_time
  2107. # Handle keep_first and keep_last event aggregation
  2108. for event_group_name in keep_first + keep_last:
  2109. if event_name not in _hid_match(event_id, [event_group_name]):
  2110. continue
  2111. if event_group_name in keep_first:
  2112. first_last_col = f'first_{event_group_name}'
  2113. else:
  2114. first_last_col = f'last_{event_group_name}'
  2115. old_time = metadata.loc[row_idx, event_group_name]
  2116. if not np.isnan(old_time):
  2117. if ((event_group_name in keep_first and
  2118. old_time <= event_time) or
  2119. (event_group_name in keep_last and
  2120. old_time >= event_time)):
  2121. continue
  2122. if event_group_name not in event_id:
  2123. # This is an HED. Strip redundant information from the
  2124. # event name
  2125. name = (event_name
  2126. .replace(event_group_name, '')
  2127. .replace('//', '/')
  2128. .strip('/'))
  2129. metadata.loc[row_idx, first_last_col] = name
  2130. del name
  2131. metadata.loc[row_idx, event_group_name] = event_time
  2132. # Only keep rows of interest
  2133. if row_events:
  2134. event_id_timelocked = {name: val for name, val in event_id.items()
  2135. if name in row_events}
  2136. events = events[np.in1d(events[:, 2],
  2137. list(event_id_timelocked.values()))]
  2138. metadata = metadata.loc[
  2139. metadata['event_name'].isin(event_id_timelocked)]
  2140. assert len(events) == len(metadata)
  2141. event_id = event_id_timelocked
  2142. return metadata, events, event_id
  2143. @fill_doc
  2144. class Epochs(BaseEpochs):
  2145. """Epochs extracted from a Raw instance.
  2146. Parameters
  2147. ----------
  2148. %(epochs_raw)s
  2149. %(epochs_events_event_id)s
  2150. %(epochs_tmin_tmax)s
  2151. %(baseline_epochs)s
  2152. Defaults to ``(None, 0)``, i.e. beginning of the the data until
  2153. time point zero.
  2154. %(picks_all)s
  2155. preload : bool
  2156. %(epochs_preload)s
  2157. %(reject_epochs)s
  2158. %(flat)s
  2159. %(proj_epochs)s
  2160. %(decim)s
  2161. %(epochs_reject_tmin_tmax)s
  2162. %(epochs_detrend)s
  2163. %(epochs_on_missing)s
  2164. %(reject_by_annotation_epochs)s
  2165. %(epochs_metadata)s
  2166. %(epochs_event_repeated)s
  2167. %(verbose)s
  2168. Attributes
  2169. ----------
  2170. %(info_not_none)s
  2171. event_id : dict
  2172. Names of conditions corresponding to event_ids.
  2173. ch_names : list of string
  2174. List of channel names.
  2175. selection : array
  2176. List of indices of selected events (not dropped or ignored etc.). For
  2177. example, if the original event array had 4 events and the second event
  2178. has been dropped, this attribute would be np.array([0, 2, 3]).
  2179. preload : bool
  2180. Indicates whether epochs are in memory.
  2181. drop_log : tuple of tuple
  2182. A tuple of the same length as the event array used to initialize the
  2183. Epochs object. If the i-th original event is still part of the
  2184. selection, drop_log[i] will be an empty tuple; otherwise it will be
  2185. a tuple of the reasons the event is not longer in the selection, e.g.:
  2186. - 'IGNORED'
  2187. If it isn't part of the current subset defined by the user
  2188. - 'NO_DATA' or 'TOO_SHORT'
  2189. If epoch didn't contain enough data names of channels that exceeded
  2190. the amplitude threshold
  2191. - 'EQUALIZED_COUNTS'
  2192. See :meth:`~mne.Epochs.equalize_event_counts`
  2193. - 'USER'
  2194. For user-defined reasons (see :meth:`~mne.Epochs.drop`).
  2195. filename : str
  2196. The filename of the object.
  2197. times : ndarray
  2198. Time vector in seconds. Goes from ``tmin`` to ``tmax``. Time interval
  2199. between consecutive time samples is equal to the inverse of the
  2200. sampling frequency.
  2201. %(verbose)s
  2202. See Also
  2203. --------
  2204. mne.epochs.combine_event_ids
  2205. mne.Epochs.equalize_event_counts
  2206. Notes
  2207. -----
  2208. When accessing data, Epochs are detrended, baseline-corrected, and
  2209. decimated, then projectors are (optionally) applied.
  2210. For indexing and slicing using ``epochs[...]``, see
  2211. :meth:`mne.Epochs.__getitem__`.
  2212. All methods for iteration over objects (using :meth:`mne.Epochs.__iter__`,
  2213. :meth:`mne.Epochs.iter_evoked` or :meth:`mne.Epochs.next`) use the same
  2214. internal state.
  2215. If ``event_repeated`` is set to ``'merge'``, the coinciding events
  2216. (duplicates) will be merged into a single event_id and assigned a new
  2217. id_number as::
  2218. event_id['{event_id_1}/{event_id_2}/...'] = new_id_number
  2219. For example with the event_id ``{'aud': 1, 'vis': 2}`` and the events
  2220. ``[[0, 0, 1], [0, 0, 2]]``, the "merge" behavior will update both event_id
  2221. and events to be: ``{'aud/vis': 3}`` and ``[[0, 0, 3]]`` respectively.
  2222. """
  2223. @verbose
  2224. def __init__(self, raw, events, event_id=None, tmin=-0.2, tmax=0.5,
  2225. baseline=(None, 0), picks=None, preload=False, reject=None,
  2226. flat=None, proj=True, decim=1, reject_tmin=None,
  2227. reject_tmax=None, detrend=None, on_missing='raise',
  2228. reject_by_annotation=True, metadata=None,
  2229. event_repeated='error', verbose=None): # noqa: D102
  2230. if not isinstance(raw, BaseRaw):
  2231. raise ValueError('The first argument to `Epochs` must be an '
  2232. 'instance of mne.io.BaseRaw')
  2233. info = deepcopy(raw.info)
  2234. # proj is on when applied in Raw
  2235. proj = proj or raw.proj
  2236. self.reject_by_annotation = reject_by_annotation
  2237. # call BaseEpochs constructor
  2238. super(Epochs, self).__init__(
  2239. info, None, events, event_id, tmin, tmax, metadata=metadata,
  2240. baseline=baseline, raw=raw, picks=picks, reject=reject,
  2241. flat=flat, decim=decim, reject_tmin=reject_tmin,
  2242. reject_tmax=reject_tmax, detrend=detrend,
  2243. proj=proj, on_missing=on_missing, preload_at_end=preload,
  2244. event_repeated=event_repeated, verbose=verbose)
  2245. @verbose
  2246. def _get_epoch_from_raw(self, idx, verbose=None):
  2247. """Load one epoch from disk.
  2248. Returns
  2249. -------
  2250. data : array | str | None
  2251. If string, it's details on rejection reason.
  2252. If array, it's the data in the desired range (good segment)
  2253. If None, it means no data is available.
  2254. """
  2255. if self._raw is None:
  2256. # This should never happen, as raw=None only if preload=True
  2257. raise ValueError('An error has occurred, no valid raw file found. '
  2258. 'Please report this to the mne-python '
  2259. 'developers.')
  2260. sfreq = self._raw.info['sfreq']
  2261. event_samp = self.events[idx, 0]
  2262. # Read a data segment from "start" to "stop" in samples
  2263. first_samp = self._raw.first_samp
  2264. start = int(round(event_samp + self._raw_times[0] * sfreq))
  2265. start -= first_samp
  2266. stop = start + len(self._raw_times)
  2267. # reject_tmin, and reject_tmax need to be converted to samples to
  2268. # check the reject_by_annotation boundaries: reject_start, reject_stop
  2269. reject_tmin = self.reject_tmin
  2270. if reject_tmin is None:
  2271. reject_tmin = self._raw_times[0]
  2272. reject_start = int(round(event_samp + reject_tmin * sfreq))
  2273. reject_start -= first_samp
  2274. reject_tmax = self.reject_tmax
  2275. if reject_tmax is None:
  2276. reject_tmax = self._raw_times[-1]
  2277. diff = int(round((self._raw_times[-1] - reject_tmax) * sfreq))
  2278. reject_stop = stop - diff
  2279. logger.debug(' Getting epoch for %d-%d' % (start, stop))
  2280. data = self._raw._check_bad_segment(start, stop, self.picks,
  2281. reject_start, reject_stop,
  2282. self.reject_by_annotation)
  2283. return data
  2284. @fill_doc
  2285. class EpochsArray(BaseEpochs):
  2286. """Epochs object from numpy array.
  2287. Parameters
  2288. ----------
  2289. data : array, shape (n_epochs, n_channels, n_times)
  2290. The channels' time series for each epoch. See notes for proper units of
  2291. measure.
  2292. %(info_not_none)s Consider using :func:`mne.create_info` to populate this
  2293. structure.
  2294. events : None | array of int, shape (n_events, 3)
  2295. The events typically returned by the read_events function.
  2296. If some events don't match the events of interest as specified
  2297. by event_id, they will be marked as 'IGNORED' in the drop log.
  2298. If None (default), all event values are set to 1 and event time-samples
  2299. are set to range(n_epochs).
  2300. tmin : float
  2301. Start time before event. If nothing provided, defaults to 0.
  2302. event_id : int | list of int | dict | None
  2303. The id of the event to consider. If dict,
  2304. the keys can later be used to access associated events. Example:
  2305. dict(auditory=1, visual=3). If int, a dict will be created with
  2306. the id as string. If a list, all events with the IDs specified
  2307. in the list are used. If None, all events will be used with
  2308. and a dict is created with string integer names corresponding
  2309. to the event id integers.
  2310. %(reject_epochs)s
  2311. %(flat)s
  2312. reject_tmin : scalar | None
  2313. Start of the time window used to reject epochs (with the default None,
  2314. the window will start with tmin).
  2315. reject_tmax : scalar | None
  2316. End of the time window used to reject epochs (with the default None,
  2317. the window will end with tmax).
  2318. %(baseline_epochs)s
  2319. Defaults to ``None``, i.e. no baseline correction.
  2320. proj : bool | 'delayed'
  2321. Apply SSP projection vectors. See :class:`mne.Epochs` for details.
  2322. on_missing : str
  2323. See :class:`mne.Epochs` docstring for details.
  2324. metadata : instance of pandas.DataFrame | None
  2325. See :class:`mne.Epochs` docstring for details.
  2326. .. versionadded:: 0.16
  2327. selection : ndarray | None
  2328. The selection compared to the original set of epochs.
  2329. Can be None to use ``np.arange(len(events))``.
  2330. .. versionadded:: 0.16
  2331. %(verbose)s
  2332. See Also
  2333. --------
  2334. create_info
  2335. EvokedArray
  2336. io.RawArray
  2337. Notes
  2338. -----
  2339. Proper units of measure:
  2340. * V: eeg, eog, seeg, dbs, emg, ecg, bio, ecog
  2341. * T: mag
  2342. * T/m: grad
  2343. * M: hbo, hbr
  2344. * Am: dipole
  2345. * AU: misc
  2346. """
  2347. @verbose
  2348. def __init__(self, data, info, events=None, tmin=0, event_id=None,
  2349. reject=None, flat=None, reject_tmin=None,
  2350. reject_tmax=None, baseline=None, proj=True,
  2351. on_missing='raise', metadata=None, selection=None,
  2352. verbose=None): # noqa: D102
  2353. dtype = np.complex128 if np.any(np.iscomplex(data)) else np.float64
  2354. data = np.asanyarray(data, dtype=dtype)
  2355. if data.ndim != 3:
  2356. raise ValueError('Data must be a 3D array of shape (n_epochs, '
  2357. 'n_channels, n_samples)')
  2358. if len(info['ch_names']) != data.shape[1]:
  2359. raise ValueError('Info and data must have same number of '
  2360. 'channels.')
  2361. if events is None:
  2362. n_epochs = len(data)
  2363. events = _gen_events(n_epochs)
  2364. info = info.copy() # do not modify original info
  2365. tmax = (data.shape[2] - 1) / info['sfreq'] + tmin
  2366. super(EpochsArray, self).__init__(
  2367. info, data, events, event_id, tmin, tmax, baseline, reject=reject,
  2368. flat=flat, reject_tmin=reject_tmin, reject_tmax=reject_tmax,
  2369. decim=1, metadata=metadata, selection=selection, proj=proj,
  2370. on_missing=on_missing)
  2371. if self.baseline is not None:
  2372. self._do_baseline = True
  2373. if len(events) != np.in1d(self.events[:, 2],
  2374. list(self.event_id.values())).sum():
  2375. raise ValueError('The events must only contain event numbers from '
  2376. 'event_id')
  2377. detrend_picks = self._detrend_picks
  2378. for e in self._data:
  2379. # This is safe without assignment b/c there is no decim
  2380. self._detrend_offset_decim(e, detrend_picks)
  2381. self.drop_bad()
  2382. def combine_event_ids(epochs, old_event_ids, new_event_id, copy=True):
  2383. """Collapse event_ids from an epochs instance into a new event_id.
  2384. Parameters
  2385. ----------
  2386. epochs : instance of Epochs
  2387. The epochs to operate on.
  2388. old_event_ids : str, or list
  2389. Conditions to collapse together.
  2390. new_event_id : dict, or int
  2391. A one-element dict (or a single integer) for the new
  2392. condition. Note that for safety, this cannot be any
  2393. existing id (in epochs.event_id.values()).
  2394. copy : bool
  2395. Whether to return a new instance or modify in place.
  2396. Returns
  2397. -------
  2398. epochs : instance of Epochs
  2399. The modified epochs.
  2400. Notes
  2401. -----
  2402. This For example (if epochs.event_id was ``{'Left': 1, 'Right': 2}``::
  2403. combine_event_ids(epochs, ['Left', 'Right'], {'Directional': 12})
  2404. would create a 'Directional' entry in epochs.event_id replacing
  2405. 'Left' and 'Right' (combining their trials).
  2406. """
  2407. epochs = epochs.copy() if copy else epochs
  2408. old_event_ids = np.asanyarray(old_event_ids)
  2409. if isinstance(new_event_id, int):
  2410. new_event_id = {str(new_event_id): new_event_id}
  2411. else:
  2412. if not isinstance(new_event_id, dict):
  2413. raise ValueError('new_event_id must be a dict or int')
  2414. if not len(list(new_event_id.keys())) == 1:
  2415. raise ValueError('new_event_id dict must have one entry')
  2416. new_event_num = list(new_event_id.values())[0]
  2417. new_event_num = operator.index(new_event_num)
  2418. if new_event_num in epochs.event_id.values():
  2419. raise ValueError('new_event_id value must not already exist')
  2420. # could use .pop() here, but if a latter one doesn't exist, we're
  2421. # in trouble, so run them all here and pop() later
  2422. old_event_nums = np.array([epochs.event_id[key] for key in old_event_ids])
  2423. # find the ones to replace
  2424. inds = np.any(epochs.events[:, 2][:, np.newaxis] ==
  2425. old_event_nums[np.newaxis, :], axis=1)
  2426. # replace the event numbers in the events list
  2427. epochs.events[inds, 2] = new_event_num
  2428. # delete old entries
  2429. for key in old_event_ids:
  2430. epochs.event_id.pop(key)
  2431. # add the new entry
  2432. epochs.event_id.update(new_event_id)
  2433. return epochs
  2434. def equalize_epoch_counts(epochs_list, method='mintime'):
  2435. """Equalize the number of trials in multiple Epoch instances.
  2436. Parameters
  2437. ----------
  2438. epochs_list : list of Epochs instances
  2439. The Epochs instances to equalize trial counts for.
  2440. method : str
  2441. If 'truncate', events will be truncated from the end of each event
  2442. list. If 'mintime', timing differences between each event list will be
  2443. minimized.
  2444. Notes
  2445. -----
  2446. This tries to make the remaining epochs occurring as close as possible in
  2447. time. This method works based on the idea that if there happened to be some
  2448. time-varying (like on the scale of minutes) noise characteristics during
  2449. a recording, they could be compensated for (to some extent) in the
  2450. equalization process. This method thus seeks to reduce any of those effects
  2451. by minimizing the differences in the times of the events in the two sets of
  2452. epochs. For example, if one had event times [1, 2, 3, 4, 120, 121] and the
  2453. other one had [3.5, 4.5, 120.5, 121.5], it would remove events at times
  2454. [1, 2] in the first epochs and not [120, 121].
  2455. Examples
  2456. --------
  2457. >>> equalize_epoch_counts([epochs1, epochs2]) # doctest: +SKIP
  2458. """
  2459. if not all(isinstance(e, BaseEpochs) for e in epochs_list):
  2460. raise ValueError('All inputs must be Epochs instances')
  2461. # make sure bad epochs are dropped
  2462. for e in epochs_list:
  2463. if not e._bad_dropped:
  2464. e.drop_bad()
  2465. event_times = [e.events[:, 0] for e in epochs_list]
  2466. indices = _get_drop_indices(event_times, method)
  2467. for e, inds in zip(epochs_list, indices):
  2468. e.drop(inds, reason='EQUALIZED_COUNT')
  2469. def _get_drop_indices(event_times, method):
  2470. """Get indices to drop from multiple event timing lists."""
  2471. small_idx = np.argmin([e.shape[0] for e in event_times])
  2472. small_e_times = event_times[small_idx]
  2473. _check_option('method', method, ['mintime', 'truncate'])
  2474. indices = list()
  2475. for e in event_times:
  2476. if method == 'mintime':
  2477. mask = _minimize_time_diff(small_e_times, e)
  2478. else:
  2479. mask = np.ones(e.shape[0], dtype=bool)
  2480. mask[small_e_times.shape[0]:] = False
  2481. indices.append(np.where(np.logical_not(mask))[0])
  2482. return indices
  2483. def _minimize_time_diff(t_shorter, t_longer):
  2484. """Find a boolean mask to minimize timing differences."""
  2485. from scipy.interpolate import interp1d
  2486. keep = np.ones((len(t_longer)), dtype=bool)
  2487. # special case: length zero or one
  2488. if len(t_shorter) < 2: # interp1d won't work
  2489. keep.fill(False)
  2490. if len(t_shorter) == 1:
  2491. idx = np.argmin(np.abs(t_longer - t_shorter))
  2492. keep[idx] = True
  2493. return keep
  2494. scores = np.ones((len(t_longer)))
  2495. x1 = np.arange(len(t_shorter))
  2496. # The first set of keep masks to test
  2497. kwargs = dict(copy=False, bounds_error=False, assume_sorted=True)
  2498. shorter_interp = interp1d(x1, t_shorter, fill_value=t_shorter[-1],
  2499. **kwargs)
  2500. for ii in range(len(t_longer) - len(t_shorter)):
  2501. scores.fill(np.inf)
  2502. # set up the keep masks to test, eliminating any rows that are already
  2503. # gone
  2504. keep_mask = ~np.eye(len(t_longer), dtype=bool)[keep]
  2505. keep_mask[:, ~keep] = False
  2506. # Check every possible removal to see if it minimizes
  2507. x2 = np.arange(len(t_longer) - ii - 1)
  2508. t_keeps = np.array([t_longer[km] for km in keep_mask])
  2509. longer_interp = interp1d(x2, t_keeps, axis=1,
  2510. fill_value=t_keeps[:, -1],
  2511. **kwargs)
  2512. d1 = longer_interp(x1) - t_shorter
  2513. d2 = shorter_interp(x2) - t_keeps
  2514. scores[keep] = np.abs(d1, d1).sum(axis=1) + np.abs(d2, d2).sum(axis=1)
  2515. keep[np.argmin(scores)] = False
  2516. return keep
  2517. @verbose
  2518. def _is_good(e, ch_names, channel_type_idx, reject, flat, full_report=False,
  2519. ignore_chs=[], verbose=None):
  2520. """Test if data segment e is good according to reject and flat.
  2521. If full_report=True, it will give True/False as well as a list of all
  2522. offending channels.
  2523. """
  2524. bad_tuple = tuple()
  2525. has_printed = False
  2526. checkable = np.ones(len(ch_names), dtype=bool)
  2527. checkable[np.array([c in ignore_chs
  2528. for c in ch_names], dtype=bool)] = False
  2529. for refl, f, t in zip([reject, flat], [np.greater, np.less], ['', 'flat']):
  2530. if refl is not None:
  2531. for key, thresh in refl.items():
  2532. idx = channel_type_idx[key]
  2533. name = key.upper()
  2534. if len(idx) > 0:
  2535. e_idx = e[idx]
  2536. deltas = np.max(e_idx, axis=1) - np.min(e_idx, axis=1)
  2537. checkable_idx = checkable[idx]
  2538. idx_deltas = np.where(np.logical_and(f(deltas, thresh),
  2539. checkable_idx))[0]
  2540. if len(idx_deltas) > 0:
  2541. bad_names = [ch_names[idx[i]] for i in idx_deltas]
  2542. if (not has_printed):
  2543. logger.info(' Rejecting %s epoch based on %s : '
  2544. '%s' % (t, name, bad_names))
  2545. has_printed = True
  2546. if not full_report:
  2547. return False
  2548. else:
  2549. bad_tuple += tuple(bad_names)
  2550. if not full_report:
  2551. return True
  2552. else:
  2553. if bad_tuple == ():
  2554. return True, None
  2555. else:
  2556. return False, bad_tuple
  2557. def _read_one_epoch_file(f, tree, preload):
  2558. """Read a single FIF file."""
  2559. with f as fid:
  2560. # Read the measurement info
  2561. info, meas = read_meas_info(fid, tree, clean_bads=True)
  2562. events, mappings = _read_events_fif(fid, tree)
  2563. # Metadata
  2564. metadata = None
  2565. metadata_tree = dir_tree_find(tree, FIFF.FIFFB_MNE_METADATA)
  2566. if len(metadata_tree) > 0:
  2567. for dd in metadata_tree[0]['directory']:
  2568. kind = dd.kind
  2569. pos = dd.pos
  2570. if kind == FIFF.FIFF_DESCRIPTION:
  2571. metadata = read_tag(fid, pos).data
  2572. metadata = _prepare_read_metadata(metadata)
  2573. break
  2574. # Locate the data of interest
  2575. processed = dir_tree_find(meas, FIFF.FIFFB_PROCESSED_DATA)
  2576. del meas
  2577. if len(processed) == 0:
  2578. raise ValueError('Could not find processed data')
  2579. epochs_node = dir_tree_find(tree, FIFF.FIFFB_MNE_EPOCHS)
  2580. if len(epochs_node) == 0:
  2581. # before version 0.11 we errantly saved with this tag instead of
  2582. # an MNE tag
  2583. epochs_node = dir_tree_find(tree, FIFF.FIFFB_MNE_EPOCHS)
  2584. if len(epochs_node) == 0:
  2585. epochs_node = dir_tree_find(tree, 122) # 122 used before v0.11
  2586. if len(epochs_node) == 0:
  2587. raise ValueError('Could not find epochs data')
  2588. my_epochs = epochs_node[0]
  2589. # Now find the data in the block
  2590. data = None
  2591. data_tag = None
  2592. bmin, bmax = None, None
  2593. baseline = None
  2594. selection = None
  2595. drop_log = None
  2596. reject_params = {}
  2597. for k in range(my_epochs['nent']):
  2598. kind = my_epochs['directory'][k].kind
  2599. pos = my_epochs['directory'][k].pos
  2600. if kind == FIFF.FIFF_FIRST_SAMPLE:
  2601. tag = read_tag(fid, pos)
  2602. first = int(tag.data)
  2603. elif kind == FIFF.FIFF_LAST_SAMPLE:
  2604. tag = read_tag(fid, pos)
  2605. last = int(tag.data)
  2606. elif kind == FIFF.FIFF_EPOCH:
  2607. # delay reading until later
  2608. fid.seek(pos, 0)
  2609. data_tag = read_tag_info(fid)
  2610. data_tag.pos = pos
  2611. data_tag.type = data_tag.type ^ (1 << 30)
  2612. elif kind in [FIFF.FIFF_MNE_BASELINE_MIN, 304]:
  2613. # Constant 304 was used before v0.11
  2614. tag = read_tag(fid, pos)
  2615. bmin = float(tag.data)
  2616. elif kind in [FIFF.FIFF_MNE_BASELINE_MAX, 305]:
  2617. # Constant 305 was used before v0.11
  2618. tag = read_tag(fid, pos)
  2619. bmax = float(tag.data)
  2620. elif kind == FIFF.FIFF_MNE_EPOCHS_SELECTION:
  2621. tag = read_tag(fid, pos)
  2622. selection = np.array(tag.data)
  2623. elif kind == FIFF.FIFF_MNE_EPOCHS_DROP_LOG:
  2624. tag = read_tag(fid, pos)
  2625. drop_log = tag.data
  2626. drop_log = json.loads(drop_log)
  2627. drop_log = tuple(tuple(x) for x in drop_log)
  2628. elif kind == FIFF.FIFF_MNE_EPOCHS_REJECT_FLAT:
  2629. tag = read_tag(fid, pos)
  2630. reject_params = json.loads(tag.data)
  2631. if bmin is not None or bmax is not None:
  2632. baseline = (bmin, bmax)
  2633. n_samp = last - first + 1
  2634. logger.info(' Found the data of interest:')
  2635. logger.info(' t = %10.2f ... %10.2f ms'
  2636. % (1000 * first / info['sfreq'],
  2637. 1000 * last / info['sfreq']))
  2638. if info['comps'] is not None:
  2639. logger.info(' %d CTF compensation matrices available'
  2640. % len(info['comps']))
  2641. # Inspect the data
  2642. if data_tag is None:
  2643. raise ValueError('Epochs data not found')
  2644. epoch_shape = (len(info['ch_names']), n_samp)
  2645. size_expected = len(events) * np.prod(epoch_shape)
  2646. # on read double-precision is always used
  2647. if data_tag.type == FIFF.FIFFT_FLOAT:
  2648. datatype = np.float64
  2649. fmt = '>f4'
  2650. elif data_tag.type == FIFF.FIFFT_DOUBLE:
  2651. datatype = np.float64
  2652. fmt = '>f8'
  2653. elif data_tag.type == FIFF.FIFFT_COMPLEX_FLOAT:
  2654. datatype = np.complex128
  2655. fmt = '>c8'
  2656. elif data_tag.type == FIFF.FIFFT_COMPLEX_DOUBLE:
  2657. datatype = np.complex128
  2658. fmt = '>c16'
  2659. fmt_itemsize = np.dtype(fmt).itemsize
  2660. assert fmt_itemsize in (4, 8, 16)
  2661. size_actual = data_tag.size // fmt_itemsize - 16 // fmt_itemsize
  2662. if not size_actual == size_expected:
  2663. raise ValueError('Incorrect number of samples (%d instead of %d)'
  2664. % (size_actual, size_expected))
  2665. # Calibration factors
  2666. cals = np.array([[info['chs'][k]['cal'] *
  2667. info['chs'][k].get('scale', 1.0)]
  2668. for k in range(info['nchan'])], np.float64)
  2669. # Read the data
  2670. if preload:
  2671. data = read_tag(fid, data_tag.pos).data.astype(datatype)
  2672. data *= cals
  2673. # Put it all together
  2674. tmin = first / info['sfreq']
  2675. tmax = last / info['sfreq']
  2676. event_id = ({str(e): e for e in np.unique(events[:, 2])}
  2677. if mappings is None else mappings)
  2678. # In case epochs didn't have a FIFF.FIFF_MNE_EPOCHS_SELECTION tag
  2679. # (version < 0.8):
  2680. if selection is None:
  2681. selection = np.arange(len(events))
  2682. if drop_log is None:
  2683. drop_log = ((),) * len(events)
  2684. return (info, data, data_tag, events, event_id, metadata, tmin, tmax,
  2685. baseline, selection, drop_log, epoch_shape, cals, reject_params,
  2686. fmt)
  2687. @verbose
  2688. def read_epochs(fname, proj=True, preload=True, verbose=None):
  2689. """Read epochs from a fif file.
  2690. Parameters
  2691. ----------
  2692. %(epochs_fname)s
  2693. %(proj_epochs)s
  2694. preload : bool
  2695. If True, read all epochs from disk immediately. If ``False``, epochs
  2696. will be read on demand.
  2697. %(verbose)s
  2698. Returns
  2699. -------
  2700. epochs : instance of Epochs
  2701. The epochs.
  2702. """
  2703. return EpochsFIF(fname, proj, preload, verbose)
  2704. class _RawContainer(object):
  2705. """Helper for a raw data container."""
  2706. def __init__(self, fid, data_tag, event_samps, epoch_shape,
  2707. cals, fmt): # noqa: D102
  2708. self.fid = fid
  2709. self.data_tag = data_tag
  2710. self.event_samps = event_samps
  2711. self.epoch_shape = epoch_shape
  2712. self.cals = cals
  2713. self.proj = False
  2714. self.fmt = fmt
  2715. def __del__(self): # noqa: D105
  2716. self.fid.close()
  2717. @fill_doc
  2718. class EpochsFIF(BaseEpochs):
  2719. """Epochs read from disk.
  2720. Parameters
  2721. ----------
  2722. %(epochs_fname)s
  2723. %(proj_epochs)s
  2724. preload : bool
  2725. If True, read all epochs from disk immediately. If False, epochs will
  2726. be read on demand.
  2727. %(verbose)s
  2728. See Also
  2729. --------
  2730. mne.Epochs
  2731. mne.epochs.combine_event_ids
  2732. mne.Epochs.equalize_event_counts
  2733. """
  2734. @verbose
  2735. def __init__(self, fname, proj=True, preload=True,
  2736. verbose=None): # noqa: D102
  2737. if _path_like(fname):
  2738. check_fname(
  2739. fname=fname, filetype='epochs',
  2740. endings=('-epo.fif', '-epo.fif.gz', '_epo.fif', '_epo.fif.gz')
  2741. )
  2742. fname = _check_fname(fname=fname, must_exist=True,
  2743. overwrite='read')
  2744. elif not preload:
  2745. raise ValueError('preload must be used with file-like objects')
  2746. fnames = [fname]
  2747. ep_list = list()
  2748. raw = list()
  2749. for fname in fnames:
  2750. fname_rep = _get_fname_rep(fname)
  2751. logger.info('Reading %s ...' % fname_rep)
  2752. fid, tree, _ = fiff_open(fname, preload=preload)
  2753. next_fname = _get_next_fname(fid, fname, tree)
  2754. (info, data, data_tag, events, event_id, metadata, tmin, tmax,
  2755. baseline, selection, drop_log, epoch_shape, cals,
  2756. reject_params, fmt) = \
  2757. _read_one_epoch_file(fid, tree, preload)
  2758. if (events[:, 0] < 0).any():
  2759. events = events.copy()
  2760. warn('Incorrect events detected on disk, setting event '
  2761. 'numbers to consecutive increasing integers')
  2762. events[:, 0] = np.arange(1, len(events) + 1)
  2763. # here we ignore missing events, since users should already be
  2764. # aware of missing events if they have saved data that way
  2765. # we also retain original baseline without re-applying baseline
  2766. # correction (data is being baseline-corrected when written to
  2767. # disk)
  2768. epoch = BaseEpochs(
  2769. info, data, events, event_id, tmin, tmax,
  2770. baseline=None,
  2771. metadata=metadata, on_missing='ignore',
  2772. selection=selection, drop_log=drop_log,
  2773. proj=False, verbose=False)
  2774. epoch.baseline = baseline
  2775. epoch._do_baseline = False # might be superfluous but won't hurt
  2776. ep_list.append(epoch)
  2777. if not preload:
  2778. # store everything we need to index back to the original data
  2779. raw.append(_RawContainer(fiff_open(fname)[0], data_tag,
  2780. events[:, 0].copy(), epoch_shape,
  2781. cals, fmt))
  2782. if next_fname is not None:
  2783. fnames.append(next_fname)
  2784. (info, data, events, event_id, tmin, tmax, metadata, baseline,
  2785. selection, drop_log, _) = \
  2786. _concatenate_epochs(ep_list, with_data=preload, add_offset=False)
  2787. # we need this uniqueness for non-preloaded data to work properly
  2788. if len(np.unique(events[:, 0])) != len(events):
  2789. raise RuntimeError('Event time samples were not unique')
  2790. # correct the drop log
  2791. assert len(drop_log) % len(fnames) == 0
  2792. step = len(drop_log) // len(fnames)
  2793. offsets = np.arange(step, len(drop_log) + 1, step)
  2794. drop_log = list(drop_log)
  2795. for i1, i2 in zip(offsets[:-1], offsets[1:]):
  2796. other_log = drop_log[i1:i2]
  2797. for k, (a, b) in enumerate(zip(drop_log, other_log)):
  2798. if a == ('IGNORED',) and b != ('IGNORED',):
  2799. drop_log[k] = b
  2800. drop_log = tuple(drop_log[:step])
  2801. # call BaseEpochs constructor
  2802. # again, ensure we're retaining the baseline period originally loaded
  2803. # from disk without trying to re-apply baseline correction
  2804. super(EpochsFIF, self).__init__(
  2805. info, data, events, event_id, tmin, tmax, baseline=None, raw=raw,
  2806. proj=proj, preload_at_end=False, on_missing='ignore',
  2807. selection=selection, drop_log=drop_log, filename=fname_rep,
  2808. metadata=metadata, verbose=verbose, **reject_params)
  2809. self.baseline = baseline
  2810. self._do_baseline = False
  2811. # use the private property instead of drop_bad so that epochs
  2812. # are not all read from disk for preload=False
  2813. self._bad_dropped = True
  2814. @verbose
  2815. def _get_epoch_from_raw(self, idx, verbose=None):
  2816. """Load one epoch from disk."""
  2817. # Find the right file and offset to use
  2818. event_samp = self.events[idx, 0]
  2819. for raw in self._raw:
  2820. idx = np.where(raw.event_samps == event_samp)[0]
  2821. if len(idx) == 1:
  2822. fmt = raw.fmt
  2823. idx = idx[0]
  2824. size = np.prod(raw.epoch_shape) * np.dtype(fmt).itemsize
  2825. offset = idx * size + 16 # 16 = Tag header
  2826. break
  2827. else:
  2828. # read the correct subset of the data
  2829. raise RuntimeError('Correct epoch could not be found, please '
  2830. 'contact mne-python developers')
  2831. # the following is equivalent to this, but faster:
  2832. #
  2833. # >>> data = read_tag(raw.fid, raw.data_tag.pos).data.astype(float)
  2834. # >>> data *= raw.cals[np.newaxis, :, :]
  2835. # >>> data = data[idx]
  2836. #
  2837. # Eventually this could be refactored in io/tag.py if other functions
  2838. # could make use of it
  2839. raw.fid.seek(raw.data_tag.pos + offset, 0)
  2840. if fmt == '>c8':
  2841. read_fmt = '>f4'
  2842. elif fmt == '>c16':
  2843. read_fmt = '>f8'
  2844. else:
  2845. read_fmt = fmt
  2846. data = np.frombuffer(raw.fid.read(size), read_fmt)
  2847. if read_fmt != fmt:
  2848. data = data.view(fmt)
  2849. data = data.astype(np.complex128)
  2850. else:
  2851. data = data.astype(np.float64)
  2852. data.shape = raw.epoch_shape
  2853. data *= raw.cals
  2854. return data
  2855. @fill_doc
  2856. def bootstrap(epochs, random_state=None):
  2857. """Compute epochs selected by bootstrapping.
  2858. Parameters
  2859. ----------
  2860. epochs : Epochs instance
  2861. epochs data to be bootstrapped
  2862. %(random_state)s
  2863. Returns
  2864. -------
  2865. epochs : Epochs instance
  2866. The bootstrap samples
  2867. """
  2868. if not epochs.preload:
  2869. raise RuntimeError('Modifying data of epochs is only supported '
  2870. 'when preloading is used. Use preload=True '
  2871. 'in the constructor.')
  2872. rng = check_random_state(random_state)
  2873. epochs_bootstrap = epochs.copy()
  2874. n_events = len(epochs_bootstrap.events)
  2875. idx = rng_uniform(rng)(0, n_events, n_events)
  2876. epochs_bootstrap = epochs_bootstrap[idx]
  2877. return epochs_bootstrap
  2878. def _check_merge_epochs(epochs_list):
  2879. """Aux function."""
  2880. if len({tuple(epochs.event_id.items()) for epochs in epochs_list}) != 1:
  2881. raise NotImplementedError("Epochs with unequal values for event_id")
  2882. if len({epochs.tmin for epochs in epochs_list}) != 1:
  2883. raise NotImplementedError("Epochs with unequal values for tmin")
  2884. if len({epochs.tmax for epochs in epochs_list}) != 1:
  2885. raise NotImplementedError("Epochs with unequal values for tmax")
  2886. if len({epochs.baseline for epochs in epochs_list}) != 1:
  2887. raise NotImplementedError("Epochs with unequal values for baseline")
  2888. @verbose
  2889. def add_channels_epochs(epochs_list, verbose=None):
  2890. """Concatenate channels, info and data from two Epochs objects.
  2891. Parameters
  2892. ----------
  2893. epochs_list : list of Epochs
  2894. Epochs object to concatenate.
  2895. %(verbose)s Defaults to True if any of the input epochs have verbose=True.
  2896. Returns
  2897. -------
  2898. epochs : instance of Epochs
  2899. Concatenated epochs.
  2900. """
  2901. if not all(e.preload for e in epochs_list):
  2902. raise ValueError('All epochs must be preloaded.')
  2903. info = _merge_info([epochs.info for epochs in epochs_list])
  2904. data = [epochs._data for epochs in epochs_list]
  2905. _check_merge_epochs(epochs_list)
  2906. for d in data:
  2907. if len(d) != len(data[0]):
  2908. raise ValueError('all epochs must be of the same length')
  2909. data = np.concatenate(data, axis=1)
  2910. if len(info['chs']) != data.shape[1]:
  2911. err = "Data shape does not match channel number in measurement info"
  2912. raise RuntimeError(err)
  2913. events = epochs_list[0].events.copy()
  2914. all_same = all(np.array_equal(events, epochs.events)
  2915. for epochs in epochs_list[1:])
  2916. if not all_same:
  2917. raise ValueError('Events must be the same.')
  2918. proj = any(e.proj for e in epochs_list)
  2919. if verbose is None:
  2920. verbose = any(e.verbose for e in epochs_list)
  2921. epochs = epochs_list[0].copy()
  2922. epochs.info = info
  2923. epochs.picks = None
  2924. epochs.verbose = verbose
  2925. epochs.events = events
  2926. epochs.preload = True
  2927. epochs._bad_dropped = True
  2928. epochs._data = data
  2929. epochs._projector, epochs.info = setup_proj(epochs.info, False,
  2930. activate=proj)
  2931. return epochs
  2932. def _concatenate_epochs(epochs_list, with_data=True, add_offset=True, *,
  2933. on_mismatch='raise'):
  2934. """Auxiliary function for concatenating epochs."""
  2935. if not isinstance(epochs_list, (list, tuple)):
  2936. raise TypeError('epochs_list must be a list or tuple, got %s'
  2937. % (type(epochs_list),))
  2938. for ei, epochs in enumerate(epochs_list):
  2939. if not isinstance(epochs, BaseEpochs):
  2940. raise TypeError('epochs_list[%d] must be an instance of Epochs, '
  2941. 'got %s' % (ei, type(epochs)))
  2942. out = epochs_list[0]
  2943. offsets = [0]
  2944. if with_data:
  2945. out.drop_bad()
  2946. offsets.append(len(out))
  2947. events = [out.events]
  2948. metadata = [out.metadata]
  2949. baseline, tmin, tmax = out.baseline, out.tmin, out.tmax
  2950. info = deepcopy(out.info)
  2951. verbose = out.verbose
  2952. drop_log = out.drop_log
  2953. event_id = deepcopy(out.event_id)
  2954. selection = out.selection
  2955. # offset is the last epoch + tmax + 10 second
  2956. shift = int((10 + tmax) * out.info['sfreq'])
  2957. events_offset = int(np.max(events[0][:, 0])) + shift
  2958. events_overflow = False
  2959. for ii, epochs in enumerate(epochs_list[1:], 1):
  2960. _ensure_infos_match(epochs.info, info, f'epochs[{ii}]',
  2961. on_mismatch=on_mismatch)
  2962. if not np.allclose(epochs.times, epochs_list[0].times):
  2963. raise ValueError('Epochs must have same times')
  2964. if epochs.baseline != baseline:
  2965. raise ValueError('Baseline must be same for all epochs')
  2966. # compare event_id
  2967. common_keys = list(set(event_id).intersection(set(epochs.event_id)))
  2968. for key in common_keys:
  2969. if not event_id[key] == epochs.event_id[key]:
  2970. msg = ('event_id values must be the same for identical keys '
  2971. 'for all concatenated epochs. Key "{}" maps to {} in '
  2972. 'some epochs and to {} in others.')
  2973. raise ValueError(msg.format(key, event_id[key],
  2974. epochs.event_id[key]))
  2975. if with_data:
  2976. epochs.drop_bad()
  2977. offsets.append(len(epochs))
  2978. evs = epochs.events.copy()
  2979. if len(epochs.events) == 0:
  2980. warn('One of the Epochs objects to concatenate was empty.')
  2981. elif add_offset:
  2982. # We need to cast to a native Python int here to detect an
  2983. # overflow of a numpy int32 (which is the default on windows)
  2984. max_timestamp = int(np.max(evs[:, 0]))
  2985. evs[:, 0] += events_offset
  2986. events_offset += max_timestamp + shift
  2987. if events_offset > INT32_MAX:
  2988. warn(f'Event number greater than {INT32_MAX} created, '
  2989. 'events[:, 0] will be assigned consecutive increasing '
  2990. 'integer values')
  2991. events_overflow = True
  2992. add_offset = False # we no longer need to add offset
  2993. events.append(evs)
  2994. selection = np.concatenate((selection, epochs.selection))
  2995. drop_log = drop_log + epochs.drop_log
  2996. event_id.update(epochs.event_id)
  2997. metadata.append(epochs.metadata)
  2998. events = np.concatenate(events, axis=0)
  2999. # check to see if we exceeded our maximum event offset
  3000. if events_overflow:
  3001. events[:, 0] = np.arange(1, len(events) + 1)
  3002. # Create metadata object (or make it None)
  3003. n_have = sum(this_meta is not None for this_meta in metadata)
  3004. if n_have == 0:
  3005. metadata = None
  3006. elif n_have != len(metadata):
  3007. raise ValueError('%d of %d epochs instances have metadata, either '
  3008. 'all or none must have metadata'
  3009. % (n_have, len(metadata)))
  3010. else:
  3011. pd = _check_pandas_installed(strict=False)
  3012. if pd is not False:
  3013. metadata = pd.concat(metadata)
  3014. else: # dict of dicts
  3015. metadata = sum(metadata, list())
  3016. assert len(offsets) == (len(epochs_list) if with_data else 0) + 1
  3017. data = None
  3018. if with_data:
  3019. offsets = np.cumsum(offsets)
  3020. for start, stop, epochs in zip(offsets[:-1], offsets[1:], epochs_list):
  3021. this_data = epochs.get_data()
  3022. if data is None:
  3023. data = np.empty(
  3024. (offsets[-1], len(out.ch_names), len(out.times)),
  3025. dtype=this_data.dtype)
  3026. data[start:stop] = this_data
  3027. return (info, data, events, event_id, tmin, tmax, metadata, baseline,
  3028. selection, drop_log, verbose)
  3029. def _finish_concat(info, data, events, event_id, tmin, tmax, metadata,
  3030. baseline, selection, drop_log, verbose):
  3031. """Finish concatenation for epochs not read from disk."""
  3032. selection = np.where([len(d) == 0 for d in drop_log])[0]
  3033. out = BaseEpochs(
  3034. info, data, events, event_id, tmin, tmax, baseline=baseline,
  3035. selection=selection, drop_log=drop_log, proj=False,
  3036. on_missing='ignore', metadata=metadata, verbose=verbose)
  3037. out.drop_bad()
  3038. return out
  3039. @verbose
  3040. def concatenate_epochs(epochs_list, add_offset=True, *, on_mismatch='raise',
  3041. verbose=None):
  3042. """Concatenate a list of epochs into one epochs object.
  3043. Parameters
  3044. ----------
  3045. epochs_list : list
  3046. List of Epochs instances to concatenate (in order).
  3047. add_offset : bool
  3048. If True, a fixed offset is added to the event times from different
  3049. Epochs sets, such that they are easy to distinguish after the
  3050. concatenation.
  3051. If False, the event times are unaltered during the concatenation.
  3052. %(on_info_mismatch)s
  3053. %(verbose)s
  3054. .. versionadded:: 0.24
  3055. Returns
  3056. -------
  3057. epochs : instance of Epochs
  3058. The result of the concatenation (first Epochs instance passed in).
  3059. Notes
  3060. -----
  3061. .. versionadded:: 0.9.0
  3062. """
  3063. return _finish_concat(*_concatenate_epochs(epochs_list,
  3064. add_offset=add_offset,
  3065. on_mismatch=on_mismatch))
  3066. @verbose
  3067. def average_movements(epochs, head_pos=None, orig_sfreq=None, picks=None,
  3068. origin='auto', weight_all=True, int_order=8, ext_order=3,
  3069. destination=None, ignore_ref=False, return_mapping=False,
  3070. mag_scale=100., verbose=None):
  3071. """Average data using Maxwell filtering, transforming using head positions.
  3072. Parameters
  3073. ----------
  3074. epochs : instance of Epochs
  3075. The epochs to operate on.
  3076. %(maxwell_pos)s
  3077. orig_sfreq : float | None
  3078. The original sample frequency of the data (that matches the
  3079. event sample numbers in ``epochs.events``). Can be ``None``
  3080. if data have not been decimated or resampled.
  3081. %(picks_all_data)s
  3082. %(maxwell_origin)s
  3083. weight_all : bool
  3084. If True, all channels are weighted by the SSS basis weights.
  3085. If False, only MEG channels are weighted, other channels
  3086. receive uniform weight per epoch.
  3087. %(maxwell_int)s
  3088. %(maxwell_ext)s
  3089. %(maxwell_dest)s
  3090. %(maxwell_ref)s
  3091. return_mapping : bool
  3092. If True, return the mapping matrix.
  3093. %(maxwell_mag)s
  3094. .. versionadded:: 0.13
  3095. %(verbose)s
  3096. Returns
  3097. -------
  3098. evoked : instance of Evoked
  3099. The averaged epochs.
  3100. See Also
  3101. --------
  3102. mne.preprocessing.maxwell_filter
  3103. mne.chpi.read_head_pos
  3104. Notes
  3105. -----
  3106. The Maxwell filtering version of this algorithm is described in [1]_,
  3107. in section V.B "Virtual signals and movement correction", equations
  3108. 40-44. For additional validation, see [2]_.
  3109. Regularization has not been added because in testing it appears to
  3110. decrease dipole localization accuracy relative to using all components.
  3111. Fine calibration and cross-talk cancellation, however, could be added
  3112. to this algorithm based on user demand.
  3113. .. versionadded:: 0.11
  3114. References
  3115. ----------
  3116. .. [1] Taulu S. and Kajola M. "Presentation of electromagnetic
  3117. multichannel data: The signal space separation method,"
  3118. Journal of Applied Physics, vol. 97, pp. 124905 1-10, 2005.
  3119. .. [2] Wehner DT, Hämäläinen MS, Mody M, Ahlfors SP. "Head movements
  3120. of children in MEG: Quantification, effects on source
  3121. estimation, and compensation. NeuroImage 40:541550, 2008.
  3122. """ # noqa: E501
  3123. from .preprocessing.maxwell import (_trans_sss_basis, _reset_meg_bads,
  3124. _check_usable, _col_norm_pinv,
  3125. _get_n_moments, _get_mf_picks_fix_mags,
  3126. _prep_mf_coils, _check_destination,
  3127. _remove_meg_projs, _get_coil_scale)
  3128. if head_pos is None:
  3129. raise TypeError('head_pos must be provided and cannot be None')
  3130. from .chpi import head_pos_to_trans_rot_t
  3131. if not isinstance(epochs, BaseEpochs):
  3132. raise TypeError('epochs must be an instance of Epochs, not %s'
  3133. % (type(epochs),))
  3134. orig_sfreq = epochs.info['sfreq'] if orig_sfreq is None else orig_sfreq
  3135. orig_sfreq = float(orig_sfreq)
  3136. if isinstance(head_pos, np.ndarray):
  3137. head_pos = head_pos_to_trans_rot_t(head_pos)
  3138. trn, rot, t = head_pos
  3139. del head_pos
  3140. _check_usable(epochs)
  3141. origin = _check_origin(origin, epochs.info, 'head')
  3142. recon_trans = _check_destination(destination, epochs.info, True)
  3143. logger.info('Aligning and averaging up to %s epochs'
  3144. % (len(epochs.events)))
  3145. if not np.array_equal(epochs.events[:, 0], np.unique(epochs.events[:, 0])):
  3146. raise RuntimeError('Epochs must have monotonically increasing events')
  3147. info_to = epochs.info.copy()
  3148. meg_picks, mag_picks, grad_picks, good_mask, _ = \
  3149. _get_mf_picks_fix_mags(info_to, int_order, ext_order, ignore_ref)
  3150. coil_scale, mag_scale = _get_coil_scale(
  3151. meg_picks, mag_picks, grad_picks, mag_scale, info_to)
  3152. n_channels, n_times = len(epochs.ch_names), len(epochs.times)
  3153. other_picks = np.setdiff1d(np.arange(n_channels), meg_picks)
  3154. data = np.zeros((n_channels, n_times))
  3155. count = 0
  3156. # keep only MEG w/bad channels marked in "info_from"
  3157. info_from = pick_info(info_to, meg_picks[good_mask], copy=True)
  3158. all_coils_recon = _prep_mf_coils(info_to, ignore_ref=ignore_ref)
  3159. all_coils = _prep_mf_coils(info_from, ignore_ref=ignore_ref)
  3160. # remove MEG bads in "to" info
  3161. _reset_meg_bads(info_to)
  3162. # set up variables
  3163. w_sum = 0.
  3164. n_in, n_out = _get_n_moments([int_order, ext_order])
  3165. S_decomp = 0. # this will end up being a weighted average
  3166. last_trans = None
  3167. decomp_coil_scale = coil_scale[good_mask]
  3168. exp = dict(int_order=int_order, ext_order=ext_order, head_frame=True,
  3169. origin=origin)
  3170. n_in = _get_n_moments(int_order)
  3171. for ei, epoch in enumerate(epochs):
  3172. event_time = epochs.events[epochs._current - 1, 0] / orig_sfreq
  3173. use_idx = np.where(t <= event_time)[0]
  3174. if len(use_idx) == 0:
  3175. trans = info_to['dev_head_t']['trans']
  3176. else:
  3177. use_idx = use_idx[-1]
  3178. trans = np.vstack([np.hstack([rot[use_idx], trn[[use_idx]].T]),
  3179. [[0., 0., 0., 1.]]])
  3180. loc_str = ', '.join('%0.1f' % tr for tr in (trans[:3, 3] * 1000))
  3181. if last_trans is None or not np.allclose(last_trans, trans):
  3182. logger.info(' Processing epoch %s (device location: %s mm)'
  3183. % (ei + 1, loc_str))
  3184. reuse = False
  3185. last_trans = trans
  3186. else:
  3187. logger.info(' Processing epoch %s (device location: same)'
  3188. % (ei + 1,))
  3189. reuse = True
  3190. epoch = epoch.copy() # because we operate inplace
  3191. if not reuse:
  3192. S = _trans_sss_basis(exp, all_coils, trans,
  3193. coil_scale=decomp_coil_scale)
  3194. # Get the weight from the un-regularized version (eq. 44)
  3195. weight = np.linalg.norm(S[:, :n_in])
  3196. # XXX Eventually we could do cross-talk and fine-cal here
  3197. S *= weight
  3198. S_decomp += S # eq. 41
  3199. epoch[slice(None) if weight_all else meg_picks] *= weight
  3200. data += epoch # eq. 42
  3201. w_sum += weight
  3202. count += 1
  3203. del info_from
  3204. mapping = None
  3205. if count == 0:
  3206. data.fill(np.nan)
  3207. else:
  3208. data[meg_picks] /= w_sum
  3209. data[other_picks] /= w_sum if weight_all else count
  3210. # Finalize weighted average decomp matrix
  3211. S_decomp /= w_sum
  3212. # Get recon matrix
  3213. # (We would need to include external here for regularization to work)
  3214. exp['ext_order'] = 0
  3215. S_recon = _trans_sss_basis(exp, all_coils_recon, recon_trans)
  3216. exp['ext_order'] = ext_order
  3217. # We could determine regularization on basis of destination basis
  3218. # matrix, restricted to good channels, as regularizing individual
  3219. # matrices within the loop above does not seem to work. But in
  3220. # testing this seemed to decrease localization quality in most cases,
  3221. # so we do not provide the option here.
  3222. S_recon /= coil_scale
  3223. # Invert
  3224. pS_ave = _col_norm_pinv(S_decomp)[0][:n_in]
  3225. pS_ave *= decomp_coil_scale.T
  3226. # Get mapping matrix
  3227. mapping = np.dot(S_recon, pS_ave)
  3228. # Apply mapping
  3229. data[meg_picks] = np.dot(mapping, data[meg_picks[good_mask]])
  3230. info_to['dev_head_t'] = recon_trans # set the reconstruction transform
  3231. evoked = epochs._evoked_from_epoch_data(data, info_to, picks,
  3232. n_events=count, kind='average',
  3233. comment=epochs._name)
  3234. _remove_meg_projs(evoked) # remove MEG projectors, they won't apply now
  3235. logger.info('Created Evoked dataset from %s epochs' % (count,))
  3236. return (evoked, mapping) if return_mapping else evoked
  3237. @verbose
  3238. def make_fixed_length_epochs(raw, duration=1., preload=False,
  3239. reject_by_annotation=True, proj=True, overlap=0.,
  3240. id=1, verbose=None):
  3241. """Divide continuous raw data into equal-sized consecutive epochs.
  3242. Parameters
  3243. ----------
  3244. raw : instance of Raw
  3245. Raw data to divide into segments.
  3246. duration : float
  3247. Duration of each epoch in seconds. Defaults to 1.
  3248. %(preload)s
  3249. %(reject_by_annotation_epochs)s
  3250. .. versionadded:: 0.21.0
  3251. %(proj_epochs)s
  3252. .. versionadded:: 0.22.0
  3253. overlap : float
  3254. The overlap between epochs, in seconds. Must be
  3255. ``0 <= overlap < duration``. Default is 0, i.e., no overlap.
  3256. .. versionadded:: 0.23.0
  3257. id : int
  3258. The id to use (default 1).
  3259. .. versionadded:: 0.24.0
  3260. %(verbose)s
  3261. Returns
  3262. -------
  3263. epochs : instance of Epochs
  3264. Segmented data.
  3265. Notes
  3266. -----
  3267. .. versionadded:: 0.20
  3268. """
  3269. events = make_fixed_length_events(raw, id=id, duration=duration,
  3270. overlap=overlap)
  3271. delta = 1. / raw.info['sfreq']
  3272. return Epochs(raw, events, event_id=[id], tmin=0, tmax=duration - delta,
  3273. baseline=None, preload=preload,
  3274. reject_by_annotation=reject_by_annotation, proj=proj,
  3275. verbose=verbose)