/src/R/examples/multicontext_model.R

http://github.com/beechung/Latent-Factor-Models · R · 270 lines · 164 code · 11 blank · 95 comment · 0 complexity · cc1b1b8d1299df22f682c83a8559de94 MD5 · raw file

  1. ### Copyright (c) 2011, Yahoo! Inc. All rights reserved.
  2. ### Copyrights licensed under the New BSD License. See the accompanying LICENSE file for terms.
  3. ###
  4. ### Author: Bee-Chung Chen
  5. ###
  6. ### Preparation:
  7. ### (1) Set your path/alias to run the right version of R
  8. ### (2) make (in public-factor-models/, not in any subdirectory)
  9. ### (3) Take a look at public-factor-models/src/R/model/Notation-multicontext.txt
  10. ### for the specification of the model.
  11. ### (4) Run R (in public-factor-models/, not in any subdirectory)
  12. ###
  13. ###
  14. ### Example 1: Run the fitting code with synthetic data
  15. ###
  16. # (1) Generate some data
  17. # See src/R/model/multicontext_model_genData.R for details
  18. library(Matrix);
  19. dyn.load("lib/c_funcs.so");
  20. source("src/R/c_funcs.R");
  21. source("src/R/util.R");
  22. source("src/R/model/util.R");
  23. source("src/R/model/multicontext_model_genData.R");
  24. source("src/R/model/multicontext_model_utils.R");
  25. set.seed(0);
  26. d = generate.GaussianData(
  27. nSrcNodes=203, nDstNodes=203, nObs=10003,
  28. nSrcContexts=4, nDstContexts=5, nEdgeContexts=3, nFactors=2, has.gamma=FALSE, has.u=TRUE,
  29. nObsFeatures=2, nSrcFeatures=3, nDstFeatures=3, nCtxFeatures=1,
  30. b.sd=1, g0.sd=1, d0.sd=1, h0.sd=0, G.sd=1, D.sd=1, H.sd=0, q.sd=1, r.sd=1,
  31. q.mean=5, r.mean=5,
  32. var_y=0.1, var_alpha=0.5, var_beta=0.5, var_gamma=1, var_v=1, var_u=1, var_w=1,
  33. var_alpha_global=0.2, var_beta_global=0.2,
  34. has.intercept=FALSE,
  35. sparse.matrices=TRUE, frac.zeroFeatures=0.2
  36. );
  37. # (2) Create training/test split
  38. select.train = runif(nrow(d$obs),min=0,max=1) < 0.75;
  39. obs = d$obs; names(obs) = c("src_id", "dst_id", "src_context", "dst_context", "ctx_id", "y");
  40. obs.train = obs[ select.train,]; x_obs.train = data.frame(as.matrix(d$feature$x_obs)[ select.train,,drop=FALSE]);
  41. obs.test = obs[!select.train,]; x_obs.test = data.frame(as.matrix(d$feature$x_obs)[!select.train,,drop=FALSE]);
  42. x_src = data.frame(src_id=1:nrow(d$feature$x_src), as.matrix(d$feature$x_src));
  43. x_dst = data.frame(dst_id=1:nrow(d$feature$x_dst), as.matrix(d$feature$x_dst));
  44. # The following are input data tables:
  45. # obs.train, obs.test, x_src, x_dst
  46. # obs.train and obs.test contain the training and test rating data
  47. # The columns of these two tables are:
  48. # 1. src_id: e.g., user_id
  49. # 2. dst_id: e.g., item_id
  50. # 3. src_context: (optional) This is the context in which the source node gives the rating
  51. # 4. dst_context: (optional) This is the context in which the destination node receives the rating
  52. # 5. ctx_id: (optional) This is the context of this (src_id, dst_id) pair
  53. # 6. y: This is the rating that the source node gives the destination node
  54. # Note: You may set all/any of src_context, dst_context, ctx_id to NULL if there is no context info
  55. # or set all of them to the same vector
  56. # The number of contexts cannot be too many; otherwise, the program will be very slow
  57. str(obs.train); # to see the data structure
  58. str(obs.test); # to see the data structure
  59. # x_src is the source node (e.g., user) feature table
  60. # The first column src_id specifies the source node ID
  61. str(x_src); # to see the data structure
  62. # x_dst is the destination node (e.g., item) feature table
  63. # The first column dst_id specifies the destination node ID
  64. str(x_dst); # to see the data structure
  65. # (3) Index training data
  66. # See src/R/model/multicontext_model_utils.R: indexData() for details
  67. data.train = indexData(
  68. obs=obs.train, src.dst.same=TRUE, rm.self.link=TRUE,
  69. x_obs=x_obs.train, x_src=x_src, x_dst=x_dst,
  70. add.intercept=FALSE,
  71. );
  72. # (4) Index test data
  73. # See src/R/model/multicontext_model_utils.R: indexTestData() for details
  74. data.test = indexTestData(
  75. data.train=data.train, obs=obs.test,
  76. x_obs=x_obs.test, x_src=x_src, x_dst=x_dst,
  77. );
  78. # (5) Setup the model(s) to be fitted
  79. # See src/R/model/multicontext_model_EM.R: run.multicontext(), fit.multicontext()
  80. # Note run.multicontext() is a wrapper to fit multiple models using fit.multicontext().
  81. setting = data.frame(
  82. name = c("wuv", "wvv"),
  83. nFactors = c( 2, 2), # number of interaction factors
  84. has.u = c( T, F), # whether to use u_i' v_j or v_i' v_j
  85. has.gamma = c( F, F), # just set to F
  86. nLocalFactors = c( 0, 0), # just set to 0
  87. is.logistic = c( F, F) # whether to use the logistic model for binary rating
  88. );
  89. dyn.load("lib/c_funcs.so");
  90. source("src/R/c_funcs.R");
  91. source("src/R/util.R");
  92. source("src/R/model/util.R");
  93. source("src/R/model/multicontext_model_genData.R");
  94. source("src/R/model/multicontext_model_utils.R");
  95. source("src/R/model/multicontext_model_MStep.R");
  96. source("src/R/model/multicontext_model_EM.R");
  97. set.seed(2);
  98. # (6) Run the fitting code
  99. # See src/R/model/multicontext_model_EM.R: run.multicontext(), fit.multicontext()
  100. # Note run.multicontext() is a wrapper to fit multiple models using fit.multicontext().
  101. ans = run.multicontext(
  102. obs=data.train$obs, # Observation table
  103. feature=data.train$feature, # Features
  104. setting=setting, # Model setting
  105. nSamples=200, # Number of samples drawn in each E-step: could be a vector of size nIter.
  106. nBurnIn=20, # Number of burn-in draws before take samples for the E-step: could be a vector of size nIter.
  107. nIter=20, # Number of EM iterations
  108. test.obs=data.test$obs, # Test data: Observations for testing (optional)
  109. test.feature=data.test$feature, # Features for testing (optional)
  110. ridge.lambda=1,
  111. IDs=data.test$IDs,
  112. out.level=1, # out.level=1: Save the factor & parameter values to out.dir/model.last and out.dir/model.minTestLoss
  113. out.dir="/tmp/test", # out.level=2: Save the factor & parameter values of each iteration i to out.dir/model.i
  114. out.overwrite=TRUE, # whether to overwrite the output directory if it exists
  115. debug=0, # Set to 0 to disable internal sanity checking; Set to 100 for most detailed sanity checking
  116. verbose=1, # Set to 0 to disable console output; Set to 100 to print everything to the console
  117. verbose.M=2
  118. );
  119. # There may be some warning messages, which are mostly debugging messages and do not mean real problems.
  120. # (7) Checking the model summary
  121. ans$summary[,c("name", "nFactors", "has.u", "has.gamma", "nLocalFactors", "is.logistic", "best.test.loss", "last.test.loss")];
  122. # (8) Load the fitted model(s)
  123. # Here, I only use the "wuv" model as an example
  124. # (8.1) Check the summary file
  125. read.table("/tmp/test_wuv/summary", header=TRUE, sep="\t", as.is=TRUE);
  126. # (8.2) Load the model
  127. load("/tmp/test_wuv/model.last");
  128. # Now, factor and param contain the fitted model
  129. str(factor);
  130. str(param);
  131. # (8.3) Make prediction
  132. prediction = predict.multicontext(
  133. model=list(factor=factor, param=param),
  134. obs=data.test$obs, feature=data.test$feature, is.logistic=FALSE
  135. );
  136. # Now, prediction$pred.y contains the predicted rating for data.test$obs
  137. str(prediction);
  138. ###
  139. ### Example 2: Run the fitting code with synthetic data using SPARSE feature matrix
  140. ###
  141. # (1) Generate some data
  142. # See src/R/model/multicontext_model_genData.R for details
  143. library(Matrix);
  144. dyn.load("lib/c_funcs.so");
  145. source("src/R/c_funcs.R");
  146. source("src/R/util.R");
  147. source("src/R/model/util.R");
  148. source("src/R/model/multicontext_model_genData.R");
  149. source("src/R/model/multicontext_model_utils.R");
  150. set.seed(0);
  151. d = generate.GaussianData(
  152. nSrcNodes=1003, nDstNodes=1003, nObs=100003,
  153. nSrcContexts=3, nDstContexts=3, nEdgeContexts=1, nFactors=3, has.gamma=FALSE, has.u=FALSE,
  154. nObsFeatures=13, nSrcFeatures=19, nDstFeatures=23, nCtxFeatures=1,
  155. b.sd=1, g0.sd=1, d0.sd=1, h0.sd=0, G.sd=1, D.sd=1, H.sd=0, q.sd=1, r.sd=1,
  156. q.mean=5, r.mean=5,
  157. var_y=0.1, var_alpha=0.5, var_beta=0.5, var_gamma=1, var_v=1, var_u=1, var_w=1,
  158. var_alpha_global=0.2, var_beta_global=0.2,
  159. has.intercept=FALSE,
  160. sparse.matrices=TRUE, index.value.format=TRUE, frac.zeroFeatures=0.5
  161. );
  162. names(d$obs) = c("src_id", "dst_id", "src_context", "dst_context", "ctx_id", "y");
  163. d$obs$ctx_id = NULL;
  164. rating.data = d$obs;
  165. x_obs=d$feature$x_obs[order(d$feature$x_obs$row,d$feature$x_obs$col),]; names(x_obs) = c("obs_id", "index", "value");
  166. x_src=d$feature$x_src[order(d$feature$x_src$row,d$feature$x_src$col),]; names(x_src) = c("src_id", "index", "value");
  167. x_dst=d$feature$x_dst[order(d$feature$x_dst$row,d$feature$x_dst$col),]; names(x_dst) = c("dst_id", "index", "value");
  168. #
  169. # Input data: rating.data, x_obs, x_src, x_dst (you need to prepare these four tables for your data)
  170. # Note: All ID numbers start from 1 (not 0)
  171. #
  172. str(rating.data); # see the data structure
  173. # rating.data is the rating data table with the following columns:
  174. # 1. src_id: e.g., user_id or voter_id
  175. # 2. dst_id: e.g., item_id or author_id
  176. # 3. src_context: (optional) This is the context in which the source node gives the rating
  177. # 4. dst_context: (optional) This is the context in which the destination node receives the rating
  178. # 5. y: This is the rating that the source node gives the destination node
  179. # 6. ctx_id: (optional) This is the context of this (src_id, dst_id) pair
  180. # Note: You may set all/any of src_context, dst_context, ctx_id to NULL if there is no context info
  181. # The number of contexts cannot be too many; otherwise, the program will be very slow
  182. str(x_obs);
  183. # x_obs is the feature table for observations with the following columns
  184. # 1. obs_id: observation ID (obs_id=n corresponds to the nth row of rating.data)
  185. # 2. index: feature index
  186. # 3. value: feature value
  187. str(x_src);
  188. # x_src is the feature table for source nodes with the following columns
  189. # 1. src_id: source node ID (this correspond to the src_id column in rating.data)
  190. # 2. index: feature index
  191. # 3. value: feature value
  192. str(x_dst);
  193. # x_dst is the feature table for destination nodes with the following columns
  194. # 1. dst_id: destination node ID (this correspond to the dst_id column in rating.data)
  195. # 2. index: feature index
  196. # 3. value: feature value
  197. # (2) Create training/test split
  198. set.seed(1);
  199. select.train = sample(nrow(rating.data), floor(nrow(rating.data)*0.75));
  200. select.test = setdiff(1:nrow(rating.data), select.train);
  201. obs.train = rating.data[select.train,]; x_obs.train = x_obs[x_obs$obs_id %in% select.train,]; x_obs.train$obs_id = match(x_obs.train$obs_id, select.train);
  202. obs.test = rating.data[select.test, ]; x_obs.test = x_obs[x_obs$obs_id %in% select.test, ]; x_obs.test$obs_id = match(x_obs.test$obs_id, select.test);
  203. # (3) Index training data
  204. # See src/R/model/multicontext_model_utils.R: indexData() for details
  205. data.train = indexData(
  206. obs=obs.train, src.dst.same=TRUE, rm.self.link=TRUE,
  207. x_obs=x_obs.train, x_src=x_src, x_dst=x_dst,
  208. add.intercept=FALSE,
  209. );
  210. # (4) Index test data
  211. # See src/R/model/multicontext_model_utils.R: indexTestData() for details
  212. data.test = indexTestData(
  213. data.train=data.train, obs=obs.test,
  214. x_obs=x_obs.test, x_src=x_src, x_dst=x_dst,
  215. );
  216. # (5) Setup the model(s) to be fitted
  217. # See src/R/model/multicontext_model_EM.R: run.multicontext(), fit.multicontext()
  218. # Note run.multicontext() is a wrapper to fit multiple models using fit.multicontext().
  219. setting = data.frame(
  220. name = c( "uv", "vv"),
  221. nFactors = c( 3, 3), # number of interaction factors
  222. has.u = c( T, F), # whether to use u_i' v_j or v_i' v_j
  223. has.gamma = c( F, F), # just set to F
  224. nLocalFactors = c( 0, 0), # just set to 0
  225. is.logistic = c( F, F) # whether to use the logistic model for binary rating
  226. );
  227. # (6) Run the fitting code
  228. # See src/R/model/multicontext_model_EM.R: run.multicontext(), fit.multicontext()
  229. # Note run.multicontext() is a wrapper to fit multiple models using fit.multicontext().
  230. dyn.load("lib/c_funcs.so");
  231. source("src/R/c_funcs.R");
  232. source("src/R/util.R");
  233. source("src/R/model/util.R");
  234. source("src/R/model/multicontext_model_genData.R");
  235. source("src/R/model/multicontext_model_utils.R");
  236. source("src/R/model/multicontext_model_MStep.R");
  237. source("src/R/model/multicontext_model_EM.R");
  238. source("src/R/model/GLMNet.R");
  239. rnd.seed=1;
  240. ans = run.multicontext(
  241. obs=data.train$obs, # Observation table
  242. feature=data.train$feature, # Features
  243. setting=setting, # Model setting
  244. nSamples=200, # Number of samples drawn in each E-step: could be a vector of size nIter.
  245. nBurnIn=20, # Number of burn-in draws before take samples for the E-step: could be a vector of size nIter.
  246. nIter=10, # Number of EM iterations
  247. test.obs=data.test$obs, # Test data: Observations for testing (optional)
  248. test.feature=data.test$feature, # Features for testing (optional)
  249. reg.algo=GLMNet,
  250. IDs=data.test$IDs,
  251. rnd.seed.init=rnd.seed, rnd.seed.fit=rnd.seed,
  252. out.level=1, # out.level=1: Save the factor & parameter values to out.dir/model.last and out.dir/model.minTestLoss
  253. out.dir="/tmp/test", # out.level=2: Save the factor & parameter values of each iteration i to out.dir/model.i
  254. out.overwrite=TRUE, # whether to overwrite the output directory if it exists
  255. debug=0, # Set to 0 to disable internal sanity checking; Set to 100 for most detailed sanity checking
  256. verbose=1, # Set to 0 to disable console output; Set to 100 to print everything to the console
  257. verbose.M=2
  258. );
  259. ans$summary[,c("name", "nFactors", "has.u", "has.gamma", "nLocalFactors", "is.logistic", "best.test.loss", "last.test.loss")];