package org.campagnelab.dl.genotype.performance;

import java.util.function.Predicate;
import org.campagnelab.dl.framework.domains.prediction.BinaryClassPrediction;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;

/* loaded from: input_file:org/campagnelab/dl/genotype/performance/AccuracyHelper.class */
public class AccuracyHelper {
    private static final int IS_HOMOZYGOUS_OUTPUT = 0;

    public double estimateWithGraph(MultiDataSetIterator multiDataSetIterator, ComputationGraph computationGraph, Predicate<Integer> predicate) {
        int i = 0;
        int i2 = 0;
        int i3 = 0;
        BinaryClassPrediction binaryClassPrediction = new BinaryClassPrediction();
        while (multiDataSetIterator.hasNext()) {
            MultiDataSet multiDataSet = (MultiDataSet) multiDataSetIterator.next();
            INDArray[] output = computationGraph.output(multiDataSet.getFeatures());
            INDArray[] labels = multiDataSet.getLabels();
            for (int i4 = 0; i4 < output[0].rows(); i4++) {
                i2++;
                int i5 = -1;
                double d = -1.0d;
                for (int i6 = 0; i6 < output[0].columns(); i6++) {
                    double d2 = output[0].getDouble(i4, i6);
                    if (d2 > d) {
                        i5 = i6;
                        d = d2;
                    }
                }
                if (i5 == 10) {
                    boolean z = true;
                    for (int i7 = 1; i7 < output.length; i7++) {
                        try {
                            binaryClassPrediction.trueLabelYes = Double.valueOf(labels[i7].getDouble(i4, 0));
                        } catch (IndexOutOfBoundsException e) {
                        }
                        binaryClassPrediction.predictedLabelYes = output[i7].getDouble(i4, 0);
                        z = z && (binaryClassPrediction.trueLabelYes.doubleValue() < 0.5d ? binaryClassPrediction.predictedLabelYes < 0.5d : binaryClassPrediction.predictedLabelYes >= 0.5d);
                        int i8 = i;
                        i++;
                        binaryClassPrediction.index = i8;
                    }
                    if (z) {
                        i3++;
                    }
                } else if (labels[0].getDouble(i4, i5) != 0.0d) {
                    i3++;
                }
            }
            if (predicate.test(Integer.valueOf(i2))) {
                break;
            }
        }
        return i3 / i2;
    }
}
