/src/multi-app/R/model/fit-EM.R

http://github.com/beechung/Latent-Factor-Models · R · 439 lines · 342 code · 43 blank · 54 comment · 113 complexity · 05fa1c9f713972d9f02422a1295bb2f9 MD5 · raw file

  1. ### Copyright (c) 2011, Yahoo! Inc. All rights reserved.
  2. ### Copyrights licensed under the New BSD License. See the accompanying LICENSE file for terms.
  3. ###
  4. ### Author: Bee-Chung Chen
  5. ###
  6. ### Use the EM algorithm to fit the model
  7. ### NOTE:
  8. ### * ridge.lambda[c("A", "B", "beta")] are the numbers to be added to the
  9. ### diagonal when fitting the regression to get A, B, and beta
  10. ###
  11. fit.EM <- function(
  12. # Input data
  13. feature, # data.frame(user, app, index, x, w)
  14. response, # data.frame(user, app, item, y, w)
  15. # Model setup
  16. nGlobalFactors, # num of global factors per user
  17. nLocalFactors=NULL, # num of local factors per user (may be a vector of length: #app)
  18. identity.A=FALSE, identity.B=FALSE, # whether A/B is a identity matrix
  19. is.y.logistic=FALSE, # whether to use a logistic model for y
  20. is.x.logistic=FALSE, # whether to use a logistic model for x
  21. # Test data (optional)
  22. test.feature=NULL,
  23. test.response=NULL,
  24. # Model-fitting parameters
  25. nIter=30, # num of EM iterations
  26. ridge.lambda=c(A=0, B=0, beta=0), # lambda values for ridge regression for A, B, beta
  27. keep.users.without.obs.in.E.step=FALSE,
  28. keep.users.without.obs.in.M.step=FALSE,
  29. fix.var_u=1, # Fix var_u to some number (set to NULL if you do not want to fix it)
  30. # Initialization parameters
  31. param=NULL, # directly set all the parameters (if not NULL, the following will not be used)
  32. var_x=1, var_y=1, var_z=1, # initial var (may be vectors of length: #app)
  33. var_u=1, # initial var (length: 1)
  34. A.sd=1, B.sd=1, beta.sd=1, # A_k ~ N(mean=0, sd=A.sd), and so on.
  35. # Output options
  36. out.level=0, # out.level=1: Save the model in out.dir/model.last and out.dir/model.minTestLoss
  37. out.dir=NULL, # out.level=2: Save the model after each iteration i in out.dir/model.i
  38. out.overwrite=FALSE, # whether to allow overwriting existing files
  39. # Debugging options
  40. debug=0, verbose=0, use.C=TRUE,
  41. show.marginal.loglik=FALSE
  42. ){
  43. option = 0;
  44. if(keep.users.without.obs.in.E.step) option = option + 1;
  45. if(keep.users.without.obs.in.M.step) option = option + 2;
  46. option = as.integer(option);
  47. if(!is.null(test.response) && out.level <= 0 && verbose <= 0) stop("!is.null(test.response) && out.level <= 0 && verbose <= 0");
  48. if(out.level > 0){
  49. if(is.null(out.dir)) stop("out.dir = NULL");
  50. if(file.exists(paste(out.dir,"/model.last",sep="")) && !out.overwrite) stop(out.dir," already exists!!");
  51. if(!file.exists(out.dir)) dir.create(out.dir, recursive=TRUE, mode="0755");
  52. }
  53. if(is.x.logistic){
  54. if(!is.null(feature$w)) stop("When is.x.logistic is TRUE, you cannot specify feature$w")
  55. feature = init.obs.logistic(feature, target="x");
  56. }
  57. if(is.y.logistic){
  58. if(!is.null(response$w)) stop("When is.y.logistic is TRUE, you cannot specify response$w")
  59. response = init.obs.logistic(response, target="y");
  60. }
  61. if(verbose > 0) cat("INITIALIZE THE PARAMETERS\n");
  62. begin.time.entire = proc.time();
  63. begin.time = proc.time();
  64. model = init.simple(
  65. feature=feature, response=response, nGlobalFactors=nGlobalFactors,
  66. param=param, nLocalFactors=nLocalFactors,
  67. var_x=var_x, var_y=var_y, var_z=var_z, var_u=var_u,
  68. identity.A=identity.A, identity.B=identity.B,
  69. is.x.logistic=is.x.logistic, is.y.logistic=is.y.logistic,
  70. A.sd=A.sd, B.sd=B.sd, beta.sd=beta.sd
  71. );
  72. if(!is.null(fix.var_u)) model$param$var_u = as.double(fix.var_u);
  73. time.used = proc.time() - begin.time;
  74. if(verbose > 0) cat("time used: ",time.used[3]," sec\n",sep="");
  75. size = check.syntax.all(feature=feature, response=response, param=model$param, factor=model$factor, check.indices=TRUE);
  76. if(!is.null(test.response)){
  77. check.syntax.all(feature=test.feature, response=test.response, param=model$param, factor=model$factor, check.indices=TRUE, test.data=TRUE);
  78. }
  79. if(verbose >= 2){
  80. cat("--------------------------------------------------------\n",
  81. " Problem Dimensionality:\n",
  82. "--------------------------------------------------------\n",sep="");
  83. print(size);
  84. }
  85. if(verbose > 0) cat("START THE EM-PROCEDURE\n");
  86. buffer = NULL; buffer.addr = NULL;
  87. prediction = NULL;
  88. best.model = NULL; best.testLoss = Inf;
  89. time.pred = NA;
  90. test.loss.y = rep(NA, nIter+1); test.loss.x = rep(NA, nIter+1);
  91. marginal.loglik = rep(NA, nIter+1);
  92. if(!is.null(test.response)){
  93. begin.time = proc.time();
  94. prediction = predict.x.and.y(feature=test.feature, response=test.response, param=model$param, factor=model$factor);
  95. time.pred = proc.time() - begin.time;
  96. test.loss.y[1] = prediction$loss.y;
  97. test.loss.x[1] = prediction$loss.x;
  98. if(prediction$loss.y < best.testLoss){
  99. best.testLoss = prediction$loss.y;
  100. best.model = deepCopy(model);
  101. }
  102. }
  103. ans = output.results(
  104. method="EM", feature=feature, response=response, model=model, prediction=prediction,
  105. out.dir=out.dir, out.level=out.level,
  106. minTestLoss=best.testLoss, iter=0, show.marginal.loglik=show.marginal.loglik,
  107. TimeEM=time.used, TimeTest=time.pred, verbose=verbose,
  108. other=NULL, name="model"
  109. );
  110. marginal.loglik[1] = ans$marginal.loglik;
  111. for(iter in seq_len(nIter)){
  112. if(verbose > 0){
  113. cat("--------------------------------------------------------\n",
  114. " EM-Iteration: ",iter,"\n",
  115. "--------------------------------------------------------\n",sep="");
  116. }
  117. begin.time = proc.time();
  118. # variational approx for logistic model
  119. if(is.x.logistic){
  120. feature$w = get.var.logistic( obs=feature, param=model$param, target="x", verbose=verbose);
  121. feature$x = get.response.logistic(obs=feature, param=model$param, target="x", verbose=verbose);
  122. model$param$var_x[] = 1;
  123. }
  124. if(is.y.logistic){
  125. response$w = get.var.logistic( obs=response, param=model$param, target="y", verbose=verbose);
  126. response$y = get.response.logistic(obs=response, param=model$param, target="y", verbose=verbose);
  127. model$param$var_y[] = 1;
  128. }
  129. if(use.C){
  130. model.addr = get.address(model);
  131. buffer = fit.EM.one.iteration.C(
  132. feature=feature, response=response, param=model$param, factor=model$factor,
  133. ridge.lambda=ridge.lambda,
  134. option=option, debug=debug, verbose=verbose, buffer=buffer
  135. );
  136. # Now, model$param and model$factor contain the result after one EM iteration
  137. # (this is call by reference, not the behavior of regular R functions)
  138. # sanity check
  139. temp = get.address(model);
  140. if(is.diff(model.addr,temp,precision=0)) stop("model address changed!!");
  141. if(is.null(buffer.addr)) buffer.addr = get.address(buffer);
  142. temp = get.address(buffer);
  143. if(is.diff(buffer.addr,temp,precision=0)) stop("buffer address changed!!");
  144. }else{
  145. model = fit.EM.one.iteration.R(feature=feature, response=response, param=model$param)
  146. }
  147. if(!is.null(fix.var_u)) model$param$var_u = as.double(fix.var_u);
  148. # variational approx for logistic model
  149. if(is.x.logistic){
  150. mean.score = predict.x.from.z(feature=feature, param=model$param, z=model$factor$z, add.noise=FALSE);
  151. model$param = update.param.logistic(param=model$param, mean.score=mean.score, var.score=model$factor$var.x.score, target="x");
  152. }
  153. if(is.y.logistic){
  154. mean.score = predict.y.from.z(response=response, param=model$param, z=model$factor$z, add.noise=FALSE);
  155. model$param = update.param.logistic(param=model$param, mean.score=mean.score, var.score=model$factor$var.y.score, target="y");
  156. }
  157. time.used = proc.time() - begin.time;
  158. if(verbose > 0) cat("time used: ",time.used[3]," sec\n",sep="");
  159. if(!is.null(test.response)){
  160. begin.time = proc.time();
  161. prediction = predict.x.and.y(feature=test.feature, response=test.response, param=model$param, factor=model$factor);
  162. time.pred = proc.time() - begin.time;
  163. test.loss.y[iter+1] = prediction$loss.y;
  164. test.loss.x[iter+1] = prediction$loss.x;
  165. if(prediction$loss.y < best.testLoss){
  166. best.testLoss = prediction$loss.y;
  167. best.model = deepCopy(model);
  168. }
  169. }
  170. ans = output.results(
  171. method="EM", feature=feature, response=response, model=model, prediction=prediction,
  172. out.dir=out.dir, out.level=out.level,
  173. minTestLoss=best.testLoss, iter=iter, show.marginal.loglik=show.marginal.loglik,
  174. TimeEM=time.used, TimeTest=time.pred, verbose=verbose,
  175. other=NULL, name="model"
  176. );
  177. marginal.loglik[iter+1] = ans$marginal.loglik;
  178. }
  179. time.used = proc.time() - begin.time.entire;
  180. if(verbose > 0) cat("END OF THE EM-PROCEDURE\nTotal time used: ",time.used[3]," sec\n",sep="");
  181. out = list(model=model, model.min.test.loss=best.model, test.loss.x=test.loss.x, test.loss.y=test.loss.y, marginal.loglik=marginal.loglik);
  182. return(out);
  183. }
  184. ###
  185. ### Initialization
  186. ###
  187. init.simple <- function(
  188. feature, response, nGlobalFactors,
  189. param=NULL, # directly set all the parameters
  190. nLocalFactors=NULL, # which may be an array of length: nApps
  191. var_x=1, var_y=1, var_z=1, # which may be arrays of length: nApps
  192. var_u=1, identity.A=FALSE, identity.B=FALSE,
  193. is.y.logistic=FALSE, # whether to use a logistic model for y
  194. is.x.logistic=FALSE, # whether to use a logistic model for x
  195. A.sd=1, B.sd=1, beta.sd=1
  196. ){
  197. check.syntax.obs(feature=feature, response=response);
  198. nApps = max(feature$app, response$app);
  199. nUsers = max(feature$user, response$user);
  200. if(!is.null(param)){
  201. size = check.syntax.param(param);
  202. if(nApps != size$nApps) stop("Input param has ",param$nApps," applications, but the input data has ",nApps);
  203. if(nGlobalFactors != size$nGlobalFactors) stop("Input param has ",param$nGlobalFactors," global factors per user, but you specify ",nGlobalFactors);
  204. if(is.null(nLocalFactors)){
  205. nLocalFactors = size$nLocalFactors;
  206. }else{
  207. if(length(nLocalFactors) == 1) nLocalFactors = rep(nLocalFactors, nApps);
  208. if(length(nLocalFactors) != nApps) stop("length(nLocalFactors) != nApps");
  209. if(any(nLocalFactors != size$nLocalFactors)) stop("Input param has different nLocalFactors than your specification");
  210. }
  211. z = list();
  212. for(k in seq_len(nApps)){
  213. z[[k]] = matrix(0.0, nrow=nUsers, ncol=nLocalFactors[k]);
  214. }
  215. factor = list(u=matrix(0.0, nrow=nUsers, ncol=nGlobalFactors), z=z);
  216. out = list(param=param, factor=factor);
  217. return(out);
  218. }
  219. if(!identity.B && is.null(nLocalFactors)) stop("Please specify nLocalFactors");
  220. if(is.null(nLocalFactors)) nLocalFactors = rep(NA, nApps);
  221. if(length(nLocalFactors) == 1) nLocalFactors = rep(nLocalFactors, nApps);
  222. if(length(var_x) == 1) var_x = rep(var_x, nApps);
  223. if(length(var_y) == 1) var_y = rep(var_y, nApps);
  224. if(length(var_z) == 1) var_z = rep(var_z, nApps);
  225. if(length(nLocalFactors) != nApps) stop("length(nLocalFactors) != nApps");
  226. if(length(var_x) != nApps) stop("length(var_x) != nApps");
  227. if(length(var_y) != nApps) stop("length(var_y) != nApps");
  228. if(length(var_z) != nApps) stop("length(var_z) != nApps");
  229. A = list(); B = list(); b = list(); alpha = list(); beta = list();
  230. temp.f = tapply(seq_len(length(feature$app)), list(feature$app), FUN=c, simplify=FALSE);
  231. temp.r = tapply(seq_len(length(response$app)), list(response$app), FUN=c, simplify=FALSE);
  232. for(k in seq_len(nApps)){
  233. m = as.character(k); # IMPORTANT: use string to access temp.f and temp.r
  234. select = temp.f[[m]];
  235. obs = feature[select,];
  236. if(length(select) == 0){
  237. if(identity.B) nLocalFactors[k] = 0;
  238. B[[k]] = matrix(0.0, nrow=0, ncol=nLocalFactors[k]);
  239. b[[k]] = rep(0.0, 0);
  240. }else if(identity.B){
  241. nFeatures = max(obs$index);
  242. if(is.na(nLocalFactors[k])) nLocalFactors[k] = nFeatures
  243. else if(nLocalFactors[k] != nFeatures) stop("Misspecification of nLocalFactors[",k,"] = ",nLocalFactors[k]," with identity.B");
  244. B[[k]] = 1.0
  245. b[[k]] = rep(0.0, nFeatures);
  246. }else{
  247. nFeatures = max(obs$index);
  248. B[[k]] = matrix(rnorm(nFeatures*nLocalFactors[k],sd=B.sd), nrow=nFeatures, ncol=nLocalFactors[k]);
  249. b[[k]] = rep(0.0, nFeatures);
  250. agg = aggregate(obs$x, by=list(by=obs$index), FUN=mean);
  251. b[[k]][agg$by] = agg$x;
  252. }
  253. if(identity.A){
  254. A[[k]] = 1;
  255. if(nLocalFactors[k] != nGlobalFactors) stop("Misspecification of nLocalFactors[",k,"] = ",nLocalFactors[k]," with identity.A");
  256. }else{
  257. A[[k]] = matrix(rnorm(nLocalFactors[k]*nGlobalFactors,sd=A.sd), nrow=nLocalFactors[k], ncol=nGlobalFactors);
  258. }
  259. select = temp.r[[m]];
  260. obs = response[select,];
  261. if(length(select) == 0){
  262. beta[[k]] = matrix(0.0, nrow=0, ncol=nLocalFactors[k]);
  263. alpha[[k]] = rep(0.0, 0);
  264. }else{
  265. nItems = max(obs$item);
  266. beta[[k]] = matrix(rnorm(nItems*nLocalFactors[k],sd=beta.sd), nrow=nItems, ncol=nLocalFactors[k]);
  267. alpha[[k]] = rep(0.0, nItems);
  268. agg = aggregate(obs$y, by=list(by=obs$item), FUN=mean);
  269. alpha[[k]][agg$by] = agg$x;
  270. }
  271. }
  272. param = list(A=A, B=B, b=b, alpha=alpha, beta=beta, var_x=var_x, var_y=var_y, var_z=var_z, var_u=var_u,
  273. is.x.logistic=is.x.logistic, is.y.logistic=is.y.logistic);
  274. z = list();
  275. for(k in seq_len(nApps)){
  276. z[[k]] = matrix(0.0, nrow=nUsers, ncol=nLocalFactors[k]);
  277. }
  278. factor = list(u=matrix(0.0, nrow=nUsers, ncol=nGlobalFactors), z=z);
  279. if(is.x.logistic){
  280. param = init.param.logistic(param, feature, value=1.0, target="x");
  281. factor$var.x.score = rep(1.0, nrow(feature));
  282. }
  283. if(is.y.logistic){
  284. param = init.param.logistic(param, response, value=1.0, target="y");
  285. factor$var.y.score = rep(1.0, nrow(response));
  286. }
  287. out = list(param=param, factor=factor);
  288. return(out);
  289. }
  290. ###
  291. ### One EM iteration
  292. ### Run E-Step once, M-step once
  293. ### Input:
  294. ### feature = data.frame(user, app, index, x, w)
  295. ### response = data.frame(user, app, item, y, w)
  296. ### param = list(A, B, b, alpha, beta, var_x, var_y, var_z, var_u)
  297. ### Output: list(param, factor)
  298. ### param: The parameter values after the M-step
  299. ### factor=list(u, z): The posterior mean of u and z
  300. ### if(is.x.logistic) factor$var.x.score[ikm] = Var[B_{k,m} z_{ik}]
  301. ### if(is.y.logistic) factor$var.y.score[ijk] = Var[beta_{jk} z_{ik}]
  302. fit.EM.one.iteration.R <- function(feature, response, param){
  303. stop("Please implement this function");
  304. }
  305. ### IMPORTANT NOTE:
  306. ### * In the C version, to reduce the memory footprint, the content of the input
  307. ### param and factor will be changed to the updated values.
  308. ### This is call by reference, and is NOT the behavior of regular R functions.
  309. ### buffer is the temp space (which can be NULL) and is also call by reference.
  310. ### * ridge.lambda[c("A", "B", "beta")] are the numbers to be added to the
  311. ### diagonal when fitting the regression to get A, B, and beta
  312. ### * buffer = list(A, B, b, alpha, beta, z)
  313. ### contains the packed version of A, B, b, ... for C/C++ functions
  314. fit.EM.one.iteration.C <- function(
  315. feature, response, param, factor,
  316. ridge.lambda=c(A=0, B=0, beta=0),
  317. buffer=NULL, option=0, debug=0, verbose=0
  318. ){
  319. size = check.syntax.all(feature=feature, response=response, param=param, factor=factor);
  320. if(is.null(buffer)) buffer = list()
  321. else check_names(buffer, "buffer", required=c("A", "B", "b", "alpha", "beta", "z"));
  322. if(length(ridge.lambda) != 3) stop("length(ridge.lambda) != 3");
  323. if(any(c("A", "B", "beta") != names(ridge.lambda))) stop("ridge.lambda must be a named vector with names: A, B, beta (in this order)");
  324. for(name in c("A", "B", "b", "alpha", "beta")){
  325. if(is.null(buffer[[name]])) buffer[[name]] = pack.list.of.matrices(param[[name]])
  326. else pack.list.of.matrices(param[[name]], output=buffer[[name]]);
  327. check_type_size(buffer[[name]]$data, "double", NA);
  328. check_type_size(buffer[[name]]$dim, "int", NA);
  329. }
  330. for(name in c("z")){
  331. if(is.null(buffer[[name]])) buffer[[name]] = pack.list.of.matrices(factor[[name]])
  332. else pack.list.of.matrices(factor[[name]], output=buffer[[name]]);
  333. check_type_size(buffer[[name]]$data, "double", NA);
  334. check_type_size(buffer[[name]]$dim, "int", NA);
  335. }
  336. for(name in c("var_x", "var_y", "var_z")){
  337. check_type_size(param[[name]], "double", size$nApps);
  338. }
  339. check_type_size(param[["var_u"]], "double", 1);
  340. check_type_size(factor[["u"]], "double", c(size$nUsers, size$nGlobalFactors));
  341. for(name in c("user", "app", "index")){
  342. check_type_size(feature[[name]], "int", nrow(feature));
  343. }
  344. check_type_size(feature[["x"]], "double", nrow(feature));
  345. if(is.null(feature[["w"]])){
  346. feature_has_w = as.integer(0);
  347. }else{
  348. feature_has_w = as.integer(1);
  349. check_type_size(feature[["w"]], "double", nrow(feature));
  350. }
  351. for(name in c("user", "app", "item")){
  352. check_type_size(response[[name]], "int", nrow(response));
  353. }
  354. check_type_size(response[["y"]], "double", nrow(response));
  355. if(is.null(response[["w"]])){
  356. response_has_w = as.integer(0);
  357. }else{
  358. response_has_w = as.integer(1);
  359. check_type_size(response[["w"]], "double", nrow(response));
  360. }
  361. for(name in c("nFeatures", "nItems", "nLocalFactors")){
  362. check_type_size(size[[name]], "int", size$nApps)
  363. }
  364. check_type_size(factor[["var.y.score"]], "double", length(factor[["var.y.score"]]));
  365. check_type_size(factor[["var.x.score"]], "double", length(factor[["var.x.score"]]));
  366. .C("EM_one_iteration",
  367. # Parameters: INPUT and OUTPUT
  368. buffer[["A"]]$data, buffer[["A"]]$dim, buffer[["B"]]$data, buffer[["B"]]$dim, buffer[["b"]]$data, buffer[["b"]]$dim,
  369. buffer[["alpha"]]$data, buffer[["alpha"]]$dim, buffer[["beta"]]$data, buffer[["beta"]]$dim,
  370. param[["var_x"]], param[["var_y"]], param[["var_z"]], param[["var_u"]],
  371. # Posterior mean of factors: OUTPUT
  372. factor[["u"]], buffer[["z"]]$data, buffer[["z"]]$dim,
  373. # Posterior variance for logistic regression: OUTPUT
  374. factor[["var.x.score"]], as.integer(length(factor[["var.x.score"]])),
  375. factor[["var.y.score"]], as.integer(length(factor[["var.y.score"]])),
  376. # Feature table: INPUT
  377. feature[["user"]], feature[["app"]], feature[["index"]], feature[["x"]], feature[["w"]],
  378. as.integer(nrow(feature)), feature_has_w,
  379. # Response table: INPUT
  380. response[["user"]], response[["app"]], response[["item"]], response[["y"]], response[["w"]],
  381. as.integer(nrow(response)), response_has_w,
  382. # Ridge regression parameters: INPUT
  383. as.double(ridge.lambda), as.integer(length(ridge.lambda)),
  384. # Size information: INPUT
  385. as.integer(size$nApps), as.integer(size$nUsers), as.integer(size$nGlobalFactors),
  386. size$nFeatures, size$nItems, size$nLocalFactors,
  387. # Others
  388. as.integer(option), as.integer(verbose), as.integer(debug),
  389. DUP=FALSE
  390. );
  391. for(name in c("A", "B", "b", "alpha", "beta")){
  392. unpack.list.of.matrices(buffer[[name]], output=param[[name]]);
  393. }
  394. unpack.list.of.matrices(buffer[["z"]], output=factor[["z"]]);
  395. return(buffer);
  396. }