package org.diffkt;

import java.util.ArrayList;
import java.util.List;
import kotlin.Metadata;
import kotlin.Pair;
import kotlin.Unit;
import kotlin.collections.ArraysKt;
import kotlin.collections.CollectionsKt;
import kotlin.collections.IntIterator;
import kotlin.jvm.functions.Function1;
import kotlin.jvm.internal.Intrinsics;
import kotlin.jvm.internal.SerializedIr;
import kotlin.ranges.IntRange;
import kotlin.ranges.RangesKt;
import org.diffkt.StridedUtils;
import org.diffkt.external.Dnnl;
import org.diffkt.external.External;
import org.diffkt.external.ExternalLib;
import org.diffkt.external.Math;
import org.diffkt.external.Predicate;
import org.diffkt.model.BatchNormResult;
import org.diffkt.model.MaxPoolKt;
import org.jetbrains.annotations.NotNull;
import shapeTyping.annotations.AllowUnreduced;
import shapeTyping.annotations.SType;

/* compiled from: StridedFloatTensorOperations.kt */
@SerializedIr(b = {"��"})
@Metadata(mv = {Convolve.H_AXIS, 6, Convolve.N_AXIS}, k = Convolve.H_AXIS, xi = 48, d1 = {"��l\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010\u000e\n\u0002\b\u0003\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0006\n\u0002\u0018\u0002\n\u0002\u0010 \n\u0002\u0010\u0015\n\u0002\b\u0002\n\u0002\u0010\b\n\u0002\b\u0002\n\u0002\u0010\u000b\n\u0002\b\t\n\u0002\u0018\u0002\n\u0002\b\u0007\n\u0002\u0018\u0002\n\u0002\b\b\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\bÁ\u0002\u0018��2\u00020\u0001B\u0007\b\u0002¢\u0006\u0002\u0010\u0002J \u0010\u0007\u001a\u00020\b2\u0006\u0010\t\u001a\u00020\n2\u0006\u0010\u000b\u001a\u00020\n2\u0006\u0010\f\u001a\u00020\rH\u0016J \u0001\u0010\u000e\u001a 0\n¢\u0006\f\b\u000f\u0012\b\b\u0010\u0012\u0004\b\b(\u0011¢\u0006\f\b\u000f\u0012\b\b\u0010\u0012\u0004\b\b(\u00112$\u0010\u0012\u001a 0\n¢\u0006\f\b\u000f\u0012\b\b\u0010\u0012\u0004\b\b(\u0011¢\u0006\f\b\u000f\u0012\b\b\u0010\u0012\u0004\b\b(\u00112$\u0010\u0013\u001a 0\n¢\u0006\f\b\u000f\u0012\b\b\u0010\u0012\u0004\b\b(\u0011¢\u0006\f\b\u000f\u0012\b\b\u0010\u0012\u0004\b\b(\u00112$\u0010\u0014\u001a 0\n¢\u0006\f\b\u000f\u0012\b\b\u0010\u0012\u0004\b\b(\u0011¢\u0006\f\b\u000f\u0012\b\b\u0010\u0012\u0004\b\b(\u00112\u0006\u0010\f\u001a\u00020\rH\u0017J<\u0010\u0015\u001a\u0016\u0012\u0004\u0012\u00020\n\u0012\f\u0012\n\u0012\u0004\u0012\u00020\u0018\u0018\u00010\u00170\u00162\u0006\u0010\u0019\u001a\u00020\n2\u0006\u0010\u001a\u001a\u00020\u001b2\u0006\u0010\u001c\u001a\u00020\u001b2\u0006\u0010\u001d\u001a\u00020\u001eH\u0016Jz\u0010\u001f\u001a 0\n¢\u0006\f\b\u000f\u0012\b\b\u0010\u0012\u0004\b\b(\u0011¢\u0006\f\b\u000f\u0012\b\b\u0010\u0012\u0004\b\b(\u00112$\u0010 \u001a 0\n¢\u0006\f\b\u000f\u0012\b\b\u0010\u0012\u0004\b\b(\u0011¢\u0006\f\b\u000f\u0012\b\b\u0010\u0012\u0004\b\b(\u00112$\u0010!\u001a 0\n¢\u0006\f\b\u000f\u0012\b\b\u0010\u0012\u0004\b\b(\u0011¢\u0006\f\b\u000f\u0012\b\b\u0010\u0012\u0004\b\b(\u00112\u0006\u0010\f\u001a\u00020\rH\u0017Jz\u0010\"\u001a 0\n¢\u0006\f\b\u000f\u0012\b\b\u0010\u0012\u0004\b\b(\u0011¢\u0006\f\b\u000f\u0012\b\b\u0010\u0012\u0004\b\b(\u00112$\u0010 \u001a 0\n¢\u0006\f\b\u000f\u0012\b\b\u0010\u0012\u0004\b\b(\u0011¢\u0006\f\b\u000f\u0012\b\b\u0010\u0012\u0004\b\b(\u00112$\u0010!\u001a 0\n¢\u0006\f\b\u000f\u0012\b\b\u0010\u0012\u0004\b\b(\u0011¢\u0006\f\b\u000f\u0012\b\b\u0010\u0012\u0004\b\b(\u00112\u0006\u0010\f\u001a\u00020\rH\u0017JL\u0010#\u001a 0\n¢\u0006\f\b\u000f\u0012\b\b\u0010\u0012\u0004\b\b(\u0011¢\u0006\f\b\u000f\u0012\b\b\u0010\u0012\u0004\b\b(\u00112$\u0010\u0019\u001a 0\n¢\u0006\f\b\u000f\u0012\b\b\u0010\u0012\u0004\b\b(\u0011¢\u0006\f\b\u000f\u0012\b\b\u0010\u0012\u0004\b\b(\u0011H\u0017J \u0010$\u001a\u00020\n2\u0006\u0010\u0019\u001a\u00020\n2\u0006\u0010%\u001a\u00020\n2\u0006\u0010\f\u001a\u00020\rH\u0016J\u0018\u0010&\u001a\u00020\n2\u0006\u0010\u0019\u001a\u00020\n2\u0006\u0010'\u001a\u00020(H\u0016J\u0018\u0010)\u001a\u00020\n2\u0006\u0010\u0019\u001a\u00020\n2\u0006\u0010*\u001a\u00020\u001bH\u0016J \u0010+\u001a\u00020\n2\u0006\u0010\u0019\u001a\u00020\n2\u0006\u0010,\u001a\u00020\u00182\u0006\u0010-\u001a\u00020\u001eH\u0016Jz\u0010.\u001a 0\n¢\u0006\f\b\u000f\u0012\b\b\u0010\u0012\u0004\b\b(\u0011¢\u0006\f\b\u000f\u0012\b\b\u0010\u0012\u0004\b\b(\u00112$\u0010 \u001a 0\n¢\u0006\f\b\u000f\u0012\b\b\u0010\u0012\u0004\b\b(\u0011¢\u0006\f\b\u000f\u0012\b\b\u0010\u0012\u0004\b\b(\u00112$\u0010!\u001a 0\n¢\u0006\f\b\u000f\u0012\b\b\u0010\u0012\u0004\b\b(\u0011¢\u0006\f\b\u000f\u0012\b\b\u0010\u0012\u0004\b\b(\u00112\u0006\u0010\f\u001a\u00020\rH\u0017J\\\u0010/\u001a 0\n¢\u0006\f\b\u000f\u0012\b\b\u0010\u0012\u0004\b\b(\u0011¢\u0006\f\b\u000f\u0012\b\b\u0010\u0012\u0004\b\b(\u00112\u0006\u0010 \u001a\u0002002$\u0010!\u001a 0\n¢\u0006\f\b\u000f\u0012\b\b\u0010\u0012\u0004\b\b(\u0011¢\u0006\f\b\u000f\u0012\b\b\u0010\u0012\u0004\b\b(\u00112\u0006\u0010\f\u001a\u00020\rH\u0017J\u0018\u00101\u001a\u00020\n2\u0006\u0010\u0019\u001a\u00020\n2\u0006\u0010,\u001a\u00020\u0018H\u0016JL\u00102\u001a 0\n¢\u0006\f\b\u000f\u0012\b\b\u0010\u0012\u0004\b\b(\u0011¢\u0006\f\b\u000f\u0012\b\b\u0010\u0012\u0004\b\b(\u00112$\u0010\u0019\u001a 0\n¢\u0006\f\b\u000f\u0012\b\b\u0010\u0012\u0004\b\b(\u0011¢\u0006\f\b\u000f\u0012\b\b\u0010\u0012\u0004\b\b(\u0011H\u0017J\u0018\u00103\u001a\u00020\n2\u0006\u0010\u0019\u001a\u00020\n2\u0006\u0010*\u001a\u00020\u001bH\u0016J\u0018\u00104\u001a\u00020\n2\u0006\u0010\u0019\u001a\u00020\n2\u0006\u00105\u001a\u00020\u0018H\u0016J \u00106\u001a\u00020\n2\u0006\u0010\u0019\u001a\u00020\n2\u0006\u00107\u001a\u00020\u001b2\u0006\u0010*\u001a\u00020\u001bH\u0016J \u00108\u001a\u00020\n2\u0006\u0010\u0019\u001a\u00020\n2\u0006\u00107\u001a\u0002092\u0006\u0010*\u001a\u00020\u001bH\u0016J\u0010\u0010:\u001a\u00020;2\u0006\u0010\u0010\u001a\u00020\nH\u0002R\u0014\u0010\u0003\u001a\u00020\u00048VX\u0096\u0004¢\u0006\u0006\u001a\u0004\b\u0005\u0010\u0006¨\u0006<"}, d2 = {"Lorg/diffkt/StridedFloatTensorOperations;", "Lorg/diffkt/FloatTensorOperations;", "()V", "name", "", "getName", "()Ljava/lang/String;", "batchNorm", "Lorg/diffkt/model/BatchNormResult;", "input", "Lorg/diffkt/DTensor;", "scaleShift", "derivativeId", "Lorg/diffkt/DerivativeID;", "ifThenElse", "LshapeTyping/annotations/SType;", "value", "S", "condition", "whenTrue", "whenFalse", "maxPoolWithIndices", "Lkotlin/Pair;", "", "", "x", "poolHeight", "", "poolWidth", "withIndices", "", "minus", "left", "right", "plus", "relu", "reluGrad", "reluUpstream", "reshape", "newShape", "Lorg/diffkt/Shape;", "squeeze", "axis", "sum", "axes", "keepDims", "times", "timesScalar", "Lorg/diffkt/DScalar;", "transpose", "unaryMinus", "unsqueeze", "view1", "indices", "view2", "index", "view3", "Lkotlin/ranges/IntRange;", "wrap", "Lorg/diffkt/StridedFloatTensor;", "api"})
@AllowUnreduced
/* loaded from: input_file:org/diffkt/StridedFloatTensorOperations.class */
public final class StridedFloatTensorOperations extends FloatTensorOperations {

