package hivemall.classifier;

import hivemall.GeneralLearnerBaseUDTF;
import hivemall.model.FeatureValue;
import hivemall.optimizer.LossFunctions;
import javax.annotation.Nonnull;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.lucene.util.packed.PackedInts;

@Description(name = "train_classifier", value = "_FUNC_(list<string|int|bigint> features, int label [, const string options]) - Returns a relation consists of <string|int|bigint feature, float weight>", extended = "Build a prediction model by a generic classifier")
/* loaded from: input_file:hivemall/classifier/GeneralClassifierUDTF.class */
public final class GeneralClassifierUDTF extends GeneralLearnerBaseUDTF {
    @Override // hivemall.GeneralLearnerBaseUDTF
    protected String getLossOptionDescription() {
        return "Loss function [HingeLoss (default), LogLoss, SquaredHingeLoss, ModifiedHuberLoss, or\na regression loss: SquaredLoss, QuantileLoss, EpsilonInsensitiveLoss, SquaredEpsilonInsensitiveLoss, HuberLoss]";
    }

    @Override // hivemall.GeneralLearnerBaseUDTF
    protected LossFunctions.LossType getDefaultLossType() {
        return LossFunctions.LossType.HingeLoss;
    }

    @Override // hivemall.GeneralLearnerBaseUDTF
    protected void checkLossFunction(@Nonnull LossFunctions.LossFunction lossFunction) throws UDFArgumentException {
    }

    @Override // hivemall.GeneralLearnerBaseUDTF
    protected void checkTargetValue(float f) throws UDFArgumentException {
        if (f != -1.0f && f != PackedInts.COMPACT && f != 1.0f) {
            throw new UDFArgumentException("Invalid label value for classification:  + label");
        }
    }

    @Override // hivemall.GeneralLearnerBaseUDTF
    protected void train(@Nonnull FeatureValue[] featureValueArr, float f) {
        update(featureValueArr, f > PackedInts.COMPACT ? 1.0f : -1.0f, predict(featureValueArr));
    }
}
