/tools/multivariate_stats/pca.py

https://bitbucket.org/cistrome/cistrome-harvard/ · Python · 130 lines · 111 code · 18 blank · 1 comment · 43 complexity · ba6e78bf4d6a599a43b60035671d7d52 MD5 · raw file

  1. #!/usr/bin/env python
  2. from galaxy import eggs
  3. import sys, string
  4. from rpy import *
  5. import numpy
  6. def stop_err(msg):
  7. sys.stderr.write(msg)
  8. sys.exit()
  9. infile = sys.argv[1]
  10. x_cols = sys.argv[2].split(',')
  11. method = sys.argv[3]
  12. outfile = sys.argv[4]
  13. outfile2 = sys.argv[5]
  14. if method == 'svd':
  15. scale = center = "FALSE"
  16. if sys.argv[6] == 'both':
  17. scale = center = "TRUE"
  18. elif sys.argv[6] == 'center':
  19. center = "TRUE"
  20. elif sys.argv[6] == 'scale':
  21. scale = "TRUE"
  22. fout = open(outfile,'w')
  23. elems = []
  24. for i, line in enumerate( file ( infile )):
  25. line = line.rstrip('\r\n')
  26. if len( line )>0 and not line.startswith( '#' ):
  27. elems = line.split( '\t' )
  28. break
  29. if i == 30:
  30. break # Hopefully we'll never get here...
  31. if len( elems )<1:
  32. stop_err( "The data in your input dataset is either missing or not formatted properly." )
  33. x_vals = []
  34. for k,col in enumerate(x_cols):
  35. x_cols[k] = int(col)-1
  36. x_vals.append([])
  37. NA = 'NA'
  38. skipped = 0
  39. for ind,line in enumerate( file( infile )):
  40. if line and not line.startswith( '#' ):
  41. try:
  42. fields = line.strip().split("\t")
  43. valid_line = True
  44. for k,col in enumerate(x_cols):
  45. try:
  46. xval = float(fields[col])
  47. except:
  48. skipped += 1
  49. valid_line = False
  50. break
  51. if valid_line:
  52. for k,col in enumerate(x_cols):
  53. xval = float(fields[col])
  54. x_vals[k].append(xval)
  55. except:
  56. skipped += 1
  57. x_vals1 = numpy.asarray(x_vals).transpose()
  58. dat= r.list(array(x_vals1))
  59. set_default_mode(NO_CONVERSION)
  60. try:
  61. if method == "cor":
  62. pc = r.princomp(r.na_exclude(dat), cor = r("TRUE"))
  63. elif method == "cov":
  64. pc = r.princomp(r.na_exclude(dat), cor = r("FALSE"))
  65. elif method=="svd":
  66. pc = r.prcomp(r.na_exclude(dat), center = r(center), scale = r(scale))
  67. except RException, rex:
  68. stop_err("Encountered error while performing PCA on the input data: %s" %(rex))
  69. set_default_mode(BASIC_CONVERSION)
  70. summary = r.summary(pc, loadings="TRUE")
  71. ncomps = len(summary['sdev'])
  72. if type(summary['sdev']) == type({}):
  73. comps_unsorted = summary['sdev'].keys()
  74. comps=[]
  75. sd = summary['sdev'].values()
  76. for i in range(ncomps):
  77. sd[i] = summary['sdev'].values()[comps_unsorted.index('Comp.%s' %(i+1))]
  78. comps.append('Comp.%s' %(i+1))
  79. elif type(summary['sdev']) == type([]):
  80. comps=[]
  81. for i in range(ncomps):
  82. comps.append('Comp.%s' %(i+1))
  83. sd = summary['sdev']
  84. print >>fout, "#Component\t%s" %("\t".join(["%s" % el for el in range(1,ncomps+1)]))
  85. print >>fout, "#Std. deviation\t%s" %("\t".join(["%.4g" % el for el in sd]))
  86. total_var = 0
  87. vars = []
  88. for s in sd:
  89. var = s*s
  90. total_var += var
  91. vars.append(var)
  92. for i,var in enumerate(vars):
  93. vars[i] = vars[i]/total_var
  94. print >>fout, "#Proportion of variance explained\t%s" %("\t".join(["%.4g" % el for el in vars]))
  95. print >>fout, "#Loadings\t%s" %("\t".join(["%s" % el for el in range(1,ncomps+1)]))
  96. xcolnames = ["c%d" %(el+1) for el in x_cols]
  97. if 'loadings' in summary: #in case of princomp
  98. loadings = 'loadings'
  99. elif 'rotation' in summary: #in case of prcomp
  100. loadings = 'rotation'
  101. for i,val in enumerate(summary[loadings]):
  102. print >>fout, "%s\t%s" %(xcolnames[i], "\t".join(["%.4g" % el for el in val]))
  103. print >>fout, "#Scores\t%s" %("\t".join(["%s" % el for el in range(1,ncomps+1)]))
  104. if 'scores' in summary: #in case of princomp
  105. scores = 'scores'
  106. elif 'x' in summary: #in case of prcomp
  107. scores = 'x'
  108. for obs,sc in enumerate(summary[scores]):
  109. print >>fout, "%s\t%s" %(obs+1, "\t".join(["%.4g" % el for el in sc]))
  110. r.pdf( outfile2, 8, 8 )
  111. r.biplot(pc)
  112. r.dev_off()