package org.neuroph.eval;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import org.apache.commons.lang3.SerializationUtils;
import org.neuroph.core.NeuralNetwork;
import org.neuroph.core.data.DataSet;
import org.neuroph.core.learning.error.MeanSquaredError;
import org.neuroph.eval.ClassifierEvaluator;
import org.neuroph.eval.classification.ClassificationMetrics;
import org.neuroph.eval.classification.ConfusionMatrix;

/* loaded from: input_file:org/neuroph/eval/KFoldCrossValidation.class */
public class KFoldCrossValidation {
    private final NeuralNetwork neuralNetwork;
    private final DataSet dataSet;
    private final int numFolds;
    private CountDownLatch allFoldsCompleted;
    private List<ConfusionMatrix> confusionMatrices;
    private List<ClassificationMetrics.Stats> statlist;
    private List<FoldResult> crossFoldResults;

    /* loaded from: input_file:org/neuroph/eval/KFoldCrossValidation$CrossValidationWorker.class */
    private class CrossValidationWorker implements Callable<FoldResult> {
        private final DataSet trainingSet;
        private final DataSet validationSet;

        public CrossValidationWorker(DataSet dataSet, DataSet dataSet2) {
            this.trainingSet = dataSet;
            this.validationSet = dataSet2;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.Callable
        public FoldResult call() throws Exception {
            NeuralNetwork neuralNetwork = (NeuralNetwork) SerializationUtils.clone(KFoldCrossValidation.this.neuralNetwork);
            Evaluation evaluation = new Evaluation();
            evaluation.addEvaluator(new ErrorEvaluator(new MeanSquaredError()));
            if (KFoldCrossValidation.this.neuralNetwork.getOutputsCount() == 1) {
                evaluation.addEvaluator(new ClassifierEvaluator.Binary(0.5d));
            } else {
                evaluation.addEvaluator(new ClassifierEvaluator.MultiClass(KFoldCrossValidation.this.dataSet.getColumnNames()));
            }
            KFoldCrossValidation.this.neuralNetwork.learn(this.trainingSet);
            EvaluationResult evaluate = evaluation.evaluate(neuralNetwork, this.validationSet);
            FoldResult foldResult = new FoldResult(neuralNetwork, this.trainingSet, this.validationSet);
            foldResult.setConfusionMatrix(evaluate.getConfusionMatrix());
            KFoldCrossValidation.this.allFoldsCompleted.countDown();
            return foldResult;
        }
    }

    public KFoldCrossValidation(NeuralNetwork neuralNetwork, DataSet dataSet, int i) {
        this.neuralNetwork = neuralNetwork;
        this.dataSet = dataSet;
        this.numFolds = i;
    }

    public EvaluationResult run() throws InterruptedException, ExecutionException {
        this.confusionMatrices = new ArrayList();
        this.statlist = new ArrayList();
        this.crossFoldResults = new ArrayList();
        DataSet[] m1split = this.dataSet.m1split(this.numFolds);
        ArrayList arrayList = new ArrayList();
        this.allFoldsCompleted = new CountDownLatch(this.numFolds);
        for (int i = 0; i < this.numFolds; i++) {
            arrayList.add(new CrossValidationWorker(createTrainingSetFromFolds(m1split, i), m1split[i]));
        }
        ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(4);
        List invokeAll = newFixedThreadPool.invokeAll(arrayList);
        ArrayList<FoldResult> arrayList2 = new ArrayList();
        Iterator it = invokeAll.iterator();
        while (it.hasNext()) {
            arrayList2.add((FoldResult) ((Future) it.next()).get());
        }
        this.allFoldsCompleted.await();
        newFixedThreadPool.shutdown();
        int i2 = 1;
        for (FoldResult foldResult : arrayList2) {
            this.confusionMatrices.add(foldResult.getConfusionMatrix());
            this.crossFoldResults.add(foldResult);
            this.statlist.add(ClassificationMetrics.average(ClassificationMetrics.createFromMatrix(foldResult.getConfusionMatrix())));
            i2++;
        }
        ConfusionMatrix sumConfusionMatrix = sumConfusionMatrix(this.confusionMatrices, this.dataSet);
        System.out.println();
        System.out.println("All folds results:");
        System.out.println();
        EvaluationResult evaluationResult = new EvaluationResult();
        evaluationResult.setConfusionMatrix(sumConfusionMatrix);
        printFoldResults(sumConfusionMatrix, 0);
        printResults(this.dataSet, averageClassificationMetrics(this.statlist), this.numFolds);
        return evaluationResult;
    }

