package org.campagnelab.dl.genotype.learning.architecture.graphs;

import org.campagnelab.dl.framework.domains.DomainDescriptor;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.LearningRatePolicy;
import org.deeplearning4j.nn.conf.graph.rnn.DuplicateToTimeSeriesVertex;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.weights.WeightInit;

/* loaded from: input_file:org/campagnelab/dl/genotype/learning/architecture/graphs/GenotypeAssembler.class */
public abstract class GenotypeAssembler {
    protected boolean hasIsVariant;

    /* JADX INFO: Access modifiers changed from: protected */
    public void appendIsVariantLayer(DomainDescriptor domainDescriptor, LearningRatePolicy learningRatePolicy, ComputationGraphConfiguration.GraphBuilder graphBuilder, int i, WeightInit weightInit, String str) {
        if (this.hasIsVariant) {
            graphBuilder.addLayer("isVariant", new OutputLayer.Builder(domainDescriptor.getOutputLoss("isVariant")).weightInit(weightInit).activation("softmax").weightInit(weightInit).learningRateDecayPolicy(learningRatePolicy).nIn(i).nOut(domainDescriptor.getNumOutputs("isVariant")[0]).build(), new String[]{str});
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void appendMetaDataLayer(DomainDescriptor domainDescriptor, LearningRatePolicy learningRatePolicy, ComputationGraphConfiguration.GraphBuilder graphBuilder, int i, WeightInit weightInit, String str) {
        graphBuilder.addLayer("metaData", new OutputLayer.Builder(domainDescriptor.getOutputLoss("metaData")).weightInit(weightInit).activation("softmax").weightInit(weightInit).learningRateDecayPolicy(learningRatePolicy).nIn(i).nOut(domainDescriptor.getNumOutputs("metaData")[0]).build(), new String[]{str});
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void appendTrueGenotypeLayers(ComputationGraphConfiguration.GraphBuilder graphBuilder, boolean z, String str, DomainDescriptor domainDescriptor, WeightInit weightInit, LearningRatePolicy learningRatePolicy, int i, int i2, int i3, int i4, int i5, float f) {
        if (z) {
            String str2 = "no layer";
            int i6 = i2;
            int i7 = i2;
            int i8 = 0;
            while (i8 < i5) {
                str2 = "denseReduction_" + i8;
                i7 = (int) (i6 * f);
                graphBuilder.addLayer(str2, new DenseLayer.Builder().weightInit(weightInit).activation("relu").learningRateDecayPolicy(learningRatePolicy).nIn(i6).nOut(i7).build(), new String[]{i8 == 0 ? str : "denseReduction_" + (i8 - 1)});
                i6 = i7;
                i8++;
            }
            graphBuilder.addVertex("feedForwardLstmDuplicate", new DuplicateToTimeSeriesVertex("trueGenotypeInput"), new String[]{str2});
            String str3 = "no layer";
            int i9 = 0;
            while (i9 < i) {
                str3 = "lstmTrueGenotype_" + i9;
                graphBuilder.addLayer(str3, new GravesLSTM.Builder().activation("softsign").nIn(i9 == 0 ? i7 + i4 : i3).nOut(i3).learningRateDecayPolicy(learningRatePolicy).weightInit(weightInit).build(), i9 == 0 ? new String[]{"trueGenotypeInput", "feedForwardLstmDuplicate"} : new String[]{"lstmTrueGenotype_" + (i9 - 1)});
                i9++;
            }
            graphBuilder.addLayer("trueGenotype", new RnnOutputLayer.Builder(domainDescriptor.getOutputLoss("trueGenotype")).weightInit(weightInit).activation("softmax").learningRateDecayPolicy(learningRatePolicy).nIn(i3).nOut(domainDescriptor.getNumOutputs("trueGenotype")[0]).build(), new String[]{str3});
        }
    }
}
