package org.diffkt;

import java.util.Iterator;
import kotlin.Metadata;
import kotlin.collections.ArraysKt;
import kotlin.jvm.JvmName;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;

/* compiled from: Sum.kt */
@Metadata(mv = {Convolve.H_AXIS, 6, Convolve.N_AXIS}, k = Convolve.W_AXIS, xi = 48, d1 = {"��\u001e\n��\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0010\u0015\n\u0002\u0010\b\n��\n\u0002\u0010\u000b\n\u0002\b\u0002\u001a\n\u0010��\u001a\u00020\u0001*\u00020\u0002\u001a'\u0010��\u001a\u00020\u0002*\u00020\u00022\n\u0010\u0003\u001a\u00020\u0004\"\u00020\u00052\b\b\u0002\u0010\u0006\u001a\u00020\u0007H\u0007¢\u0006\u0002\b\b\u001a\u001e\u0010��\u001a\u00020\u0002*\u00020\u00022\b\b\u0002\u0010\u0003\u001a\u00020\u00042\b\b\u0002\u0010\u0006\u001a\u00020\u0007¨\u0006\t"}, d2 = {"sum", "Lorg/diffkt/DScalar;", "Lorg/diffkt/DTensor;", "axes", "", "", "keepDims", "", "varargSum", "api"})
/* loaded from: input_file:org/diffkt/SumKt.class */
public final class SumKt {
    @NotNull
    public static final DScalar sum(@NotNull DTensor dTensor) {
        Intrinsics.checkNotNullParameter(dTensor, "<this>");
        DTensor sum$default = sum$default(dTensor, null, false, 1, null);
        Intrinsics.checkNotNull(sum$default, "null cannot be cast to non-null type org.diffkt.DScalar");
        return (DScalar) sum$default;
    }

    @JvmName(name = "varargSum")
    @NotNull
    public static final DTensor varargSum(@NotNull DTensor dTensor, @NotNull int[] iArr, boolean z) {
        Intrinsics.checkNotNullParameter(dTensor, "<this>");
        Intrinsics.checkNotNullParameter(iArr, "axes");
        return sum(dTensor, iArr, z);
    }

    public static /* synthetic */ DTensor varargSum$default(DTensor dTensor, int[] iArr, boolean z, int i, Object obj) {
        if ((i & 2) != 0) {
            z = false;
        }
        return varargSum(dTensor, iArr, z);
    }

    @NotNull
    public static final DTensor sum(@NotNull DTensor dTensor, @NotNull int[] iArr, boolean z) {
        boolean z2;
        Intrinsics.checkNotNullParameter(dTensor, "<this>");
        Intrinsics.checkNotNullParameter(iArr, "axes");
        if (iArr.length == 0) {
            return dTensor;
        }
        int i = 0;
        int length = iArr.length;
        while (true) {
            if (i >= length) {
                z2 = true;
                break;
            }
            if (!(dTensor.getShape().get(iArr[i]) == 1)) {
                z2 = false;
                break;
            }
            i++;
        }
        if (!z2) {
            return dTensor.mo153getOperations().sum(dTensor, iArr, z);
        }
        if (z) {
            return dTensor;
        }
        DTensor dTensor2 = dTensor;
        Iterator it = ArraysKt.reversed(iArr).iterator();
        while (it.hasNext()) {
            dTensor2 = SqueezeKt.squeeze(dTensor2, ((Number) it.next()).intValue());
        }
        return dTensor2;
    }

    public static /* synthetic */ DTensor sum$default(DTensor dTensor, int[] iArr, boolean z, int i, Object obj) {
        if ((i & 1) != 0) {
            int rank = dTensor.getRank();
            int[] iArr2 = new int[rank];
            for (int i2 = 0; i2 < rank; i2++) {
                int i3 = i2;
                iArr2[i3] = i3;
            }
            iArr = iArr2;
        }
        if ((i & 2) != 0) {
            z = false;
        }
        return sum(dTensor, iArr, z);
    }
}
