package org.diffkt.tracing;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import kotlin.Metadata;
import kotlin.NotImplementedError;
import kotlin.Pair;
import kotlin.collections.CollectionsKt;
import kotlin.jvm.functions.Function1;
import kotlin.jvm.internal.Intrinsics;
import kotlin.jvm.internal.SerializedIr;
import org.diffkt.Convolve;
import org.diffkt.DScalar;
import org.diffkt.Shape;
import org.diffkt.tracing.TracingRandomKey;
import org.diffkt.tracing.TracingScalar;
import org.diffkt.tracing.TracingTensor;
import org.jetbrains.annotations.NotNull;

/* compiled from: DeepTracingRewriter.kt */
@SerializedIr(b = {"��"})
@Metadata(mv = {Convolve.H_AXIS, 6, Convolve.N_AXIS}, k = Convolve.H_AXIS, xi = 48, d1 = {"��ü\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0005\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\b\u0010\u0018��2\u00020\u0001B\u0005¢\u0006\u0002\u0010\u0002J\u000e\u0010\t\u001a\u00020\u00052\u0006\u0010\n\u001a\u00020\u0005J\u0010\u0010\u000b\u001a\u00020\u00052\u0006\u0010\n\u001a\u00020\u0005H\u0016J\u0010\u0010\u0003\u001a\u00020\u00052\u0006\u0010\n\u001a\u00020\u0005H\u0002J\u0010\u0010\u0003\u001a\u00020\f2\u0006\u0010\n\u001a\u00020\fH\u0002J\u0010\u0010\u0003\u001a\u00020\r2\u0006\u0010\n\u001a\u00020\rH\u0002J\u0010\u0010\u000e\u001a\u00020\r2\u0006\u0010\n\u001a\u00020\u000fH\u0016J\u0010\u0010\u0010\u001a\u00020\r2\u0006\u0010\n\u001a\u00020\u0011H\u0016J\u0010\u0010\u0012\u001a\u00020\r2\u0006\u0010\n\u001a\u00020\u0013H\u0016J\u0010\u0010\u0014\u001a\u00020\r2\u0006\u0010\n\u001a\u00020\u0015H\u0016J\u0010\u0010\u0016\u001a\u00020\r2\u0006\u0010\n\u001a\u00020\u0017H\u0016J\u0010\u0010\u0018\u001a\u00020\r2\u0006\u0010\n\u001a\u00020\u0019H\u0016J\u0010\u0010\u001a\u001a\u00020\r2\u0006\u0010\n\u001a\u00020\u001bH\u0016J\u0010\u0010\u001c\u001a\u00020\r2\u0006\u0010\n\u001a\u00020\u001dH\u0016J\u0010\u0010\u001e\u001a\u00020\r2\u0006\u0010\n\u001a\u00020\u001fH\u0016J\u0010\u0010 \u001a\u00020\r2\u0006\u0010\n\u001a\u00020!H\u0016J\u0010\u0010\"\u001a\u00020\r2\u0006\u0010\n\u001a\u00020#H\u0016J\u0010\u0010$\u001a\u00020\r2\u0006\u0010\n\u001a\u00020%H\u0016J\u0010\u0010&\u001a\u00020\r2\u0006\u0010\n\u001a\u00020'H\u0016J\u0010\u0010(\u001a\u00020\r2\u0006\u0010\n\u001a\u00020)H\u0016J\u0010\u0010*\u001a\u00020\r2\u0006\u0010\n\u001a\u00020+H\u0016J\u0010\u0010,\u001a\u00020\r2\u0006\u0010\n\u001a\u00020-H\u0016J\u0010\u0010.\u001a\u00020\r2\u0006\u0010\n\u001a\u00020/H\u0016J\u0010\u00100\u001a\u00020\r2\u0006\u0010\n\u001a\u000201H\u0016J\u0010\u00102\u001a\u00020\r2\u0006\u0010\n\u001a\u000203H\u0016J\u0010\u00104\u001a\u00020\r2\u0006\u0010\n\u001a\u000205H\u0016J\u0010\u00106\u001a\u00020\r2\u0006\u0010\n\u001a\u000207H\u0016J\u0010\u00108\u001a\u00020\r2\u0006\u0010\n\u001a\u000209H\u0016J\u0010\u0010:\u001a\u00020\r2\u0006\u0010\n\u001a\u00020;H\u0016J\u0010\u0010<\u001a\u00020\r2\u0006\u0010\n\u001a\u00020=H\u0016J\u0010\u0010>\u001a\u00020\r2\u0006\u0010\n\u001a\u00020?H\u0016J\u0010\u0010@\u001a\u00020\r2\u0006\u0010\n\u001a\u00020AH\u0016J\u0010\u0010B\u001a\u00020\r2\u0006\u0010\n\u001a\u00020CH\u0016J\u0010\u0010D\u001a\u00020\r2\u0006\u0010\n\u001a\u00020EH\u0016J\u0010\u0010F\u001a\u00020\r2\u0006\u0010\n\u001a\u00020GH\u0016J\u0010\u0010H\u001a\u00020\r2\u0006\u0010\n\u001a\u00020IH\u0016J\u0010\u0010J\u001a\u00020\u00052\u0006\u0010\n\u001a\u00020KH\u0016J\u0010\u0010L\u001a\u00020\u00052\u0006\u0010\n\u001a\u00020MH\u0016J\u0010\u0010N\u001a\u00020\u00052\u0006\u0010\n\u001a\u00020OH\u0016J\u0010\u0010P\u001a\u00020\r2\u0006\u0010\n\u001a\u00020QH\u0016J\u0010\u0010R\u001a\u00020\r2\u0006\u0010\n\u001a\u00020SH\u0016J\u0010\u0010T\u001a\u00020\r2\u0006\u0010\n\u001a\u00020UH\u0016J\u0010\u0010V\u001a\u00020\r2\u0006\u0010\n\u001a\u00020WH\u0016J\u0010\u0010X\u001a\u00020\r2\u0006\u0010\n\u001a\u00020YH\u0016J\u0010\u0010Z\u001a\u00020\r2\u0006\u0010\n\u001a\u00020[H\u0016J\u0010\u0010\\\u001a\u00020\r2\u0006\u0010\n\u001a\u00020]H\u0016J\u0010\u0010^\u001a\u00020\r2\u0006\u0010\n\u001a\u00020_H\u0016J\u0010\u0010`\u001a\u00020\r2\u0006\u0010\n\u001a\u00020aH\u0016J\u0010\u0010b\u001a\u00020\r2\u0006\u0010\n\u001a\u00020cH\u0016J\u0010\u0010d\u001a\u00020\r2\u0006\u0010\n\u001a\u00020eH\u0016J\u0010\u0010f\u001a\u00020\r2\u0006\u0010\n\u001a\u00020gH\u0016J\u0010\u0010h\u001a\u00020\r2\u0006\u0010\n\u001a\u00020iH\u0016J\u0010\u0010j\u001a\u00020\r2\u0006\u0010\n\u001a\u00020kH\u0016J\u0010\u0010l\u001a\u00020\r2\u0006\u0010\n\u001a\u00020mH\u0016J\u0010\u0010n\u001a\u00020\r2\u0006\u0010\n\u001a\u00020oH\u0016J\u0010\u0010p\u001a\u00020\r2\u0006\u0010\n\u001a\u00020qH\u0016J\u0010\u0010r\u001a\u00020\r2\u0006\u0010\n\u001a\u00020sH\u0016J\u0010\u0010t\u001a\u00020\r2\u0006\u0010\n\u001a\u00020uH\u0016J\u0010\u0010v\u001a\u00020\r2\u0006\u0010\n\u001a\u00020wH\u0016J\u0010\u0010x\u001a\u00020\r2\u0006\u0010\n\u001a\u00020yH\u0016J\u0010\u0010z\u001a\u00020\r2\u0006\u0010\n\u001a\u00020{H\u0016J\u0010\u0010|\u001a\u00020\r2\u0006\u0010\n\u001a\u00020}H\u0016J\u0010\u0010~\u001a\u00020\r2\u0006\u0010\n\u001a\u00020\u007fH\u0016R-\u0010\u0003\u001a\u001e\u0012\u0004\u0012\u00020\u0005\u0012\u0004\u0012\u00020\u00050\u0004j\u000e\u0012\u0004\u0012\u00020\u0005\u0012\u0004\u0012\u00020\u0005`\u0006¢\u0006\b\n��\u001a\u0004\b\u0007\u0010\b¨\u0006\u0080\u0001"}, d2 = {"Lorg/diffkt/tracing/DeepTracingRewriter;", "Lorg/diffkt/tracing/ShallowTracingRewriter;", "()V", "rewrittenForm", "Ljava/util/HashMap;", "Lorg/diffkt/tracing/Traceable;", "Lkotlin/collections/HashMap;", "getRewrittenForm", "()Ljava/util/HashMap;", "rewrite", "x", "rewriteOne", "Lorg/diffkt/tracing/TracingRandomKey;", "Lorg/diffkt/tracing/TracingTensor;", "visitAtan", "Lorg/diffkt/tracing/TracingTensor$Atan;", "visitAvgPool", "Lorg/diffkt/tracing/TracingTensor$AvgPool;", "visitAvgPoolGrad", "Lorg/diffkt/tracing/TracingTensor$AvgPoolGrad;", "visitBroadcastTo", "Lorg/diffkt/tracing/TracingTensor$BroadcastTo;", "visitCompare", "Lorg/diffkt/tracing/TracingTensor$Compare;", "visitConcat", "Lorg/diffkt/tracing/TracingTensor$Concat;", "visitConvImpl", "Lorg/diffkt/tracing/TracingTensor$ConvImpl;", "visitCos", "Lorg/diffkt/tracing/TracingTensor$Cos;", "visitDigamma", "Lorg/diffkt/tracing/TracingTensor$Digamma;", "visitDiv", "Lorg/diffkt/tracing/TracingTensor$Div;", "visitExp", "Lorg/diffkt/tracing/TracingTensor$Exp;", "visitExpand", "Lorg/diffkt/tracing/TracingTensor$Expand;", "visitFlip", "Lorg/diffkt/tracing/TracingTensor$Flip;", "visitGather", "Lorg/diffkt/tracing/TracingTensor$Gather;", "visitGatherAtIndices", "Lorg/diffkt/tracing/TracingTensor$GatherAtIndices;", "visitIdentityGradient", "Lorg/diffkt/tracing/TracingTensor$IdentityGradient;", "visitIfThenElse", "Lorg/diffkt/tracing/TracingTensor$IfThenElse;", "visitLgamma", "Lorg/diffkt/tracing/TracingTensor$Lgamma;", "visitLn", "Lorg/diffkt/tracing/TracingTensor$Ln;", "visitLogSoftmax", "Lorg/diffkt/tracing/TracingTensor$LogSoftmax;", "visitLogSoftmaxGrad", "Lorg/diffkt/tracing/TracingTensor$LogSoftmaxGrad;", "visitMatmul", "Lorg/diffkt/tracing/TracingTensor$Matmul;", "visitMaxPoolWithIndices", "Lorg/diffkt/tracing/TracingTensor$MaxPoolWithIndices;", "visitMeld", "Lorg/diffkt/tracing/TracingTensor$Meld;", "visitMinus", "Lorg/diffkt/tracing/TracingTensor$Minus;", "visitOuterProduct", "Lorg/diffkt/tracing/TracingTensor$OuterProduct;", "visitPlus", "Lorg/diffkt/tracing/TracingTensor$Plus;", "visitPolygamma", "Lorg/diffkt/tracing/TracingTensor$Polygamma;", "visitPow", "Lorg/diffkt/tracing/TracingTensor$Pow;", "visitRandomFloats", "Lorg/diffkt/tracing/TracingTensor$RandomFloats;", "visitRandomSplit", "Lorg/diffkt/tracing/TracingRandomKey$Split;", "visitRandomSplitPart", "Lorg/diffkt/tracing/TracingRandomKey$SplitPart;", "visitRandomVariable", "Lorg/diffkt/tracing/TracingRandomKey$Variable;", "visitRelu", "Lorg/diffkt/tracing/TracingTensor$Relu;", "visitReluGrad", "Lorg/diffkt/tracing/TracingTensor$ReluGrad;", "visitReshape", "Lorg/diffkt/tracing/TracingTensor$Reshape;", "visitReshapeToScalar", "Lorg/diffkt/tracing/TracingScalar$ReshapeToScalar;", "visitScatter", "Lorg/diffkt/tracing/TracingTensor$Scatter;", "visitScatterAtIndices", "Lorg/diffkt/tracing/TracingTensor$ScatterAtIndices;", "visitSigmoid", "Lorg/diffkt/tracing/TracingTensor$Sigmoid;", "visitSin", "Lorg/diffkt/tracing/TracingTensor$Sin;", "visitSplit", "Lorg/diffkt/tracing/TracingTensor$Split;", "visitSplitPart", "Lorg/diffkt/tracing/TracingTensor$SplitPart;", "visitSqrt", "Lorg/diffkt/tracing/TracingTensor$Sqrt;", "visitSqueeze", "Lorg/diffkt/tracing/TracingTensor$Squeeze;", "visitSum", "Lorg/diffkt/tracing/TracingTensor$Sum;", "visitTan", "Lorg/diffkt/tracing/TracingTensor$Tan;", "visitTanh", "Lorg/diffkt/tracing/TracingTensor$Tanh;", "visitTimes", "Lorg/diffkt/tracing/TracingTensor$Times;", "visitTimesScalar", "Lorg/diffkt/tracing/TracingTensor$TimesScalar;", "visitTranspose", "Lorg/diffkt/tracing/TracingTensor$Transpose;", "visitUnaryMinus", "Lorg/diffkt/tracing/TracingTensor$UnaryMinus;", "visitUnsqueeze", "Lorg/diffkt/tracing/TracingTensor$Unsqueeze;", "visitView1", "Lorg/diffkt/tracing/TracingTensor$View1;", "visitView2", "Lorg/diffkt/tracing/TracingTensor$View2;", "visitView3", "Lorg/diffkt/tracing/TracingTensor$View3;", "visitZero", "Lorg/diffkt/tracing/TracingTensor$Zero;", "api"})
/* loaded from: input_file:org/diffkt/tracing/DeepTracingRewriter.class */
public class DeepTracingRewriter extends ShallowTracingRewriter {

