package hivemall.regression;

import hivemall.optimizer.EtaEstimator;
import hivemall.optimizer.LossFunctions;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.lucene.util.packed.PackedInts;

@Description(name = "logress", value = "_FUNC_(array<int|bigint|string> features, float target [, constant string options]) - Returns a relation consists of <{int|bigint|string} feature, float weight>")
@Deprecated
/* loaded from: input_file:hivemall/regression/LogressUDTF.class */
public final class LogressUDTF extends RegressionBaseUDTF {
    private EtaEstimator etaEstimator;

    @Override // hivemall.regression.RegressionBaseUDTF
    public StructObjectInspector initialize(ObjectInspector[] objectInspectorArr) throws UDFArgumentException {
        int length = objectInspectorArr.length;
        if (length == 2 || length == 3) {
            return super.initialize(objectInspectorArr);
        }
        throw new UDFArgumentException("LogressUDTF takes 2 or 3 arguments: List<Text|Int|BitInt> features, float target [, constant string options]");
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // hivemall.LearnerBaseUDTF, hivemall.UDTFWithOptions
    public Options getOptions() {
        Options options = super.getOptions();
        options.addOption("t", "total_steps", true, "a total of n_samples * epochs time steps [default: 10000]");
        options.addOption("power_t", true, "The exponent for inverse scaling learning rate [default 0.1]");
        options.addOption("eta0", true, "The initial learning rate [default 0.1]");
        return options;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // hivemall.LearnerBaseUDTF, hivemall.UDTFWithOptions
    public CommandLine processOptions(ObjectInspector[] objectInspectorArr) throws UDFArgumentException {
        CommandLine processOptions = super.processOptions(objectInspectorArr);
        this.etaEstimator = EtaEstimator.get(processOptions);
        return processOptions;
    }

    @Override // hivemall.regression.RegressionBaseUDTF
    protected void checkTargetValue(float f) throws UDFArgumentException {
        if (f < PackedInts.COMPACT || f > 1.0f) {
            throw new UDFArgumentException("target must be in range 0 to 1: " + f);
        }
    }

    @Override // hivemall.regression.RegressionBaseUDTF
    protected float computeGradient(float f, float f2) {
        return this.etaEstimator.eta(this.count) * LossFunctions.logisticLoss(f, f2);
    }
}
