package org.diffkt;

import java.util.ArrayList;
import kotlin.Metadata;
import kotlin.Pair;
import kotlin._Assertions;
import kotlin.collections.IntIterator;
import kotlin.jvm.internal.Intrinsics;
import kotlin.ranges.RangesKt;
import org.diffkt.external.SparseOps;
import org.jetbrains.annotations.NotNull;
import shapeTyping.annotations.AllowUnreduced;
import shapeTyping.annotations.SType;

/* compiled from: Matmul.kt */
@Metadata(mv = {Convolve.H_AXIS, 6, Convolve.N_AXIS}, k = Convolve.W_AXIS, xi = 48, d1 = {"��\u0016\n��\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0006\n\u0002\u0018\u0002\n\u0002\b\u0004\u001an\u0010��\u001a 0\u0001¢\u0006\f\b\u0002\u0012\b\b\u0003\u0012\u0004\b\b(\u0004¢\u0006\f\b\u0002\u0012\b\b\u0003\u0012\u0004\b\b(\u0004* 0\u0001¢\u0006\f\b\u0002\u0012\b\b\u0003\u0012\u0004\b\b(\u0005¢\u0006\f\b\u0002\u0012\b\b\u0003\u0012\u0004\b\b(\u00052$\u0010\u0006\u001a 0\u0001¢\u0006\f\b\u0002\u0012\b\b\u0003\u0012\u0004\b\b(\u0007¢\u0006\f\b\u0002\u0012\b\b\u0003\u0012\u0004\b\b(\u0007H\u0007\u001a2\u0010��\u001a\u00020\u0001*\u00020\u00012\u0006\u0010\u0006\u001a\u00020\u00012\u0006\u0010\b\u001a\u00020\t2\u0006\u0010\n\u001a\u00020\t2\u0006\u0010\u000b\u001a\u00020\t2\u0006\u0010\f\u001a\u00020\t¨\u0006\r"}, d2 = {"matmul", "Lorg/diffkt/DTensor;", "LshapeTyping/annotations/SType;", "value", "matmul(A,B)", "A", "right", "B", "a", "Lorg/diffkt/Shape;", "b", "c", "d", "api"})
/* loaded from: input_file:org/diffkt/MatmulKt.class */
public final class MatmulKt {
    @SType("A: Shape, B: Shape")
    @AllowUnreduced
    @NotNull
    public static final DTensor matmul(@NotNull DTensor dTensor, @NotNull DTensor dTensor2) {
        Intrinsics.checkNotNullParameter(dTensor, "<this>");
        Intrinsics.checkNotNullParameter(dTensor2, "right");
        if ((dTensor instanceof SparseFloatTensor) && (dTensor2 instanceof SparseFloatTensor)) {
            if (!(dTensor.getShape().getRank() == dTensor2.getShape().getRank())) {
                throw new IllegalArgumentException("The number of dimensions in both side should be consistent.".toString());
            }
            switch (dTensor.getShape().getRank()) {
                case Convolve.W_AXIS /* 2 */:
                    if (!(dTensor.getShape().get(1) == dTensor2.getShape().get(0))) {
                        throw new IllegalArgumentException("The number of cols of left should be the same as the number of rows on the right.".toString());
                    }
                    break;
                case Convolve.C_AXIS /* 3 */:
                    if (!(dTensor.getShape().get(0) == dTensor2.getShape().get(0) && dTensor.getShape().get(2) == dTensor2.getShape().get(1))) {
                        throw new IllegalArgumentException("For 3D matmul, the number of batches on both side should be same;the number of cols of left should be the same as the number of rows on the right.".toString());
                    }
                    break;
                default:
                    throw new IllegalArgumentException("The numbers of dimensions supported for Spare Matmul are 2 and 3.");
            }
            return SparseOps.INSTANCE.matmul((SparseFloatTensor) dTensor, (SparseFloatTensor) dTensor2);
        }
        DTensor unsqueeze = dTensor.getRank() == 1 ? UnsqueezeKt.unsqueeze(dTensor, 0) : dTensor;
        DTensor unsqueeze2 = dTensor2.getRank() == 1 ? UnsqueezeKt.unsqueeze(dTensor2, 1) : dTensor2;
        Pair pair = new Pair(Integer.valueOf(unsqueeze.getShape().get(unsqueeze.getRank() - 2)), Integer.valueOf(unsqueeze.getShape().get(unsqueeze.getRank() - 1)));
        int intValue = ((Number) pair.component1()).intValue();
        int intValue2 = ((Number) pair.component2()).intValue();
        Pair pair2 = new Pair(Integer.valueOf(unsqueeze2.getShape().get(unsqueeze2.getRank() - 2)), Integer.valueOf(unsqueeze2.getShape().get(unsqueeze2.getRank() - 1)));
        int intValue3 = ((Number) pair2.component1()).intValue();
        int intValue4 = ((Number) pair2.component2()).intValue();
        if (!(intValue2 == intValue3)) {
            throw new IllegalArgumentException(("Shape mismatch: Inner dimensions of " + dTensor.getShape() + " and " + dTensor2.getShape() + " do not match").toString());
        }
        int max = Math.max(unsqueeze.getRank(), unsqueeze2.getRank());
        if (unsqueeze.getRank() > unsqueeze2.getRank()) {
            IntIterator it = RangesKt.until(0, max - unsqueeze2.getRank()).iterator();
            while (it.hasNext()) {
                it.nextInt();
                unsqueeze2 = UnsqueezeKt.unsqueeze(unsqueeze2, 0);
            }
        }
        if (unsqueeze2.getRank() > unsqueeze.getRank()) {
            IntIterator it2 = RangesKt.until(0, max - unsqueeze.getRank()).iterator();
            while (it2.hasNext()) {
                it2.nextInt();
                unsqueeze = UnsqueezeKt.unsqueeze(unsqueeze, 0);
            }
        }
        ArrayList arrayList = new ArrayList();
        IntIterator it3 = RangesKt.until(0, max - 2).iterator();
        while (it3.hasNext()) {
            int nextInt = it3.nextInt();
            arrayList.add(Integer.valueOf(Math.max(unsqueeze.getShape().get(nextInt), unsqueeze2.getShape().get(nextInt))));
            if (!(unsqueeze.getShape().get(nextInt) == unsqueeze2.getShape().get(nextInt) || unsqueeze.getShape().get(nextInt) == 1 || unsqueeze2.getShape().get(nextInt) == 1)) {
                throw new IllegalArgumentException(("Shape mismatch: Batch dimension mismatch for " + dTensor.getShape() + " and " + dTensor2.getShape()).toString());
            }
        }
        Shape invoke = Shape.Companion.invoke(arrayList);
        DTensor matmul = matmul(BroadcastKt.broadcastTo(unsqueeze, invoke.plus(intValue).plus(intValue2)), BroadcastKt.broadcastTo(unsqueeze2, invoke.plus(intValue2).plus(intValue4)), invoke, Shape.Companion.invoke(intValue), Shape.Companion.invoke(intValue2), Shape.Companion.invoke(intValue4));
        boolean areEqual = Intrinsics.areEqual(matmul.getShape(), invoke.plus(intValue).plus(intValue4));
        if (_Assertions.ENABLED && !areEqual) {
            throw new AssertionError("Assertion failed");
        }
        DTensor squeeze = dTensor.getRank() == 1 ? SqueezeKt.squeeze(matmul, matmul.getRank() - 2) : matmul;
        return dTensor2.getRank() == 1 ? SqueezeKt.squeeze(squeeze, squeeze.getRank() - 1) : squeeze;
    }

