package plugins.nherve.toolbox.image.feature.learning;

import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.Iterator;
import java.util.List;
import plugins.nherve.toolbox.image.db.ImageDatabaseSplit;
import plugins.nherve.toolbox.image.feature.FeatureException;
import plugins.nherve.toolbox.image.feature.signature.DefaultVectorSignature;
import plugins.nherve.toolbox.image.feature.signature.SignatureException;
import plugins.nherve.toolbox.libsvm.svm;
import plugins.nherve.toolbox.libsvm.svm_model;
import plugins.nherve.toolbox.libsvm.svm_node;
import plugins.nherve.toolbox.libsvm.svm_parameter;
import plugins.nherve.toolbox.libsvm.svm_problem;

/* loaded from: input_file:plugins/nherve/toolbox/image/feature/learning/SVMClassifier.class */
public class SVMClassifier extends LearningAlgorithm {
    public static final double SCORE_OFFSET = 100.0d;
    private svm_parameter param;
    private svm_problem prob;
    private svm_model model;
    private int sigSize;
    private boolean balanceWeight = false;

    public void createProblem(List<DefaultVectorSignature> list) throws SignatureException {
        createProblem((DefaultVectorSignature[]) list.toArray(new DefaultVectorSignature[list.size()]));
    }

    public void createProblem(List<DefaultVectorSignature> list, List<DefaultVectorSignature> list2) throws SignatureException {
        createProblem((DefaultVectorSignature[]) list.toArray(new DefaultVectorSignature[list.size()]), (DefaultVectorSignature[]) list2.toArray(new DefaultVectorSignature[list2.size()]));
    }

    /* JADX WARN: Type inference failed for: r1v33, types: [plugins.nherve.toolbox.libsvm.svm_node[], plugins.nherve.toolbox.libsvm.svm_node[][]] */
    public void createProblem(DefaultVectorSignature[] defaultVectorSignatureArr) throws SignatureException {
        info("SVM create problem pos(" + defaultVectorSignatureArr.length + ")");
        this.sigSize = defaultVectorSignatureArr[0].getSize();
        this.param = new svm_parameter();
        this.param.svm_type = 2;
        this.param.kernel_type = 5;
        this.param.degree = 3;
        this.param.gamma = 1.0d / this.sigSize;
        this.param.coef0 = 0.0d;
        this.param.nu = 0.5d;
        this.param.cache_size = 100.0d;
        this.param.C = 1.0d;
        this.param.eps = 0.001d;
        this.param.p = 0.1d;
        this.param.shrinking = 1;
        this.param.probability = 0;
        this.param.nr_weight = 0;
        this.param.weight_label = new int[0];
        this.param.weight = new double[0];
        this.prob = new svm_problem();
        this.prob.l = defaultVectorSignatureArr.length;
        this.prob.x = new svm_node[this.prob.l];
        this.prob.y = new double[this.prob.l];
        for (int i = 0; i < defaultVectorSignatureArr.length; i++) {
            this.prob.x[i] = getNode(defaultVectorSignatureArr[i]);
            this.prob.y[i] = 1.0d;
        }
    }

    public void createProblem(ImageDatabaseSplit imageDatabaseSplit, String str, String str2) throws ClassifierException {
        try {
            setModelInfo(str);
            createProblem(imageDatabaseSplit.getLrnSignatures(str, true, str2), imageDatabaseSplit.getLrnSignatures(str, false, str2));
        } catch (FeatureException e) {
            throw new ClassifierException(e);
        }
    }

