/working/auteur/R/rjmcmc.bm.R

http://github.com/eastman/auteur · R · 212 lines · 168 code · 30 blank · 14 comment · 42 complexity · 3dcd5d37d5dd984a654748d48ebee4db MD5 · raw file

  1. #function for Markov sampling from a range of model complexities under the general process of Brownian motion evolution of continuous traits
  2. #author: LJ HARMON 2009, A HIPP 2009, and JM EASTMAN 2010
  3. rjmcmc.bm <-
  4. function (phy, dat, SE=0, ngen=1000, sample.freq=100, prob.mergesplit=0.05, prob.root=0.05, lambdaK=log(2), constrainK=FALSE, jumpsize=NULL, simplestart=FALSE, internal.only=FALSE, summary=TRUE, fileBase="result")
  5. {
  6. model="BM"
  7. heat=1 ## heating not currently implemented
  8. require(geiger)
  9. if(is.null(jumpsize)) {
  10. cat("CALIBRATING jumpsize...\n")
  11. adjustable.jumpsize=TRUE
  12. jumpsize=calibrate.jumpsize(phy, dat, nsteps=ngen/1000)
  13. } else {
  14. if(jumpsize<=0) stop("please supply a 'jumpsize' larger than 0")
  15. adjustable.jumpsize=FALSE
  16. }
  17. ### prepare data for rjMCMC
  18. cur.model <- model
  19. dataList <- prepare.data(phy, dat, SE)
  20. ape.tre <- dataList$ape.tre
  21. orig.dat <- dataList$orig.dat
  22. SE <- dataList$SE
  23. node.des <- sapply(unique(c(ape.tre$edge[1,1],ape.tre$edge[,2])), function(x) get.descendants.of.node(x, ape.tre))
  24. names(node.des) <- c(ape.tre$edge[1,1], unique(ape.tre$edge[,2]))
  25. # initialize parameters
  26. if(is.numeric(constrainK) & (constrainK > length(ape.tre$edge) | constrainK < 1)) stop("Constraint on rate shifts is nonsensical. Ensure that constrainK is at least 1 and less than the number of available nodes in the tree.")
  27. if(simplestart | is.numeric(constrainK) | internal.only) {
  28. if(is.numeric(constrainK)) {
  29. init.rate <- generate.starting.point(orig.dat, ape.tre, node.des, theta=FALSE, K=constrainK, jumpsize=jumpsize)
  30. } else {
  31. init.rate <- list(values=rep(fit.continuous(ape.tre,orig.dat),length(ape.tre$edge.length)),delta=rep(0,length(ape.tre$edge.length)))
  32. }
  33. } else {
  34. init.rate <- generate.starting.point(orig.dat, ape.tre, node.des, theta=FALSE, K=constrainK, jumpsize=jumpsize )
  35. }
  36. cur.rates <- init.rate$values
  37. cur.delta.rates <- init.rate$delta
  38. cur.root <- adjustvalue(mean(orig.dat), jumpsize)
  39. cur.vcv <- updatevcv(ape.tre, cur.rates)
  40. mod.cur = bm.lik.fn(cur.rates, cur.root, orig.dat, cur.vcv, SE)
  41. # proposal counts
  42. nRateProp = 0
  43. nRateSwapProp = 0
  44. nRcatProp = 0
  45. nRootProp = 0
  46. nRateOK = 0
  47. nRateSwapOK = 0
  48. nRcatOK = 0
  49. nRootOK = 0
  50. cOK=c(
  51. nRateOK,
  52. nRateSwapOK,
  53. nRcatOK,
  54. nRootOK
  55. )
  56. prob.rates=1-sum(c(2*prob.mergesplit, prob.root))
  57. if(prob.rates<0) stop("proposal frequencies must not exceed 1; adjust 'prob.mergesplit' and (or) 'prob.root'")
  58. proposal.rates<-orig.proposal.rates<-c(2*prob.mergesplit, prob.root, prob.rates)
  59. prop.cs<-orig.prop.cs<-cumsum(proposal.rates)
  60. # find jumpsize calibration generations
  61. if(adjustable.jumpsize) {
  62. tf=c()
  63. tenths=function(v)floor(0.1*v)
  64. nn=ngen
  65. while(1) {
  66. vv=tenths(nn)
  67. nn=vv
  68. if(nn<=1) break else tf=c(nn,tf)
  69. }
  70. tuneFreq=tf
  71. } else {
  72. tuneFreq=0
  73. }
  74. cur.acceptancerate=0
  75. tickerFreq=ceiling((ngen+max(tuneFreq))/30)
  76. # file handling
  77. if(summary) parmBase=paste(model, fileBase, "parameters/",sep=".") else parmBase=paste(".", runif(1), model, fileBase, "parameters/",sep="")
  78. if(!file.exists(parmBase)) dir.create(parmBase)
  79. parlogger(model=model, init=TRUE, node.des, parmBase=parmBase)
  80. errorLog=paste(parmBase,paste(cur.model, fileBase, "rjmcmc.errors.log",sep="."),sep="/")
  81. runLog=file(paste(parmBase,paste(cur.model, fileBase, "rjmcmc.log",sep="."),sep="/"),open='w+')
  82. generate.log(bundled.parms=NULL, cur.model, file=runLog, init=TRUE)
  83. ### Begin rjMCMC
  84. for (i in 1:(ngen+max(tuneFreq))) {
  85. lnLikelihoodRatio <- lnHastingsRatio <- lnPriorRatio <- 0
  86. startparms = c(nRateProp, nRateSwapProp, nRcatProp, nRootProp)
  87. if(internal.only) {
  88. tips.tmp=which(as.numeric(names(cur.delta.rates))<=Ntip(ape.tre))
  89. if(any(cur.delta.rates[tips.tmp]==1)) stop("Broken internal only sampling")
  90. }
  91. ## BM IMPLEMENTATION ##
  92. while(1) {
  93. cur.proposal=min(which(runif(1)<prop.cs))
  94. if (cur.proposal==1 & !constrainK) { # adjust rate categories
  95. nr=splitormerge(cur.delta.rates, cur.rates, ape.tre, node.des, lambdaK, theta=FALSE, internal.only)
  96. new.rates=nr$new.values
  97. new.delta.rates=nr$new.delta
  98. new.root=cur.root
  99. nRcatProp=nRcatProp+1
  100. lnHastingsRatio=nr$lnHastingsRatio
  101. lnPriorRatio=nr$lnPriorRatio
  102. break()
  103. } else if(cur.proposal==2) { # adjust root
  104. new.root=adjustvalue(cur.root, jumpsize)
  105. new.rates=cur.rates
  106. new.delta.rates=cur.delta.rates
  107. nRootProp=nRootProp+1
  108. break()
  109. } else if(cur.proposal==3){
  110. if(runif(1)>0.05 & sum(cur.delta.rates)>1) { # tune local rate
  111. new.rates=tune.rate(cur.rates, jumpsize)
  112. new.delta.rates=cur.delta.rates
  113. new.root=cur.root
  114. nRateSwapProp=nRateSwapProp+1
  115. break()
  116. } else { # adjust rates
  117. new.rates=adjustrate(cur.rates, jumpsize)
  118. new.delta.rates=cur.delta.rates
  119. new.root=cur.root
  120. nRateProp=nRateProp+1
  121. break()
  122. }
  123. }
  124. }
  125. if(any(new.rates!=cur.rates)) new.vcv=updatevcv(ape.tre, new.rates) else new.vcv=cur.vcv
  126. mod.new=NULL
  127. mod.new=try(bm.lik.fn(new.rates, new.root, orig.dat, new.vcv, SE), silent=TRUE)
  128. if(inherits(mod.new, "try-error")) {mod.new=as.list(mod.new); mod.new$lnL=-Inf}
  129. if(!is.infinite(mod.new$lnL)) {
  130. lnLikelihoodRatio = mod.new$lnL - mod.cur$lnL
  131. } else {
  132. mod.new$lnL=-Inf
  133. lnLikelihoodRatio = -Inf
  134. }
  135. # compare likelihoods
  136. endparms = c(nRateProp=nRateProp, nRateSwapProp=nRateSwapProp, nRcatProp=nRcatProp, nRootProp=nRootProp)
  137. r=assess.lnR((heat * lnLikelihoodRatio + heat * lnPriorRatio + lnHastingsRatio)->lnR)
  138. # potential errors
  139. if(is.infinite(mod.cur$lnL)) stop("starting point has exceptionally poor likelihood")
  140. if(r$error) generate.error.message(Ntip(phy), i, mod.cur, mod.new, lnLikelihoodRatio, lnPriorRatio, lnHastingsRatio, cur.delta.rates, new.delta.rates, errorLog)
  141. if (runif(1) <= r$r) { ## adopt proposal ##
  142. decision="adopt"
  143. cur.root <- new.root
  144. cur.rates <- new.rates
  145. cur.delta.rates <- new.delta.rates
  146. mod.cur <- mod.new
  147. cur.vcv <- new.vcv
  148. curr.lnL <- mod.new$lnL
  149. cOK <- determine.accepted.proposal(startparms, endparms, cOK)
  150. } else { ## deny proposal ##
  151. decision="reject"
  152. curr.lnL <- mod.cur$lnL
  153. }
  154. # iteration-specific functions
  155. if(i%%tickerFreq==0 & summary) {
  156. if(i==tickerFreq) cat("|",rep(" ",9),toupper("generations complete"),rep(" ",9),"|","\n")
  157. cat(". ")
  158. }
  159. if(i%%sample.freq==0 & i>max(tuneFreq)) {
  160. bundled.parms=list(gen=i-max(tuneFreq), mrate=exp(mean(log(cur.rates))), cats=sum(cur.delta.rates)+1, root=cur.root, lnL=curr.lnL, rates=cur.rates)
  161. generate.log(bundled.parms, cur.model, file=runLog)
  162. parlogger(model=model, init=FALSE, node.des=node.des, i=i-max(tuneFreq), curr.lnL=curr.lnL, cur.root=cur.root, cur.rates=cur.rates, cur.delta.rates=cur.delta.rates, parmBase=parmBase)
  163. }
  164. if(any(tuneFreq%in%i) & adjustable.jumpsize) {
  165. new.acceptancerate=sum(cOK)/sum(endparms)
  166. if((new.acceptancerate<cur.acceptancerate) & adjustable.jumpsize) {
  167. jumpsize=calibrate.jumpsize(phy, dat, jumpsizes=c(exp(log(jumpsize)-log(2)), exp(log(jumpsize)+log(2))))
  168. }
  169. cur.acceptancerate=new.acceptancerate
  170. }
  171. }
  172. # End rjMCMC
  173. close(runLog)
  174. # clear out jumpsize calibration directories if calibration run
  175. if(summary) {
  176. summarize.run(cOK, endparms, cur.model)
  177. cleanup.files(parmBase, cur.model, fileBase)
  178. } else {
  179. unlink(parmBase, recursive=TRUE)
  180. }
  181. return(list(acceptance.rate=sum(cOK)/sum(endparms), jumpsize=jumpsize))
  182. }