package org.diffkt.tracing;

import java.util.ArrayList;
import kotlin.Metadata;
import kotlin.Pair;
import kotlin.collections.CollectionsKt;
import kotlin.jvm.functions.Function1;
import kotlin.jvm.functions.Function2;
import kotlin.jvm.internal.Intrinsics;
import kotlin.jvm.internal.Ref;
import org.diffkt.Convolve;
import org.diffkt.DScalar;
import org.diffkt.DTensor;
import org.diffkt.FloatScalar;
import org.diffkt.Shape;
import org.diffkt.TimingStats;
import org.diffkt.Wrapper;
import org.diffkt.random.RandomKey;
import org.diffkt.tracing.TracingRandomKey;
import org.diffkt.tracing.TracingScalar;
import org.diffkt.tracing.TracingTensor;
import org.jetbrains.annotations.NotNull;

/* JADX INFO: Add missing generic type declarations: [TOutput, TInput] */
/* compiled from: TracingJit.kt */
@Metadata(mv = {Convolve.H_AXIS, 6, Convolve.N_AXIS}, k = Convolve.H_AXIS, xi = 48, d1 = {"��>\n��\n\u0002\u0018\u0002\n��\n\u0002\b\u0005\n\u0002\u0010\b\n\u0002\b\u0003\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0018\u0002\n\u0002\b\u0006\n\u0002\u0010\u000b\n��\n\u0002\u0010��\n��\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��*\u0002��\u0003\b\n\u0018��2\u000e\u0012\u0004\u0012\u00028��\u0012\u0004\u0012\u00028\u00010\u0001J\u0016\u0010\u0013\u001a\u00028\u00012\u0006\u0010\u0014\u001a\u00028��H\u0096\u0002¢\u0006\u0002\u0010\u0015J&\u0010\u0016\u001a\u00020\u0017\"\b\b\u0002\u0010\u0018*\u00020\u00192\f\u0010\u001a\u001a\b\u0012\u0004\u0012\u0002H\u00180\u001b2\u0006\u0010\u000f\u001a\u00020\u0010J\u0016\u0010\u0016\u001a\u00020\u00172\u0006\u0010\u001a\u001a\u00020\u001c2\u0006\u0010\u000f\u001a\u00020\u0010R\u001f\u0010\u0002\u001a\u000e\u0012\u0004\u0012\u00028��\u0012\u0004\u0012\u00028\u00010\u0003¢\u0006\n\n\u0002\u0010\u0006\u001a\u0004\b\u0004\u0010\u0005R\u0014\u0010\u0007\u001a\u00020\b8VX\u0096\u0004¢\u0006\u0006\u001a\u0004\b\t\u0010\nR\u0014\u0010\u000b\u001a\u00020\f8VX\u0096\u0004¢\u0006\u0006\u001a\u0004\b\r\u0010\u000eR\u0011\u0010\u000f\u001a\u00020\u0010¢\u0006\b\n��\u001a\u0004\b\u0011\u0010\u0012¨\u0006\u001d"}, d2 = {"org/diffkt/tracing/TracingJitKt$jit$2", "Lorg/diffkt/tracing/JittedFunction;", "cache", "org/diffkt/tracing/TracingJitKt$jit$2$cache$1", "getCache", "()Lorg/diffkt/tracing/TracingJitKt$jit$2$cache$1;", "Lorg/diffkt/tracing/TracingJitKt$jit$2$cache$1;", "cacheSize", "", "getCacheSize", "()I", "evaluatorToUse", "Lorg/diffkt/tracing/JitEvaluatorToUse;", "getEvaluatorToUse", "()Lorg/diffkt/tracing/JitEvaluatorToUse;", "traceId", "Lorg/diffkt/tracing/TraceId;", "getTraceId", "()Lorg/diffkt/tracing/TraceId;", "invoke", "input", "(Ljava/lang/Object;)Ljava/lang/Object;", "onlyScalarFloatValues", "", "TOutput", "", "trace", "Lorg/diffkt/tracing/DedaggedTracingTensor;", "Lorg/diffkt/tracing/TracingTensor;", "api"})
/* loaded from: input_file:org/diffkt/tracing/TracingJitKt$jit$2.class */
public final class TracingJitKt$jit$2<TInput, TOutput> implements JittedFunction<TInput, TOutput> {

    @NotNull
    private final TraceId traceId = new TraceId();

    @NotNull
    private final TracingJitKt$jit$2$cache$1 cache;
    final /* synthetic */ JitEvaluatorToUse $evaluatorToUse;
    final /* synthetic */ TimingStats $evaluationLog;
    final /* synthetic */ Function2<TInput, Wrapper, TInput> $wrapInput;
    final /* synthetic */ Function1<TInput, TOutput> $f;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* JADX WARN: Multi-variable type inference failed */
    public TracingJitKt$jit$2(JitEvaluatorToUse jitEvaluatorToUse, int i, boolean z, String str, TimingStats timingStats, Function2<? super TInput, ? super Wrapper, ? extends TInput> function2, Function1<? super TInput, ? extends TOutput> function1) {
        this.$evaluatorToUse = jitEvaluatorToUse;
        this.$evaluationLog = timingStats;
        this.$wrapInput = function2;
        this.$f = function1;
        this.cache = new TracingJitKt$jit$2$cache$1(z, str, i);
    }

