package org.campagnelab.dl.genotype.learning.architecture.graphs;

import java.util.Set;
import org.campagnelab.dl.framework.architecture.graphs.ComputationGraphAssembler;
import org.campagnelab.dl.framework.domains.DomainDescriptor;
import org.campagnelab.dl.framework.models.ModelPropertiesHelper;
import org.campagnelab.dl.framework.tools.TrainingArguments;
import org.campagnelab.dl.genotype.learning.GenotypeTrainingArguments;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.LearningRatePolicy;
import org.deeplearning4j.nn.conf.graph.MergeVertex;
import org.deeplearning4j.nn.conf.graph.rnn.LastTimeStepVertex;
import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.lossfunctions.ILossFunction;

/* loaded from: input_file:org/campagnelab/dl/genotype/learning/architecture/graphs/GenotypeSixDenseLayersWithIndelLSTMAggregate.class */
public class GenotypeSixDenseLayersWithIndelLSTMAggregate extends GenotypeAssembler implements ComputationGraphAssembler {
    private final String[] outputNames;
    private final String[] inputNames;
    private final String combined;
    private final OutputType outputType;
    private final boolean addTrueGenotypeLabels;
    private final Set<String> basicOutputNames;
    private GenotypeTrainingArguments arguments;
    private static final String[] lstmInputNames = {"indel"};
    private static final WeightInit WEIGHT_INIT = WeightInit.XAVIER;
    private static final LearningRatePolicy LEARNING_RATE_POLICY = LearningRatePolicy.Poly;
    private static final OutputType DEFAULT_OUTPUT_TYPE = OutputType.DISTINCT_ALLELES;
    private static final String[] basicOutputNamesArray = {"A", "T", "C", "G", "N", "I1", "I2", "I3", "I4", "I5"};

    /* renamed from: org.campagnelab.dl.genotype.learning.architecture.graphs.GenotypeSixDenseLayersWithIndelLSTMAggregate$1, reason: invalid class name */
    /* loaded from: input_file:org/campagnelab/dl/genotype/learning/architecture/graphs/GenotypeSixDenseLayersWithIndelLSTMAggregate$1.class */
    static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$campagnelab$dl$genotype$learning$architecture$graphs$GenotypeSixDenseLayersWithIndelLSTMAggregate$OutputType = new int[OutputType.values().length];

