package org.diffkt.tracing;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import kotlin.Metadata;
import kotlin.Pair;
import kotlin._Assertions;
import kotlin.collections.CollectionsKt;
import kotlin.jvm.functions.Function1;
import kotlin.jvm.functions.Function2;
import kotlin.jvm.internal.Intrinsics;
import org.diffkt.Convolve;
import org.diffkt.DTensor;
import org.diffkt.FloatScalar;
import org.diffkt.TimingStats;
import org.diffkt.Wrapper;
import org.diffkt.random.RandomKey;
import org.diffkt.tracing.TracingRandomKey;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

/* compiled from: TracingJit.kt */
@Metadata(mv = {Convolve.H_AXIS, 6, Convolve.N_AXIS}, k = Convolve.W_AXIS, xi = 48, d1 = {"��N\n\u0002\b\u0002\n\u0002\u0010��\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010 \n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0010\b\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010\u000e\n��\n\u0002\u0010\u000b\n\u0002\b\u0003\u001a3\u0010��\u001a\u0002H\u0001\"\b\b��\u0010\u0001*\u00020\u00022\f\u0010\u0003\u001a\b\u0012\u0004\u0012\u0002H\u00010\u00042\f\u0010\u0005\u001a\b\u0012\u0004\u0012\u00020\u00070\u0006H��¢\u0006\u0002\u0010\b\u001a:\u0010\t\u001a\u000e\u0012\u0004\u0012\u0002H\u000b\u0012\u0004\u0012\u0002H\f0\n\"\b\b��\u0010\u000b*\u00020\u0002\"\b\b\u0001\u0010\f*\u00020\u00022\u0012\u0010\r\u001a\u000e\u0012\u0004\u0012\u0002H\u000b\u0012\u0004\u0012\u0002H\f0\u000e\u001a\u0080\u0001\u0010\t\u001a\u000e\u0012\u0004\u0012\u0002H\u000b\u0012\u0004\u0012\u0002H\f0\n\"\b\b��\u0010\u000b*\u00020\u0002\"\b\b\u0001\u0010\f*\u00020\u00022\u0012\u0010\r\u001a\u000e\u0012\u0004\u0012\u0002H\u000b\u0012\u0004\u0012\u0002H\f0\u000e2\u001c\b\u0002\u0010\u000f\u001a\u0016\u0012\u0004\u0012\u0002H\u000b\u0012\u0004\u0012\u00020\u0011\u0012\u0004\u0012\u0002H\u000b\u0018\u00010\u00102\b\b\u0002\u0010\u0012\u001a\u00020\u00132\b\b\u0002\u0010\u0014\u001a\u00020\u00152\b\b\u0002\u0010\u0016\u001a\u00020\u00172\b\b\u0002\u0010\u0018\u001a\u00020\u0019\u001a3\u0010\u001a\u001a\u0002H\u0001\"\b\b��\u0010\u0001*\u00020\u00022\f\u0010\u0003\u001a\b\u0012\u0004\u0012\u0002H\u00010\u00042\f\u0010\u0005\u001a\b\u0012\u0004\u0012\u00020\u00070\u0006H��¢\u0006\u0002\u0010\b\u001a3\u0010\u001b\u001a\u0002H\u0001\"\b\b��\u0010\u0001*\u00020\u00022\f\u0010\u0003\u001a\b\u0012\u0004\u0012\u0002H\u00010\u00042\f\u0010\u0005\u001a\b\u0012\u0004\u0012\u00020\u00070\u0006H��¢\u0006\u0002\u0010\b¨\u0006\u001c"}, d2 = {"eval", "TValue", "", "tracingResult", "Lorg/diffkt/tracing/DedaggedTracingTensor;", "variables", "", "Lorg/diffkt/DTensor;", "(Lorg/diffkt/tracing/DedaggedTracingTensor;Ljava/util/List;)Ljava/lang/Object;", "jit", "Lorg/diffkt/tracing/JittedFunction;", "TInput", "TOutput", "f", "Lkotlin/Function1;", "wrapInput", "Lkotlin/Function2;", "Lorg/diffkt/Wrapper;", "cacheSizeLimit", "", "evaluatorToUse", "Lorg/diffkt/tracing/JitEvaluatorToUse;", "loggingName", "", "shouldLog", "", "jitEval", "scalarEval", "api"})
/* loaded from: input_file:org/diffkt/tracing/TracingJitKt.class */
public final class TracingJitKt {
    @NotNull
    public static final <TInput, TOutput> JittedFunction<TInput, TOutput> jit(@NotNull Function1<? super TInput, ? extends TOutput> function1) {
        Intrinsics.checkNotNullParameter(function1, "f");
        return jit$default(function1, null, 0, JitEvaluatorToUse.BestAvailable, null, false, 54, null);
    }

