package hivemall.math.matrix.sparse;

import hivemall.annotations.Experimental;
import hivemall.math.matrix.AbstractMatrix;
import hivemall.math.matrix.ColumnMajorMatrix;
import hivemall.math.matrix.MatrixUtils;
import hivemall.math.matrix.RowMajorMatrix;
import hivemall.math.matrix.builders.DoKMatrixBuilder;
import hivemall.math.vector.Vector;
import hivemall.math.vector.VectorProcedure;
import hivemall.utils.collections.maps.Long2DoubleOpenHashTable;
import hivemall.utils.lang.Preconditions;
import hivemall.utils.lang.Primitives;
import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import org.apache.lucene.util.packed.PackedInts;

@Experimental
/* loaded from: input_file:hivemall/math/matrix/sparse/DoKMatrix.class */
public final class DoKMatrix extends AbstractMatrix {

    @Nonnull
    private final Long2DoubleOpenHashTable elements;

    @Nonnegative
    private int numRows;

    @Nonnegative
    private int numColumns;

    @Nonnegative
    private int nnz;

    public DoKMatrix() {
        this(0, 0);
    }

    public DoKMatrix(@Nonnegative int i, @Nonnegative int i2) {
        this(i, i2, 0.05f);
    }

    public DoKMatrix(@Nonnegative int i, @Nonnegative int i2, @Nonnegative float f) {
        Preconditions.checkArgument(f >= PackedInts.COMPACT && f <= 1.0f, "Invalid Sparsity value: " + f);
        this.elements = new Long2DoubleOpenHashTable(Math.max(16384, Math.round(i * i2 * f)));
        this.elements.defaultReturnValue(CMAESOptimizer.DEFAULT_STOPFITNESS);
        this.numRows = i;
        this.numColumns = i2;
        this.nnz = 0;
    }

    public DoKMatrix(@Nonnegative int i) {
        this.elements = new Long2DoubleOpenHashTable(Math.max(i, 16384));
        this.elements.defaultReturnValue(CMAESOptimizer.DEFAULT_STOPFITNESS);
        this.numRows = 0;
        this.numColumns = 0;
        this.nnz = 0;
    }

    @Override // hivemall.math.matrix.Matrix
    public boolean isSparse() {
        return true;
    }

    @Override // hivemall.math.matrix.Matrix
    public boolean isRowMajorMatrix() {
        return false;
    }

    @Override // hivemall.math.matrix.Matrix
    public boolean isColumnMajorMatrix() {
        return false;
    }

    @Override // hivemall.math.matrix.Matrix
    public boolean readOnly() {
        return false;
    }

    @Override // hivemall.math.matrix.Matrix
    public boolean swappable() {
        return true;
    }

    @Override // hivemall.math.matrix.Matrix
    public int nnz() {
        return this.nnz;
    }

    @Override // hivemall.math.matrix.Matrix
    public int numRows() {
        return this.numRows;
    }

    @Override // hivemall.math.matrix.Matrix
    public int numColumns() {
        return this.numColumns;
    }

    @Override // hivemall.math.matrix.Matrix
    public int numColumns(@Nonnegative int i) {
        int i2 = 0;
        for (int i3 = 0; i3 < this.numColumns; i3++) {
            if (this.elements.containsKey(index(i, i3))) {
                i2++;
            }
        }
        return i2;
    }

    @Override // hivemall.math.matrix.Matrix
    public double[] getRow(@Nonnegative int i) {
        return getRow(i, row());
    }

    @Override // hivemall.math.matrix.Matrix
    public double[] getRow(@Nonnegative int i, @Nonnull double[] dArr) {
        checkRowIndex(i, this.numRows);
        int min = Math.min(dArr.length, this.numColumns);
        for (int i2 = 0; i2 < min; i2++) {
            dArr[i2] = this.elements.get(index(i, i2));
        }
        return dArr;
    }