    @NotNull
    public static final StridedFloatTensorOperations INSTANCE = new StridedFloatTensorOperations();

    /* compiled from: StridedFloatTensorOperations.kt */
    @Metadata(mv = {Convolve.H_AXIS, 6, Convolve.N_AXIS}, k = Convolve.C_AXIS, xi = 48)
    /* loaded from: input_file:org/diffkt/StridedFloatTensorOperations$WhenMappings.class */
    public /* synthetic */ class WhenMappings {
        public static final /* synthetic */ int[] $EnumSwitchMapping$0;

        static {
            int[] iArr = new int[StridedUtils.Layout.values().length];
            iArr[StridedUtils.Layout.SINGLETON.ordinal()] = 1;
            iArr[StridedUtils.Layout.NATURAL.ordinal()] = 2;
            $EnumSwitchMapping$0 = iArr;
        }
    }

    private StridedFloatTensorOperations() {
    }

    @Override // org.diffkt.Operations
    @NotNull
    public String getName() {
        return "StridedFloatTensor";
    }

    private final StridedFloatTensor wrap(DTensor dTensor) {
        if (!(dTensor.getDerivativeID().getSequence() == 0)) {
            throw new IllegalArgumentException("Failed requirement.".toString());
        }
        if (dTensor instanceof FloatTensor) {
            return ((FloatTensor) dTensor).asStrided();
        }
        throw new IllegalArgumentException("Failed requirement.".toString());
    }

