package org.diffkt;

import java.util.Map;
import kotlin.Metadata;
import kotlin.Pair;
import kotlin.Unit;
import kotlin.jvm.functions.Function1;
import kotlin.jvm.functions.Function2;
import kotlin.jvm.functions.Function3;
import kotlin.jvm.internal.Intrinsics;
import kotlin.jvm.internal.Ref;
import org.diffkt.reverse.ReverseDerivativeID;
import org.diffkt.reverse.ReverseTensor;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

/* compiled from: WrappedPullback.kt */
@Metadata(mv = {Convolve.H_AXIS, 6, Convolve.N_AXIS}, k = Convolve.W_AXIS, xi = 48, d1 = {"��:\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010��\n\u0002\b\u0003\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0004\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0010\u0002\n\u0002\b\u000b\u001a\u0091\u0001\u0010��\u001a\u001a\u0012\u0004\u0012\u0002H\u0002\u0012\u0010\u0012\u000e\u0012\u0004\u0012\u0002H\u0002\u0012\u0004\u0012\u0002H\u00040\u00030\u0001\"\b\b��\u0010\u0004*\u00020\u0005\"\b\b\u0001\u0010\u0002*\u00020\u00052\u0006\u0010\u0006\u001a\u0002H\u00042\u0012\u0010\u0007\u001a\u000e\u0012\u0004\u0012\u0002H\u0004\u0012\u0004\u0012\u0002H\u00020\u00032\u001c\b\u0002\u0010\b\u001a\u0016\u0012\u0004\u0012\u0002H\u0004\u0012\u0004\u0012\u00020\n\u0012\u0004\u0012\u0002H\u0004\u0018\u00010\t2\u001c\b\u0002\u0010\u000b\u001a\u0016\u0012\u0004\u0012\u0002H\u0002\u0012\u0004\u0012\u00020\n\u0012\u0004\u0012\u0002H\u0002\u0018\u00010\tH��¢\u0006\u0002\u0010\f\u001aú\u0001\u0010��\u001a\u001a\u0012\u0004\u0012\u0002H\u0002\u0012\u0010\u0012\u000e\u0012\u0004\u0012\u0002H\r\u0012\u0004\u0012\u0002H\u00040\u00030\u0001\"\b\b��\u0010\u0004*\u00020\u0005\"\b\b\u0001\u0010\u0002*\u00020\u0005\"\b\b\u0002\u0010\r*\u00020\u00052\u0006\u0010\u0006\u001a\u0002H\u00042\u0012\u0010\u0007\u001a\u000e\u0012\u0004\u0012\u0002H\u0004\u0012\u0004\u0012\u0002H\u00020\u00032\u001c\b\u0002\u0010\b\u001a\u0016\u0012\u0004\u0012\u0002H\u0004\u0012\u0004\u0012\u00020\n\u0012\u0004\u0012\u0002H\u0004\u0018\u00010\t2\u001c\b\u0002\u0010\u000b\u001a\u0016\u0012\u0004\u0012\u0002H\u0002\u0012\u0004\u0012\u00020\n\u0012\u0004\u0012\u0002H\u0002\u0018\u00010\t2]\u0010\u000e\u001aY\u0012\u0004\u0012\u0002H\u0002\u0012\u0004\u0012\u0002H\r\u0012C\u0012A\u0012\u0013\u0012\u00110\u0010¢\u0006\f\b\u0011\u0012\b\b\u0012\u0012\u0004\b\b(\u0013\u0012\u0013\u0012\u00110\u0010¢\u0006\f\b\u0011\u0012\b\b\u0012\u0012\u0004\b\b(\u0014\u0012\u0004\u0012\u00020\u00150\t¢\u0006\f\b\u0011\u0012\b\b\u0012\u0012\u0004\b\b(\u0016\u0012\u0004\u0012\u00020\u00150\u000fH��¢\u0006\u0002\u0010\u0017\u001a\u0085\u0002\u0010��\u001a\u001a\u0012\u0004\u0012\u0002H\u0002\u0012\u0010\u0012\u000e\u0012\u0004\u0012\u0002H\u0002\u0012\u0004\u0012\u0002H\u00180\u00030\u0001\"\b\b��\u0010\u0004*\u00020\u0005\"\b\b\u0001\u0010\u0002*\u00020\u0005\"\b\b\u0002\u0010\u0018*\u00020\u00052\u0006\u0010\u0006\u001a\u0002H\u00042\u0012\u0010\u0007\u001a\u000e\u0012\u0004\u0012\u0002H\u0004\u0012\u0004\u0012\u0002H\u00020\u00032\u001c\b\u0002\u0010\u000b\u001a\u0016\u0012\u0004\u0012\u0002H\u0002\u0012\u0004\u0012\u00020\n\u0012\u0004\u0012\u0002H\u0002\u0018\u00010\t2Q\u0010\u0019\u001aM\u0012\u0013\u0012\u0011H\u0004¢\u0006\f\b\u0011\u0012\b\b\u0012\u0012\u0004\b\b(\u001a\u0012.\u0012,\u0012\u0013\u0012\u00110\u0010¢\u0006\f\b\u0011\u0012\b\b\u0012\u0012\u0004\b\b(\u001a\u0012\u0004\u0012\u00020\u00100\u0003¢\u0006\f\b\u0011\u0012\b\b\u0012\u0012\u0004\b\b(\u001b\u0012\u0004\u0012\u0002H\u00040\t23\u0010\u001c\u001a/\u0012\u0004\u0012\u0002H\u0004\u0012\u001f\u0012\u001d\u0012\u0004\u0012\u00020\u0010\u0012\u0004\u0012\u00020\u00100\u0003¢\u0006\f\b\u0011\u0012\b\b\u0012\u0012\u0004\b\b(\u001d\u0012\u0004\u0012\u0002H\u00180\tH��¢\u0006\u0002\u0010\u001e\u001aî\u0002\u0010��\u001a\u001a\u0012\u0004\u0012\u0002H\u0002\u0012\u0010\u0012\u000e\u0012\u0004\u0012\u0002H\r\u0012\u0004\u0012\u0002H\u00180\u00030\u0001\"\b\b��\u0010\u0004*\u00020\u0005\"\b\b\u0001\u0010\u0002*\u00020\u0005\"\b\b\u0002\u0010\u0018*\u00020\u0005\"\b\b\u0003\u0010\r*\u00020\u00052\u0006\u0010\u0006\u001a\u0002H\u00042\u0012\u0010\u0007\u001a\u000e\u0012\u0004\u0012\u0002H\u0004\u0012\u0004\u0012\u0002H\u00020\u00032\u001c\b\u0002\u0010\u000b\u001a\u0016\u0012\u0004\u0012\u0002H\u0002\u0012\u0004\u0012\u00020\n\u0012\u0004\u0012\u0002H\u0002\u0018\u00010\t2Q\u0010\u0019\u001aM\u0012\u0013\u0012\u0011H\u0004¢\u0006\f\b\u0011\u0012\b\b\u0012\u0012\u0004\b\b(\u001a\u0012.\u0012,\u0012\u0013\u0012\u00110\u0010¢\u0006\f\b\u0011\u0012\b\b\u0012\u0012\u0004\b\b(\u001a\u0012\u0004\u0012\u00020\u00100\u0003¢\u0006\f\b\u0011\u0012\b\b\u0012\u0012\u0004\b\b(\u001b\u0012\u0004\u0012\u0002H\u00040\t2]\u0010\u000e\u001aY\u0012\u0004\u0012\u0002H\u0002\u0012\u0004\u0012\u0002H\r\u0012C\u0012A\u0012\u0013\u0012\u00110\u0010¢\u0006\f\b\u0011\u0012\b\b\u0012\u0012\u0004\b\b(\u0013\u0012\u0013\u0012\u00110\u0010¢\u0006\f\b\u0011\u0012\b\b\u0012\u0012\u0004\b\b(\u0014\u0012\u0004\u0012\u00020\u00150\t¢\u0006\f\b\u0011\u0012\b\b\u0012\u0012\u0004\b\b(\u0016\u0012\u0004\u0012\u00020\u00150\u000f23\u0010\u001c\u001a/\u0012\u0004\u0012\u0002H\u0004\u0012\u001f\u0012\u001d\u0012\u0004\u0012\u00020\u0010\u0012\u0004\u0012\u00020\u00100\u0003¢\u0006\f\b\u0011\u0012\b\b\u0012\u0012\u0004\b\b(\u001d\u0012\u0004\u0012\u0002H\u00180\tH��¢\u0006\u0002\u0010\u001f¨\u0006 "}, d2 = {"primalAndPullback", "Lkotlin/Pair;", "Output", "Lkotlin/Function1;", "Input", "", "x", "f", "wrapInput", "Lkotlin/Function2;", "Lorg/diffkt/Wrapper;", "wrapOutput", "(Ljava/lang/Object;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function2;Lkotlin/jvm/functions/Function2;)Lkotlin/Pair;", "OutputTangent", "setOutputTangent", "Lkotlin/Function3;", "Lorg/diffkt/DTensor;", "Lkotlin/ParameterName;", "name", "tensor", "tangent", "", "setTensorTangent", "(Ljava/lang/Object;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function2;Lkotlin/jvm/functions/Function2;Lkotlin/jvm/functions/Function3;)Lkotlin/Pair;", "InputTangent", "makeReverseInput", "primal", "makeReverseTensor", "extractInputTangent", "extractTensorTangent", "(Ljava/lang/Object;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function2;Lkotlin/jvm/functions/Function2;Lkotlin/jvm/functions/Function2;)Lkotlin/Pair;", "(Ljava/lang/Object;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function2;Lkotlin/jvm/functions/Function2;Lkotlin/jvm/functions/Function3;Lkotlin/jvm/functions/Function2;)Lkotlin/Pair;", "api"})
/* loaded from: input_file:org/diffkt/WrappedPullbackKt.class */
public final class WrappedPullbackKt {
    @NotNull
    public static final <Input, Output, InputTangent, OutputTangent> Pair<Output, Function1<OutputTangent, InputTangent>> primalAndPullback(@NotNull Input input, @NotNull Function1<? super Input, ? extends Output> function1, @Nullable Function2<? super Output, ? super Wrapper, ? extends Output> function2, @NotNull Function2<? super Input, ? super Function1<? super DTensor, ? extends DTensor>, ? extends Input> function22, @NotNull Function3<? super Output, ? super OutputTangent, ? super Function2<? super DTensor, ? super DTensor, Unit>, Unit> function3, @NotNull Function2<? super Input, ? super Function1<? super DTensor, ? extends DTensor>, ? extends InputTangent> function23) {
        Intrinsics.checkNotNullParameter(input, "x");
        Intrinsics.checkNotNullParameter(function1, "f");
        Intrinsics.checkNotNullParameter(function22, "makeReverseInput");
        Intrinsics.checkNotNullParameter(function3, "setOutputTangent");
        Intrinsics.checkNotNullParameter(function23, "extractInputTangent");
        final ReverseDerivativeID reverseDerivativeID = new ReverseDerivativeID();
        Object invoke = function22.invoke(input, new WrappedPullbackKt$primalAndPullback$wrappedInput$1(reverseDerivativeID));
        Object invoke2 = function1.invoke(invoke);
        Ref.BooleanRef booleanRef = new Ref.BooleanRef();
        Wrapper wrapper = new Wrapper() { // from class: org.diffkt.WrappedPullbackKt$primalAndPullback$outputUnwrapper$1
            @Override // org.diffkt.Wrapper
            @NotNull
            public DTensor wrapDTensor(@NotNull DTensor dTensor) {
                DTensor dTensor2;
                Intrinsics.checkNotNullParameter(dTensor, "value");
                DTensor dTensor3 = dTensor;
                while (true) {
                    dTensor2 = dTensor3;
                    if (dTensor2.getDerivativeID().getSequence() <= ReverseDerivativeID.this.getSequence()) {
                        break;
                    }
                    dTensor3 = dTensor2.getPrimal();
                }
                if (Intrinsics.areEqual(dTensor2.getDerivativeID(), ReverseDerivativeID.this)) {
                    dTensor2 = dTensor2.getPrimal();
                }
                return dTensor2;
            }
        };
        return new Pair<>(function2 != null ? function2.invoke(invoke2, wrapper) : wrapper.wrap(invoke2), new WrappedPullbackKt$primalAndPullback$1(booleanRef, reverseDerivativeID, function3, invoke2, function23, invoke));
    }

