package hivemall;

import hivemall.mix.MixMessage;
import hivemall.mix.client.MixClient;
import hivemall.model.DenseModel;
import hivemall.model.NewDenseModel;
import hivemall.model.NewSpaceEfficientDenseModel;
import hivemall.model.NewSparseModel;
import hivemall.model.PredictionModel;
import hivemall.model.SpaceEfficientDenseModel;
import hivemall.model.SparseModel;
import hivemall.model.SynchronizedModelWrapper;
import hivemall.optimizer.DenseOptimizerFactory;
import hivemall.optimizer.Optimizer;
import hivemall.optimizer.SparseOptimizerFactory;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.io.IOUtils;
import hivemall.utils.lang.Preconditions;
import hivemall.utils.lang.Primitives;
import io.netty.util.internal.StringUtil;
import java.util.Map;
import javax.annotation.CheckForNull;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;

/* loaded from: input_file:hivemall/LearnerBaseUDTF.class */
public abstract class LearnerBaseUDTF extends UDTFWithOptions {
    private static final Log logger;
    private static final int DEFAULT_SPARSE_DIMS = 16384;
    private static final int DEFAULT_DENSE_DIMS = 16777216;
    protected final boolean enableNewModel;
    protected boolean dense_model;
    protected int model_dims;
    protected boolean disable_halffloat;
    protected boolean is_mini_batch;
    protected int mini_batch_size;
    protected String mixConnectInfo;
    protected String mixSessionName;
    protected int mixThreshold;
    protected boolean mixCancel;
    protected boolean ssl;

    @Nullable
    protected MixClient mixClient;
    static final /* synthetic */ boolean $assertionsDisabled;