    @Override // hivemall.math.matrix.Matrix
    public void getRow(@Nonnegative int i, @Nonnull Vector vector) {
        checkRowIndex(i, this.numRows);
        vector.clear();
        for (int i2 = 0; i2 < this.numColumns; i2++) {
            double d = this.elements.get(index(i, i2), CMAESOptimizer.DEFAULT_STOPFITNESS);
            if (d != CMAESOptimizer.DEFAULT_STOPFITNESS) {
                vector.set(i2, d);
            }
        }
    }

    @Override // hivemall.math.matrix.Matrix
    public double get(@Nonnegative int i, @Nonnegative int i2, double d) {
        return this.elements.get(index(i, i2), d);
    }

    @Override // hivemall.math.matrix.Matrix
    public void set(@Nonnegative int i, @Nonnegative int i2, double d) {
        checkIndex(i, i2);
        long index = index(i, i2);
        if ((d != CMAESOptimizer.DEFAULT_STOPFITNESS || this.elements.containsKey(index)) && this.elements.put(index, d, CMAESOptimizer.DEFAULT_STOPFITNESS) == CMAESOptimizer.DEFAULT_STOPFITNESS) {
            this.nnz++;
            this.numRows = Math.max(this.numRows, i + 1);
            this.numColumns = Math.max(this.numColumns, i2 + 1);
        }
    }

    @Override // hivemall.math.matrix.Matrix
    public double getAndSet(@Nonnegative int i, @Nonnegative int i2, double d) {
        checkIndex(i, i2);
        long index = index(i, i2);
        if (d == CMAESOptimizer.DEFAULT_STOPFITNESS && !this.elements.containsKey(index)) {
            return CMAESOptimizer.DEFAULT_STOPFITNESS;
        }
        double put = this.elements.put(index, d, CMAESOptimizer.DEFAULT_STOPFITNESS);
        if (put == CMAESOptimizer.DEFAULT_STOPFITNESS) {
            this.nnz++;
            this.numRows = Math.max(this.numRows, i + 1);
            this.numColumns = Math.max(this.numColumns, i2 + 1);
        }
        return put;
    }

    @Override // hivemall.math.matrix.Matrix
    public void swap(@Nonnegative int i, @Nonnegative int i2) {
        checkRowIndex(i, this.numRows);
        checkRowIndex(i2, this.numRows);
        for (int i3 = 0; i3 < this.numColumns; i3++) {
            long index = index(i, i3);
            long index2 = index(i2, i3);
            int _findKey = this.elements._findKey(index);
            int _findKey2 = this.elements._findKey(index2);
            if (_findKey >= 0) {
                if (_findKey2 >= 0) {
                    this.elements._set(_findKey, this.elements._set(_findKey2, this.elements._get(_findKey)));
                } else {
                    this.elements.put(index2, this.elements._remove(_findKey));
                }
            } else if (_findKey2 >= 0) {
                this.elements.put(index, this.elements._remove(_findKey2));
            }
        }
    }

    @Override // hivemall.math.matrix.Matrix
    public void eachInRow(@Nonnegative int i, @Nonnull VectorProcedure vectorProcedure, boolean z) {
        checkRowIndex(i, this.numRows);
        for (int i2 = 0; i2 < this.numColumns; i2++) {
            int _findKey = this.elements._findKey(index(i, i2));
            if (_findKey >= 0) {
                vectorProcedure.apply(i2, this.elements._get(_findKey));
            } else if (z) {
                vectorProcedure.apply(i2, CMAESOptimizer.DEFAULT_STOPFITNESS);
            }
        }
    }

    @Override // hivemall.math.matrix.Matrix
    public void eachNonZeroInRow(@Nonnegative int i, @Nonnull VectorProcedure vectorProcedure) {
        checkRowIndex(i, this.numRows);
        for (int i2 = 0; i2 < this.numColumns; i2++) {
            double d = this.elements.get(index(i, i2), CMAESOptimizer.DEFAULT_STOPFITNESS);
            if (d != CMAESOptimizer.DEFAULT_STOPFITNESS) {
                vectorProcedure.apply(i2, d);
            }
        }
    }

