Your IP : 192.168.165.1


Current Path : C:/xampp/htdocs/moodle/lib/mlbackend/php/phpml/src/Phpml/Classification/
Upload File :
Current File : C:/xampp/htdocs/moodle/lib/mlbackend/php/phpml/src/Phpml/Classification/DecisionTree.php

<?php

declare(strict_types=1);

namespace Phpml\Classification;

use Phpml\Classification\DecisionTree\DecisionTreeLeaf;
use Phpml\Exception\InvalidArgumentException;
use Phpml\Helper\Predictable;
use Phpml\Helper\Trainable;
use Phpml\Math\Statistic\Mean;

class DecisionTree implements Classifier
{
    use Trainable;
    use Predictable;

    public const CONTINUOUS = 1;

    public const NOMINAL = 2;

    /**
     * @var int
     */
    public $actualDepth = 0;

    /**
     * @var array
     */
    protected $columnTypes = [];

    /**
     * @var DecisionTreeLeaf
     */
    protected $tree;

    /**
     * @var int
     */
    protected $maxDepth;

    /**
     * @var array
     */
    private $labels = [];

    /**
     * @var int
     */
    private $featureCount = 0;

    /**
     * @var int
     */
    private $numUsableFeatures = 0;

    /**
     * @var array
     */
    private $selectedFeatures = [];

    /**
     * @var array|null
     */
    private $featureImportances;

    /**
     * @var array
     */
    private $columnNames = [];

    public function __construct(int $maxDepth = 10)
    {
        $this->maxDepth = $maxDepth;
    }

    public function train(array $samples, array $targets): void
    {
        $this->samples = array_merge($this->samples, $samples);
        $this->targets = array_merge($this->targets, $targets);

        $this->featureCount = count($this->samples[0]);
        $this->columnTypes = self::getColumnTypes($this->samples);
        $this->labels = array_keys(array_count_values($this->targets));
        $this->tree = $this->getSplitLeaf(range(0, count($this->samples) - 1));

        // Each time the tree is trained, feature importances are reset so that
        // we will have to compute it again depending on the new data
        $this->featureImportances = null;

        // If column names are given or computed before, then there is no
        // need to init it and accidentally remove the previous given names
        if ($this->columnNames === []) {
            $this->columnNames = range(0, $this->featureCount - 1);
        } elseif (count($this->columnNames) > $this->featureCount) {
            $this->columnNames = array_slice($this->columnNames, 0, $this->featureCount);
        } elseif (count($this->columnNames) < $this->featureCount) {
            $this->columnNames = array_merge(
                $this->columnNames,
                range(count($this->columnNames), $this->featureCount - 1)
            );
        }
    }

    public static function getColumnTypes(array $samples): array
    {
        $types = [];
        $featureCount = count($samples[0]);
        for ($i = 0; $i < $featureCount; ++$i) {
            $values = array_column($samples, $i);
            $isCategorical = self::isCategoricalColumn($values);
            $types[] = $isCategorical ? self::NOMINAL : self::CONTINUOUS;
        }

        return $types;
    }

    /**
     * @param mixed $baseValue
     */
    public function getGiniIndex($baseValue, array $colValues, array $targets): float
    {
        $countMatrix = [];
        foreach ($this->labels as $label) {
            $countMatrix[$label] = [0, 0];
        }

        foreach ($colValues as $index => $value) {
            $label = $targets[$index];
            $rowIndex = $value === $baseValue ? 0 : 1;
            ++$countMatrix[$label][$rowIndex];
        }

        $giniParts = [0, 0];
        for ($i = 0; $i <= 1; ++$i) {
            $part = 0;
            $sum = array_sum(array_column($countMatrix, $i));
            if ($sum > 0) {
                foreach ($this->labels as $label) {
                    $part += ($countMatrix[$label][$i] / (float) $sum) ** 2;
                }
            }

            $giniParts[$i] = (1 - $part) * $sum;
        }

        return array_sum($giniParts) / count($colValues);
    }

    /**
     * This method is used to set number of columns to be used
     * when deciding a split at an internal node of the tree.  <br>
     * If the value is given 0, then all features are used (default behaviour),
     * otherwise the given value will be used as a maximum for number of columns
     * randomly selected for each split operation.
     *
     * @return $this
     *
     * @throws InvalidArgumentException
     */
    public function setNumFeatures(int $numFeatures)
    {
        if ($numFeatures < 0) {
            throw new InvalidArgumentException('Selected column count should be greater or equal to zero');
        }

        $this->numUsableFeatures = $numFeatures;

        return $this;
    }