    public LearnerBaseUDTF(boolean z) {
        this.enableNewModel = z;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public boolean useCovariance() {
        return false;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // hivemall.UDTFWithOptions
    public Options getOptions() {
        Options options = new Options();
        options.addOption("dense", "densemodel", false, "Use dense model or not");
        options.addOption("dims", "feature_dimensions", true, "The dimension of model [default: 16777216 (2^24)]");
        options.addOption("disable_halffloat", false, "Toggle this option to disable the use of SpaceEfficientDenseModel");
        options.addOption("mini_batch", "mini_batch_size", true, "Mini batch size [default: 1]. Expecting the value in range [1,100] or so.");
        options.addOption("mix", "mix_servers", true, "Comma separated list of MIX servers");
        options.addOption("mix_session", "mix_session_name", true, "Mix session name [default: ${mapred.job.id}]");
        options.addOption("mix_threshold", true, "Threshold to mix local updates in range (0,127] [default: 3]");
        options.addOption("mix_cancel", "enable_mix_canceling", false, "Enable mix cancel requests");
        options.addOption("ssl", false, "Use SSL for the communication with mix servers");
        return options;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // hivemall.UDTFWithOptions
    @Nullable
    public CommandLine processOptions(@Nonnull ObjectInspector[] objectInspectorArr) throws UDFArgumentException {
        boolean z = false;
        int i = -1;
        boolean z2 = false;
        int i2 = 1;
        String str = null;
        String str2 = null;
        int i3 = -1;
        boolean z3 = false;
        boolean z4 = false;
        CommandLine commandLine = null;
        if (objectInspectorArr.length >= 3) {
            commandLine = parseOptions(HiveUtils.getConstString(objectInspectorArr[2]));
            z = commandLine.hasOption("dense");
            if (z) {
                i = Primitives.parseInt(commandLine.getOptionValue("dims"), 16777216);
            }
            z2 = commandLine.hasOption("disable_halffloat");
            i2 = Primitives.parseInt(commandLine.getOptionValue("mini_batch_size"), 1);
            if (i2 <= 0) {
                throw new UDFArgumentException("mini_batch_size must be greater than 0: " + i2);
            }
            str = commandLine.getOptionValue("mix");
            str2 = commandLine.getOptionValue("mix_session");
            i3 = Primitives.parseInt(commandLine.getOptionValue("mix_threshold"), 3);
            if (i3 > 127) {
                throw new UDFArgumentException("mix_threshold must be in range (0,127]: " + i3);
            }
            z3 = commandLine.hasOption("mix_cancel");
            z4 = commandLine.hasOption("ssl");
        }
        this.dense_model = z;
        this.model_dims = i;
        this.disable_halffloat = z2;
        this.is_mini_batch = i2 > 1;
        this.mini_batch_size = i2;
        this.mixConnectInfo = str;
        this.mixSessionName = str2;
        this.mixThreshold = i3;
        this.mixCancel = z3;
        this.ssl = z4;
        return commandLine;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Nullable
    public PredictionModel createModel() {
        return this.enableNewModel ? createNewModel(null) : createOldModel(null);
    }

    @Nonnull
    private final PredictionModel createOldModel(@Nullable String str) {
        PredictionModel sparseModel;
        boolean useCovariance = useCovariance();
        if (!this.dense_model) {
            int initialModelSize = getInitialModelSize();
            logger.info("Build a sparse model with initial with " + initialModelSize + " initial dimensions");
            sparseModel = new SparseModel(initialModelSize, useCovariance);
        } else if (this.disable_halffloat || this.model_dims <= 16777216) {
            logger.info("Build a dense model with initial with " + this.model_dims + " initial dimensions" + (useCovariance ? " w/ covariances" : StringUtil.EMPTY_STRING));
            sparseModel = new DenseModel(this.model_dims, useCovariance);
        } else {
            logger.info("Build a space efficient dense model with " + this.model_dims + " initial dimensions" + (useCovariance ? " w/ covariances" : StringUtil.EMPTY_STRING));
            sparseModel = new SpaceEfficientDenseModel(this.model_dims, useCovariance);
        }
        if (this.mixConnectInfo != null) {
            sparseModel.configureClock();
            sparseModel = new SynchronizedModelWrapper(sparseModel);
            MixClient configureMixClient = configureMixClient(this.mixConnectInfo, str, sparseModel);
            sparseModel.configureMix(configureMixClient, this.mixCancel);
            this.mixClient = configureMixClient;
        }
        if ($assertionsDisabled || sparseModel != null) {
            return sparseModel;
        }
        throw new AssertionError();
    }

    @Nonnull
    private final PredictionModel createNewModel(@Nullable String str) {
        PredictionModel newSparseModel;
        boolean useCovariance = useCovariance();
        if (!this.dense_model) {
            int initialModelSize = getInitialModelSize();
            logger.info("Build a sparse model with initial with " + initialModelSize + " initial dimensions");
            newSparseModel = new NewSparseModel(initialModelSize, useCovariance);
        } else if (this.disable_halffloat || this.model_dims <= 16777216) {
            logger.info("Build a dense model with initial with " + this.model_dims + " initial dimensions" + (useCovariance ? " w/ covariances" : StringUtil.EMPTY_STRING));
            newSparseModel = new NewDenseModel(this.model_dims, useCovariance);
        } else {
            logger.info("Build a space efficient dense model with " + this.model_dims + " initial dimensions" + (useCovariance ? " w/ covariances" : StringUtil.EMPTY_STRING));
            newSparseModel = new NewSpaceEfficientDenseModel(this.model_dims, useCovariance);
        }
        if (this.mixConnectInfo != null) {
            newSparseModel.configureClock();
            newSparseModel = new SynchronizedModelWrapper(newSparseModel);
            MixClient configureMixClient = configureMixClient(this.mixConnectInfo, str, newSparseModel);
            newSparseModel.configureMix(configureMixClient, this.mixCancel);
            this.mixClient = configureMixClient;
        }
        if ($assertionsDisabled || newSparseModel != null) {
            return newSparseModel;
        }
        throw new AssertionError();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Nonnull
    public final Optimizer createOptimizer(@CheckForNull Map<String, String> map) {
        Preconditions.checkNotNull(map);
        if (this.dense_model) {
            return DenseOptimizerFactory.create(this.model_dims < 0 ? 16777216 : this.model_dims, map);
        }
        return SparseOptimizerFactory.create(this.model_dims < 0 ? 16384 : this.model_dims, map);
    }

    @Nonnull
    protected MixClient configureMixClient(@Nonnull String str, @Nullable String str2, @Nonnull PredictionModel predictionModel) {
        String str3 = this.mixSessionName == null ? MixClient.DUMMY_JOB_ID : this.mixSessionName;
        if (str2 != null) {
            str3 = str3 + '-' + str2;
        }
        MixClient mixClient = new MixClient(useCovariance() ? MixMessage.MixEventName.argminKLD : MixMessage.MixEventName.average, str3, str, this.ssl, this.mixThreshold, predictionModel);
        logger.info("Successfully configured mix client: " + str);
        return mixClient;
    }

    protected int getInitialModelSize() {
        return 16384;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Nonnull
    public ObjectInspector getFeatureOutputOI(@Nonnull PrimitiveObjectInspector primitiveObjectInspector) throws UDFArgumentException {
        return this.dense_model ? PrimitiveObjectInspectorFactory.javaIntObjectInspector : ObjectInspectorUtils.getStandardObjectInspector(primitiveObjectInspector);
    }

    public void close() throws HiveException {
        if (this.mixClient != null) {
            IOUtils.closeQuietly(this.mixClient);
            this.mixClient = null;
        }
    }

    static {
        $assertionsDisabled = !LearnerBaseUDTF.class.desiredAssertionStatus();
        logger = LogFactory.getLog(LearnerBaseUDTF.class);
    }
}
