package org.neuroph.eval;

import java.util.ArrayList;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import org.apache.commons.lang3.SerializationUtils;
import org.neuroph.core.NeuralNetwork;
import org.neuroph.core.data.DataSet;
import org.neuroph.core.learning.error.MeanSquaredError;
import org.neuroph.eval.ClassifierEvaluator;
import org.neuroph.util.data.sample.Sampling;
import org.neuroph.util.data.sample.SubSampling;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/neuroph/eval/CrossValidationBak.class */
public class CrossValidationBak {
    private static final Logger LOGGER = LoggerFactory.getLogger(CrossValidationBak.class.getName());
    private NeuralNetwork neuralNetwork;
    private DataSet dataSet;
    private Sampling sampling;
    private int numberOfFolds;
    private int foldSize;
    private final Evaluation evaluation = new Evaluation();

    /* loaded from: input_file:org/neuroph/eval/CrossValidationBak$CrossValidationWorker.class */
    private class CrossValidationWorker implements Callable<EvaluationResult> {
        private final NeuralNetwork neuralNetwork;
        private final DataSet dataSet;
        private final int foldIndex;

        public Evaluation getEvaluation() {
            return CrossValidationBak.this.evaluation;
        }

        public CrossValidationWorker(NeuralNetwork neuralNetwork, DataSet dataSet, int i) {
            this.neuralNetwork = neuralNetwork;
            this.dataSet = dataSet;
            this.foldIndex = i;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.Callable
        public EvaluationResult call() {
            NeuralNetwork neuralNetwork = (NeuralNetwork) SerializationUtils.clone(this.neuralNetwork);
            DataSet dataSet = new DataSet(this.dataSet.size() - CrossValidationBak.this.foldSize);
            DataSet dataSet2 = new DataSet(CrossValidationBak.this.foldSize);
            int i = CrossValidationBak.this.foldSize * this.foldIndex;
            int i2 = CrossValidationBak.this.foldSize * (this.foldIndex + 1);
            for (int i3 = 0; i3 < this.dataSet.size(); i3++) {
                if (i3 < i || i3 >= i2) {
                    dataSet.add(this.dataSet.getRowAt(i3));
                } else {
                    dataSet2.add(this.dataSet.getRowAt(i3));
                }
            }
            neuralNetwork.learn(dataSet);
            new EvaluationResult().setNeuralNetwork(neuralNetwork);
            return CrossValidationBak.this.evaluation.evaluate(neuralNetwork, dataSet2);
        }
    }

    private void initialize(NeuralNetwork neuralNetwork, DataSet dataSet, int i) {
        this.neuralNetwork = neuralNetwork;
        this.numberOfFolds = i;
        this.dataSet = dataSet;
        if (neuralNetwork.getOutputsCount() == 1) {
            this.evaluation.addEvaluator(new ClassifierEvaluator.Binary(0.5d));
        } else {
            this.evaluation.addEvaluator(new ClassifierEvaluator.MultiClass(dataSet.getColumnNames()));
        }
        this.evaluation.addEvaluator(new ErrorEvaluator(new MeanSquaredError()));
    }

    public CrossValidationBak(NeuralNetwork neuralNetwork, DataSet dataSet, int i) {
        initialize(neuralNetwork, dataSet, i);
        this.sampling = new SubSampling(i);
    }

    public Sampling getSampling() {
        return this.sampling;
    }

    public void setSampling(Sampling sampling) {
        this.sampling = sampling;
    }

    public Evaluation getEvaluation() {
        return this.evaluation;
    }

    public void run() throws InterruptedException, ExecutionException {
        this.dataSet.shuffle();
        this.foldSize = this.dataSet.size() / this.numberOfFolds;
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < this.numberOfFolds; i++) {
            arrayList.add(new CrossValidationWorker(this.neuralNetwork, this.dataSet, i));
        }
        ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(4);
        newFixedThreadPool.invokeAll(arrayList);
        newFixedThreadPool.shutdown();
    }

    public void addEvaluator(Evaluator evaluator) {
        this.evaluation.addEvaluator(evaluator);
    }

    public <T extends Evaluator> T getEvaluator(Class<T> cls) {
        return (T) this.evaluation.getEvaluator(cls);
    }
}
