package org.diffkt;

import java.util.ArrayList;
import kotlin.Metadata;
import kotlin.Pair;
import kotlin.Triple;
import kotlin._Assertions;
import kotlin.collections.ArraysKt;
import kotlin.collections.CollectionsKt;
import kotlin.jvm.internal.Intrinsics;
import kotlin.jvm.internal.SerializedIr;
import org.jetbrains.annotations.NotNull;

/* compiled from: Broadcast.kt */
@SerializedIr(b = {"��"})
@Metadata(mv = {Convolve.H_AXIS, 6, Convolve.N_AXIS}, k = Convolve.H_AXIS, xi = 48, d1 = {"��8\n\u0002\u0018\u0002\n\u0002\u0010��\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0004\n\u0002\u0010\u0011\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010\u0015\n\u0002\b\u0005\bÀ\u0002\u0018��2\u00020\u0001B\u0007\b\u0002¢\u0006\u0002\u0010\u0002J\u0016\u0010\u0003\u001a\u00020\u00042\u0006\u0010\u0005\u001a\u00020\u00042\u0006\u0010\u0006\u001a\u00020\u0004J!\u0010\u0007\u001a\u00020\u00042\u0012\u0010\b\u001a\n\u0012\u0006\b\u0001\u0012\u00020\u00040\t\"\u00020\u0004H\u0002¢\u0006\u0002\u0010\nJ\"\u0010\u000b\u001a\u000e\u0012\u0004\u0012\u00020\r\u0012\u0004\u0012\u00020\r0\f2\u0006\u0010\u000e\u001a\u00020\r2\u0006\u0010\u000f\u001a\u00020\rJ0\u0010\u000b\u001a\u0014\u0012\u0004\u0012\u00020\r\u0012\u0004\u0012\u00020\r\u0012\u0004\u0012\u00020\r0\u00102\u0006\u0010\u000e\u001a\u00020\r2\u0006\u0010\u000f\u001a\u00020\r2\u0006\u0010\u0011\u001a\u00020\rJ\u0016\u0010\u0012\u001a\u00020\u00132\u0006\u0010\u0014\u001a\u00020\u00042\u0006\u0010\u0015\u001a\u00020\u0004J*\u0010\u0016\u001a\u000e\u0012\u0004\u0012\u00020\u0013\u0012\u0004\u0012\u00020\u00130\f2\u0006\u0010\u0014\u001a\u00020\u00042\u0006\u0010\u0017\u001a\u00020\u00132\u0006\u0010\u0015\u001a\u00020\u0004¨\u0006\u0018"}, d2 = {"Lorg/diffkt/Broadcasting;", "", "()V", "broadcastShape", "Lorg/diffkt/Shape;", "xShape", "yShape", "broadcastShapeImpl", "shapes", "", "([Lorg/diffkt/Shape;)Lorg/diffkt/Shape;", "broadcastToCommonShape", "Lkotlin/Pair;", "Lorg/diffkt/DTensor;", "a", "b", "Lkotlin/Triple;", "c", "getBroadcastedAxes", "", "shape", "newShape", "getBroadcastedStridesAndAxes", "strides", "api"})
/* loaded from: input_file:org/diffkt/Broadcasting.class */
public final class Broadcasting {

    @NotNull
    public static final Broadcasting INSTANCE = new Broadcasting();

    private Broadcasting() {
    }

    @NotNull
    public final Pair<DTensor, DTensor> broadcastToCommonShape(@NotNull DTensor dTensor, @NotNull DTensor dTensor2) {
        Intrinsics.checkNotNullParameter(dTensor, "a");
        Intrinsics.checkNotNullParameter(dTensor2, "b");
        if (Intrinsics.areEqual(dTensor.getShape(), dTensor2.getShape())) {
            return new Pair<>(dTensor, dTensor2);
        }
        Shape broadcastShape = broadcastShape(dTensor.getShape(), dTensor2.getShape());
        return new Pair<>(BroadcastKt.broadcastTo(dTensor, broadcastShape), BroadcastKt.broadcastTo(dTensor2, broadcastShape));
    }

    @NotNull
    public final Triple<DTensor, DTensor, DTensor> broadcastToCommonShape(@NotNull DTensor dTensor, @NotNull DTensor dTensor2, @NotNull DTensor dTensor3) {
        Intrinsics.checkNotNullParameter(dTensor, "a");
        Intrinsics.checkNotNullParameter(dTensor2, "b");
        Intrinsics.checkNotNullParameter(dTensor3, "c");
        if (Intrinsics.areEqual(dTensor.getShape(), dTensor2.getShape()) && Intrinsics.areEqual(dTensor.getShape(), dTensor3.getShape())) {
            return new Triple<>(dTensor, dTensor2, dTensor3);
        }
        Shape broadcastShapeImpl = broadcastShapeImpl(dTensor.getShape(), dTensor2.getShape(), dTensor3.getShape());
        return new Triple<>(BroadcastKt.broadcastTo(dTensor, broadcastShapeImpl), BroadcastKt.broadcastTo(dTensor2, broadcastShapeImpl), BroadcastKt.broadcastTo(dTensor3, broadcastShapeImpl));
    }

