package org.diffkt;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import kotlin.Metadata;
import kotlin.Pair;
import kotlin._Assertions;
import kotlin.collections.CollectionsKt;
import kotlin.jvm.JvmName;
import kotlin.jvm.functions.Function1;
import kotlin.jvm.functions.Function2;
import kotlin.jvm.internal.Intrinsics;
import org.diffkt.forward.ForwardDerivativeID;
import org.diffkt.forward.ForwardScalar;
import org.diffkt.forward.ForwardTensor;
import org.jetbrains.annotations.NotNull;

/* compiled from: ForwardDerivative.kt */
@Metadata(mv = {Convolve.H_AXIS, 6, Convolve.N_AXIS}, k = Convolve.W_AXIS, xi = 48, d1 = {"��.\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010\b\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\b\n\u0002\u0010 \n\u0002\b\u0004\u001a*\u0010��\u001a\u00020\u00012\u0006\u0010\u0002\u001a\u00020\u00032\u0006\u0010\u0004\u001a\u00020\u00012\u0012\u0010\u0005\u001a\u000e\u0012\u0004\u0012\u00020\u0001\u0012\u0004\u0012\u00020\u00010\u0006\u001a*\u0010��\u001a\u00020\u00072\u0006\u0010\u0002\u001a\u00020\u00032\u0006\u0010\u0004\u001a\u00020\u00072\u0012\u0010\u0005\u001a\u000e\u0012\u0004\u0012\u00020\u0007\u0012\u0004\u0012\u00020\u00070\u0006\u001a\"\u0010��\u001a\u00020\u00012\u0006\u0010\u0004\u001a\u00020\u00012\u0012\u0010\u0005\u001a\u000e\u0012\u0004\u0012\u00020\u0001\u0012\u0004\u0012\u00020\u00010\u0006\u001a<\u0010��\u001a\u000e\u0012\u0004\u0012\u00020\u0001\u0012\u0004\u0012\u00020\u00010\b2\u0006\u0010\u0004\u001a\u00020\u00012\u0006\u0010\t\u001a\u00020\u00012\u0018\u0010\u0005\u001a\u0014\u0012\u0004\u0012\u00020\u0001\u0012\u0004\u0012\u00020\u0001\u0012\u0004\u0012\u00020\u00010\n\u001a\"\u0010��\u001a\u00020\u00072\u0006\u0010\u0004\u001a\u00020\u00072\u0012\u0010\u0005\u001a\u000e\u0012\u0004\u0012\u00020\u0007\u0012\u0004\u0012\u00020\u00070\u0006\u001a<\u0010��\u001a\u000e\u0012\u0004\u0012\u00020\u0007\u0012\u0004\u0012\u00020\u00070\b2\u0006\u0010\u0004\u001a\u00020\u00072\u0006\u0010\t\u001a\u00020\u00072\u0018\u0010\u0005\u001a\u0014\u0012\u0004\u0012\u00020\u0007\u0012\u0004\u0012\u00020\u0007\u0012\u0004\u0012\u00020\u00070\n\u001a\"\u0010\u000b\u001a\u00020\u00012\u0006\u0010\u0004\u001a\u00020\u00012\u0012\u0010\u0005\u001a\u000e\u0012\u0004\u0012\u00020\u0001\u0012\u0004\u0012\u00020\u00010\u0006\u001a\"\u0010\u000b\u001a\u00020\u00072\u0006\u0010\u0004\u001a\u00020\u00072\u0012\u0010\u0005\u001a\u000e\u0012\u0004\u0012\u00020\u0007\u0012\u0004\u0012\u00020\u00070\u0006\u001a\"\u0010\f\u001a\u00020\u00012\u0006\u0010\u0004\u001a\u00020\u00012\u0012\u0010\u0005\u001a\u000e\u0012\u0004\u0012\u00020\u0001\u0012\u0004\u0012\u00020\u00010\u0006\u001a\"\u0010\f\u001a\u00020\u00072\u0006\u0010\u0004\u001a\u00020\u00072\u0012\u0010\u0005\u001a\u000e\u0012\u0004\u0012\u00020\u0007\u0012\u0004\u0012\u00020\u00070\u0006\u001a\"\u0010\r\u001a\u00020\u00012\u0006\u0010\u0004\u001a\u00020\u00012\u0012\u0010\u0005\u001a\u000e\u0012\u0004\u0012\u00020\u0001\u0012\u0004\u0012\u00020\u00010\u0006\u001a\"\u0010\r\u001a\u00020\u00072\u0006\u0010\u0004\u001a\u00020\u00072\u0012\u0010\u0005\u001a\u000e\u0012\u0004\u0012\u00020\u0007\u0012\u0004\u0012\u00020\u00070\u0006\u001a\"\u0010\u000e\u001a\u00020\u00012\u0006\u0010\u0004\u001a\u00020\u00012\u0012\u0010\u0005\u001a\u000e\u0012\u0004\u0012\u00020\u0001\u0012\u0004\u0012\u00020\u00010\u0006\u001a\"\u0010\u000e\u001a\u00020\u00072\u0006\u0010\u0004\u001a\u00020\u00072\u0012\u0010\u0005\u001a\u000e\u0012\u0004\u0012\u00020\u0007\u0012\u0004\u0012\u00020\u00070\u0006\u001a&\u0010\u000f\u001a\u000e\u0012\u0004\u0012\u00020\u0001\u0012\u0004\u0012\u00020\u00010\u00062\u0012\u0010\u0005\u001a\u000e\u0012\u0004\u0012\u00020\u0001\u0012\u0004\u0012\u00020\u00010\u0006\u001a*\u0010\u0010\u001a\u00020\u00072\u0006\u0010\u0004\u001a\u00020\u00072\u0006\u0010\u0011\u001a\u00020\u00072\u0012\u0010\u0005\u001a\u000e\u0012\u0004\u0012\u00020\u0007\u0012\u0004\u0012\u00020\u00070\u0006\u001aG\u0010\u0012\u001a\u0014\u0012\u0004\u0012\u00020\u0001\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00010\u00130\b2\f\u0010\u0014\u001a\b\u0012\u0004\u0012\u00020\u00010\u00132\u0018\u0010\u0005\u001a\u0014\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00010\u0013\u0012\u0004\u0012\u00020\u00010\u0006H\u0007¢\u0006\u0002\b\u0015\u001a@\u0010\u0012\u001a\u0014\u0012\u0004\u0012\u00020\u0007\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00070\u00130\b2\f\u0010\u0014\u001a\b\u0012\u0004\u0012\u00020\u00070\u00132\u0018\u0010\u0005\u001a\u0014\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00070\u0013\u0012\u0004\u0012\u00020\u00070\u0006\u001a.\u0010\u0012\u001a\u000e\u0012\u0004\u0012\u00020\u0001\u0012\u0004\u0012\u00020\u00010\b2\u0006\u0010\u0004\u001a\u00020\u00012\u0012\u0010\u0005\u001a\u000e\u0012\u0004\u0012\u00020\u0001\u0012\u0004\u0012\u00020\u00010\u0006\u001aH\u0010\u0012\u001a\u001a\u0012\u0004\u0012\u00020\u0001\u0012\u0010\u0012\u000e\u0012\u0004\u0012\u00020\u0001\u0012\u0004\u0012\u00020\u00010\b0\b2\u0006\u0010\u0004\u001a\u00020\u00012\u0006\u0010\t\u001a\u00020\u00012\u0018\u0010\u0005\u001a\u0014\u0012\u0004\u0012\u00020\u0001\u0012\u0004\u0012\u00020\u0001\u0012\u0004\u0012\u00020\u00010\n\u001a.\u0010\u0012\u001a\u000e\u0012\u0004\u0012\u00020\u0007\u0012\u0004\u0012\u00020\u00070\b2\u0006\u0010\u0004\u001a\u00020\u00072\u0012\u0010\u0005\u001a\u000e\u0012\u0004\u0012\u00020\u0007\u0012\u0004\u0012\u00020\u00070\u0006\u001aH\u0010\u0012\u001a\u001a\u0012\u0004\u0012\u00020\u0007\u0012\u0010\u0012\u000e\u0012\u0004\u0012\u00020\u0007\u0012\u0004\u0012\u00020\u00070\b0\b2\u0006\u0010\u0004\u001a\u00020\u00072\u0006\u0010\t\u001a\u00020\u00072\u0018\u0010\u0005\u001a\u0014\u0012\u0004\u0012\u00020\u0007\u0012\u0004\u0012\u00020\u0007\u0012\u0004\u0012\u00020\u00070\n\u001a6\u0010\u0016\u001a\u000e\u0012\u0004\u0012\u00020\u0007\u0012\u0004\u0012\u00020\u00070\b2\u0006\u0010\u0004\u001a\u00020\u00072\u0006\u0010\u0011\u001a\u00020\u00072\u0012\u0010\u0005\u001a\u000e\u0012\u0004\u0012\u00020\u0007\u0012\u0004\u0012\u00020\u00070\u0006¨\u0006\u0017"}, d2 = {"forwardDerivative", "Lorg/diffkt/DScalar;", "n", "", "x", "f", "Lkotlin/Function1;", "Lorg/diffkt/DTensor;", "Lkotlin/Pair;", "y", "Lkotlin/Function2;", "forwardDerivative1", "forwardDerivative2", "forwardDerivative3", "forwardDerivative4", "forwardDiff", "jvp", "v", "primalAndForwardDerivative", "", "inputs", "primalAndForwardderivative_2", "primalAndJvp", "api"})
/* loaded from: input_file:org/diffkt/ForwardDerivativeKt.class */
public final class ForwardDerivativeKt {
    @NotNull
    public static final Pair<DTensor, DTensor> primalAndForwardDerivative(@NotNull DTensor dTensor, @NotNull Function1<? super DTensor, ? extends DTensor> function1) {
        DTensor dTensor2;
        DTensor tangent;
        Intrinsics.checkNotNullParameter(dTensor, "x");
        Intrinsics.checkNotNullParameter(function1, "f");
        ForwardDerivativeID forwardDerivativeID = new ForwardDerivativeID(dTensor.getShape());
        DTensor dTensor3 = (DTensor) function1.invoke(ForwardTensor.Companion.invoke(dTensor, forwardDerivativeID, dTensor.mo153getOperations().identityGradientOfSameKind(dTensor, dTensor.getShape())));
        while (true) {
            dTensor2 = dTensor3;
            if (dTensor2.getDerivativeID().getSequence() <= forwardDerivativeID.getSequence()) {
                break;
            }
            dTensor3 = dTensor2.getPrimal();
        }
        if (!Intrinsics.areEqual(dTensor2.getDerivativeID(), forwardDerivativeID)) {
            return new Pair<>(dTensor2, UtilsKt.zeroOfSameKind(dTensor2, dTensor2.getShape().plus(forwardDerivativeID.getInputTangentShapeForJacobian())));
        }
        if (dTensor2 instanceof ForwardScalar) {
            tangent = ((ForwardScalar) dTensor2).getTangent();
        } else {
            if (!(dTensor2 instanceof ForwardTensor)) {
                throw new Error();
            }
            tangent = ((ForwardTensor) dTensor2).getTangent();
        }
        return new Pair<>(dTensor2.getPrimal(), tangent);
    }

