package org.nd4j.linalg.dataset.api.iterator;

import java.io.File;
import java.io.IOException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@Deprecated
/* loaded from: input_file:org/nd4j/linalg/dataset/api/iterator/StandardScaler.class */
public class StandardScaler {
    private static Logger logger = LoggerFactory.getLogger((Class<?>) StandardScaler.class);
    private INDArray mean;
    private INDArray std;
    private int runningTotal = 0;
    private int batchCount = 0;

    public void fit(DataSet dataSet) {
        this.mean = dataSet.getFeatureMatrix().mean(0);
        this.std = dataSet.getFeatureMatrix().std(0);
        this.std.addi(Nd4j.scalar(Nd4j.EPS_THRESHOLD));
        if (this.std.min(1) == Nd4j.scalar(Nd4j.EPS_THRESHOLD)) {
            logger.info("API_INFO: Std deviation found to be zero. Transform will round upto epsilon to avoid nans.");
        }
    }

    public void fit(DataSetIterator dataSetIterator) {
        while (dataSetIterator.hasNext()) {
            DataSet next = dataSetIterator.next();
            this.runningTotal += next.numExamples();
            this.batchCount = next.getFeatures().size(0);
            if (this.mean == null) {
                this.mean = next.getFeatureMatrix().mean(0);
                this.std = this.batchCount == 1 ? Nd4j.zeros(this.mean.shape()) : Transforms.pow(next.getFeatureMatrix().std(0), (Number) 2);
                this.std.muli(Integer.valueOf(this.batchCount));
            } else {
                INDArray add = this.mean.add(next.getFeatureMatrix().subRowVector(this.mean).sum(0).divi(Integer.valueOf(this.runningTotal)));
                INDArray mul = Transforms.pow(next.getFeatureMatrix().mean(0).subRowVector(this.mean), (Number) 2).mul(Float.valueOf(((this.runningTotal - this.batchCount) * this.batchCount) / this.runningTotal));
                INDArray pow = Transforms.pow(next.getFeatureMatrix().std(0), (Number) 2);
                pow.muli(Integer.valueOf(this.batchCount));
                this.std = this.std.add(pow);
                this.std = this.std.add(mul);
                this.mean = add;
            }
        }
        this.std.divi(Integer.valueOf(this.runningTotal));
        this.std = Transforms.sqrt(this.std);
        this.std.addi(Nd4j.scalar(Nd4j.EPS_THRESHOLD));
        if (this.std.min(1) == Nd4j.scalar(Nd4j.EPS_THRESHOLD)) {
            logger.info("API_INFO: Std deviation found to be zero. Transform will round upto epsilon to avoid nans.");
        }
        dataSetIterator.reset();
    }

    public void load(File file, File file2) throws IOException {
        this.mean = Nd4j.readBinary(file);
        this.std = Nd4j.readBinary(file2);
    }

    public void save(File file, File file2) throws IOException {
        Nd4j.saveBinary(this.mean, file);
        Nd4j.saveBinary(this.std, file2);
    }

    public void transform(DataSet dataSet) {
        dataSet.setFeatures(dataSet.getFeatures().subRowVector(this.mean));
        dataSet.setFeatures(dataSet.getFeatures().divRowVector(this.std));
    }

    public INDArray getMean() {
        return this.mean;
    }

    public INDArray getStd() {
        return this.std;
    }
}
