/lib/mlbackend/php/phpml/src/Phpml/Clustering/KMeans/Space.php

https://github.com/markn86/moodle · PHP · 259 lines · 163 code · 57 blank · 39 comment · 17 complexity · 15d9435d9c6b8917805584fd1710739d MD5 · raw file

  1. <?php
  2. declare(strict_types=1);
  3. namespace Phpml\Clustering\KMeans;
  4. use InvalidArgumentException;
  5. use LogicException;
  6. use Phpml\Clustering\KMeans;
  7. use SplObjectStorage;
  8. class Space extends SplObjectStorage
  9. {
  10. /**
  11. * @var int
  12. */
  13. protected $dimension;
  14. public function __construct(int $dimension)
  15. {
  16. if ($dimension < 1) {
  17. throw new LogicException('a space dimension cannot be null or negative');
  18. }
  19. $this->dimension = $dimension;
  20. }
  21. public function toArray(): array
  22. {
  23. $points = [];
  24. /** @var Point $point */
  25. foreach ($this as $point) {
  26. $points[] = $point->toArray();
  27. }
  28. return ['points' => $points];
  29. }
  30. /**
  31. * @param mixed $label
  32. */
  33. public function newPoint(array $coordinates, $label = null): Point
  34. {
  35. if (count($coordinates) !== $this->dimension) {
  36. throw new LogicException('('.implode(',', $coordinates).') is not a point of this space');
  37. }
  38. return new Point($coordinates, $label);
  39. }
  40. /**
  41. * @param mixed $label
  42. * @param mixed $data
  43. */
  44. public function addPoint(array $coordinates, $label = null, $data = null): void
  45. {
  46. $this->attach($this->newPoint($coordinates, $label), $data);
  47. }
  48. /**
  49. * @param object $point
  50. * @param mixed $data
  51. */
  52. public function attach($point, $data = null): void
  53. {
  54. if (!$point instanceof Point) {
  55. throw new InvalidArgumentException('can only attach points to spaces');
  56. }
  57. parent::attach($point, $data);
  58. }
  59. public function getDimension(): int
  60. {
  61. return $this->dimension;
  62. }
  63. /**
  64. * @return array|bool
  65. */
  66. public function getBoundaries()
  67. {
  68. if (count($this) === 0) {
  69. return false;
  70. }
  71. $min = $this->newPoint(array_fill(0, $this->dimension, null));
  72. $max = $this->newPoint(array_fill(0, $this->dimension, null));
  73. /** @var self $point */
  74. foreach ($this as $point) {
  75. for ($n = 0; $n < $this->dimension; ++$n) {
  76. if ($min[$n] === null || $min[$n] > $point[$n]) {
  77. $min[$n] = $point[$n];
  78. }
  79. if ($max[$n] === null || $max[$n] < $point[$n]) {
  80. $max[$n] = $point[$n];
  81. }
  82. }
  83. }
  84. return [$min, $max];
  85. }
  86. public function getRandomPoint(Point $min, Point $max): Point
  87. {
  88. $point = $this->newPoint(array_fill(0, $this->dimension, null));
  89. for ($n = 0; $n < $this->dimension; ++$n) {
  90. $point[$n] = random_int($min[$n], $max[$n]);
  91. }
  92. return $point;
  93. }
  94. /**
  95. * @return Cluster[]
  96. */
  97. public function cluster(int $clustersNumber, int $initMethod = KMeans::INIT_RANDOM): array
  98. {
  99. $clusters = $this->initializeClusters($clustersNumber, $initMethod);
  100. do {
  101. } while (!$this->iterate($clusters));
  102. return $clusters;
  103. }
  104. /**
  105. * @return Cluster[]
  106. */
  107. protected function initializeClusters(int $clustersNumber, int $initMethod): array
  108. {
  109. switch ($initMethod) {
  110. case KMeans::INIT_RANDOM:
  111. $clusters = $this->initializeRandomClusters($clustersNumber);
  112. break;
  113. case KMeans::INIT_KMEANS_PLUS_PLUS:
  114. $clusters = $this->initializeKMPPClusters($clustersNumber);
  115. break;
  116. default:
  117. return [];
  118. }
  119. $clusters[0]->attachAll($this);
  120. return $clusters;
  121. }
  122. /**
  123. * @param Cluster[] $clusters
  124. */
  125. protected function iterate(array $clusters): bool
  126. {
  127. $convergence = true;
  128. $attach = new SplObjectStorage();
  129. $detach = new SplObjectStorage();
  130. foreach ($clusters as $cluster) {
  131. foreach ($cluster as $point) {
  132. $closest = $point->getClosest($clusters);
  133. if ($closest !== $cluster) {
  134. $attach[$closest] ?? $attach[$closest] = new SplObjectStorage();
  135. $detach[$cluster] ?? $detach[$cluster] = new SplObjectStorage();
  136. $attach[$closest]->attach($point);
  137. $detach[$cluster]->attach($point);
  138. $convergence = false;
  139. }
  140. }
  141. }
  142. /** @var Cluster $cluster */
  143. foreach ($attach as $cluster) {
  144. $cluster->attachAll($attach[$cluster]);
  145. }
  146. /** @var Cluster $cluster */
  147. foreach ($detach as $cluster) {
  148. $cluster->detachAll($detach[$cluster]);
  149. }
  150. foreach ($clusters as $cluster) {
  151. $cluster->updateCentroid();
  152. }
  153. return $convergence;
  154. }
  155. /**
  156. * @return Cluster[]
  157. */
  158. protected function initializeKMPPClusters(int $clustersNumber): array
  159. {
  160. $clusters = [];
  161. $this->rewind();
  162. /** @var Point $current */
  163. $current = $this->current();
  164. $clusters[] = new Cluster($this, $current->getCoordinates());
  165. $distances = new SplObjectStorage();
  166. for ($i = 1; $i < $clustersNumber; ++$i) {
  167. $sum = 0;
  168. /** @var Point $point */
  169. foreach ($this as $point) {
  170. $closest = $point->getClosest($clusters);
  171. if ($closest === null) {
  172. continue;
  173. }
  174. $distance = $point->getDistanceWith($closest);
  175. $sum += $distances[$point] = $distance;
  176. }
  177. $sum = random_int(0, (int) $sum);
  178. /** @var Point $point */
  179. foreach ($this as $point) {
  180. $sum -= $distances[$point];
  181. if ($sum > 0) {
  182. continue;
  183. }
  184. $clusters[] = new Cluster($this, $point->getCoordinates());
  185. break;
  186. }
  187. }
  188. return $clusters;
  189. }
  190. /**
  191. * @return Cluster[]
  192. */
  193. private function initializeRandomClusters(int $clustersNumber): array
  194. {
  195. $clusters = [];
  196. [$min, $max] = $this->getBoundaries();
  197. for ($n = 0; $n < $clustersNumber; ++$n) {
  198. $clusters[] = new Cluster($this, $this->getRandomPoint($min, $max)->getCoordinates());
  199. }
  200. return $clusters;
  201. }
  202. }