PageRenderTime 36ms CodeModel.GetById 18ms app.highlight 13ms RepoModel.GetById 1ms app.codeStats 0ms

/archive/source.archive/bm.and.ou/trait.rjmcmc.09042010.R

http://github.com/eastman/auteur
R | 355 lines | 301 code | 32 blank | 22 comment | 74 complexity | 4beb9e3517f6a4403fb5ecfed0ad67ac MD5 | raw file
  1## Written by LJ HARMON, AH: 2008
  2## Updated by JM EASTMAN: 09.2010
  3
  4rjMCMC.trait<-function (phy, data, ngen = 1000, model = c("OU","BM"), probAlpha=0.05, probRegimes=0.10, probMergeSplit = 0.05, heat=1, fileBase = "result") 
  5{
  6	if(length(model)>1)stop(paste("Please specify either ", sQuote("OU"), " or ", sQuote("BM")))
  7	
  8	require(ouch)
  9	require(geiger)
 10	
 11### prepare objects for rjMCMC
 12	dataList <- prepare.ouchdata(phy, data)					# check tree and data; convert into S4
 13	ape.tre <- dataList$ape.tre								# S3 phylogeny
 14	orig.dat <- dataList$orig.dat							# S3 data
 15	ouch.tre <- dataList$ouch.tre							# S4 OUCH phylogeny
 16	ouch.dat <- dataList$ouch.dat							# S4 OUCH data
 17	nn <- length(ape.tre$edge.length)						# max number of rates (one per branch)
 18	currRates.c <- numeric(nn)								# initialize array for recording which branch belongs to which rate category
 19    currRates.c[] <- 1										# assign one rate to all branches initially
 20    currRates <- numeric(nn)								# initialize array for rate estimates branchwise
 21    currRates[] <- init <- 0.001							# initialize BMMrate to be small
 22	currAlpha <- 0.001										# initialize alpha
 23	currRegimes <- numeric(nrow(ouch.dat))					# initial set of selective regimes
 24	currRegimes[] <- 1										# initialize regime to be global
 25	currModel <- model										# allow user to define family of models explored
 26    nRateProp = 0											# proposal count: update rate parameter(s)
 27	nRcatProp = 0
 28	nAlphaProp = 0											# proposal count: update alpha parameter
 29	nRegimesProp = 0
 30    nRateOK = 0												# proposal count: successful rate changes
 31	nRcatOK = 0
 32	nAlphaOK = 0											# proposal count: successful alpha updates
 33	nRegimesOK = 0
 34	cOK=c(nRateOK, nRcatOK, nAlphaOK, nRegimesOK)
 35	l <- array(dim=ngen)									# number of MCMC generations
 36    l[] <- NA												# initialize array for storing lnLs
 37	
 38### Begin rjMCMC
 39	cat("\trates\tregimes\t\talpha\t\tlnL\t\t\tmodel")
 40    for (i in 1:ngen) {
 41        lnLikelihoodRatio <- lnProposalRatio <- lnPriorRatio <- 0
 42		startparms=list(currRates, currRates.c, currAlpha, currRegimes)
 43
 44## OU IMPLEMENTATION ## thetas currently optimized by ouch:::glssoln(); selective regimes are choosable (see Beta)
 45		if(currModel=="OU") {
 46			if (runif(1) < (2 * probMergeSplit)) {
 47				adj.regimes <- adj.rates <- FALSE
 48				if(runif(1) < 0.5) adj.regimes=TRUE else adj.rates=TRUE
 49				if(adj.regimes) {
 50					nr=split.or.merge(currRegimes, currRegimes, decide.s.or.m(currRegimes)->s.m)
 51					newRegimes=nr$new.categories
 52					newRates=currRates
 53					newRates.c=currRates.c
 54					nRegimesProp=nRegimesProp+1
 55				} else { # adjust rate categories
 56					nr=split.or.merge(currRates, currRates.c, decide.s.or.m(currRates.c)->s.m)
 57					newRates=nr$new.values
 58					newRates.c=nr$new.categories
 59					newRegimes=currRegimes
 60					nRcatProp=nRcatProp+1
 61				}
 62				newAlpha=currAlpha
 63				lnProposalRatio=nr$lnl.prop
 64				lnPriorRatio=nr$lnl.prior
 65			} else { # neither split nor merge
 66				if(runif(1)<probAlpha) { # adjust alpha
 67					newAlpha=adjust.rates(currAlpha)
 68					newRates=currRates
 69					newRates.c=currRates.c
 70					nAlphaProp=nAlphaProp+1
 71				} else { # adjust rates
 72					newRates=adjust.rates(currRates)
 73					newRates.c=fix.categories(newRates)
 74					newAlpha=currAlpha
 75					nRateProp=nRateProp+1
 76				}
 77				newRegimes=currRegimes
 78			}			
 79			modCurr=lnl.OU(currRates, currAlpha, currRegimes, ape.tre, ouch.dat)
 80			modNew=lnl.OU(newRates, newAlpha, newRegimes, ape.tre, ouch.dat)
 81			
 82			lnLikelihoodRatio = modNew$lnL - modCurr$lnL
 83			r=assess.lnR(heat * lnLikelihoodRatio + heat * lnPriorRatio + lnProposalRatio)
 84			print(lnPriorRatio+lnProposalRatio)
 85	
 86			if (runif(1) <= r) {			# adopt proposal
 87				currRates <- newRates
 88				currRates.c <- newRates.c
 89				currAlpha <- newAlpha
 90				currRegimes <- newRegimes
 91				best.lnL <- modNew$lnL
 92			} else {						# deny proposal			
 93				best.lnL <- modCurr$lnL  
 94			}
 95		} else if(currModel=="BM"){
 96			if (runif(1) < (2 * probMergeSplit)) {
 97				nr=split.or.merge(currRates, currRates.c, decide.s.or.m(currRates.c)->s.m)
 98				newRates=nr$new.values
 99				newRates.c=nr$new.categories
100				nRcatProp=nRcatProp+1
101				lnProposalRatio=nr$lnl.prop
102				lnPriorRatio=nr$lnl.prior
103			} else { # neither split nor merge
104				newRates=adjust.rates(currRates)
105				newRates.c=fix.categories(newRates)
106				nRateProp=nRateProp+1
107			}
108			
109			modCurr=lnl.BM(currRates, ape.tre, ouch.dat)
110			modNew=lnl.BM(newRates, ape.tre, ouch.dat)
111			
112			lnLikelihoodRatio = modNew$lnL - modCurr$lnL
113			r=assess.lnR(heat * lnLikelihoodRatio + heat * lnPriorRatio + lnProposalRatio)
114			print(lnPriorRatio+lnProposalRatio)
115			
116			if (runif(1) <= r) {			# adopt proposal
117				currRates <- newRates
118				currRates.c <- newRates.c
119				best.lnL <- modNew$lnL
120			} else {						# deny proposal			
121				best.lnL <- modCurr$lnL  
122			}
123		}
124		
125		l[i]=best.lnL
126#		if(currModel=="BM") {
127#			cat(paste("\n", max(currRates.c),"\t\t\t", round(best.lnL,2), currModel, sep="\t\t"))
128#		} else {
129#			cat(paste("\n", max(currRates.c), max(currRegimes), round(currAlpha,5), round(best.lnL,2), currModel, sep="\t\t"))
130#		}
131		
132		endparms=list(currRates, currRates.c, currAlpha, currRegimes)
133		cOK=tally.mcmc.parms(startparms, endparms, cOK)
134	}
135# End rjMCMC
136	cProp=c(nRateProp, nRcatProp, nAlphaProp, nRegimesProp)
137	names(cProp)<-names(cOK)<-c("rates", "rcats", "alpha", "sRegimes")
138	names(cProp)=paste("prop.",names(cProp),sep="")
139	names(cOK)=paste("okay.",names(cOK),sep="")
140	cat("\n")
141	print(cProp)
142	print(cOK)
143}
144
145
146## AUXILLIARY FUNCTIONS
147
148lnl.OU <- function(rates, alpha, regimes, apetree, ouch.dat) { # from OUCH 2.7-1:::hansen
149	ouch.dat$regimes=regimes
150	ouch.tre=apply.BMM.to.tree(rates, apetree)
151	beta=ouch:::regime.spec(ouch.tre, ouch.dat['regimes'])
152	dat=do.call(c,lapply(ouch.dat['trait'],function(y)y[ouch.tre@term]))
153	res=ouch:::ou.lik.fn(ouch.tre, as.matrix(alpha), as.matrix(1), beta, dat)
154	list(theta=res$coef, lnL=-0.5*res$deviance)
155}
156
157lnl.BM <- function(rates, apetree, ouch.dat) { # from OUCH 2.7-1:::brown
158	ouch.tre=apply.BMM.to.tree(rates, apetree)
159	dat=do.call(c,lapply(ouch.dat['trait'],function(y)y[ouch.tre@term]))
160	nterm <- ouch.tre@nterm
161	nchar <- 1
162	w <- matrix(data=1,nrow=nterm,ncol=1)
163	b <- ouch.tre@branch.times
164	sols <- ouch:::glssoln(w,dat,b)
165	e <- sols$residuals # residuals
166	q <- t(e)%*%solve(b,e)
167	v <- q/nterm
168	dev <- nchar*nterm*(1+log(2*pi))+nchar*log(det(b))+nterm*log(det(v))
169	list(lnL = -0.5*dev)	
170}
171
172assess.lnR <- function(lnR) {
173	if(lnR > 0) {
174		r=1 
175	} else if(lnR < -100) {
176		r=0 
177	} else {
178		r=exp(lnR)
179	}
180	r
181}
182
183decide.s.or.m <- function(categories) {
184	if (max(categories) == 1) {
185		return("split")
186	}
187	else if (max(categories) == length(categories)) {
188		return("merge")
189	}
190	else if (runif(1) < 0.5) {
191		return("split")
192	}
193	else return("merge")
194}
195	
196split.or.merge <- function(values, categories, task=c("split","merge")) 
197{	
198	vv=values
199	cc=categories
200	new.vv=vv
201	new.cc=cc
202	nn=length(values)
203	
204	if(task=="merge") {
205		ncat <- max(cc)
206		m.cc <- sample(1:ncat)[1:2]
207		new.cc[new.cc == m.cc[1]] <- m.cc[2]
208		new.cc <- fix.categories(new.cc)
209		or1 <- vv[cc == m.cc[1]][1]
210		or2 <- vv[cc == m.cc[2]][1]
211		n1 <- sum(cc == m.cc[1])
212		n2 <- sum(cc == m.cc[2])
213		nr <- (n1 * or1 + n2 * or2)/(n1 + n2)
214		new.vv[cc == m.cc[1] | cc == m.cc[2]] <- nr
215		numSplittableBeforeMerge <- sum(table(cc) > 1)
216		numSplittableAfterMerge <- sum(table(new.cc) > 1)
217		if (max(cc) == nn) {
218			probMerge = 1
219		} else {
220			probMerge = 0.5
221		}
222		
223		if (max(new.cc) == 1) {
224			probSplit = 1
225		} else {
226			probSplit = 0.5
227		}
228		factor = probSplit/probMerge
229		lnProposalRatio = log(factor) + log(nn) + log(ncat) - log(numSplittableAfterMerge) - log(2^(n1 + n2) - 2) - log(nr/mean(vv)) - log(n1 + n2)
230		lnPriorRatio = log(Stirling2(nn, ncat)) - log(Stirling2(nn, ncat - 1))
231	}
232	if(task=="split") {
233		ncat <- max(cc)
234		while (1) {
235			rcat <- round(runif(1) * ncat + 0.5)
236			nc <- sum(cc == rcat)
237			if (nc > 1) 
238			break
239		}
240		while (1) {
241			new <- round(runif(nc) * 2 + 0.5)
242			if (length(table(new)) == 2) 
243			break
244		}
245		new.cc[cc == rcat][new == 2] <- ncat + 1
246		or <- vv[cc == rcat][1]
247		n1 <- sum(new == 1)
248		n2 <- sum(new == 2)
249		u <- runif(1, min = -0.5 * n1 * or, max = 0.5 * n2 * or)
250		nr1 <- or + u/n1
251		nr2 <- or - u/n2
252		new.vv[new.cc == rcat] <- nr1
253		new.vv[new.cc == ncat + 1] <- nr2
254		new.cc <- fix.categories(new.cc)
255		numSplittableBeforeSplit <- sum(table(cc) > 1)
256		numSplittableAfterSplit <- sum(table(new.cc) > 1)
257		if (max(cc) == 1) {
258			probSplit = 1
259		} else {
260			probSplit = 0.5
261		}
262		
263		if (max(new.cc) == nn) {
264			probMerge = 1
265		} else {
266			probMerge = 0.5
267		}
268		
269		factor = probMerge/probSplit
270		lnProposalRatio = log(factor) + log(numSplittableBeforeSplit) + log(2^(n1 + n2) - 2) + log(or/mean(vv)) + log(n1 + n2) - log(nn) - log(ncat + 1)
271		lnPriorRatio = log(Stirling2(nn, ncat)) - log(Stirling2(nn, ncat + 1))
272	}
273	list(new.values=new.vv, new.categories=new.cc, lnl.prop=lnProposalRatio, lnl.prior=lnPriorRatio)
274}
275
276fix.categories <- function(x)
277{
278	f<-factor(x)
279	n<-nlevels(f)
280	o<-numeric(n)
281	for(i in 1:n) o[i]<-which(f==levels(f)[i])[1]
282	nf<-ordered(f, levels(f)[order(o)])	
283	nx<-as.numeric(nf)
284	nx
285}
286
287apply.BMM.to.tree <- function(rates, phy) { # SLOW STEP: convert -> deconvert tree
288	phy$edge.length=phy$edge.length*rates
289	ouch.tre=ape2ouch(phy)
290	ouch.tre
291}
292
293adjust.value <- function(value) {
294	rr=value
295	rch <- runif(1, min = -0.5, max = 0.5)
296	rr <- rr + rch
297	rr
298}
299
300adjust.rates <- function(values) {
301	vv=values
302	ncat <- max(fix.categories(vv)->cc)
303	rcat <- round(runif(1) * (ncat) + 0.5)
304	r <- log(vv[cc == rcat][1])
305	rch <- runif(1, min = -2, max = 2)
306	nr <- r + rch
307	vv[cc == rcat] <- exp(nr)
308	vv
309}
310
311prepare.ouchdata <- function(phy, data) {
312	td <- treedata(phy, data, sort = T)	
313	ouch.tre <- ape2ouch(td$phy->ape.tre)		
314	ouch.dat=data.frame(as.numeric(ouch.tre@nodes), as.character(ouch.tre@nodelabels))
315	names(ouch.dat)=c("nodes","labels")
316	ouch.dat[ouch.dat[,2]=="",2]=NA
317	ouch.dat$trait=NA
318	mM=match(ouch.dat$labels,names(data))
319	for(i in 1:nrow(ouch.dat)){
320		if(!is.na(mM[i])){
321			ouch.dat$trait[i]=data[mM[i]]	
322		}
323	}
324	return(list(ouch.tre=ouch.tre, ouch.dat=ouch.dat, ape.tre=ape.tre, orig.dat=data))
325}
326
327tally.mcmc.parms<-function(startparms, endparms, cOK) {
328	index=array(dim=length(startparms))
329	for(i in 1:length(startparms)) {
330		index[i]=all(startparms[[i]]==endparms[[i]])
331	}
332	if(!all(index)) cOK[which(index==FALSE)]=cOK[which(index==FALSE)]+1
333	cOK
334}
335	
336Stirling2 <- function(n,m)
337{
338    ## Purpose:  Stirling Numbers of the 2-nd kind
339    ## 		S^{(m)}_n = number of ways of partitioning a set of
340    ##                      $n$ elements into $m$ non-empty subsets
341    ## Author: Martin Maechler, Date:  May 28 1992, 23:42
342    ## ----------------------------------------------------------------
343    ## Abramowitz/Stegun: 24,1,4 (p. 824-5 ; Table 24.4, p.835)
344    ## Closed Form : p.824 "C."
345    ## ----------------------------------------------------------------
346
347    if (0 > m || m > n) stop("'m' must be in 0..n !")
348    k <- 0:m
349    sig <- rep(c(1,-1)*(-1)^m, length= m+1)# 1 for m=0; -1 1 (m=1)
350    ## The following gives rounding errors for (25,5) :
351    ## r <- sum( sig * k^n /(gamma(k+1)*gamma(m+1-k)) )
352    ga <- gamma(k+1)
353    round(sum( sig * k^n /(ga * rev(ga))))
354}
355