    /* JADX WARN: Type inference failed for: r1v39, types: [plugins.nherve.toolbox.libsvm.svm_node[], plugins.nherve.toolbox.libsvm.svm_node[][]] */
    public void createProblem(DefaultVectorSignature[] defaultVectorSignatureArr, DefaultVectorSignature[] defaultVectorSignatureArr2) throws SignatureException {
        DefaultVectorSignature[] defaultVectorSignatureArr3 = defaultVectorSignatureArr;
        DefaultVectorSignature[] defaultVectorSignatureArr4 = defaultVectorSignatureArr2;
        try {
            if (hasDataProcessor() && !isLearnDataProcessed()) {
                DefaultVectorSignature[][] dataProcess = dataProcess(defaultVectorSignatureArr, defaultVectorSignatureArr2);
                defaultVectorSignatureArr3 = dataProcess[0];
                defaultVectorSignatureArr4 = dataProcess[1];
            }
            this.sigSize = defaultVectorSignatureArr3[0].getSize();
            info("SVM create problem (" + (defaultVectorSignatureArr3.length + defaultVectorSignatureArr4.length) + ") : pos(" + defaultVectorSignatureArr3.length + "), neg(" + defaultVectorSignatureArr4.length + "), sig size(" + this.sigSize + ")");
            this.param = new svm_parameter();
            this.param.svm_type = 0;
            this.param.kernel_type = 5;
            this.param.degree = 3;
            this.param.gamma = 1.0d / this.sigSize;
            this.param.coef0 = 0.0d;
            this.param.nu = 0.5d;
            this.param.cache_size = 100.0d;
            this.param.C = 1.0d;
            this.param.eps = 0.001d;
            this.param.p = 0.1d;
            this.param.shrinking = 1;
            this.param.probability = 0;
            this.prob = new svm_problem();
            this.prob.l = defaultVectorSignatureArr3.length + defaultVectorSignatureArr4.length;
            if (this.balanceWeight) {
                this.param.nr_weight = 2;
                this.param.weight_label = new int[2];
                this.param.weight_label[0] = 1;
                this.param.weight_label[1] = -1;
                this.param.weight = new double[2];
                this.param.weight[0] = defaultVectorSignatureArr4.length / this.prob.l;
                this.param.weight[1] = defaultVectorSignatureArr3.length / this.prob.l;
            } else {
                this.param.nr_weight = 0;
                this.param.weight_label = null;
                this.param.weight = null;
            }
            this.prob.x = new svm_node[this.prob.l];
            this.prob.y = new double[this.prob.l];
            for (int i = 0; i < defaultVectorSignatureArr.length; i++) {
                this.prob.x[i] = getNode(defaultVectorSignatureArr3[i]);
                this.prob.y[i] = 1.0d;
            }
            for (int i2 = 0; i2 < defaultVectorSignatureArr2.length; i2++) {
                int length = defaultVectorSignatureArr3.length + i2;
                this.prob.x[length] = getNode(defaultVectorSignatureArr4[i2]);
                this.prob.y[length] = -1.0d;
            }
        } catch (ClassifierException e) {
            throw new SignatureException(e);
        }
    }

    public void crossValidation() {
        double[] dArr = new double[this.prob.l];
        for (int i = -10; i < 16; i++) {
            this.param.C = Math.pow(2.0d, i);
            svm.svm_cross_validation(this.prob, this.param, 10, dArr);
            int i2 = 0;
            for (int i3 = 0; i3 < this.prob.l; i3++) {
                if (dArr[i3] == this.prob.y[i3]) {
                    i2++;
                }
            }
            out("C = " + this.param.C + ", gamma = " + this.param.gamma + " - " + ((100.0d * i2) / this.prob.l) + "%");
        }
    }

    public void saveModel(File file) throws IOException {
        svm.svm_save_model(file.getAbsolutePath(), this.model);
    }

    public void saveModel(OutputStream outputStream) throws IOException {
        svm.svm_save_model(outputStream, this.model);
    }

    public void loadModel(InputStream inputStream) throws IOException {
        this.model = svm.svm_load_model(inputStream);
        if (this.model == null) {
            throw new IOException("unable to load the svm model");
        }
    }

    public void loadModel(File file) throws IOException {
        this.model = svm.svm_load_model(file.getAbsolutePath());
    }

