package hivemall.tools;

import hivemall.utils.collections.BoundedPriorityQueue;
import hivemall.utils.hadoop.HiveUtils;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import javax.annotation.Nonnull;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDTF;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.io.IntWritable;
import org.apache.lucene.analysis.wikipedia.WikipediaTokenizer;

@Description(name = "each_top_k", value = "_FUNC_(int K, Object group, double cmpKey, *) - Returns top-K values (or tail-K values when k is less than 0)")
/* loaded from: input_file:hivemall/tools/EachTopKUDTF.class */
public final class EachTopKUDTF extends GenericUDTF {
    private ObjectInspector[] argOIs;
    private PrimitiveObjectInspector kOI;
    private ObjectInspector prevGroupOI;
    private PrimitiveObjectInspector cmpKeyOI;
    private boolean _constantK;
    private int _prevK;
    private BoundedPriorityQueue<TupleWithKey> _queue;
    private TupleWithKey _tuple;
    private Object _previousGroup;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:hivemall/tools/EachTopKUDTF$TupleWithKey.class */
    public static final class TupleWithKey implements Comparable<TupleWithKey> {
        double key;
        Object[] row;

        TupleWithKey(double d, Object[] objArr) {
            this.key = d;
            this.row = objArr;
        }

        double getKey() {
            return this.key;
        }

        Object[] getRow() {
            return this.row;
        }

        void setKey(double d) {
            this.key = d;
        }

        @Override // java.lang.Comparable
        public int compareTo(TupleWithKey tupleWithKey) {
            return Double.compare(this.key, tupleWithKey.key);
        }
    }

    public StructObjectInspector initialize(ObjectInspector[] objectInspectorArr) throws UDFArgumentException {
        int length = objectInspectorArr.length;
        if (length < 4) {
            throw new UDFArgumentException("each_top_k(int K, Object group, double cmpKey, *) takes at least 4 arguments: " + length);
        }
        this.argOIs = objectInspectorArr;
        this._constantK = ObjectInspectorUtils.isConstantObjectInspector(objectInspectorArr[0]);
        if (this._constantK) {
            int asConstInt = HiveUtils.getAsConstInt(objectInspectorArr[0]);
            if (asConstInt == 0) {
                throw new UDFArgumentException("k should not be 0");
            }
            this._queue = getQueue(asConstInt);
        } else {
            this.kOI = HiveUtils.asIntCompatibleOI(objectInspectorArr[0]);
            this._prevK = 0;
        }
        this.prevGroupOI = ObjectInspectorUtils.getStandardObjectInspector(objectInspectorArr[1], ObjectInspectorUtils.ObjectInspectorCopyOption.DEFAULT);
        this.cmpKeyOI = HiveUtils.asDoubleCompatibleOI(objectInspectorArr[2]);
        this._tuple = null;
        this._previousGroup = null;
        ArrayList arrayList = new ArrayList(length);
        ArrayList arrayList2 = new ArrayList(length);
        arrayList.add("rank");
        arrayList2.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
        arrayList.add("key");
        arrayList2.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
        for (int i = 3; i < length; i++) {
            arrayList.add(WikipediaTokenizer.CATEGORY + (i - 2));
            arrayList2.add(ObjectInspectorUtils.getStandardObjectInspector(objectInspectorArr[i], ObjectInspectorUtils.ObjectInspectorCopyOption.DEFAULT));
        }
        return ObjectInspectorFactory.getStandardStructObjectInspector(arrayList, arrayList2);
    }

    @Nonnull
    private static BoundedPriorityQueue<TupleWithKey> getQueue(int i) {
        return new BoundedPriorityQueue<>(Math.abs(i), i < 0 ? Collections.reverseOrder() : new Comparator<TupleWithKey>() { // from class: hivemall.tools.EachTopKUDTF.1
            @Override // java.util.Comparator
            public int compare(TupleWithKey tupleWithKey, TupleWithKey tupleWithKey2) {
                return tupleWithKey.compareTo(tupleWithKey2);
            }
        });
    }

    public void process(Object[] objArr) throws HiveException {
        Object[] row;
        Object obj = objArr[1];
        if (!isSameGroup(obj)) {
            this._previousGroup = ObjectInspectorUtils.copyToStandardObject(obj, this.argOIs[1], ObjectInspectorUtils.ObjectInspectorCopyOption.DEFAULT);
            if (this._queue != null) {
                drainQueue();
            }
            if (!this._constantK) {
                int i = PrimitiveObjectInspectorUtils.getInt(objArr[0], this.kOI);
                if (i == 0) {
                    return;
                }
                if (i != this._prevK) {
                    this._queue = getQueue(i);
                    this._prevK = i;
                }
            }
        }
        double d = PrimitiveObjectInspectorUtils.getDouble(objArr[2], this.cmpKeyOI);
        TupleWithKey tupleWithKey = this._tuple;
        if (this._tuple == null) {
            row = new Object[objArr.length - 1];
            tupleWithKey = new TupleWithKey(d, row);
            this._tuple = tupleWithKey;
        } else {
            row = tupleWithKey.getRow();
            tupleWithKey.setKey(d);
        }
        for (int i2 = 3; i2 < objArr.length; i2++) {
            row[i2 - 1] = ObjectInspectorUtils.copyToStandardObject(objArr[i2], this.argOIs[i2], ObjectInspectorUtils.ObjectInspectorCopyOption.DEFAULT);
        }
        if (this._queue.offer(tupleWithKey)) {
            this._tuple = null;
        }
    }

    private boolean isSameGroup(Object obj) {
        if (obj == null && this._previousGroup == null) {
            return true;
        }
        return (obj == null || this._previousGroup == null || ObjectInspectorUtils.compare(obj, this.argOIs[1], this._previousGroup, this.prevGroupOI) != 0) ? false : true;
    }

    private void drainQueue() throws HiveException {
        int size = this._queue.size();
        if (size > 0) {
            TupleWithKey[] tupleWithKeyArr = new TupleWithKey[size];
            for (int i = 0; i < size; i++) {
                TupleWithKey poll = this._queue.poll();
                if (poll == null) {
                    throw new IllegalStateException("Found null element in the queue");
                }
                tupleWithKeyArr[i] = poll;
            }
            IntWritable intWritable = new IntWritable(-1);
            DoubleWritable doubleWritable = new DoubleWritable(Double.NaN);
            int i2 = 0;
            double d = Double.NaN;
            for (int i3 = size - 1; i3 >= 0; i3--) {
                TupleWithKey tupleWithKey = tupleWithKeyArr[i3];
                tupleWithKeyArr[i3] = null;
                double key = tupleWithKey.getKey();
                if (key != d) {
                    i2++;
                    intWritable.set(i2);
                    doubleWritable.set(key);
                    d = key;
                }
                Object[] row = tupleWithKey.getRow();
                row[0] = intWritable;
                row[1] = doubleWritable;
                forward(row);
            }
            this._queue.clear();
        }
    }

    public void close() throws HiveException {
        drainQueue();
        this._queue = null;
        this._tuple = null;
    }
}