    @NotNull
    public final TraceId getTraceId() {
        return this.traceId;
    }

    @Override // org.diffkt.tracing.JittedFunction
    @NotNull
    public JitEvaluatorToUse getEvaluatorToUse() {
        return this.$evaluatorToUse;
    }

    @Override // org.diffkt.tracing.JittedFunction
    public int getCacheSize() {
        return this.cache.size();
    }

    @NotNull
    public final TracingJitKt$jit$2$cache$1 getCache() {
        return this.cache;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @NotNull
    public TOutput invoke(@NotNull TInput tinput) {
        Object eval;
        Function1 function1;
        Intrinsics.checkNotNullParameter(tinput, "input");
        TimingStats timingStats = this.$evaluationLog;
        if (timingStats != null) {
            timingStats.start();
        }
        final Ref.BooleanRef booleanRef = new Ref.BooleanRef();
        booleanRef.element = this.$evaluatorToUse.compareTo(JitEvaluatorToUse.Scalar) >= 0;
        final ArrayList arrayList = new ArrayList();
        final Ref.IntRef intRef = new Ref.IntRef();
        Wrapper wrapper = new Wrapper() { // from class: org.diffkt.tracing.TracingJitKt$jit$2$invoke$tracingWrapper$1
            @Override // org.diffkt.Wrapper
            @NotNull
            public DTensor wrapDTensor(@NotNull DTensor dTensor) {
                Intrinsics.checkNotNullParameter(dTensor, "value");
                int i = intRef.element;
                intRef.element = i + 1;
                arrayList.add(dTensor);
                if (dTensor instanceof FloatScalar) {
                    return new TracingScalar.Variable(i, null, this.getTraceId(), 2, null);
                }
                if (dTensor instanceof DScalar) {
                    booleanRef.element = false;
                    return new TracingScalar.Variable(i, null, this.getTraceId(), 2, null);
                }
                booleanRef.element = false;
                return new TracingTensor.Variable(i, null, dTensor.getShape(), this.getTraceId(), 2, null);
            }

            @Override // org.diffkt.Wrapper
            @NotNull
            public RandomKey wrapRandomKey(@NotNull RandomKey randomKey) {
                Intrinsics.checkNotNullParameter(randomKey, "value");
                booleanRef.element = false;
                int i = intRef.element;
                intRef.element = i + 1;
                arrayList.add(new TracingRandomKey.RandomTensorWrapper(randomKey));
                return new TracingRandomKey.Variable(i, this.getTraceId(), null, 4, null);
            }
        };
        if (this.$evaluatorToUse.compareTo(JitEvaluatorToUse.Scalar) >= 0 && !booleanRef.element) {
            throw new IllegalArgumentException("Requested evaluatorToUse=" + this.$evaluatorToUse + " but some input was not a FloatScalar");
        }
        Object invoke = this.$wrapInput != null ? this.$wrapInput.invoke(tinput, wrapper) : wrapper.wrap(tinput);
        TracingJitKt$jit$JitCacheEntry tracingJitKt$jit$JitCacheEntry = (TracingJitKt$jit$JitCacheEntry) this.cache.get(invoke);
        if (tracingJitKt$jit$JitCacheEntry == null) {
            DedaggedTracingTensor<TOutput> dedeepen = DedagKt.dedeepen(DedagKt.dedag$default(SimplifyKt.simplify(this.$f.invoke(invoke)), intRef.element, this.traceId, false, 8, null), 500);
            booleanRef.element = booleanRef.element && onlyScalarFloatValues(dedeepen, this.traceId);
            if (this.$evaluatorToUse.compareTo(JitEvaluatorToUse.Scalar) >= 0 && !booleanRef.element && this.$evaluatorToUse != JitEvaluatorToUse.BestAvailable) {
                throw new IllegalArgumentException("Requested evaluatorToUse=" + this.$evaluatorToUse + " but some intermediate result was not a scalar");
            }
            if (booleanRef.element && (this.$evaluatorToUse == JitEvaluatorToUse.Jvm || this.$evaluatorToUse == JitEvaluatorToUse.BestAvailable)) {
                Function1 jvmTracingTranslator = JvmTracingTranslatorKt.jvmTracingTranslator(dedeepen);
                if (jvmTracingTranslator == null && this.$evaluatorToUse == JitEvaluatorToUse.Jvm) {
                    throw new IllegalArgumentException("Requested jvm evaluator but the function would be too large.");
                }
                function1 = jvmTracingTranslator;
            } else {
                function1 = null;
            }
            tracingJitKt$jit$JitCacheEntry = new TracingJitKt$jit$JitCacheEntry(dedeepen, function1);
            this.cache.put((TracingJitKt$jit$2$cache$1) invoke, tracingJitKt$jit$JitCacheEntry);
        }
        TracingJitKt$jit$JitCacheEntry tracingJitKt$jit$JitCacheEntry2 = tracingJitKt$jit$JitCacheEntry;
        DedaggedTracingTensor component1 = tracingJitKt$jit$JitCacheEntry2.component1();
        Function1 component2 = tracingJitKt$jit$JitCacheEntry2.component2();
        if (!booleanRef.element || !component1.getCanScalarEval()) {
            eval = TracingJitKt.eval(component1, arrayList);
        } else if (component2 != null) {
            ArrayList<DTensor> arrayList2 = arrayList;
            ArrayList arrayList3 = new ArrayList(CollectionsKt.collectionSizeOrDefault(arrayList2, 10));
            for (DTensor dTensor : arrayList2) {
                Intrinsics.checkNotNull(dTensor, "null cannot be cast to non-null type org.diffkt.FloatScalar");
                arrayList3.add(Float.valueOf(((FloatScalar) dTensor).getValue()));
            }
            eval = component2.invoke(CollectionsKt.toFloatArray(arrayList3));
        } else {
            eval = TracingJitKt.scalarEval(component1, arrayList);
        }
        TOutput toutput = (TOutput) eval;
        TimingStats timingStats2 = this.$evaluationLog;
        if (timingStats2 != null) {
            timingStats2.end();
        }
        return toutput;
    }

    public final boolean onlyScalarFloatValues(@NotNull TracingTensor tracingTensor, @NotNull TraceId traceId) {
        Intrinsics.checkNotNullParameter(tracingTensor, "trace");
        Intrinsics.checkNotNullParameter(traceId, "traceId");
        for (Traceable traceable : TopologicalSortKt.topologicalSort(CollectionsKt.listOf(tracingTensor), new Function1<Traceable, Boolean>() { // from class: org.diffkt.tracing.TracingJitKt$jit$2$onlyScalarFloatValues$1
            @NotNull
            public final Boolean invoke(@NotNull Traceable traceable2) {
                Intrinsics.checkNotNullParameter(traceable2, "it");
                return false;
            }
        })) {
            if ((traceable instanceof TracingTensor) && !Intrinsics.areEqual(((TracingTensor) traceable).getShape(), Shape.Companion.invoke())) {
                return false;
            }
            if (traceable instanceof TracingScalar.Variable) {
                if (!Intrinsics.areEqual(((TracingScalar.Variable) traceable).getTraceId(), traceId)) {
                    return false;
                }
            } else if (traceable instanceof TracingScalar.Constant) {
                if (!(((TracingScalar.Constant) traceable).getValues() instanceof FloatScalar)) {
                    return false;
                }
            } else if ((traceable instanceof TracingTensor.RandomFloats) && !Intrinsics.areEqual(((TracingTensor.RandomFloats) traceable).getTraceId(), traceId)) {
                return false;
            }
        }
        return true;
    }

    public final <TOutput> boolean onlyScalarFloatValues(@NotNull DedaggedTracingTensor<TOutput> dedaggedTracingTensor, @NotNull final TraceId traceId) {
        boolean z;
        Intrinsics.checkNotNullParameter(dedaggedTracingTensor, "trace");
        Intrinsics.checkNotNullParameter(traceId, "traceId");
        final Ref.BooleanRef booleanRef = new Ref.BooleanRef();
        booleanRef.element = true;
        for (Pair<Integer, Traceable> pair : dedaggedTracingTensor.getAssignments()) {
            if (pair.getSecond() instanceof TracingTensor) {
                if (booleanRef.element) {
                    Object second = pair.getSecond();
                    Intrinsics.checkNotNull(second, "null cannot be cast to non-null type org.diffkt.tracing.TracingTensor");
                    if (onlyScalarFloatValues((TracingTensor) second, traceId)) {
                        z = true;
                        booleanRef.element = z;
                    }
                }
                z = false;
                booleanRef.element = z;
            }
        }
        new Wrapper() { // from class: org.diffkt.tracing.TracingJitKt$jit$2$onlyScalarFloatValues$resultScanner$1
            @Override // org.diffkt.Wrapper
            @NotNull
            public DTensor wrapDTensor(@NotNull DTensor dTensor) {
                Intrinsics.checkNotNullParameter(dTensor, "value");
                booleanRef.element = booleanRef.element && ((dTensor instanceof FloatScalar) || ((dTensor instanceof TracingTensor) && this.onlyScalarFloatValues((TracingTensor) dTensor, traceId)));
                return dTensor;
            }
        }.wrap(dedaggedTracingTensor.getValue());
        return booleanRef.element;
    }
}