    @NotNull
    public final int[] getBroadcastedAxes(@NotNull Shape shape, @NotNull Shape shape2) {
        Intrinsics.checkNotNullParameter(shape, "shape");
        Intrinsics.checkNotNullParameter(shape2, "newShape");
        int rank = shape.getRank();
        int[] iArr = new int[rank];
        for (int i = 0; i < rank; i++) {
            iArr[i] = 0;
        }
        return (int[]) getBroadcastedStridesAndAxes(shape, iArr, shape2).getSecond();
    }

    @NotNull
    public final Pair<int[], int[]> getBroadcastedStridesAndAxes(@NotNull Shape shape, @NotNull int[] iArr, @NotNull Shape shape2) {
        Intrinsics.checkNotNullParameter(shape, "shape");
        Intrinsics.checkNotNullParameter(iArr, "strides");
        Intrinsics.checkNotNullParameter(shape2, "newShape");
        if (Intrinsics.areEqual(shape, shape2)) {
            return new Pair<>(iArr, new int[0]);
        }
        int rank = shape.getRank();
        int rank2 = shape2.getRank();
        if (rank > rank2) {
            throw new IllegalArgumentException("broadcastTo: tensor rank " + rank + " > broadcast rank " + rank2);
        }
        int i = rank2 - rank;
        ArrayList arrayList = new ArrayList(rank2);
        for (int i2 = 0; i2 < rank2; i2++) {
            int i3 = i2;
            arrayList.add(Integer.valueOf(i3 < i ? shape2.get(i3) : shape.get(i3 - i)));
        }
        ArrayList arrayList2 = arrayList;
        ArrayList arrayList3 = new ArrayList(rank2);
        for (int i4 = 0; i4 < rank2; i4++) {
            int i5 = i4;
            arrayList3.add(Integer.valueOf(i5 < i ? 0 : iArr[i5 - i]));
        }
        ArrayList arrayList4 = arrayList3;
        ArrayList arrayList5 = new ArrayList();
        for (int i6 = 0; i6 < rank2; i6++) {
            int intValue = ((Number) arrayList2.get(i6)).intValue();
            int i7 = shape2.get(i6);
            if (intValue != i7) {
                if (intValue == 1) {
                    arrayList2.set(i6, Integer.valueOf(i7));
                    arrayList4.set(i6, 0);
                    arrayList5.add(Integer.valueOf(i6));
                } else {
                    boolean z = i7 != 1;
                    if (_Assertions.ENABLED && !z) {
                        throw new AssertionError("cannot broadcast dim " + i6 + " from " + intValue + " to 1");
                    }
                    if (i7 != 1) {
                        throw new RuntimeException("broadcast: incompatible shapes at dim " + i6 + ": " + shape + " " + shape2);
                    }
                }
            }
        }
        boolean areEqual = Intrinsics.areEqual(new Shape(CollectionsKt.toIntArray(arrayList2)), shape2);
        if (!_Assertions.ENABLED || areEqual) {
            return new Pair<>(CollectionsKt.toIntArray(arrayList4), CollectionsKt.toIntArray(arrayList5));
        }
        throw new AssertionError("Assertion failed");
    }

    @NotNull
    public final Shape broadcastShape(@NotNull Shape shape, @NotNull Shape shape2) {
        Intrinsics.checkNotNullParameter(shape, "xShape");
        Intrinsics.checkNotNullParameter(shape2, "yShape");
        return broadcastShapeImpl(shape, shape2);
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v45, types: [java.util.List] */
    private final Shape broadcastShapeImpl(Shape... shapeArr) {
        if (shapeArr.length == 0) {
            throw new IllegalArgumentException("broadcastShape called with no tensors");
        }
        int rank = shapeArr[0].getRank();
        ArrayList arrayList = new ArrayList(rank);
        for (int i = 0; i < rank; i++) {
            arrayList.add(Integer.valueOf(shapeArr[0].get(i)));
        }
        ArrayList arrayList2 = arrayList;
        int i2 = 1;
        int lastIndex = ArraysKt.getLastIndex(shapeArr);
        if (1 <= lastIndex) {
            while (true) {
                Shape shape = shapeArr[i2];
                int size = arrayList2.size() - shape.getRank();
                if (size < 0) {
                    arrayList2 = ArraysKt.toMutableList(shape.take(-size).plus(new Shape(CollectionsKt.toIntArray(arrayList2))).getDims());
                    size = 0;
                }
                int i3 = size;
                int lastIndex2 = CollectionsKt.getLastIndex(arrayList2);
                if (i3 <= lastIndex2) {
                    while (true) {
                        int intValue = ((Number) arrayList2.get(i3)).intValue();
                        int i4 = shape.get(i3 - size);
                        if (intValue == 1) {
                            arrayList2.set(i3, Integer.valueOf(i4));
                        } else if (i4 != 1 && i4 != intValue) {
                            throw new RuntimeException("broadcast: incompatible shapes at dim " + i3 + ": " + intValue + " and " + i4);
                        }
                        if (i3 == lastIndex2) {
                            break;
                        }
                        i3++;
                    }
                }
                if (i2 == lastIndex) {
                    break;
                }
                i2++;
            }
        }
        return new Shape(CollectionsKt.toIntArray(arrayList2));
    }
}