    @NotNull
    public static final Pair<DTensor, DTensor> primalAndJvp(@NotNull DTensor dTensor, @NotNull DTensor dTensor2, @NotNull Function1<? super DTensor, ? extends DTensor> function1) {
        DTensor dTensor3;
        Intrinsics.checkNotNullParameter(dTensor, "x");
        Intrinsics.checkNotNullParameter(dTensor2, "v");
        Intrinsics.checkNotNullParameter(function1, "f");
        if (!Intrinsics.areEqual(dTensor2.getShape(), dTensor.getShape())) {
            throw new IllegalArgumentException(("The shape of v was " + dTensor2.getShape() + " but should be " + dTensor.getShape() + '.').toString());
        }
        ForwardDerivativeID forwardDerivativeID = new ForwardDerivativeID(Shape.Companion.invoke());
        DTensor dTensor4 = (DTensor) function1.invoke(ForwardTensor.Companion.invoke(dTensor, forwardDerivativeID, dTensor2));
        while (true) {
            dTensor3 = dTensor4;
            if (dTensor3.getDerivativeID().getSequence() <= forwardDerivativeID.getSequence()) {
                break;
            }
            dTensor4 = dTensor3.getPrimal();
        }
        if (!Intrinsics.areEqual(dTensor3.getDerivativeID(), forwardDerivativeID)) {
            return new Pair<>(dTensor3, UtilsKt.zeroOfSameKind(dTensor3, dTensor3.getShape().plus(forwardDerivativeID.getInputTangentShapeForJacobian())));
        }
        Intrinsics.checkNotNull(dTensor3, "null cannot be cast to non-null type org.diffkt.forward.ForwardTensor");
        return new Pair<>(dTensor3.getPrimal(), ((ForwardTensor) dTensor3).getTangent());
    }

