PageRenderTime 85ms CodeModel.GetById 19ms RepoModel.GetById 1ms app.codeStats 0ms

/lib/mlbackend/php/phpml/src/Phpml/Classification/DecisionTree.php

https://github.com/mackensen/moodle
PHP | 484 lines | 302 code | 78 blank | 104 comment | 46 complexity | f85f1afc84f56058096d35ece9aa2340 MD5 | raw file
  1. <?php
  2. declare(strict_types=1);
  3. namespace Phpml\Classification;
  4. use Phpml\Classification\DecisionTree\DecisionTreeLeaf;
  5. use Phpml\Exception\InvalidArgumentException;
  6. use Phpml\Helper\Predictable;
  7. use Phpml\Helper\Trainable;
  8. use Phpml\Math\Statistic\Mean;
  9. class DecisionTree implements Classifier
  10. {
  11. use Trainable;
  12. use Predictable;
  13. public const CONTINUOUS = 1;
  14. public const NOMINAL = 2;
  15. /**
  16. * @var int
  17. */
  18. public $actualDepth = 0;
  19. /**
  20. * @var array
  21. */
  22. protected $columnTypes = [];
  23. /**
  24. * @var DecisionTreeLeaf
  25. */
  26. protected $tree;
  27. /**
  28. * @var int
  29. */
  30. protected $maxDepth;
  31. /**
  32. * @var array
  33. */
  34. private $labels = [];
  35. /**
  36. * @var int
  37. */
  38. private $featureCount = 0;
  39. /**
  40. * @var int
  41. */
  42. private $numUsableFeatures = 0;
  43. /**
  44. * @var array
  45. */
  46. private $selectedFeatures = [];
  47. /**
  48. * @var array|null
  49. */
  50. private $featureImportances;
  51. /**
  52. * @var array
  53. */
  54. private $columnNames = [];
  55. public function __construct(int $maxDepth = 10)
  56. {
  57. $this->maxDepth = $maxDepth;
  58. }
  59. public function train(array $samples, array $targets): void
  60. {
  61. $this->samples = array_merge($this->samples, $samples);
  62. $this->targets = array_merge($this->targets, $targets);
  63. $this->featureCount = count($this->samples[0]);
  64. $this->columnTypes = self::getColumnTypes($this->samples);
  65. $this->labels = array_keys(array_count_values($this->targets));
  66. $this->tree = $this->getSplitLeaf(range(0, count($this->samples) - 1));
  67. // Each time the tree is trained, feature importances are reset so that
  68. // we will have to compute it again depending on the new data
  69. $this->featureImportances = null;
  70. // If column names are given or computed before, then there is no
  71. // need to init it and accidentally remove the previous given names
  72. if ($this->columnNames === []) {
  73. $this->columnNames = range(0, $this->featureCount - 1);
  74. } elseif (count($this->columnNames) > $this->featureCount) {
  75. $this->columnNames = array_slice($this->columnNames, 0, $this->featureCount);
  76. } elseif (count($this->columnNames) < $this->featureCount) {
  77. $this->columnNames = array_merge(
  78. $this->columnNames,
  79. range(count($this->columnNames), $this->featureCount - 1)
  80. );
  81. }
  82. }
  83. public static function getColumnTypes(array $samples): array
  84. {
  85. $types = [];
  86. $featureCount = count($samples[0]);
  87. for ($i = 0; $i < $featureCount; ++$i) {
  88. $values = array_column($samples, $i);
  89. $isCategorical = self::isCategoricalColumn($values);
  90. $types[] = $isCategorical ? self::NOMINAL : self::CONTINUOUS;
  91. }
  92. return $types;
  93. }
  94. /**
  95. * @param mixed $baseValue
  96. */
  97. public function getGiniIndex($baseValue, array $colValues, array $targets): float
  98. {
  99. $countMatrix = [];
  100. foreach ($this->labels as $label) {
  101. $countMatrix[$label] = [0, 0];
  102. }
  103. foreach ($colValues as $index => $value) {
  104. $label = $targets[$index];
  105. $rowIndex = $value === $baseValue ? 0 : 1;
  106. ++$countMatrix[$label][$rowIndex];
  107. }
  108. $giniParts = [0, 0];
  109. for ($i = 0; $i <= 1; ++$i) {
  110. $part = 0;
  111. $sum = array_sum(array_column($countMatrix, $i));
  112. if ($sum > 0) {
  113. foreach ($this->labels as $label) {
  114. $part += ($countMatrix[$label][$i] / (float) $sum) ** 2;
  115. }
  116. }
  117. $giniParts[$i] = (1 - $part) * $sum;
  118. }
  119. return array_sum($giniParts) / count($colValues);
  120. }
  121. /**
  122. * This method is used to set number of columns to be used
  123. * when deciding a split at an internal node of the tree. <br>
  124. * If the value is given 0, then all features are used (default behaviour),
  125. * otherwise the given value will be used as a maximum for number of columns
  126. * randomly selected for each split operation.
  127. *
  128. * @return $this
  129. *
  130. * @throws InvalidArgumentException
  131. */
  132. public function setNumFeatures(int $numFeatures)
  133. {
  134. if ($numFeatures < 0) {
  135. throw new InvalidArgumentException('Selected column count should be greater or equal to zero');
  136. }
  137. $this->numUsableFeatures = $numFeatures;
  138. return $this;
  139. }
  140. /**
  141. * A string array to represent columns. Useful when HTML output or
  142. * column importances are desired to be inspected.
  143. *
  144. * @return $this
  145. *
  146. * @throws InvalidArgumentException
  147. */
  148. public function setColumnNames(array $names)
  149. {
  150. if ($this->featureCount !== 0 && count($names) !== $this->featureCount) {
  151. throw new InvalidArgumentException(sprintf('Length of the given array should be equal to feature count %s', $this->featureCount));
  152. }
  153. $this->columnNames = $names;
  154. return $this;
  155. }
  156. public function getHtml(): string
  157. {
  158. return $this->tree->getHTML($this->columnNames);
  159. }
  160. /**
  161. * This will return an array including an importance value for
  162. * each column in the given dataset. The importance values are
  163. * normalized and their total makes 1.<br/>
  164. */
  165. public function getFeatureImportances(): array
  166. {
  167. if ($this->featureImportances !== null) {
  168. return $this->featureImportances;
  169. }
  170. $sampleCount = count($this->samples);
  171. $this->featureImportances = [];
  172. foreach ($this->columnNames as $column => $columnName) {
  173. $nodes = $this->getSplitNodesByColumn($column, $this->tree);
  174. $importance = 0;
  175. foreach ($nodes as $node) {
  176. $importance += $node->getNodeImpurityDecrease($sampleCount);
  177. }
  178. $this->featureImportances[$columnName] = $importance;
  179. }
  180. // Normalize & sort the importances
  181. $total = array_sum($this->featureImportances);
  182. if ($total > 0) {
  183. array_walk($this->featureImportances, function (&$importance) use ($total): void {
  184. $importance /= $total;
  185. });
  186. arsort($this->featureImportances);
  187. }
  188. return $this->featureImportances;
  189. }
  190. protected function getSplitLeaf(array $records, int $depth = 0): DecisionTreeLeaf
  191. {
  192. $split = $this->getBestSplit($records);
  193. $split->level = $depth;
  194. if ($this->actualDepth < $depth) {
  195. $this->actualDepth = $depth;
  196. }
  197. // Traverse all records to see if all records belong to the same class,
  198. // otherwise group the records so that we can classify the leaf
  199. // in case maximum depth is reached
  200. $leftRecords = [];
  201. $rightRecords = [];
  202. $remainingTargets = [];
  203. $prevRecord = null;
  204. $allSame = true;
  205. foreach ($records as $recordNo) {
  206. // Check if the previous record is the same with the current one
  207. $record = $this->samples[$recordNo];
  208. if ($prevRecord !== null && $prevRecord != $record) {
  209. $allSame = false;
  210. }
  211. $prevRecord = $record;
  212. // According to the split criteron, this record will
  213. // belong to either left or the right side in the next split
  214. if ($split->evaluate($record)) {
  215. $leftRecords[] = $recordNo;
  216. } else {
  217. $rightRecords[] = $recordNo;
  218. }
  219. // Group remaining targets
  220. $target = $this->targets[$recordNo];
  221. if (!array_key_exists($target, $remainingTargets)) {
  222. $remainingTargets[$target] = 1;
  223. } else {
  224. ++$remainingTargets[$target];
  225. }
  226. }
  227. if ($allSame || $depth >= $this->maxDepth || count($remainingTargets) === 1) {
  228. $split->isTerminal = true;
  229. arsort($remainingTargets);
  230. $split->classValue = (string) key($remainingTargets);
  231. } else {
  232. if (isset($leftRecords[0])) {
  233. $split->leftLeaf = $this->getSplitLeaf($leftRecords, $depth + 1);
  234. }
  235. if (isset($rightRecords[0])) {
  236. $split->rightLeaf = $this->getSplitLeaf($rightRecords, $depth + 1);
  237. }
  238. }
  239. return $split;
  240. }
  241. protected function getBestSplit(array $records): DecisionTreeLeaf
  242. {
  243. $targets = array_intersect_key($this->targets, array_flip($records));
  244. $samples = (array) array_combine(
  245. $records,
  246. $this->preprocess(array_intersect_key($this->samples, array_flip($records)))
  247. );
  248. $bestGiniVal = 1;
  249. $bestSplit = null;
  250. $features = $this->getSelectedFeatures();
  251. foreach ($features as $i) {
  252. $colValues = [];
  253. foreach ($samples as $index => $row) {
  254. $colValues[$index] = $row[$i];
  255. }
  256. $counts = array_count_values($colValues);
  257. arsort($counts);
  258. $baseValue = key($counts);
  259. if ($baseValue === null) {
  260. continue;
  261. }
  262. $gini = $this->getGiniIndex($baseValue, $colValues, $targets);
  263. if ($bestSplit === null || $bestGiniVal > $gini) {
  264. $split = new DecisionTreeLeaf();
  265. $split->value = $baseValue;
  266. $split->giniIndex = $gini;
  267. $split->columnIndex = $i;
  268. $split->isContinuous = $this->columnTypes[$i] === self::CONTINUOUS;
  269. $split->records = $records;
  270. // If a numeric column is to be selected, then
  271. // the original numeric value and the selected operator
  272. // will also be saved into the leaf for future access
  273. if ($this->columnTypes[$i] === self::CONTINUOUS) {
  274. $matches = [];
  275. preg_match("/^([<>=]{1,2})\s*(.*)/", (string) $split->value, $matches);
  276. $split->operator = $matches[1];
  277. $split->numericValue = (float) $matches[2];
  278. }
  279. $bestSplit = $split;
  280. $bestGiniVal = $gini;
  281. }
  282. }
  283. return $bestSplit;
  284. }
  285. /**
  286. * Returns available features/columns to the tree for the decision making
  287. * process. <br>
  288. *
  289. * If a number is given with setNumFeatures() method, then a random selection
  290. * of features up to this number is returned. <br>
  291. *
  292. * If some features are manually selected by use of setSelectedFeatures(),
  293. * then only these features are returned <br>
  294. *
  295. * If any of above methods were not called beforehand, then all features
  296. * are returned by default.
  297. */
  298. protected function getSelectedFeatures(): array
  299. {
  300. $allFeatures = range(0, $this->featureCount - 1);
  301. if ($this->numUsableFeatures === 0 && count($this->selectedFeatures) === 0) {
  302. return $allFeatures;
  303. }
  304. if (count($this->selectedFeatures) > 0) {
  305. return $this->selectedFeatures;
  306. }
  307. $numFeatures = $this->numUsableFeatures;
  308. if ($numFeatures > $this->featureCount) {
  309. $numFeatures = $this->featureCount;
  310. }
  311. shuffle($allFeatures);
  312. $selectedFeatures = array_slice($allFeatures, 0, $numFeatures);
  313. sort($selectedFeatures);
  314. return $selectedFeatures;
  315. }
  316. protected function preprocess(array $samples): array
  317. {
  318. // Detect and convert continuous data column values into
  319. // discrete values by using the median as a threshold value
  320. $columns = [];
  321. for ($i = 0; $i < $this->featureCount; ++$i) {
  322. $values = array_column($samples, $i);
  323. if ($this->columnTypes[$i] == self::CONTINUOUS) {
  324. $median = Mean::median($values);
  325. foreach ($values as &$value) {
  326. if ($value <= $median) {
  327. $value = "<= ${median}";
  328. } else {
  329. $value = "> ${median}";
  330. }
  331. }
  332. }
  333. $columns[] = $values;
  334. }
  335. // Below method is a strange yet very simple & efficient method
  336. // to get the transpose of a 2D array
  337. return array_map(null, ...$columns);
  338. }
  339. protected static function isCategoricalColumn(array $columnValues): bool
  340. {
  341. $count = count($columnValues);
  342. // There are two main indicators that *may* show whether a
  343. // column is composed of discrete set of values:
  344. // 1- Column may contain string values and non-float values
  345. // 2- Number of unique values in the column is only a small fraction of
  346. // all values in that column (Lower than or equal to %20 of all values)
  347. $numericValues = array_filter($columnValues, 'is_numeric');
  348. $floatValues = array_filter($columnValues, 'is_float');
  349. if (count($floatValues) > 0) {
  350. return false;
  351. }
  352. if (count($numericValues) !== $count) {
  353. return true;
  354. }
  355. $distinctValues = array_count_values($columnValues);
  356. return count($distinctValues) <= $count / 5;
  357. }
  358. /**
  359. * Used to set predefined features to consider while deciding which column to use for a split
  360. */
  361. protected function setSelectedFeatures(array $selectedFeatures): void
  362. {
  363. $this->selectedFeatures = $selectedFeatures;
  364. }
  365. /**
  366. * Collects and returns an array of internal nodes that use the given
  367. * column as a split criterion
  368. */
  369. protected function getSplitNodesByColumn(int $column, DecisionTreeLeaf $node): array
  370. {
  371. if ($node->isTerminal) {
  372. return [];
  373. }
  374. $nodes = [];
  375. if ($node->columnIndex === $column) {
  376. $nodes[] = $node;
  377. }
  378. $lNodes = [];
  379. $rNodes = [];
  380. if ($node->leftLeaf !== null) {
  381. $lNodes = $this->getSplitNodesByColumn($column, $node->leftLeaf);
  382. }
  383. if ($node->rightLeaf !== null) {
  384. $rNodes = $this->getSplitNodesByColumn($column, $node->rightLeaf);
  385. }
  386. return array_merge($nodes, $lNodes, $rNodes);
  387. }
  388. /**
  389. * @return mixed
  390. */
  391. protected function predictSample(array $sample)
  392. {
  393. $node = $this->tree;
  394. do {
  395. if ($node->isTerminal) {
  396. return $node->classValue;
  397. }
  398. if ($node->evaluate($sample)) {
  399. $node = $node->leftLeaf;
  400. } else {
  401. $node = $node->rightLeaf;
  402. }
  403. } while ($node);
  404. return $this->labels[0];
  405. }
  406. }