/src/models/neural/graph.js

https://github.com/aslanides/aixijs · JavaScript · 222 lines · 172 code · 42 blank · 8 comment · 33 complexity · c6df80124bbe7a709297ef41be02c2e8 MD5 · raw file

  1. // thx Karpathy :)
  2. class Graph {
  3. constructor(needsBackprop) {
  4. this.needsBackprop = needsBackprop;
  5. this.backprop = [];
  6. }
  7. backward() {
  8. for (let i = this.backprop.length - 1; i >= 0; i--) {
  9. this.backprop[i](); // tick!
  10. }
  11. }
  12. rowPluck(m, idx) {
  13. let d = m.d;
  14. let out = new Matrix(d, 1);
  15. for (let i = 0, n = d; i < n; i++) {
  16. out.w[i] = m.w[d * idx + i];
  17. } // copy over the data
  18. if (this.needsBackprop) {
  19. let backward = function () {
  20. for (let i = 0, n = d; i < n; i++) {
  21. m.dw[d * idx + i] += out.dw[i];
  22. }
  23. };
  24. this.backprop.push(backward);
  25. }
  26. return out;
  27. }
  28. tanh(m) {
  29. let out = new Matrix(m.n, m.d);
  30. let n = m.w.length;
  31. for (let i = 0; i < n; i++) {
  32. out.w[i] = Math.tanh(m.w[i]);
  33. }
  34. if (this.needs_backprop) {
  35. let backward = function () {
  36. for (let i = 0; i < n; i++) {
  37. // grad for z = tanh(x) is (1 - z^2)
  38. let mwi = out.w[i];
  39. m.dw[i] += (1.0 - mwi * mwi) * out.dw[i];
  40. }
  41. };
  42. this.backprop.push(backward);
  43. }
  44. return out;
  45. }
  46. sigmoid(m) {
  47. let out = new Matrix(m.n, m.d);
  48. let n = m.w.length;
  49. for (let i = 0; i < n; i++) {
  50. out.w[i] = sig(m.w[i]);
  51. }
  52. if (this.needs_backprop) {
  53. let backward = function () {
  54. for (let i = 0; i < n; i++) {
  55. // grad for z = tanh(x) is (1 - z^2)
  56. let mwi = out.w[i];
  57. m.dw[i] += mwi * (1.0 - mwi) * out.dw[i];
  58. }
  59. };
  60. this.backprop.push(backward);
  61. }
  62. return out;
  63. }
  64. relu(m) {
  65. let out = new Matrix(m.n, m.d);
  66. let n = m.w.length;
  67. for (let i = 0; i < n; i++) {
  68. out.w[i] = Math.max(0, m.w[i]); // relu
  69. }
  70. if (this.needs_backprop) {
  71. let backward = function () {
  72. for (let i = 0; i < n; i++) {
  73. m.dw[i] += m.w[i] > 0 ? out.dw[i] : 0.0;
  74. }
  75. };
  76. this.backprop.push(backward);
  77. }
  78. return out;
  79. }
  80. mul(m1, m2) {
  81. // multiply matrices m1 * m2
  82. Util.assert(m1.d === m2.n, 'matmul dimensions misaligned');
  83. let n = m1.n;
  84. let d = m2.d;
  85. let out = new Matrix(n, d);
  86. for (let i = 0; i < m1.n; i++) { // loop over rows of m1
  87. for (let j = 0; j < m2.d; j++) { // loop over cols of m2
  88. let dot = 0.0;
  89. for (let k = 0; k < m1.d; k++) { // dot product loop
  90. dot += m1.w[m1.d * i + k] * m2.w[m2.d * k + j];
  91. }
  92. out.w[d * i + j] = dot;
  93. }
  94. }
  95. if (this.needs_backprop) {
  96. let backward = function () {
  97. for (let i = 0; i < m1.n; i++) { // loop over rows of m1
  98. for (let j = 0; j < m2.d; j++) { // loop over cols of m2
  99. for (let k = 0; k < m1.d; k++) { // dot product loop
  100. let b = out.dw[d * i + j];
  101. m1.dw[m1.d * i + k] += m2.w[m2.d * k + j] * b;
  102. m2.dw[m2.d * k + j] += m1.w[m1.d * i + k] * b;
  103. }
  104. }
  105. }
  106. };
  107. this.backprop.push(backward);
  108. }
  109. return out;
  110. }
  111. add(m1, m2) {
  112. Util.assert(m1.w.length === m2.w.length);
  113. let out = new Matrix(m1.n, m1.d);
  114. for (let i = 0, n = m1.w.length; i < n; i++) {
  115. out.w[i] = m1.w[i] + m2.w[i];
  116. }
  117. if (this.needs_backprop) {
  118. let backward = function () {
  119. for (let i = 0, n = m1.w.length; i < n; i++) {
  120. m1.dw[i] += out.dw[i];
  121. m2.dw[i] += out.dw[i];
  122. }
  123. };
  124. this.backprop.push(backward);
  125. }
  126. return out;
  127. }
  128. dot(m1, m2) {
  129. // m1 m2 are both column vectors
  130. assert(m1.w.length === m2.w.length);
  131. let out = new Matrix(1, 1);
  132. let dot = 0.0;
  133. for (let i = 0, n = m1.w.length; i < n; i++) {
  134. dot += m1.w[i] * m2.w[i];
  135. }
  136. out.w[0] = dot;
  137. if (this.needs_backprop) {
  138. let backward = function () {
  139. for (let i = 0, n = m1.w.length; i < n; i++) {
  140. m1.dw[i] += m2.w[i] * out.dw[0];
  141. m2.dw[i] += m1.w[i] * out.dw[0];
  142. }
  143. };
  144. this.backprop.push(backward);
  145. }
  146. return out;
  147. }
  148. eltmul(m1, m2) {
  149. assert(m1.w.length === m2.w.length);
  150. let out = new Matrix(m1.n, m1.d);
  151. for (let i = 0, n = m1.w.length; i < n; i++) {
  152. out.w[i] = m1.w[i] * m2.w[i];
  153. }
  154. if (this.needs_backprop) {
  155. let backward = function () {
  156. for (let i = 0, n = m1.w.length; i < n; i++) {
  157. m1.dw[i] += m2.w[i] * out.dw[i];
  158. m2.dw[i] += m1.w[i] * out.dw[i];
  159. }
  160. };
  161. this.backprop.push(backward);
  162. }
  163. return out;
  164. }
  165. softmax(m) {
  166. let out = new Matrix(m.n, m.d); // probability volume
  167. let maxval = -999999;
  168. for (let i = 0, n = m.w.length; i < n; i++) { if (m.w[i] > maxval) maxval = m.w[i]; }
  169. let s = 0.0;
  170. for (let i = 0, n = m.w.length; i < n; i++) {
  171. out.w[i] = Math.exp(m.w[i] - maxval);
  172. s += out.w[i];
  173. }
  174. for (let i = 0, n = m.w.length; i < n; i++) { out.w[i] /= s; }
  175. // no backward pass here needed
  176. // since we will use the computed probabilities outside
  177. // to set gradients directly on m
  178. return out;
  179. }
  180. }