    @NotNull
    public static final DTensor jvp(@NotNull DTensor dTensor, @NotNull DTensor dTensor2, @NotNull Function1<? super DTensor, ? extends DTensor> function1) {
        Intrinsics.checkNotNullParameter(dTensor, "x");
        Intrinsics.checkNotNullParameter(dTensor2, "v");
        Intrinsics.checkNotNullParameter(function1, "f");
        return (DTensor) primalAndJvp(dTensor, dTensor2, function1).getSecond();
    }

    @NotNull
    public static final DTensor forwardDerivative(@NotNull DTensor dTensor, @NotNull Function1<? super DTensor, ? extends DTensor> function1) {
        Intrinsics.checkNotNullParameter(dTensor, "x");
        Intrinsics.checkNotNullParameter(function1, "f");
        return (DTensor) primalAndForwardDerivative(dTensor, function1).getSecond();
    }

    @NotNull
    public static final DTensor forwardDerivative1(@NotNull DTensor dTensor, @NotNull Function1<? super DTensor, ? extends DTensor> function1) {
        Intrinsics.checkNotNullParameter(dTensor, "x");
        Intrinsics.checkNotNullParameter(function1, "f");
        return forwardDerivative(dTensor, function1);
    }

    @NotNull
    public static final DTensor forwardDerivative2(@NotNull DTensor dTensor, @NotNull Function1<? super DTensor, ? extends DTensor> function1) {
        Intrinsics.checkNotNullParameter(dTensor, "x");
        Intrinsics.checkNotNullParameter(function1, "f");
        return forwardDerivative(2, dTensor, function1);
    }

