package org.diffkt;

import java.util.ArrayList;
import kotlin.Metadata;
import kotlin.Unit;
import kotlin.collections.CollectionsKt;
import kotlin.jvm.functions.Function1;
import kotlin.jvm.functions.Function2;
import kotlin.jvm.functions.Function3;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

/* compiled from: ReverseDerivativeWrappers.kt */
@Metadata(mv = {Convolve.H_AXIS, 6, Convolve.N_AXIS}, k = Convolve.W_AXIS, xi = 48, d1 = {"��6\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010��\n��\n\u0002\u0018\u0002\n\u0002\b\u0004\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0010\u0002\n\u0002\b\u0003\u001a]\u0010��\u001a/\u0012\u0004\u0012\u0002H\u0002\u0012\u001f\u0012\u001d\u0012\u0004\u0012\u00020\u0004\u0012\u0004\u0012\u00020\u00040\u0003¢\u0006\f\b\u0005\u0012\b\b\u0006\u0012\u0004\b\b(\u0007\u0012\u0004\u0012\u0002H\u00020\u0001\"\b\b��\u0010\u0002*\u00020\b2\u001c\b\u0002\u0010\t\u001a\u0016\u0012\u0004\u0012\u0002H\u0002\u0012\u0004\u0012\u00020\n\u0012\u0004\u0012\u0002H\u0002\u0018\u00010\u0001H��\u001al\u0010\u000b\u001a>\u0012\u0004\u0012\u0002H\u0002\u0012.\u0012,\u0012\u0013\u0012\u00110\u0004¢\u0006\f\b\u0005\u0012\b\b\u0006\u0012\u0004\b\b(\f\u0012\u0004\u0012\u00020\u00040\u0003¢\u0006\f\b\u0005\u0012\b\b\u0006\u0012\u0004\b\b(\r\u0012\u0004\u0012\u0002H\u00020\u0001\"\b\b��\u0010\u0002*\u00020\b2\u001c\b\u0002\u0010\t\u001a\u0016\u0012\u0004\u0012\u0002H\u0002\u0012\u0004\u0012\u00020\n\u0012\u0004\u0012\u0002H\u0002\u0018\u00010\u0001H��\u001a\u0087\u0001\u0010\u000e\u001aY\u0012\u0004\u0012\u0002H\u0010\u0012\u0004\u0012\u0002H\u0010\u0012C\u0012A\u0012\u0013\u0012\u00110\u0004¢\u0006\f\b\u0005\u0012\b\b\u0006\u0012\u0004\b\b(\u0011\u0012\u0013\u0012\u00110\u0004¢\u0006\f\b\u0005\u0012\b\b\u0006\u0012\u0004\b\b(\u0012\u0012\u0004\u0012\u00020\u00130\u0001¢\u0006\f\b\u0005\u0012\b\b\u0006\u0012\u0004\b\b(\u0014\u0012\u0004\u0012\u00020\u00130\u000f\"\b\b��\u0010\u0010*\u00020\b2\u001c\b\u0002\u0010\u0015\u001a\u0016\u0012\u0004\u0012\u0002H\u0010\u0012\u0004\u0012\u00020\n\u0012\u0004\u0012\u0002H\u0010\u0018\u00010\u0001H��¨\u0006\u0016"}, d2 = {"defaultExtractInputTangent", "Lkotlin/Function2;", "Input", "Lkotlin/Function1;", "Lorg/diffkt/DTensor;", "Lkotlin/ParameterName;", "name", "extractTensorTangent", "", "wrapInput", "Lorg/diffkt/Wrapper;", "defaultMakeReverseInput", "primal", "makeReverseTensor", "defaultSetOutputTangent", "Lkotlin/Function3;", "Output", "tensor", "tangent", "", "setTensorTangent", "wrapOutput", "api"})
/* loaded from: input_file:org/diffkt/ReverseDerivativeWrappersKt.class */
public final class ReverseDerivativeWrappersKt {
    @NotNull
    public static final <Output> Function3<Output, Output, Function2<? super DTensor, ? super DTensor, Unit>, Unit> defaultSetOutputTangent(@Nullable Function2<? super Output, ? super Wrapper, ? extends Output> function2) {
        return new ReverseDerivativeWrappersKt$defaultSetOutputTangent$1(function2);
    }

    public static /* synthetic */ Function3 defaultSetOutputTangent$default(Function2 function2, int i, Object obj) {
        if ((i & 1) != 0) {
            function2 = null;
        }
        return defaultSetOutputTangent(function2);
    }

    @NotNull
    public static final <Input> Function2<Input, Function1<? super DTensor, ? extends DTensor>, Input> defaultMakeReverseInput(@Nullable Function2<? super Input, ? super Wrapper, ? extends Input> function2) {
        return new ReverseDerivativeWrappersKt$defaultMakeReverseInput$1(function2);
    }