    public ConfusionMatrix sumConfusionMatrix(List<ConfusionMatrix> list, DataSet dataSet) {
        ConfusionMatrix confusionMatrix = new ConfusionMatrix(list.get(0).getClassLabels());
        int[][] iArr = new int[dataSet.getOutputSize()][dataSet.getOutputSize()];
        for (ConfusionMatrix confusionMatrix2 : list) {
            for (int i = 0; i < dataSet.getOutputSize(); i++) {
                for (int i2 = 0; i2 < dataSet.getOutputSize(); i2++) {
                    int[] iArr2 = iArr[i];
                    int i3 = i2;
                    iArr2[i3] = iArr2[i3] + confusionMatrix2.get(i, i2);
                }
            }
        }
        confusionMatrix.setValues(iArr);
        return confusionMatrix;
    }

    public static ClassificationMetrics.Stats averageClassificationMetrics(List<ClassificationMetrics.Stats> list) {
        ClassificationMetrics.Stats stats = list.get(0);
        for (ClassificationMetrics.Stats stats2 : list) {
            if (!stats2.equals(stats)) {
                stats.accuracy += stats2.accuracy;
                stats.precision += stats2.precision;
                stats.recall += stats2.recall;
                stats.fScore += stats2.fScore;
            }
        }
        double d = 1.0d + 1.0d;
        double size = list.size();
        stats.accuracy /= size;
        stats.precision /= size;
        stats.recall /= size;
        stats.fScore /= size;
        return stats;
    }

    public void printResults(DataSet dataSet, ClassificationMetrics.Stats stats, int i) {
        System.out.println();
        System.out.println("=== Cross validation result ===");
        System.out.println("Instances: " + dataSet.size());
        System.out.println("Number of folds: " + i);
        System.out.println("\n");
        System.out.println("=== Summary ===");
        System.out.println("Accuracy: " + stats.accuracy);
        System.out.println("Precision: " + stats.precision);
        System.out.println("Recall: " + stats.recall);
        System.out.println("FScore: " + stats.fScore);
        System.out.println("Correlation coefficient: " + stats.correlationCoefficient);
    }

    public void printFoldResults(ConfusionMatrix confusionMatrix, int i) {
        System.out.println();
        System.out.println("Fold: " + i);
        System.out.println();
        System.out.println("Confusion matrrix:\r\n");
        System.out.println(confusionMatrix.toString() + "\r\n\r\n");
        System.out.println("Classification metrics\r\n");
        ClassificationMetrics[] createFromMatrix = ClassificationMetrics.createFromMatrix(confusionMatrix);
        ClassificationMetrics.Stats average = ClassificationMetrics.average(createFromMatrix);
        for (ClassificationMetrics classificationMetrics : createFromMatrix) {
            System.out.println(classificationMetrics.toString() + "\r\n");
        }
        System.out.println(average.toString());
    }

    public void printStats(ConfusionMatrix confusionMatrix) {
        System.out.println(ClassificationMetrics.average(ClassificationMetrics.createFromMatrix(confusionMatrix)).toString());
    }

    private DataSet createTrainingSetFromFolds(DataSet[] dataSetArr, int i) {
        DataSet dataSet = new DataSet(this.dataSet.getInputSize(), this.dataSet.getOutputSize());
        for (int i2 = 0; i2 < dataSetArr.length; i2++) {
            if (i2 != i) {
                dataSet.addAll(dataSetArr[i2]);
            }
        }
        return dataSet;
    }

    public List<FoldResult> getResultsByFolds() {
        return this.crossFoldResults;
    }
}
