package org.diffkt.tracing;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import kotlin.Metadata;
import kotlin.Pair;
import kotlin._Assertions;
import kotlin.collections.CollectionsKt;
import kotlin.collections.IntIterator;
import kotlin.jvm.functions.Function1;
import kotlin.jvm.internal.Intrinsics;
import kotlin.jvm.internal.SerializedIr;
import kotlin.ranges.RangesKt;
import org.diffkt.BroadcastKt;
import org.diffkt.Convolve;
import org.diffkt.DScalar;
import org.diffkt.DTensor;
import org.diffkt.DivKt;
import org.diffkt.FloatScalar;
import org.diffkt.FloatTensor;
import org.diffkt.MinusKt;
import org.diffkt.NoDerivativeID;
import org.diffkt.PowerKt;
import org.diffkt.ReshapeKt;
import org.diffkt.Shape;
import org.diffkt.StridedFloatTensor;
import org.diffkt.TimesKt;
import org.diffkt.TranscendentalKt;
import org.diffkt.ViewKt;
import org.diffkt.tracing.TracingScalar;
import org.diffkt.tracing.TracingTensor;
import org.jetbrains.annotations.NotNull;

/* compiled from: Simplify.kt */
@SerializedIr(b = {"��"})
@Metadata(mv = {Convolve.H_AXIS, 6, Convolve.N_AXIS}, k = Convolve.H_AXIS, xi = 48, d1 = {"��~\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0010\u0011\n\u0002\u0018\u0002\n\u0002\b\u0004\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\bÂ\u0002\u0018��2\u00020\u0001B\u0007\b\u0002¢\u0006\u0002\u0010\u0002J\u0010\u0010\r\u001a\u00020\u000e2\u0006\u0010\u000f\u001a\u00020\u000eH\u0002J\"\u0010\u0010\u001a\u000e\u0012\u0004\u0012\u00020\u000e\u0012\u0004\u0012\u00020\u000e0\u00112\u0006\u0010\u0012\u001a\u00020\u000e2\u0006\u0010\u0013\u001a\u00020\u000eJ\u000e\u0010\u0014\u001a\u00020\u00152\u0006\u0010\u000f\u001a\u00020\u0015J\u000e\u0010\u0014\u001a\u00020\u000e2\u0006\u0010\u000f\u001a\u00020\u000eJ\u0010\u0010\u0016\u001a\u00020\u000e2\u0006\u0010\u000f\u001a\u00020\u0017H\u0016J\u0010\u0010\u0018\u001a\u00020\u000e2\u0006\u0010\u000f\u001a\u00020\u0019H\u0016J\u0010\u0010\u001a\u001a\u00020\u000e2\u0006\u0010\u000f\u001a\u00020\u001bH\u0016J\u0010\u0010\u001c\u001a\u00020\u000e2\u0006\u0010\u000f\u001a\u00020\u001dH\u0016J\u0010\u0010\u001e\u001a\u00020\u000e2\u0006\u0010\u000f\u001a\u00020\u001fH\u0016J\u0010\u0010 \u001a\u00020\u000e2\u0006\u0010\u000f\u001a\u00020!H\u0016J\u0010\u0010\"\u001a\u00020\u000e2\u0006\u0010\u000f\u001a\u00020#H\u0016J\u0010\u0010$\u001a\u00020\u000e2\u0006\u0010\u000f\u001a\u00020%H\u0016J\u0010\u0010&\u001a\u00020\u000e2\u0006\u0010\u000f\u001a\u00020'H\u0016J\u0010\u0010(\u001a\u00020\u000e2\u0006\u0010\u000f\u001a\u00020)H\u0016J\u0010\u0010*\u001a\u00020\u000e2\u0006\u0010\u000f\u001a\u00020+H\u0016J\u0010\u0010,\u001a\u00020\u000e2\u0006\u0010\u000f\u001a\u00020-H\u0016R\u0011\u0010\u0003\u001a\u00020\u0004¢\u0006\b\n��\u001a\u0004\b\u0005\u0010\u0006R\u001b\u0010\u0007\u001a\n\u0012\u0006\u0012\u0004\u0018\u00010\t0\b¢\u0006\n\n\u0002\u0010\f\u001a\u0004\b\n\u0010\u000b¨\u0006."}, d2 = {"Lorg/diffkt/tracing/optimizer;", "Lorg/diffkt/tracing/ShallowTracingRewriter;", "()V", "constantEvaluator", "Lorg/diffkt/tracing/TracingEvaluator;", "getConstantEvaluator", "()Lorg/diffkt/tracing/TracingEvaluator;", "emptyVariables", "", "Lorg/diffkt/DTensor;", "getEmptyVariables", "()[Lorg/diffkt/DTensor;", "[Lorg/diffkt/DTensor;", "makeConstant", "Lorg/diffkt/tracing/TracingTensor;", "x", "removeBroadcast", "Lkotlin/Pair;", "left", "right", "rewrite", "Lorg/diffkt/tracing/Traceable;", "visitDiv", "Lorg/diffkt/tracing/TracingTensor$Div;", "visitIfThenElse", "Lorg/diffkt/tracing/TracingTensor$IfThenElse;", "visitMinus", "Lorg/diffkt/tracing/TracingTensor$Minus;", "visitPlus", "Lorg/diffkt/tracing/TracingTensor$Plus;", "visitPow", "Lorg/diffkt/tracing/TracingTensor$Pow;", "visitSin", "Lorg/diffkt/tracing/TracingTensor$Sin;", "visitSplitPart", "Lorg/diffkt/tracing/TracingTensor$SplitPart;", "visitTimes", "Lorg/diffkt/tracing/TracingTensor$Times;", "visitTimesScalar", "Lorg/diffkt/tracing/TracingTensor$TimesScalar;", "visitUnaryMinus", "Lorg/diffkt/tracing/TracingTensor$UnaryMinus;", "visitView2", "Lorg/diffkt/tracing/TracingTensor$View2;", "visitView3", "Lorg/diffkt/tracing/TracingTensor$View3;", "api"})
/* loaded from: input_file:org/diffkt/tracing/optimizer.class */
final class optimizer extends ShallowTracingRewriter {

