package org.diffkt.model;

import kotlin.Metadata;
import kotlin.Pair;
import kotlin.jvm.internal.Intrinsics;
import org.diffkt.Convolve;
import org.diffkt.DTensor;
import org.diffkt.DerivativeID;
import org.diffkt.DivKt;
import org.diffkt.MinusKt;
import org.diffkt.Operations;
import org.diffkt.OperationsKt;
import org.diffkt.PlusKt;
import org.diffkt.PowerKt;
import org.diffkt.Shape;
import org.diffkt.SumKt;
import org.diffkt.TimesKt;
import org.diffkt.TranscendentalKt;
import org.jetbrains.annotations.NotNull;

/* compiled from: BatchNorm.kt */
@Metadata(mv = {Convolve.H_AXIS, 6, Convolve.N_AXIS}, k = Convolve.W_AXIS, xi = 48, d1 = {"��\u0016\n��\n\u0002\u0010\u0007\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0003\u001a\u0018\u0010\u0002\u001a\u00020\u00032\u0006\u0010\u0004\u001a\u00020\u00052\u0006\u0010\u0006\u001a\u00020\u0005H��\u001a\u0016\u0010\u0007\u001a\u00020\u00032\u0006\u0010\u0004\u001a\u00020\u00052\u0006\u0010\u0006\u001a\u00020\u0005\"\u000e\u0010��\u001a\u00020\u0001X\u0086T¢\u0006\u0002\n��¨\u0006\b"}, d2 = {"BATCHNORM_EPSILON", "", "baseBatchNorm", "Lorg/diffkt/model/BatchNormResult;", "input", "Lorg/diffkt/DTensor;", "scaleShift", "batchNorm", "api"})
/* loaded from: input_file:org/diffkt/model/BatchNormKt.class */
public final class BatchNormKt {
    public static final float BATCHNORM_EPSILON = 1.0E-5f;

    @NotNull
    public static final BatchNormResult batchNorm(@NotNull DTensor dTensor, @NotNull DTensor dTensor2) {
        Intrinsics.checkNotNullParameter(dTensor, "input");
        Intrinsics.checkNotNullParameter(dTensor2, "scaleShift");
        if (!(dTensor.getRank() >= 2)) {
            throw new IllegalArgumentException("Failed requirement.".toString());
        }
        if (!Intrinsics.areEqual(dTensor2.getShape(), Shape.Companion.invoke(2, dTensor.getShape().getLast()))) {
            throw new IllegalArgumentException("Failed requirement.".toString());
        }
        Pair<Operations, DerivativeID> commonKind = OperationsKt.commonKind(dTensor, dTensor2);
        return ((Operations) commonKind.component1()).batchNorm(dTensor, dTensor2, (DerivativeID) commonKind.component2());
    }

    @NotNull
    public static final BatchNormResult baseBatchNorm(@NotNull DTensor dTensor, @NotNull DTensor dTensor2) {
        Intrinsics.checkNotNullParameter(dTensor, "input");
        Intrinsics.checkNotNullParameter(dTensor2, "scaleShift");
        float product = dTensor.getShape().getProduct() / dTensor.getShape().getLast();
        int rank = dTensor.getRank() - 1;
        int[] iArr = new int[rank];
        for (int i = 0; i < rank; i++) {
            int i2 = i;
            iArr[i2] = i2;
        }
        DTensor sum$default = SumKt.sum$default(dTensor, iArr, false, 2, null);
        DTensor sum$default2 = SumKt.sum$default(PowerKt.pow(dTensor, 2), iArr, false, 2, null);
        DTensor div = DivKt.div(sum$default, product);
        DTensor minus = MinusKt.minus(DivKt.div(sum$default2, product), PowerKt.pow(div, 2));
        DTensor div2 = DivKt.div(MinusKt.minus(dTensor, div), TranscendentalKt.sqrt(PlusKt.plus(minus, 1.0E-5f)));
        return new BatchNormResult(PlusKt.plus(TimesKt.times(dTensor2.get(0), div2), dTensor2.get(1)), product, sum$default, sum$default2, div, minus);
    }
}