    @NotNull
    private final HashMap<Traceable, Traceable> rewrittenForm = new HashMap<>();

    @NotNull
    public final HashMap<Traceable, Traceable> getRewrittenForm() {
        return this.rewrittenForm;
    }

    @NotNull
    public final Traceable rewrite(@NotNull Traceable traceable) {
        Intrinsics.checkNotNullParameter(traceable, "x");
        for (Traceable traceable2 : CollectionsKt.reversed(TopologicalSortKt.topologicalSort(CollectionsKt.listOf(traceable), new Function1<Traceable, Boolean>() { // from class: org.diffkt.tracing.DeepTracingRewriter$rewrite$1
            /* JADX INFO: Access modifiers changed from: package-private */
            {
                super(1);
            }

            @NotNull
            public final Boolean invoke(@NotNull Traceable traceable3) {
                Intrinsics.checkNotNullParameter(traceable3, "it");
                return Boolean.valueOf(DeepTracingRewriter.this.getRewrittenForm().containsKey(traceable3));
            }
        }))) {
            this.rewrittenForm.put(traceable2, rewriteOne(traceable2));
        }
        return rewrittenForm(traceable);
    }

    @NotNull
    public Traceable rewriteOne(@NotNull Traceable traceable) {
        Intrinsics.checkNotNullParameter(traceable, "x");
        Traceable traceable2 = this.rewrittenForm.get(traceable);
        return traceable2 != null ? traceable2 : (Traceable) traceable.accept(this);
    }