    @NotNull
    public static final DTensor matmul(@NotNull DTensor dTensor, @NotNull DTensor dTensor2, @NotNull Shape shape, @NotNull Shape shape2, @NotNull Shape shape3, @NotNull Shape shape4) {
        Intrinsics.checkNotNullParameter(dTensor, "<this>");
        Intrinsics.checkNotNullParameter(dTensor2, "right");
        Intrinsics.checkNotNullParameter(shape, "a");
        Intrinsics.checkNotNullParameter(shape2, "b");
        Intrinsics.checkNotNullParameter(shape3, "c");
        Intrinsics.checkNotNullParameter(shape4, "d");
        boolean areEqual = Intrinsics.areEqual(dTensor.getShape(), shape.plus(shape2).plus(shape3));
        if (_Assertions.ENABLED && !areEqual) {
            throw new AssertionError("Assertion failed");
        }
        boolean areEqual2 = Intrinsics.areEqual(dTensor2.getShape(), shape.plus(shape3).plus(shape4));
        if (_Assertions.ENABLED && !areEqual2) {
            throw new AssertionError("Assertion failed");
        }
        if (shape3.isScalar()) {
            return OuterProductKt.outerProduct(dTensor, dTensor2);
        }
        Pair<Operations, DerivativeID> commonKind = OperationsKt.commonKind(dTensor, dTensor2);
        return ((Operations) commonKind.component1()).matmul(dTensor, dTensor2, shape, shape2, shape3, shape4, (DerivativeID) commonKind.component2());
    }
}