    @Override // org.diffkt.FloatTensorOperations, org.diffkt.Operations
    @SType("S: Shape")
    @NotNull
    public DTensor plus(@NotNull DTensor dTensor, @NotNull DTensor dTensor2, @NotNull DerivativeID derivativeID) {
        Intrinsics.checkNotNullParameter(dTensor, "left");
        Intrinsics.checkNotNullParameter(dTensor2, "right");
        Intrinsics.checkNotNullParameter(derivativeID, "derivativeId");
        if (!Intrinsics.areEqual(derivativeID, NoDerivativeID.INSTANCE)) {
            throw new IllegalArgumentException("Failed requirement.".toString());
        }
        StridedFloatTensor wrap = wrap(dTensor);
        StridedFloatTensor wrap2 = wrap(dTensor2);
        return UtilsKt.shouldSendToCpp(200, (ExternalLib) Dnnl.INSTANCE, new FloatTensor[]{wrap, wrap2}, false, false) ? Dnnl.INSTANCE.add(wrap, wrap2) : super.plus(dTensor, dTensor2, derivativeID);
    }

    @Override // org.diffkt.FloatTensorOperations, org.diffkt.Operations
    @SType("S: Shape")
    @NotNull
    public DTensor minus(@NotNull DTensor dTensor, @NotNull DTensor dTensor2, @NotNull DerivativeID derivativeID) {
        Intrinsics.checkNotNullParameter(dTensor, "left");
        Intrinsics.checkNotNullParameter(dTensor2, "right");
        Intrinsics.checkNotNullParameter(derivativeID, "derivativeId");
        if (!Intrinsics.areEqual(derivativeID, NoDerivativeID.INSTANCE)) {
            throw new IllegalArgumentException("Failed requirement.".toString());
        }
        StridedFloatTensor wrap = wrap(dTensor);
        StridedFloatTensor wrap2 = wrap(dTensor2);
        return UtilsKt.shouldSendToCpp(200, (ExternalLib) Dnnl.INSTANCE, new FloatTensor[]{wrap, wrap2}, false, false) ? Dnnl.INSTANCE.sub(wrap, wrap2) : super.minus(dTensor, dTensor2, derivativeID);
    }