    public static /* synthetic */ Pair primalAndPullback$default(Object obj, Function1 function1, Function2 function2, Function2 function22, Function3 function3, Function2 function23, int i, Object obj2) {
        if ((i & 4) != 0) {
            function2 = null;
        }
        return primalAndPullback(obj, function1, function2, function22, function3, function23);
    }

    @NotNull
    public static final <Input, Output, InputTangent> Pair<Output, Function1<Output, InputTangent>> primalAndPullback(@NotNull Input input, @NotNull Function1<? super Input, ? extends Output> function1, @Nullable Function2<? super Output, ? super Wrapper, ? extends Output> function2, @NotNull Function2<? super Input, ? super Function1<? super DTensor, ? extends DTensor>, ? extends Input> function22, @NotNull Function2<? super Input, ? super Function1<? super DTensor, ? extends DTensor>, ? extends InputTangent> function23) {
        Intrinsics.checkNotNullParameter(input, "x");
        Intrinsics.checkNotNullParameter(function1, "f");
        Intrinsics.checkNotNullParameter(function22, "makeReverseInput");
        Intrinsics.checkNotNullParameter(function23, "extractInputTangent");
        return primalAndPullback(input, function1, function2, function22, ReverseDerivativeWrappersKt.defaultSetOutputTangent(function2), function23);
    }

