package aima.learning.learners;

import aima.learning.framework.DataSet;
import aima.learning.framework.Example;
import aima.learning.framework.Learner;
import aima.learning.statistics.FeedForwardNetwork;
import aima.learning.statistics.IrisDataSetNumerizer;
import aima.learning.statistics.NeuralNetworkTrainingScheme;
import aima.util.Pair;
import aima.util.Util;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:aima/learning/learners/NeuralNetLearner.class */
public class NeuralNetLearner implements Learner {
    private FeedForwardNetwork network;
    private IrisDataSetNumerizer numerizer;
    private int numberOfIterations;
    private NeuralNetworkTrainingScheme trainingScheme;

    public NeuralNetLearner(FeedForwardNetwork feedForwardNetwork, IrisDataSetNumerizer irisDataSetNumerizer, NeuralNetworkTrainingScheme neuralNetworkTrainingScheme, int i) {
        this.network = feedForwardNetwork;
        this.numerizer = irisDataSetNumerizer;
        this.numberOfIterations = i;
        this.trainingScheme = neuralNetworkTrainingScheme;
    }

    @Override // aima.learning.framework.Learner
    public void train(DataSet dataSet) {
        for (int i = 0; i < this.numberOfIterations; i++) {
            double d = 0.0d;
            Iterator<Example> it = dataSet.examples.iterator();
            while (it.hasNext()) {
                Pair<List<Double>, List<Double>> numerize = this.numerizer.numerize(it.next());
                List<Double> first = numerize.getFirst();
                List<Double> second = numerize.getSecond();
                this.trainingScheme.backPropogate(this.network, Util.normalize(first), second);
                this.trainingScheme.updateWeightsAndBiases(this.network);
                d += this.trainingScheme.error(second, this.network);
            }
        }
    }

    @Override // aima.learning.framework.Learner
    public String predict(Example example) {
        Pair<List<Double>, List<Double>> numerize = this.numerizer.numerize(example);
        List<Double> first = numerize.getFirst();
        numerize.getSecond();
        this.network.propogateInput(first);
        return this.numerizer.denumerize(this.network.output());
    }

    @Override // aima.learning.framework.Learner
    public int[] test(DataSet dataSet) {
        int[] iArr = new int[2];
        for (Example example : dataSet.examples) {
            if (example.targetValue().equals(predict(example))) {
                iArr[0] = iArr[0] + 1;
            } else {
                iArr[1] = iArr[1] + 1;
            }
        }
        return iArr;
    }
}
