package aima.learning.statistics;

import aima.util.Util;
import java.util.Hashtable;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:aima/learning/statistics/PerceptronLearning.class */
public class PerceptronLearning implements NeuralNetworkTrainingScheme {
    private Hashtable<Neuron, Double> neuronDeltaMap = new Hashtable<>();
    private Hashtable<Neuron, Double> neuronBiasMap = new Hashtable<>();
    private Hashtable<Link, Double> linkWeightMap = new Hashtable<>();
    private double learningRate = 0.1d;

    @Override // aima.learning.statistics.NeuralNetworkTrainingScheme
    public void backPropogate(FeedForwardNetwork feedForwardNetwork, List<Double> list, List<Double> list2) {
        if (feedForwardNetwork.layerCount() != 2) {
            throw new RuntimeException("Perceptron larning can be used only with 2 layer networks. This one has " + feedForwardNetwork.layerCount());
        }
        feedForwardNetwork.propogateInput(list);
        calculateDelta(feedForwardNetwork, list2);
    }

    @Override // aima.learning.statistics.NeuralNetworkTrainingScheme
    public void updateWeightsAndBiases(FeedForwardNetwork feedForwardNetwork) {
        for (Neuron neuron : feedForwardNetwork.getOutputLayer().getNeurons()) {
            double doubleValue = this.neuronDeltaMap.get(neuron).doubleValue();
            this.neuronBiasMap.put(neuron, Double.valueOf(neuron.bias()));
            neuron.setBias(neuron.bias() - (this.learningRate * doubleValue));
            for (Link link : neuron.inLinks()) {
                this.linkWeightMap.put(link, Double.valueOf(link.weight()));
                link.setWeight(link.weight() - ((this.learningRate * doubleValue) * link.source().activation()));
            }
        }
    }

    private void calculateDelta(FeedForwardNetwork feedForwardNetwork, List<Double> list) {
        Layer outputLayer = feedForwardNetwork.getOutputLayer();
        Iterator<Neuron> it = outputLayer.iterator();
        Iterator<Double> it2 = outputLayer.getError(list).iterator();
        while (it.hasNext() && it2.hasNext()) {
            this.neuronDeltaMap.put(it.next(), Double.valueOf((-1.0d) * it2.next().doubleValue()));
        }
    }

    @Override // aima.learning.statistics.NeuralNetworkTrainingScheme
    public double error(List<Double> list, FeedForwardNetwork feedForwardNetwork) {
        return Util.sumOfSquares(feedForwardNetwork.error(list));
    }
}