    @Override // org.diffkt.FloatTensorOperations, org.diffkt.Operations
    @SType("S: Shape")
    @NotNull
    public DTensor times(@NotNull DTensor dTensor, @NotNull DTensor dTensor2, @NotNull DerivativeID derivativeID) {
        Intrinsics.checkNotNullParameter(dTensor, "left");
        Intrinsics.checkNotNullParameter(dTensor2, "right");
        Intrinsics.checkNotNullParameter(derivativeID, "derivativeId");
        StridedFloatTensor wrap = wrap(dTensor);
        StridedFloatTensor wrap2 = wrap(dTensor2);
        return UtilsKt.shouldSendToCpp$default(200, (ExternalLib) External.INSTANCE, new FloatTensor[]{wrap, wrap2}, false, false, 24, (Object) null) ? FloatTensor.Companion.invoke(dTensor.getShape(), Math.INSTANCE.times(wrap.getData$api(), wrap2.getData$api(), dTensor.getSize())) : super.times(dTensor, dTensor2, derivativeID);
    }

    @Override // org.diffkt.FloatTensorOperations, org.diffkt.Operations
    @SType("S: Shape")
    @NotNull
    public DTensor timesScalar(@NotNull DScalar dScalar, @NotNull DTensor dTensor, @NotNull DerivativeID derivativeID) {
        Intrinsics.checkNotNullParameter(dScalar, "left");
        Intrinsics.checkNotNullParameter(dTensor, "right");
        Intrinsics.checkNotNullParameter(derivativeID, "derivativeId");
        StridedFloatTensor wrap = wrap(dTensor);
        return UtilsKt.shouldSendToCpp$default(32, (ExternalLib) Dnnl.INSTANCE, new FloatTensor[]{wrap}, false, false, 16, (Object) null) ? Dnnl.INSTANCE.mulScalar(wrap, FloatScalarOperations.INSTANCE.wrap$api(dScalar).getValue()) : super.timesScalar(dScalar, dTensor, derivativeID);
    }

    @Override // org.diffkt.FloatTensorOperations, org.diffkt.Operations
    @SType("S: Shape")
    @NotNull
    public DTensor unaryMinus(@NotNull DTensor dTensor) {
        Intrinsics.checkNotNullParameter(dTensor, "x");
        if (dTensor instanceof StridedFloatTensor) {
            return UtilsKt.shouldSendToCpp$default(100, (ExternalLib) External.INSTANCE, new FloatTensor[]{(FloatTensor) dTensor}, false, false, 24, (Object) null) ? FloatTensor.Companion.invoke(dTensor.getShape(), Math.INSTANCE.unaryMinus(((StridedFloatTensor) dTensor).getData$api(), dTensor.getSize())) : super.unaryMinus(dTensor);
        }
        throw new IllegalArgumentException("Failed requirement.".toString());
    }

