package org.campagnelab.dl.genotype.learning.domains.predictions;

import org.campagnelab.dl.framework.domains.prediction.PredictionInterpreter;
import org.campagnelab.dl.genotype.mappers.MetaDataLabelMapper;
import org.campagnelab.dl.genotype.predictions.TrueGenotypeOutputLayerPrediction;
import org.campagnelab.dl.varanalysis.protobuf.BaseInformationRecords;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.accum.Max;
import org.nd4j.linalg.api.ops.impl.accum.Mean;
import org.nd4j.linalg.api.ops.impl.accum.Sum;
import org.nd4j.linalg.api.ops.impl.indexaccum.IMax;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

/* loaded from: input_file:org/campagnelab/dl/genotype/learning/domains/predictions/TrueGenotypeOutputLayerInterpreter.class */
public class TrueGenotypeOutputLayerInterpreter implements PredictionInterpreter<BaseInformationRecords.BaseInformation, TrueGenotypeOutputLayerPrediction> {
    /* renamed from: interpret, reason: merged with bridge method [inline-methods] */
    public TrueGenotypeOutputLayerPrediction m21interpret(INDArray iNDArray, INDArray iNDArray2, int i) {
        TrueGenotypeOutputLayerPrediction trueGenotypeOutputLayerPrediction = new TrueGenotypeOutputLayerPrediction();
        INDArray row = iNDArray.getRow(i);
        INDArray row2 = iNDArray2.getRow(i);
        String genotypeFromINDArray = getGenotypeFromINDArray(row);
        String substring = getGenotypeFromINDArray(row2).substring(0, genotypeFromINDArray.length());
        String replace = substring.replace("$", "").replace("*", "");
        trueGenotypeOutputLayerPrediction.trueGenotype = genotypeFromINDArray;
        trueGenotypeOutputLayerPrediction.predictedGenotype = substring;
        trueGenotypeOutputLayerPrediction.isPredictedIndel = replace.length() > 3;
        trueGenotypeOutputLayerPrediction.overallProbability = getAverageFloatMaxArray(row2);
        return trueGenotypeOutputLayerPrediction;
    }

    public TrueGenotypeOutputLayerPrediction interpret(BaseInformationRecords.BaseInformation baseInformation, INDArray iNDArray) {
        TrueGenotypeOutputLayerPrediction trueGenotypeOutputLayerPrediction = new TrueGenotypeOutputLayerPrediction();
        trueGenotypeOutputLayerPrediction.inspectRecord(baseInformation);
        String substring = getGenotypeFromINDArray(iNDArray).substring(0, trueGenotypeOutputLayerPrediction.trueGenotype.length());
        String replace = substring.replace("$", "").replace("*", "");
        trueGenotypeOutputLayerPrediction.predictedGenotype = substring;
        trueGenotypeOutputLayerPrediction.isPredictedIndel = replace.length() > 3;
        trueGenotypeOutputLayerPrediction.overallProbability = getAverageFloatMaxArray(iNDArray);
        return trueGenotypeOutputLayerPrediction;
    }

    private static String getGenotypeFromINDArray(INDArray iNDArray) {
        INDArray intArgMaxArray = getIntArgMaxArray(iNDArray);
        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < intArgMaxArray.length(); i++) {
            sb.append(labelToBase(intArgMaxArray.getInt(new int[]{i})));
        }
        return sb.toString();
    }

    private static INDArray getIntArgMaxArray(INDArray iNDArray) {
        int intValue = Nd4j.getExecutioner().exec(new Sum(iNDArray), new int[]{0}).gt(0).sumNumber().intValue();
        INDArray exec = Nd4j.getExecutioner().exec(new IMax(iNDArray), new int[]{0});
        return intValue > 0 ? exec.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(0, intValue)}) : exec;
    }

    private static double getAverageFloatMaxArray(INDArray iNDArray) {
        INDArray exec = Nd4j.getExecutioner().exec(new Max(iNDArray), new int[]{0});
        int intValue = exec.gt(0).sumNumber().intValue();
        return Nd4j.getExecutioner().execAndReturn(new Mean(intValue > 0 ? exec.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(0, intValue)}) : exec)).getFinalResult().doubleValue();
    }

    private static char labelToBase(int i) {
        switch (i) {
            case 0:
                return 'A';
            case MetaDataLabelMapper.IS_INDEL_FEATURE_INDEX /* 1 */:
                return 'T';
            case MetaDataLabelMapper.IS_MATCHING_REF_FEATURE_INDEX /* 2 */:
                return 'C';
            case 3:
                return 'G';
            case 4:
                return 'N';
            case MetaDataLabelMapper.IS_COUNT3_ORIGINAL_INDEX_FEATURE_INDEX /* 5 */:
                return '-';
            case MetaDataLabelMapper.IS_COUNT4_ORIGINAL_INDEX_FEATURE_INDEX /* 6 */:
                return '|';
            case MetaDataLabelMapper.IS_COUNT5_ORIGINAL_INDEX_FEATURE_INDEX /* 7 */:
                return '*';
            case MetaDataLabelMapper.IS_COUNT6_ORIGINAL_INDEX_FEATURE_INDEX /* 8 */:
                return '$';
            default:
                return ' ';
        }
    }
}
