PageRenderTime 61ms CodeModel.GetById 32ms RepoModel.GetById 0ms app.codeStats 0ms

/oncodrivefm/command/full.py

https://bitbucket.org/bbglab/oncodrivefm
Python | 289 lines | 200 code | 71 blank | 18 comment | 26 complexity | 4e7daf66a91ba6f2e1bf2058cbd87593 MD5 | raw file
Possible License(s): LGPL-2.1
  1. import datetime
  2. import os
  3. import os.path
  4. import numpy as np
  5. import pandas as pd
  6. import matplotlib.pyplot as plt
  7. import operator
  8. from oncodrivefm import tdm
  9. from oncodrivefm.analysis import OncodriveFmAnalysis
  10. from oncodrivefm.command.base import Command
  11. from oncodrivefm.method.factory import create_method
  12. from oncodrivefm.method.empirical import EmpiricalTest
  13. from oncodrivefm.method.zscore import ZscoreTest
  14. from oncodrivefm.mapping import MatrixMapping
  15. from oncodrivefm.utils.labelfilter import LabelFilter
  16. class FullCommand(Command):
  17. def __init__(self):
  18. Command.__init__(self, prog="oncodrivefm", desc="Compute the FM bias for genes and pathways")
  19. def _add_arguments(self, parser):
  20. Command._add_arguments(self, parser)
  21. parser.add_argument("data_path", metavar="DATA",
  22. help="File containing the data matrix in TDM format")
  23. parser.add_argument("-N", "--samplings", dest="num_samplings", type=int, default=10000, metavar="NUMBER",
  24. help="Number of samplings to compute the FM bias pvalue")
  25. parser.add_argument("-e", "--estimator", dest="estimator", metavar="ESTIMATOR",
  26. choices=["mean", "median"], default="mean",
  27. help="Test estimator for computation.")
  28. parser.add_argument("--gt", "--gene-threshold", dest="mut_gene_threshold", type=int, default=2,
  29. metavar="THRESHOLD",
  30. help="Minimum number of mutations per gene to compute the FM bias")
  31. parser.add_argument("--pt", "--pathway-threshold", dest="mut_pathway_threshold", type=int, default=10,
  32. metavar="THRESHOLD",
  33. help="Minimum number of mutations per pathway to compute the FM bias")
  34. parser.add_argument("-s", "--slices", dest="slices", metavar="SLICES",
  35. help="Slices to process separated by commas")
  36. parser.add_argument("-m", "--mapping", dest="mapping", metavar="PATH",
  37. help="File with mappings between genes and pathways to be analysed")
  38. parser.add_argument("-f", "--filter", dest="filter", metavar="PATH",
  39. help="File containing the features to be filtered. By default labels are includes,"
  40. " labels preceded with - are excludes.")
  41. parser.add_argument("--save-data", dest="save_data", default=False, action="store_true",
  42. help="The input data matrix will be saved")
  43. parser.add_argument("--save-analysis", dest="save_analysis", default=False, action="store_true",
  44. help="The analysis results will be saved")
  45. parser.add_argument("--pathways_only", dest="pathways_only", default=False, action="store_true",
  46. help="Run only the pathways analysis")
  47. parser.add_argument("--plots", dest="plots", default=False, action="store_true",
  48. help="Print plots for quality control")
  49. def _check_args(self):
  50. Command._check_args(self)
  51. if self.args.analysis_name is None:
  52. self.args.analysis_name, ext = os.path.splitext(os.path.basename(self.args.data_path))
  53. if self.args.num_samplings < 1:
  54. self._error("Number of samplings out of range [2, ..)")
  55. if self.args.mut_gene_threshold < 1:
  56. self._error("Minimum number of mutations per gene out of range [1, ..)")
  57. if self.args.mut_pathway_threshold < 1:
  58. self._error("Minimum number of mutations per pathway out of range [1, ..)")
  59. if self.args.mapping is not None and not os.path.isfile(self.args.mapping):
  60. self._error("Pathways mapping file not found: {0}".format(self.args.mapping))
  61. def run(self):
  62. Command.run(self)
  63. # Load filter
  64. self.filter = LabelFilter()
  65. if self.args.filter is not None:
  66. self.log.info("Loading filter ...")
  67. self.log.debug(" > {0}".format(self.args.filter))
  68. self.filter.load(self.args.filter)
  69. self.log.debug(" {0} includes, {1} excludes".format(
  70. self.filter.include_count, self.filter.exclude_count))
  71. # Load data
  72. self.log.info("Loading data ...")
  73. self.log.debug(" > {0}".format(self.args.data_path))
  74. # TODO: Support loading plain matrices: /file.tsv#slice=SIFT
  75. self.matrix = tdm.load_matrix(self.args.data_path)
  76. self.log.debug(" {0} rows, {1} columns and {2} slices".format(
  77. self.matrix.num_rows, self.matrix.num_cols, self.matrix.num_slices))
  78. # Get selected slice indices
  79. if self.args.slices is not None:
  80. slices = []
  81. for name in self.args.slices.split(","):
  82. name = name.strip()
  83. if name not in self.matrix.slice_name_index:
  84. raise Exception("Slice not found: {0}".format(name))
  85. slices += [self.matrix.slice_name_index[name]]
  86. else:
  87. slices = range(self.matrix.num_slices)
  88. col_names = [self.matrix.slice_names[i] for i in slices]
  89. if self.args.save_data:
  90. for i in slices:
  91. slice_name = self.matrix.slice_names[i]
  92. self.log.info("Saving {0} data matrix ...".format(slice_name))
  93. self.save_matrix(self.args.output_path, self.args.analysis_name, self.args.output_format,
  94. self.matrix.row_names, self.matrix.col_names, self.matrix.data[i],
  95. suffix="data-{0}".format(slice_name))
  96. if not self.args.pathways_only:
  97. # GENES ---------------------------------------
  98. # One to one mapping for genes
  99. map = {}
  100. for row_name in self.matrix.row_names:
  101. if self.filter.valid(row_name):
  102. map[row_name] = (row_name,)
  103. genes_mapping = MatrixMapping(self.matrix, map)
  104. genes_method_name = "{0}-{1}".format(self.args.estimator, EmpiricalTest.NAME)
  105. # Analysis for genes
  106. self.log.info("Analysing genes with '{0}' ...".format(genes_method_name))
  107. analysis = OncodriveFmAnalysis(
  108. "oncodrivefm.genes",
  109. num_samplings=self.args.num_samplings,
  110. mut_threshold=self.args.mut_gene_threshold,
  111. num_cores=self.args.num_cores)
  112. results = analysis.compute(self.matrix, genes_mapping, genes_method_name, slices)
  113. method = create_method(genes_method_name)
  114. self.log.info("Adding individual results to results ...")
  115. # sort genes and slices according to their mapping dicts {name -> index}
  116. sorted_gene_names = [tup[0] for tup in sorted(genes_mapping.group_name_index.items(), key=operator.itemgetter(1))]
  117. sorted_slices = [self.matrix.slice_names[s].upper() + "_PVALUE" for s in slices]
  118. results_pandas = pd.DataFrame(results, columns=sorted_gene_names).transpose()
  119. results_pandas.sort_index(inplace=True)
  120. results_pandas.columns = sorted_slices
  121. # Combination for genes
  122. self.log.info("Combining analysis results ...")
  123. combined_results = method.combine(np.ma.masked_invalid(results.T))
  124. combined_results_pandas = pd.DataFrame(combined_results, columns=sorted_gene_names).transpose()
  125. combined_results_pandas.columns = method.combination_columns
  126. combined_results_pandas.sort_index(inplace=True)
  127. combined_results_pandas._metadata = ['filter', 'test']
  128. combined_results_pandas = combined_results_pandas.join(results_pandas)
  129. combined_results_pandas.dropna(how='all', inplace=True)
  130. self.log.info("Saving genes combined results ...")
  131. self.save_pandas_matrix(self.args.output_path, self.args.analysis_name, self.args.output_format,
  132. genes_mapping.group_names, method.combination_columns, combined_results_pandas,
  133. suffix="genes", params=[("slices", ",".join(col_names)), ("method", method.name)])
  134. # self.save_matrix(self.args.output_path, self.args.analysis_name, self.args.output_format,
  135. # genes_mapping.group_names, method.combination_columns, combined_results.T,
  136. # params=[("slices", ",".join(col_names)), ("method", method.name)], suffix="genes",
  137. # valid_row=lambda row: sum([1 if np.isnan(v) else 0 for v in row]) == 0)
  138. if self.args.plots:
  139. self.qqplot(combined_results_pandas, self.args.output_path, self.args.analysis_name, suffix="genes")
  140. # Exit if there is no mapping
  141. if self.args.mapping is None:
  142. return
  143. # PATHWAYS ---------------------------------------
  144. # Load pathways mappping
  145. self.log.info("Loading pathways mapping ...")
  146. self.log.debug(" > {0}".format(self.args.mapping))
  147. pathways_mapping = self.load_mapping(self.matrix, self.args.mapping, filt=self.filter)
  148. self.log.debug(" {0} pathways".format(pathways_mapping.num_groups))
  149. pathways_method_name = "{0}-{1}".format(self.args.estimator, ZscoreTest.NAME)
  150. # Analysis for pathways
  151. self.log.info("Analysing pathways with '{0}' ...".format(pathways_method_name))
  152. analysis = OncodriveFmAnalysis(
  153. "oncodrivefm.pathways",
  154. num_samplings=self.args.num_samplings,
  155. mut_threshold=self.args.mut_pathway_threshold,
  156. num_cores=self.args.num_cores)
  157. results = analysis.compute(self.matrix, pathways_mapping, pathways_method_name, slices)
  158. method = create_method(pathways_method_name)
  159. if self.args.save_analysis:
  160. self.log.info("Saving pathways analysis results ...")
  161. self.save_splitted_results(
  162. self.args.output_path, self.args.analysis_name, self.args.output_format,
  163. self.matrix, pathways_mapping,
  164. method, results, slices, suffix="pathways")
  165. # Combination for pathways
  166. self.log.info("Combining analysis results ...")
  167. combined_results = method.combine(np.ma.masked_invalid(results.T))
  168. self.log.info("Saving pathways combined results ...")
  169. self.save_matrix(self.args.output_path, self.args.analysis_name, self.args.output_format,
  170. pathways_mapping.group_names, method.combination_columns, combined_results.T,
  171. params=[("slices", ",".join(col_names)), ("method", method.name)], suffix="pathways",
  172. valid_row=lambda row: sum([1 if np.isnan(v) else 0 for v in row]) == 0)
  173. def qqplot(self, data_frame, output_path, analysis_name, suffix=""):
  174. ## Courtesy of Loris Mularoni
  175. pvalue_cols = [hit for hit in filter(lambda x:'PVALUE' in x, data_frame.columns)]
  176. NCOLS = 3
  177. NROWS = int(len(pvalue_cols) / NCOLS) + 1
  178. WIDTH = 16
  179. fig = plt.figure(figsize=(WIDTH, WIDTH / float(NCOLS) * NROWS))
  180. axs = [plt.subplot2grid((NROWS, NCOLS), (N // NCOLS, N % NCOLS)) for N in range(NCOLS * NROWS)]
  181. self.log.info("Plotting for {},{} ...".format(analysis_name, suffix))
  182. upper_limit = -np.log10(1.0/self.args.num_samplings)
  183. for i, pvalue_col in enumerate(pvalue_cols):
  184. ylabel = pvalue_col
  185. ax = axs[i]
  186. obs_pvalues = data_frame[pvalue_col].map(lambda x: -np.log10(x))
  187. obs_pvalues.sort()
  188. exp_pvalues = -1 * np.log10(np.arange(1, len(data_frame) + 1) / float(len(data_frame)))
  189. exp_pvalues.sort()
  190. ax.scatter(exp_pvalues, obs_pvalues, alpha=0.5)
  191. ax.set_xlabel("expected pvalues")
  192. ax.set_ylabel("observed pvalues")
  193. ax.plot(np.linspace(0, upper_limit), np.linspace(0, upper_limit), 'r--')
  194. ax.set_xlim(-0.2, upper_limit)
  195. ax.set_ylim(-0.2, upper_limit)
  196. ax.set_title("{} ({})".format(pvalue_col, analysis_name))
  197. ax.set_ylabel(ylabel)
  198. plt.tight_layout()
  199. suffix = "-" + suffix if len(suffix) > 0 else ""
  200. output_file = os.path.join(output_path, analysis_name + suffix + ".png")
  201. plt.savefig(output_file, bbox_inches='tight')
  202. def main():
  203. FullCommand().run()
  204. if __name__ == "__main__":
  205. start = datetime.datetime.now()
  206. main()
  207. end = datetime.datetime.now()
  208. print(end - start)