package org.nd4j.linalg.api.ops.executioner;

import java.util.List;
import java.util.Properties;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.complex.IComplexNDArray;
import org.nd4j.linalg.api.environment.Nd4jEnvironment;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Accumulation;
import org.nd4j.linalg.api.ops.BroadcastOp;
import org.nd4j.linalg.api.ops.GridOp;
import org.nd4j.linalg.api.ops.IndexAccumulation;
import org.nd4j.linalg.api.ops.MetaOp;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.RandomOp;
import org.nd4j.linalg.api.ops.ScalarOp;
import org.nd4j.linalg.api.ops.TransformOp;
import org.nd4j.linalg.api.ops.aggregates.Aggregate;
import org.nd4j.linalg.api.ops.aggregates.Batch;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.ops.impl.accum.Variance;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.cache.TADManager;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.profiler.OpProfiler;
import org.nd4j.linalg.util.ArrayUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/linalg/api/ops/executioner/DefaultOpExecutioner.class */
public class DefaultOpExecutioner implements OpExecutioner {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) DefaultOpExecutioner.class);
    protected OpExecutioner.ProfilingMode profilingMode = OpExecutioner.ProfilingMode.DISABLED;
    protected OpExecutioner.ExecutionMode executionMode = OpExecutioner.ExecutionMode.JAVA;

    protected void checkForCompression(Op op) {
        interceptIntDataType(op);
        if (op.x() != null && op.x().isCompressed()) {
            Nd4j.getCompressor().decompressi(op.x());
        }
        if (op.y() != null && op.y().isCompressed()) {
            Nd4j.getCompressor().decompressi(op.y());
        }
        if (op.z() == null || !op.z().isCompressed()) {
            return;
        }
        Nd4j.getCompressor().decompressi(op.z());
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public String getLastOp() {
        return "UNKNOWN";
    }

    protected void interceptIntDataType(Op op) {
        if (op.x() != null && op.x().data().dataType() == DataBuffer.Type.INT) {
            throw new ND4JIllegalStateException("Op.X contains INT data. Operations on INT dataType are not supported yet");
        }
        if (op.z() != null && op.z().data().dataType() == DataBuffer.Type.INT) {
            throw new ND4JIllegalStateException("Op.Z contains INT data. Operations on INT dataType are not supported yet");
        }
        if (op.y() != null && op.y().data().dataType() == DataBuffer.Type.INT) {
            throw new ND4JIllegalStateException("Op.Y contains INT data. Operations on INT dataType are not supported yet.");
        }
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public Op exec(Op op) {
        if (!op.isPassThrough()) {
            throw new IllegalStateException("Java computation no longer supported");
        }
        op.exec();
        return op;
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public INDArray execAndReturn(Op op) {
        if (op instanceof TransformOp) {
            return execAndReturn((TransformOp) op);
        }
        if (op instanceof ScalarOp) {
            return execAndReturn((ScalarOp) op);
        }
        if (op instanceof Accumulation) {
            return Nd4j.scalar(execAndReturn((Accumulation) op).getFinalResult());
        }
        if (op instanceof IndexAccumulation) {
            return Nd4j.scalar(execAndReturn((IndexAccumulation) op).getFinalResult());
        }
        throw new IllegalArgumentException("Illegal type of op: " + op.getClass());
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public void iterateOverAllRows(Op op) {
        if (op.x().isVector()) {
            op.setX(op.x());
            if (op.y() != null) {
                op.setY(op.y());
            }
            op.setZ(op.z());
            exec(op);
            return;
        }
        if (!op.x().isMatrix()) {
            INDArray x = op.x();
            INDArray z = op.z();
            for (int i = 0; i < x.slices(); i++) {
                INDArray slice = x.slice(i);
                INDArray slice2 = z.slice(i);
                op.setX(slice);
                op.setZ(slice2);
                iterateOverAllRows(op);
            }
            return;
        }
        if (op.x() instanceof IComplexNDArray) {
            IComplexNDArray iComplexNDArray = (IComplexNDArray) op.x();
            IComplexNDArray iComplexNDArray2 = (IComplexNDArray) op.z();
            IComplexNDArray iComplexNDArray3 = (IComplexNDArray) op.y();
            for (int i2 = 0; i2 < iComplexNDArray.rows(); i2++) {
                IComplexNDArray slice3 = iComplexNDArray.slice(i2);
                IComplexNDArray slice4 = iComplexNDArray2.slice(i2);
                op.setX(slice3.dup());
                op.setZ(slice4.dup());
                if (iComplexNDArray3 != null) {
                    op.setY(iComplexNDArray3.slice(i2));
                }
                exec(op);
                iComplexNDArray2.slice(i2).assign(op.z());
            }
            return;
        }
        INDArray x2 = op.x();
        INDArray z2 = op.z();
        INDArray y = op.y();
        for (int i3 = 0; i3 < x2.rows(); i3++) {
            INDArray row = x2.getRow(i3);
            INDArray row2 = z2.getRow(i3);
            op.setX(row.dup());
            op.setZ(row2.dup());
            if (y != null) {
                op.setY(y.getRow(i3).dup());
            }
            exec(op);
            row2.assign(op.z());
        }
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public void iterateOverAllColumns(Op op) {
        if (op.x().isVector()) {
            exec(op);
            return;
        }
        if (op.x().isMatrix() || op.x().isColumnVector()) {
            exec(op, 1);
            return;
        }
        if (op.x() instanceof IComplexNDArray) {
            IComplexNDArray iComplexNDArray = (IComplexNDArray) op.x();
            IComplexNDArray iComplexNDArray2 = (IComplexNDArray) op.z();
            IComplexNDArray iComplexNDArray3 = (IComplexNDArray) op.y();
            for (int i = 0; i < op.x().slices(); i++) {
                op.setX(iComplexNDArray.getColumn(i));
                op.setZ(iComplexNDArray2.getColumn(i));
                if (iComplexNDArray3 != null) {
                    op.setY(iComplexNDArray3.getColumn(i));
                }
                iterateOverAllColumns(op);
            }
            return;
        }
        INDArray x = op.x();
        INDArray z = op.z();
        INDArray y = op.y();
        for (int i2 = 0; i2 < op.x().slices(); i2++) {
            op.setX(x.getColumn(i2));
            op.setZ(z.getColumn(i2));
            if (y != null) {
                op.setY(y.getColumn(i2));
            }
            iterateOverAllColumns(op);
        }
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public INDArray execAndReturn(TransformOp transformOp) {
        return ((TransformOp) exec(transformOp)).z();
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public Accumulation execAndReturn(Accumulation accumulation) {
        return (Accumulation) exec(accumulation);
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public Accumulation execAndReturn(Variance variance, boolean z) {
        return null;
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public INDArray execAndReturn(ScalarOp scalarOp) {
        return exec(scalarOp).z();
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public IndexAccumulation execAndReturn(IndexAccumulation indexAccumulation) {
        return (IndexAccumulation) exec(indexAccumulation);
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public INDArray execAndReturn(BroadcastOp broadcastOp) {
        return exec(broadcastOp).z();
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public Op exec(Op op, int... iArr) {
        if (iArr.length == op.x().rank()) {
            iArr = new int[]{Integer.MAX_VALUE};
        }
        if (op.isPassThrough()) {
            op.exec(iArr);
            return op;
        }
        if ((op instanceof Accumulation) || (op instanceof IndexAccumulation)) {
            throw new IllegalStateException("exec(Op,int...) should never be invoked for Accumulation/IndexAccumulation");
        }
        if (op instanceof ScalarOp) {
            throw new IllegalStateException("Java computation no longer supported");
        }
        if (op instanceof TransformOp) {
            throw new UnsupportedOperationException("Executing transform ops along a dimension should be done via exec special");
        }
        throw new UnsupportedOperationException("Unknown op type");
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public INDArray exec(Accumulation accumulation, int... iArr) {
        if (iArr.length == accumulation.x().rank()) {
            iArr = new int[]{Integer.MAX_VALUE};
        }
        if (accumulation.isPassThrough()) {
            accumulation.exec(iArr);
            return accumulation.z();
        }
        if (iArr[0] == Integer.MAX_VALUE) {
            return accumulation.x() instanceof IComplexNDArray ? Nd4j.scalar(execAndReturn(accumulation).getFinalResultComplex()) : Nd4j.scalar(execAndReturn(accumulation).getFinalResult().doubleValue());
        }
        if (!(accumulation instanceof IComplexNDArray)) {
            throw new UnsupportedOperationException("Java computation no longer supported");
        }
        int[] removeIndex = ArrayUtil.removeIndex(accumulation.x().shape(), iArr);
        if (removeIndex.length == 1) {
            removeIndex = iArr[0] == 0 ? new int[]{1, removeIndex[0]} : new int[]{removeIndex[0], 1};
        } else if (removeIndex.length == 0) {
            removeIndex = new int[]{1, 1};
        }
        IComplexNDArray createComplex = Nd4j.createComplex(removeIndex);
        for (int i = 0; i < accumulation.x().tensorssAlongDimension(iArr); i++) {
            createComplex.putScalar(i, execAndReturn((Accumulation) accumulation.opForDimension(i, iArr)).getFinalResultComplex());
        }
        if (createComplex.ordering() == 'c') {
            createComplex.setStride(ArrayUtil.reverseCopy(createComplex.stride()));
        }
        return createComplex;
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public INDArray exec(Variance variance, boolean z, int... iArr) {
        variance.setBiasCorrected(z);
        return exec((Accumulation) variance, iArr);
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public INDArray exec(IndexAccumulation indexAccumulation, int... iArr) {
        throw new UnsupportedOperationException("Operation should use exec special");
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public OpExecutioner.ExecutionMode executionMode() {
        return this.executionMode;
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public void setExecutionMode(OpExecutioner.ExecutionMode executionMode) {
        this.executionMode = executionMode;
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public INDArray exec(BroadcastOp broadcastOp, int... iArr) {
        if (iArr.length == broadcastOp.x().rank()) {
            iArr = new int[]{Integer.MAX_VALUE};
        }
        if (!broadcastOp.isPassThrough()) {
            throw new IllegalStateException("Java computation no longer supported");
        }
        broadcastOp.exec(iArr);
        return broadcastOp.z();
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public void exec(MetaOp metaOp) {
        throw new UnsupportedOperationException("MetaOp execution isn't supported for this OpExecutioner yet");
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public void exec(GridOp gridOp) {
        throw new UnsupportedOperationException("GridOp execution isn't supported for this OpExecutioner yet");
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public <T extends Aggregate> void exec(Batch<T> batch) {
        throw new UnsupportedOperationException();
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public void exec(Aggregate aggregate) {
        throw new UnsupportedOperationException();
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public void exec(List<Aggregate> list) {
        throw new UnsupportedOperationException();
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public INDArray exec(RandomOp randomOp) {
        return exec(randomOp, Nd4j.getRandom());
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public INDArray exec(RandomOp randomOp, Random random) {
        throw new UnsupportedOperationException();
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public void setProfilingMode(OpExecutioner.ProfilingMode profilingMode) {
        this.profilingMode = profilingMode;
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public OpExecutioner.ProfilingMode getProfilingMode() {
        return this.profilingMode;
    }

    /* JADX WARN: Can't fix incorrect switch cases order, some code will duplicate */
    /* JADX WARN: Failed to find 'out' block for switch in B:2:0x000b. Please report as an issue. */
    public long profilingHookIn(Op op, DataBuffer... dataBufferArr) {
        switch (this.profilingMode) {
            case ALL:
                OpProfiler.getInstance().processOpCall(op, dataBufferArr);
                return System.nanoTime();
            case METHODS:
                return System.nanoTime();
            case OPERATIONS:
                OpProfiler.getInstance().processOpCall(op, dataBufferArr);
                return System.nanoTime();
            case DISABLED:
            default:
                return 0L;
        }
    }

    /* JADX WARN: Can't fix incorrect switch cases order, some code will duplicate */
    /* JADX WARN: Failed to find 'out' block for switch in B:2:0x000b. Please report as an issue. */
    public long profilingHookIn(Op op) {
        switch (this.profilingMode) {
            case ALL:
                OpProfiler.getInstance().processOpCall(op);
                return System.nanoTime();
            case METHODS:
                return System.nanoTime();
            case OPERATIONS:
                OpProfiler.getInstance().processOpCall(op);
                return System.nanoTime();
            case DISABLED:
            default:
                return 0L;
        }
    }

    public void profilingHookOut(Op op, long j) {
        switch (this.profilingMode) {
            case ALL:
                OpProfiler.getInstance().processStackCall(op, j);
                OpProfiler.getInstance().timeOpCall(op, j);
                return;
            case METHODS:
                OpProfiler.getInstance().processStackCall(op, j);
                return;
            case OPERATIONS:
                OpProfiler.getInstance().timeOpCall(op, j);
                return;
            case DISABLED:
            default:
                return;
            case NAN_PANIC:
                OpExecutionerUtil.checkForNaN(op);
                return;
            case INF_PANIC:
                OpExecutionerUtil.checkForInf(op);
                return;
            case ANY_PANIC:
                OpExecutionerUtil.checkForNaN(op);
                OpExecutionerUtil.checkForInf(op);
                return;
        }
    }

    public static void validateDataType(DataBuffer.Type type, Op op) {
        if (op.x() != null && op.x().data().dataType() == DataBuffer.Type.COMPRESSED) {
            Nd4j.getCompressor().decompressi(op.x());
        }
        if (op.y() != null && op.y().data().dataType() == DataBuffer.Type.COMPRESSED) {
            Nd4j.getCompressor().decompressi(op.y());
        }
        if (op.z() != null && op.z().data().dataType() == DataBuffer.Type.COMPRESSED) {
            Nd4j.getCompressor().decompressi(op.z());
        }
        if (op.x() != null && op.x().data().dataType() != type && op.x().data().dataType() != DataBuffer.Type.COMPRESSED) {
            throw new ND4JIllegalStateException("op.X dataType is [" + op.x().data().dataType() + "] instead of expected [" + type + "]");
        }
        if (op.z() != null && op.z().data().dataType() != type && op.z().data().dataType() != DataBuffer.Type.COMPRESSED) {
            throw new ND4JIllegalStateException("op.Z dataType is [" + op.z().data().dataType() + "] instead of expected [" + type + "]");
        }
        if (op.y() != null && op.y().data().dataType() != type) {
            throw new ND4JIllegalStateException("op.Y dataType is [" + op.y().data().dataType() + "] instead of expected [" + type + "]");
        }
        DataBuffer extraArgsDataBuff = op.extraArgsDataBuff();
        if (extraArgsDataBuff != null && extraArgsDataBuff.dataType() != type) {
            throw new ND4JIllegalStateException("op.Extras dataType is [" + extraArgsDataBuff.dataType() + "] instead of expected [" + type + "]");
        }
    }

    public static void validateDataType(DataBuffer.Type type, INDArray... iNDArrayArr) {
        if (iNDArrayArr == null || iNDArrayArr.length == 0) {
            return;
        }
        for (INDArray iNDArray : iNDArrayArr) {
            if (iNDArray != null && iNDArray.data().dataType() != type) {
                int i = 0 + 1;
                throw new ND4JIllegalStateException("INDArray [0] dataType is [" + iNDArray.data().dataType() + "] instead of expected [" + type + "]");
            }
        }
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public TADManager getTADManager() {
        throw new UnsupportedOperationException();
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public Properties getEnvironmentInformation() {
        Properties properties = new Properties();
        properties.put(Nd4jEnvironment.CPU_CORES_KEY, Integer.valueOf(Runtime.getRuntime().availableProcessors()));
        properties.put(Nd4jEnvironment.HOST_TOTAL_MEMORY_KEY, Long.valueOf(Runtime.getRuntime().maxMemory()));
        properties.put(Nd4jEnvironment.OS_KEY, System.getProperty("os.name"));
        return properties;
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public void printEnvironmentInformation() {
        Properties environmentInformation = getEnvironmentInformation();
        String format = String.format("%.1f", Double.valueOf(((((Long) environmentInformation.get(Nd4jEnvironment.HOST_TOTAL_MEMORY_KEY)).longValue() / 1024.0d) / 1024.0d) / 1024.0d));
        log.info("Backend used: [{}]; OS: [{}]", environmentInformation.get(Nd4jEnvironment.BACKEND_KEY), environmentInformation.get(Nd4jEnvironment.OS_KEY));
        log.info("Cores: [{}]; Memory: [{}GB];", environmentInformation.get(Nd4jEnvironment.CPU_CORES_KEY), format);
        log.info("Blas vendor: [{}]", environmentInformation.get(Nd4jEnvironment.BLAS_VENDOR_KEY));
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public void push() {
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public void commit() {
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public INDArray thresholdEncode(INDArray iNDArray, double d) {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public INDArray thresholdEncode(INDArray iNDArray, double d, Integer num) {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public INDArray thresholdDecode(INDArray iNDArray, INDArray iNDArray2) {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public long bitmapEncode(INDArray iNDArray, INDArray iNDArray2, double d) {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public INDArray bitmapEncode(INDArray iNDArray, double d) {
        INDArray createArrayFromShapeBuffer = Nd4j.createArrayFromShapeBuffer(Nd4j.getDataBufferFactory().createInt((iNDArray.length() / 16) + 5), iNDArray.shapeInfoDataBuffer());
        bitmapEncode(iNDArray, createArrayFromShapeBuffer, d);
        return createArrayFromShapeBuffer;
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public INDArray bitmapDecode(INDArray iNDArray, INDArray iNDArray2) {
        throw new UnsupportedOperationException("Not yet implemented");
    }
}
