package org.neuroph.eval;

import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import org.neuroph.core.NeuralNetwork;
import org.neuroph.core.data.DataSet;
import org.neuroph.core.data.DataSetRow;
import org.neuroph.core.learning.error.MeanSquaredError;
import org.neuroph.eval.ClassifierEvaluator;
import org.neuroph.eval.classification.ClassificationMetrics;
import org.neuroph.eval.classification.ConfusionMatrix;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/neuroph/eval/Evaluation.class */
public final class Evaluation {
    private static final Logger LOGGER = LoggerFactory.getLogger("neuroph");
    private final Map<Class<?>, Evaluator> evaluators = new HashMap();

    public Evaluation() {
        addEvaluator(new ErrorEvaluator(new MeanSquaredError()));
    }

    public EvaluationResult evaluate(NeuralNetwork neuralNetwork, DataSet dataSet) {
        Iterator<Evaluator> it = this.evaluators.values().iterator();
        while (it.hasNext()) {
            it.next().reset();
        }
        for (DataSetRow dataSetRow : dataSet.getRows()) {
            neuralNetwork.setInput(dataSetRow.getInput());
            neuralNetwork.calculate();
            Iterator<Evaluator> it2 = this.evaluators.values().iterator();
            while (it2.hasNext()) {
                it2.next().processNetworkResult(neuralNetwork.getOutput(), dataSetRow.getDesiredOutput());
            }
        }
        ConfusionMatrix result = neuralNetwork.getOutputsCount() > 1 ? ((ClassifierEvaluator.MultiClass) getEvaluator(ClassifierEvaluator.MultiClass.class)).getResult() : ((ClassifierEvaluator.Binary) getEvaluator(ClassifierEvaluator.Binary.class)).getResult();
        double doubleValue = ((ErrorEvaluator) getEvaluator(ErrorEvaluator.class)).getResult().doubleValue();
        EvaluationResult evaluationResult = new EvaluationResult();
        evaluationResult.setDataSet(dataSet);
        evaluationResult.setConfusionMatrix(result);
        evaluationResult.setMeanSquareError(doubleValue);
        return evaluationResult;
    }

    public void addEvaluator(Evaluator evaluator) {
        if (evaluator == null) {
            throw new IllegalArgumentException("Evaluator cannot be null!");
        }
        this.evaluators.put(evaluator.getClass(), evaluator);
    }

    public <T extends Evaluator> T getEvaluator(Class<T> cls) {
        return cls.cast(this.evaluators.get(cls));
    }

    public Map<Class<?>, Evaluator> getEvaluators() {
        return this.evaluators;
    }

    public double getMeanSquareError() {
        return ((ErrorEvaluator) getEvaluator(ErrorEvaluator.class)).getResult().doubleValue();
    }

    public static void runFullEvaluation(NeuralNetwork<?> neuralNetwork, DataSet dataSet) {
        Evaluation evaluation = new Evaluation();
        evaluation.addEvaluator(new ClassifierEvaluator.MultiClass(dataSet.getColumnNames()));
        evaluation.evaluate(neuralNetwork, dataSet);
        LOGGER.info("##############################################################################");
        LOGGER.info("MeanSquare Error: " + ((ErrorEvaluator) evaluation.getEvaluator(ErrorEvaluator.class)).getResult());
        LOGGER.info("##############################################################################");
        ConfusionMatrix result = ((ClassifierEvaluator) evaluation.getEvaluator(ClassifierEvaluator.MultiClass.class)).getResult();
        LOGGER.info("Confusion Matrix: \r\n" + result.toString());
        LOGGER.info("##############################################################################");
        LOGGER.info("Classification metrics: ");
        for (ClassificationMetrics classificationMetrics : ClassificationMetrics.createFromMatrix(result)) {
            LOGGER.info(classificationMetrics.toString());
        }
        LOGGER.info("##############################################################################");
    }
}