    public static /* synthetic */ Pair primalAndPullback$default(Object obj, Function1 function1, Function2 function2, Function2 function22, Function2 function23, int i, Object obj2) {
        if ((i & 4) != 0) {
            function2 = null;
        }
        return primalAndPullback(obj, (Function1<? super Object, ? extends Output>) function1, function2, (Function2<? super Object, ? super Function1<? super DTensor, ? extends DTensor>, ? extends Object>) function22, (Function2<? super Object, ? super Function1<? super DTensor, ? extends DTensor>, ? extends InputTangent>) function23);
    }

    @NotNull
    public static final <Input, Output, OutputTangent> Pair<Output, Function1<OutputTangent, Input>> primalAndPullback(@NotNull Input input, @NotNull Function1<? super Input, ? extends Output> function1, @Nullable Function2<? super Input, ? super Wrapper, ? extends Input> function2, @Nullable Function2<? super Output, ? super Wrapper, ? extends Output> function22, @NotNull Function3<? super Output, ? super OutputTangent, ? super Function2<? super DTensor, ? super DTensor, Unit>, Unit> function3) {
        Intrinsics.checkNotNullParameter(input, "x");
        Intrinsics.checkNotNullParameter(function1, "f");
        Intrinsics.checkNotNullParameter(function3, "setOutputTangent");
        return primalAndPullback(input, function1, function22, ReverseDerivativeWrappersKt.defaultMakeReverseInput(function2), function3, ReverseDerivativeWrappersKt.defaultExtractInputTangent(function2));
    }