    @NotNull
    public static final DTensor forwardDerivative3(@NotNull DTensor dTensor, @NotNull Function1<? super DTensor, ? extends DTensor> function1) {
        Intrinsics.checkNotNullParameter(dTensor, "x");
        Intrinsics.checkNotNullParameter(function1, "f");
        return forwardDerivative(3, dTensor, function1);
    }

    @NotNull
    public static final DTensor forwardDerivative4(@NotNull DTensor dTensor, @NotNull Function1<? super DTensor, ? extends DTensor> function1) {
        Intrinsics.checkNotNullParameter(dTensor, "x");
        Intrinsics.checkNotNullParameter(function1, "f");
        return forwardDerivative(4, dTensor, function1);
    }

    @NotNull
    public static final DTensor forwardDerivative(final int i, @NotNull DTensor dTensor, @NotNull final Function1<? super DTensor, ? extends DTensor> function1) {
        Intrinsics.checkNotNullParameter(dTensor, "x");
        Intrinsics.checkNotNullParameter(function1, "f");
        return i == 0 ? (DTensor) function1.invoke(dTensor) : forwardDerivative(dTensor, new Function1<DTensor, DTensor>() { // from class: org.diffkt.ForwardDerivativeKt$forwardDerivative$1
            /* JADX INFO: Access modifiers changed from: package-private */
            /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
            /* JADX WARN: Multi-variable type inference failed */
            {
                super(1);
            }

            @NotNull
            public final DTensor invoke(@NotNull DTensor dTensor2) {
                Intrinsics.checkNotNullParameter(dTensor2, "y");
                return ForwardDerivativeKt.forwardDerivative(i - 1, dTensor2, function1);
            }
        });
    }

