package aima.learning.statistics;

import aima.probability.Randomizer;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:aima/learning/statistics/FeedForwardNetwork.class */
public class FeedForwardNetwork {
    private List<Layer> layers;

    public FeedForwardNetwork() {
        this.layers = new ArrayList();
    }

    public FeedForwardNetwork(int i, int i2, Randomizer randomizer) {
        this();
        Layer layer = new Layer(i, 1.0d, new SigmoidActivationFunction());
        Layer layer2 = new Layer(i2, 1.0d, new IdentityActivationFunction());
        addLayer(layer, randomizer);
        addLayer(layer2, randomizer);
    }

    public FeedForwardNetwork(int i, int i2, int i3, Randomizer randomizer) {
        this();
        Layer layer = new Layer(i, 1.0d, new SigmoidActivationFunction());
        Layer layer2 = new Layer(i2, 1.0d, new SigmoidActivationFunction());
        Layer layer3 = new Layer(i3, 1.0d, new IdentityActivationFunction());
        addLayer(layer, randomizer);
        addLayer(layer2, randomizer);
        addLayer(layer3, randomizer);
    }

    public void addLayer(Layer layer, Randomizer randomizer) {
        if (atLeastOneLayerPresent()) {
            getLastLayer().connectTo(layer, randomizer);
        }
        this.layers.add(layer);
    }

    public void propogateInput(List<Double> list) {
        getInputLayer().acceptInput(list);
        if (layerCount() > 2) {
            Iterator<Layer> it = getHiddenLayers().iterator();
            while (it.hasNext()) {
                it.next().update();
            }
        }
        getOutputLayer().update();
    }

    public List<Double> output() {
        return getOutputLayer().activation();
    }

    public List<Double> error(List<Double> list) {
        return getOutputLayer().getError(list);
    }

    private boolean atLeastOneLayerPresent() {
        return this.layers.size() > 0;
    }

    private Layer getLastLayer() {
        if (this.layers.size() == 0) {
            throw new RuntimeException("cannot call this method on network with zero layers");
        }
        return this.layers.get(this.layers.size() - 1);
    }

    public Layer getOutputLayer() {
        return getLastLayer();
    }

    public Layer getInputLayer() {
        return this.layers.get(0);
    }

    public List<Layer> getHiddenLayers() {
        if (this.layers.size() < 3) {
            throw new RuntimeException("cannot call this method on network with " + this.layers.size() + " layers");
        }
        ArrayList arrayList = new ArrayList();
        for (int i = 1; i < this.layers.size() - 1; i++) {
            arrayList.add(this.layers.get(i));
        }
        return arrayList;
    }

    public int layerCount() {
        return this.layers.size();
    }
}
