PageRenderTime 54ms CodeModel.GetById 18ms RepoModel.GetById 1ms app.codeStats 0ms

/bcontrol.py

https://github.com/cxrodgers/ns5_process
Python | 869 lines | 728 code | 52 blank | 89 comment | 27 complexity | 478033c91432554a42c472d1ef963557 MD5 | raw file
  1. from __future__ import print_function
  2. from builtins import str
  3. from builtins import range
  4. from builtins import object
  5. import numpy as np
  6. import os.path
  7. import scipy.io
  8. import glob
  9. import pickle
  10. import pandas
  11. class LBPB_constants(object):
  12. def __init__(self, sn2name=None):
  13. if sn2name is None:
  14. sn2name = \
  15. {1: u'lo_pc_go', 2: u'hi_pc_no', 3: u'le_lc_go', 4: u'ri_lc_no',
  16. 5: u'le_hi_pc', 6: u'ri_hi_pc', 7: u'le_lo_pc', 8: u'ri_lo_pc',
  17. 9: u'le_hi_lc', 10: u'ri_hi_lc', 11: u'le_lo_lc', 12: u'ri_lo_lc'}
  18. self.sn2name = sn2name
  19. self.name2sn = dict([(val, key) for key, val in list(sn2name.items())])
  20. self.sn2shortname = \
  21. {1: u'lo', 2: u'hi', 3: u'le', 4: u'ri',
  22. 5: u'lehi', 6: u'rihi', 7: u'lelo', 8: u'rilo',
  23. 9: u'lehi', 10: u'rihi', 11: u'lelo', 12: u'rilo'}
  24. self.sn2block = \
  25. {1: u'P', 2: u'P', 3: u'L', 4: u'L',
  26. 5: u'PB', 6: u'PB', 7: u'PB', 8: u'PB',
  27. 9: u'LB', 10: u'LB', 11: u'LB', 12: u'LB'}
  28. def ordered_by_sound(self):
  29. sns = (5, 9, 6, 10, 7, 11, 8, 12)
  30. return (sns, tuple([self.sn2name[sn] for sn in sns]))
  31. def LB(self):
  32. return set(('le_hi_lc', 'ri_hi_lc', 'le_lo_lc', 'ri_lo_lc'))
  33. def PB(self):
  34. return set(('le_hi_pc', 'ri_hi_pc', 'le_lo_pc', 'ri_lo_pc'))
  35. def go(self):
  36. return set(('le_hi_lc', 'le_lo_lc', 'le_lo_pc', 'ri_lo_pc'))
  37. def nogo(self):
  38. return set(('ri_hi_lc', 'ri_lo_lc', 'le_hi_pc', 'ri_hi_pc'))
  39. def lo(self):
  40. return set(('le_lo_lc', 'ri_lo_lc', 'le_lo_pc', 'ri_lo_pc'))
  41. def hi(self):
  42. return set(('le_hi_lc', 'ri_hi_lc', 'le_hi_pc', 'ri_hi_pc'))
  43. def le(self):
  44. return set(('le_lo_lc', 'le_hi_lc', 'le_lo_pc', 'le_hi_pc'))
  45. def ri(self):
  46. return set(('ri_lo_lc', 'ri_hi_lc', 'ri_lo_pc', 'ri_hi_pc'))
  47. def comparisons(self, comp='sound'):
  48. """Returns meaningful comparisons.
  49. Returns a tuple (names, idxs, groupnames).
  50. `names` and `idxs` each have the same form: it is
  51. an N-tuple of 2-tuples. N is the number of pairwise comparisons.
  52. Each entry of the 2-tuple is a tuple of stimuli to be pooled.
  53. `groupnames` is an N-tuple of 2-tuples of strings, the name of each
  54. pool.
  55. Example: blockwise comparison
  56. (((5,6,7,8), (9,10,11,12)))
  57. Example: soundwise comparison
  58. (((5,), (9,)), ((6,), (10,)), ((7,), (11,)), ((8,), (12,)))
  59. Usage:
  60. names, idxs, groupnames = comparisons()
  61. len(names) # the number of comparisons
  62. len(names[n]) # length of nth comparison, always 2 since pairwise
  63. len(names[n][m]) # size of mth pool in nth comparison
  64. groupnames[n][m] # name of the mth pool in nth comparison
  65. """
  66. x_labels = []
  67. stim_groups = []
  68. groupnames = []
  69. idxs, names = self.ordered_by_sound()
  70. if comp == 'sound':
  71. for n_pairs in range(4):
  72. n = n_pairs * 2
  73. pool1 = (idxs[n],)
  74. pool2 = (idxs[n+1],)
  75. stim_groups.append((pool1, pool2))
  76. pool1 = (names[n],)
  77. pool2 = (names[n+1],)
  78. x_labels.append((pool1, pool2))
  79. groupnames.append((names[n], names[n+1]))
  80. elif comp == 'block':
  81. for n_pairs in range(1):
  82. pool1 = tuple(idxs[::2])
  83. pool2 = tuple(idxs[1::2])
  84. stim_groups.append((pool1, pool2))
  85. pool1 = tuple(names[::2])
  86. pool2 = tuple(names[1::2])
  87. x_labels.append((pool1, pool2))
  88. groupnames.append(('PB', 'LB'))
  89. elif comp == 'leri':
  90. for n_pairs in range(1):
  91. pool1 = [idxs[0], idxs[1], idxs[4], idxs[5]]
  92. pool2 = [idxs[2], idxs[3], idxs[6], idxs[7]]
  93. stim_groups.append((pool1, pool2))
  94. pool1 = [names[0], names[1], names[4], names[5]]
  95. pool2 = [names[2], names[3], names[6], names[7]]
  96. x_labels.append((pool1, pool2))
  97. groupnames.append(('Le', 'Ri'))
  98. elif comp == 'lohi':
  99. for n_pairs in range(1):
  100. pool1 = idxs[4:8]
  101. pool2 = idxs[0:4]
  102. stim_groups.append((pool1, pool2))
  103. pool1 = names[4:8]
  104. pool2 = names[0:4]
  105. x_labels.append((pool1, pool2))
  106. groupnames.append(('Lo', 'Hi'))
  107. else:
  108. raise ValueError("unrecognized comparison: %s" % comp)
  109. return x_labels, stim_groups, groupnames
  110. class Bcontrol_Loader_By_Dir(object):
  111. """Wrapper for Bcontrol_Loader to load/save from directory.
  112. Methods
  113. -------
  114. load : Get data from directory
  115. get_sn2trials : returns a dict of stimulus numbers and trial numbers
  116. get_sn2name : returns a dict of stimulus numbers and name
  117. Other useful information (TRIALS_INFO, SOUNDS_INFO, etc) is
  118. available in my dict `data` after loading.
  119. """
  120. def __init__(self, dirname, auto_validate=True, v2_behavior=False,
  121. skip_trial_set=[], dictify=True):
  122. """Initialize loader, specifying directory containing info.
  123. dictify : store a sanitized version that replaces mat_struct object
  124. with dicts. This is done here, not in the lower object, to avoid
  125. rewriting code that depends on mat_struct.
  126. For other parameters, see Bcontrol_Loader
  127. """
  128. self.dirname = dirname
  129. self._pickle_name = 'bdata.pickle'
  130. self._bcontrol_matfilename = 'data_*.mat'
  131. # Build a Bcontrol_Loader with same parameters
  132. self._bcl = Bcontrol_Loader(auto_validate=auto_validate,
  133. v2_behavior=v2_behavior, skip_trial_set=skip_trial_set)
  134. self.dictify = dictify
  135. def load(self, force=False):
  136. """Loads Bcontrol info into self.data.
  137. First checks to see if bdata pickle exists, in which case it loads
  138. that pickle. Otherwise, uses self._bcl to load data from matfile.
  139. If force is True, skips check for pickle.
  140. """
  141. # Decide whether to run
  142. if force:
  143. pickle_found = False
  144. data = None
  145. else:
  146. data, pickle_found = self._check_for_pickle()
  147. if pickle_found:
  148. self._bcl.data = data
  149. else:
  150. filename = self._find_bcontrol_matfile()
  151. self._bcl.filename = filename
  152. self._bcl.load()
  153. if self.dictify:
  154. self._bcl.data = dictify_mat_struct(self._bcl.data)
  155. # Pickle self._bcl.data
  156. self._pickle_data()
  157. self.data = self._bcl.data
  158. def _check_for_pickle(self):
  159. """Tries to load bdata pickle if exists.
  160. Returns (data, True) if bdata pickle is found in self.dirname.
  161. Otherwise returns (None, False)
  162. """
  163. data = None
  164. possible_pickles = glob.glob(os.path.join(self.dirname,
  165. self._pickle_name))
  166. if len(possible_pickles) == 1:
  167. # A pickle was found, load it
  168. f = file(possible_pickles[0], 'r')
  169. try:
  170. data = pickle.load(f)
  171. except AttributeError:
  172. # not sure why this is the pickling error
  173. print("bad pickle")
  174. return (None, False)
  175. f.close()
  176. return (data, len(possible_pickles) == 1)
  177. def _find_bcontrol_matfile(self):
  178. """Returns filename to BControl matfile in self.dirname"""
  179. fn_bdata = glob.glob(os.path.join(self.dirname,
  180. self._bcontrol_matfilename))
  181. if len(fn_bdata) == 0:
  182. raise IOError("cannot find bcontrol file in %s" % self.dirname)
  183. elif len(fn_bdata) > 1:
  184. raise IOError("multiple bcontrol files in %s" % self.dirname)
  185. #assert(len(fn_bdata) == 1)
  186. return fn_bdata[0]
  187. def _pickle_data(self):
  188. """Pickles self._bcl.data for future use."""
  189. to_pickle = self._bcl.data
  190. fn_pickle = os.path.join(self.dirname, self._pickle_name)
  191. f = file(fn_pickle, 'w')
  192. pickle.dump(to_pickle, f)
  193. f.close()
  194. def get_sn2trials(self, outcome='hit'):
  195. return self._bcl.get_sn2trials(outcome)
  196. def get_sn2names(self):
  197. return self._bcl.get_sn2names()
  198. def get_sn2name(self):
  199. return self._bcl.get_sn2names()
  200. class Bcontrol_Loader(object):
  201. """Loads matlab BControl data and validates"""
  202. def __init__(self, filename=None, auto_validate=True, v2_behavior=False,
  203. skip_trial_set=[], dictify=True):
  204. """Initialize loader, optionally specifying filename.
  205. If auto_validate is True, then the validation script will run
  206. after loading the data. In any case, you can always call the
  207. validation method manually.
  208. v2_behavior : boolean. If True, then looks for variables that
  209. work with TwoAltChoice_v2 (no datasink).
  210. skip_trial_set : list. Wherever TRIALS_INFO['TRIAL_NUMBER'] is
  211. a member of skip_trial_set, that trial will be skipped in the
  212. validation process.
  213. """
  214. self.filename = filename
  215. self.auto_validate = auto_validate
  216. self.v2_behavior = v2_behavior
  217. self.skip_trial_set = np.array(skip_trial_set)
  218. # Set a variable for accessing TwoAltChoice_vx variable names
  219. if self.v2_behavior:
  220. self._vstring = 'v2'
  221. else:
  222. self._vstring = 'v4'
  223. def load(self, filename=None, keep_matdata=False):
  224. """Loads the bcontrol matlab file.
  225. Loads the data from disk. Then, optionally, validates it. Finally,
  226. returns a dict of useful information from the file, containing
  227. the following keys:
  228. TRIALS_INFO: a recarray of trial-by-trial info
  229. SOUNDS_INFO: describes the rules associated with each sound
  230. CONSTS: helps in decoding the integer values
  231. peh: the raw events and pokes from BControl
  232. datasink: debugging trial-by-trial snapshots
  233. onsets: the stimulus onsets, extracted from peh
  234. Note: for compatibility with Matlab, the stimulus numbers in
  235. TRIALS_INFO are numbered beginning with 1.
  236. If keep_matdata, then the raw data in the matfile is saved in `matdata`.
  237. This is mainly useful if you want to investigate 'saved' and
  238. 'saved_history'.
  239. """
  240. if filename is not None: self.filename = filename
  241. # Actually load the file from disk and store variables in self.data
  242. self._load(keep_matdata=keep_matdata)
  243. # Optionally, run validation
  244. # Will fail assertion if errors, otherwise you're fine
  245. if self.auto_validate: self.validate()
  246. # Return dict of import info
  247. return self.data
  248. def get_sn2trials(self, outcome='hit'):
  249. """Returns a dict: stimulus number -> trials on which it occurred.
  250. For each stimulus number, finds trials with that stimulus number
  251. that were not forced and with the specified outcome.
  252. Parameters
  253. ----------
  254. outcome : string. Will be tested against TRIALS_INFO['OUTCOME'].
  255. Should be hit, error, or wrong_port.
  256. Returns
  257. -------
  258. dict sn2trials, such that sn2trials[sn] is the list of trials on
  259. which sn occurred.
  260. """
  261. TRIALS_INFO = self.data['TRIALS_INFO']
  262. CONSTS = self.data['CONSTS']
  263. trial_numbers_vs_sn = dict()
  264. # Find all trials matching the requirements.
  265. for sn in np.unique(TRIALS_INFO['STIM_NUMBER']):
  266. keep_rows = \
  267. (TRIALS_INFO['STIM_NUMBER'] == sn) & \
  268. (TRIALS_INFO['OUTCOME'] == CONSTS[outcome.upper()]) & \
  269. (TRIALS_INFO['NONRANDOM'] == 0)
  270. trial_numbers_vs_sn[sn] = TRIALS_INFO['TRIAL_NUMBER'][keep_rows]
  271. return trial_numbers_vs_sn
  272. def get_sn2names(self):
  273. sn2name = dict([(n+1, sndname) for n, sndname in \
  274. enumerate(self.data['SOUNDS_INFO']['sound_name'])])
  275. return sn2name
  276. def _load(self, keep_matdata=False):
  277. """Hidden method that actually loads matfile data and stores
  278. This is for low-level code that parse the BControl `saved`,
  279. `saved_history`, etc.
  280. """
  281. # Load the matlab file
  282. matdata = scipy.io.loadmat(self.filename, squeeze_me=True,
  283. struct_as_record=False)
  284. saved = matdata['saved']
  285. saved_history = matdata['saved_history']
  286. # Optionally store
  287. if keep_matdata:
  288. self.matdata = matdata
  289. # Load TRIALS_INFO matrix as recarray
  290. TRIALS_INFO = self._format_trials_info(saved)
  291. # Load CONSTS
  292. CONSTS = saved.__dict__[('TwoAltChoice_%s_CONSTS' % self._vstring)].\
  293. __dict__.copy()
  294. CONSTS.pop('_fieldnames')
  295. for (k,v) in list(CONSTS.items()):
  296. try:
  297. # This will work if v is a 0d array (EPD loadmat)
  298. CONSTS[k] = v.flatten()[0]
  299. except AttributeError:
  300. # With other versions of loadmat, v is an int
  301. CONSTS[k] = v
  302. # Load SOUNDS_INFO
  303. SOUNDS_INFO = saved.__dict__[\
  304. ('TwoAltChoice_%s_SOUNDS_INFO' % self._vstring)].__dict__.copy()
  305. SOUNDS_INFO.pop('_fieldnames')
  306. # Now the trial-by-trial datasink, which does not exist in v2
  307. datasink = None
  308. if not self.v2_behavior:
  309. datasink = saved_history.__dict__[('TwoAltChoice_%s_datasink' % \
  310. self._vstring)]
  311. # And finally the stored behavioral events
  312. peh = saved_history.ProtocolsSection_parsed_events
  313. # Extract out the parameter of most interest: stimulus onset
  314. onsets = np.array([trial.__dict__['states'].\
  315. __dict__['play_stimulus'][0] for trial in peh])
  316. # Store
  317. self.data = dict((
  318. ('TRIALS_INFO', TRIALS_INFO),
  319. ('SOUNDS_INFO', SOUNDS_INFO),
  320. ('CONSTS', CONSTS),
  321. ('peh', peh),
  322. ('datasink', datasink),
  323. ('onsets', onsets)))
  324. def _format_trials_info(self, saved):
  325. """Hidden method to format TRIALS_INFO.
  326. Converts the matrix to a recarray and names it with the column
  327. names from TRIALS_INFO_COLS.
  328. """
  329. # Some constants that need to be converted from structs to dicts
  330. d2 = saved.__dict__[('TwoAltChoice_%s_TRIALS_INFO_COLS' % \
  331. self._vstring)].__dict__.copy()
  332. d2.pop('_fieldnames')
  333. try:
  334. # This will work if loadmat returns 0d arrays (EPD)
  335. d3 = dict((v.flatten()[0], k) for k, v in d2.items())
  336. except AttributeError:
  337. # With other versions, v is an int
  338. d3 = dict((v, k) for k, v in d2.items())
  339. # Check that all the columns are named
  340. if len(d3) != len(d2):
  341. print("Multiple columns with same number in TRIALS_INFO_COLS")
  342. # Write the column names in order
  343. # Will error here if the column names are messed up
  344. # Note inherent conversion from 1-based to 0-based indexing
  345. field_names = [d3[col] for col in range(1,1+len(d3))]
  346. TRIALS_INFO = np.rec.fromrecords(\
  347. saved.__dict__[('TwoAltChoice_%s_TRIALS_INFO' % self._vstring)],
  348. titles=field_names)
  349. return TRIALS_INFO
  350. def validate(self):
  351. """Runs validation checks on the loaded data.
  352. There are unlimited consistency checks we could do, but only a few
  353. easy checks are implemented. The most problematic error would be
  354. inconsistent data in TRIALS_INFO, for example if the rows were
  355. written with the wrong trial number or something. That's the
  356. primary thing that is checkoed.
  357. It is assumed that we
  358. can trust the state machine states. So, the pokes are not explicitly
  359. checked to ensure the exact timing of behavioral events. This would
  360. be a good feature to add though. Instead, the indicator states are
  361. matched to TRIALS_INFO. When easy, I check that at least one poke
  362. in the right port occurred, but I don't check that it actually
  363. occurred in the window of opportunity.
  364. No block information is checked. This is usually pretty obvious
  365. if it's wrong.
  366. If there is a known problem on certain trials, set
  367. self.skip_trial_set to a list of trials to skip. Rows of
  368. TRIALS_INFO for which TRIAL_NUMBER matches a member of this set
  369. will be skipped (not validated).
  370. Checks:
  371. 1) Does the *_istate outcome match the TRIALS_INFO outcome
  372. 2) For each possible trial outcome, the correct port must have
  373. been entered (or not entered).
  374. 3) The stim number in TRIALS_INFO should match the other TRIALS_INFO
  375. characteristics in accordance with SOUNDS_INFO.
  376. 4) Every trial in peh should be in TRIALS_INFO, all others should be
  377. FUTURE_TRIAL.
  378. """
  379. # Shortcut references to save typing
  380. CONSTS = self.data['CONSTS']
  381. TRIALS_INFO = self.data['TRIALS_INFO']
  382. SOUNDS_INFO = self.data['SOUNDS_INFO']
  383. peh = self.data['peh']
  384. datasink = self.data['datasink']
  385. # Some inverse maps for looking up data in TRIALS_INFO
  386. outcome_map = dict((CONSTS[str.upper(s)], s) for s in \
  387. ('hit', 'error', 'wrong_port'))
  388. left_right_map = dict((CONSTS[str.upper(s)], s) for s in \
  389. ('left', 'right'))
  390. go_or_nogo_map = dict((CONSTS[str.upper(s)], s) for s in \
  391. ('go', 'nogo'))
  392. # Go through peh and for each trial, match data to TRIALS_INFO
  393. # Also match to datasink. Note that datasink is a snapshot taken
  394. # immediately before the next trial state machine was uploaded.
  395. # So it contains some information about previous trial and some
  396. # about next. It is also always length 1 more than peh
  397. for n, trial in enumerate(peh):
  398. # Skip trials
  399. if TRIALS_INFO['TRIAL_NUMBER'][n] in self.skip_trial_set:
  400. continue
  401. # Extract info from the current row of TRIALS_INFO
  402. outcome = outcome_map[TRIALS_INFO['OUTCOME'][n]]
  403. correct_side = left_right_map[TRIALS_INFO['CORRECT_SIDE'][n]]
  404. go_or_nogo = go_or_nogo_map[TRIALS_INFO['GO_OR_NOGO'][n]]
  405. # Note that we correct for 1- and 0- indexing into SOUNDS_INFO here
  406. stim_number = TRIALS_INFO['STIM_NUMBER'][n] - 1
  407. # TRIALS_INFO is internally consistent with sound parameters
  408. assert(TRIALS_INFO['CORRECT_SIDE'][n] == \
  409. SOUNDS_INFO['correct_side'][stim_number])
  410. assert(TRIALS_INFO['GO_OR_NOGO'][n] == \
  411. SOUNDS_INFO['go_or_nogo'][stim_number])
  412. # If possible, check datasink
  413. if not self.v2_behavior:
  414. # Check that datasink is consistent with TRIALS_INFO
  415. # First load the n and n+1 sinks, since the info is split
  416. # across them. The funny .item() syntax is because loading
  417. # Matlab structs sometimes produces 0d arrays.
  418. # This little segment of code is the only place where the
  419. # datasink is checked.
  420. prev_sink = datasink[n]
  421. next_sink = datasink[n+1]
  422. try:
  423. assert(prev_sink.next_sound_id.stimulus.item() == \
  424. TRIALS_INFO['STIM_NUMBER'][n])
  425. assert(prev_sink.next_side.item() == \
  426. TRIALS_INFO['CORRECT_SIDE'][n])
  427. assert(prev_sink.next_trial_type.item() == \
  428. TRIALS_INFO['GO_OR_NOGO'][n])
  429. assert(next_sink.finished_trial_num.item() == \
  430. TRIALS_INFO['TRIAL_NUMBER'][n])
  431. assert(CONSTS[next_sink.finished_trial_outcome.item()] == \
  432. TRIALS_INFO['OUTCOME'][n])
  433. except AttributeError:
  434. # .item() syntax only required for some versions of scipy
  435. assert(prev_sink.next_sound_id.stimulus == \
  436. TRIALS_INFO['STIM_NUMBER'][n])
  437. assert(prev_sink.next_side == \
  438. TRIALS_INFO['CORRECT_SIDE'][n])
  439. assert(prev_sink.next_trial_type == \
  440. TRIALS_INFO['GO_OR_NOGO'][n])
  441. assert(next_sink.finished_trial_num == \
  442. TRIALS_INFO['TRIAL_NUMBER'][n])
  443. assert(CONSTS[next_sink.finished_trial_outcome] == \
  444. TRIALS_INFO['OUTCOME'][n])
  445. # Sound name is correct
  446. # assert(SOUNDS_INFO.sound_names[stim_number] == datasink[sound name]
  447. # Validate trial
  448. self._validate_trial(trial, outcome, correct_side, go_or_nogo)
  449. # All future trials should be marked as such
  450. # Under certain circumstances, TRIALS_INFO can contain information
  451. # about one more trial than peh. I think this is if the protocol
  452. # is turned off before the end of the trial.
  453. try:
  454. assert(np.all(TRIALS_INFO['OUTCOME'][len(peh):] == \
  455. CONSTS['FUTURE_TRIAL']))
  456. except AssertionError:
  457. print("warn: at least one more trial in TRIALS_INFO than peh.")
  458. print("checking that it is no more than one ...")
  459. assert(np.all(TRIALS_INFO['OUTCOME'][len(peh)+1:] == \
  460. CONSTS['FUTURE_TRIAL']))
  461. def _validate_trial(self, trial, outcome, correct_side, go_or_nogo):
  462. """Dispatches to appropriate trial validation method"""
  463. # Check if *_istate matches TRIALS_INFO
  464. assert(trial.states.__dict__[outcome+'_istate'].size == 2)
  465. dispatch_table = dict((\
  466. (('hit', 'go'), self._validate_hit_on_go),
  467. (('error', 'go'), self._validate_error_on_go),
  468. (('hit', 'nogo'), self._validate_hit_on_nogo),
  469. (('error', 'nogo'), self._validate_error_on_nogo),
  470. (('wrong_port', 'go'), self._validate_wrong_port),
  471. (('wrong_port', 'nogo'), self._validate_wrong_port),
  472. ))
  473. validation_method = dispatch_table[(outcome, go_or_nogo)]
  474. validation_method(trial, outcome, correct_side)
  475. def _validate_hit_on_go(self, trial, outcome, correct_side):
  476. """For hits on go trials, rewarded side should match correct side
  477. And there should be at least one poke in correct side
  478. """
  479. assert(trial.states.__dict__[correct_side+'_reward'].size == 2)
  480. assert(trial.pokes.__dict__[str.upper(correct_side[0])].size > 0)
  481. assert(trial.states.hit_on_go.size == 2)
  482. def _validate_error_on_go(self, trial, outcome, correct_side):
  483. """For errors on go trials, the reward state should not be entered"""
  484. assert(trial.states.left_reward.size == 0)
  485. assert(trial.states.right_reward.size == 0)
  486. assert(trial.states.error_on_go.size == 2)
  487. def _validate_error_on_nogo(self, trial, outcome, correct_side):
  488. """For errors on nogo trials, no reward should have been delivered
  489. And at least one entry into the correct side
  490. """
  491. assert(trial.states.left_reward.size == 0)
  492. assert(trial.states.right_reward.size == 0)
  493. assert(trial.pokes.__dict__[str.upper(correct_side[0])].size > 0)
  494. assert(trial.states.error_on_nogo.size == 2)
  495. def _validate_hit_on_nogo(self, trial, outcome, correct_side):
  496. """For hits on nogo trials, a very short reward state should have
  497. occurred (this is just how it is handled in the protocol)
  498. """
  499. assert(np.diff(trial.states.__dict__[correct_side+'_reward']) < .002)
  500. assert(trial.states.hit_on_nogo.size == 2)
  501. def _validate_wrong_port(self, trial, outcome, correct_side):
  502. """For wrong port trials, no reward state, and should have
  503. entered wrong side at least once
  504. """
  505. assert(trial.states.left_reward.size == 0)
  506. assert(trial.states.right_reward.size == 0)
  507. if correct_side == 'left': assert(trial.pokes.R.size > 0)
  508. else: assert(trial.pokes.L.size > 0)
  509. # Helper fuctions to remove mat_struct ugliness
  510. def is_mat_struct(obj):
  511. res = True
  512. try:
  513. obj.__dict__['_fieldnames']
  514. except (AttributeError, KeyError):
  515. res = False
  516. return res
  517. def dictify_mat_struct(mat_struct, flatten_0d=True, max_depth=-1):
  518. """Recursively turn mat struct objects into simple dicts.
  519. flatten_0d: if a 0d array is encountered, simply store its item
  520. max_depth: stop if this recursion depth is exceeded.
  521. if -1, no max depth
  522. """
  523. # Check recursion depth
  524. if max_depth == 0:
  525. raise ValueError("max depth exceeded!")
  526. # If not a mat struct, simply return (or optionally flatten)
  527. if not is_mat_struct(mat_struct):
  528. if hasattr(mat_struct, 'ndim'):
  529. # array-like
  530. if flatten_0d and mat_struct.ndim == 0:
  531. # flattened 0d array
  532. return mat_struct.flatten()[0]
  533. elif mat_struct.dtype != np.dtype('object'):
  534. # simple arrays
  535. return mat_struct
  536. else:
  537. # object arrays
  538. return np.array([
  539. dictify_mat_struct(val, flatten_0d, max_depth=max_depth-1)
  540. for val in mat_struct], dtype=mat_struct.dtype)
  541. elif hasattr(mat_struct, 'values') and hasattr(mat_struct, 'keys'):
  542. # dict-like
  543. return dict(
  544. [(key, dictify_mat_struct(val, flatten_0d, max_depth=max_depth-1))
  545. for key, val in list(mat_struct.items())])
  546. elif hasattr(mat_struct, '__len__'):
  547. # list-like
  548. # this case now seems to catch strings too!
  549. # detect which it is
  550. try:
  551. if str(mat_struct) == mat_struct:
  552. is_a_string = True
  553. else:
  554. is_a_string = False
  555. except:
  556. raise ValueError("cannot convert obj to string")
  557. # If a string, return directly
  558. # Else, dictify every object in the list-like object
  559. if is_a_string:
  560. return mat_struct
  561. else:
  562. return [
  563. dictify_mat_struct(val, flatten_0d, max_depth=max_depth-1)
  564. for val in mat_struct]
  565. else:
  566. # everything else
  567. return mat_struct
  568. # Create a new dict to store result
  569. res = {}
  570. # Convert troublesome mat_struct to a dict, then recursively remove
  571. # contained mat structs.
  572. msd = mat_struct.__dict__
  573. for key in msd['_fieldnames']:
  574. res[key] = dictify_mat_struct(msd[key], flatten_0d,
  575. max_depth=max_depth-1)
  576. return res
  577. def generate_event_list(peh, TRIALS_INFO=None, TI_start_idx=0, sort=True,
  578. error_check_last_trial=5):
  579. """Given a peh object (or list of objects), return list of events.
  580. TRIALS_INFO : provide this and useful information from it, such
  581. as stimulus number and trial number, will be inserted into the
  582. event list. It should be the original TRIALS_INFO, not the demunged.
  583. TI_start_idx : The index (into TRIALS_INFO) of the first trial in peh.
  584. error_check_last_trial : the trial after the last one in peh should
  585. have OUTCOME equal to this. (5 means future trial). Otherwise a warning
  586. will be printed. To disable, set to None.
  587. Returns DataFrame with coluns 'event' and time'
  588. """
  589. # These will be used if trial and stim numbers are to be inserted
  590. trialnumber = None
  591. stimnumber = None
  592. ti_idx = TI_start_idx
  593. # Check whether a full peh was passed of just one trial
  594. if not hasattr(peh, 'keys'):
  595. # must be a list of trials
  596. res = []
  597. for trial in peh:
  598. # Get the info about this trial
  599. if TRIALS_INFO is not None:
  600. trialnumber = TRIALS_INFO.TRIAL_NUMBER[ti_idx]
  601. stimnumber = TRIALS_INFO.STIM_NUMBER[ti_idx]
  602. ti_idx += 1
  603. # append to growing list
  604. res += generate_event_list_from_trial(trial, trialnumber,
  605. stimnumber)
  606. # error check that TRIALS_INFO is lined up right
  607. if TRIALS_INFO is not None and error_check_last_trial is not None:
  608. if TRIALS_INFO.OUTCOME[ti_idx] != error_check_last_trial:
  609. print("warning: last trial in TRIALS_INFO not right")
  610. else:
  611. # must be a single trial
  612. res = generate_event_list_from_trial(peh, trialnumber, stimnumber)
  613. # Return as record array
  614. #df = pandas.DataFrame(data=res, columns=['event', 'time'])
  615. df = np.rec.fromrecords(res, names=['event', 'time'])
  616. if sort:
  617. df = df[~np.isnan(df.time)]
  618. df = df[np.argsort(df.time)]
  619. return df
  620. def generate_event_list_from_trial(trial, trialnumber=None, stimnumber=None):
  621. """Helper function to operate on a single trial.
  622. Converts states in dictified trial to a list of records.
  623. First is event name, then event time.
  624. States are parsed into events called: state_name_in and state_name_out
  625. If trialnumber:
  626. a special state called 'trial_%d_in' will be added at
  627. 'state_0'
  628. If stimnumber:
  629. a special state called 'play_stimulus_%d_in' will be added at
  630. 'play_stimulus'
  631. """
  632. rec_l = []
  633. states = trial['states']
  634. for key, val in list(states.items()):
  635. if key == 'starting_state':
  636. # field for starting state, error check
  637. assert val == 'state_0'
  638. elif key == 'ending_state':
  639. # field for ending state, error check
  640. assert val == 'state_0'
  641. elif len(val) == 0:
  642. # This state did not occur
  643. pass
  644. elif val.ndim == 1:
  645. # occurred once
  646. assert len(val) == 2
  647. rec_l.append((key+'_in', float(val[0])))
  648. rec_l.append((key+'_out', float(val[1])))
  649. if key == 'play_stimulus' and stimnumber is not None:
  650. # Special processing to add stimulus number
  651. key2 = 'play_stimulus_%d' % stimnumber
  652. rec_l.append((key2+'_in', float(val[0])))
  653. rec_l.append((key2+'_out', float(val[1])))
  654. else:
  655. # occurred multiple times
  656. # key == state_0 should always end up here
  657. assert val.shape[1] == 2
  658. if key == 'state_0' and trialnumber is not None:
  659. # Special processing to add trial number
  660. key2 = 'trial_%d' % trialnumber
  661. rec_l.append((key2+'_in', float(val[0, 0])))
  662. rec_l.append((key2+'_out', float(val[0, 1])))
  663. for subval in val:
  664. rec_l.append((key+'_in', float(subval[0])))
  665. rec_l.append((key+'_out', float(subval[1])))
  666. return rec_l
  667. def demung_trials_info(bcl, TRIALS_INFO=None, CONSTS=None, sound_name=None,
  668. replace_consts=True, replace_stim_names=True, col2consts=None):
  669. """Returns DataFrame version of TRIALS_INFO matrix
  670. Current version of this loader returns a munged version of
  671. TRIALS_INFO with column names like ('CORRECT_SIDE', 'f1')
  672. which confuses pandas.
  673. This method replaces munged names like 'f1' with good names
  674. like 'correct_side'
  675. This will also replace
  676. bcl : Bcontrol_Loader object, already loaded.
  677. If None, then provide TRIALS_INFO and CONSTS
  678. replace_consts : if True, replaces integers with string constants
  679. replace_stim_names : adds column 'stim_name' based on the 1-indexed
  680. stimulus numbers in bcld, and values in SOUNDS_INFO
  681. col2consts : a dict mapper from column name to contained constants
  682. If None, a reasonable default is used
  683. """
  684. if bcl is not None:
  685. TRIALS_INFO = bcl.data['TRIALS_INFO']
  686. CONSTS = bcl.data['CONSTS']
  687. sound_name = bcl.data['SOUNDS_INFO']['sound_name']
  688. # Dict from column name to appropriate entries in CONSTS
  689. if col2consts is None:
  690. col2consts = {
  691. 'go_or_nogo': ('GO', 'NOGO', 'TWOAC'),
  692. 'outcome': ('HIT', 'SHORT_CPOKE', 'WRONG_PORT', 'ERROR',
  693. 'FUTURE_TRIAL', 'CHOICE_TIME_UP', 'CURRENT_TRIAL'),
  694. 'correct_side': ('LEFT', 'RIGHT'),
  695. }
  696. # Create DataFrame
  697. df = pandas.DataFrame.from_records(TRIALS_INFO)
  698. # Rename columns nicely
  699. rename_d = {}
  700. for munged_name in df.columns:
  701. good_name = TRIALS_INFO.dtype.fields[munged_name][2]
  702. rename_d[munged_name] = good_name.lower()
  703. df = df.rename(columns=rename_d)
  704. # Replace each integer in each column with appropriate string from CONSTS
  705. if replace_consts:
  706. for col, col_consts in list(col2consts.items()):
  707. newcol = np.empty(df[col].shape, dtype=np.object)
  708. for col_const in col_consts:
  709. #df[col][df[col] == CONSTS[col_const]] = col_const.lower()
  710. newcol[df[col] == CONSTS[col_const]] = col_const.lower()
  711. df[col] = newcol
  712. # Rename stimuli with appropriate names
  713. if replace_stim_names:
  714. newcol = np.array([sound_name[n-1] for n in df['stim_number']])
  715. df['stim_name'] = newcol
  716. return df