    @NotNull
    public static final <TInput, TOutput> JittedFunction<TInput, TOutput> jit(@NotNull final Function1<? super TInput, ? extends TOutput> function1, @Nullable Function2<? super TInput, ? super Wrapper, ? extends TInput> function2, int i, @NotNull JitEvaluatorToUse jitEvaluatorToUse, @NotNull String str, boolean z) {
        Intrinsics.checkNotNullParameter(function1, "f");
        Intrinsics.checkNotNullParameter(jitEvaluatorToUse, "evaluatorToUse");
        Intrinsics.checkNotNullParameter(str, "loggingName");
        final TimingStats timingStats = z ? new TimingStats(str + "_timing") : null;
        return ((function1 instanceof JittedFunction) && ((JittedFunction) function1).getEvaluatorToUse() == jitEvaluatorToUse) ? (JittedFunction) function1 : jitEvaluatorToUse == JitEvaluatorToUse.None ? new JittedFunction<TInput, TOutput>() { // from class: org.diffkt.tracing.TracingJitKt$jit$1

            @NotNull
            private final JitEvaluatorToUse evaluatorToUse = JitEvaluatorToUse.None;
            private final int cacheSize;

            @Override // org.diffkt.tracing.JittedFunction
            @NotNull
            public JitEvaluatorToUse getEvaluatorToUse() {
                return this.evaluatorToUse;
            }

            @Override // org.diffkt.tracing.JittedFunction
            public int getCacheSize() {
                return this.cacheSize;
            }

            @NotNull
            public TOutput invoke(@NotNull TInput tinput) {
                Intrinsics.checkNotNullParameter(tinput, "p1");
                TimingStats timingStats2 = TimingStats.this;
                if (timingStats2 != null) {
                    timingStats2.start();
                }
                TOutput toutput = (TOutput) function1.invoke(tinput);
                TimingStats timingStats3 = TimingStats.this;
                if (timingStats3 != null) {
                    timingStats3.end();
                }
                return toutput;
            }
        } : new TracingJitKt$jit$2(jitEvaluatorToUse, i, z, str, timingStats, function2, function1);
    }

    public static /* synthetic */ JittedFunction jit$default(Function1 function1, Function2 function2, int i, JitEvaluatorToUse jitEvaluatorToUse, String str, boolean z, int i2, Object obj) {
        if ((i2 & 2) != 0) {
            function2 = null;
        }
        if ((i2 & 4) != 0) {
            i = 5;
        }
        if ((i2 & 8) != 0) {
            jitEvaluatorToUse = JitEvaluatorToUse.BestAvailable;
        }
        if ((i2 & 16) != 0) {
            str = "jit";
        }
        if ((i2 & 32) != 0) {
            z = false;
        }
        return jit(function1, function2, i, jitEvaluatorToUse, str, z);
    }

    @NotNull
    public static final <TValue> TValue scalarEval(@NotNull DedaggedTracingTensor<TValue> dedaggedTracingTensor, @NotNull List<? extends DTensor> list) {
        Intrinsics.checkNotNullParameter(dedaggedTracingTensor, "tracingResult");
        Intrinsics.checkNotNullParameter(list, "variables");
        final float[] fArr = new float[dedaggedTracingTensor.getNumInputs() + dedaggedTracingTensor.getNumTemps()];
        int i = 0;
        for (DTensor dTensor : list) {
            int i2 = i;
            i = i2 + 1;
            Intrinsics.checkNotNull(dTensor, "null cannot be cast to non-null type org.diffkt.FloatScalar");
            fArr[i2] = ((FloatScalar) dTensor).getValue();
        }
        boolean z = i == dedaggedTracingTensor.getNumInputs();
        if (_Assertions.ENABLED && !z) {
            throw new AssertionError("Assertion failed");
        }
        int numInputs = dedaggedTracingTensor.getNumInputs();
        Iterator<T> it = dedaggedTracingTensor.getAssignments().iterator();
        while (it.hasNext()) {
            Pair pair = (Pair) it.next();
            int intValue = ((Number) pair.component1()).intValue();
            Traceable traceable = (Traceable) pair.component2();
            boolean z2 = intValue == numInputs;
            if (_Assertions.ENABLED && !z2) {
                throw new AssertionError("Assertion failed");
            }
            Intrinsics.checkNotNull(traceable, "null cannot be cast to non-null type org.diffkt.tracing.TracingTensor");
            fArr[intValue] = ((TracingTensor) traceable).floatEval(fArr);
            numInputs++;
        }
        return (TValue) new Wrapper() { // from class: org.diffkt.tracing.TracingJitKt$scalarEval$outputEvaluator$1
            @Override // org.diffkt.Wrapper
            @NotNull
            public DTensor wrapDTensor(@NotNull DTensor dTensor2) {
                Intrinsics.checkNotNullParameter(dTensor2, "value");
                return dTensor2 instanceof TracingTensor ? new FloatScalar(((TracingTensor) dTensor2).floatEval(fArr)) : dTensor2;
            }
        }.wrap(dedaggedTracingTensor.getValue());
    }