    @NotNull
    public static final Pair<DScalar, DScalar> primalAndForwardDerivative(@NotNull DScalar dScalar, @NotNull final Function1<? super DScalar, ? extends DScalar> function1) {
        Intrinsics.checkNotNullParameter(dScalar, "x");
        Intrinsics.checkNotNullParameter(function1, "f");
        Pair<DTensor, DTensor> primalAndForwardDerivative = primalAndForwardDerivative((DTensor) dScalar, (Function1<? super DTensor, ? extends DTensor>) new Function1<DTensor, DTensor>() { // from class: org.diffkt.ForwardDerivativeKt$primalAndForwardDerivative$pad$1
            /* JADX INFO: Access modifiers changed from: package-private */
            /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
            /* JADX WARN: Multi-variable type inference failed */
            {
                super(1);
            }

            @NotNull
            public final DTensor invoke(@NotNull DTensor dTensor) {
                Intrinsics.checkNotNullParameter(dTensor, "xx");
                return (DTensor) function1.invoke((DScalar) dTensor);
            }
        });
        Object first = primalAndForwardDerivative.getFirst();
        Intrinsics.checkNotNull(first, "null cannot be cast to non-null type org.diffkt.DScalar");
        Object second = primalAndForwardDerivative.getSecond();
        Intrinsics.checkNotNull(second, "null cannot be cast to non-null type org.diffkt.DScalar");
        return new Pair<>((DScalar) first, (DScalar) second);
    }

    @NotNull
    public static final DScalar forwardDerivative(@NotNull DScalar dScalar, @NotNull Function1<? super DScalar, ? extends DScalar> function1) {
        Intrinsics.checkNotNullParameter(dScalar, "x");
        Intrinsics.checkNotNullParameter(function1, "f");
        return (DScalar) primalAndForwardDerivative(dScalar, function1).getSecond();
    }

    @NotNull
    public static final Function1<DScalar, DScalar> forwardDiff(@NotNull final Function1<? super DScalar, ? extends DScalar> function1) {
        Intrinsics.checkNotNullParameter(function1, "f");
        return new Function1<DScalar, DScalar>() { // from class: org.diffkt.ForwardDerivativeKt$forwardDiff$1
            /* JADX INFO: Access modifiers changed from: package-private */
            /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
            /* JADX WARN: Multi-variable type inference failed */
            {
                super(1);
            }

            @NotNull
            public final DScalar invoke(@NotNull DScalar dScalar) {
                Intrinsics.checkNotNullParameter(dScalar, "x");
                return (DScalar) ForwardDerivativeKt.primalAndForwardDerivative(dScalar, (Function1<? super DScalar, ? extends DScalar>) function1).getSecond();
            }
        };
    }

