/archive/source.archive/bm.and.ou/trait.rjmcmc.09042010.R

http://github.com/eastman/auteur · R · 355 lines · 301 code · 32 blank · 22 comment · 74 complexity · 4beb9e3517f6a4403fb5ecfed0ad67ac MD5 · raw file

  1. ## Written by LJ HARMON, AH: 2008
  2. ## Updated by JM EASTMAN: 09.2010
  3. rjMCMC.trait<-function (phy, data, ngen = 1000, model = c("OU","BM"), probAlpha=0.05, probRegimes=0.10, probMergeSplit = 0.05, heat=1, fileBase = "result")
  4. {
  5. if(length(model)>1)stop(paste("Please specify either ", sQuote("OU"), " or ", sQuote("BM")))
  6. require(ouch)
  7. require(geiger)
  8. ### prepare objects for rjMCMC
  9. dataList <- prepare.ouchdata(phy, data) # check tree and data; convert into S4
  10. ape.tre <- dataList$ape.tre # S3 phylogeny
  11. orig.dat <- dataList$orig.dat # S3 data
  12. ouch.tre <- dataList$ouch.tre # S4 OUCH phylogeny
  13. ouch.dat <- dataList$ouch.dat # S4 OUCH data
  14. nn <- length(ape.tre$edge.length) # max number of rates (one per branch)
  15. currRates.c <- numeric(nn) # initialize array for recording which branch belongs to which rate category
  16. currRates.c[] <- 1 # assign one rate to all branches initially
  17. currRates <- numeric(nn) # initialize array for rate estimates branchwise
  18. currRates[] <- init <- 0.001 # initialize BMMrate to be small
  19. currAlpha <- 0.001 # initialize alpha
  20. currRegimes <- numeric(nrow(ouch.dat)) # initial set of selective regimes
  21. currRegimes[] <- 1 # initialize regime to be global
  22. currModel <- model # allow user to define family of models explored
  23. nRateProp = 0 # proposal count: update rate parameter(s)
  24. nRcatProp = 0
  25. nAlphaProp = 0 # proposal count: update alpha parameter
  26. nRegimesProp = 0
  27. nRateOK = 0 # proposal count: successful rate changes
  28. nRcatOK = 0
  29. nAlphaOK = 0 # proposal count: successful alpha updates
  30. nRegimesOK = 0
  31. cOK=c(nRateOK, nRcatOK, nAlphaOK, nRegimesOK)
  32. l <- array(dim=ngen) # number of MCMC generations
  33. l[] <- NA # initialize array for storing lnLs
  34. ### Begin rjMCMC
  35. cat("\trates\tregimes\t\talpha\t\tlnL\t\t\tmodel")
  36. for (i in 1:ngen) {
  37. lnLikelihoodRatio <- lnProposalRatio <- lnPriorRatio <- 0
  38. startparms=list(currRates, currRates.c, currAlpha, currRegimes)
  39. ## OU IMPLEMENTATION ## thetas currently optimized by ouch:::glssoln(); selective regimes are choosable (see Beta)
  40. if(currModel=="OU") {
  41. if (runif(1) < (2 * probMergeSplit)) {
  42. adj.regimes <- adj.rates <- FALSE
  43. if(runif(1) < 0.5) adj.regimes=TRUE else adj.rates=TRUE
  44. if(adj.regimes) {
  45. nr=split.or.merge(currRegimes, currRegimes, decide.s.or.m(currRegimes)->s.m)
  46. newRegimes=nr$new.categories
  47. newRates=currRates
  48. newRates.c=currRates.c
  49. nRegimesProp=nRegimesProp+1
  50. } else { # adjust rate categories
  51. nr=split.or.merge(currRates, currRates.c, decide.s.or.m(currRates.c)->s.m)
  52. newRates=nr$new.values
  53. newRates.c=nr$new.categories
  54. newRegimes=currRegimes
  55. nRcatProp=nRcatProp+1
  56. }
  57. newAlpha=currAlpha
  58. lnProposalRatio=nr$lnl.prop
  59. lnPriorRatio=nr$lnl.prior
  60. } else { # neither split nor merge
  61. if(runif(1)<probAlpha) { # adjust alpha
  62. newAlpha=adjust.rates(currAlpha)
  63. newRates=currRates
  64. newRates.c=currRates.c
  65. nAlphaProp=nAlphaProp+1
  66. } else { # adjust rates
  67. newRates=adjust.rates(currRates)
  68. newRates.c=fix.categories(newRates)
  69. newAlpha=currAlpha
  70. nRateProp=nRateProp+1
  71. }
  72. newRegimes=currRegimes
  73. }
  74. modCurr=lnl.OU(currRates, currAlpha, currRegimes, ape.tre, ouch.dat)
  75. modNew=lnl.OU(newRates, newAlpha, newRegimes, ape.tre, ouch.dat)
  76. lnLikelihoodRatio = modNew$lnL - modCurr$lnL
  77. r=assess.lnR(heat * lnLikelihoodRatio + heat * lnPriorRatio + lnProposalRatio)
  78. print(lnPriorRatio+lnProposalRatio)
  79. if (runif(1) <= r) { # adopt proposal
  80. currRates <- newRates
  81. currRates.c <- newRates.c
  82. currAlpha <- newAlpha
  83. currRegimes <- newRegimes
  84. best.lnL <- modNew$lnL
  85. } else { # deny proposal
  86. best.lnL <- modCurr$lnL
  87. }
  88. } else if(currModel=="BM"){
  89. if (runif(1) < (2 * probMergeSplit)) {
  90. nr=split.or.merge(currRates, currRates.c, decide.s.or.m(currRates.c)->s.m)
  91. newRates=nr$new.values
  92. newRates.c=nr$new.categories
  93. nRcatProp=nRcatProp+1
  94. lnProposalRatio=nr$lnl.prop
  95. lnPriorRatio=nr$lnl.prior
  96. } else { # neither split nor merge
  97. newRates=adjust.rates(currRates)
  98. newRates.c=fix.categories(newRates)
  99. nRateProp=nRateProp+1
  100. }
  101. modCurr=lnl.BM(currRates, ape.tre, ouch.dat)
  102. modNew=lnl.BM(newRates, ape.tre, ouch.dat)
  103. lnLikelihoodRatio = modNew$lnL - modCurr$lnL
  104. r=assess.lnR(heat * lnLikelihoodRatio + heat * lnPriorRatio + lnProposalRatio)
  105. print(lnPriorRatio+lnProposalRatio)
  106. if (runif(1) <= r) { # adopt proposal
  107. currRates <- newRates
  108. currRates.c <- newRates.c
  109. best.lnL <- modNew$lnL
  110. } else { # deny proposal
  111. best.lnL <- modCurr$lnL
  112. }
  113. }
  114. l[i]=best.lnL
  115. # if(currModel=="BM") {
  116. # cat(paste("\n", max(currRates.c),"\t\t\t", round(best.lnL,2), currModel, sep="\t\t"))
  117. # } else {
  118. # cat(paste("\n", max(currRates.c), max(currRegimes), round(currAlpha,5), round(best.lnL,2), currModel, sep="\t\t"))
  119. # }
  120. endparms=list(currRates, currRates.c, currAlpha, currRegimes)
  121. cOK=tally.mcmc.parms(startparms, endparms, cOK)
  122. }
  123. # End rjMCMC
  124. cProp=c(nRateProp, nRcatProp, nAlphaProp, nRegimesProp)
  125. names(cProp)<-names(cOK)<-c("rates", "rcats", "alpha", "sRegimes")
  126. names(cProp)=paste("prop.",names(cProp),sep="")
  127. names(cOK)=paste("okay.",names(cOK),sep="")
  128. cat("\n")
  129. print(cProp)
  130. print(cOK)
  131. }
  132. ## AUXILLIARY FUNCTIONS
  133. lnl.OU <- function(rates, alpha, regimes, apetree, ouch.dat) { # from OUCH 2.7-1:::hansen
  134. ouch.dat$regimes=regimes
  135. ouch.tre=apply.BMM.to.tree(rates, apetree)
  136. beta=ouch:::regime.spec(ouch.tre, ouch.dat['regimes'])
  137. dat=do.call(c,lapply(ouch.dat['trait'],function(y)y[ouch.tre@term]))
  138. res=ouch:::ou.lik.fn(ouch.tre, as.matrix(alpha), as.matrix(1), beta, dat)
  139. list(theta=res$coef, lnL=-0.5*res$deviance)
  140. }
  141. lnl.BM <- function(rates, apetree, ouch.dat) { # from OUCH 2.7-1:::brown
  142. ouch.tre=apply.BMM.to.tree(rates, apetree)
  143. dat=do.call(c,lapply(ouch.dat['trait'],function(y)y[ouch.tre@term]))
  144. nterm <- ouch.tre@nterm
  145. nchar <- 1
  146. w <- matrix(data=1,nrow=nterm,ncol=1)
  147. b <- ouch.tre@branch.times
  148. sols <- ouch:::glssoln(w,dat,b)
  149. e <- sols$residuals # residuals
  150. q <- t(e)%*%solve(b,e)
  151. v <- q/nterm
  152. dev <- nchar*nterm*(1+log(2*pi))+nchar*log(det(b))+nterm*log(det(v))
  153. list(lnL = -0.5*dev)
  154. }
  155. assess.lnR <- function(lnR) {
  156. if(lnR > 0) {
  157. r=1
  158. } else if(lnR < -100) {
  159. r=0
  160. } else {
  161. r=exp(lnR)
  162. }
  163. r
  164. }
  165. decide.s.or.m <- function(categories) {
  166. if (max(categories) == 1) {
  167. return("split")
  168. }
  169. else if (max(categories) == length(categories)) {
  170. return("merge")
  171. }
  172. else if (runif(1) < 0.5) {
  173. return("split")
  174. }
  175. else return("merge")
  176. }
  177. split.or.merge <- function(values, categories, task=c("split","merge"))
  178. {
  179. vv=values
  180. cc=categories
  181. new.vv=vv
  182. new.cc=cc
  183. nn=length(values)
  184. if(task=="merge") {
  185. ncat <- max(cc)
  186. m.cc <- sample(1:ncat)[1:2]
  187. new.cc[new.cc == m.cc[1]] <- m.cc[2]
  188. new.cc <- fix.categories(new.cc)
  189. or1 <- vv[cc == m.cc[1]][1]
  190. or2 <- vv[cc == m.cc[2]][1]
  191. n1 <- sum(cc == m.cc[1])
  192. n2 <- sum(cc == m.cc[2])
  193. nr <- (n1 * or1 + n2 * or2)/(n1 + n2)
  194. new.vv[cc == m.cc[1] | cc == m.cc[2]] <- nr
  195. numSplittableBeforeMerge <- sum(table(cc) > 1)
  196. numSplittableAfterMerge <- sum(table(new.cc) > 1)
  197. if (max(cc) == nn) {
  198. probMerge = 1
  199. } else {
  200. probMerge = 0.5
  201. }
  202. if (max(new.cc) == 1) {
  203. probSplit = 1
  204. } else {
  205. probSplit = 0.5
  206. }
  207. factor = probSplit/probMerge
  208. lnProposalRatio = log(factor) + log(nn) + log(ncat) - log(numSplittableAfterMerge) - log(2^(n1 + n2) - 2) - log(nr/mean(vv)) - log(n1 + n2)
  209. lnPriorRatio = log(Stirling2(nn, ncat)) - log(Stirling2(nn, ncat - 1))
  210. }
  211. if(task=="split") {
  212. ncat <- max(cc)
  213. while (1) {
  214. rcat <- round(runif(1) * ncat + 0.5)
  215. nc <- sum(cc == rcat)
  216. if (nc > 1)
  217. break
  218. }
  219. while (1) {
  220. new <- round(runif(nc) * 2 + 0.5)
  221. if (length(table(new)) == 2)
  222. break
  223. }
  224. new.cc[cc == rcat][new == 2] <- ncat + 1
  225. or <- vv[cc == rcat][1]
  226. n1 <- sum(new == 1)
  227. n2 <- sum(new == 2)
  228. u <- runif(1, min = -0.5 * n1 * or, max = 0.5 * n2 * or)
  229. nr1 <- or + u/n1
  230. nr2 <- or - u/n2
  231. new.vv[new.cc == rcat] <- nr1
  232. new.vv[new.cc == ncat + 1] <- nr2
  233. new.cc <- fix.categories(new.cc)
  234. numSplittableBeforeSplit <- sum(table(cc) > 1)
  235. numSplittableAfterSplit <- sum(table(new.cc) > 1)
  236. if (max(cc) == 1) {
  237. probSplit = 1
  238. } else {
  239. probSplit = 0.5
  240. }
  241. if (max(new.cc) == nn) {
  242. probMerge = 1
  243. } else {
  244. probMerge = 0.5
  245. }
  246. factor = probMerge/probSplit
  247. lnProposalRatio = log(factor) + log(numSplittableBeforeSplit) + log(2^(n1 + n2) - 2) + log(or/mean(vv)) + log(n1 + n2) - log(nn) - log(ncat + 1)
  248. lnPriorRatio = log(Stirling2(nn, ncat)) - log(Stirling2(nn, ncat + 1))
  249. }
  250. list(new.values=new.vv, new.categories=new.cc, lnl.prop=lnProposalRatio, lnl.prior=lnPriorRatio)
  251. }
  252. fix.categories <- function(x)
  253. {
  254. f<-factor(x)
  255. n<-nlevels(f)
  256. o<-numeric(n)
  257. for(i in 1:n) o[i]<-which(f==levels(f)[i])[1]
  258. nf<-ordered(f, levels(f)[order(o)])
  259. nx<-as.numeric(nf)
  260. nx
  261. }
  262. apply.BMM.to.tree <- function(rates, phy) { # SLOW STEP: convert -> deconvert tree
  263. phy$edge.length=phy$edge.length*rates
  264. ouch.tre=ape2ouch(phy)
  265. ouch.tre
  266. }
  267. adjust.value <- function(value) {
  268. rr=value
  269. rch <- runif(1, min = -0.5, max = 0.5)
  270. rr <- rr + rch
  271. rr
  272. }
  273. adjust.rates <- function(values) {
  274. vv=values
  275. ncat <- max(fix.categories(vv)->cc)
  276. rcat <- round(runif(1) * (ncat) + 0.5)
  277. r <- log(vv[cc == rcat][1])
  278. rch <- runif(1, min = -2, max = 2)
  279. nr <- r + rch
  280. vv[cc == rcat] <- exp(nr)
  281. vv
  282. }
  283. prepare.ouchdata <- function(phy, data) {
  284. td <- treedata(phy, data, sort = T)
  285. ouch.tre <- ape2ouch(td$phy->ape.tre)
  286. ouch.dat=data.frame(as.numeric(ouch.tre@nodes), as.character(ouch.tre@nodelabels))
  287. names(ouch.dat)=c("nodes","labels")
  288. ouch.dat[ouch.dat[,2]=="",2]=NA
  289. ouch.dat$trait=NA
  290. mM=match(ouch.dat$labels,names(data))
  291. for(i in 1:nrow(ouch.dat)){
  292. if(!is.na(mM[i])){
  293. ouch.dat$trait[i]=data[mM[i]]
  294. }
  295. }
  296. return(list(ouch.tre=ouch.tre, ouch.dat=ouch.dat, ape.tre=ape.tre, orig.dat=data))
  297. }
  298. tally.mcmc.parms<-function(startparms, endparms, cOK) {
  299. index=array(dim=length(startparms))
  300. for(i in 1:length(startparms)) {
  301. index[i]=all(startparms[[i]]==endparms[[i]])
  302. }
  303. if(!all(index)) cOK[which(index==FALSE)]=cOK[which(index==FALSE)]+1
  304. cOK
  305. }
  306. Stirling2 <- function(n,m)
  307. {
  308. ## Purpose: Stirling Numbers of the 2-nd kind
  309. ## S^{(m)}_n = number of ways of partitioning a set of
  310. ## $n$ elements into $m$ non-empty subsets
  311. ## Author: Martin Maechler, Date: May 28 1992, 23:42
  312. ## ----------------------------------------------------------------
  313. ## Abramowitz/Stegun: 24,1,4 (p. 824-5 ; Table 24.4, p.835)
  314. ## Closed Form : p.824 "C."
  315. ## ----------------------------------------------------------------
  316. if (0 > m || m > n) stop("'m' must be in 0..n !")
  317. k <- 0:m
  318. sig <- rep(c(1,-1)*(-1)^m, length= m+1)# 1 for m=0; -1 1 (m=1)
  319. ## The following gives rounding errors for (25,5) :
  320. ## r <- sum( sig * k^n /(gamma(k+1)*gamma(m+1-k)) )
  321. ga <- gamma(k+1)
  322. round(sum( sig * k^n /(ga * rev(ga))))
  323. }