        static {
            try {
                $SwitchMap$org$campagnelab$dl$genotype$learning$architecture$graphs$GenotypeSixDenseLayersWithIndelLSTMAggregate$OutputType[OutputType.DISTINCT_ALLELES.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$campagnelab$dl$genotype$learning$architecture$graphs$GenotypeSixDenseLayersWithIndelLSTMAggregate$OutputType[OutputType.COMBINED.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$campagnelab$dl$genotype$learning$architecture$graphs$GenotypeSixDenseLayersWithIndelLSTMAggregate$OutputType[OutputType.HOMOZYGOUS.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
        }
    }

    /* loaded from: input_file:org/campagnelab/dl/genotype/learning/architecture/graphs/GenotypeSixDenseLayersWithIndelLSTMAggregate$OutputType.class */
    public enum OutputType {
        HOMOZYGOUS,
        DISTINCT_ALLELES,
        COMBINED
    }

    public GenotypeSixDenseLayersWithIndelLSTMAggregate() {
        this(DEFAULT_OUTPUT_TYPE, false, false, false);
    }

    /* JADX WARN: Failed to find 'out' block for switch in B:5:0x0046. Please report as an issue. */
    /* JADX WARN: Removed duplicated region for block: B:13:0x03a9  */
    /* JADX WARN: Removed duplicated region for block: B:16:0x03c3  */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    public GenotypeSixDenseLayersWithIndelLSTMAggregate(org.campagnelab.dl.genotype.learning.architecture.graphs.GenotypeSixDenseLayersWithIndelLSTMAggregate.OutputType r7, boolean r8, boolean r9, boolean r10) {
        /*
            Method dump skipped, instructions count: 982
            To view this dump add '--comments-level debug' option
        */
        throw new UnsupportedOperationException("Method not decompiled: org.campagnelab.dl.genotype.learning.architecture.graphs.GenotypeSixDenseLayersWithIndelLSTMAggregate.<init>(org.campagnelab.dl.genotype.learning.architecture.graphs.GenotypeSixDenseLayersWithIndelLSTMAggregate$OutputType, boolean, boolean, boolean):void");
    }

    private GenotypeTrainingArguments args() {
        return this.arguments;
    }

    public void setArguments(TrainingArguments trainingArguments) {
        this.arguments = (GenotypeTrainingArguments) trainingArguments;
    }

    private void addOutputLayers(ComputationGraphConfiguration.GraphBuilder graphBuilder, DomainDescriptor domainDescriptor, String str, int i, int i2, int i3, int i4, int i5, float f) {
        if (this.outputType == OutputType.HOMOZYGOUS || this.outputType == OutputType.DISTINCT_ALLELES) {
            if (this.outputType == OutputType.DISTINCT_ALLELES) {
                graphBuilder.addLayer("numDistinctAlleles", new OutputLayer.Builder(domainDescriptor.getOutputLoss("homozygous")).weightInit(WEIGHT_INIT).activation("softmax").weightInit(WEIGHT_INIT).learningRateDecayPolicy(LEARNING_RATE_POLICY).nIn(i).nOut(domainDescriptor.getNumOutputs("numDistinctAlleles")[0]).build(), new String[]{str});
            } else {
                graphBuilder.addLayer("homozygous", new OutputLayer.Builder(domainDescriptor.getOutputLoss("homozygous")).weightInit(WEIGHT_INIT).activation("softmax").weightInit(WEIGHT_INIT).learningRateDecayPolicy(LEARNING_RATE_POLICY).nIn(i).nOut(11).build(), new String[]{str});
            }
            for (String str2 : this.outputNames) {
                if (this.basicOutputNames.contains(str2)) {
                    graphBuilder.addLayer(str2, new OutputLayer.Builder(domainDescriptor.getOutputLoss(str2)).weightInit(WEIGHT_INIT).activation("softmax").weightInit(WEIGHT_INIT).learningRateDecayPolicy(LEARNING_RATE_POLICY).nIn(i).nOut(2).build(), new String[]{str});
                }
            }
        } else if (this.outputType == OutputType.COMBINED) {
            graphBuilder.addLayer(this.combined, new OutputLayer.Builder(domainDescriptor.getOutputLoss(this.combined)).weightInit(WEIGHT_INIT).activation(this.combined).weightInit(WEIGHT_INIT).learningRateDecayPolicy(LEARNING_RATE_POLICY).nIn(i).nOut(domainDescriptor.getNumOutputs(this.combined)[0]).build(), new String[]{str});
        }
        appendMetaDataLayer(domainDescriptor, LEARNING_RATE_POLICY, graphBuilder, i, WEIGHT_INIT, str);
        appendIsVariantLayer(domainDescriptor, LEARNING_RATE_POLICY, graphBuilder, i, WEIGHT_INIT, str);
        appendTrueGenotypeLayers(graphBuilder, this.addTrueGenotypeLabels, str, domainDescriptor, WEIGHT_INIT, LEARNING_RATE_POLICY, i2, i, i3, i4, i5, f);
    }

    public ComputationGraph createComputationalGraph(DomainDescriptor domainDescriptor) {
        int i = domainDescriptor.getNumInputs("input")[0];
        int i2 = domainDescriptor.getNumInputs("indel")[0];
        int i3 = domainDescriptor.getNumInputs("trueGenotypeInput")[0];
        int numHiddenNodes = domainDescriptor.getNumHiddenNodes("firstDense");
        int numHiddenNodes2 = domainDescriptor.getNumHiddenNodes("lstmIndelLayer");
        int numHiddenNodes3 = domainDescriptor.getNumHiddenNodes("lstmTrueGenotypeLayer");
        int max = Math.max(1, args().numLSTMLayers);
        int max2 = Math.max(1, args().numLayers);
        int max3 = Math.max(1, args().numReductionLayers);
        float max4 = Math.max(0.1f, Math.min(1.0f, args().reductionRate));
        FeedForwardDenseLayerAssembler feedForwardDenseLayerAssembler = new FeedForwardDenseLayerAssembler(args());
        feedForwardDenseLayerAssembler.setLearningRatePolicy(LEARNING_RATE_POLICY);
        feedForwardDenseLayerAssembler.initializeBuilder(getInputNames());
        feedForwardDenseLayerAssembler.setInputTypes(getInputTypes(domainDescriptor));
        ComputationGraphConfiguration.GraphBuilder build = feedForwardDenseLayerAssembler.getBuild();
        String str = "no layer";
        int i4 = 0;
        while (i4 < max) {
            str = "lstmindel_" + i4;
            build.addLayer(str, new LSTM.Builder().nIn(i4 == 0 ? i2 : numHiddenNodes2).nOut(numHiddenNodes2).build(), new String[]{i4 == 0 ? "indel" : "lstmindel_" + (i4 - 1)});
            i4++;
        }
        build.addVertex("lstmindelLastTimeStepVertex", new LastTimeStepVertex("indel"), new String[]{str});
        String[] strArr = new String[lstmInputNames.length + 1];
        for (int i5 = 0; i5 < lstmInputNames.length; i5++) {
            strArr[i5] = "lstm" + lstmInputNames[i5] + "LastTimeStepVertex";
        }
        strArr[lstmInputNames.length] = "input";
        build.addVertex("lstmFeedForwardMerge", new MergeVertex(), strArr);
        feedForwardDenseLayerAssembler.assemble(i + (lstmInputNames.length * numHiddenNodes2), numHiddenNodes, max2, "lstmFeedForwardMerge", 1);
        addOutputLayers(build, domainDescriptor, feedForwardDenseLayerAssembler.lastLayerName(), feedForwardDenseLayerAssembler.getNumOutputs(), max, numHiddenNodes3, i3, max3, max4);
        ComputationGraphConfiguration build2 = build.setOutputs(this.outputNames).build();
        System.out.println(build2);
        return new ComputationGraph(build2);
    }

    public void setNumInputs(String str, int... iArr) {
    }

    public void setNumOutputs(String str, int... iArr) {
    }

    public void setNumHiddenNodes(String str, int i) {
    }

    public String[] getInputNames() {
        return this.inputNames;
    }

    /* JADX WARN: Removed duplicated region for block: B:17:0x008c  */
    /* JADX WARN: Removed duplicated region for block: B:20:0x00a0  */
    /* JADX WARN: Removed duplicated region for block: B:22:0x00b4 A[SYNTHETIC] */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    private org.deeplearning4j.nn.conf.inputs.InputType[] getInputTypes(org.campagnelab.dl.framework.domains.DomainDescriptor r7) {
        /*
            r6 = this;
            r0 = r6
            java.lang.String[] r0 = r0.getInputNames()
            r8 = r0
            r0 = r8
            int r0 = r0.length
            org.deeplearning4j.nn.conf.inputs.InputType[] r0 = new org.deeplearning4j.nn.conf.inputs.InputType[r0]
            r9 = r0
            r0 = 0
            r10 = r0
        Le:
            r0 = r10
            r1 = r8
            int r1 = r1.length
            if (r0 >= r1) goto Lc4
            r0 = r8
            r1 = r10
            r0 = r0[r1]
            r11 = r0
            r0 = -1
            r12 = r0
            r0 = r11
            int r0 = r0.hashCode()
            switch(r0) {
                case 100346054: goto L54;
                case 100358090: goto L44;
                case 1627108739: goto L64;
                default: goto L71;
            }
        L44:
            r0 = r11
            java.lang.String r1 = "input"
            boolean r0 = r0.equals(r1)
            if (r0 == 0) goto L71
            r0 = 0
            r12 = r0
            goto L71
        L54:
            r0 = r11
            java.lang.String r1 = "indel"
            boolean r0 = r0.equals(r1)
            if (r0 == 0) goto L71
            r0 = 1
            r12 = r0
            goto L71
        L64:
            r0 = r11
            java.lang.String r1 = "trueGenotypeInput"
            boolean r0 = r0.equals(r1)
            if (r0 == 0) goto L71
            r0 = 2
            r12 = r0
        L71:
            r0 = r12
            switch(r0) {
                case 0: goto L8c;
                case 1: goto La0;
                case 2: goto La0;
                default: goto Lb4;
            }
        L8c:
            r0 = r9
            r1 = r10
            r2 = r7
            r3 = r8
            r4 = r10
            r3 = r3[r4]
            int[] r2 = r2.getNumInputs(r3)
            r3 = 0
            r2 = r2[r3]
            org.deeplearning4j.nn.conf.inputs.InputType r2 = org.deeplearning4j.nn.conf.inputs.InputType.feedForward(r2)
            r0[r1] = r2
            goto Lbe
        La0:
            r0 = r9
            r1 = r10
            r2 = r7
            r3 = r8
            r4 = r10
            r3 = r3[r4]
            int[] r2 = r2.getNumInputs(r3)
            r3 = 0
            r2 = r2[r3]
            org.deeplearning4j.nn.conf.inputs.InputType r2 = org.deeplearning4j.nn.conf.inputs.InputType.recurrent(r2)
            r0[r1] = r2
            goto Lbe
        Lb4:
            java.lang.RuntimeException r0 = new java.lang.RuntimeException
            r1 = r0
            java.lang.String r2 = "Invalid input to computation graph"
            r1.<init>(r2)
            throw r0
        Lbe:
            int r10 = r10 + 1
            goto Le
        Lc4:
            r0 = r9
            return r0
        */
        throw new UnsupportedOperationException("Method not decompiled: org.campagnelab.dl.genotype.learning.architecture.graphs.GenotypeSixDenseLayersWithIndelLSTMAggregate.getInputTypes(org.campagnelab.dl.framework.domains.DomainDescriptor):org.deeplearning4j.nn.conf.inputs.InputType[]");
    }

    public String[] getOutputNames() {
        return this.outputNames;
    }

    public String[] getComponentNames() {
        return new String[]{"firstDense", "lstmLayer"};
    }

    public void setLossFunction(String str, ILossFunction iLossFunction) {
    }

    public void saveProperties(ModelPropertiesHelper modelPropertiesHelper) {
        modelPropertiesHelper.put(getClass().getCanonicalName() + ".numLayers", args().numLayers);
        modelPropertiesHelper.put(getClass().getCanonicalName() + ".numLstmLayers", args().numLSTMLayers);
    }
}
