package hivemall.common;

import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;

/* loaded from: input_file:hivemall/common/ConversionState.class */
public final class ConversionState {
    private static final Log logger = LogFactory.getLog(ConversionState.class);
    private final boolean conversionCheck;
    private final double convergenceRate;
    private boolean readyToFinishIterations;
    private double totalErrors;
    private double currLosses;
    private double prevLosses;
    private int curIter;

    public ConversionState() {
        this(true, 0.005d);
    }

    public ConversionState(boolean z, double d) {
        this.conversionCheck = z;
        this.convergenceRate = d;
        this.readyToFinishIterations = false;
        this.totalErrors = CMAESOptimizer.DEFAULT_STOPFITNESS;
        this.currLosses = CMAESOptimizer.DEFAULT_STOPFITNESS;
        this.prevLosses = Double.POSITIVE_INFINITY;
        this.curIter = 1;
    }

    public double getTotalErrors() {
        return this.totalErrors;
    }

    public double getCumulativeLoss() {
        return this.currLosses;
    }

    public double getAverageLoss(@Nonnegative long j) {
        return j == 0 ? CMAESOptimizer.DEFAULT_STOPFITNESS : this.currLosses / j;
    }

    public double getPreviousLoss() {
        return this.prevLosses;
    }

    public void incrError(double d) {
        this.totalErrors += d;
    }

    public void incrLoss(double d) {
        this.currLosses += d;
    }

    public void multiplyLoss(double d) {
        this.currLosses *= d;
    }

    public boolean isLossIncreased() {
        return this.currLosses > this.prevLosses;
    }

    public boolean isConverged(long j) {
        if (!this.conversionCheck) {
            return false;
        }
        if (this.currLosses > this.prevLosses) {
            if (logger.isInfoEnabled()) {
                logger.info("Iteration #" + this.curIter + " current cumulative loss `" + this.currLosses + "` > previous cumulative loss `" + this.prevLosses + '`');
            }
            this.readyToFinishIterations = false;
            return false;
        }
        if (getChangeRate() >= this.convergenceRate) {
            if (logger.isInfoEnabled()) {
                logger.info(getInfo(j));
            }
            this.readyToFinishIterations = false;
            return false;
        }
        if (this.readyToFinishIterations) {
            if (!logger.isInfoEnabled()) {
                return true;
            }
            logger.info("Training converged at " + this.curIter + "-th iteration!\n" + getInfo(j));
            return true;
        }
        if (logger.isInfoEnabled()) {
            logger.info(getInfo(j));
        }
        this.readyToFinishIterations = true;
        return false;
    }

    double getChangeRate() {
        return (this.prevLosses - this.currLosses) / this.prevLosses;
    }

    public void next() {
        this.prevLosses = this.currLosses;
        this.currLosses = CMAESOptimizer.DEFAULT_STOPFITNESS;
        this.curIter++;
    }

    public int getCurrentIteration() {
        return this.curIter;
    }

    @Nonnull
    public String getInfo(@Nonnegative long j) {
        StringBuilder sb = new StringBuilder();
        sb.append("Iteration #").append(this.curIter).append(" | ");
        sb.append("average loss=").append(getAverageLoss(j)).append(", ");
        sb.append("current cumulative loss=").append(this.currLosses).append(", ");
        sb.append("previous cumulative loss=").append(this.prevLosses).append(", ");
        sb.append("change rate=").append(getChangeRate()).append(", ");
        sb.append("#trainingExamples=").append(j);
        return sb.toString();
    }
}