    /**
     * A string array to represent columns. Useful when HTML output or
     * column importances are desired to be inspected.
     *
     * @return $this
     *
     * @throws InvalidArgumentException
     */
    public function setColumnNames(array $names)
    {
        if ($this->featureCount !== 0 && count($names) !== $this->featureCount) {
            throw new InvalidArgumentException(sprintf('Length of the given array should be equal to feature count %s', $this->featureCount));
        }

        $this->columnNames = $names;

        return $this;
    }

    public function getHtml(): string
    {
        return $this->tree->getHTML($this->columnNames);
    }

    /**
     * This will return an array including an importance value for
     * each column in the given dataset. The importance values are
     * normalized and their total makes 1.<br/>
     */
    public function getFeatureImportances(): array
    {
        if ($this->featureImportances !== null) {
            return $this->featureImportances;
        }

        $sampleCount = count($this->samples);
        $this->featureImportances = [];
        foreach ($this->columnNames as $column => $columnName) {
            $nodes = $this->getSplitNodesByColumn($column, $this->tree);

            $importance = 0;
            foreach ($nodes as $node) {
                $importance += $node->getNodeImpurityDecrease($sampleCount);
            }

            $this->featureImportances[$columnName] = $importance;
        }

        // Normalize & sort the importances
        $total = array_sum($this->featureImportances);
        if ($total > 0) {
            array_walk($this->featureImportances, function (&$importance) use ($total): void {
                $importance /= $total;
            });
            arsort($this->featureImportances);
        }

        return $this->featureImportances;
    }

    protected function getSplitLeaf(array $records, int $depth = 0): DecisionTreeLeaf
    {
        $split = $this->getBestSplit($records);
        $split->level = $depth;
        if ($this->actualDepth < $depth) {
            $this->actualDepth = $depth;
        }

        // Traverse all records to see if all records belong to the same class,
        // otherwise group the records so that we can classify the leaf
        // in case maximum depth is reached
        $leftRecords = [];
        $rightRecords = [];
        $remainingTargets = [];
        $prevRecord = null;
        $allSame = true;

        foreach ($records as $recordNo) {
            // Check if the previous record is the same with the current one
            $record = $this->samples[$recordNo];
            if ($prevRecord !== null && $prevRecord != $record) {
                $allSame = false;
            }

            $prevRecord = $record;

            // According to the split criteron, this record will
            // belong to either left or the right side in the next split
            if ($split->evaluate($record)) {
                $leftRecords[] = $recordNo;
            } else {
                $rightRecords[] = $recordNo;
            }

            // Group remaining targets
            $target = $this->targets[$recordNo];
            if (!array_key_exists($target, $remainingTargets)) {
                $remainingTargets[$target] = 1;
            } else {
                ++$remainingTargets[$target];
            }
        }

        if ($allSame || $depth >= $this->maxDepth || count($remainingTargets) === 1) {
            $split->isTerminal = true;
            arsort($remainingTargets);
            $split->classValue = (string) key($remainingTargets);
        } else {
            if (isset($leftRecords[0])) {
                $split->leftLeaf = $this->getSplitLeaf($leftRecords, $depth + 1);
            }

            if (isset($rightRecords[0])) {
                $split->rightLeaf = $this->getSplitLeaf($rightRecords, $depth + 1);
            }
        }

        return $split;
    }

    protected function getBestSplit(array $records): DecisionTreeLeaf
    {
        $targets = array_intersect_key($this->targets, array_flip($records));
        $samples = (array) array_combine(
            $records,
            $this->preprocess(array_intersect_key($this->samples, array_flip($records)))
        );
        $bestGiniVal = 1;
        $bestSplit = null;
        $features = $this->getSelectedFeatures();
        foreach ($features as $i) {
            $colValues = [];
            foreach ($samples as $index => $row) {
                $colValues[$index] = $row[$i];
            }

            $counts = array_count_values($colValues);
            arsort($counts);
            $baseValue = key($counts);
            if ($baseValue === null) {
                continue;
            }

            $gini = $this->getGiniIndex($baseValue, $colValues, $targets);
            if ($bestSplit === null || $bestGiniVal > $gini) {
                $split = new DecisionTreeLeaf();
                $split->value = $baseValue;
                $split->giniIndex = $gini;
                $split->columnIndex = $i;
                $split->isContinuous = $this->columnTypes[$i] === self::CONTINUOUS;
                $split->records = $records;

                // If a numeric column is to be selected, then
                // the original numeric value and the selected operator
                // will also be saved into the leaf for future access
                if ($this->columnTypes[$i] === self::CONTINUOUS) {
                    $matches = [];
                    preg_match("/^([<>=]{1,2})\s*(.*)/", (string) $split->value, $matches);
                    $split->operator = $matches[1];
                    $split->numericValue = (float) $matches[2];
                }

                $bestSplit = $split;
                $bestGiniVal = $gini;
            }
        }

        return $bestSplit;
    }

