package org.campagnelab.dl.genotype.performance;

import it.unimi.dsi.fastutil.doubles.DoubleList;
import java.util.ArrayList;
import java.util.Collection;
import java.util.function.Consumer;
import java.util.function.Predicate;
import org.campagnelab.dl.framework.domains.DomainDescriptor;
import org.campagnelab.dl.framework.domains.prediction.Prediction;
import org.campagnelab.dl.framework.tools.PredictWithModel;
import org.campagnelab.dl.genotype.predictions.SegmentPrediction;
import org.campagnelab.dl.varanalysis.protobuf.SegmentInformationRecords;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;

/* loaded from: input_file:org/campagnelab/dl/genotype/performance/SegmentTrainingPerformanceHelper.class */
public class SegmentTrainingPerformanceHelper extends PredictWithModel<SegmentInformationRecords.SegmentInformation> {
    protected SegmentStatsAccumulator accumulator;
    double validationScore;
    int n;

    public SegmentTrainingPerformanceHelper(DomainDescriptor<SegmentInformationRecords.SegmentInformation> domainDescriptor) {
        super(domainDescriptor);
        this.accumulator = new SegmentStatsAccumulator(new AccuracySegmentAccumulator(), new IndelAccuracySegmentAccumulator());
    }

    public double estimateWithGraph(MultiDataSetIterator multiDataSetIterator, ComputationGraph computationGraph, Predicate<Integer> predicate) {
        this.validationScore = 0.0d;
        this.n = 0;
        return estimateWithGraph(multiDataSetIterator, computationGraph, predicate, segmentPrediction -> {
        }, d -> {
            this.validationScore += d.doubleValue();
            this.n++;
        });
    }

    public double estimateWithGraph(MultiDataSetIterator multiDataSetIterator, ComputationGraph computationGraph, Predicate<Integer> predicate, Consumer<SegmentPrediction> consumer, Consumer<Double> consumer2) {
        multiDataSetIterator.reset();
        this.accumulator.initializeStats();
        int i = 0;
        ArrayList arrayList = new ArrayList();
        while (multiDataSetIterator.hasNext()) {
            MultiDataSet multiDataSet = (MultiDataSet) multiDataSetIterator.next();
            INDArray[] output = computationGraph.output(multiDataSet.getFeatures());
            double score = computationGraph.score(multiDataSet);
            if (score == score) {
                consumer2.accept(Double.valueOf(score));
            }
            INDArray[] labels = multiDataSet.getLabels();
            int size = multiDataSet.getFeatures(0).size(0);
            for (int i2 = 0; i2 < size; i2++) {
                arrayList.clear();
                for (int i3 = 0; i3 < this.domainDescriptor.getNumModelOutputs(); i3++) {
                    if (this.interpretors[i3] != null) {
                        Prediction interpret = this.interpretors[i3].interpret(labels[i3], output[i3], i2);
                        interpret.outputIndex = i3;
                        interpret.index = 0;
                        arrayList.add(interpret);
                    }
                }
                this.accumulator.observe((SegmentPrediction) this.domainDescriptor.aggregatePredictions((Object) null, arrayList));
                if (predicate.test(Integer.valueOf(i))) {
                    break;
                }
                i++;
            }
        }
        return ((Double) this.accumulator.estimates().get(0)).doubleValue();
    }

    public double[] getMetricValues() {
        DoubleList estimates = this.accumulator.estimates();
        double[] dArr = new double[estimates.size() + 1];
        dArr[0] = this.validationScore / this.n;
        for (int i = 1; i < estimates.size() + 1; i++) {
            dArr[i] = ((Double) estimates.get(i - 1)).doubleValue();
        }
        return dArr;
    }

    public Collection<? extends String> getMetricNames() {
        return this.accumulator.metricNames();
    }
}
