PageRenderTime 66ms CodeModel.GetById 27ms RepoModel.GetById 0ms app.codeStats 0ms

/extapis/java_jni/src/org/graphlab/toolkits/matrix/als/AlsUpdater.java

https://github.com/michaelkook/GraphLab-2
Java | 95 lines | 54 code | 24 blank | 17 comment | 7 complexity | 6d89cb307fe498b7695419baf5c9d4cd MD5 | raw file
Possible License(s): ISC, Apache-2.0
  1. package org.graphlab.toolkits.matrix.als;
  2. import java.util.Set;
  3. import no.uib.cipr.matrix.DenseCholesky;
  4. import no.uib.cipr.matrix.DenseMatrix;
  5. import no.uib.cipr.matrix.DenseVector;
  6. import no.uib.cipr.matrix.Vector;
  7. import org.graphlab.Context;
  8. import org.graphlab.Updater;
  9. import org.jgrapht.Graphs;
  10. import org.jgrapht.graph.DefaultWeightedEdge;
  11. /**
  12. * Alternating Least Squares updater.
  13. * @author Jiunn Haur Lim <jiunnhal@cmu.edu>
  14. */
  15. public class AlsUpdater
  16. extends Updater<AlsVertex, DefaultWeightedEdge, AlsUpdater> {
  17. /** Regularization parameter */
  18. private static final double LAMBDA = 0.065;
  19. /** Convergence tolerance */
  20. private static final double TOLERANCE = 1e-2;
  21. /** Number of latent factors */
  22. protected static final int NLATENT = 20;
  23. private AlsGraph mGraph;
  24. protected AlsUpdater(AlsGraph graph) {
  25. mGraph = graph;
  26. }
  27. @Override
  28. protected AlsUpdater clone() {
  29. return new AlsUpdater(mGraph);
  30. }
  31. @Override
  32. public void update(Context context, AlsVertex vertex) {
  33. vertex.mSSE = 0;
  34. // if there are no neighbors just return -----------------------------------
  35. Set<DefaultWeightedEdge> edges = mGraph.edgesOf(vertex);
  36. if (edges.isEmpty()) return;
  37. // compute X ---------------------------------------------------------------
  38. DenseMatrix X = new DenseMatrix(edges.size(), NLATENT);
  39. DenseVector y = new DenseVector(edges.size());
  40. int i=0;
  41. for (final DefaultWeightedEdge edge : edges) {
  42. // set x values
  43. Vector neighbor = Graphs.getOppositeVertex(mGraph, edge, vertex).vector();
  44. for (int j = 0; j < NLATENT; j++) X.set(i, j, neighbor.get(j));
  45. // set rating
  46. y.set(i, mGraph.getEdgeWeight(edge));
  47. i++;
  48. }
  49. // compute X'X and X'y -----------------------------------------------------
  50. DenseMatrix XtX = new DenseMatrix(NLATENT, NLATENT);
  51. X.transAmult(X, XtX);
  52. DenseVector Xty = new DenseVector(NLATENT);
  53. X.transMult(y, Xty);
  54. // regularization
  55. for (i = 0; i < NLATENT; i++) XtX.add(i, i, (LAMBDA) * edges.size());
  56. // solve the least squares problem -----------------------------------------
  57. double[] weights = DenseCholesky.factorize(XtX).solve(new DenseMatrix(Xty)).getData();
  58. vertex.setVector(new DenseVector(weights));
  59. // update the RMSE and reschedule neighbors --------------------------------
  60. for (final DefaultWeightedEdge edge : edges) {
  61. // get the neighbor id
  62. AlsVertex neighbor = Graphs.getOppositeVertex(mGraph, edge, vertex);
  63. final double pred = vertex.vector().dot(neighbor.vector());
  64. final double error = Math.abs(mGraph.getEdgeWeight(edge) - pred);
  65. vertex.mSSE += error * error;
  66. // reschedule neighbors
  67. if (error > TOLERANCE && vertex.mResidual > TOLERANCE)
  68. context.schedule(neighbor, new AlsUpdater(mGraph));
  69. }
  70. } // end of operator()
  71. }