PageRenderTime 36ms CodeModel.GetById 10ms app.highlight 19ms RepoModel.GetById 1ms app.codeStats 1ms

/src/R/examples/multicontext_model.R

http://github.com/beechung/Latent-Factor-Models
R | 270 lines | 164 code | 11 blank | 95 comment | 0 complexity | cc1b1b8d1299df22f682c83a8559de94 MD5 | raw file
  1### Copyright (c) 2011, Yahoo! Inc.  All rights reserved.
  2### Copyrights licensed under the New BSD License. See the accompanying LICENSE file for terms.
  3### 
  4### Author: Bee-Chung Chen
  5
  6###
  7### Preparation:
  8###    (1) Set your path/alias to run the right version of R
  9###    (2) make  (in public-factor-models/, not in any subdirectory)
 10###    (3) Take a look at public-factor-models/src/R/model/Notation-multicontext.txt
 11###        for the specification of the model.
 12###    (4) Run R (in public-factor-models/, not in any subdirectory)
 13###
 14
 15###
 16### Example 1: Run the fitting code with synthetic data
 17###
 18# (1) Generate some data
 19#     See src/R/model/multicontext_model_genData.R for details
 20library(Matrix);
 21dyn.load("lib/c_funcs.so");
 22source("src/R/c_funcs.R");
 23source("src/R/util.R");
 24source("src/R/model/util.R");
 25source("src/R/model/multicontext_model_genData.R");
 26source("src/R/model/multicontext_model_utils.R");
 27set.seed(0);
 28d = generate.GaussianData(
 29		nSrcNodes=203, nDstNodes=203, nObs=10003, 
 30		nSrcContexts=4, nDstContexts=5, nEdgeContexts=3, nFactors=2, has.gamma=FALSE, has.u=TRUE,
 31		nObsFeatures=2, nSrcFeatures=3, nDstFeatures=3, nCtxFeatures=1,
 32		b.sd=1, g0.sd=1, d0.sd=1, h0.sd=0, G.sd=1, D.sd=1, H.sd=0, q.sd=1, r.sd=1,
 33		q.mean=5, r.mean=5,
 34		var_y=0.1, var_alpha=0.5, var_beta=0.5, var_gamma=1, var_v=1, var_u=1, var_w=1,
 35		var_alpha_global=0.2, var_beta_global=0.2,
 36		has.intercept=FALSE,
 37		sparse.matrices=TRUE, frac.zeroFeatures=0.2
 38);
 39# (2) Create training/test split
 40select.train = runif(nrow(d$obs),min=0,max=1) < 0.75;
 41obs = d$obs;  names(obs) = c("src_id", "dst_id", "src_context", "dst_context", "ctx_id", "y");
 42obs.train = obs[ select.train,];  x_obs.train = data.frame(as.matrix(d$feature$x_obs)[ select.train,,drop=FALSE]);
 43obs.test  = obs[!select.train,];  x_obs.test  = data.frame(as.matrix(d$feature$x_obs)[!select.train,,drop=FALSE]);
 44x_src = data.frame(src_id=1:nrow(d$feature$x_src), as.matrix(d$feature$x_src));
 45x_dst = data.frame(dst_id=1:nrow(d$feature$x_dst), as.matrix(d$feature$x_dst));
 46# The following are input data tables:
 47#     obs.train, obs.test, x_src, x_dst
 48# obs.train and obs.test contain the training and test rating data
 49#      The columns of these two tables are:
 50#      1. src_id: e.g., user_id
 51#      2. dst_id: e.g., item_id
 52#      3. src_context: (optional) This is the context in which the source node gives the rating
 53#      4. dst_context: (optional) This is the context in which the destination node receives the rating
 54#      5. ctx_id:      (optional) This is the context of this (src_id, dst_id) pair
 55#      6. y: This is the rating that the source node gives the destination node
 56# Note: You may set all/any of src_context, dst_context, ctx_id to NULL if there is no context info
 57#            or set all of them to the same vector
 58#       The number of contexts cannot be too many; otherwise, the program will be very slow
 59str(obs.train); # to see the data structure
 60str(obs.test);  # to see the data structure
 61# x_src is the source node (e.g., user) feature table
 62#       The first column src_id specifies the source node ID
 63str(x_src); # to see the data structure
 64# x_dst is the destination node (e.g., item) feature table
 65#       The first column dst_id specifies the destination node ID
 66str(x_dst); # to see the data structure
 67
 68# (3) Index training data
 69#     See src/R/model/multicontext_model_utils.R: indexData() for details
 70data.train = indexData(
 71		obs=obs.train, src.dst.same=TRUE, rm.self.link=TRUE,
 72		x_obs=x_obs.train, x_src=x_src, x_dst=x_dst,
 73		add.intercept=FALSE,
 74);
 75# (4) Index test data
 76#     See src/R/model/multicontext_model_utils.R: indexTestData() for details
 77data.test = indexTestData(
 78		data.train=data.train, obs=obs.test,
 79		x_obs=x_obs.test, x_src=x_src, x_dst=x_dst,
 80);
 81# (5) Setup the model(s) to be fitted
 82#     See src/R/model/multicontext_model_EM.R: run.multicontext(), fit.multicontext()
 83#     Note run.multicontext() is a wrapper to fit multiple models using fit.multicontext().
 84setting = data.frame(
 85		name          = c("wuv", "wvv"),
 86		nFactors      = c(    2,     2), # number of interaction factors
 87		has.u         = c(    T,     F), # whether to use u_i' v_j or v_i' v_j
 88		has.gamma     = c(    F,     F), # just set to F
 89		nLocalFactors = c(    0,     0), # just set to 0
 90		is.logistic   = c(    F,     F)  # whether to use the logistic model for binary rating
 91);
 92dyn.load("lib/c_funcs.so");
 93source("src/R/c_funcs.R");
 94source("src/R/util.R");
 95source("src/R/model/util.R");
 96source("src/R/model/multicontext_model_genData.R");
 97source("src/R/model/multicontext_model_utils.R");
 98source("src/R/model/multicontext_model_MStep.R");
 99source("src/R/model/multicontext_model_EM.R");
