package org.campagnelab.dl.genotype.learning.architecture.graphs;

import org.campagnelab.dl.framework.tools.TrainingArguments;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.LearningRatePolicy;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.weights.WeightInit;

/* loaded from: input_file:org/campagnelab/dl/genotype/learning/architecture/graphs/FeedForwardDenseLayerAssembler.class */
public class FeedForwardDenseLayerAssembler {
    private TrainingArguments args;
    private int numOutputs;
    private String lastLayerName;
    private ComputationGraphConfiguration.GraphBuilder build;
    private static final double BUILDER_EPSILON = 1.0E-8d;
    private static final double LAYER_EPSILON = 0.1d;
    private static final WeightInit WEIGHT_INIT;
    private static final float REDUCTION = 1.0f;
    static final /* synthetic */ boolean $assertionsDisabled;
    private LearningRatePolicy learningRatePolicy = LearningRatePolicy.Poly;
    private double modelCapacity = 1.0d;
    private double reductionRate = 1.0d;

    public FeedForwardDenseLayerAssembler(TrainingArguments trainingArguments) {
        this.args = trainingArguments;
    }

    TrainingArguments args() {
        return this.args;
    }

    public void initializeBuilder(String... strArr) {
        if (strArr.length == 0) {
            strArr = new String[]{"input"};
        }
        NeuralNetConfiguration.Builder weightInit = new NeuralNetConfiguration.Builder().seed(args().seed).iterations(1).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).learningRate(args().learningRate).updater(Updater.ADAGRAD).epsilon(BUILDER_EPSILON).lrPolicyDecayRate(0.5d).weightInit(WEIGHT_INIT);
        if (args().regularizationRate != null) {
            weightInit.l2(args().regularizationRate.doubleValue());
            weightInit.regularization(args().regularizationRate != null);
        }
        if (args().dropoutRate != null) {
            weightInit.dropOut(args().dropoutRate.doubleValue());
            weightInit.setUseDropConnect(true);
        }
        this.modelCapacity = args().modelCapacity;
        this.reductionRate = args().reductionRate;
        this.build = weightInit.graphBuilder().addInputs(strArr);
    }

    public ComputationGraphConfiguration.GraphBuilder assemble(int i, int i2, int i3) {
        return assemble(i, i2, i3, "input", 1);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public ComputationGraphConfiguration.GraphBuilder assemble(int i, int i2, int i3, String str, int i4) {
        if (!$assertionsDisabled && i2 <= 0) {
            throw new AssertionError("model capacity is too small. At least some hidden nodes must be created.");
        }
        WeightInit weightInit = WeightInit.XAVIER;
        int pow = (int) (i2 * Math.pow(REDUCTION, i3));
        if (!$assertionsDisabled && pow <= 2) {
            throw new AssertionError("Too much reduction, not enough outputs: ");
        }
        int i5 = i;
        int i6 = (int) (i2 * this.modelCapacity);
        String str2 = "no layers";
        int i7 = i4;
        while (i7 < i4 + i3) {
            str2 = "dense" + i7;
            String str3 = i7 == i4 ? str : "dense" + (i7 - 1);
            i6 = (int) (i2 * Math.pow(this.reductionRate, i7) * this.modelCapacity);
            this.build.addLayer(str2, new DenseLayer.Builder().nIn(i5).nOut(i6).weightInit(weightInit).activation("relu").learningRateDecayPolicy(this.learningRatePolicy).epsilon(LAYER_EPSILON).build(), new String[]{str3});
            i5 = i6;
            i7++;
        }
        this.numOutputs = i6;
        this.lastLayerName = str2;
        return this.build;
    }

    public int getNumOutputs() {
        return this.numOutputs;
    }

    public void setLearningRatePolicy(LearningRatePolicy learningRatePolicy) {
        this.learningRatePolicy = learningRatePolicy;
    }

    public void setInputTypes(InputType... inputTypeArr) {
        this.build.setInputTypes(inputTypeArr);
    }

    public ComputationGraphConfiguration.GraphBuilder getBuild() {
        return this.build;
    }

    public String lastLayerName() {
        return this.lastLayerName;
    }

    static {
        $assertionsDisabled = !FeedForwardDenseLayerAssembler.class.desiredAssertionStatus();
        WEIGHT_INIT = WeightInit.XAVIER;
    }
}
