package aima.learning.learners;

import aima.learning.framework.DataSet;
import aima.learning.framework.Example;
import aima.learning.framework.Learner;
import aima.learning.inductive.ConstantDecisonTree;
import aima.learning.inductive.DecisionTree;
import aima.util.Util;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:aima/learning/learners/DecisionTreeLearner.class */
public class DecisionTreeLearner implements Learner {
    private DecisionTree tree;
    private String defaultValue;

    public DecisionTreeLearner() {
        this.defaultValue = "Unable To Classify";
    }

    public DecisionTreeLearner(DecisionTree decisionTree, String str) {
        this.tree = decisionTree;
        this.defaultValue = str;
    }

    @Override // aima.learning.framework.Learner
    public void train(DataSet dataSet) {
        this.tree = decisionTreeLearning(dataSet, dataSet.getNonTargetAttributes(), new ConstantDecisonTree(this.defaultValue));
    }

    @Override // aima.learning.framework.Learner
    public String predict(Example example) {
        return (String) this.tree.predict(example);
    }

    @Override // aima.learning.framework.Learner
    public int[] test(DataSet dataSet) {
        int[] iArr = new int[2];
        for (Example example : dataSet.examples) {
            if (example.targetValue().equals(this.tree.predict(example))) {
                iArr[0] = iArr[0] + 1;
            } else {
                iArr[1] = iArr[1] + 1;
            }
        }
        return iArr;
    }

    private DecisionTree decisionTreeLearning(DataSet dataSet, List<String> list, ConstantDecisonTree constantDecisonTree) {
        if (dataSet.size() == 0) {
            return constantDecisonTree;
        }
        if (allExamplesHaveSameClassification(dataSet)) {
            return new ConstantDecisonTree(dataSet.getExample(0).targetValue());
        }
        if (list.size() == 0) {
            return majorityValue(dataSet);
        }
        String chooseAttribute = chooseAttribute(dataSet, list);
        DecisionTree decisionTree = new DecisionTree(chooseAttribute);
        ConstantDecisonTree majorityValue = majorityValue(dataSet);
        for (String str : dataSet.getPossibleAttributeValues(chooseAttribute)) {
            decisionTree.addNode(str, decisionTreeLearning(dataSet.matchingDataSet(chooseAttribute, str), Util.removeFrom(list, chooseAttribute), majorityValue));
        }
        return decisionTree;
    }

    private ConstantDecisonTree majorityValue(DataSet dataSet) {
        MajorityLearner majorityLearner = new MajorityLearner();
        majorityLearner.train(dataSet);
        return new ConstantDecisonTree(majorityLearner.predict(dataSet.getExample(0)));
    }

    private String chooseAttribute(DataSet dataSet, List<String> list) {
        double d = 0.0d;
        String str = list.get(0);
        for (String str2 : list) {
            double calculateGainFor = dataSet.calculateGainFor(str2);
            if (calculateGainFor > d) {
                d = calculateGainFor;
                str = str2;
            }
        }
        return str;
    }

    private boolean allExamplesHaveSameClassification(DataSet dataSet) {
        String targetValue = dataSet.getExample(0).targetValue();
        Iterator<Example> it = dataSet.iterator();
        while (it.hasNext()) {
            if (!it.next().targetValue().equals(targetValue)) {
                return false;
            }
        }
        return true;
    }

    public DecisionTree getDecisionTree() {
        return this.tree;
    }
}