    @NotNull
    public static final DScalar forwardDerivative(final int i, @NotNull DScalar dScalar, @NotNull final Function1<? super DScalar, ? extends DScalar> function1) {
        Intrinsics.checkNotNullParameter(dScalar, "x");
        Intrinsics.checkNotNullParameter(function1, "f");
        return i == 0 ? (DScalar) function1.invoke(dScalar) : forwardDerivative(dScalar, (Function1<? super DScalar, ? extends DScalar>) new Function1<DScalar, DScalar>() { // from class: org.diffkt.ForwardDerivativeKt$forwardDerivative$2
            /* JADX INFO: Access modifiers changed from: package-private */
            /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
            /* JADX WARN: Multi-variable type inference failed */
            {
                super(1);
            }

            @NotNull
            public final DScalar invoke(@NotNull DScalar dScalar2) {
                Intrinsics.checkNotNullParameter(dScalar2, "xx");
                return ForwardDerivativeKt.forwardDerivative(i - 1, dScalar2, (Function1<? super DScalar, ? extends DScalar>) function1);
            }
        });
    }

    @NotNull
    public static final DScalar forwardDerivative1(@NotNull DScalar dScalar, @NotNull final Function1<? super DScalar, ? extends DScalar> function1) {
        Intrinsics.checkNotNullParameter(dScalar, "x");
        Intrinsics.checkNotNullParameter(function1, "f");
        return forwardDerivative(dScalar, (Function1<? super DScalar, ? extends DScalar>) new Function1<DScalar, DScalar>() { // from class: org.diffkt.ForwardDerivativeKt$forwardDerivative1$1
            /* JADX INFO: Access modifiers changed from: package-private */
            /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
            /* JADX WARN: Multi-variable type inference failed */
            {
                super(1);
            }

            @NotNull
            public final DScalar invoke(@NotNull DScalar dScalar2) {
                Intrinsics.checkNotNullParameter(dScalar2, "xx");
                return (DScalar) function1.invoke(dScalar2);
            }
        });
    }

    @NotNull
    public static final DScalar forwardDerivative2(@NotNull DScalar dScalar, @NotNull Function1<? super DScalar, ? extends DScalar> function1) {
        Intrinsics.checkNotNullParameter(dScalar, "x");
        Intrinsics.checkNotNullParameter(function1, "f");
        return forwardDerivative(2, dScalar, function1);
    }

    @NotNull
    public static final DScalar forwardDerivative3(@NotNull DScalar dScalar, @NotNull Function1<? super DScalar, ? extends DScalar> function1) {
        Intrinsics.checkNotNullParameter(dScalar, "x");
        Intrinsics.checkNotNullParameter(function1, "f");
        return forwardDerivative(3, dScalar, function1);
    }

    @NotNull
    public static final DScalar forwardDerivative4(@NotNull DScalar dScalar, @NotNull Function1<? super DScalar, ? extends DScalar> function1) {
        Intrinsics.checkNotNullParameter(dScalar, "x");
        Intrinsics.checkNotNullParameter(function1, "f");
        return forwardDerivative(4, dScalar, function1);
    }

