PageRenderTime 53ms CodeModel.GetById 25ms RepoModel.GetById 1ms app.codeStats 0ms

/AnalysisLFP.py

https://github.com/cxrodgers/ns5_process
Python | 345 lines | 272 code | 54 blank | 19 comment | 68 complexity | 15c466407814aaeba9c03bb8d0d0cf11 MD5 | raw file
  1. from __future__ import print_function
  2. from __future__ import absolute_import
  3. from builtins import zip
  4. from builtins import range
  5. from builtins import object
  6. import numpy as np
  7. import matplotlib.pyplot as plt
  8. import os.path
  9. import glob
  10. import pandas
  11. from . import SpikeAnalysis
  12. from . import myutils
  13. from matplotlib import mlab
  14. import scipy.stats
  15. class MultiServeLFP(object):
  16. def __init__(self, lfp_dir_list=None, tdf_filename_list=None, filter_list=None):
  17. self.lfp_dir_list = lfp_dir_list
  18. self.tdf_filename_list = tdf_filename_list
  19. self.ssl_list = None
  20. self.filter_list = filter_list
  21. def refresh_files(self):
  22. self.ssl_list = []
  23. if self.filter_list is not None:
  24. fla = np.array(self.filter_list)
  25. for lfp_dir, tdf_filename in zip(self.lfp_dir_list,
  26. self.tdf_filename_list):
  27. ssl = SingleServeLFP(dirname=lfp_dir, tdf_file=tdf_filename)
  28. if self.filter_list is not None:
  29. keep_tets = fla[fla[:, 0] == ssl.session][:, 1].astype(np.int)
  30. else:
  31. keep_tets = None
  32. ssl.split_on_filter = keep_tets
  33. self.ssl_list.append(ssl)
  34. self.session_list = [ssl.session for ssl in self.ssl_list]
  35. def read(self):
  36. self.refresh_files()
  37. ddf = pandas.DataFrame()
  38. bins = None
  39. tt = None
  40. for ssl in self.ssl_list:
  41. df, t = ssl.read(return_t=True, include_trials='hits')
  42. if len(df) > 0:
  43. bins = ssl.bins
  44. tt = ssl.t
  45. ddf = ddf.append(df, ignore_index=True)
  46. self._lfp = ddf
  47. self.t = tt
  48. self.bins = bins
  49. def average_grouped_by_sound(self, detrend=None, **kwargs):
  50. # Optionally filter by keywork, eg tetrode=[2]
  51. if len(kwargs) > 0:
  52. lfpf = self.lfp.copy()
  53. for key, val in list(kwargs.items()):
  54. lfpf = lfpf[lfpf[key].isin(val)]
  55. else:
  56. lfpf = self.lfp
  57. # Iterate over sound * block
  58. g = lfpf.groupby(['sound', 'block'])
  59. res = {}
  60. for sound in ['lehi', 'rihi', 'lelo', 'rilo']:
  61. res[sound] = {}
  62. for block in ['LB', 'PB']:
  63. df = lfpf.ix[g.groups[sound, block]]
  64. # Now try to group by session * tetrode
  65. g2 = df.groupby(['session', 'tetrode'])
  66. if len(g2.groups) == 1:
  67. # Only a single session * tetrode, plot across trials
  68. res[sound][block] = df[self.bins]
  69. else:
  70. ddf = pandas.DataFrame()
  71. res[sound][block] = np.array([
  72. df.ix[val][self.bins].mean(axis=0)
  73. for val in list(g2.groups.values())])
  74. # Optional detrend
  75. if detrend == 'baseline':
  76. pre_bins = self.bins[self.t < 0.0]
  77. baseline = res[sound][block][pre_bins].mean(axis=1)
  78. tb = np.tile(baseline[:, np.newaxis], (1, len(self.t)))
  79. res[sound][block] -= tb
  80. elif detrend is not None:
  81. raise ValueError("detrend mode not supported")
  82. return res
  83. @property
  84. def lfp(self):
  85. if hasattr(self, '_lfp'):
  86. return self._lfp
  87. else:
  88. self.read()
  89. return self._lfp
  90. class SingleServeLFP(object):
  91. def __init__(self, dirname=None, filename_filter=None, tdf_file=None,
  92. split_on_filter=None):
  93. self.dirname = dirname
  94. if dirname is not None:
  95. self.session = os.path.split(self.dirname)[1]
  96. self.filename_filter = filename_filter
  97. self.tdf_file = tdf_file
  98. self.split_on_filter = split_on_filter
  99. if self.split_on_filter is None:
  100. self.split_on_filter = list(range(16))
  101. def refresh_files(self):
  102. self.lfp_filenames = []
  103. self.lfp_tetrodes = []
  104. if self.split_on_filter is None:
  105. self.split_on_filter = list(range(16))
  106. if self.filename_filter is None:
  107. putative_files = sorted(glob.glob(os.path.join(
  108. self.dirname, '*')))
  109. else:
  110. putative_files = sorted(glob.glob(os.path.join(
  111. self.dirname, ('*%s*' % self.filename_filter))))
  112. for fn in putative_files:
  113. for tetrode in self.split_on_filter:
  114. m = glob.re.search('\.lfp\.%d\.npz$' % tetrode, fn)
  115. if m is not None:
  116. self.lfp_filenames.append(fn)
  117. self.lfp_tetrodes.append(tetrode)
  118. self._load_tdf()
  119. def _load_tdf(self):
  120. try:
  121. self._tdf = pandas.load(self.tdf_file)
  122. except IOError:
  123. print("warning: cannot load trials %s" % self.tdf_file)
  124. self._tdf = None
  125. def _read_numpy_z_format(self, fn):
  126. nz = np.load(fn)
  127. lfp = nz['lfp']
  128. stored_trial_numbers = nz['trial_numbers']
  129. N = nz['lfp'].shape[0]
  130. if len(stored_trial_numbers) > N:
  131. # sometimes stored duplicate!
  132. assert np.all(stored_trial_numbers[:N] == stored_trial_numbers[N:])
  133. stored_trial_numbers = stored_trial_numbers[:N]
  134. t = nz['t']
  135. nz.close()
  136. return lfp, t, stored_trial_numbers
  137. def read(self, return_t=False, include_trials='hits', stim_number_filter=None):
  138. """Return DataFrame containing all LFP from this session"""
  139. if stim_number_filter is None:
  140. stim_number_filter = list(range(5, 13))
  141. # Search directory for files
  142. self.refresh_files()
  143. bigdf = pandas.DataFrame()
  144. t_vals = None
  145. bins = None
  146. for lfp_filename, tetrode in zip(self.lfp_filenames, self.lfp_tetrodes):
  147. # Load numpy z format
  148. lfp, t, trial_numbers = self._read_numpy_z_format(lfp_filename)
  149. if t_vals is None:
  150. t_vals = t
  151. else:
  152. assert np.all(t_vals - t < 1e-7)
  153. bins = np.array(['t%d' % n for n in range(lfp.shape[1])])
  154. df = pandas.DataFrame(lfp, columns=bins)
  155. df.insert(loc=0, column='trial', value=trial_numbers)
  156. df.insert(loc=0, column='tetrode', value=tetrode)
  157. df.insert(loc=0, column='session',
  158. value=[os.path.split(self.dirname)[1]]*len(df))
  159. bigdf = bigdf.append(df, ignore_index=True)
  160. self.t = t_vals
  161. self.bins = bins
  162. if len(bigdf) > 0 and self.tdf is not None:
  163. bigdf = bigdf.join(self.tdf[
  164. ['block', 'outcome', 'stim_number', 't_center', 'nonrandom']], on='trial')
  165. if include_trials == 'hits':
  166. bigdf = bigdf[(bigdf.outcome == 1) & (bigdf.nonrandom == 0)]
  167. bigdf.pop('nonrandom')
  168. bigdf.pop('outcome')
  169. bigdf = bigdf[bigdf.stim_number.isin(stim_number_filter)]
  170. SpikeAnalysis.replace_stim_numbers_with_names(bigdf)
  171. if return_t:
  172. return bigdf, self.t
  173. else:
  174. return bigdf
  175. @property
  176. def lfp(self):
  177. if hasattr(self, '_lfp'):
  178. return self._lfp
  179. else:
  180. self._lfp = self.read()
  181. return self._lfp
  182. @property
  183. def tdf(self):
  184. if hasattr(self, '_tdf'):
  185. return self._tdf
  186. else:
  187. self._load_tdf()
  188. return self._tdf
  189. def average_grouped_by_sound(self, detrend=None, **kwargs):
  190. # Optionally filter by keywork, eg tetrode=[2]
  191. if len(kwargs) > 0:
  192. lfpf = self.lfp.copy()
  193. for key, val in list(kwargs.items()):
  194. lfpf = lfpf[lfpf[key].isin(val)]
  195. else:
  196. lfpf = self.lfp
  197. # Iterate over sound * block
  198. g = lfpf.groupby(['sound', 'block'])
  199. res = {}
  200. for sound in ['lehi', 'rihi', 'lelo', 'rilo']:
  201. res[sound] = {}
  202. for block in ['LB', 'PB']:
  203. df = lfpf.ix[g.groups[sound, block]]
  204. # Now try to group by session * tetrode
  205. g2 = df.groupby(['session', 'tetrode'])
  206. if len(g2.groups) == 1:
  207. # Only a single session * tetrode, plot across trials
  208. res[sound][block] = df[self.bins]
  209. else:
  210. ddf = pandas.DataFrame()
  211. res[sound][block] = np.array([
  212. df.ix[val][self.bins].mean(axis=0)
  213. for val in list(g2.groups.values())])
  214. # Optional detrend
  215. if detrend == 'baseline':
  216. pre_bins = self.bins[self.t < 0.0]
  217. baseline = res[sound][block][pre_bins].mean(axis=1)
  218. tb = np.tile(baseline[:, np.newaxis], (1, len(self.t)))
  219. res[sound][block] -= tb
  220. elif detrend is not None:
  221. raise ValueError("detrend mode not supported")
  222. return res
  223. def plot_lfp_grouped_by_sound(ssl, plot_difference=True, p_adj_meth=None,
  224. mark_significance=True, t_start=None, t_stop=None, **kwargs):
  225. # First get grouped averages
  226. res_d = ssl.average_grouped_by_sound(**kwargs)
  227. # set time limits
  228. if t_start is None:
  229. t_start = ssl.t[0]
  230. if t_stop is None:
  231. t_stop = ssl.t[-1]
  232. t1bin = np.argmin(np.abs(ssl.t - t_start))
  233. t2bin = np.argmin(np.abs(ssl.t - t_stop)) + 1
  234. f = plt.figure()
  235. for n, sound in enumerate(['lehi', 'rihi', 'lelo', 'rilo']):
  236. # Create an axis for this sound and plot both blocks
  237. ax = f.add_subplot(2, 2, n+1)
  238. for block in ['LB', 'PB']:
  239. myutils.plot_mean_trace(ax=ax, x=ssl.t[t1bin:t2bin],
  240. data=res_d[sound][block][:, t1bin:t2bin], label=block)
  241. # Optionally plot difference
  242. if plot_difference:
  243. di = res_d[sound]['LB'] - res_d[sound]['PB']
  244. myutils.plot_mean_trace(ax=ax, x=ssl.t[t1bin:t2bin],
  245. data=di[:, t1bin:t2bin], label='diff', color='m')
  246. # Optionally mark significance
  247. if mark_significance:
  248. p_vals = scipy.stats.ttest_rel(res_d[sound]['LB'][:, t1bin:t2bin],
  249. res_d[sound]['PB'][:, t1bin:t2bin])[1]
  250. if p_adj_meth is not None:
  251. p_vals = myutils.r_adj_pval(p_vals, meth=p_adj_meth)
  252. pp = np.where(p_vals < .05)[0]
  253. plt.plot(ssl.t[t1bin:t2bin][pp], np.zeros_like(pp), 'k*')
  254. pp = np.where(p_vals < .01)[0]
  255. plt.plot(ssl.t[t1bin:t2bin][pp], np.zeros_like(pp), 'ko',
  256. markerfacecolor='w')
  257. plt.legend(loc='best')
  258. ax.set_title(sound)
  259. ax.set_xlim((t_start, t_stop))
  260. plt.show()
  261. def get_tetrode_filter(ratname=None):
  262. fn_d = {
  263. 'CR12B': '/media/STELLATE/20111208_CR12B_allsessions_sorted/data_params_CR12B.csv',
  264. 'CR17B': '/media/STELLATE/20110907_CR17B_allsessions_sorted/data_params_CR17B.csv',
  265. 'CR13A': '/media/STELLATE/20110816_CR13A_allsessions_sorted/data_params_CR13A.csv'
  266. }
  267. if ratname is None:
  268. l = []
  269. for r in list(fn_d.keys()):
  270. l += get_tetrode_filter(r)
  271. return l
  272. dp = mlab.csv2rec(fn_d[ratname])
  273. tetrode_filter = []
  274. for row in dp:
  275. if row['session_type'] != 'behaving':
  276. continue
  277. for t in myutils.parse_space_sep(row['auditory_tetrodes']):
  278. tetrode_filter.append((row['session_name'], t))
  279. return sorted(tetrode_filter)
  280. def get_subdir_list(ratname):
  281. fn_d = {
  282. 'CR12B': '/media/STELLATE/20111208_CR12B_allsessions_sorted/*behaving',
  283. 'CR17B': '/media/STELLATE/20110907_CR17B_allsessions_sorted/*behaving',
  284. 'CR13A': '/media/STELLATE/20110816_CR13A_allsessions_sorted/*behaving',
  285. }
  286. return sorted(glob.glob(fn_d[ratname]))
  287. def build_tdf_filename_list(subdir_list):
  288. session_list = [os.path.split(subdir)[1] for subdir in subdir_list]
  289. tdf_file_list = ['/media/TBLABDATA/20111208_frame_data/%s_trials' % session
  290. for session in session_list]
  291. return tdf_file_list