package org.campagnelab.dl.genotype.performance;

import java.util.ArrayList;
import java.util.function.Consumer;
import java.util.function.Predicate;
import org.campagnelab.dl.framework.domains.DomainDescriptor;
import org.campagnelab.dl.framework.domains.prediction.Prediction;
import org.campagnelab.dl.framework.tools.PredictWithModel;
import org.campagnelab.dl.genotype.helpers.GenotypeHelper;
import org.campagnelab.dl.genotype.predictions.GenotypePrediction;
import org.campagnelab.dl.varanalysis.protobuf.BaseInformationRecords;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;

/* loaded from: input_file:org/campagnelab/dl/genotype/performance/GenotypeTrainingPerformanceHelper.class */
public class GenotypeTrainingPerformanceHelper extends PredictWithModel<BaseInformationRecords.BaseInformation> {
    protected StatsAccumulator accumulator;

    public GenotypeTrainingPerformanceHelper(DomainDescriptor<BaseInformationRecords.BaseInformation> domainDescriptor) {
        super(domainDescriptor);
    }

    public double estimateWithGraph(MultiDataSetIterator multiDataSetIterator, ComputationGraph computationGraph, Predicate<Integer> predicate) {
        return estimateWithGraph(multiDataSetIterator, computationGraph, predicate, genotypePrediction -> {
        }, d -> {
        });
    }

    public double estimateWithGraph(MultiDataSetIterator multiDataSetIterator, ComputationGraph computationGraph, Predicate<Integer> predicate, Consumer<GenotypePrediction> consumer, Consumer<Double> consumer2) {
        multiDataSetIterator.reset();
        this.accumulator = new StatsAccumulator();
        this.accumulator.initializeStats();
        int i = 0;
        int i2 = 0;
        ArrayList arrayList = new ArrayList();
        while (multiDataSetIterator.hasNext()) {
            MultiDataSet multiDataSet = (MultiDataSet) multiDataSetIterator.next();
            INDArray[] output = computationGraph.output(multiDataSet.getFeatures());
            double score = computationGraph.score(multiDataSet);
            if (score == score) {
                consumer2.accept(Double.valueOf(score));
            }
            INDArray[] labels = multiDataSet.getLabels();
            int size = multiDataSet.getFeatures(0).size(0);
            for (int i3 = 0; i3 < size; i3++) {
                arrayList.clear();
                for (int i4 = 0; i4 < this.domainDescriptor.getNumModelOutputs(); i4++) {
                    if (this.interpretors[i4] != null) {
                        Prediction interpret = this.interpretors[i4].interpret(labels[i4], output[i4], i3);
                        interpret.outputIndex = i4;
                        interpret.index = i;
                        arrayList.add(interpret);
                    }
                }
                GenotypePrediction genotypePrediction = (GenotypePrediction) this.domainDescriptor.aggregatePredictions((Object) null, arrayList);
                this.accumulator.observe(genotypePrediction, genotypePrediction.isVariant(), GenotypeHelper.isVariant(genotypePrediction.predictedGenotype, Integer.toString(genotypePrediction.referenceGobyIndex)));
                consumer.accept(genotypePrediction);
                if (predicate.test(Integer.valueOf(i2))) {
                    break;
                }
                i2++;
                i++;
            }
        }
        return this.accumulator.createOutputStatistics()[3];
    }

    public double[] getMetricValues(String... strArr) {
        return this.accumulator.createOutputStatistics(strArr);
    }
}