    @Override // org.diffkt.FloatTensorOperations, org.diffkt.Operations
    @NotNull
    public DTensor view1(@NotNull DTensor dTensor, @NotNull int[] iArr) {
        Intrinsics.checkNotNullParameter(dTensor, "x");
        Intrinsics.checkNotNullParameter(iArr, "indices");
        if (!(dTensor instanceof StridedFloatTensor)) {
            throw new IllegalArgumentException("Failed requirement.".toString());
        }
        Shape drop = dTensor.getShape().drop(iArr.length);
        Iterable indices = ArraysKt.getIndices(iArr);
        int offset$api = ((StridedFloatTensor) dTensor).getOffset$api();
        IntIterator it = indices.iterator();
        while (it.hasNext()) {
            int nextInt = it.nextInt();
            offset$api += ((StridedFloatTensor) dTensor).getStrides$api()[nextInt] * iArr[nextInt];
        }
        int i = offset$api;
        return drop.isScalar() ? new FloatScalar(((StridedFloatTensor) dTensor).getData$api()[i]) : new StridedFloatTensor(drop, i, ArraysKt.sliceArray(((StridedFloatTensor) dTensor).getStrides$api(), RangesKt.until(iArr.length, dTensor.getShape().getRank())), ((StridedFloatTensor) dTensor).getData$api(), null, 16, null);
    }

    @Override // org.diffkt.FloatTensorOperations, org.diffkt.Operations
    @NotNull
    public DTensor view2(@NotNull DTensor dTensor, int i, int i2) {
        Intrinsics.checkNotNullParameter(dTensor, "x");
        if (!(dTensor instanceof StridedFloatTensor)) {
            throw new IllegalArgumentException("Failed requirement.".toString());
        }
        Shape remove = dTensor.getShape().remove(i2);
        return remove.isScalar() ? new FloatScalar(((StridedFloatTensor) dTensor).at(i)) : new StridedFloatTensor(remove, ((StridedFloatTensor) dTensor).getOffset$api() + (i * ((StridedFloatTensor) dTensor).getStrides$api()[i2]), UtilsKt.remove(((StridedFloatTensor) dTensor).getStrides$api(), i2), ((StridedFloatTensor) dTensor).getData$api(), null, 16, null);
    }

    @Override // org.diffkt.FloatTensorOperations, org.diffkt.Operations
    @NotNull
    public DTensor view3(@NotNull DTensor dTensor, @NotNull IntRange intRange, int i) {
        Intrinsics.checkNotNullParameter(dTensor, "x");
        Intrinsics.checkNotNullParameter(intRange, "index");
        if (!(dTensor instanceof StridedFloatTensor)) {
            throw new IllegalArgumentException("Failed requirement.".toString());
        }
        Shape updated = dTensor.getShape().updated(i, 1 + ((intRange.getEndInclusive().intValue() - intRange.getStart().intValue()) / intRange.getStep()));
        int offset$api = ((StridedFloatTensor) dTensor).getOffset$api() + (intRange.getStart().intValue() * ((StridedFloatTensor) dTensor).getStrides$api()[i]);
        int length = ((StridedFloatTensor) dTensor).getStrides$api().length;
        int[] iArr = new int[length];
        for (int i2 = 0; i2 < length; i2++) {
            int i3 = i2;
            iArr[i3] = i3 == i ? ((StridedFloatTensor) dTensor).getStrides$api()[i] * intRange.getStep() : ((StridedFloatTensor) dTensor).getStrides$api()[i3];
        }
        return new StridedFloatTensor(updated, offset$api, iArr, ((StridedFloatTensor) dTensor).getData$api(), null, 16, null);
    }