    private svm_node[] getNode(DefaultVectorSignature defaultVectorSignature) throws SignatureException {
        svm_node[] svm_nodeVarArr = new svm_node[defaultVectorSignature.getNonZeroBins()];
        int i = 0;
        Iterator it = defaultVectorSignature.iterator();
        while (it.hasNext()) {
            int intValue = ((Integer) it.next()).intValue();
            svm_nodeVarArr[i] = new svm_node();
            svm_nodeVarArr[i].index = intValue;
            svm_nodeVarArr[i].value = defaultVectorSignature.get(intValue);
            i++;
        }
        return svm_nodeVarArr;
    }

    @Override // plugins.nherve.toolbox.image.feature.learning.LearningAlgorithm
    protected boolean isPositiveImpl(DefaultVectorSignature defaultVectorSignature) throws ClassifierException {
        try {
            return predictImpl(defaultVectorSignature) > 0.0d;
        } catch (SignatureException e) {
            throw new ClassifierException(e);
        }
    }

    public void learn(List<DefaultVectorSignature> list) throws ClassifierException {
        learn((DefaultVectorSignature[]) list.toArray(new DefaultVectorSignature[list.size()]));
    }

    public void learn(DefaultVectorSignature[] defaultVectorSignatureArr) throws ClassifierException {
        try {
            createProblem(defaultVectorSignatureArr);
            learnModel();
        } catch (SignatureException e) {
            throw new ClassifierException(e);
        }
    }

    @Override // plugins.nherve.toolbox.image.feature.learning.LearningAlgorithm
    protected void learnImpl(DefaultVectorSignature[] defaultVectorSignatureArr, DefaultVectorSignature[] defaultVectorSignatureArr2) throws ClassifierException {
        try {
            createProblem(defaultVectorSignatureArr, defaultVectorSignatureArr2);
            learnModel();
        } catch (SignatureException e) {
            throw new ClassifierException(e);
        }
    }

    public void learnModel() throws SignatureException {
        this.model = svm.svm_train(this.prob, this.param);
    }

    public void cleanAfterLearn() {
        this.prob = null;
        this.param = null;
    }

    private double predictImpl(DefaultVectorSignature defaultVectorSignature) throws SignatureException {
        return svm.svm_predict(this.model, getNode(defaultVectorSignature));
    }

    public double predict(DefaultVectorSignature defaultVectorSignature) throws SignatureException {
        return hasDataProcessor() ? predictImpl(getDataProcessor().apply(defaultVectorSignature)) : predictImpl(defaultVectorSignature);
    }

    @Override // plugins.nherve.toolbox.image.feature.learning.LearningAlgorithm
    protected double scoreImpl(DefaultVectorSignature defaultVectorSignature) throws ClassifierException {
        return 100.0d + rawScore(defaultVectorSignature);
    }

    public double rawScore(DefaultVectorSignature defaultVectorSignature) throws ClassifierException {
        try {
            double[] dArr = new double[1];
            svm.svm_predict_values(this.model, getNode(defaultVectorSignature), dArr);
            return dArr[0];
        } catch (SignatureException e) {
            throw new ClassifierException(e);
        }
    }

    public void setC(double d) {
        this.param.C = d;
    }

    public void setGamma(double d) {
        this.param.gamma = d;
    }

    public double getC() {
        return this.param.C;
    }

    public double getGamma() {
        return this.param.gamma;
    }

    public void setNu(double d) {
        this.param.nu = d;
    }

    public void setKernel(int i) {
        this.param.kernel_type = i;
    }

    public int getNbSupportVector() {
        return this.model.l;
    }

    public String toString() {
        return "[" + svm.kernel_type_table[this.param.kernel_type] + ", C = " + this.param.C + ", gamma = " + this.param.gamma + "]";
    }

    public boolean isBalanceWeight() {
        return this.balanceWeight;
    }

    public void setBalanceWeight(boolean z) {
        this.balanceWeight = z;
    }
}
