PageRenderTime 1471ms CodeModel.GetById 18ms RepoModel.GetById 0ms app.codeStats 0ms

/src/pathology.py

https://bitbucket.org/davidzhang/camda2013
Python | 242 lines | 234 code | 5 blank | 3 comment | 0 complexity | 9c17d30caa792409da52432dc0d3cfab MD5 | raw file
  1. #!/usr/bin/env python
  2. from sklearn.svm import SVC
  3. from sklearn.ensemble import RandomForestClassifier
  4. from sklearn.cross_validation import KFold
  5. from sklearn.cross_validation import StratifiedKFold
  6. from sklearn.metrics import roc_curve, auc, precision_score, classification_report
  7. import pandas as pd
  8. import numpy as np
  9. from scipy import interp
  10. import os
  11. import pylab as pl
  12. import config
  13. import util
  14. def randomforest(data, targets, tree_num=100):
  15. model = RandomForestClassifier(n_estimators=tree_num,
  16. n_jobs=4,
  17. max_features=data.shape[1]/2+1,
  18. verbose=0,
  19. oob_score=True,
  20. compute_importances=True,
  21. random_state=12345678)
  22. model.fit(data, targets)
  23. return model
  24. def get_control_data():
  25. # invivo single control
  26. folder = config.rat_invivo_single_folder
  27. '''
  28. fpath = os.path.join(folder, "expr_Control_3hr.csv")
  29. df_3hr = pd.read_csv(fpath, header=0, index_col=0)
  30. fpath = os.path.join(folder, "expr_Control_6hr.csv")
  31. df_6hr = pd.read_csv(fpath, header=0, index_col=0)
  32. fpath = os.path.join(folder, "expr_Control_9hr.csv")
  33. df_9hr = pd.read_csv(fpath, header=0, index_col=0)
  34. '''
  35. fpath = os.path.join(folder, "expr_Control_24hr.csv")
  36. df_24hr = pd.read_csv(fpath, header=0, index_col=0)
  37. #single_control_df = pd.concat([df_3hr, df_6hr, df_9hr, df_24hr])
  38. single_control_df = pd.concat([df_24hr])
  39. '''
  40. # invivo repeat control
  41. folder = config.rat_invivo_repeat_folder
  42. fpath = os.path.join(folder, "expr_Control_4day.csv")
  43. df_4day = pd.read_csv(fpath, header=0, index_col=0)
  44. fpath = os.path.join(folder, "expr_Control_8day.csv")
  45. df_8day = pd.read_csv(fpath, header=0, index_col=0)
  46. fpath = os.path.join(folder, "expr_Control_15day.csv")
  47. df_15day = pd.read_csv(fpath, header=0, index_col=0)
  48. fpath = os.path.join(folder, "expr_Control_29day.csv")
  49. df_29day = pd.read_csv(fpath, header=0, index_col=0)
  50. repeat_control_df = pd.concat([df_4day, df_8day, df_15day, df_29day])
  51. '''
  52. #control_df = pd.concat([single_control_df, repeat_control_df])
  53. control_df = pd.concat([single_control_df])
  54. del control_df['DILI_Class']
  55. return control_df
  56. def invivo():
  57. control_df = get_control_data()
  58. print "Control:", control_df.shape
  59. # single collapsed expression
  60. single_expr_df = pd.read_csv(config.rat_invivo_single_expression, header=None, index_col=False)
  61. sample_names = util.read_sample_names(config.rat_invivo_single_sample)
  62. gene_names = util.read_gene_names(config.rat_invivo_single_gene)
  63. single_expr_df.columns = gene_names
  64. single_expr_df.index = sample_names
  65. # repeat collapsed expression
  66. repeat_expr_df = pd.read_csv(config.rat_invivo_repeat_expression, header=None, index_col=False)
  67. sample_names = util.read_sample_names(config.rat_invivo_repeat_sample)
  68. gene_names = util.read_gene_names(config.rat_invivo_repeat_gene)
  69. repeat_expr_df.columns = gene_names
  70. repeat_expr_df.index = sample_names
  71. findings = ["Necrosis", "Hypertrophy", "Microgranuloma", "Cellular infiltration", "Change"]
  72. for finding in findings:
  73. pathology_df = pd.read_csv("../data/pathology/%s.csv" % finding, header=0, index_col=False)
  74. expr_df = pd.DataFrame(columns=single_expr_df.columns)
  75. for ind, row in single_expr_df.iterrows():
  76. snames = ind.split("_")
  77. snames = [int(sname.lstrip("00")) for sname in snames]
  78. for sname in snames:
  79. if sname in list(pathology_df['CEL']):
  80. expr_df = expr_df.append(row, verify_integrity=True)
  81. break
  82. for ind, row in repeat_expr_df.iterrows():
  83. snames = ind.split("_")
  84. snames = [int(sname.lstrip("00")) for sname in snames]
  85. for sname in snames:
  86. if sname in list(pathology_df['CEL']):
  87. expr_df = expr_df.append(row, verify_integrity=True)
  88. break
  89. targets = []
  90. targets.extend([0]*control_df.shape[0])
  91. targets.extend([1]*expr_df.shape[0])
  92. targets = np.array(targets)
  93. data_df = pd.concat([control_df, expr_df])
  94. print data_df.shape
  95. print targets.shape
  96. data = np.array(data_df)
  97. kf = KFold(data.shape[0], k=5, shuffle=True)
  98. mean_tpr = 0.0
  99. mean_fpr = np.linspace(0, 1, 100)
  100. all_tpr = []
  101. count = 1
  102. for train_index, test_index in kf:
  103. model = randomforest(data[train_index], targets[train_index], tree_num=100)
  104. probas_ = model.predict_proba(data[test_index])
  105. fpr, tpr, thresholds = roc_curve(targets[test_index], probas_[:, 1])
  106. mean_tpr += interp(mean_fpr, fpr, tpr)
  107. mean_tpr[0] = 0.0
  108. roc_auc = auc(fpr, tpr)
  109. count += 1
  110. mean_tpr /= len(kf)
  111. mean_tpr[-1] = 1.0
  112. mean_auc = auc(mean_fpr, mean_tpr)
  113. pl.plot(mean_fpr, mean_tpr, label='Mean AUC = %0.2f (%s)' % (mean_auc, finding), lw=2)
  114. pl.xlim([-0.05, 1.05])
  115. pl.ylim([-0.05, 1.05])
  116. pl.xlabel('False Positive Rate')
  117. pl.ylabel('True Positive Rate')
  118. pl.title('Mean ROC Curve')
  119. pl.legend(loc="lower right")
  120. pl.savefig('../plots/all_pathology.pdf')
  121. pl.show()
  122. def invitro():
  123. expr_df = pd.read_csv(config.rat_invitro_expression, header=None, index_col=False)
  124. sample_names = util.read_sample_names(config.rat_invitro_sample)
  125. gene_names = util.read_gene_names(config.rat_invitro_gene)
  126. expr_df.columns = gene_names
  127. expr_df.index = sample_names
  128. findings = ["Necrosis", "Hypertrophy", "Microgranuloma", "Cellular infiltration", "Change"]
  129. new_expr_df = pd.DataFrame(columns=expr_df.columns)
  130. for finding in findings:
  131. barcodes = []
  132. with open("../data/invitro_pathology/%s" % finding) as fh:
  133. for line in fh:
  134. barcodes.append(int(line))
  135. for ind, row in expr_df.iterrows():
  136. snames = ind.split("_")
  137. snames = [int(sname.lstrip("00")) for sname in snames]
  138. for sname in snames:
  139. if sname in barcodes:
  140. #new_expr_df = new_expr_df.append(row, verify_integrity=True)
  141. new_expr_df = new_expr_df.append(row)
  142. break
  143. for finding in findings:
  144. pos_data_df = pd.DataFrame(columns=new_expr_df.columns)
  145. neg_data_df = pd.DataFrame(columns=new_expr_df.columns)
  146. barcodes = []
  147. with open("../data/invitro_pathology/%s" % finding) as fh:
  148. for line in fh:
  149. barcodes.append(int(line))
  150. for ind, row in new_expr_df.iterrows():
  151. snames = ind.split("_")
  152. snames = [int(sname.lstrip("00")) for sname in snames]
  153. flag = True
  154. for sname in snames:
  155. if sname in barcodes:
  156. pos_data_df = pos_data_df.append(row)
  157. flag = False
  158. break
  159. if flag:
  160. neg_data_df = neg_data_df.append(row)
  161. targets = []
  162. targets.extend([1]*pos_data_df.shape[0])
  163. targets.extend([0]*neg_data_df.shape[0])
  164. targets = np.array(targets)
  165. data_df = pd.concat([pos_data_df, neg_data_df])
  166. data = np.array(data_df)
  167. print data.shape
  168. print targets.shape
  169. kf = KFold(data.shape[0], k=5, shuffle=True)
  170. mean_tpr = 0.0
  171. mean_fpr = np.linspace(0, 1, 100)
  172. all_tpr = []
  173. count = 1
  174. for train_index, test_index in kf:
  175. model = randomforest(data[train_index], targets[train_index], tree_num=100)
  176. probas_ = model.predict_proba(data[test_index])
  177. fpr, tpr, thresholds = roc_curve(targets[test_index], probas_[:, 1])
  178. mean_tpr += interp(mean_fpr, fpr, tpr)
  179. mean_tpr[0] = 0.0
  180. roc_auc = auc(fpr, tpr)
  181. count += 1
  182. mean_tpr /= len(kf)
  183. mean_tpr[-1] = 1.0
  184. mean_auc = auc(mean_fpr, mean_tpr)
  185. pl.plot(mean_fpr, mean_tpr, label='Mean AUC = %0.2f (%s)' % (mean_auc, finding), lw=2)
  186. pl.xlim([-0.05, 1.05])
  187. pl.ylim([-0.05, 1.05])
  188. pl.xlabel('False Positive Rate')
  189. pl.ylabel('True Positive Rate')
  190. pl.title('Mean ROC Curve')
  191. pl.legend(loc="lower right")
  192. pl.savefig('../plots/invitro_pathology.pdf')
  193. pl.show()
  194. def main():
  195. #invivo()
  196. invitro()
  197. if __name__ == "__main__":
  198. main()