    @Override // org.diffkt.FloatTensorOperations, org.diffkt.Operations
    @NotNull
    public DTensor reshape(@NotNull DTensor dTensor, @NotNull Shape shape) {
        Intrinsics.checkNotNullParameter(dTensor, "x");
        Intrinsics.checkNotNullParameter(shape, "newShape");
        if (!(dTensor instanceof StridedFloatTensor)) {
            throw new IllegalArgumentException("Failed requirement.".toString());
        }
        switch (WhenMappings.$EnumSwitchMapping$0[((StridedFloatTensor) dTensor).getLayout$api().ordinal()]) {
            case Convolve.H_AXIS /* 1 */:
                return new StridedFloatTensor(shape, ((StridedFloatTensor) dTensor).getOffset$api(), StridedUtils.INSTANCE.singletonStrides$api(shape.getRank()), ((StridedFloatTensor) dTensor).getData$api(), ((StridedFloatTensor) dTensor).getLayout$api());
            case Convolve.W_AXIS /* 2 */:
                return new StridedFloatTensor(shape, ((StridedFloatTensor) dTensor).getOffset$api(), StridedUtils.INSTANCE.contigStrides(shape), ((StridedFloatTensor) dTensor).getData$api(), ((StridedFloatTensor) dTensor).getLayout$api());
            default:
                return reshape(((StridedFloatTensor) dTensor).normalize(), shape);
        }
    }

    @Override // org.diffkt.FloatTensorOperations, org.diffkt.Operations
    @NotNull
    public DTensor squeeze(@NotNull DTensor dTensor, int i) {
        Intrinsics.checkNotNullParameter(dTensor, "x");
        if (!(dTensor instanceof StridedFloatTensor)) {
            throw new IllegalArgumentException("Failed requirement.".toString());
        }
        int[] strides$api = ((StridedFloatTensor) dTensor).getStrides$api();
        Shape shape = dTensor.getShape();
        return new StridedFloatTensor(shape.take(i).plus(shape.drop(i + 1)), ((StridedFloatTensor) dTensor).getOffset$api(), CollectionsKt.toIntArray(CollectionsKt.plus(ArraysKt.take(strides$api, i), ArraysKt.drop(strides$api, i + 1))), ((StridedFloatTensor) dTensor).getData$api(), null, 16, null);
    }

    @Override // org.diffkt.FloatTensorOperations, org.diffkt.Operations
    @NotNull
    public DTensor unsqueeze(@NotNull DTensor dTensor, int i) {
        Intrinsics.checkNotNullParameter(dTensor, "x");
        if (!(dTensor instanceof StridedFloatTensor)) {
            throw new IllegalArgumentException("Failed requirement.".toString());
        }
        int[] strides$api = ((StridedFloatTensor) dTensor).getStrides$api();
        Shape shape = dTensor.getShape();
        return new StridedFloatTensor(shape.take(i).plus(1).plus(shape.drop(i)), ((StridedFloatTensor) dTensor).getOffset$api(), CollectionsKt.toIntArray(CollectionsKt.plus(CollectionsKt.plus(ArraysKt.take(strides$api, i), Integer.valueOf(i == strides$api.length ? 1 : i == 0 ? strides$api[0] * shape.get(0) : strides$api[i - 1])), ArraysKt.drop(strides$api, i))), ((StridedFloatTensor) dTensor).getData$api(), null, 16, null);
    }