    @NotNull
    public static final Pair<DTensor, List<DTensor>> primalAndForwardDerivative(@NotNull final List<? extends DTensor> list, @NotNull final Function1<? super List<? extends DTensor>, ? extends DTensor> function1) {
        Intrinsics.checkNotNullParameter(list, "inputs");
        Intrinsics.checkNotNullParameter(function1, "f");
        DTensor meld = MeldSplitKt.meld(list);
        Pair<DTensor, DTensor> primalAndForwardDerivative = primalAndForwardDerivative(meld, new Function1<DTensor, DTensor>() { // from class: org.diffkt.ForwardDerivativeKt$primalAndForwardDerivative$pad$2
            /* JADX INFO: Access modifiers changed from: package-private */
            /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
            /* JADX WARN: Multi-variable type inference failed */
            {
                super(1);
            }

            @NotNull
            public final DTensor invoke(@NotNull DTensor dTensor) {
                Intrinsics.checkNotNullParameter(dTensor, "wrappedInput");
                List<DTensor> list2 = list;
                ArrayList arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(list2, 10));
                Iterator<T> it = list2.iterator();
                while (it.hasNext()) {
                    arrayList.add(((DTensor) it.next()).getShape());
                }
                return (DTensor) function1.invoke(MeldSplitKt.split(dTensor, arrayList));
            }
        });
        DTensor dTensor = (DTensor) primalAndForwardDerivative.getFirst();
        Shape shape = dTensor.getShape();
        DTensor leftTranspose = TransposeKt.leftTranspose((DTensor) primalAndForwardDerivative.getSecond(), shape, meld.getShape());
        List<? extends DTensor> list2 = list;
        ArrayList arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(list2, 10));
        Iterator<T> it = list2.iterator();
        while (it.hasNext()) {
            arrayList.add(((DTensor) it.next()).getShape().plus(shape));
        }
        List<DTensor> split = MeldSplitKt.split(leftTranspose, arrayList);
        ArrayList arrayList2 = new ArrayList(CollectionsKt.collectionSizeOrDefault(split, 10));
        for (DTensor dTensor2 : split) {
            arrayList2.add(TransposeKt.leftTranspose(dTensor2, dTensor2.getShape().dropLast(dTensor.getRank()), shape));
        }
        return new Pair<>(dTensor, arrayList2);
    }

    @NotNull
    public static final Pair<DTensor, Pair<DTensor, DTensor>> primalAndForwardDerivative(@NotNull DTensor dTensor, @NotNull DTensor dTensor2, @NotNull final Function2<? super DTensor, ? super DTensor, ? extends DTensor> function2) {
        Intrinsics.checkNotNullParameter(dTensor, "x");
        Intrinsics.checkNotNullParameter(dTensor2, "y");
        Intrinsics.checkNotNullParameter(function2, "f");
        Pair<DTensor, List<DTensor>> primalAndForwardDerivative = primalAndForwardDerivative((List<? extends DTensor>) CollectionsKt.listOf(new DTensor[]{dTensor, dTensor2}), new Function1<List<? extends DTensor>, DTensor>() { // from class: org.diffkt.ForwardDerivativeKt$primalAndForwardDerivative$pad$3
            /* JADX INFO: Access modifiers changed from: package-private */
            /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
            /* JADX WARN: Multi-variable type inference failed */
            {
                super(1);
            }

            @NotNull
            public final DTensor invoke(@NotNull List<? extends DTensor> list) {
                Intrinsics.checkNotNullParameter(list, "args");
                boolean z = list.size() == 2;
                if (!_Assertions.ENABLED || z) {
                    return (DTensor) function2.invoke(list.get(0), list.get(1));
                }
                throw new AssertionError("Assertion failed");
            }
        });
        boolean z = ((List) primalAndForwardDerivative.getSecond()).size() == 2;
        if (!_Assertions.ENABLED || z) {
            return new Pair<>(primalAndForwardDerivative.getFirst(), new Pair(((List) primalAndForwardDerivative.getSecond()).get(0), ((List) primalAndForwardDerivative.getSecond()).get(1)));
        }
        throw new AssertionError("Assertion failed");
    }

    @JvmName(name = "primalAndForwardderivative_2")
    @NotNull
    public static final Pair<DScalar, List<DScalar>> primalAndForwardderivative_2(@NotNull List<? extends DScalar> list, @NotNull final Function1<? super List<? extends DScalar>, ? extends DScalar> function1) {
        Intrinsics.checkNotNullParameter(list, "inputs");
        Intrinsics.checkNotNullParameter(function1, "f");
        Pair<DTensor, List<DTensor>> primalAndForwardDerivative = primalAndForwardDerivative(list, new Function1<List<? extends DTensor>, DTensor>() { // from class: org.diffkt.ForwardDerivativeKt$primalAndForwardDerivative$1
            /* JADX INFO: Access modifiers changed from: package-private */
            /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
            /* JADX WARN: Multi-variable type inference failed */
            {
                super(1);
            }

            @NotNull
            public final DTensor invoke(@NotNull List<? extends DTensor> list2) {
                Intrinsics.checkNotNullParameter(list2, "l");
                Function1<List<? extends DScalar>, DScalar> function12 = function1;
                List<? extends DTensor> list3 = list2;
                ArrayList arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(list3, 10));
                for (DTensor dTensor : list3) {
                    Intrinsics.checkNotNull(dTensor, "null cannot be cast to non-null type org.diffkt.DScalar");
                    arrayList.add((DScalar) dTensor);
                }
                return (DTensor) function12.invoke(arrayList);
            }
        });
        DTensor dTensor = (DTensor) primalAndForwardDerivative.component1();
        List list2 = (List) primalAndForwardDerivative.component2();
        Intrinsics.checkNotNull(dTensor, "null cannot be cast to non-null type org.diffkt.DScalar");
        DScalar dScalar = (DScalar) dTensor;
        List<DTensor> list3 = list2;
        ArrayList arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(list3, 10));
        for (DTensor dTensor2 : list3) {
            Intrinsics.checkNotNull(dTensor2, "null cannot be cast to non-null type org.diffkt.DScalar");
            arrayList.add((DScalar) dTensor2);
        }
        return new Pair<>(dScalar, arrayList);
    }

    @NotNull
    public static final Pair<DTensor, DTensor> forwardDerivative(@NotNull DTensor dTensor, @NotNull DTensor dTensor2, @NotNull Function2<? super DTensor, ? super DTensor, ? extends DTensor> function2) {
        Intrinsics.checkNotNullParameter(dTensor, "x");
        Intrinsics.checkNotNullParameter(dTensor2, "y");
        Intrinsics.checkNotNullParameter(function2, "f");
        return (Pair) primalAndForwardDerivative(dTensor, dTensor2, function2).getSecond();
    }

    @NotNull
    public static final Pair<DScalar, Pair<DScalar, DScalar>> primalAndForwardDerivative(@NotNull DScalar dScalar, @NotNull DScalar dScalar2, @NotNull final Function2<? super DScalar, ? super DScalar, ? extends DScalar> function2) {
        Intrinsics.checkNotNullParameter(dScalar, "x");
        Intrinsics.checkNotNullParameter(dScalar2, "y");
        Intrinsics.checkNotNullParameter(function2, "f");
        Pair<DScalar, List<DScalar>> primalAndForwardderivative_2 = primalAndForwardderivative_2(CollectionsKt.listOf(new DScalar[]{dScalar, dScalar2}), new Function1<List<? extends DScalar>, DScalar>() { // from class: org.diffkt.ForwardDerivativeKt$primalAndForwardDerivative$pad$4
            /* JADX INFO: Access modifiers changed from: package-private */
            /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
            /* JADX WARN: Multi-variable type inference failed */
            {
                super(1);
            }

            @NotNull
            public final DScalar invoke(@NotNull List<? extends DScalar> list) {
                Intrinsics.checkNotNullParameter(list, "args");
                boolean z = list.size() == 2;
                if (!_Assertions.ENABLED || z) {
                    return (DScalar) function2.invoke(list.get(0), list.get(1));
                }
                throw new AssertionError("Assertion failed");
            }
        });
        boolean z = ((List) primalAndForwardderivative_2.getSecond()).size() == 2;
        if (!_Assertions.ENABLED || z) {
            return new Pair<>(primalAndForwardderivative_2.getFirst(), new Pair(((List) primalAndForwardderivative_2.getSecond()).get(0), ((List) primalAndForwardderivative_2.getSecond()).get(1)));
        }
        throw new AssertionError("Assertion failed");
    }

    @NotNull
    public static final Pair<DScalar, DScalar> forwardDerivative(@NotNull DScalar dScalar, @NotNull DScalar dScalar2, @NotNull Function2<? super DScalar, ? super DScalar, ? extends DScalar> function2) {
        Intrinsics.checkNotNullParameter(dScalar, "x");
        Intrinsics.checkNotNullParameter(dScalar2, "y");
        Intrinsics.checkNotNullParameter(function2, "f");
        return (Pair) primalAndForwardDerivative(dScalar, dScalar2, function2).getSecond();
    }
}
