package org.diffkt;

import kotlin.Metadata;
import kotlin.jvm.JvmName;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;

/* compiled from: Loss.kt */
@Metadata(mv = {Convolve.H_AXIS, 6, Convolve.N_AXIS}, k = Convolve.W_AXIS, xi = 48, d1 = {"��(\n��\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010\b\n��\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010\u000b\n\u0002\b\b\u001a\u0018\u0010��\u001a\u00020\u00012\u0006\u0010\u0002\u001a\u00020\u00012\u0006\u0010\u0003\u001a\u00020\u0004H��\u001a\u0018\u0010\u0005\u001a\u00020\u00062\u0006\u0010\u0002\u001a\u00020\u00012\u0006\u0010\u0007\u001a\u00020\u0006H\u0002\u001a \u0010\b\u001a\u00020\t2\u0006\u0010\u0002\u001a\u00020\u00012\u0006\u0010\n\u001a\u00020\u00012\b\b\u0002\u0010\u000b\u001a\u00020\f\u001a\u0016\u0010\r\u001a\u00020\t2\u0006\u0010\u0002\u001a\u00020\u00012\u0006\u0010\u000e\u001a\u00020\u0001\u001a\u0016\u0010\u000f\u001a\u00020\t2\u0006\u0010\u0010\u001a\u00020\u00012\u0006\u0010\n\u001a\u00020\u0001\u001a\u0016\u0010\u0011\u001a\u00020\u00012\u0006\u0010\u0002\u001a\u00020\u00012\u0006\u0010\u0003\u001a\u00020\u0004\u001a\u0012\u0010\u0012\u001a\u00020\u0001*\u00020\u00012\u0006\u0010\u0003\u001a\u00020\u0004\u001a\u0019\u0010\u0011\u001a\u00020\u0001*\u00020\u00012\u0006\u0010\u0003\u001a\u00020\u0004H\u0007¢\u0006\u0002\b\u0013¨\u0006\u0014"}, d2 = {"baseLogSoftmax", "Lorg/diffkt/DTensor;", "x", "axis", "", "createOneHotFromClasses", "Lorg/diffkt/FloatTensor;", "l", "crossEntropyLoss", "Lorg/diffkt/DScalar;", "labels", "fromOneHot", "", "crossEntropyLossFromOneHot", "oneHotLabels", "nllLossFromOneHot", "guesses", "softmax", "logSoftmax", "DTensorSoftmaxExt", "api"})
/* loaded from: input_file:org/diffkt/LossKt.class */
public final class LossKt {
    @NotNull
    public static final DScalar crossEntropyLossFromOneHot(@NotNull DTensor dTensor, @NotNull DTensor dTensor2) {
        Intrinsics.checkNotNullParameter(dTensor, "x");
        Intrinsics.checkNotNullParameter(dTensor2, "oneHotLabels");
        return nllLossFromOneHot(logSoftmax(dTensor, 1), dTensor2);
    }

    @NotNull
    public static final DScalar crossEntropyLoss(@NotNull DTensor dTensor, @NotNull DTensor dTensor2, boolean z) {
        Intrinsics.checkNotNullParameter(dTensor, "x");
        Intrinsics.checkNotNullParameter(dTensor2, "labels");
        return crossEntropyLossFromOneHot(dTensor, z ? dTensor2 : createOneHotFromClasses(dTensor, (FloatTensor) dTensor2));
    }

    public static /* synthetic */ DScalar crossEntropyLoss$default(DTensor dTensor, DTensor dTensor2, boolean z, int i, Object obj) {
        if ((i & 4) != 0) {
            z = false;
        }
        return crossEntropyLoss(dTensor, dTensor2, z);
    }

    private static final FloatTensor createOneHotFromClasses(DTensor dTensor, FloatTensor floatTensor) {
        if (!(dTensor.getRank() == 2)) {
            throw new IllegalArgumentException("Failed requirement.".toString());
        }
        if (!(floatTensor.getRank() == 1)) {
            throw new IllegalArgumentException("Failed requirement.".toString());
        }
        if (!(dTensor.getShape().get(0) == floatTensor.getShape().get(0))) {
            throw new IllegalArgumentException(("NLL weight and target dimension mismatch at dim 0 (weight: " + dTensor.getShape().get(0) + ", target: " + floatTensor.getShape().get(0)).toString());
        }
        float[] fArr = new float[dTensor.getSize()];
        int last = dTensor.getShape().getLast();
        int size = floatTensor.getSize();
        for (int i = 0; i < size; i++) {
            fArr[(i * last) + ((int) floatTensor.at(i))] = 1.0f;
        }
        return FloatTensor.Companion.invoke(dTensor.getShape(), fArr);
    }

    @NotNull
    public static final DTensor logSoftmax(@NotNull DTensor dTensor, int i) {
        Intrinsics.checkNotNullParameter(dTensor, "<this>");
        return dTensor.mo153getOperations().logSoftmax(dTensor, i);
    }

    @NotNull
    public static final DTensor baseLogSoftmax(@NotNull DTensor dTensor, int i) {
        Intrinsics.checkNotNullParameter(dTensor, "x");
        DTensor minus = MinusKt.minus(dTensor, MinMaxKt.max(UtilsKt.basePrimal(dTensor), new int[]{i}, true));
        return MinusKt.minus(minus, TranscendentalKt.ln(SumKt.varargSum(TranscendentalKt.exp(minus), new int[]{i}, true)));
    }

    @NotNull
    public static final DTensor softmax(@NotNull DTensor dTensor, int i) {
        Intrinsics.checkNotNullParameter(dTensor, "x");
        DTensor exp = TranscendentalKt.exp(MinusKt.minus(dTensor, MinMaxKt.max(UtilsKt.basePrimal(dTensor), new int[]{i}, true)));
        return DivKt.div(exp, SumKt.sum(exp, new int[]{i}, true));
    }

    @JvmName(name = "DTensorSoftmaxExt")
    @NotNull
    public static final DTensor DTensorSoftmaxExt(@NotNull DTensor dTensor, int i) {
        Intrinsics.checkNotNullParameter(dTensor, "<this>");
        return softmax(dTensor, i);
    }

    @NotNull
    public static final DScalar nllLossFromOneHot(@NotNull DTensor dTensor, @NotNull DTensor dTensor2) {
        Intrinsics.checkNotNullParameter(dTensor, "guesses");
        Intrinsics.checkNotNullParameter(dTensor2, "labels");
        return DivKt.div(MinusKt.unaryMinus(SumKt.sum(TimesKt.times(dTensor, dTensor2))), dTensor.getSize() / dTensor.getShape().get(1));
    }
}
