package org.diffkt;

import kotlin.Metadata;
import kotlin._Assertions;
import kotlin.collections.ArraysKt;
import kotlin.collections.CollectionsKt;
import kotlin.jvm.internal.Intrinsics;
import kotlin.ranges.RangesKt;
import org.jetbrains.annotations.NotNull;

/* compiled from: Transpose.kt */
@Metadata(mv = {Convolve.H_AXIS, 6, Convolve.N_AXIS}, k = Convolve.W_AXIS, xi = 48, d1 = {"��\u0016\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0004\n\u0002\u0010\u0015\n��\u001a\u001a\u0010��\u001a\u00020\u0001*\u00020\u00012\u0006\u0010\u0002\u001a\u00020\u00032\u0006\u0010\u0004\u001a\u00020\u0003\u001a\u001a\u0010\u0005\u001a\u00020\u0001*\u00020\u00012\u0006\u0010\u0002\u001a\u00020\u00032\u0006\u0010\u0004\u001a\u00020\u0003\u001a\u0014\u0010\u0006\u001a\u00020\u0001*\u00020\u00012\b\b\u0002\u0010\u0007\u001a\u00020\b¨\u0006\t"}, d2 = {"leftTranspose", "Lorg/diffkt/DTensor;", "a", "Lorg/diffkt/Shape;", "b", "rightTranspose", "transpose", "axes", "", "api"})
/* loaded from: input_file:org/diffkt/TransposeKt.class */
public final class TransposeKt {
    @NotNull
    public static final DTensor transpose(@NotNull DTensor dTensor, @NotNull int[] iArr) {
        Intrinsics.checkNotNullParameter(dTensor, "<this>");
        Intrinsics.checkNotNullParameter(iArr, "axes");
        boolean z = ArraysKt.distinct(iArr).size() == iArr.length && Intrinsics.areEqual(ArraysKt.sorted(iArr), CollectionsKt.toList(RangesKt.until(0, dTensor.getShape().getRank())));
        if (!_Assertions.ENABLED || z) {
            return transpose$isIdentityPermutation(iArr) ? dTensor : dTensor.mo153getOperations().transpose(dTensor, iArr);
        }
        throw new AssertionError("transpose: axes " + iArr + " are not a permutation of " + RangesKt.until(0, dTensor.getShape().getRank()));
    }

    public static /* synthetic */ DTensor transpose$default(DTensor dTensor, int[] iArr, int i, Object obj) {
        if ((i & 1) != 0) {
            iArr = ArraysKt.reversedArray(UtilsKt.getAllAxes(dTensor));
        }
        return transpose(dTensor, iArr);
    }

    @NotNull
    public static final DTensor leftTranspose(@NotNull DTensor dTensor, @NotNull Shape shape, @NotNull Shape shape2) {
        Intrinsics.checkNotNullParameter(dTensor, "<this>");
        Intrinsics.checkNotNullParameter(shape, "a");
        Intrinsics.checkNotNullParameter(shape2, "b");
        if (!shape.plus(shape2).isPrefix(dTensor.getShape())) {
            throw new IllegalArgumentException("Failed requirement.".toString());
        }
        if (!shape.isScalar() && !shape2.isScalar()) {
            boolean areEqual = Intrinsics.areEqual(shape.plus(shape2).plus(dTensor.getShape().drop(shape.getRank() + shape2.getRank())), dTensor.getShape());
            if (_Assertions.ENABLED && !areEqual) {
                throw new AssertionError("Assertion failed");
            }
            int rank = shape.getRank();
            int rank2 = shape2.getRank();
            int rank3 = dTensor.getRank();
            int[] iArr = new int[rank3];
            for (int i = 0; i < rank3; i++) {
                int i2 = i;
                iArr[i2] = i2 < rank2 ? i2 + rank : i2 < rank + rank2 ? i2 - rank2 : i2;
            }
            return transpose(dTensor, iArr);
        }
        return dTensor;
    }

    @NotNull
    public static final DTensor rightTranspose(@NotNull DTensor dTensor, @NotNull Shape shape, @NotNull Shape shape2) {
        Intrinsics.checkNotNullParameter(dTensor, "<this>");
        Intrinsics.checkNotNullParameter(shape, "a");
        Intrinsics.checkNotNullParameter(shape2, "b");
        if (!shape.isScalar() && !shape2.isScalar()) {
            Shape dropLast = dTensor.getShape().dropLast(shape.getRank() + shape2.getRank());
            boolean areEqual = Intrinsics.areEqual(dTensor.getShape(), dropLast.plus(shape).plus(shape2));
            if (_Assertions.ENABLED && !areEqual) {
                throw new AssertionError("Assertion failed");
            }
            int rank = shape.getRank();
            int rank2 = shape2.getRank();
            int rank3 = dropLast.getRank();
            int rank4 = dTensor.getRank();
            int[] iArr = new int[rank4];
            for (int i = 0; i < rank4; i++) {
                int i2 = i;
                iArr[i2] = i2 < rank3 ? i2 : i2 < rank3 + rank2 ? i2 + rank : i2 - rank2;
            }
            return transpose(dTensor, iArr);
        }
        return dTensor;
    }

    private static final boolean transpose$isIdentityPermutation(int[] iArr) {
        int length = iArr.length;
        for (int i = 0; i < length; i++) {
            if (iArr[i] != i) {
                return false;
            }
        }
        return true;
    }
}