    public static /* synthetic */ Function2 defaultMakeReverseInput$default(Function2 function2, int i, Object obj) {
        if ((i & 1) != 0) {
            function2 = null;
        }
        return defaultMakeReverseInput(function2);
    }

    @NotNull
    public static final <Input> Function2<Input, Function1<? super DTensor, ? extends DTensor>, Input> defaultExtractInputTangent(@Nullable Function2<? super Input, ? super Wrapper, ? extends Input> function2) {
        return new ReverseDerivativeWrappersKt$defaultExtractInputTangent$1(function2);
    }

    public static /* synthetic */ Function2 defaultExtractInputTangent$default(Function2 function2, int i, Object obj) {
        if ((i & 1) != 0) {
            function2 = null;
        }
        return defaultExtractInputTangent(function2);
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* JADX WARN: Type inference failed for: r0v3, types: [java.lang.Object, org.diffkt.ReverseDerivativeWrappersKt$defaultSetOutputTangent$setOutputTangent$tangentSetter$1] */
    public static final <Output> void defaultSetOutputTangent$setOutputTangent(Function2<? super Output, ? super Wrapper, ? extends Output> function2, Output output, Output output2, final Function2<? super DTensor, ? super DTensor, Unit> function22) {
        final ArrayList arrayList = new ArrayList();
        Wrapper wrapper = new Wrapper() { // from class: org.diffkt.ReverseDerivativeWrappersKt$defaultSetOutputTangent$setOutputTangent$tangentGatherer$1
            @Override // org.diffkt.Wrapper
            @NotNull
            public DTensor wrapDTensor(@NotNull DTensor dTensor) {
                Intrinsics.checkNotNullParameter(dTensor, "value");
                arrayList.add(dTensor);
                return dTensor;
            }
        };
        ?? r0 = new Wrapper() { // from class: org.diffkt.ReverseDerivativeWrappersKt$defaultSetOutputTangent$setOutputTangent$tangentSetter$1
            private int nextId;

            public final int getNextId() {
                return this.nextId;
            }

            public final void setNextId(int i) {
                this.nextId = i;
            }

            @Override // org.diffkt.Wrapper
            @NotNull
            public DTensor wrapDTensor(@NotNull DTensor dTensor) {
                Intrinsics.checkNotNullParameter(dTensor, "value");
                DTensor dTensor2 = (DTensor) CollectionsKt.getOrNull(arrayList, this.nextId);
                if (dTensor2 == null) {
                    throw new Exception("inconsistent number of leaves in tangent object, expected at least " + (this.nextId + 1));
                }
                function22.invoke(dTensor, dTensor2);
                this.nextId++;
                return dTensor;
            }
        };
        if (function2 != null) {
            function2.invoke(output2, wrapper);
            function2.invoke(output, (Object) r0);
        } else {
            wrapper.wrap(output2);
            r0.wrap(output);
        }
        int size = arrayList.size();
        int nextId = r0.getNextId();
        if (!(nextId == size)) {
            throw new IllegalArgumentException(("number of leaves in the output object (" + nextId + ") is inconsistent with number of leaves in the tangent object (" + size + ')').toString());
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static final <Input> Input defaultMakeReverseInput$makeReverseInput(Function2<? super Input, ? super Wrapper, ? extends Input> function2, Input input, final Function1<? super DTensor, ? extends DTensor> function1) {
        Wrapper wrapper = new Wrapper() { // from class: org.diffkt.ReverseDerivativeWrappersKt$defaultMakeReverseInput$makeReverseInput$revWrapper$1
            @Override // org.diffkt.Wrapper
            @NotNull
            public DTensor wrapDTensor(@NotNull DTensor dTensor) {
                Intrinsics.checkNotNullParameter(dTensor, "value");
                return (DTensor) function1.invoke(dTensor);
            }
        };
        return function2 != null ? (Input) function2.invoke(input, wrapper) : (Input) wrapper.wrap(input);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static final <Input> Input defaultExtractInputTangent$extractInputTangent(Function2<? super Input, ? super Wrapper, ? extends Input> function2, Input input, final Function1<? super DTensor, ? extends DTensor> function1) {
        Wrapper wrapper = new Wrapper() { // from class: org.diffkt.ReverseDerivativeWrappersKt$defaultExtractInputTangent$extractInputTangent$tangentExtractor$1
            @Override // org.diffkt.Wrapper
            @NotNull
            public DTensor wrapDTensor(@NotNull DTensor dTensor) {
                Intrinsics.checkNotNullParameter(dTensor, "value");
                return (DTensor) function1.invoke(dTensor);
            }
        };
        return function2 != null ? (Input) function2.invoke(input, wrapper) : (Input) wrapper.wrap(input);
    }
}