    @NotNull
    public static final <TValue> TValue jitEval(@NotNull DedaggedTracingTensor<TValue> dedaggedTracingTensor, @NotNull List<? extends DTensor> list) {
        Intrinsics.checkNotNullParameter(dedaggedTracingTensor, "tracingResult");
        Intrinsics.checkNotNullParameter(list, "variables");
        Function1 jvmTracingTranslator = JvmTracingTranslatorKt.jvmTracingTranslator(dedaggedTracingTensor);
        List<? extends DTensor> list2 = list;
        ArrayList arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(list2, 10));
        for (DTensor dTensor : list2) {
            Intrinsics.checkNotNull(dTensor, "null cannot be cast to non-null type org.diffkt.FloatScalar");
            arrayList.add(Float.valueOf(((FloatScalar) dTensor).getValue()));
        }
        float[] floatArray = CollectionsKt.toFloatArray(arrayList);
        boolean z = jvmTracingTranslator != null;
        if (_Assertions.ENABLED && !z) {
            throw new AssertionError("Assertion failed");
        }
        Intrinsics.checkNotNull(jvmTracingTranslator);
        return (TValue) jvmTracingTranslator.invoke(floatArray);
    }

    @NotNull
    public static final <TValue> TValue eval(@NotNull DedaggedTracingTensor<TValue> dedaggedTracingTensor, @NotNull List<? extends DTensor> list) {
        Intrinsics.checkNotNullParameter(dedaggedTracingTensor, "tracingResult");
        Intrinsics.checkNotNullParameter(list, "variables");
        TraceId traceId = dedaggedTracingTensor.getTraceId();
        int numInputs = dedaggedTracingTensor.getNumInputs() + dedaggedTracingTensor.getNumTemps();
        DTensor[] dTensorArr = new DTensor[numInputs];
        for (int i = 0; i < numInputs; i++) {
            dTensorArr[i] = null;
        }
        int i2 = 0;
        Iterator<T> it = list.iterator();
        while (it.hasNext()) {
            int i3 = i2;
            i2 = i3 + 1;
            dTensorArr[i3] = (DTensor) it.next();
        }
        boolean z = i2 == dedaggedTracingTensor.getNumInputs();
        if (_Assertions.ENABLED && !z) {
            throw new AssertionError("Assertion failed");
        }
        final TracingEvaluator tracingEvaluator = new TracingEvaluator(dTensorArr, traceId);
        int numInputs2 = dedaggedTracingTensor.getNumInputs();
        Iterator<T> it2 = dedaggedTracingTensor.getAssignments().iterator();
        while (it2.hasNext()) {
            Pair pair = (Pair) it2.next();
            int intValue = ((Number) pair.component1()).intValue();
            Traceable traceable = (Traceable) pair.component2();
            boolean z2 = intValue == numInputs2;
            if (_Assertions.ENABLED && !z2) {
                throw new AssertionError("Assertion failed");
            }
            dTensorArr[intValue] = tracingEvaluator.evaluate(traceable);
            numInputs2++;
        }
        return (TValue) new Wrapper() { // from class: org.diffkt.tracing.TracingJitKt$eval$outputEvaluator$1
            @Override // org.diffkt.Wrapper
            @NotNull
            public DTensor wrapDTensor(@NotNull DTensor dTensor) {
                Intrinsics.checkNotNullParameter(dTensor, "value");
                return dTensor instanceof TracingTensor ? TracingEvaluator.this.evaluate((Traceable) dTensor) : dTensor;
            }

            @Override // org.diffkt.Wrapper
            @NotNull
            public RandomKey wrapRandomKey(@NotNull RandomKey randomKey) {
                Intrinsics.checkNotNullParameter(randomKey, "value");
                if (!(randomKey instanceof TracingRandomKey)) {
                    return randomKey;
                }
                DTensor evaluate = TracingEvaluator.this.evaluate((Traceable) randomKey);
                Intrinsics.checkNotNull(evaluate, "null cannot be cast to non-null type org.diffkt.tracing.TracingRandomKey.RandomTensorWrapper");
                return ((TracingRandomKey.RandomTensorWrapper) evaluate).getKey();
            }
        }.wrap(dedaggedTracingTensor.getValue());
    }
}
