PageRenderTime 75ms CodeModel.GetById 16ms app.highlight 53ms RepoModel.GetById 1ms app.codeStats 0ms

/src/multi-app/R/model/fit-EM.R

http://github.com/beechung/Latent-Factor-Models
R | 439 lines | 342 code | 43 blank | 54 comment | 113 complexity | 05fa1c9f713972d9f02422a1295bb2f9 MD5 | raw file
  1### Copyright (c) 2011, Yahoo! Inc.  All rights reserved.
  2### Copyrights licensed under the New BSD License. See the accompanying LICENSE file for terms.
  3### 
  4### Author: Bee-Chung Chen
  5
  6###
  7### Use the EM algorithm to fit the model
  8### NOTE:
  9###   * ridge.lambda[c("A", "B", "beta")] are the numbers to be added to the
 10###     diagonal when fitting the regression to get A, B, and beta
 11###
 12fit.EM <- function(
 13	# Input data
 14	feature,  # data.frame(user, app, index, x, w)
 15	response, # data.frame(user, app, item,  y, w)
 16	# Model setup
 17	nGlobalFactors,     # num of global factors per user
 18	nLocalFactors=NULL, # num of local  factors per user (may be a vector of length: #app)
 19	identity.A=FALSE, identity.B=FALSE, # whether A/B is a identity matrix
 20	is.y.logistic=FALSE, # whether to use a logistic model for y
 21	is.x.logistic=FALSE, # whether to use a logistic model for x
 22	# Test data (optional)
 23	test.feature=NULL,
 24	test.response=NULL,
 25	# Model-fitting parameters
 26	nIter=30, # num of EM iterations
 27	ridge.lambda=c(A=0, B=0, beta=0), # lambda values for ridge regression for A, B, beta
 28	keep.users.without.obs.in.E.step=FALSE,
 29	keep.users.without.obs.in.M.step=FALSE,
 30	fix.var_u=1, # Fix var_u to some number (set to NULL if you do not want to fix it)
 31	# Initialization parameters
 32	param=NULL, # directly set all the parameters (if not NULL, the following will not be used)
 33	var_x=1, var_y=1, var_z=1, # initial var (may be vectors of length: #app)
 34	var_u=1,                   # initial var (length: 1)
 35	A.sd=1, B.sd=1, beta.sd=1, # A_k ~ N(mean=0, sd=A.sd), and so on.
 36	# Output options
 37	out.level=0,  # out.level=1: Save the model in out.dir/model.last and out.dir/model.minTestLoss
 38	out.dir=NULL, # out.level=2: Save the model after each iteration i in out.dir/model.i
 39	out.overwrite=FALSE, # whether to allow overwriting existing files
 40	# Debugging options
 41	debug=0, verbose=0, use.C=TRUE,
 42	show.marginal.loglik=FALSE
 43){
 44	option = 0;
 45	if(keep.users.without.obs.in.E.step) option = option + 1;
 46	if(keep.users.without.obs.in.M.step) option = option + 2;
 47	option = as.integer(option);
 48	
 49	if(!is.null(test.response) && out.level <= 0 && verbose <= 0) stop("!is.null(test.response) && out.level <= 0 && verbose <= 0");
 50	if(out.level > 0){
 51		if(is.null(out.dir)) stop("out.dir = NULL");
 52		if(file.exists(paste(out.dir,"/model.last",sep="")) && !out.overwrite) stop(out.dir," already exists!!");
 53		if(!file.exists(out.dir)) dir.create(out.dir, recursive=TRUE, mode="0755");
 54	}
 55	
 56	if(is.x.logistic){
 57		if(!is.null(feature$w)) stop("When is.x.logistic is TRUE, you cannot specify feature$w")
 58		feature = init.obs.logistic(feature,  target="x");
 59	}
 60	if(is.y.logistic){
 61		if(!is.null(response$w)) stop("When is.y.logistic is TRUE, you cannot specify response$w")
 62		response = init.obs.logistic(response, target="y");
 63	}
 64	
 65	if(verbose > 0) cat("INITIALIZE THE PARAMETERS\n");
 66	begin.time.entire = proc.time();
 67	begin.time = proc.time();
 68	model = init.simple(
 69		feature=feature, response=response, nGlobalFactors=nGlobalFactors,
 70		param=param, nLocalFactors=nLocalFactors, 
 71		var_x=var_x, var_y=var_y, var_z=var_z, var_u=var_u, 
 72		identity.A=identity.A, identity.B=identity.B,
 73		is.x.logistic=is.x.logistic, is.y.logistic=is.y.logistic,
 74		A.sd=A.sd, B.sd=B.sd, beta.sd=beta.sd
 75	);
 76	if(!is.null(fix.var_u)) model$param$var_u = as.double(fix.var_u);
 77	time.used = proc.time() - begin.time;
 78	if(verbose > 0) cat("time used: ",time.used[3]," sec\n",sep="");
 79	size = check.syntax.all(feature=feature, response=response, param=model$param, factor=model$factor, check.indices=TRUE);
 80	if(!is.null(test.response)){
 81		check.syntax.all(feature=test.feature, response=test.response, param=model$param, factor=model$factor, check.indices=TRUE, test.data=TRUE);
 82	}
 83	if(verbose >= 2){
 84		cat("--------------------------------------------------------\n",
 85			"   Problem Dimensionality:\n",
 86			"--------------------------------------------------------\n",sep="");
 87		print(size);
 88	}
 89	
 90	if(verbose > 0) cat("START  THE EM-PROCEDURE\n");
 91	buffer = NULL; buffer.addr = NULL;
 92	
 93	prediction = NULL;
 94	best.model = NULL;  best.testLoss = Inf;
 95	time.pred  = NA;
 96	test.loss.y = rep(NA, nIter+1);  test.loss.x = rep(NA, nIter+1);
 97	marginal.loglik = rep(NA, nIter+1);
 98	
 99	if(!is.null(test.response)){
100		begin.time = proc.time();
101		prediction = predict.x.and.y(feature=test.feature, response=test.response, param=model$param, factor=model$factor);
102		time.pred = proc.time() - begin.time;
103		test.loss.y[1] = prediction$loss.y;
104		test.loss.x[1] = prediction$loss.x;
105		if(prediction$loss.y < best.testLoss){
106			best.testLoss = prediction$loss.y;
107			best.model = deepCopy(model);
108		}
109	}
110	ans = output.results(
111		method="EM", feature=feature, response=response, model=model, prediction=prediction,
112		out.dir=out.dir, out.level=out.level,
113		minTestLoss=best.testLoss, iter=0, show.marginal.loglik=show.marginal.loglik,
114		TimeEM=time.used, TimeTest=time.pred, verbose=verbose,
115		other=NULL, name="model"
116	);
117	marginal.loglik[1] = ans$marginal.loglik;
118	
119	for(iter in seq_len(nIter)){
120		if(verbose > 0){
121			cat("--------------------------------------------------------\n",
122				"   EM-Iteration: ",iter,"\n",
123				"--------------------------------------------------------\n",sep="");
124		}
125		begin.time = proc.time();
126		
127		# variational approx for logistic model
128		if(is.x.logistic){
129			feature$w = get.var.logistic(     obs=feature, param=model$param, target="x", verbose=verbose);
130			feature$x = get.response.logistic(obs=feature, param=model$param, target="x", verbose=verbose);
131			model$param$var_x[] = 1;
132		}
133		if(is.y.logistic){
134			response$w = get.var.logistic(     obs=response, param=model$param, target="y", verbose=verbose);
135			response$y = get.response.logistic(obs=response, param=model$param, target="y", verbose=verbose);
136			model$param$var_y[] = 1;
137		}
138		
139		if(use.C){
140			model.addr = get.address(model);
141			buffer = fit.EM.one.iteration.C(
142				feature=feature, response=response, param=model$param, factor=model$factor,
143				ridge.lambda=ridge.lambda, 
144				option=option, debug=debug, verbose=verbose, buffer=buffer
145			);
146			# Now, model$param and model$factor contain the result after one EM iteration
147			# (this is call by reference, not the behavior of regular R functions)
148		
149			# sanity check
150			temp = get.address(model);
151			if(is.diff(model.addr,temp,precision=0)) stop("model address changed!!");
152			if(is.null(buffer.addr)) buffer.addr = get.address(buffer);
153			temp = get.address(buffer);
154			if(is.diff(buffer.addr,temp,precision=0)) stop("buffer address changed!!");
155			
156		}else{
157			model = fit.EM.one.iteration.R(feature=feature, response=response, param=model$param)
158		}
159		if(!is.null(fix.var_u)) model$param$var_u = as.double(fix.var_u);
160		
161		# variational approx for logistic model
162		if(is.x.logistic){
163			mean.score = predict.x.from.z(feature=feature, param=model$param, z=model$factor$z, add.noise=FALSE);
164			model$param = update.param.logistic(param=model$param, mean.score=mean.score, var.score=model$factor$var.x.score, target="x");
165		}
166		if(is.y.logistic){
167			mean.score = predict.y.from.z(response=response, param=model$param, z=model$factor$z, add.noise=FALSE);
168			model$param = update.param.logistic(param=model$param, mean.score=mean.score, var.score=model$factor$var.y.score, target="y");
169		}
170		
171		time.used = proc.time() - begin.time;
172		if(verbose > 0) cat("time used: ",time.used[3]," sec\n",sep="");
173		
174		if(!is.null(test.response)){
175			begin.time = proc.time();
176			prediction = predict.x.and.y(feature=test.feature, response=test.response, param=model$param, factor=model$factor);
177			time.pred = proc.time() - begin.time;
178			test.loss.y[iter+1] = prediction$loss.y;
179			test.loss.x[iter+1] = prediction$loss.x;
180			if(prediction$loss.y < best.testLoss){
181				best.testLoss = prediction$loss.y;
182				best.model = deepCopy(model);
183			}
184		}
185		ans = output.results(
186				method="EM", feature=feature, response=response, model=model, prediction=prediction,
187				out.dir=out.dir, out.level=out.level,
188				minTestLoss=best.testLoss, iter=iter, show.marginal.loglik=show.marginal.loglik,
189				TimeEM=time.used, TimeTest=time.pred, verbose=verbose,
190				other=NULL, name="model"
191		);
192		marginal.loglik[iter+1] = ans$marginal.loglik;
193	}
194		
195	time.used = proc.time() - begin.time.entire;
196	if(verbose > 0) cat("END OF THE EM-PROCEDURE\nTotal time used: ",time.used[3]," sec\n",sep="");
197	
198	out = list(model=model, model.min.test.loss=best.model, test.loss.x=test.loss.x, test.loss.y=test.loss.y, marginal.loglik=marginal.loglik);
199	
200	return(out);
201}
202
203###
204### Initialization
205###
206init.simple <- function(
207	feature, response, nGlobalFactors, 
208	param=NULL, # directly set all the parameters
209	nLocalFactors=NULL,        # which may be an array of length: nApps
210	var_x=1, var_y=1, var_z=1, # which may be arrays of length: nApps
211	var_u=1, identity.A=FALSE, identity.B=FALSE,
212	is.y.logistic=FALSE, # whether to use a logistic model for y
213	is.x.logistic=FALSE, # whether to use a logistic model for x
214	A.sd=1, B.sd=1, beta.sd=1
215){
216	check.syntax.obs(feature=feature, response=response);
217	nApps  = max(feature$app,  response$app);
218	nUsers = max(feature$user, response$user);
219	
220	if(!is.null(param)){
221		size = check.syntax.param(param);
222		if(nApps != size$nApps) stop("Input param has ",param$nApps," applications, but the input data has ",nApps);
223		if(nGlobalFactors != size$nGlobalFactors) stop("Input param has ",param$nGlobalFactors," global factors per user, but you specify ",nGlobalFactors);
224		if(is.null(nLocalFactors)){
225			nLocalFactors = size$nLocalFactors;
226		}else{
227			if(length(nLocalFactors) == 1) nLocalFactors = rep(nLocalFactors, nApps);
228			if(length(nLocalFactors) != nApps) stop("length(nLocalFactors) != nApps");
229			if(any(nLocalFactors != size$nLocalFactors)) stop("Input param has different nLocalFactors than your specification");
230		}
231		z = list();
232		for(k in seq_len(nApps)){
233			z[[k]] = matrix(0.0, nrow=nUsers, ncol=nLocalFactors[k]);
234		}
235		factor = list(u=matrix(0.0, nrow=nUsers, ncol=nGlobalFactors), z=z);
236		out = list(param=param, factor=factor);
237		return(out);
238	}
239	
240	if(!identity.B && is.null(nLocalFactors)) stop("Please specify nLocalFactors");
241	if(is.null(nLocalFactors)) nLocalFactors = rep(NA, nApps);
242	
243	if(length(nLocalFactors) == 1) nLocalFactors = rep(nLocalFactors, nApps);
244	if(length(var_x) == 1) var_x = rep(var_x, nApps);
245	if(length(var_y) == 1) var_y = rep(var_y, nApps);
246	if(length(var_z) == 1) var_z = rep(var_z, nApps);
247	if(length(nLocalFactors) != nApps) stop("length(nLocalFactors) != nApps");
248	if(length(var_x) != nApps) stop("length(var_x) != nApps");
249	if(length(var_y) != nApps) stop("length(var_y) != nApps");
250	if(length(var_z) != nApps) stop("length(var_z) != nApps");
251	
252	A = list(); B = list(); b = list(); alpha = list(); beta = list();
253	
254	temp.f = tapply(seq_len(length(feature$app)),  list(feature$app),  FUN=c, simplify=FALSE);
255	temp.r = tapply(seq_len(length(response$app)), list(response$app), FUN=c, simplify=FALSE);
256	
257	for(k in seq_len(nApps)){
258		m = as.character(k); # IMPORTANT: use string to access temp.f and temp.r
259		
260		select = temp.f[[m]];
261		obs = feature[select,];
262		if(length(select) == 0){
263			if(identity.B) nLocalFactors[k] = 0;
264			B[[k]] = matrix(0.0, nrow=0, ncol=nLocalFactors[k]);
265			b[[k]] = rep(0.0, 0);
266		}else if(identity.B){
267			nFeatures = max(obs$index);
268			if(is.na(nLocalFactors[k])) nLocalFactors[k] = nFeatures
269			else if(nLocalFactors[k] != nFeatures) stop("Misspecification of nLocalFactors[",k,"] = ",nLocalFactors[k]," with identity.B");
270			B[[k]] = 1.0
271			b[[k]] = rep(0.0, nFeatures);
272		}else{
273			nFeatures = max(obs$index);
274			B[[k]] = matrix(rnorm(nFeatures*nLocalFactors[k],sd=B.sd), nrow=nFeatures, ncol=nLocalFactors[k]);
275			b[[k]] = rep(0.0, nFeatures);
276			agg = aggregate(obs$x, by=list(by=obs$index), FUN=mean);
277			b[[k]][agg$by] = agg$x;
278		}
279		
280		if(identity.A){
281			A[[k]] = 1;
282			if(nLocalFactors[k] != nGlobalFactors) stop("Misspecification of nLocalFactors[",k,"] = ",nLocalFactors[k]," with identity.A");
283		}else{
284			A[[k]] = matrix(rnorm(nLocalFactors[k]*nGlobalFactors,sd=A.sd), nrow=nLocalFactors[k], ncol=nGlobalFactors);	
285		}
286		
287		select = temp.r[[m]];
288		obs = response[select,];
289		if(length(select) == 0){
290			beta[[k]]  = matrix(0.0, nrow=0, ncol=nLocalFactors[k]);
291			alpha[[k]] = rep(0.0, 0);
292		}else{
293			nItems = max(obs$item);
294			beta[[k]]  = matrix(rnorm(nItems*nLocalFactors[k],sd=beta.sd), nrow=nItems, ncol=nLocalFactors[k]);
295			alpha[[k]] = rep(0.0, nItems);
296			agg = aggregate(obs$y, by=list(by=obs$item), FUN=mean);
297			alpha[[k]][agg$by] = agg$x;
298		}
299	}
300	param = list(A=A, B=B, b=b, alpha=alpha, beta=beta, var_x=var_x, var_y=var_y, var_z=var_z, var_u=var_u,
301	             is.x.logistic=is.x.logistic, is.y.logistic=is.y.logistic);
302
303	z = list();
304	for(k in seq_len(nApps)){
305		z[[k]] = matrix(0.0, nrow=nUsers, ncol=nLocalFactors[k]);
306	}
307	factor = list(u=matrix(0.0, nrow=nUsers, ncol=nGlobalFactors), z=z);
308	
309	if(is.x.logistic){
310		param = init.param.logistic(param, feature, value=1.0, target="x");
311		factor$var.x.score = rep(1.0, nrow(feature));
312	}
313	if(is.y.logistic){
314		param = init.param.logistic(param, response, value=1.0, target="y");
315		factor$var.y.score = rep(1.0, nrow(response));
316	}	
317	
318	out = list(param=param, factor=factor);
319	return(out);
320}
321
322###
323### One EM iteration
324###		Run E-Step once, M-step once
325###	Input:
326###		feature = data.frame(user, app, index, x, w)
327###    response = data.frame(user, app, item,  y, w)
328###       param = list(A, B, b, alpha, beta, var_x, var_y, var_z, var_u)
329### Output: list(param, factor)
330###		param: The parameter values after the M-step
331###    factor=list(u, z): The posterior mean of u and z
332###    if(is.x.logistic) factor$var.x.score[ikm] = Var[B_{k,m} z_{ik}]  
333###    if(is.y.logistic) factor$var.y.score[ijk] = Var[beta_{jk} z_{ik}]
334fit.EM.one.iteration.R <- function(feature, response, param){
335	stop("Please implement this function");
336}
337### IMPORTANT NOTE:
338###   * In the C version, to reduce the memory footprint, the content of the input 
339###     param and factor will be changed to the updated values.
340###     This is call by reference, and is NOT the behavior of regular R functions.
341###     buffer is the temp space (which can be NULL) and is also call by reference.
342###   * ridge.lambda[c("A", "B", "beta")] are the numbers to be added to the
343###     diagonal when fitting the regression to get A, B, and beta
344###   * buffer = list(A, B, b, alpha, beta, z)
345###     contains the packed version of A, B, b, ... for C/C++ functions
346fit.EM.one.iteration.C <- function(
347	feature, response, param, factor,
348	ridge.lambda=c(A=0, B=0, beta=0),
349	buffer=NULL, option=0, debug=0, verbose=0
350){
351	size = check.syntax.all(feature=feature, response=response, param=param, factor=factor);
352	
353	if(is.null(buffer)) buffer = list()
354	else                check_names(buffer, "buffer", required=c("A", "B", "b", "alpha", "beta", "z"));
355
356	if(length(ridge.lambda) != 3) stop("length(ridge.lambda) != 3");
357	if(any(c("A", "B", "beta") != names(ridge.lambda))) stop("ridge.lambda must be a named vector with names: A, B, beta (in this order)");
358	
359	for(name in c("A", "B", "b", "alpha", "beta")){
360		if(is.null(buffer[[name]])) buffer[[name]] = pack.list.of.matrices(param[[name]])
361		else                  pack.list.of.matrices(param[[name]], output=buffer[[name]]);
362		check_type_size(buffer[[name]]$data, "double", NA);
363		check_type_size(buffer[[name]]$dim,  "int",    NA);
364	}
365	for(name in c("z")){
366		if(is.null(buffer[[name]])) buffer[[name]] = pack.list.of.matrices(factor[[name]])
367		else                  pack.list.of.matrices(factor[[name]], output=buffer[[name]]);
368		check_type_size(buffer[[name]]$data, "double", NA);
369		check_type_size(buffer[[name]]$dim,  "int",    NA);
370	}
371	for(name in c("var_x", "var_y", "var_z")){
372		check_type_size(param[[name]], "double", size$nApps);
373	}
374	check_type_size(param[["var_u"]], "double", 1);
375	check_type_size(factor[["u"]], "double", c(size$nUsers, size$nGlobalFactors));
376
377	for(name in c("user", "app", "index")){
378		check_type_size(feature[[name]], "int", nrow(feature));
379	}
380	check_type_size(feature[["x"]], "double", nrow(feature));
381	if(is.null(feature[["w"]])){
382		feature_has_w = as.integer(0);
383	}else{
384		feature_has_w = as.integer(1);
385		check_type_size(feature[["w"]], "double", nrow(feature));
386	}
387
388	for(name in c("user", "app", "item")){
389		check_type_size(response[[name]], "int", nrow(response));
390	}
391	check_type_size(response[["y"]], "double", nrow(response));
392	if(is.null(response[["w"]])){
393		response_has_w = as.integer(0);
394	}else{
395		response_has_w = as.integer(1);
396		check_type_size(response[["w"]], "double", nrow(response));
397	}
398	
399	for(name in c("nFeatures", "nItems", "nLocalFactors")){
400		check_type_size(size[[name]], "int", size$nApps)
401	}
402	
403	check_type_size(factor[["var.y.score"]], "double", length(factor[["var.y.score"]]));
404	check_type_size(factor[["var.x.score"]], "double", length(factor[["var.x.score"]]));
405	
406	.C("EM_one_iteration",
407		# Parameters: INPUT and OUTPUT
408		buffer[["A"]]$data, buffer[["A"]]$dim,  buffer[["B"]]$data, buffer[["B"]]$dim,  buffer[["b"]]$data, buffer[["b"]]$dim,
409		buffer[["alpha"]]$data, buffer[["alpha"]]$dim,  buffer[["beta"]]$data, buffer[["beta"]]$dim,
410		param[["var_x"]], param[["var_y"]], param[["var_z"]], param[["var_u"]],
411		# Posterior mean of factors: OUTPUT
412		factor[["u"]],  buffer[["z"]]$data, buffer[["z"]]$dim,
413		# Posterior variance for logistic regression: OUTPUT
414		factor[["var.x.score"]], as.integer(length(factor[["var.x.score"]])),
415		factor[["var.y.score"]], as.integer(length(factor[["var.y.score"]])),
416		# Feature table: INPUT
417		feature[["user"]], feature[["app"]], feature[["index"]], feature[["x"]], feature[["w"]],
418		as.integer(nrow(feature)), feature_has_w,
419		# Response table: INPUT
420		response[["user"]], response[["app"]], response[["item"]], response[["y"]], response[["w"]],
421		as.integer(nrow(response)), response_has_w,
422		# Ridge regression parameters: INPUT
423		as.double(ridge.lambda), as.integer(length(ridge.lambda)),
424		# Size information: INPUT
425		as.integer(size$nApps), as.integer(size$nUsers), as.integer(size$nGlobalFactors),
426		size$nFeatures, size$nItems, size$nLocalFactors,
427		# Others
428		as.integer(option), as.integer(verbose), as.integer(debug),
429		DUP=FALSE
430	);
431	
432	for(name in c("A", "B", "b", "alpha", "beta")){
433		unpack.list.of.matrices(buffer[[name]], output=param[[name]]);
434	}
435	unpack.list.of.matrices(buffer[["z"]], output=factor[["z"]]);
436	
437	return(buffer);
438}
439