/lib/mlbackend/php/phpml/src/Phpml/DimensionReduction/LDA.php

https://github.com/markn86/moodle · PHP · 223 lines · 120 code · 42 blank · 61 comment · 12 complexity · f252a49e7fed9c85de415eb116c9d286 MD5 · raw file

  1. <?php
  2. declare(strict_types=1);
  3. namespace Phpml\DimensionReduction;
  4. use Phpml\Exception\InvalidArgumentException;
  5. use Phpml\Exception\InvalidOperationException;
  6. use Phpml\Math\Matrix;
  7. class LDA extends EigenTransformerBase
  8. {
  9. /**
  10. * @var bool
  11. */
  12. public $fit = false;
  13. /**
  14. * @var array
  15. */
  16. public $labels = [];
  17. /**
  18. * @var array
  19. */
  20. public $means = [];
  21. /**
  22. * @var array
  23. */
  24. public $counts = [];
  25. /**
  26. * @var float[]
  27. */
  28. public $overallMean = [];
  29. /**
  30. * Linear Discriminant Analysis (LDA) is used to reduce the dimensionality
  31. * of the data. Unlike Principal Component Analysis (PCA), it is a supervised
  32. * technique that requires the class labels in order to fit the data to a
  33. * lower dimensional space. <br><br>
  34. * The algorithm can be initialized by speciyfing
  35. * either with the totalVariance(a value between 0.1 and 0.99)
  36. * or numFeatures (number of features in the dataset) to be preserved.
  37. *
  38. * @param float|null $totalVariance Total explained variance to be preserved
  39. * @param int|null $numFeatures Number of features to be preserved
  40. *
  41. * @throws InvalidArgumentException
  42. */
  43. public function __construct(?float $totalVariance = null, ?int $numFeatures = null)
  44. {
  45. if ($totalVariance !== null && ($totalVariance < 0.1 || $totalVariance > 0.99)) {
  46. throw new InvalidArgumentException('Total variance can be a value between 0.1 and 0.99');
  47. }
  48. if ($numFeatures !== null && $numFeatures <= 0) {
  49. throw new InvalidArgumentException('Number of features to be preserved should be greater than 0');
  50. }
  51. if (($totalVariance !== null) === ($numFeatures !== null)) {
  52. throw new InvalidArgumentException('Either totalVariance or numFeatures should be specified in order to run the algorithm');
  53. }
  54. if ($numFeatures !== null) {
  55. $this->numFeatures = $numFeatures;
  56. }
  57. if ($totalVariance !== null) {
  58. $this->totalVariance = $totalVariance;
  59. }
  60. }
  61. /**
  62. * Trains the algorithm to transform the given data to a lower dimensional space.
  63. */
  64. public function fit(array $data, array $classes): array
  65. {
  66. $this->labels = $this->getLabels($classes);
  67. $this->means = $this->calculateMeans($data, $classes);
  68. $sW = $this->calculateClassVar($data, $classes);
  69. $sB = $this->calculateClassCov();
  70. $S = $sW->inverse()->multiply($sB);
  71. $this->eigenDecomposition($S->toArray());
  72. $this->fit = true;
  73. return $this->reduce($data);
  74. }
  75. /**
  76. * Transforms the given sample to a lower dimensional vector by using
  77. * the eigenVectors obtained in the last run of <code>fit</code>.
  78. *
  79. * @throws InvalidOperationException
  80. */
  81. public function transform(array $sample): array
  82. {
  83. if (!$this->fit) {
  84. throw new InvalidOperationException('LDA has not been fitted with respect to original dataset, please run LDA::fit() first');
  85. }
  86. if (!is_array($sample[0])) {
  87. $sample = [$sample];
  88. }
  89. return $this->reduce($sample);
  90. }
  91. /**
  92. * Returns unique labels in the dataset
  93. */
  94. protected function getLabels(array $classes): array
  95. {
  96. $counts = array_count_values($classes);
  97. return array_keys($counts);
  98. }
  99. /**
  100. * Calculates mean of each column for each class and returns
  101. * n by m matrix where n is number of labels and m is number of columns
  102. */
  103. protected function calculateMeans(array $data, array $classes): array
  104. {
  105. $means = [];
  106. $counts = [];
  107. $overallMean = array_fill(0, count($data[0]), 0.0);
  108. foreach ($data as $index => $row) {
  109. $label = array_search($classes[$index], $this->labels, true);
  110. foreach ($row as $col => $val) {
  111. if (!isset($means[$label][$col])) {
  112. $means[$label][$col] = 0.0;
  113. }
  114. $means[$label][$col] += $val;
  115. $overallMean[$col] += $val;
  116. }
  117. if (!isset($counts[$label])) {
  118. $counts[$label] = 0;
  119. }
  120. ++$counts[$label];
  121. }
  122. foreach ($means as $index => $row) {
  123. foreach ($row as $col => $sum) {
  124. $means[$index][$col] = $sum / $counts[$index];
  125. }
  126. }
  127. // Calculate overall mean of the dataset for each column
  128. $numElements = array_sum($counts);
  129. $map = function ($el) use ($numElements) {
  130. return $el / $numElements;
  131. };
  132. $this->overallMean = array_map($map, $overallMean);
  133. $this->counts = $counts;
  134. return $means;
  135. }
  136. /**
  137. * Returns in-class scatter matrix for each class, which
  138. * is a n by m matrix where n is number of classes and
  139. * m is number of columns
  140. */
  141. protected function calculateClassVar(array $data, array $classes): Matrix
  142. {
  143. // s is an n (number of classes) by m (number of column) matrix
  144. $s = array_fill(0, count($data[0]), array_fill(0, count($data[0]), 0));
  145. $sW = new Matrix($s, false);
  146. foreach ($data as $index => $row) {
  147. $label = array_search($classes[$index], $this->labels, true);
  148. $means = $this->means[$label];
  149. $row = $this->calculateVar($row, $means);
  150. $sW = $sW->add($row);
  151. }
  152. return $sW;
  153. }
  154. /**
  155. * Returns between-class scatter matrix for each class, which
  156. * is an n by m matrix where n is number of classes and
  157. * m is number of columns
  158. */
  159. protected function calculateClassCov(): Matrix
  160. {
  161. // s is an n (number of classes) by m (number of column) matrix
  162. $s = array_fill(0, count($this->overallMean), array_fill(0, count($this->overallMean), 0));
  163. $sB = new Matrix($s, false);
  164. foreach ($this->means as $index => $classMeans) {
  165. $row = $this->calculateVar($classMeans, $this->overallMean);
  166. $N = $this->counts[$index];
  167. $sB = $sB->add($row->multiplyByScalar($N));
  168. }
  169. return $sB;
  170. }
  171. /**
  172. * Returns the result of the calculation (x - m)T.(x - m)
  173. */
  174. protected function calculateVar(array $row, array $means): Matrix
  175. {
  176. $x = new Matrix($row, false);
  177. $m = new Matrix($means, false);
  178. $diff = $x->subtract($m);
  179. return $diff->transpose()->multiply($diff);
  180. }
  181. }