    @Override // hivemall.math.matrix.Matrix
    public void eachColumnIndexInRow(int i, VectorProcedure vectorProcedure) {
        checkRowIndex(i, this.numRows);
        for (int i2 = 0; i2 < this.numColumns; i2++) {
            if (this.elements._findKey(index(i, i2)) != -1) {
                vectorProcedure.apply(i2);
            }
        }
    }

    @Override // hivemall.math.matrix.Matrix
    public void eachInColumn(@Nonnegative int i, @Nonnull VectorProcedure vectorProcedure, boolean z) {
        checkColIndex(i, this.numColumns);
        for (int i2 = 0; i2 < this.numRows; i2++) {
            int _findKey = this.elements._findKey(index(i2, i));
            if (_findKey >= 0) {
                vectorProcedure.apply(i2, this.elements._get(_findKey));
            } else if (z) {
                vectorProcedure.apply(i2, CMAESOptimizer.DEFAULT_STOPFITNESS);
            }
        }
    }

    @Override // hivemall.math.matrix.Matrix
    public void eachNonZeroInColumn(@Nonnegative int i, @Nonnull VectorProcedure vectorProcedure) {
        checkColIndex(i, this.numColumns);
        for (int i2 = 0; i2 < this.numRows; i2++) {
            double d = this.elements.get(index(i2, i), CMAESOptimizer.DEFAULT_STOPFITNESS);
            if (d != CMAESOptimizer.DEFAULT_STOPFITNESS) {
                vectorProcedure.apply(i2, d);
            }
        }
    }

    @Override // hivemall.math.matrix.AbstractMatrix, hivemall.math.matrix.Matrix
    public void eachNonZeroCell(@Nonnull VectorProcedure vectorProcedure) {
        if (this.nnz == 0) {
            return;
        }
        Long2DoubleOpenHashTable.IMapIterator entries = this.elements.entries();
        while (entries.next() != -1) {
            long key = entries.getKey();
            vectorProcedure.apply(Primitives.getHigh(key), Primitives.getLow(key), entries.getValue());
        }
    }

    @Override // hivemall.math.matrix.Matrix
    public RowMajorMatrix toRowMajorMatrix() {
        int size = this.elements.size();
        int[] iArr = new int[size];
        int[] iArr2 = new int[size];
        double[] dArr = new double[size];
        Long2DoubleOpenHashTable.IMapIterator entries = this.elements.entries();
        for (int i = 0; i < size; i++) {
            if (entries.next() == -1) {
                throw new IllegalStateException("itor.next() returns -1 where i=" + i);
            }
            long key = entries.getKey();
            iArr[i] = Primitives.getHigh(key);
            iArr2[i] = Primitives.getLow(key);
            dArr[i] = entries.getValue();
        }
        return MatrixUtils.coo2csr(iArr, iArr2, dArr, this.numRows, this.numColumns, true);
    }

    @Override // hivemall.math.matrix.Matrix, hivemall.math.matrix.FloatMatrix
    public ColumnMajorMatrix toColumnMajorMatrix() {
        int size = this.elements.size();
        int[] iArr = new int[size];
        int[] iArr2 = new int[size];
        double[] dArr = new double[size];
        Long2DoubleOpenHashTable.IMapIterator entries = this.elements.entries();
        for (int i = 0; i < size; i++) {
            if (entries.next() == -1) {
                throw new IllegalStateException("itor.next() returns -1 where i=" + i);
            }
            long key = entries.getKey();
            iArr[i] = Primitives.getHigh(key);
            iArr2[i] = Primitives.getLow(key);
            dArr[i] = entries.getValue();
        }
        return MatrixUtils.coo2csc(iArr, iArr2, dArr, this.numRows, this.numColumns, true);
    }

    @Override // hivemall.math.matrix.Matrix
    public DoKMatrixBuilder builder() {
        return new DoKMatrixBuilder(this.elements.size());
    }

    @Nonnegative
    private static long index(@Nonnegative int i, @Nonnegative int i2) {
        return Primitives.toLong(i, i2);
    }
}