100set.seed(2);
101# (6) Run the fitting code
102#     See src/R/model/multicontext_model_EM.R: run.multicontext(), fit.multicontext()
103#     Note run.multicontext() is a wrapper to fit multiple models using fit.multicontext().
104ans = run.multicontext(
105		obs=data.train$obs,         # Observation table
106		feature=data.train$feature, # Features
107		setting=setting,    # Model setting
108		nSamples=200,   # Number of samples drawn in each E-step: could be a vector of size nIter.
109		nBurnIn=20,     # Number of burn-in draws before take samples for the E-step: could be a vector of size nIter.
110		nIter=20,       # Number of EM iterations
111		test.obs=data.test$obs,         # Test data: Observations for testing (optional)
112		test.feature=data.test$feature, #            Features for testing     (optional)
113		ridge.lambda=1,
114		IDs=data.test$IDs,
115		out.level=1,         # out.level=1: Save the factor & parameter values to out.dir/model.last and out.dir/model.minTestLoss
116		out.dir="/tmp/test", # out.level=2: Save the factor & parameter values of each iteration i to out.dir/model.i
117		out.overwrite=TRUE,     # whether to overwrite the output directory if it exists
118		debug=0,      # Set to 0 to disable internal sanity checking; Set to 100 for most detailed sanity checking
119		verbose=1,    # Set to 0 to disable console output; Set to 100 to print everything to the console
120		verbose.M=2
121);
122# There may be some warning messages, which are mostly debugging messages and do not mean real problems.
123
124# (7) Checking the model summary
125ans$summary[,c("name", "nFactors", "has.u", "has.gamma", "nLocalFactors", "is.logistic", "best.test.loss", "last.test.loss")];
126
127# (8) Load the fitted model(s)
128#     Here, I only use the "wuv" model as an example
129# (8.1) Check the summary file
130read.table("/tmp/test_wuv/summary", header=TRUE, sep="\t", as.is=TRUE);
131# (8.2) Load the model
132load("/tmp/test_wuv/model.last");
133#       Now, factor and param contain the fitted model
134str(factor);
135str(param);
136# (8.3) Make prediction
137prediction = predict.multicontext(
138	model=list(factor=factor, param=param), 
139	obs=data.test$obs, feature=data.test$feature, is.logistic=FALSE
140);
141# Now, prediction$pred.y contains the predicted rating for data.test$obs
142str(prediction);
143
144
145###
146### Example 2: Run the fitting code with synthetic data using SPARSE feature matrix
147###
148# (1) Generate some data
149#     See src/R/model/multicontext_model_genData.R for details
150library(Matrix);
151dyn.load("lib/c_funcs.so");
152source("src/R/c_funcs.R");
153source("src/R/util.R");
154source("src/R/model/util.R");
155source("src/R/model/multicontext_model_genData.R");
156source("src/R/model/multicontext_model_utils.R");
157set.seed(0);
158d = generate.GaussianData(
159		nSrcNodes=1003, nDstNodes=1003, nObs=100003, 
160		nSrcContexts=3, nDstContexts=3, nEdgeContexts=1, nFactors=3, has.gamma=FALSE, has.u=FALSE,
161		nObsFeatures=13, nSrcFeatures=19, nDstFeatures=23, nCtxFeatures=1,
162		b.sd=1, g0.sd=1, d0.sd=1, h0.sd=0, G.sd=1, D.sd=1, H.sd=0, q.sd=1, r.sd=1,
163		q.mean=5, r.mean=5,
164		var_y=0.1, var_alpha=0.5, var_beta=0.5, var_gamma=1, var_v=1, var_u=1, var_w=1,
165		var_alpha_global=0.2, var_beta_global=0.2,
166		has.intercept=FALSE,
167		sparse.matrices=TRUE, index.value.format=TRUE, frac.zeroFeatures=0.5
168);
169names(d$obs) = c("src_id", "dst_id", "src_context", "dst_context", "ctx_id", "y");
170d$obs$ctx_id = NULL;
171rating.data = d$obs;
172x_obs=d$feature$x_obs[order(d$feature$x_obs$row,d$feature$x_obs$col),];  names(x_obs) = c("obs_id", "index", "value");
173x_src=d$feature$x_src[order(d$feature$x_src$row,d$feature$x_src$col),];  names(x_src) = c("src_id", "index", "value");
174x_dst=d$feature$x_dst[order(d$feature$x_dst$row,d$feature$x_dst$col),];  names(x_dst) = c("dst_id", "index", "value");
175
176#
177# Input data: rating.data, x_obs, x_src, x_dst (you need to prepare these four tables for your data)
178# Note: All ID numbers start from 1 (not 0)
179#
180str(rating.data); # see the data structure
181# rating.data is the rating data table with the following columns:
182#      1. src_id: e.g., user_id or voter_id
183#      2. dst_id: e.g., item_id or author_id
184#      3. src_context: (optional) This is the context in which the source node gives the rating
185#      4. dst_context: (optional) This is the context in which the destination node receives the rating
186#      5. y: This is the rating that the source node gives the destination node
187#      6. ctx_id: (optional) This is the context of this (src_id, dst_id) pair
188# Note: You may set all/any of src_context, dst_context, ctx_id to NULL if there is no context info
189#       The number of contexts cannot be too many; otherwise, the program will be very slow
190str(x_obs);
191# x_obs is the feature table for observations with the following columns
192#      1. obs_id: observation ID (obs_id=n corresponds to the nth row of rating.data)
193#      2. index:  feature index
194#      3. value:  feature value
195str(x_src);
196# x_src is the feature table for source nodes with the following columns
197#      1. src_id: source node ID (this correspond to the src_id column in rating.data)
198#      2. index:  feature index
199#      3. value:  feature value
200str(x_dst);
201# x_dst is the feature table for destination nodes with the following columns
202#      1. dst_id: destination node ID (this correspond to the dst_id column in rating.data)
203#      2. index:  feature index
204#      3. value:  feature value
205
206# (2) Create training/test split
207set.seed(1);
208select.train = sample(nrow(rating.data), floor(nrow(rating.data)*0.75));
209select.test  = setdiff(1:nrow(rating.data), select.train);
210obs.train = rating.data[select.train,];  x_obs.train = x_obs[x_obs$obs_id %in% select.train,];  x_obs.train$obs_id = match(x_obs.train$obs_id, select.train);
211obs.test  = rating.data[select.test, ];  x_obs.test  = x_obs[x_obs$obs_id %in% select.test, ];  x_obs.test$obs_id  = match(x_obs.test$obs_id,  select.test);
212
213# (3) Index training data
214#     See src/R/model/multicontext_model_utils.R: indexData() for details
215data.train = indexData(
216		obs=obs.train, src.dst.same=TRUE, rm.self.link=TRUE,
217		x_obs=x_obs.train, x_src=x_src, x_dst=x_dst,
218		add.intercept=FALSE,
219);
220# (4) Index test data
221#     See src/R/model/multicontext_model_utils.R: indexTestData() for details
222data.test = indexTestData(
223		data.train=data.train, obs=obs.test,
224		x_obs=x_obs.test, x_src=x_src, x_dst=x_dst,
225);
226# (5) Setup the model(s) to be fitted
227#     See src/R/model/multicontext_model_EM.R: run.multicontext(), fit.multicontext()
228#     Note run.multicontext() is a wrapper to fit multiple models using fit.multicontext().
229setting = data.frame(
230		name          = c( "uv",  "vv"),
231		nFactors      = c(    3,     3), # number of interaction factors
232		has.u         = c(    T,     F), # whether to use u_i' v_j or v_i' v_j
233		has.gamma     = c(    F,     F), # just set to F
234		nLocalFactors = c(    0,     0), # just set to 0
235		is.logistic   = c(    F,     F)  # whether to use the logistic model for binary rating
236);
237# (6) Run the fitting code
238#     See src/R/model/multicontext_model_EM.R: run.multicontext(), fit.multicontext()
239#     Note run.multicontext() is a wrapper to fit multiple models using fit.multicontext().
240dyn.load("lib/c_funcs.so");
241source("src/R/c_funcs.R");
242source("src/R/util.R");
243source("src/R/model/util.R");
244source("src/R/model/multicontext_model_genData.R");
245source("src/R/model/multicontext_model_utils.R");
246source("src/R/model/multicontext_model_MStep.R");
247source("src/R/model/multicontext_model_EM.R");
248source("src/R/model/GLMNet.R");
249rnd.seed=1;
250ans = run.multicontext(
251		obs=data.train$obs,         # Observation table
252		feature=data.train$feature, # Features
253		setting=setting,    # Model setting
254		nSamples=200,   # Number of samples drawn in each E-step: could be a vector of size nIter.
255		nBurnIn=20,     # Number of burn-in draws before take samples for the E-step: could be a vector of size nIter.
256		nIter=10,       # Number of EM iterations
257		test.obs=data.test$obs,         # Test data: Observations for testing (optional)
258		test.feature=data.test$feature, #            Features for testing     (optional)
259		reg.algo=GLMNet,
260		IDs=data.test$IDs,
261		rnd.seed.init=rnd.seed, rnd.seed.fit=rnd.seed,
262		out.level=1,         # out.level=1: Save the factor & parameter values to out.dir/model.last and out.dir/model.minTestLoss
263		out.dir="/tmp/test", # out.level=2: Save the factor & parameter values of each iteration i to out.dir/model.i
264		out.overwrite=TRUE,     # whether to overwrite the output directory if it exists
265		debug=0,      # Set to 0 to disable internal sanity checking; Set to 100 for most detailed sanity checking
266		verbose=1,    # Set to 0 to disable console output; Set to 100 to print everything to the console
267		verbose.M=2
268);
269ans$summary[,c("name", "nFactors", "has.u", "has.gamma", "nLocalFactors", "is.logistic", "best.test.loss", "last.test.loss")];
270