package org.campagnelab.dl.genotype.performance;

import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
import it.unimi.dsi.fastutil.objects.ObjectArrayList;
import java.util.function.Predicate;
import org.campagnelab.dl.framework.domains.DomainDescriptor;
import org.campagnelab.dl.framework.performance.AreaUnderTheROCCurve;
import org.campagnelab.dl.varanalysis.protobuf.BaseInformationRecords;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;

/* loaded from: input_file:org/campagnelab/dl/genotype/performance/GenotypeTrainingPerformanceHelperWithAUC.class */
public class GenotypeTrainingPerformanceHelperWithAUC extends GenotypeTrainingPerformanceHelper {
    private final GenotypeTrainingPerformanceHelper delegate;
    private double observedScore;
    private double observedAUC;
    AreaUnderTheROCCurve aucCalculator;
    private double observedAUC_F1;

    public GenotypeTrainingPerformanceHelperWithAUC(DomainDescriptor<BaseInformationRecords.BaseInformation> domainDescriptor) {
        super(domainDescriptor);
        this.aucCalculator = new AreaUnderTheROCCurve(100000);
        this.delegate = new GenotypeTrainingPerformanceHelper(domainDescriptor);
    }

    @Override // org.campagnelab.dl.genotype.performance.GenotypeTrainingPerformanceHelper
    public double estimateWithGraph(MultiDataSetIterator multiDataSetIterator, ComputationGraph computationGraph, Predicate<Integer> predicate) {
        double[] dArr = new double[1];
        int[] iArr = new int[1];
        double estimateWithGraph = this.delegate.estimateWithGraph(multiDataSetIterator, computationGraph, predicate, genotypePrediction -> {
            if (genotypePrediction.isVariant()) {
                this.aucCalculator.observe(genotypePrediction.isVariantProbability, genotypePrediction.isCorrect() ? 1.0d : -1.0d);
            }
        }, d -> {
            dArr[0] = dArr[0] + d.doubleValue();
            iArr[0] = iArr[0] + 1;
        });
        this.observedScore = dArr[0] / iArr[0];
        this.observedAUC = this.aucCalculator.evaluateStatistic();
        return estimateWithGraph;
    }

    @Override // org.campagnelab.dl.genotype.performance.GenotypeTrainingPerformanceHelper
    public double[] getMetricValues(String... strArr) {
        ObjectArrayList wrap = ObjectArrayList.wrap((Object[]) strArr.clone());
        wrap.remove("score");
        wrap.remove("AUC");
        wrap.remove("AUC+F1");
        wrap.trim();
        DoubleArrayList wrap2 = DoubleArrayList.wrap(this.delegate.getMetricValues((String[]) wrap.toArray(new String[wrap.size()])));
        double d = 0.0d;
        for (int i = 0; i < wrap.size(); i++) {
            if ("F1".equals((String) wrap.get(i))) {
                d = wrap2.getDouble(i);
            }
        }
        this.observedAUC_F1 = combine(this.observedAUC, d);
        for (int i2 = 0; i2 < strArr.length; i2++) {
            if ("score".equals(strArr[i2])) {
                wrap2.add(i2, this.observedScore);
            }
            if ("AUC".equals(strArr[i2])) {
                wrap2.add(i2, this.observedAUC);
            }
            if ("AUC+F1".equals(strArr[i2])) {
                wrap2.add(i2, this.observedAUC_F1);
            }
        }
        return wrap2.toDoubleArray();
    }

    private double combine(double d, double d2) {
        return (d * 0.1d) + (d2 * 0.9d);
    }
}
