package aima.test.learningtest;

import aima.learning.statistics.FeedForwardNetwork;
import aima.learning.statistics.IdentityActivationFunction;
import aima.learning.statistics.Layer;
import aima.learning.statistics.Link;
import aima.learning.statistics.LogSigActivationFunction;
import aima.learning.statistics.Neuron;
import aima.learning.statistics.SquareActivationFunction;
import aima.learning.statistics.StandardBackPropogation;
import aima.test.probabilitytest.MockRandomizer;
import java.util.Arrays;
import java.util.List;
import junit.framework.Assert;
import junit.framework.TestCase;

/* loaded from: input_file:aima/test/learningtest/NeuralNetworkTest.class */
public class NeuralNetworkTest extends TestCase {
    public double testFunction(double d) {
        return 1.0d + (Math.sin(0.7853981633974483d) * d);
    }

    public void testDefaultValuesOfUnConnectedNeuron() {
        Neuron neuron = new Neuron();
        Assert.assertEquals(Double.valueOf(1.0d), Double.valueOf(neuron.bias()));
        Assert.assertEquals(0, neuron.outLinks().size());
        Assert.assertEquals(0, neuron.inLinks().size());
    }

    public void testBiasSettingOnNeuron() {
        Assert.assertEquals(Double.valueOf(3.0d), Double.valueOf(new Neuron(3.0d).bias()));
    }

    public void testLinkCreation() {
        Neuron neuron = new Neuron();
        Neuron neuron2 = new Neuron();
        Link link = new Link(neuron, neuron2, 4.0d);
        Assert.assertEquals(neuron, link.source());
        Assert.assertEquals(neuron2, link.target());
        Assert.assertEquals(Double.valueOf(4.0d), Double.valueOf(link.weight()));
    }

    public void testNeuronConnection() {
        Neuron neuron = new Neuron();
        Neuron neuron2 = new Neuron();
        neuron.connectTo(neuron2, 5.0d);
        Assert.assertEquals(1, neuron.outLinks().size());
        Assert.assertEquals(1, neuron2.inLinks().size());
        Assert.assertEquals(Double.valueOf(5.0d), neuron2.weights().get(0));
    }

    public void testActivationOfConnectedNeurons() {
        Neuron neuron = new Neuron();
        Neuron neuron2 = new Neuron(0.0d, new SquareActivationFunction());
        Neuron neuron3 = new Neuron(1.0d, new IdentityActivationFunction());
        neuron.connectTo(neuron2, 1.0d);
        neuron2.connectTo(neuron3, 2.0d);
        neuron.acceptAsInput(2.0d);
        neuron2.update();
        Assert.assertEquals(Double.valueOf(4.0d), Double.valueOf(neuron2.activation()));
        neuron3.update();
        Assert.assertEquals(Double.valueOf(9.0d), Double.valueOf(neuron3.activation()));
    }

    public void testLayerConstruction() {
        Layer layer = new Layer(3);
        Neuron neuron = layer.getNeuron(0);
        Neuron neuron2 = layer.getNeuron(1);
        Neuron neuron3 = layer.getNeuron(2);
        Assert.assertNotNull(neuron);
        Assert.assertNotNull(neuron2);
        Assert.assertNotNull(neuron3);
    }

    public void testLayerAcceptsInputAndGeneratesErrorCorrectly() {
        Layer layer = new Layer(3);
        layer.getNeuron(0);
        layer.getNeuron(1);
        layer.getNeuron(2);
        List<Double> asList = Arrays.asList(Double.valueOf(1.0d), Double.valueOf(2.0d), Double.valueOf(3.0d));
        List<Double> asList2 = Arrays.asList(Double.valueOf(2.0d), Double.valueOf(3.0d), Double.valueOf(4.0d));
        List asList3 = Arrays.asList(Double.valueOf(1.0d), Double.valueOf(1.0d), Double.valueOf(1.0d));
        layer.acceptInput(asList);
        Assert.assertEquals(asList3, layer.getError(asList2));
    }

