PageRenderTime 26ms CodeModel.GetById 23ms RepoModel.GetById 0ms app.codeStats 0ms

/public/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantDataManager.java

https://github.com/alexischr/gatk
Java | 287 lines | 218 code | 34 blank | 35 comment | 74 complexity | 6c4660dff7986c8e3020d19727f3ea4d MD5 | raw file
  1. /*
  2. * Copyright (c) 2011 The Broad Institute
  3. *
  4. * Permission is hereby granted, free of charge, to any person
  5. * obtaining a copy of this software and associated documentation
  6. * files (the "Software"), to deal in the Software without
  7. * restriction, including without limitation the rights to use,
  8. * copy, modify, merge, publish, distribute, sublicense, and/or sell
  9. * copies of the Software, and to permit persons to whom the
  10. * Software is furnished to do so, subject to the following
  11. * conditions:
  12. *
  13. * The above copyright notice and this permission notice shall be
  14. * included in all copies or substantial portions of the Software.
  15. *
  16. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
  17. * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
  18. * OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
  19. * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
  20. * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
  21. * WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
  22. * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR
  23. * THE USE OR OTHER DEALINGS IN THE SOFTWARE.
  24. */
  25. package org.broadinstitute.sting.gatk.walkers.variantrecalibration;
  26. import org.apache.log4j.Logger;
  27. import org.broadinstitute.sting.gatk.GenomeAnalysisEngine;
  28. import org.broadinstitute.sting.gatk.contexts.AlignmentContext;
  29. import org.broadinstitute.sting.gatk.contexts.ReferenceContext;
  30. import org.broadinstitute.sting.gatk.refdata.RefMetaDataTracker;
  31. import org.broadinstitute.sting.utils.MathUtils;
  32. import org.broadinstitute.sting.utils.collections.ExpandingArrayList;
  33. import org.broadinstitute.sting.utils.exceptions.UserException;
  34. import org.broadinstitute.sting.utils.variantcontext.VariantContext;
  35. import java.io.PrintStream;
  36. import java.util.ArrayList;
  37. import java.util.Collections;
  38. import java.util.List;
  39. /**
  40. * Created by IntelliJ IDEA.
  41. * User: rpoplin
  42. * Date: Mar 4, 2011
  43. */
  44. public class VariantDataManager {
  45. private ExpandingArrayList<VariantDatum> data;
  46. private final double[] meanVector;
  47. private final double[] varianceVector; // this is really the standard deviation
  48. public final ArrayList<String> annotationKeys;
  49. private final ExpandingArrayList<TrainingSet> trainingSets;
  50. private final VariantRecalibratorArgumentCollection VRAC;
  51. protected final static Logger logger = Logger.getLogger(VariantDataManager.class);
  52. public VariantDataManager( final List<String> annotationKeys, final VariantRecalibratorArgumentCollection VRAC ) {
  53. this.data = null;
  54. this.annotationKeys = new ArrayList<String>( annotationKeys );
  55. this.VRAC = VRAC;
  56. meanVector = new double[this.annotationKeys.size()];
  57. varianceVector = new double[this.annotationKeys.size()];
  58. trainingSets = new ExpandingArrayList<TrainingSet>();
  59. }
  60. public void setData( final ExpandingArrayList<VariantDatum> data ) {
  61. this.data = data;
  62. }
  63. public ExpandingArrayList<VariantDatum> getData() {
  64. return data;
  65. }
  66. public void normalizeData() {
  67. boolean foundZeroVarianceAnnotation = false;
  68. for( int iii = 0; iii < meanVector.length; iii++ ) {
  69. final double theMean = mean(iii);
  70. final double theSTD = standardDeviation(theMean, iii);
  71. logger.info( annotationKeys.get(iii) + String.format(": \t mean = %.2f\t standard deviation = %.2f", theMean, theSTD) );
  72. if( Double.isNaN(theMean) ) {
  73. throw new UserException.BadInput("Values for " + annotationKeys.get(iii) + " annotation not detected for ANY training variant in the input callset. VariantAnnotator may be used to add these annotations. See http://www.broadinstitute.org/gsa/wiki/index.php/VariantAnnotator");
  74. }
  75. foundZeroVarianceAnnotation = foundZeroVarianceAnnotation || (theSTD < 1E-6);
  76. meanVector[iii] = theMean;
  77. varianceVector[iii] = theSTD;
  78. for( final VariantDatum datum : data ) {
  79. // Transform each data point via: (x - mean) / standard deviation
  80. datum.annotations[iii] = ( datum.isNull[iii] ? GenomeAnalysisEngine.getRandomGenerator().nextGaussian() : ( datum.annotations[iii] - theMean ) / theSTD );
  81. }
  82. }
  83. if( foundZeroVarianceAnnotation ) {
  84. throw new UserException.BadInput( "Found annotations with zero variance. They must be excluded before proceeding." );
  85. }
  86. // trim data by standard deviation threshold and mark failing data for exclusion later
  87. for( final VariantDatum datum : data ) {
  88. boolean remove = false;
  89. for( final double val : datum.annotations ) {
  90. remove = remove || (Math.abs(val) > VRAC.STD_THRESHOLD);
  91. }
  92. datum.failingSTDThreshold = remove;
  93. }
  94. }
  95. public void addTrainingSet( final TrainingSet trainingSet ) {
  96. trainingSets.add( trainingSet );
  97. }
  98. public boolean checkHasTrainingSet() {
  99. for( final TrainingSet trainingSet : trainingSets ) {
  100. if( trainingSet.isTraining ) { return true; }
  101. }
  102. return false;
  103. }
  104. public boolean checkHasTruthSet() {
  105. for( final TrainingSet trainingSet : trainingSets ) {
  106. if( trainingSet.isTruth ) { return true; }
  107. }
  108. return false;
  109. }
  110. public boolean checkHasKnownSet() {
  111. for( final TrainingSet trainingSet : trainingSets ) {
  112. if( trainingSet.isKnown ) { return true; }
  113. }
  114. return false;
  115. }
  116. public ExpandingArrayList<VariantDatum> getTrainingData() {
  117. final ExpandingArrayList<VariantDatum> trainingData = new ExpandingArrayList<VariantDatum>();
  118. for( final VariantDatum datum : data ) {
  119. if( datum.atTrainingSite && !datum.failingSTDThreshold && datum.originalQual > VRAC.QUAL_THRESHOLD ) {
  120. trainingData.add( datum );
  121. }
  122. }
  123. logger.info( "Training with " + trainingData.size() + " variants after standard deviation thresholding." );
  124. if( trainingData.size() < VRAC.MIN_NUM_BAD_VARIANTS ) {
  125. logger.warn( "WARNING: Training with very few variant sites! Please check the model reporting PDF to ensure the quality of the model is reliable." );
  126. }
  127. return trainingData;
  128. }
  129. public ExpandingArrayList<VariantDatum> selectWorstVariants( double bottomPercentage, final int minimumNumber ) {
  130. // The return value is the list of training variants
  131. final ExpandingArrayList<VariantDatum> trainingData = new ExpandingArrayList<VariantDatum>();
  132. // First add to the training list all sites overlapping any bad sites training tracks
  133. for( final VariantDatum datum : data ) {
  134. if( datum.atAntiTrainingSite && !datum.failingSTDThreshold && !Double.isInfinite(datum.lod) ) {
  135. trainingData.add( datum );
  136. }
  137. }
  138. final int numBadSitesAdded = trainingData.size();
  139. logger.info( "Found " + numBadSitesAdded + " variants overlapping bad sites training tracks." );
  140. // Next sort the variants by the LOD coming from the positive model and add to the list the bottom X percent of variants
  141. Collections.sort( data );
  142. final int numToAdd = Math.max( minimumNumber - trainingData.size(), Math.round((float)bottomPercentage * data.size()) );
  143. if( numToAdd > data.size() ) {
  144. throw new UserException.BadInput( "Error during negative model training. Minimum number of variants to use in training is larger than the whole call set. One can attempt to lower the --minNumBadVariants arugment but this is unsafe." );
  145. } else if( numToAdd == minimumNumber - trainingData.size() ) {
  146. logger.warn( "WARNING: Training with very few variant sites! Please check the model reporting PDF to ensure the quality of the model is reliable." );
  147. bottomPercentage = ((float) numToAdd) / ((float) data.size());
  148. }
  149. int index = 0, numAdded = 0;
  150. while( numAdded < numToAdd ) {
  151. final VariantDatum datum = data.get(index++);
  152. if( !datum.atAntiTrainingSite && !datum.failingSTDThreshold && !Double.isInfinite(datum.lod) ) {
  153. datum.atAntiTrainingSite = true;
  154. trainingData.add( datum );
  155. numAdded++;
  156. }
  157. }
  158. logger.info( "Additionally training with worst " + String.format("%.3f", (float) bottomPercentage * 100.0f) + "% of passing data --> " + (trainingData.size() - numBadSitesAdded) + " variants with LOD <= " + String.format("%.4f", data.get(index).lod) + "." );
  159. return trainingData;
  160. }
  161. public ExpandingArrayList<VariantDatum> getRandomDataForPlotting( int numToAdd ) {
  162. numToAdd = Math.min(numToAdd, data.size());
  163. final ExpandingArrayList<VariantDatum> returnData = new ExpandingArrayList<VariantDatum>();
  164. for( int iii = 0; iii < numToAdd; iii++) {
  165. final VariantDatum datum = data.get(GenomeAnalysisEngine.getRandomGenerator().nextInt(data.size()));
  166. if( !datum.failingSTDThreshold ) {
  167. returnData.add(datum);
  168. }
  169. }
  170. // Add an extra 5% of points from bad training set, since that set is small but interesting
  171. for( int iii = 0; iii < Math.floor(0.05*numToAdd); iii++) {
  172. final VariantDatum datum = data.get(GenomeAnalysisEngine.getRandomGenerator().nextInt(data.size()));
  173. if( datum.atAntiTrainingSite && !datum.failingSTDThreshold ) { returnData.add(datum); }
  174. else { iii--; }
  175. }
  176. return returnData;
  177. }
  178. private double mean( final int index ) {
  179. double sum = 0.0;
  180. int numNonNull = 0;
  181. for( final VariantDatum datum : data ) {
  182. if( datum.atTrainingSite && !datum.isNull[index] ) { sum += datum.annotations[index]; numNonNull++; }
  183. }
  184. return sum / ((double) numNonNull);
  185. }
  186. private double standardDeviation( final double mean, final int index ) {
  187. double sum = 0.0;
  188. int numNonNull = 0;
  189. for( final VariantDatum datum : data ) {
  190. if( datum.atTrainingSite && !datum.isNull[index] ) { sum += ((datum.annotations[index] - mean)*(datum.annotations[index] - mean)); numNonNull++; }
  191. }
  192. return Math.sqrt( sum / ((double) numNonNull) );
  193. }
  194. public void decodeAnnotations( final VariantDatum datum, final VariantContext vc, final boolean jitter ) {
  195. final double[] annotations = new double[annotationKeys.size()];
  196. final boolean[] isNull = new boolean[annotationKeys.size()];
  197. int iii = 0;
  198. for( final String key : annotationKeys ) {
  199. isNull[iii] = false;
  200. annotations[iii] = decodeAnnotation( key, vc, jitter );
  201. if( Double.isNaN(annotations[iii]) ) { isNull[iii] = true; }
  202. iii++;
  203. }
  204. datum.annotations = annotations;
  205. datum.isNull = isNull;
  206. }
  207. private static double decodeAnnotation( final String annotationKey, final VariantContext vc, final boolean jitter ) {
  208. double value;
  209. try {
  210. value = Double.parseDouble( (String)vc.getAttribute( annotationKey ) );
  211. if( Double.isInfinite(value) ) { value = Double.NaN; }
  212. if( jitter && annotationKey.equalsIgnoreCase("HRUN") ) { // Integer valued annotations must be jittered a bit to work in this GMM
  213. value += -0.25 + 0.5 * GenomeAnalysisEngine.getRandomGenerator().nextDouble();
  214. }
  215. if( jitter && annotationKey.equalsIgnoreCase("HaplotypeScore") && MathUtils.compareDoubles(value, 0.0, 0.0001) == 0 ) { value = -0.2 + 0.4*GenomeAnalysisEngine.getRandomGenerator().nextDouble(); }
  216. if( jitter && annotationKey.equalsIgnoreCase("FS") && MathUtils.compareDoubles(value, 0.0, 0.001) == 0 ) { value = -0.2 + 0.4*GenomeAnalysisEngine.getRandomGenerator().nextDouble(); }
  217. } catch( Exception e ) {
  218. value = Double.NaN; // The VQSR works with missing data by marginalizing over the missing dimension when evaluating the Gaussian mixture model
  219. }
  220. return value;
  221. }
  222. public void parseTrainingSets( final RefMetaDataTracker tracker, final ReferenceContext ref, final AlignmentContext context, final VariantContext evalVC, final VariantDatum datum, final boolean TRUST_ALL_POLYMORPHIC ) {
  223. datum.isKnown = false;
  224. datum.atTruthSite = false;
  225. datum.atTrainingSite = false;
  226. datum.atAntiTrainingSite = false;
  227. datum.prior = 2.0;
  228. datum.consensusCount = 0;
  229. for( final TrainingSet trainingSet : trainingSets ) {
  230. for( final VariantContext trainVC : tracker.getVariantContexts( ref, trainingSet.name, null, context.getLocation(), false, false ) ) {
  231. if( trainVC != null && trainVC.isNotFiltered() && trainVC.isVariant() &&
  232. ((evalVC.isSNP() && trainVC.isSNP()) || ((evalVC.isIndel()||evalVC.isMixed()) && (trainVC.isIndel()||trainVC.isMixed()))) &&
  233. (TRUST_ALL_POLYMORPHIC || !trainVC.hasGenotypes() || trainVC.isPolymorphic()) ) {
  234. datum.isKnown = datum.isKnown || trainingSet.isKnown;
  235. datum.atTruthSite = datum.atTruthSite || trainingSet.isTruth;
  236. datum.atTrainingSite = datum.atTrainingSite || trainingSet.isTraining;
  237. datum.prior = Math.max( datum.prior, trainingSet.prior );
  238. datum.consensusCount += ( trainingSet.isConsensus ? 1 : 0 );
  239. }
  240. if( trainVC != null ) {
  241. datum.atAntiTrainingSite = datum.atAntiTrainingSite || trainingSet.isAntiTraining;
  242. }
  243. }
  244. }
  245. }
  246. public void writeOutRecalibrationTable( final PrintStream RECAL_FILE ) {
  247. for( final VariantDatum datum : data ) {
  248. RECAL_FILE.println(String.format("%s,%d,%d,%.4f,%s",
  249. datum.contig, datum.start, datum.stop, datum.lod,
  250. (datum.worstAnnotation != -1 ? annotationKeys.get(datum.worstAnnotation) : "NULL")));
  251. }
  252. }
  253. }