    @Override // org.diffkt.FloatTensorOperations, org.diffkt.Operations
    @NotNull
    public DTensor transpose(@NotNull DTensor dTensor, @NotNull int[] iArr) {
        Intrinsics.checkNotNullParameter(dTensor, "x");
        Intrinsics.checkNotNullParameter(iArr, "axes");
        if (!(dTensor instanceof StridedFloatTensor)) {
            throw new IllegalArgumentException("Failed requirement.".toString());
        }
        return new StridedFloatTensor(new Shape(transpose$shuffle(iArr, dTensor.getShape().getDims())), ((StridedFloatTensor) dTensor).getOffset$api(), transpose$shuffle(iArr, ((StridedFloatTensor) dTensor).getStrides$api()), ((StridedFloatTensor) dTensor).getData$api(), null, 16, null);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.diffkt.FloatTensorOperations, org.diffkt.Operations
    @SType("S: Shape")
    @NotNull
    public DTensor relu(@NotNull final DTensor dTensor) {
        Intrinsics.checkNotNullParameter(dTensor, "x");
        if (!(dTensor instanceof StridedFloatTensor)) {
            throw new IllegalArgumentException("Failed requirement.".toString());
        }
        if (!UtilsKt.shouldSendToCpp$default(10, (FloatTensor) dTensor, (ExternalLib) Dnnl.INSTANCE, false, false, 16, (Object) null)) {
            return super.relu(dTensor);
        }
        final StridedFloatTensor normalize = ((StridedFloatTensor) dTensor).normalize();
        return StridedFloatTensor.Companion.contiguous(dTensor.getShape(), new Function1<float[], Unit>() { // from class: org.diffkt.StridedFloatTensorOperations$relu$1
            /* JADX INFO: Access modifiers changed from: package-private */
            /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
            {
                super(1);
            }

            public final void invoke(@NotNull float[] fArr) {
                Intrinsics.checkNotNullParameter(fArr, "data");
                Dnnl.INSTANCE.relu(DTensor.this.getShape().getDims(), fArr, normalize.getData$api());
            }

            public /* bridge */ /* synthetic */ Object invoke(Object obj) {
                invoke((float[]) obj);
                return Unit.INSTANCE;
            }
        });
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.diffkt.FloatTensorOperations, org.diffkt.Operations
    @NotNull
    public DTensor reluGrad(@NotNull final DTensor dTensor, @NotNull DTensor dTensor2, @NotNull DerivativeID derivativeID) {
        Intrinsics.checkNotNullParameter(dTensor, "x");
        Intrinsics.checkNotNullParameter(dTensor2, "reluUpstream");
        Intrinsics.checkNotNullParameter(derivativeID, "derivativeId");
        final StridedFloatTensor wrap = wrap(dTensor);
        final StridedFloatTensor wrap2 = wrap(dTensor2);
        return (Intrinsics.areEqual(dTensor2.getShape(), dTensor.getShape()) && UtilsKt.shouldSendToCpp$default(10, (ExternalLib) Dnnl.INSTANCE, new FloatTensor[]{wrap, wrap2}, true, false, 16, (Object) null)) ? StridedFloatTensor.Companion.contiguous(dTensor.getShape(), new Function1<float[], Unit>() { // from class: org.diffkt.StridedFloatTensorOperations$reluGrad$1
            /* JADX INFO: Access modifiers changed from: package-private */
            /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
            {
                super(1);
            }

            public final void invoke(@NotNull float[] fArr) {
                Intrinsics.checkNotNullParameter(fArr, "data");
                Dnnl.INSTANCE.reluGrad(DTensor.this.getShape().getDims(), fArr, wrap2.getData$api(), wrap.getData$api());
            }

            public /* bridge */ /* synthetic */ Object invoke(Object obj) {
                invoke((float[]) obj);
                return Unit.INSTANCE;
            }
        }) : super.reluGrad(dTensor, dTensor2, derivativeID);
    }

    @Override // org.diffkt.FloatTensorOperations, org.diffkt.Operations
    @NotNull
    public DTensor sum(@NotNull DTensor dTensor, @NotNull int[] iArr, boolean z) {
        Intrinsics.checkNotNullParameter(dTensor, "x");
        Intrinsics.checkNotNullParameter(iArr, "axes");
        if (!(dTensor instanceof StridedFloatTensor)) {
            throw new IllegalArgumentException("Failed requirement.".toString());
        }
        if (!UtilsKt.shouldSendToCpp$default(100, (ExternalLib) Dnnl.INSTANCE, new FloatTensor[]{(FloatTensor) dTensor}, false, false, 16, (Object) null)) {
            return super.sum(dTensor, iArr, z);
        }
        Shape sum$resShape = sum$resShape(dTensor, iArr, z);
        Shape sum$resShape2 = sum$resShape(dTensor, iArr, true);
        float[] fArr = new float[sum$resShape.getProduct()];
        Dnnl.INSTANCE.reduceSum(sum$resShape2.getDims(), fArr, dTensor.getShape().getDims(), ((StridedFloatTensor) dTensor).normalize().getData$api());
        return FloatTensor.Companion.invoke(sum$resShape, fArr);
    }

    @Override // org.diffkt.FloatTensorOperations, org.diffkt.Operations
    @NotNull
    public BatchNormResult batchNorm(@NotNull DTensor dTensor, @NotNull DTensor dTensor2, @NotNull DerivativeID derivativeID) {
        Intrinsics.checkNotNullParameter(dTensor, "input");
        Intrinsics.checkNotNullParameter(dTensor2, "scaleShift");
        Intrinsics.checkNotNullParameter(derivativeID, "derivativeId");
        if (dTensor.getRank() != 4 || !UtilsKt.isDnnlEligible(dTensor) || !UtilsKt.isDnnlEligible(dTensor2)) {
            return super.batchNorm(dTensor, dTensor2, derivativeID);
        }
        int i = dTensor.getShape().get(3);
        if (!Intrinsics.areEqual(dTensor2.getShape(), Shape.Companion.invoke(2, i))) {
            throw new IllegalArgumentException(("scaleShift must have shape " + Shape.Companion.invoke(2, i)).toString());
        }
        StridedFloatTensor contigZeros = StridedFloatTensor.Companion.contigZeros(dTensor.getShape());
        StridedFloatTensor contigZeros2 = StridedFloatTensor.Companion.contigZeros(Shape.Companion.invoke(i));
        StridedFloatTensor contigZeros3 = StridedFloatTensor.Companion.contigZeros(Shape.Companion.invoke(i));
        Dnnl.INSTANCE.batchNorm(contigZeros.getShape().getDims(), contigZeros.getData$api(), contigZeros2.getData$api(), contigZeros3.getData$api(), ((FloatTensor) dTensor).normalize().getData$api(), ((FloatTensor) dTensor2).normalize().getData$api());
        return BatchNormResult.Companion.fromMeanAndVariance(contigZeros, contigZeros2, contigZeros3);
    }

    @Override // org.diffkt.FloatTensorOperations, org.diffkt.Operations
    @NotNull
    public Pair<DTensor, List<int[]>> maxPoolWithIndices(@NotNull DTensor dTensor, int i, int i2, boolean z) {
        Intrinsics.checkNotNullParameter(dTensor, "x");
        if (dTensor instanceof StridedFloatTensor) {
            return (!UtilsKt.shouldSendToCpp$default(100, (ExternalLib) Dnnl.INSTANCE, new FloatTensor[]{(FloatTensor) dTensor}, false, false, 24, (Object) null) || z) ? super.maxPoolWithIndices(dTensor, i, i2, z) : new Pair<>(MaxPoolKt.maxPoolWithIndicesDnnl((StridedFloatTensor) dTensor, i, i2, false).getFirst(), (Object) null);
        }
        throw new IllegalArgumentException("Failed requirement.".toString());
    }

    @Override // org.diffkt.FloatTensorOperations, org.diffkt.Operations
    @SType("S: Shape")
    @NotNull
    public DTensor ifThenElse(@NotNull DTensor dTensor, @NotNull DTensor dTensor2, @NotNull DTensor dTensor3, @NotNull DerivativeID derivativeID) {
        Intrinsics.checkNotNullParameter(dTensor, "condition");
        Intrinsics.checkNotNullParameter(dTensor2, "whenTrue");
        Intrinsics.checkNotNullParameter(dTensor3, "whenFalse");
        Intrinsics.checkNotNullParameter(derivativeID, "derivativeId");
        return UtilsKt.shouldSendToCpp$default(200, (ExternalLib) External.INSTANCE, new FloatTensor[]{(FloatTensor) dTensor, (FloatTensor) dTensor2, (FloatTensor) dTensor3}, false, false, 24, (Object) null) ? FloatTensor.Companion.invoke(dTensor2.getShape(), Predicate.INSTANCE.ifThenElse(((StridedFloatTensor) dTensor).getData$api(), ((StridedFloatTensor) dTensor2).getData$api(), ((StridedFloatTensor) dTensor3).getData$api(), dTensor.getSize())) : super.ifThenElse(dTensor, dTensor2, dTensor3, derivativeID);
    }

    private static final int[] transpose$shuffle(int[] iArr, int[] iArr2) {
        ArrayList arrayList = new ArrayList(iArr.length);
        for (int i : iArr) {
            arrayList.add(Integer.valueOf(iArr2[i]));
        }
        return CollectionsKt.toIntArray(arrayList);
    }

    private static final Shape sum$resShape(DTensor dTensor, int[] iArr, boolean z) {
        if (z) {
            int[] dims = dTensor.getShape().getDims();
            ArrayList arrayList = new ArrayList(dims.length);
            int i = 0;
            for (int i2 : dims) {
                int i3 = i;
                i++;
                arrayList.add(Integer.valueOf(ArraysKt.contains(iArr, i3) ? 1 : i2));
            }
            return new Shape(CollectionsKt.toIntArray(arrayList));
        }
        int[] dims2 = dTensor.getShape().getDims();
        ArrayList arrayList2 = new ArrayList();
        int i4 = 0;
        for (int i5 : dims2) {
            int i6 = i4;
            i4++;
            if (!ArraysKt.contains(iArr, i6)) {
                arrayList2.add(Integer.valueOf(i5));
            }
        }
        return new Shape(CollectionsKt.toIntArray(arrayList2));
    }
}