    public static /* synthetic */ Pair primalAndPullback$default(Object obj, Function1 function1, Function2 function2, Function2 function22, Function3 function3, int i, Object obj2) {
        if ((i & 4) != 0) {
            function2 = null;
        }
        if ((i & 8) != 0) {
            function22 = null;
        }
        return primalAndPullback(obj, (Function1<? super Object, ? extends Output>) function1, (Function2<? super Object, ? super Wrapper, ? extends Object>) function2, function22, function3);
    }

    @NotNull
    public static final <Input, Output> Pair<Output, Function1<Output, Input>> primalAndPullback(@NotNull Input input, @NotNull Function1<? super Input, ? extends Output> function1, @Nullable Function2<? super Input, ? super Wrapper, ? extends Input> function2, @Nullable Function2<? super Output, ? super Wrapper, ? extends Output> function22) {
        Intrinsics.checkNotNullParameter(input, "x");
        Intrinsics.checkNotNullParameter(function1, "f");
        return primalAndPullback(input, function1, function22, ReverseDerivativeWrappersKt.defaultMakeReverseInput(function2), ReverseDerivativeWrappersKt.defaultSetOutputTangent(function22), ReverseDerivativeWrappersKt.defaultExtractInputTangent(function2));
    }

    public static /* synthetic */ Pair primalAndPullback$default(Object obj, Function1 function1, Function2 function2, Function2 function22, int i, Object obj2) {
        if ((i & 4) != 0) {
            function2 = null;
        }
        if ((i & 8) != 0) {
            function22 = null;
        }
        return primalAndPullback(obj, function1, function2, function22);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static final DTensor primalAndPullback$makeReverseTensor(ReverseDerivativeID reverseDerivativeID, DTensor dTensor) {
        return dTensor instanceof DScalar ? new ActiveVariableReverseScalar((DScalar) dTensor, reverseDerivativeID) : new ActiveVariableReverseTensor(dTensor, reverseDerivativeID);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static final void primalAndPullback$setTensorTangent(DTensor dTensor, DTensor dTensor2) {
        if (dTensor instanceof FloatTensor) {
            return;
        }
        if (!(dTensor instanceof ReverseTensor)) {
            throw new IllegalArgumentException(("expected ReverseTensor, found " + dTensor).toString());
        }
        ((ReverseTensor) dTensor).pushback(dTensor2);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static final DTensor primalAndPullback$pullback$extractTensorTangent(Map<ReverseTensor, ? extends DTensor> map, DTensor dTensor) {
        DTensor dTensor2 = map.get(dTensor);
        Intrinsics.checkNotNull(dTensor2);
        return dTensor2;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static final <OutputTangent, InputTangent, Output, Input> InputTangent primalAndPullback$pullback(Ref.BooleanRef booleanRef, ReverseDerivativeID reverseDerivativeID, Function3<? super Output, ? super OutputTangent, ? super Function2<? super DTensor, ? super DTensor, Unit>, Unit> function3, Output output, Function2<? super Input, ? super Function1<? super DTensor, ? extends DTensor>, ? extends InputTangent> function2, Input input, OutputTangent outputtangent) {
        if (booleanRef.element) {
            throw new IllegalStateException("pullback may only be invoked once");
        }
        booleanRef.element = true;
        reverseDerivativeID.setUpstreamShape(Shape.Companion.invoke());
        function3.invoke(output, outputtangent, WrappedPullbackKt$primalAndPullback$pullback$1.INSTANCE);
        return (InputTangent) function2.invoke(input, new WrappedPullbackKt$primalAndPullback$pullback$2(reverseDerivativeID.reversePass()));
    }
}