    private final Traceable rewrittenForm(Traceable traceable) {
        Traceable traceable2 = this.rewrittenForm.get(traceable);
        Intrinsics.checkNotNull(traceable2);
        return traceable2;
    }

    private final TracingTensor rewrittenForm(TracingTensor tracingTensor) {
        Traceable traceable = this.rewrittenForm.get(tracingTensor);
        Intrinsics.checkNotNull(traceable, "null cannot be cast to non-null type org.diffkt.tracing.TracingTensor");
        return (TracingTensor) traceable;
    }

    private final TracingRandomKey rewrittenForm(TracingRandomKey tracingRandomKey) {
        Traceable traceable = this.rewrittenForm.get(tracingRandomKey);
        Intrinsics.checkNotNull(traceable, "null cannot be cast to non-null type org.diffkt.tracing.TracingRandomKey");
        return (TracingRandomKey) traceable;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.diffkt.tracing.ShallowTracingRewriter, org.diffkt.tracing.TracingVisitor
    @NotNull
    /* renamed from: visitPlus */
    public Traceable visitPlus2(@NotNull TracingTensor.Plus plus) {
        Intrinsics.checkNotNullParameter(plus, "x");
        TracingTensor rewrittenForm = rewrittenForm(plus.getLeft());
        TracingTensor rewrittenForm2 = rewrittenForm(plus.getRight());
        if (rewrittenForm == plus.getLeft() && rewrittenForm2 == plus.getRight()) {
            return plus;
        }
        return ((rewrittenForm instanceof TracingScalar) && (rewrittenForm2 instanceof TracingScalar)) ? new TracingScalar.Plus((TracingScalar) rewrittenForm, (TracingScalar) rewrittenForm2) : new TracingTensor.Plus(rewrittenForm, rewrittenForm2);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.diffkt.tracing.ShallowTracingRewriter, org.diffkt.tracing.TracingVisitor
    @NotNull
    /* renamed from: visitMinus */
    public Traceable visitMinus2(@NotNull TracingTensor.Minus minus) {
        Intrinsics.checkNotNullParameter(minus, "x");
        TracingTensor rewrittenForm = rewrittenForm(minus.getLeft());
        TracingTensor rewrittenForm2 = rewrittenForm(minus.getRight());
        if (rewrittenForm == minus.getLeft() && rewrittenForm2 == minus.getRight()) {
            return minus;
        }
        return ((rewrittenForm instanceof TracingScalar) && (rewrittenForm2 instanceof TracingScalar)) ? new TracingScalar.Minus((TracingScalar) rewrittenForm, (TracingScalar) rewrittenForm2) : new TracingTensor.Minus(rewrittenForm, rewrittenForm2);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.diffkt.tracing.ShallowTracingRewriter, org.diffkt.tracing.TracingVisitor
    @NotNull
    /* renamed from: visitTimes */
    public Traceable visitTimes2(@NotNull TracingTensor.Times times) {
        Intrinsics.checkNotNullParameter(times, "x");
        TracingTensor rewrittenForm = rewrittenForm(times.getLeft());
        TracingTensor rewrittenForm2 = rewrittenForm(times.getRight());
        return (rewrittenForm == times.getLeft() && rewrittenForm2 == times.getRight()) ? times : new TracingTensor.Times(rewrittenForm, rewrittenForm2);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.diffkt.tracing.ShallowTracingRewriter, org.diffkt.tracing.TracingVisitor
    @NotNull
    /* renamed from: visitTimesScalar */
    public Traceable visitTimesScalar2(@NotNull TracingTensor.TimesScalar timesScalar) {
        Intrinsics.checkNotNullParameter(timesScalar, "x");
        TracingTensor rewrittenForm = rewrittenForm((TracingTensor) timesScalar.getLeft());
        Intrinsics.checkNotNull(rewrittenForm, "null cannot be cast to non-null type org.diffkt.tracing.TracingScalar");
        TracingScalar tracingScalar = (TracingScalar) rewrittenForm;
        TracingTensor rewrittenForm2 = rewrittenForm(timesScalar.getRight());
        if (tracingScalar == timesScalar.getLeft() && rewrittenForm2 == timesScalar.getRight()) {
            return timesScalar;
        }
        return rewrittenForm2 instanceof TracingScalar ? new TracingScalar.TimesScalar(tracingScalar, (TracingScalar) rewrittenForm2) : new TracingTensor.TimesScalar(tracingScalar, rewrittenForm2);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.diffkt.tracing.ShallowTracingRewriter, org.diffkt.tracing.TracingVisitor
    @NotNull
    /* renamed from: visitDiv */
    public Traceable visitDiv2(@NotNull TracingTensor.Div div) {
        Intrinsics.checkNotNullParameter(div, "x");
        TracingTensor rewrittenForm = rewrittenForm(div.getLeft());
        TracingTensor rewrittenForm2 = rewrittenForm(div.getRight());
        if (rewrittenForm == div.getLeft() && rewrittenForm2 == div.getRight()) {
            return div;
        }
        return ((rewrittenForm instanceof TracingScalar) && (rewrittenForm2 instanceof TracingScalar)) ? new TracingScalar.Div((TracingScalar) rewrittenForm, (TracingScalar) rewrittenForm2) : new TracingTensor.Div(rewrittenForm, rewrittenForm2);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.diffkt.tracing.ShallowTracingRewriter, org.diffkt.tracing.TracingVisitor
    @NotNull
    /* renamed from: visitZero */
    public Traceable visitZero2(@NotNull TracingTensor.Zero zero) {
        Intrinsics.checkNotNullParameter(zero, "x");
        return zero;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.diffkt.tracing.ShallowTracingRewriter, org.diffkt.tracing.TracingVisitor
    @NotNull
    /* renamed from: visitIdentityGradient */
    public Traceable visitIdentityGradient2(@NotNull TracingTensor.IdentityGradient identityGradient) {
        Intrinsics.checkNotNullParameter(identityGradient, "x");
        return identityGradient;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.diffkt.tracing.ShallowTracingRewriter, org.diffkt.tracing.TracingVisitor
    @NotNull
    /* renamed from: visitUnaryMinus */
    public Traceable visitUnaryMinus2(@NotNull TracingTensor.UnaryMinus unaryMinus) {
        Intrinsics.checkNotNullParameter(unaryMinus, "x");
        TracingTensor rewrittenForm = rewrittenForm(unaryMinus.getX());
        if (rewrittenForm != unaryMinus.getX()) {
            return rewrittenForm instanceof TracingScalar ? new TracingScalar.UnaryMinus((TracingScalar) rewrittenForm) : new TracingTensor.UnaryMinus(rewrittenForm);
        }
        return unaryMinus;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.diffkt.tracing.ShallowTracingRewriter, org.diffkt.tracing.TracingVisitor
    @NotNull
    /* renamed from: visitMatmul */
    public Traceable visitMatmul2(@NotNull TracingTensor.Matmul matmul) {
        Intrinsics.checkNotNullParameter(matmul, "x");
        TracingTensor rewrittenForm = rewrittenForm(matmul.getX());
        TracingTensor rewrittenForm2 = rewrittenForm(matmul.getY());
        Shape plus = matmul.getA().plus(matmul.getB()).plus(matmul.getD());
        if (rewrittenForm == matmul.getX() && rewrittenForm2 == matmul.getY()) {
            return matmul;
        }
        return plus.isScalar() ? new TracingScalar.Matmul(rewrittenForm, rewrittenForm2, matmul.getA(), matmul.getB(), matmul.getC(), matmul.getD()) : new TracingTensor.Matmul(rewrittenForm, rewrittenForm2, matmul.getA(), matmul.getB(), matmul.getC(), matmul.getD());
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.diffkt.tracing.ShallowTracingRewriter, org.diffkt.tracing.TracingVisitor
    @NotNull
    /* renamed from: visitOuterProduct */
    public Traceable visitOuterProduct2(@NotNull TracingTensor.OuterProduct outerProduct) {
        Intrinsics.checkNotNullParameter(outerProduct, "x");
        TracingTensor rewrittenForm = rewrittenForm(outerProduct.getX());
        TracingTensor rewrittenForm2 = rewrittenForm(outerProduct.getY());
        return (rewrittenForm == outerProduct.getX() && rewrittenForm2 == outerProduct.getY()) ? outerProduct : new TracingTensor.OuterProduct(rewrittenForm, rewrittenForm2);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.diffkt.tracing.ShallowTracingRewriter, org.diffkt.tracing.TracingVisitor
    @NotNull
    /* renamed from: visitSin */
    public Traceable visitSin2(@NotNull TracingTensor.Sin sin) {
        Intrinsics.checkNotNullParameter(sin, "x");
        TracingTensor rewrittenForm = rewrittenForm(sin.getX());
        if (rewrittenForm != sin.getX()) {
            return rewrittenForm instanceof TracingScalar ? new TracingScalar.Sin((TracingScalar) rewrittenForm) : new TracingTensor.Sin(rewrittenForm);
        }
        return sin;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.diffkt.tracing.ShallowTracingRewriter, org.diffkt.tracing.TracingVisitor
    @NotNull
    /* renamed from: visitCos */
    public Traceable visitCos2(@NotNull TracingTensor.Cos cos) {
        Intrinsics.checkNotNullParameter(cos, "x");
        TracingTensor rewrittenForm = rewrittenForm(cos.getX());
        if (rewrittenForm != cos.getX()) {
            return rewrittenForm instanceof TracingScalar ? new TracingScalar.Cos((TracingScalar) rewrittenForm) : new TracingTensor.Cos(rewrittenForm);
        }
        return cos;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.diffkt.tracing.ShallowTracingRewriter, org.diffkt.tracing.TracingVisitor
    @NotNull
    /* renamed from: visitTan */
    public Traceable visitTan2(@NotNull TracingTensor.Tan tan) {
        Intrinsics.checkNotNullParameter(tan, "x");
        TracingTensor rewrittenForm = rewrittenForm(tan.getX());
        if (rewrittenForm != tan.getX()) {
            return rewrittenForm instanceof TracingScalar ? new TracingScalar.Tan((TracingScalar) rewrittenForm) : new TracingTensor.Tan(rewrittenForm);
        }
        return tan;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.diffkt.tracing.ShallowTracingRewriter, org.diffkt.tracing.TracingVisitor
    @NotNull
    /* renamed from: visitAtan */
    public Traceable visitAtan2(@NotNull TracingTensor.Atan atan) {
        Intrinsics.checkNotNullParameter(atan, "x");
        TracingTensor rewrittenForm = rewrittenForm(atan.getX());
        if (rewrittenForm != atan.getX()) {
            return rewrittenForm instanceof TracingScalar ? new TracingScalar.Atan((TracingScalar) rewrittenForm) : new TracingTensor.Atan(rewrittenForm);
        }
        return atan;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.diffkt.tracing.ShallowTracingRewriter, org.diffkt.tracing.TracingVisitor
    @NotNull
    /* renamed from: visitExp */
    public Traceable visitExp2(@NotNull TracingTensor.Exp exp) {
        Intrinsics.checkNotNullParameter(exp, "x");
        TracingTensor rewrittenForm = rewrittenForm(exp.getX());
        if (rewrittenForm != exp.getX()) {
            return rewrittenForm instanceof TracingScalar ? new TracingScalar.Exp((TracingScalar) rewrittenForm) : new TracingTensor.Exp(rewrittenForm);
        }
        return exp;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.diffkt.tracing.ShallowTracingRewriter, org.diffkt.tracing.TracingVisitor
    @NotNull
    /* renamed from: visitLn */
    public Traceable visitLn2(@NotNull TracingTensor.Ln ln) {
        Intrinsics.checkNotNullParameter(ln, "x");
        TracingTensor rewrittenForm = rewrittenForm(ln.getX());
        if (rewrittenForm != ln.getX()) {
            return rewrittenForm instanceof TracingScalar ? new TracingScalar.Ln((TracingScalar) rewrittenForm) : new TracingTensor.Ln(rewrittenForm);
        }
        return ln;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.diffkt.tracing.ShallowTracingRewriter, org.diffkt.tracing.TracingVisitor
    @NotNull
    /* renamed from: visitLgamma */
    public Traceable visitLgamma2(@NotNull TracingTensor.Lgamma lgamma) {
        Intrinsics.checkNotNullParameter(lgamma, "x");
        TracingTensor rewrittenForm = rewrittenForm(lgamma.getX());
        if (rewrittenForm != lgamma.getX()) {
            return rewrittenForm instanceof TracingScalar ? new TracingScalar.Lgamma((TracingScalar) rewrittenForm) : new TracingTensor.Lgamma(rewrittenForm);
        }
        return lgamma;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.diffkt.tracing.ShallowTracingRewriter, org.diffkt.tracing.TracingVisitor
    @NotNull
    /* renamed from: visitDigamma */
    public Traceable visitDigamma2(@NotNull TracingTensor.Digamma digamma) {
        Intrinsics.checkNotNullParameter(digamma, "x");
        TracingTensor rewrittenForm = rewrittenForm(digamma.getX());
        if (rewrittenForm != digamma.getX()) {
            return rewrittenForm instanceof TracingScalar ? new TracingScalar.Digamma((TracingScalar) rewrittenForm) : new TracingTensor.Digamma(rewrittenForm);
        }
        return digamma;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.diffkt.tracing.ShallowTracingRewriter, org.diffkt.tracing.TracingVisitor
    @NotNull
    /* renamed from: visitPolygamma */
    public Traceable visitPolygamma2(@NotNull TracingTensor.Polygamma polygamma) {
        Intrinsics.checkNotNullParameter(polygamma, "x");
        TracingTensor rewrittenForm = rewrittenForm(polygamma.getX());
        if (rewrittenForm != polygamma.getX()) {
            return rewrittenForm instanceof TracingScalar ? new TracingScalar.Polygamma(polygamma.getN(), (TracingScalar) rewrittenForm) : new TracingTensor.Polygamma(polygamma.getN(), rewrittenForm);
        }
        return polygamma;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.diffkt.tracing.ShallowTracingRewriter, org.diffkt.tracing.TracingVisitor
    @NotNull
    /* renamed from: visitSqrt */
    public Traceable visitSqrt2(@NotNull TracingTensor.Sqrt sqrt) {
        Intrinsics.checkNotNullParameter(sqrt, "x");
        TracingTensor rewrittenForm = rewrittenForm(sqrt.getX());
        if (rewrittenForm != sqrt.getX()) {
            return rewrittenForm instanceof TracingScalar ? new TracingScalar.Sqrt((TracingScalar) rewrittenForm) : new TracingTensor.Sqrt(rewrittenForm);
        }
        return sqrt;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.diffkt.tracing.ShallowTracingRewriter, org.diffkt.tracing.TracingVisitor
    @NotNull
    /* renamed from: visitTanh */
    public Traceable visitTanh2(@NotNull TracingTensor.Tanh tanh) {
        Intrinsics.checkNotNullParameter(tanh, "x");
        TracingTensor rewrittenForm = rewrittenForm(tanh.getX());
        if (rewrittenForm != tanh.getX()) {
            return rewrittenForm instanceof TracingScalar ? new TracingScalar.Tanh((TracingScalar) rewrittenForm) : new TracingTensor.Tanh(rewrittenForm);
        }
        return tanh;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.diffkt.tracing.ShallowTracingRewriter, org.diffkt.tracing.TracingVisitor
    @NotNull
    /* renamed from: visitMeld */
    public Traceable visitMeld2(@NotNull TracingTensor.Meld meld) {
        boolean z;
        Intrinsics.checkNotNullParameter(meld, "x");
        List<TracingTensor> values = meld.getValues();
        ArrayList arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(values, 10));
        Iterator<T> it = values.iterator();
        while (it.hasNext()) {
            arrayList.add(rewrittenForm((TracingTensor) it.next()));
        }
        ArrayList arrayList2 = arrayList;
        List zip = CollectionsKt.zip(meld.getValues(), arrayList2);
        if (!(zip instanceof Collection) || !zip.isEmpty()) {
            Iterator it2 = zip.iterator();
            while (true) {
                if (!it2.hasNext()) {
                    z = false;
                    break;
                }
                Pair pair = (Pair) it2.next();
                if (pair.getFirst() != pair.getSecond()) {
                    z = true;
                    break;
                }
            }
        } else {
            z = false;
        }
        return z ? new TracingTensor.Meld(arrayList2) : meld;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.diffkt.tracing.ShallowTracingRewriter, org.diffkt.tracing.TracingVisitor
    @NotNull
    /* renamed from: visitSplit */
    public Traceable visitSplit2(@NotNull TracingTensor.Split split) {
        Intrinsics.checkNotNullParameter(split, "x");
        TracingTensor rewrittenForm = rewrittenForm(split.getX());
        if (rewrittenForm != split.getX()) {
            return rewrittenForm instanceof TracingScalar ? new TracingScalar.Split((TracingScalar) rewrittenForm, split.getShapes()) : new TracingTensor.Split(rewrittenForm, split.getShapes());
        }
        return split;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.diffkt.tracing.ShallowTracingRewriter, org.diffkt.tracing.TracingVisitor
    @NotNull
    /* renamed from: visitSplitPart */
    public Traceable visitSplitPart2(@NotNull TracingTensor.SplitPart splitPart) {
        Intrinsics.checkNotNullParameter(splitPart, "x");
        TracingTensor rewrittenForm = rewrittenForm(splitPart.getFrom());
        if (rewrittenForm != splitPart.getFrom()) {
            return splitPart.isScalar() ? new TracingScalar.SplitPart(rewrittenForm, splitPart.getIndex()) : new TracingTensor.SplitPart(rewrittenForm, splitPart.getIndex(), splitPart.getShape());
        }
        return splitPart;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.diffkt.tracing.ShallowTracingRewriter, org.diffkt.tracing.TracingVisitor
    @NotNull
    /* renamed from: visitConcat */
    public Traceable visitConcat2(@NotNull TracingTensor.Concat concat) {
        boolean z;
        Intrinsics.checkNotNullParameter(concat, "x");
        List<TracingTensor> slices = concat.getSlices();
        ArrayList arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(slices, 10));
        Iterator<T> it = slices.iterator();
        while (it.hasNext()) {
            arrayList.add(rewrittenForm((TracingTensor) it.next()));
        }
        ArrayList arrayList2 = arrayList;
        List zip = CollectionsKt.zip(concat.getSlices(), arrayList2);
        if (!(zip instanceof Collection) || !zip.isEmpty()) {
            Iterator it2 = zip.iterator();
            while (true) {
                if (!it2.hasNext()) {
                    z = false;
                    break;
                }
                Pair pair = (Pair) it2.next();
                if (pair.getFirst() != pair.getSecond()) {
                    z = true;
                    break;
                }
            }
        } else {
            z = false;
        }
        return z ? new TracingTensor.Concat(arrayList2, concat.getAxis()) : concat;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.diffkt.tracing.ShallowTracingRewriter, org.diffkt.tracing.TracingVisitor
    @NotNull
    /* renamed from: visitBroadcastTo */
    public Traceable visitBroadcastTo2(@NotNull TracingTensor.BroadcastTo broadcastTo) {
        Intrinsics.checkNotNullParameter(broadcastTo, "x");
        TracingTensor rewrittenForm = rewrittenForm(broadcastTo.getX());
        return rewrittenForm != broadcastTo.getX() ? new TracingTensor.BroadcastTo(rewrittenForm, broadcastTo.getShape()) : broadcastTo;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.diffkt.tracing.ShallowTracingRewriter, org.diffkt.tracing.TracingVisitor
    @NotNull
    /* renamed from: visitConvImpl */
    public Traceable visitConvImpl2(@NotNull TracingTensor.ConvImpl convImpl) {
        Intrinsics.checkNotNullParameter(convImpl, "x");
        TracingTensor rewrittenForm = rewrittenForm(convImpl.getSignal());
        TracingTensor rewrittenForm2 = rewrittenForm(convImpl.getFilter());
        return (rewrittenForm == convImpl.getSignal() && rewrittenForm2 == convImpl.getFilter()) ? convImpl : new TracingTensor.ConvImpl(rewrittenForm, rewrittenForm2, convImpl.getHStride(), convImpl.getVStride(), convImpl.getPadding());
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.diffkt.tracing.ShallowTracingRewriter, org.diffkt.tracing.TracingVisitor
    @NotNull
    /* renamed from: visitExpand */
    public Traceable visitExpand2(@NotNull TracingTensor.Expand expand) {
        Intrinsics.checkNotNullParameter(expand, "x");
        TracingTensor rewrittenForm = rewrittenForm(expand.getX());
        return rewrittenForm != expand.getX() ? new TracingTensor.Expand(rewrittenForm, expand.getShape()) : expand;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.diffkt.tracing.ShallowTracingRewriter, org.diffkt.tracing.TracingVisitor
    @NotNull
    /* renamed from: visitFlip */
    public Traceable visitFlip2(@NotNull TracingTensor.Flip flip) {
        Intrinsics.checkNotNullParameter(flip, "x");
        TracingTensor rewrittenForm = rewrittenForm(flip.getX());
        return rewrittenForm != flip.getX() ? new TracingTensor.Flip(rewrittenForm, flip.getAxes()) : flip;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.diffkt.tracing.ShallowTracingRewriter, org.diffkt.tracing.TracingVisitor
    @NotNull
    /* renamed from: visitLogSoftmax */
    public Traceable visitLogSoftmax2(@NotNull TracingTensor.LogSoftmax logSoftmax) {
        Intrinsics.checkNotNullParameter(logSoftmax, "x");
        TracingTensor rewrittenForm = rewrittenForm(logSoftmax.getX());
        return rewrittenForm != logSoftmax.getX() ? new TracingTensor.LogSoftmax(rewrittenForm, logSoftmax.getAxis()) : logSoftmax;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.diffkt.tracing.ShallowTracingRewriter, org.diffkt.tracing.TracingVisitor
    @NotNull
    /* renamed from: visitLogSoftmaxGrad */
    public Traceable visitLogSoftmaxGrad2(@NotNull TracingTensor.LogSoftmaxGrad logSoftmaxGrad) {
        Intrinsics.checkNotNullParameter(logSoftmaxGrad, "x");
        TracingTensor rewrittenForm = rewrittenForm(logSoftmaxGrad.getX());
        TracingTensor rewrittenForm2 = rewrittenForm(logSoftmaxGrad.getLogSoftmax());
        TracingTensor rewrittenForm3 = rewrittenForm(logSoftmaxGrad.getUpstream());
        return (rewrittenForm == logSoftmaxGrad.getX() && rewrittenForm2 == logSoftmaxGrad.getLogSoftmax() && Intrinsics.areEqual(rewrittenForm3, logSoftmaxGrad.getUpstream())) ? logSoftmaxGrad : new TracingTensor.LogSoftmaxGrad(rewrittenForm, logSoftmaxGrad.getAxis(), rewrittenForm2, rewrittenForm3);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.diffkt.tracing.ShallowTracingRewriter, org.diffkt.tracing.TracingVisitor
    @NotNull
    /* renamed from: visitPow */
    public Traceable visitPow2(@NotNull TracingTensor.Pow pow) {
        Intrinsics.checkNotNullParameter(pow, "x");
        TracingTensor rewrittenForm = rewrittenForm(pow.getBase());
        if (rewrittenForm != pow.getBase()) {
            return rewrittenForm instanceof TracingScalar ? new TracingScalar.Pow((TracingScalar) rewrittenForm, pow.getExponent()) : new TracingTensor.Pow(rewrittenForm, pow.getExponent());
        }
        return pow;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.diffkt.tracing.ShallowTracingRewriter, org.diffkt.tracing.TracingVisitor
    @NotNull
    /* renamed from: visitView1 */
    public Traceable visitView12(@NotNull TracingTensor.View1 view1) {
        Intrinsics.checkNotNullParameter(view1, "x");
        TracingTensor rewrittenForm = rewrittenForm(view1.getX());
        return rewrittenForm != view1.getX() ? new TracingTensor.View1(rewrittenForm, view1.getIndexes(), view1.getShape()) : view1;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.diffkt.tracing.ShallowTracingRewriter, org.diffkt.tracing.TracingVisitor
    @NotNull
    /* renamed from: visitView2 */
    public Traceable visitView22(@NotNull TracingTensor.View2 view2) {
        Intrinsics.checkNotNullParameter(view2, "x");
        TracingTensor rewrittenForm = rewrittenForm(view2.getX());
        if (rewrittenForm != view2.getX()) {
            return view2 instanceof DScalar ? new TracingScalar.View2(rewrittenForm, view2.getIndex(), view2.getAxis()) : new TracingTensor.View2(rewrittenForm, view2.getIndex(), view2.getAxis(), view2.getShape());
        }
        return view2;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.diffkt.tracing.ShallowTracingRewriter, org.diffkt.tracing.TracingVisitor
    @NotNull
    /* renamed from: visitView3 */
    public Traceable visitView32(@NotNull TracingTensor.View3 view3) {
        Intrinsics.checkNotNullParameter(view3, "x");
        TracingTensor rewrittenForm = rewrittenForm(view3.getX());
        return rewrittenForm != view3.getX() ? new TracingTensor.View3(rewrittenForm, view3.getIndex(), view3.getAxis()) : view3;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.diffkt.tracing.ShallowTracingRewriter, org.diffkt.tracing.TracingVisitor
    @NotNull
    /* renamed from: visitReshape */
    public Traceable visitReshape2(@NotNull TracingTensor.Reshape reshape) {
        Intrinsics.checkNotNullParameter(reshape, "x");
        TracingTensor rewrittenForm = rewrittenForm(reshape.getX());
        return rewrittenForm != reshape.getX() ? new TracingTensor.Reshape(rewrittenForm, reshape.getShape()) : reshape;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.diffkt.tracing.ShallowTracingRewriter, org.diffkt.tracing.TracingVisitor
    @NotNull
    /* renamed from: visitReshapeToScalar */
    public Traceable visitReshapeToScalar2(@NotNull TracingScalar.ReshapeToScalar reshapeToScalar) {
        Intrinsics.checkNotNullParameter(reshapeToScalar, "x");
        TracingTensor rewrittenForm = rewrittenForm(reshapeToScalar.getX());
        return rewrittenForm != reshapeToScalar.getX() ? new TracingScalar.ReshapeToScalar(rewrittenForm) : reshapeToScalar;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.diffkt.tracing.ShallowTracingRewriter, org.diffkt.tracing.TracingVisitor
    @NotNull
    /* renamed from: visitSqueeze */
    public Traceable visitSqueeze2(@NotNull TracingTensor.Squeeze squeeze) {
        Intrinsics.checkNotNullParameter(squeeze, "x");
        TracingTensor rewrittenForm = rewrittenForm(squeeze.getX());
        if (rewrittenForm != squeeze.getX()) {
            return squeeze instanceof DScalar ? new TracingScalar.Squeeze(rewrittenForm) : new TracingTensor.Squeeze(rewrittenForm, squeeze.getAxis());
        }
        return squeeze;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.diffkt.tracing.ShallowTracingRewriter, org.diffkt.tracing.TracingVisitor
    @NotNull
    /* renamed from: visitUnsqueeze */
    public Traceable visitUnsqueeze2(@NotNull TracingTensor.Unsqueeze unsqueeze) {
        Intrinsics.checkNotNullParameter(unsqueeze, "x");
        TracingTensor rewrittenForm = rewrittenForm(unsqueeze.getX());
        return rewrittenForm != unsqueeze.getX() ? new TracingTensor.Unsqueeze(rewrittenForm, unsqueeze.getAxis()) : unsqueeze;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.diffkt.tracing.ShallowTracingRewriter, org.diffkt.tracing.TracingVisitor
    @NotNull
    /* renamed from: visitTranspose */
    public Traceable visitTranspose2(@NotNull TracingTensor.Transpose transpose) {
        Intrinsics.checkNotNullParameter(transpose, "x");
        TracingTensor rewrittenForm = rewrittenForm(transpose.getX());
        return rewrittenForm != transpose.getX() ? new TracingTensor.Transpose(rewrittenForm, transpose.getAxes()) : transpose;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.diffkt.tracing.ShallowTracingRewriter, org.diffkt.tracing.TracingVisitor
    @NotNull
    /* renamed from: visitRelu */
    public Traceable visitRelu2(@NotNull TracingTensor.Relu relu) {
        Intrinsics.checkNotNullParameter(relu, "x");
        TracingTensor rewrittenForm = rewrittenForm(relu.getX());
        if (rewrittenForm != relu.getX()) {
            return rewrittenForm instanceof TracingScalar ? new TracingScalar.Relu(rewrittenForm) : new TracingTensor.Relu(rewrittenForm);
        }
        return relu;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.diffkt.tracing.ShallowTracingRewriter, org.diffkt.tracing.TracingVisitor
    @NotNull
    /* renamed from: visitReluGrad */
    public Traceable visitReluGrad2(@NotNull TracingTensor.ReluGrad reluGrad) {
        Intrinsics.checkNotNullParameter(reluGrad, "x");
        TracingTensor rewrittenForm = rewrittenForm(reluGrad.getX());
        TracingTensor rewrittenForm2 = rewrittenForm(reluGrad.getUpstream());
        if (rewrittenForm == reluGrad.getX() && Intrinsics.areEqual(rewrittenForm2, reluGrad.getUpstream())) {
            return reluGrad;
        }
        return rewrittenForm instanceof TracingScalar ? new TracingScalar.ReluGrad(rewrittenForm, rewrittenForm2) : new TracingTensor.ReluGrad(rewrittenForm, rewrittenForm2);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.diffkt.tracing.ShallowTracingRewriter, org.diffkt.tracing.TracingVisitor
    @NotNull
    /* renamed from: visitSigmoid */
    public Traceable visitSigmoid2(@NotNull TracingTensor.Sigmoid sigmoid) {
        Intrinsics.checkNotNullParameter(sigmoid, "x");
        TracingTensor rewrittenForm = rewrittenForm(sigmoid.getX());
        if (rewrittenForm != sigmoid.getX()) {
            return rewrittenForm instanceof TracingScalar ? new TracingScalar.Sigmoid(rewrittenForm) : new TracingTensor.Sigmoid(rewrittenForm);
        }
        return sigmoid;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.diffkt.tracing.ShallowTracingRewriter, org.diffkt.tracing.TracingVisitor
    @NotNull
    /* renamed from: visitSum */
    public Traceable visitSum2(@NotNull TracingTensor.Sum sum) {
        Intrinsics.checkNotNullParameter(sum, "x");
        TracingTensor rewrittenForm = rewrittenForm(sum.getX());
        if (rewrittenForm != sum.getX()) {
            return sum instanceof TracingScalar ? new TracingScalar.Sum(rewrittenForm) : new TracingTensor.Sum(rewrittenForm, sum.getAxes(), sum.getKeepDims());
        }
        return sum;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.diffkt.tracing.ShallowTracingRewriter, org.diffkt.tracing.TracingVisitor
    @NotNull
    /* renamed from: visitAvgPool */
    public Traceable visitAvgPool2(@NotNull TracingTensor.AvgPool avgPool) {
        Intrinsics.checkNotNullParameter(avgPool, "x");
        throw new NotImplementedError("An operation is not implemented: Not yet implemented");
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.diffkt.tracing.ShallowTracingRewriter, org.diffkt.tracing.TracingVisitor
    @NotNull
    /* renamed from: visitAvgPoolGrad */
    public Traceable visitAvgPoolGrad2(@NotNull TracingTensor.AvgPoolGrad avgPoolGrad) {
        Intrinsics.checkNotNullParameter(avgPoolGrad, "x");
        throw new NotImplementedError("An operation is not implemented: Not yet implemented");
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.diffkt.tracing.ShallowTracingRewriter, org.diffkt.tracing.TracingVisitor
    @NotNull
    /* renamed from: visitMaxPoolWithIndices */
    public Traceable visitMaxPoolWithIndices2(@NotNull TracingTensor.MaxPoolWithIndices maxPoolWithIndices) {
        Intrinsics.checkNotNullParameter(maxPoolWithIndices, "x");
        throw new NotImplementedError("An operation is not implemented: Not yet implemented");
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.diffkt.tracing.ShallowTracingRewriter, org.diffkt.tracing.TracingVisitor
    @NotNull
    /* renamed from: visitGather */
    public Traceable visitGather2(@NotNull TracingTensor.Gather gather) {
        Intrinsics.checkNotNullParameter(gather, "x");
        TracingTensor rewrittenForm = rewrittenForm(gather.getX());
        return rewrittenForm != gather.getX() ? new TracingTensor.Gather(rewrittenForm, gather.getIndexes(), gather.getAxis(), gather.getPaddingIndex()) : gather;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.diffkt.tracing.ShallowTracingRewriter, org.diffkt.tracing.TracingVisitor
    @NotNull
    /* renamed from: visitGatherAtIndices */
    public Traceable visitGatherAtIndices2(@NotNull TracingTensor.GatherAtIndices gatherAtIndices) {
        Intrinsics.checkNotNullParameter(gatherAtIndices, "x");
        throw new NotImplementedError("An operation is not implemented: Not yet implemented");
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.diffkt.tracing.ShallowTracingRewriter, org.diffkt.tracing.TracingVisitor
    @NotNull
    /* renamed from: visitScatter */
    public Traceable visitScatter2(@NotNull TracingTensor.Scatter scatter) {
        Intrinsics.checkNotNullParameter(scatter, "x");
        TracingTensor rewrittenForm = rewrittenForm(scatter.getX());
        return rewrittenForm != scatter.getX() ? new TracingTensor.Scatter(rewrittenForm, scatter.getIndexes(), scatter.getAxis(), scatter.getShape(), scatter.getPaddingIndex()) : scatter;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.diffkt.tracing.ShallowTracingRewriter, org.diffkt.tracing.TracingVisitor
    @NotNull
    /* renamed from: visitScatterAtIndices */
    public Traceable visitScatterAtIndices2(@NotNull TracingTensor.ScatterAtIndices scatterAtIndices) {
        Intrinsics.checkNotNullParameter(scatterAtIndices, "x");
        throw new NotImplementedError("An operation is not implemented: Not yet implemented");
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.diffkt.tracing.ShallowTracingRewriter, org.diffkt.tracing.TracingVisitor
    @NotNull
    /* renamed from: visitCompare */
    public Traceable visitCompare2(@NotNull TracingTensor.Compare compare) {
        Intrinsics.checkNotNullParameter(compare, "x");
        TracingTensor rewrittenForm = rewrittenForm(compare.getLeft());
        TracingTensor rewrittenForm2 = rewrittenForm(compare.getRight());
        return (Intrinsics.areEqual(compare.getLeft(), rewrittenForm) && Intrinsics.areEqual(compare.getRight(), rewrittenForm2)) ? compare : compare instanceof TracingScalar ? new TracingScalar.Compare(rewrittenForm, rewrittenForm2, compare.getComparison()) : new TracingTensor.Compare(rewrittenForm, rewrittenForm2, compare.getComparison());
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.diffkt.tracing.ShallowTracingRewriter, org.diffkt.tracing.TracingVisitor
    @NotNull
    /* renamed from: visitIfThenElse */
    public Traceable visitIfThenElse2(@NotNull TracingTensor.IfThenElse ifThenElse) {
        Intrinsics.checkNotNullParameter(ifThenElse, "x");
        TracingTensor rewrittenForm = rewrittenForm(ifThenElse.getCond());
        TracingTensor rewrittenForm2 = rewrittenForm(ifThenElse.getWhenTrue());
        TracingTensor rewrittenForm3 = rewrittenForm(ifThenElse.getWhenFalse());
        return (Intrinsics.areEqual(ifThenElse.getCond(), rewrittenForm) && Intrinsics.areEqual(ifThenElse.getWhenTrue(), rewrittenForm2) && Intrinsics.areEqual(ifThenElse.getWhenFalse(), rewrittenForm3)) ? ifThenElse : ifThenElse instanceof TracingScalar ? new TracingScalar.IfThenElse(rewrittenForm, rewrittenForm2, rewrittenForm3) : new TracingTensor.IfThenElse(rewrittenForm, rewrittenForm2, rewrittenForm3);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.diffkt.tracing.ShallowTracingRewriter, org.diffkt.tracing.TracingVisitor
    @NotNull
    /* renamed from: visitRandomFloats */
    public Traceable visitRandomFloats2(@NotNull TracingTensor.RandomFloats randomFloats) {
        Intrinsics.checkNotNullParameter(randomFloats, "x");
        TracingRandomKey rewrittenForm = rewrittenForm(randomFloats.getKey());
        return rewrittenForm != randomFloats.getKey() ? new TracingTensor.RandomFloats(rewrittenForm, randomFloats.getShape()) : randomFloats;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.diffkt.tracing.ShallowTracingRewriter, org.diffkt.tracing.TracingVisitor
    @NotNull
    /* renamed from: visitRandomSplit */
    public Traceable visitRandomSplit2(@NotNull TracingRandomKey.Split split) {
        Intrinsics.checkNotNullParameter(split, "x");
        TracingRandomKey rewrittenForm = rewrittenForm(split.getKey());
        return !Intrinsics.areEqual(rewrittenForm, split.getKey()) ? new TracingRandomKey.Split(rewrittenForm, split.getN()) : split;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.diffkt.tracing.ShallowTracingRewriter, org.diffkt.tracing.TracingVisitor
    @NotNull
    /* renamed from: visitRandomSplitPart */
    public Traceable visitRandomSplitPart2(@NotNull TracingRandomKey.SplitPart splitPart) {
        Intrinsics.checkNotNullParameter(splitPart, "x");
        TracingRandomKey rewrittenForm = rewrittenForm(splitPart.getSplit());
        return !Intrinsics.areEqual(rewrittenForm, splitPart.getSplit()) ? new TracingRandomKey.SplitPart(rewrittenForm, splitPart.getSplitIndex()) : splitPart;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.diffkt.tracing.ShallowTracingRewriter, org.diffkt.tracing.TracingVisitor
    @NotNull
    /* renamed from: visitRandomVariable */
    public Traceable visitRandomVariable2(@NotNull TracingRandomKey.Variable variable) {
        Intrinsics.checkNotNullParameter(variable, "x");
        return variable;
    }
}