    /**
     * Returns available features/columns to the tree for the decision making
     * process. <br>
     *
     * If a number is given with setNumFeatures() method, then a random selection
     * of features up to this number is returned. <br>
     *
     * If some features are manually selected by use of setSelectedFeatures(),
     * then only these features are returned <br>
     *
     * If any of above methods were not called beforehand, then all features
     * are returned by default.
     */
    protected function getSelectedFeatures(): array
    {
        $allFeatures = range(0, $this->featureCount - 1);
        if ($this->numUsableFeatures === 0 && count($this->selectedFeatures) === 0) {
            return $allFeatures;
        }

        if (count($this->selectedFeatures) > 0) {
            return $this->selectedFeatures;
        }

        $numFeatures = $this->numUsableFeatures;
        if ($numFeatures > $this->featureCount) {
            $numFeatures = $this->featureCount;
        }

        shuffle($allFeatures);
        $selectedFeatures = array_slice($allFeatures, 0, $numFeatures);
        sort($selectedFeatures);

        return $selectedFeatures;
    }

    protected function preprocess(array $samples): array
    {
        // Detect and convert continuous data column values into
        // discrete values by using the median as a threshold value
        $columns = [];
        for ($i = 0; $i < $this->featureCount; ++$i) {
            $values = array_column($samples, $i);
            if ($this->columnTypes[$i] == self::CONTINUOUS) {
                $median = Mean::median($values);
                foreach ($values as &$value) {
                    if ($value <= $median) {
                        $value = "<= ${median}";
                    } else {
                        $value = "> ${median}";
                    }
                }
            }

            $columns[] = $values;
        }

        // Below method is a strange yet very simple & efficient method
        // to get the transpose of a 2D array
        return array_map(null, ...$columns);
    }

    protected static function isCategoricalColumn(array $columnValues): bool
    {
        $count = count($columnValues);

        // There are two main indicators that *may* show whether a
        // column is composed of discrete set of values:
        // 1- Column may contain string values and non-float values
        // 2- Number of unique values in the column is only a small fraction of
        //	  all values in that column (Lower than or equal to %20 of all values)
        $numericValues = array_filter($columnValues, 'is_numeric');
        $floatValues = array_filter($columnValues, 'is_float');
        if (count($floatValues) > 0) {
            return false;
        }

        if (count($numericValues) !== $count) {
            return true;
        }

        $distinctValues = array_count_values($columnValues);

        return count($distinctValues) <= $count / 5;
    }

    /**
     * Used to set predefined features to consider while deciding which column to use for a split
     */
    protected function setSelectedFeatures(array $selectedFeatures): void
    {
        $this->selectedFeatures = $selectedFeatures;
    }

    /**
     * Collects and returns an array of internal nodes that use the given
     * column as a split criterion
     */
    protected function getSplitNodesByColumn(int $column, DecisionTreeLeaf $node): array
    {
        if ($node->isTerminal) {
            return [];
        }

        $nodes = [];
        if ($node->columnIndex === $column) {
            $nodes[] = $node;
        }

        $lNodes = [];
        $rNodes = [];
        if ($node->leftLeaf !== null) {
            $lNodes = $this->getSplitNodesByColumn($column, $node->leftLeaf);
        }

        if ($node->rightLeaf !== null) {
            $rNodes = $this->getSplitNodesByColumn($column, $node->rightLeaf);
        }

        return array_merge($nodes, $lNodes, $rNodes);
    }

    /**
     * @return mixed
     */
    protected function predictSample(array $sample)
    {
        $node = $this->tree;
        do {
            if ($node->isTerminal) {
                return $node->classValue;
            }

            if ($node->evaluate($sample)) {
                $node = $node->leftLeaf;
            } else {
                $node = $node->rightLeaf;
            }
        } while ($node);

        return $this->labels[0];
    }
}