    @NotNull
    public static final optimizer INSTANCE = new optimizer();

    @NotNull
    private static final DTensor[] emptyVariables;

    @NotNull
    private static final TracingEvaluator constantEvaluator;

    private optimizer() {
    }

    @NotNull
    public final Traceable rewrite(@NotNull Traceable traceable) {
        Intrinsics.checkNotNullParameter(traceable, "x");
        Traceable traceable2 = traceable;
        while (true) {
            Traceable traceable3 = traceable2;
            Traceable visit = visit2(traceable3);
            if (visit == traceable3) {
                return traceable3;
            }
            traceable2 = visit;
        }
    }

    @NotNull
    public final TracingTensor rewrite(@NotNull TracingTensor tracingTensor) {
        Intrinsics.checkNotNullParameter(tracingTensor, "x");
        Traceable rewrite = rewrite((Traceable) tracingTensor);
        Intrinsics.checkNotNull(rewrite, "null cannot be cast to non-null type org.diffkt.tracing.TracingTensor");
        return (TracingTensor) rewrite;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.diffkt.tracing.ShallowTracingRewriter, org.diffkt.tracing.TracingVisitor
    @NotNull
    /* renamed from: visitDiv */
    public Traceable visitDiv2(@NotNull TracingTensor.Div div) {
        Intrinsics.checkNotNullParameter(div, "x");
        if (!(div.getLeft() instanceof TracingScalar.UnaryMinus)) {
            return div;
        }
        DTensor unaryMinus = MinusKt.unaryMinus(DivKt.div(((TracingScalar.UnaryMinus) div.getLeft()).getX(), div.getRight()));
        Intrinsics.checkNotNull(unaryMinus, "null cannot be cast to non-null type org.diffkt.tracing.TracingTensor");
        return (TracingTensor) unaryMinus;
    }

    @NotNull
    public final Pair<TracingTensor, TracingTensor> removeBroadcast(@NotNull TracingTensor tracingTensor, @NotNull TracingTensor tracingTensor2) {
        boolean z;
        boolean z2;
        Intrinsics.checkNotNullParameter(tracingTensor, "left");
        Intrinsics.checkNotNullParameter(tracingTensor2, "right");
        TracingTensor tracingTensor3 = tracingTensor;
        TracingTensor tracingTensor4 = tracingTensor2;
        while (tracingTensor3 instanceof TracingTensor.Expand) {
            tracingTensor3 = ((TracingTensor.Expand) tracingTensor3).getX();
        }
        while (tracingTensor4 instanceof TracingTensor.Expand) {
            tracingTensor4 = ((TracingTensor.Expand) tracingTensor4).getX();
        }
        if ((tracingTensor3 instanceof TracingTensor.Unsqueeze) && !(tracingTensor4 instanceof TracingTensor.Unsqueeze)) {
            while (tracingTensor3 instanceof TracingTensor.Unsqueeze) {
                TracingTensor.Unsqueeze unsqueeze = (TracingTensor.Unsqueeze) tracingTensor3;
                Iterable until = RangesKt.until(0, unsqueeze.getAxis());
                if (!(until instanceof Collection) || !((Collection) until).isEmpty()) {
                    IntIterator it = until.iterator();
                    while (true) {
                        if (!it.hasNext()) {
                            z2 = true;
                            break;
                        }
                        if (!(unsqueeze.getX().getShape().get(it.nextInt()) == 1)) {
                            z2 = false;
                            break;
                        }
                    }
                } else {
                    z2 = true;
                }
                if (!z2) {
                    break;
                }
                tracingTensor3 = ((TracingTensor.Unsqueeze) tracingTensor3).getX();
            }
        } else if ((tracingTensor4 instanceof TracingTensor.Unsqueeze) && !(tracingTensor3 instanceof TracingTensor.Unsqueeze)) {
            while (tracingTensor4 instanceof TracingTensor.Unsqueeze) {
                TracingTensor.Unsqueeze unsqueeze2 = (TracingTensor.Unsqueeze) tracingTensor4;
                Iterable until2 = RangesKt.until(0, unsqueeze2.getAxis());
                if (!(until2 instanceof Collection) || !((Collection) until2).isEmpty()) {
                    IntIterator it2 = until2.iterator();
                    while (true) {
                        if (!it2.hasNext()) {
                            z = true;
                            break;
                        }
                        if (!(unsqueeze2.getX().getShape().get(it2.nextInt()) == 1)) {
                            z = false;
                            break;
                        }
                    }
                } else {
                    z = true;
                }
                if (!z) {
                    break;
                }
                tracingTensor4 = ((TracingTensor.Unsqueeze) tracingTensor4).getX();
            }
        }
        return new Pair<>(tracingTensor3, tracingTensor4);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.diffkt.tracing.ShallowTracingRewriter, org.diffkt.tracing.TracingVisitor
    @NotNull
    /* renamed from: visitTimesScalar */
    public Traceable visitTimesScalar2(@NotNull TracingTensor.TimesScalar timesScalar) {
        Intrinsics.checkNotNullParameter(timesScalar, "x");
        TracingTensor left = timesScalar.getLeft();
        TracingTensor right = timesScalar.getRight();
        if (right instanceof TracingScalar.Constant) {
            DTensor values = ((TracingScalar.Constant) right).getValues();
            Intrinsics.checkNotNull(values, "null cannot be cast to non-null type org.diffkt.FloatScalar");
            if (((FloatScalar) values).getValue() == 1.0f) {
                return left;
            }
        }
        if (left instanceof TracingScalar.Constant) {
            DTensor values2 = ((TracingScalar.Constant) left).getValues();
            Intrinsics.checkNotNull(values2, "null cannot be cast to non-null type org.diffkt.FloatScalar");
            if (((FloatScalar) values2).getValue() == 1.0f) {
                return right;
            }
        }
        if (right instanceof TracingScalar.Zero) {
            return right;
        }
        if (left instanceof TracingScalar.Zero) {
            return left;
        }
        if (right instanceof TracingScalar.Constant) {
            DTensor values3 = ((TracingScalar.Constant) right).getValues();
            Intrinsics.checkNotNull(values3, "null cannot be cast to non-null type org.diffkt.FloatScalar");
            if (((FloatScalar) values3).getValue() == 0.0f) {
                return right;
            }
        }
        if (left instanceof TracingScalar.Constant) {
            DTensor values4 = ((TracingScalar.Constant) left).getValues();
            Intrinsics.checkNotNull(values4, "null cannot be cast to non-null type org.diffkt.FloatScalar");
            if (((FloatScalar) values4).getValue() == 0.0f) {
                return left;
            }
        }
        if (left instanceof TracingScalar.Constant) {
            DTensor values5 = ((TracingScalar.Constant) left).getValues();
            Intrinsics.checkNotNull(values5, "null cannot be cast to non-null type org.diffkt.FloatScalar");
            if (((FloatScalar) values5).getValue() == -1.0f) {
                DTensor unaryMinus = right.mo153getOperations().unaryMinus(right);
                Intrinsics.checkNotNull(unaryMinus, "null cannot be cast to non-null type org.diffkt.tracing.TracingTensor");
                return (TracingTensor) unaryMinus;
            }
        }
        if (right instanceof TracingScalar.Constant) {
            DTensor values6 = ((TracingScalar.Constant) right).getValues();
            Intrinsics.checkNotNull(values6, "null cannot be cast to non-null type org.diffkt.FloatScalar");
            if (((FloatScalar) values6).getValue() == -1.0f) {
                DTensor unaryMinus2 = left.mo153getOperations().unaryMinus(left);
                Intrinsics.checkNotNull(unaryMinus2, "null cannot be cast to non-null type org.diffkt.tracing.TracingTensor");
                return (TracingTensor) unaryMinus2;
            }
        }
        if (right instanceof TracingScalar.UnaryMinus) {
            TracingScalar left2 = timesScalar.getLeft();
            TracingTensor right2 = timesScalar.getRight();
            Intrinsics.checkNotNull(right2, "null cannot be cast to non-null type org.diffkt.tracing.TracingScalar.UnaryMinus");
            DTensor unaryMinus3 = MinusKt.unaryMinus(TimesKt.times((DScalar) left2, (DTensor) ((TracingScalar.UnaryMinus) right2).getX()));
            Intrinsics.checkNotNull(unaryMinus3, "null cannot be cast to non-null type org.diffkt.tracing.TracingTensor");
            return (TracingTensor) unaryMinus3;
        }
        if (left instanceof TracingScalar.UnaryMinus) {
            TracingScalar left3 = timesScalar.getLeft();
            Intrinsics.checkNotNull(left3, "null cannot be cast to non-null type org.diffkt.tracing.TracingScalar.UnaryMinus");
            DTensor unaryMinus4 = MinusKt.unaryMinus(TimesKt.times(((TracingScalar.UnaryMinus) left3).getX(), timesScalar.getRight()));
            Intrinsics.checkNotNull(unaryMinus4, "null cannot be cast to non-null type org.diffkt.tracing.TracingTensor");
            return (TracingTensor) unaryMinus4;
        }
        if ((left instanceof TracingTensor.UnaryMinus) && (right instanceof TracingTensor.UnaryMinus)) {
            DTensor times = TimesKt.times(((TracingTensor.UnaryMinus) left).getX(), ((TracingTensor.UnaryMinus) right).getX());
            Intrinsics.checkNotNull(times, "null cannot be cast to non-null type org.diffkt.tracing.TracingTensor");
            return (TracingTensor) times;
        }
        if (left instanceof TracingScalar.Pow) {
            if (((TracingScalar.Pow) left).getExponent() == -1.0f) {
                DTensor div = DivKt.div(right, ((TracingScalar.Pow) left).getBase());
                Intrinsics.checkNotNull(div, "null cannot be cast to non-null type org.diffkt.tracing.TracingTensor");
                return (TracingTensor) div;
            }
        }
        if (right instanceof TracingScalar.Pow) {
            if (((TracingScalar.Pow) right).getExponent() == -1.0f) {
                DTensor div2 = DivKt.div(left, ((TracingScalar.Pow) right).getBase());
                Intrinsics.checkNotNull(div2, "null cannot be cast to non-null type org.diffkt.tracing.TracingTensor");
                return (TracingTensor) div2;
            }
        }
        if (left == timesScalar.getLeft() && right == timesScalar.getRight()) {
            return timesScalar;
        }
        DTensor timesScalar2 = timesScalar.mo153getOperations().timesScalar((DScalar) left, right, NoDerivativeID.INSTANCE);
        Intrinsics.checkNotNull(timesScalar2, "null cannot be cast to non-null type org.diffkt.tracing.TracingTensor");
        return (TracingTensor) timesScalar2;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.diffkt.tracing.ShallowTracingRewriter, org.diffkt.tracing.TracingVisitor
    @NotNull
    /* renamed from: visitTimes */
    public Traceable visitTimes2(@NotNull TracingTensor.Times times) {
        Intrinsics.checkNotNullParameter(times, "x");
        Pair<TracingTensor, TracingTensor> removeBroadcast = removeBroadcast(times.getLeft(), times.getRight());
        TracingTensor tracingTensor = (TracingTensor) removeBroadcast.component1();
        TracingTensor tracingTensor2 = (TracingTensor) removeBroadcast.component2();
        if (tracingTensor2 instanceof TracingScalar) {
            return visitTimesScalar2(new TracingTensor.TimesScalar((TracingScalar) tracingTensor2, tracingTensor));
        }
        if (tracingTensor instanceof TracingScalar) {
            return visitTimesScalar2(new TracingTensor.TimesScalar((TracingScalar) tracingTensor, tracingTensor2));
        }
        if ((tracingTensor2 instanceof TracingTensor.Constant) && (((TracingTensor.Constant) tracingTensor2).getValues() instanceof FloatTensor) && ((FloatTensor) ((TracingTensor.Constant) tracingTensor2).getValues()).all(new Function1<Float, Boolean>() { // from class: org.diffkt.tracing.optimizer$visitTimes$1
            @NotNull
            public final Boolean invoke(float f) {
                return Boolean.valueOf(f == 1.0f);
            }

            public /* bridge */ /* synthetic */ Object invoke(Object obj) {
                return invoke(((Number) obj).floatValue());
            }
        })) {
            return tracingTensor;
        }
        if ((tracingTensor instanceof TracingTensor.Constant) && (((TracingTensor.Constant) tracingTensor).getValues() instanceof FloatTensor) && ((FloatTensor) ((TracingTensor.Constant) tracingTensor).getValues()).all(new Function1<Float, Boolean>() { // from class: org.diffkt.tracing.optimizer$visitTimes$2
            @NotNull
            public final Boolean invoke(float f) {
                return Boolean.valueOf(f == 1.0f);
            }

            public /* bridge */ /* synthetic */ Object invoke(Object obj) {
                return invoke(((Number) obj).floatValue());
            }
        })) {
            return tracingTensor2;
        }
        if (tracingTensor == times.getLeft() && tracingTensor2 == times.getRight()) {
            return times;
        }
        DTensor times2 = times.mo153getOperations().times(tracingTensor, tracingTensor2, NoDerivativeID.INSTANCE);
        Intrinsics.checkNotNull(times2, "null cannot be cast to non-null type org.diffkt.tracing.TracingTensor");
        return (TracingTensor) times2;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.diffkt.tracing.ShallowTracingRewriter, org.diffkt.tracing.TracingVisitor
    @NotNull
    /* renamed from: visitPlus */
    public Traceable visitPlus2(@NotNull TracingTensor.Plus plus) {
        Intrinsics.checkNotNullParameter(plus, "x");
        Pair<TracingTensor, TracingTensor> removeBroadcast = removeBroadcast(plus.getLeft(), plus.getRight());
        TracingTensor tracingTensor = (TracingTensor) removeBroadcast.component1();
        TracingTensor tracingTensor2 = (TracingTensor) removeBroadcast.component2();
        if (tracingTensor instanceof TracingTensor.Zero) {
            return tracingTensor2;
        }
        if (tracingTensor2 instanceof TracingTensor.Zero) {
            return tracingTensor;
        }
        if ((tracingTensor2 instanceof TracingTensor.Constant) && (((TracingTensor.Constant) tracingTensor2).getValues() instanceof FloatTensor) && ((FloatTensor) ((TracingTensor.Constant) tracingTensor2).getValues()).all(new Function1<Float, Boolean>() { // from class: org.diffkt.tracing.optimizer$visitPlus$1
            @NotNull
            public final Boolean invoke(float f) {
                return Boolean.valueOf(f == 0.0f);
            }

            public /* bridge */ /* synthetic */ Object invoke(Object obj) {
                return invoke(((Number) obj).floatValue());
            }
        })) {
            return tracingTensor;
        }
        if ((tracingTensor instanceof TracingTensor.Constant) && (((TracingTensor.Constant) tracingTensor).getValues() instanceof FloatTensor) && ((FloatTensor) ((TracingTensor.Constant) tracingTensor).getValues()).all(new Function1<Float, Boolean>() { // from class: org.diffkt.tracing.optimizer$visitPlus$2
            @NotNull
            public final Boolean invoke(float f) {
                return Boolean.valueOf(f == 0.0f);
            }

            public /* bridge */ /* synthetic */ Object invoke(Object obj) {
                return invoke(((Number) obj).floatValue());
            }
        })) {
            return tracingTensor2;
        }
        if (tracingTensor2 instanceof TracingTensor.UnaryMinus) {
            DTensor minus = MinusKt.minus(tracingTensor, ((TracingTensor.UnaryMinus) tracingTensor2).getX());
            Intrinsics.checkNotNull(minus, "null cannot be cast to non-null type org.diffkt.tracing.TracingTensor");
            return (TracingTensor) minus;
        }
        if (Intrinsics.areEqual(tracingTensor, plus.getLeft()) && Intrinsics.areEqual(tracingTensor2, plus.getRight())) {
            return plus;
        }
        if ((tracingTensor instanceof TracingScalar) && (tracingTensor2 instanceof TracingScalar)) {
            return new TracingScalar.Plus((TracingScalar) tracingTensor, (TracingScalar) tracingTensor2);
        }
        DTensor plus2 = plus.mo153getOperations().plus(tracingTensor, tracingTensor2, NoDerivativeID.INSTANCE);
        Intrinsics.checkNotNull(plus2, "null cannot be cast to non-null type org.diffkt.tracing.TracingTensor");
        return (TracingTensor) plus2;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.diffkt.tracing.ShallowTracingRewriter, org.diffkt.tracing.TracingVisitor
    @NotNull
    /* renamed from: visitMinus */
    public Traceable visitMinus2(@NotNull TracingTensor.Minus minus) {
        Intrinsics.checkNotNullParameter(minus, "x");
        Pair<TracingTensor, TracingTensor> removeBroadcast = removeBroadcast(minus.getLeft(), minus.getRight());
        TracingTensor tracingTensor = (TracingTensor) removeBroadcast.component1();
        TracingTensor tracingTensor2 = (TracingTensor) removeBroadcast.component2();
        if (tracingTensor2 instanceof TracingTensor.Zero) {
            return tracingTensor;
        }
        if (tracingTensor instanceof TracingTensor.Zero) {
            DTensor unaryMinus = tracingTensor2.mo153getOperations().unaryMinus(tracingTensor2);
            Intrinsics.checkNotNull(unaryMinus, "null cannot be cast to non-null type org.diffkt.tracing.TracingTensor");
            return (TracingTensor) unaryMinus;
        }
        if ((tracingTensor instanceof TracingTensor.Constant) && (((TracingTensor.Constant) tracingTensor).getValues() instanceof FloatTensor) && ((FloatTensor) ((TracingTensor.Constant) tracingTensor).getValues()).all(new Function1<Float, Boolean>() { // from class: org.diffkt.tracing.optimizer$visitMinus$1
            @NotNull
            public final Boolean invoke(float f) {
                return Boolean.valueOf(f == 0.0f);
            }

            public /* bridge */ /* synthetic */ Object invoke(Object obj) {
                return invoke(((Number) obj).floatValue());
            }
        })) {
            DTensor unaryMinus2 = tracingTensor2.mo153getOperations().unaryMinus(tracingTensor2);
            Intrinsics.checkNotNull(unaryMinus2, "null cannot be cast to non-null type org.diffkt.tracing.TracingTensor");
            return (TracingTensor) unaryMinus2;
        }
        if ((tracingTensor2 instanceof TracingTensor.Constant) && (((TracingTensor.Constant) tracingTensor2).getValues() instanceof FloatTensor) && ((FloatTensor) ((TracingTensor.Constant) tracingTensor2).getValues()).all(new Function1<Float, Boolean>() { // from class: org.diffkt.tracing.optimizer$visitMinus$2
            @NotNull
            public final Boolean invoke(float f) {
                return Boolean.valueOf(f == 0.0f);
            }

            public /* bridge */ /* synthetic */ Object invoke(Object obj) {
                return invoke(((Number) obj).floatValue());
            }
        })) {
            return tracingTensor;
        }
        if (tracingTensor instanceof TracingTensor.UnaryMinus) {
            DTensor unaryMinus3 = TracingTensorOperations.INSTANCE.unaryMinus(TracingTensorOperations.INSTANCE.plus(((TracingTensor.UnaryMinus) tracingTensor).getX(), tracingTensor2, NoDerivativeID.INSTANCE));
            Intrinsics.checkNotNull(unaryMinus3, "null cannot be cast to non-null type org.diffkt.tracing.TracingTensor");
            return (TracingTensor) unaryMinus3;
        }
        if (Intrinsics.areEqual(tracingTensor, minus.getLeft()) && Intrinsics.areEqual(tracingTensor2, minus.getRight())) {
            return minus;
        }
        DTensor minus2 = minus.mo153getOperations().minus(tracingTensor, tracingTensor2, NoDerivativeID.INSTANCE);
        Intrinsics.checkNotNull(minus2, "null cannot be cast to non-null type org.diffkt.tracing.TracingTensor");
        return (TracingTensor) minus2;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.diffkt.tracing.ShallowTracingRewriter, org.diffkt.tracing.TracingVisitor
    @NotNull
    /* renamed from: visitUnaryMinus */
    public Traceable visitUnaryMinus2(@NotNull TracingTensor.UnaryMinus unaryMinus) {
        Intrinsics.checkNotNullParameter(unaryMinus, "x");
        TracingTensor x = unaryMinus.getX();
        if (x instanceof TracingTensor.TimesScalar) {
            TracingScalar left = ((TracingTensor.TimesScalar) x).getLeft();
            TracingTensor right = ((TracingTensor.TimesScalar) x).getRight();
            if (right instanceof TracingTensor.UnaryMinus) {
                DTensor times = TimesKt.times((DScalar) left, (DTensor) ((TracingTensor.UnaryMinus) right).getX());
                Intrinsics.checkNotNull(times, "null cannot be cast to non-null type org.diffkt.tracing.TracingTensor");
                return (TracingTensor) times;
            }
        }
        if (x instanceof TracingTensor.UnaryMinus) {
            return ((TracingTensor.UnaryMinus) x).getX();
        }
        if (x instanceof TracingTensor.Constant) {
            return TracingTensor.Constant.Companion.wrap(Intrinsics.areEqual(((TracingTensor.Constant) x).getValues(), FloatScalar.Companion.getZERO()) ? FloatScalar.Companion.getZERO() : MinusKt.unaryMinus(((TracingTensor.Constant) x).getValues()));
        }
        return x instanceof TracingTensor.Zero ? x : unaryMinus;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.diffkt.tracing.ShallowTracingRewriter, org.diffkt.tracing.TracingVisitor
    @NotNull
    /* renamed from: visitSin */
    public Traceable visitSin2(@NotNull TracingTensor.Sin sin) {
        Intrinsics.checkNotNullParameter(sin, "x");
        TracingTensor x = sin.getX();
        if (!(x instanceof TracingTensor.UnaryMinus)) {
            return sin;
        }
        DTensor unaryMinus = MinusKt.unaryMinus(TranscendentalKt.sin(((TracingTensor.UnaryMinus) x).getX()));
        Intrinsics.checkNotNull(unaryMinus, "null cannot be cast to non-null type org.diffkt.tracing.TracingTensor");
        return (TracingTensor) unaryMinus;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.diffkt.tracing.ShallowTracingRewriter, org.diffkt.tracing.TracingVisitor
    @NotNull
    /* renamed from: visitPow */
    public Traceable visitPow2(@NotNull TracingTensor.Pow pow) {
        Intrinsics.checkNotNullParameter(pow, "x");
        TracingTensor base = pow.getBase();
        if (!(base instanceof TracingTensor.Pow)) {
            return pow;
        }
        DTensor pow2 = PowerKt.pow(((TracingTensor.Pow) base).getBase(), ((TracingTensor.Pow) base).getExponent() * pow.getExponent());
        Intrinsics.checkNotNull(pow2, "null cannot be cast to non-null type org.diffkt.tracing.TracingTensor");
        return (TracingTensor) pow2;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.diffkt.tracing.ShallowTracingRewriter, org.diffkt.tracing.TracingVisitor
    @NotNull
    /* renamed from: visitView2 */
    public Traceable visitView22(@NotNull TracingTensor.View2 view2) {
        Intrinsics.checkNotNullParameter(view2, "x");
        TracingTensor x = view2.getX();
        if (x instanceof TracingTensor.Plus) {
            TracingTensor rewrite = rewrite(visitView2$makeView2(view2, ((TracingTensor.Plus) view2.getX()).getLeft(), view2.getX().getShape()));
            TracingTensor rewrite2 = rewrite(visitView2$makeView2(view2, ((TracingTensor.Plus) view2.getX()).getRight(), view2.getX().getShape()));
            if (Intrinsics.areEqual(rewrite, rewrite2)) {
                TracingScalar.Constant constant = new TracingScalar.Constant(new FloatScalar(2.0f));
                if (!(view2 instanceof DScalar)) {
                    return new TracingTensor.Plus(constant, rewrite2);
                }
                Intrinsics.checkNotNull(rewrite2, "null cannot be cast to non-null type org.diffkt.tracing.TracingScalar");
                return new TracingScalar.TimesScalar(constant, (TracingScalar) rewrite2);
            }
            if (!(view2 instanceof DScalar)) {
                return new TracingTensor.Plus(rewrite, rewrite2);
            }
            Intrinsics.checkNotNull(rewrite, "null cannot be cast to non-null type org.diffkt.tracing.TracingScalar");
            Intrinsics.checkNotNull(rewrite2, "null cannot be cast to non-null type org.diffkt.tracing.TracingScalar");
            return new TracingScalar.Plus((TracingScalar) rewrite, (TracingScalar) rewrite2);
        }
        if (x instanceof TracingTensor.Minus) {
            TracingTensor rewrite3 = rewrite(visitView2$makeView2(view2, ((TracingTensor.Minus) view2.getX()).getLeft(), view2.getX().getShape()));
            TracingTensor rewrite4 = rewrite(visitView2$makeView2(view2, ((TracingTensor.Minus) view2.getX()).getRight(), view2.getX().getShape()));
            if (Intrinsics.areEqual(rewrite3, rewrite4)) {
                return TracingTensor.Constant.Companion.wrap(FloatTensor.Companion.zeros(view2.getShape()));
            }
            if (!(view2 instanceof DScalar)) {
                return new TracingTensor.Plus(rewrite3, rewrite4);
            }
            Intrinsics.checkNotNull(rewrite3, "null cannot be cast to non-null type org.diffkt.tracing.TracingScalar");
            Intrinsics.checkNotNull(rewrite4, "null cannot be cast to non-null type org.diffkt.tracing.TracingScalar");
            return new TracingScalar.Minus((TracingScalar) rewrite3, (TracingScalar) rewrite4);
        }
        if (x instanceof TracingTensor.Times) {
            TracingTensor rewrite5 = rewrite(visitView2$makeView2(view2, ((TracingTensor.Times) view2.getX()).getLeft(), view2.getX().getShape()));
            TracingTensor rewrite6 = rewrite(visitView2$makeView2(view2, ((TracingTensor.Times) view2.getX()).getRight(), view2.getX().getShape()));
            if (!(view2 instanceof DScalar)) {
                return new TracingTensor.Times(rewrite5, rewrite6);
            }
            Intrinsics.checkNotNull(rewrite5, "null cannot be cast to non-null type org.diffkt.tracing.TracingScalar");
            Intrinsics.checkNotNull(rewrite6, "null cannot be cast to non-null type org.diffkt.tracing.TracingScalar");
            return new TracingScalar.TimesScalar((TracingScalar) rewrite5, (TracingScalar) rewrite6);
        }
        if (x instanceof TracingTensor.TimesScalar) {
            TracingScalar left = ((TracingTensor.TimesScalar) view2.getX()).getLeft();
            TracingTensor rewrite7 = rewrite(visitView2$makeView2(view2, ((TracingTensor.TimesScalar) view2.getX()).getRight(), view2.getX().getShape()));
            if (!(view2 instanceof DScalar)) {
                return new TracingTensor.TimesScalar(left, rewrite7);
            }
            Intrinsics.checkNotNull(rewrite7, "null cannot be cast to non-null type org.diffkt.tracing.TracingScalar");
            return new TracingScalar.TimesScalar(left, (TracingScalar) rewrite7);
        }
        if (x instanceof TracingTensor.UnaryMinus) {
            TracingTensor rewrite8 = rewrite(visitView2$makeView2(view2, ((TracingTensor.UnaryMinus) view2.getX()).getX(), view2.getX().getShape()));
            if (!(view2 instanceof DScalar)) {
                return new TracingTensor.UnaryMinus(rewrite8);
            }
            Intrinsics.checkNotNull(rewrite8, "null cannot be cast to non-null type org.diffkt.tracing.TracingScalar");
            return new TracingScalar.UnaryMinus((TracingScalar) rewrite8);
        }
        if (x instanceof TracingTensor.Sin) {
            TracingTensor rewrite9 = rewrite(visitView2$makeView2(view2, ((TracingTensor.Sin) view2.getX()).getX(), view2.getX().getShape()));
            if (!(view2 instanceof DScalar)) {
                return new TracingTensor.Sin(rewrite9);
            }
            Intrinsics.checkNotNull(rewrite9, "null cannot be cast to non-null type org.diffkt.tracing.TracingScalar");
            return new TracingScalar.Sin((TracingScalar) rewrite9);
        }
        if (x instanceof TracingTensor.Cos) {
            TracingTensor rewrite10 = rewrite(visitView2$makeView2(view2, ((TracingTensor.Cos) view2.getX()).getX(), view2.getX().getShape()));
            if (!(view2 instanceof DScalar)) {
                return new TracingTensor.Cos(rewrite10);
            }
            Intrinsics.checkNotNull(rewrite10, "null cannot be cast to non-null type org.diffkt.tracing.TracingScalar");
            return new TracingScalar.Cos((TracingScalar) rewrite10);
        }
        if (x instanceof TracingTensor.Constant) {
            return TracingTensor.Companion.wrap(ViewKt.view(((TracingTensor.Constant) view2.getX()).getValues(), view2.getIndex(), view2.getAxis()));
        }
        if (x instanceof TracingTensor.Zero) {
            return view2 instanceof DScalar ? new TracingScalar.Zero() : new TracingTensor.Zero(view2.getShape());
        }
        if (x instanceof TracingTensor.IdentityGradient) {
            return TracingTensor.Companion.wrap(ViewKt.view(StridedFloatTensor.Companion.identityGradient(((TracingTensor.IdentityGradient) view2.getX()).getHalfShape()), view2.getIndex(), view2.getAxis()));
        }
        if (x instanceof TracingTensor.Reshape) {
            if (view2.getX().getRank() == 1 && (((TracingTensor.Reshape) view2.getX()).getX() instanceof DScalar) && view2.getIndex() == 0) {
                return ((TracingTensor.Reshape) view2.getX()).getX();
            }
        } else if ((x instanceof TracingTensor.BroadcastTo) && view2.getX().getRank() == 1) {
            return new TracingScalar.View2(((TracingTensor.BroadcastTo) view2.getX()).getX(), 0, 0);
        }
        Traceable visitView22 = super.visitView22(view2);
        Intrinsics.checkNotNull(visitView22, "null cannot be cast to non-null type org.diffkt.tracing.TracingTensor");
        return (TracingTensor) visitView22;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.diffkt.tracing.ShallowTracingRewriter, org.diffkt.tracing.TracingVisitor
    @NotNull
    /* renamed from: visitView3 */
    public Traceable visitView32(@NotNull TracingTensor.View3 view3) {
        Intrinsics.checkNotNullParameter(view3, "x");
        if (!(view3.getX() instanceof TracingTensor.Zero)) {
            return view3;
        }
        DTensor zeroOfSameKind = TracingTensorOperations.INSTANCE.zeroOfSameKind(view3, view3.getShape());
        Intrinsics.checkNotNull(zeroOfSameKind, "null cannot be cast to non-null type org.diffkt.tracing.TracingTensor");
        return (TracingTensor) zeroOfSameKind;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.diffkt.tracing.ShallowTracingRewriter, org.diffkt.tracing.TracingVisitor
    @NotNull
    /* renamed from: visitSplitPart */
    public Traceable visitSplitPart2(@NotNull TracingTensor.SplitPart splitPart) {
        boolean z;
        Intrinsics.checkNotNullParameter(splitPart, "x");
        if (!(splitPart.getFrom() instanceof TracingTensor.Split)) {
            return splitPart;
        }
        TracingTensor x = ((TracingTensor.Split) splitPart.getFrom()).getX();
        if (x instanceof TracingTensor.Meld) {
            List<TracingTensor> values = ((TracingTensor.Meld) x).getValues();
            ArrayList arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(values, 10));
            Iterator<T> it = values.iterator();
            while (it.hasNext()) {
                arrayList.add(((TracingTensor) it.next()).getShape());
            }
            if (Intrinsics.areEqual(arrayList, ((TracingTensor.Split) splitPart.getFrom()).getShapes())) {
                return ((TracingTensor.Meld) x).getValues().get(splitPart.getIndex());
            }
        }
        if (((TracingTensor.Split) splitPart.getFrom()).getShapes().size() == 1) {
            boolean z2 = splitPart.getIndex() == 0;
            if (_Assertions.ENABLED && !z2) {
                throw new AssertionError("Assertion failed");
            }
            DTensor reshape = ReshapeKt.reshape(x, ((TracingTensor.Split) splitPart.getFrom()).getShapes().get(0));
            Intrinsics.checkNotNull(reshape, "null cannot be cast to non-null type org.diffkt.tracing.TracingTensor");
            return (TracingTensor) reshape;
        }
        if (x instanceof TracingTensor.IdentityGradient) {
            return makeConstant(splitPart);
        }
        if (x instanceof TracingTensor.Zero) {
            return splitPart instanceof TracingScalar ? new TracingScalar.Zero() : new TracingTensor.Zero(splitPart.getShape());
        }
        if (x.getRank() == 1) {
            List<Shape> shapes = ((TracingTensor.Split) splitPart.getFrom()).getShapes();
            if (!(shapes instanceof Collection) || !shapes.isEmpty()) {
                Iterator<T> it2 = shapes.iterator();
                while (true) {
                    if (!it2.hasNext()) {
                        z = true;
                        break;
                    }
                    if (!((Shape) it2.next()).isScalar()) {
                        z = false;
                        break;
                    }
                }
            } else {
                z = true;
            }
            if (z) {
                return new TracingScalar.View2(x, splitPart.getIndex(), 0);
            }
        }
        return splitPart;
    }

    @NotNull
    public final DTensor[] getEmptyVariables() {
        return emptyVariables;
    }

    @NotNull
    public final TracingEvaluator getConstantEvaluator() {
        return constantEvaluator;
    }

    private final TracingTensor makeConstant(TracingTensor tracingTensor) {
        return TracingTensor.Companion.wrap(constantEvaluator.evaluate((Traceable) tracingTensor));
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.diffkt.tracing.ShallowTracingRewriter, org.diffkt.tracing.TracingVisitor
    @NotNull
    /* renamed from: visitIfThenElse */
    public Traceable visitIfThenElse2(@NotNull TracingTensor.IfThenElse ifThenElse) {
        Intrinsics.checkNotNullParameter(ifThenElse, "x");
        return Intrinsics.areEqual(ifThenElse.getWhenTrue(), ifThenElse.getWhenFalse()) ? ifThenElse.getWhenTrue() : ifThenElse;
    }

    private static final TracingTensor visitView2$makeView2(TracingTensor.View2 view2, TracingTensor tracingTensor, Shape shape) {
        if ((tracingTensor instanceof DScalar) && (view2 instanceof DScalar)) {
            return tracingTensor;
        }
        DTensor broadcastTo = BroadcastKt.broadcastTo(tracingTensor, shape);
        Intrinsics.checkNotNull(broadcastTo, "null cannot be cast to non-null type org.diffkt.tracing.TracingTensor");
        TracingTensor tracingTensor2 = (TracingTensor) broadcastTo;
        return view2 instanceof DScalar ? new TracingScalar.View2(tracingTensor2, view2.getIndex(), view2.getAxis()) : new TracingTensor.View2(tracingTensor2, view2.getIndex(), view2.getAxis(), shape);
    }

    static {
        DTensor[] dTensorArr = new DTensor[0];
        for (int i = 0; i < 0; i++) {
            dTensorArr[i] = null;
        }
        emptyVariables = dTensorArr;
        optimizer optimizerVar = INSTANCE;
        constantEvaluator = new TracingEvaluator(emptyVariables, new TraceId());
    }
}
