package org.neuroph.nnet.learning;

import java.util.Iterator;
import java.util.List;
import org.neuroph.core.Connection;
import org.neuroph.core.Layer;
import org.neuroph.core.Neuron;
import org.neuroph.core.Weight;

/* loaded from: input_file:org/neuroph/nnet/learning/MomentumBackpropagation.class */
public class MomentumBackpropagation extends BackPropagation {
    private static final long serialVersionUID = 1;
    protected double momentum = 0.25d;

    /* loaded from: input_file:org/neuroph/nnet/learning/MomentumBackpropagation$MomentumTrainingData.class */
    public static class MomentumTrainingData {
        public double previousWeightChange;
    }

    @Override // org.neuroph.nnet.learning.BackPropagation
    protected void calculateErrorAndUpdateHiddenNeurons() {
        List<Layer> layers = this.neuralNetwork.getLayers();
        for (int size = layers.size() - 2; size > 0; size--) {
            List<Neuron> neurons = layers.get(size).getNeurons();
            if (neurons.size() >= 100) {
                neurons.parallelStream().forEach(neuron -> {
                    neuron.setDelta(calculateHiddenNeuronError(neuron));
                    calculateWeightChanges(neuron);
                });
            } else {
                for (Neuron neuron2 : neurons) {
                    neuron2.setDelta(calculateHiddenNeuronError(neuron2));
                    calculateWeightChanges(neuron2);
                }
            }
        }
    }

    @Override // org.neuroph.nnet.learning.LMS
    public void calculateWeightChanges(Neuron neuron) {
        for (Connection connection : neuron.getInputConnections()) {
            double input = connection.getInput();
            if (input != 0.0d) {
                double delta = neuron.getDelta();
                Weight weight = connection.getWeight();
                MomentumTrainingData momentumTrainingData = (MomentumTrainingData) weight.getTrainingData();
                double d = ((-this.learningRate) * delta * input) + (this.momentum * momentumTrainingData.previousWeightChange);
                momentumTrainingData.previousWeightChange = weight.weightChange;
                if (isBatchMode()) {
                    weight.weightChange += d;
                } else {
                    weight.weightChange = d;
                }
            }
        }
    }

    public double getMomentum() {
        return this.momentum;
    }

    public void setMomentum(double d) {
        this.momentum = d;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.neuroph.core.learning.SupervisedLearning, org.neuroph.core.learning.IterativeLearning, org.neuroph.core.learning.LearningRule
    public void onStart() {
        super.onStart();
        Iterator<Layer> it = this.neuralNetwork.getLayers().iterator();
        while (it.hasNext()) {
            Iterator<Neuron> it2 = it.next().getNeurons().iterator();
            while (it2.hasNext()) {
                Iterator<Connection> it3 = it2.next().getInputConnections().iterator();
                while (it3.hasNext()) {
                    it3.next().getWeight().setTrainingData(new MomentumTrainingData());
                }
            }
        }
    }
}
