package org.campagnelab.dl.genotype.learning.architecture.graphs;

import org.campagnelab.dl.framework.architecture.graphs.ComputationGraphAssembler;
import org.campagnelab.dl.framework.domains.DomainDescriptor;
import org.campagnelab.dl.framework.models.ModelPropertiesHelper;
import org.campagnelab.dl.framework.tools.TrainingArguments;
import org.campagnelab.dl.genotype.learning.GenotypeTrainingArguments;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.LearningRatePolicy;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.lossfunctions.ILossFunction;

/* loaded from: input_file:org/campagnelab/dl/genotype/learning/architecture/graphs/GenotypeSixDenseLayersNarrower2.class */
public class GenotypeSixDenseLayersNarrower2 extends GenotypeAssembler implements ComputationGraphAssembler {
    private int numHiddenNodes;
    private TrainingArguments arguments;
    private int numInputs;
    private String[] outputNames;
    private FeedForwardDenseLayerAssembler layerAssembler;
    private int numLayers;

    private TrainingArguments args() {
        return this.arguments;
    }

    public GenotypeSixDenseLayersNarrower2() {
        this(false);
    }

    public GenotypeSixDenseLayersNarrower2(boolean z) {
        this.hasIsVariant = z;
        if (z) {
            this.outputNames = new String[]{"homozygous", "A", "T", "C", "G", "N", "I1", "I2", "I3", "I4", "I5", "metaData"};
        } else {
            this.outputNames = new String[]{"homozygous", "A", "T", "C", "G", "N", "I1", "I2", "I3", "I4", "I5", "metaData", "isVariant"};
        }
    }

    public void setArguments(TrainingArguments trainingArguments) {
        this.arguments = trainingArguments;
        this.numLayers = ((GenotypeTrainingArguments) trainingArguments).numLayers;
        this.layerAssembler = new FeedForwardDenseLayerAssembler(trainingArguments);
    }

    public ComputationGraph createComputationalGraph(DomainDescriptor domainDescriptor) {
        LearningRatePolicy learningRatePolicy = LearningRatePolicy.Poly;
        this.layerAssembler.setLearningRatePolicy(learningRatePolicy);
        this.layerAssembler.initializeBuilder(new String[0]);
        ComputationGraphConfiguration.GraphBuilder assemble = this.layerAssembler.assemble(domainDescriptor.getNumInputs("input")[0], domainDescriptor.getNumHiddenNodes("firstDense"), this.numLayers);
        int numOutputs = this.layerAssembler.getNumOutputs();
        WeightInit weightInit = WeightInit.XAVIER;
        String lastLayerName = this.layerAssembler.lastLayerName();
        assemble.addLayer("homozygous", new OutputLayer.Builder(domainDescriptor.getOutputLoss("homozygous")).weightInit(weightInit).activation("softmax").weightInit(weightInit).learningRateDecayPolicy(learningRatePolicy).nIn(numOutputs).nOut(11).build(), new String[]{lastLayerName});
        for (int i = 1; i < this.outputNames.length; i++) {
            assemble.addLayer(this.outputNames[i], new OutputLayer.Builder(domainDescriptor.getOutputLoss(this.outputNames[i])).weightInit(weightInit).activation("softmax").weightInit(weightInit).learningRateDecayPolicy(learningRatePolicy).nIn(numOutputs).nOut(2).build(), new String[]{lastLayerName});
        }
        appendMetaDataLayer(domainDescriptor, learningRatePolicy, assemble, numOutputs, weightInit, lastLayerName);
        appendIsVariantLayer(domainDescriptor, learningRatePolicy, assemble, numOutputs, weightInit, lastLayerName);
        return new ComputationGraph(assemble.setOutputs(this.outputNames).build());
    }

    public void setNumInputs(String str, int... iArr) {
    }

    public void setNumOutputs(String str, int... iArr) {
    }

    public void setNumHiddenNodes(String str, int i) {
    }

    public String[] getInputNames() {
        return new String[]{"input"};
    }

    public String[] getOutputNames() {
        return this.outputNames;
    }

    public String[] getComponentNames() {
        return new String[]{"firstDense"};
    }

    public void setLossFunction(String str, ILossFunction iLossFunction) {
    }

    public void saveProperties(ModelPropertiesHelper modelPropertiesHelper) {
    }
}