    public void testLayerUpdatesActivationProperly() {
        Layer layer = new Layer(1);
        Layer layer2 = new Layer(1, 0.0d, new SquareActivationFunction());
        Layer layer3 = new Layer(1, 1.0d, new IdentityActivationFunction());
        layer.connectTo(layer2, new MockRandomizer(new double[]{1.0d}));
        layer2.connectTo(layer3, new MockRandomizer(new double[]{2.0d}));
        layer.acceptInput(Arrays.asList(Double.valueOf(1.0d)));
        layer2.update();
        Assert.assertEquals(Arrays.asList(Double.valueOf(1.0d)), layer2.activation());
        layer3.update();
        Assert.assertEquals(Arrays.asList(Double.valueOf(3.0d)), layer3.activation());
        layer.acceptInput(Arrays.asList(Double.valueOf(2.0d)));
        layer2.update();
        Assert.assertEquals(Arrays.asList(Double.valueOf(4.0d)), layer2.activation());
        layer3.update();
        Assert.assertEquals(Arrays.asList(Double.valueOf(9.0d)), layer3.activation());
    }

    public void testFeedForwardNeuralNetwork() {
        Layer layer = new Layer(1);
        Layer layer2 = new Layer(1, 0.0d, new SquareActivationFunction());
        Layer layer3 = new Layer(1, 1.0d, new IdentityActivationFunction());
        FeedForwardNetwork feedForwardNetwork = new FeedForwardNetwork();
        feedForwardNetwork.addLayer(layer, null);
        feedForwardNetwork.addLayer(layer2, new MockRandomizer(new double[]{1.0d}));
        feedForwardNetwork.addLayer(layer3, new MockRandomizer(new double[]{2.0d}));
        feedForwardNetwork.propogateInput(Arrays.asList(Double.valueOf(1.0d)));
        Assert.assertEquals(Arrays.asList(Double.valueOf(3.0d)), feedForwardNetwork.output());
        feedForwardNetwork.propogateInput(Arrays.asList(Double.valueOf(2.0d)));
        Assert.assertEquals(Arrays.asList(Double.valueOf(9.0d)), feedForwardNetwork.output());
    }

    public void testBackPropogation() {
        Layer layer = new Layer(1);
        Layer layer2 = new Layer(2, (List<Double>) Arrays.asList(Double.valueOf(-0.48d), Double.valueOf(-0.13d)), new LogSigActivationFunction());
        Layer layer3 = new Layer(1, 0.48d, new IdentityActivationFunction());
        FeedForwardNetwork feedForwardNetwork = new FeedForwardNetwork();
        feedForwardNetwork.addLayer(layer, null);
        feedForwardNetwork.addLayer(layer2, new MockRandomizer(new double[]{-0.27d, -0.41d}));
        feedForwardNetwork.addLayer(layer3, new MockRandomizer(new double[]{0.09d, -0.17d}));
        StandardBackPropogation standardBackPropogation = new StandardBackPropogation();
        standardBackPropogation.backPropogate(feedForwardNetwork, Arrays.asList(Double.valueOf(1.0d)), Arrays.asList(Double.valueOf(testFunction(1.0d))));
        Assert.assertEquals(0.321d, layer2.activation().get(0).doubleValue(), 0.001d);
        Assert.assertEquals(0.368d, layer2.activation().get(1).doubleValue(), 0.001d);
        Assert.assertEquals(0.446d, layer3.activation().get(0).doubleValue(), 0.001d);
        Assert.assertEquals(-1.261d, standardBackPropogation.delta(layer3).get(0).doubleValue(), 0.001d);
        Assert.assertEquals(-0.0247d, standardBackPropogation.delta(layer2).get(0).doubleValue(), 0.001d);
        Assert.assertEquals(0.04986d, standardBackPropogation.delta(layer2).get(1).doubleValue(), 0.001d);
        standardBackPropogation.updateWeightsAndBiases(feedForwardNetwork);
        Assert.assertEquals(0.13d, layer3.weights().get(0).doubleValue(), 0.001d);
        Assert.assertEquals(-0.123d, layer3.weights().get(1).doubleValue(), 0.001d);
        Assert.assertEquals(0.606d, layer3.getNeurons().get(0).bias(), 0.001d);
        Assert.assertEquals(-0.268d, layer2.weights().get(0).doubleValue(), 0.001d);
        Assert.assertEquals(-0.415d, layer2.weights().get(1).doubleValue(), 0.001d);
        Assert.assertEquals(-0.478d, layer2.getNeuron(0).bias(), 0.001d);
        Assert.assertEquals(-0.135d, layer2.getNeuron(1).bias(), 0.001d);
    }
}
