package org.diffkt.model;

import java.util.List;
import kotlin.Metadata;
import kotlin.Pair;
import kotlin.collections.CollectionsKt;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import kotlin.jvm.internal.SerializedIr;
import kotlin.random.Random;
import org.diffkt.ConcatKt;
import org.diffkt.Convolve;
import org.diffkt.DTensor;
import org.diffkt.FloatTensor;
import org.diffkt.MinusKt;
import org.diffkt.PlusKt;
import org.diffkt.TimesKt;
import org.diffkt.TranscendentalKt;
import org.diffkt.UtilsKt;
import org.diffkt.model.Activation;
import org.diffkt.model.RecurrentBase;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

/* compiled from: GRU.kt */
@SerializedIr(b = {"��"})
@Metadata(mv = {Convolve.H_AXIS, 6, Convolve.N_AXIS}, k = Convolve.H_AXIS, xi = 48, d1 = {"��P\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0010\b\n\u0002\b\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0018\u0002\n\u0002\b\n\n\u0002\u0010 \n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0010\u000b\n��\n\u0002\u0010��\n\u0002\b\u0003\u0018��2\u00020\u0001B5\b\u0016\u0012\u0006\u0010\u0002\u001a\u00020\u0003\u0012\u0006\u0010\u0004\u001a\u00020\u0003\u0012\u0006\u0010\u0005\u001a\u00020\u0006\u0012\n\b\u0002\u0010\u0007\u001a\u0004\u0018\u00010\b\u0012\b\b\u0002\u0010\t\u001a\u00020\n¢\u0006\u0002\u0010\u000bBK\u0012\u0006\u0010\u0002\u001a\u00020\u0003\u0012\u0006\u0010\u0004\u001a\u00020\u0003\u0012\n\b\u0002\u0010\u0007\u001a\u0004\u0018\u00010\b\u0012\b\b\u0002\u0010\f\u001a\u00020\n\u0012\u0006\u0010\r\u001a\u00020\u000e\u0012\u0006\u0010\u000f\u001a\u00020\u000e\u0012\u0006\u0010\u0010\u001a\u00020\u000e\u0012\u0006\u0010\u0011\u001a\u00020\u000e¢\u0006\u0002\u0010\u0012J0\u0010\u001d\u001a\u000e\u0012\u0004\u0012\u00020\b\u0012\u0004\u0012\u00020\b0\u001e2\u0012\u0010\u001f\u001a\u000e\u0012\u0004\u0012\u00020\b\u0012\u0004\u0012\u00020\b0\u001e2\u0006\u0010 \u001a\u00020\bH\u0016J\u0013\u0010!\u001a\u00020\"2\b\u0010#\u001a\u0004\u0018\u00010$H\u0096\u0002J\b\u0010%\u001a\u00020\u0003H\u0016J\u001a\u0010&\u001a\u00020\u00012\u0010\u0010\u0018\u001a\f\u0012\b\u0012\u0006\u0012\u0002\b\u00030\u001a0\u0019H\u0016R\u000e\u0010\u0011\u001a\u00020\u000eX\u0082\u0004¢\u0006\u0002\n��R\u0013\u0010\u0007\u001a\u0004\u0018\u00010\b¢\u0006\b\n��\u001a\u0004\b\u0013\u0010\u0014R\u0011\u0010\u0004\u001a\u00020\u0003¢\u0006\b\n��\u001a\u0004\b\u0015\u0010\u0016R\u0011\u0010\u0002\u001a\u00020\u0003¢\u0006\b\n��\u001a\u0004\b\u0017\u0010\u0016R\u001e\u0010\u0018\u001a\f\u0012\b\u0012\u0006\u0012\u0002\b\u00030\u001a0\u0019X\u0096\u0004¢\u0006\b\n��\u001a\u0004\b\u001b\u0010\u001cR\u000e\u0010\u0010\u001a\u00020\u000eX\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u000f\u001a\u00020\u000eX\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\r\u001a\u00020\u000eX\u0082\u0004¢\u0006\u0002\n��¨\u0006'"}, d2 = {"Lorg/diffkt/model/LinearBeforeResetGRU;", "Lorg/diffkt/model/GRU;", "numInputs", "", "numHidden", "random", "Lkotlin/random/Random;", "initialHidden", "Lorg/diffkt/DTensor;", "acc", "Lorg/diffkt/model/RecurrentBase$RecurrentBase$AccType;", "(IILkotlin/random/Random;Lorg/diffkt/DTensor;Lorg/diffkt/model/RecurrentBase$RecurrentBase$AccType;)V", "accType", "xh2u", "Lorg/diffkt/model/Dense;", "xh2r", "x2n", "h2n", "(IILorg/diffkt/DTensor;Lorg/diffkt/model/RecurrentBase$RecurrentBase$AccType;Lorg/diffkt/model/Dense;Lorg/diffkt/model/Dense;Lorg/diffkt/model/Dense;Lorg/diffkt/model/Dense;)V", "getInitialHidden", "()Lorg/diffkt/DTensor;", "getNumHidden", "()I", "getNumInputs", "trainables", "", "Lorg/diffkt/model/Trainable;", "getTrainables", "()Ljava/util/List;", "cell", "Lkotlin/Pair;", "state", "x", "equals", "", "other", "", "hashCode", "withTrainables", "api"})
/* loaded from: input_file:org/diffkt/model/LinearBeforeResetGRU.class */
public final class LinearBeforeResetGRU extends GRU {
    private final int numInputs;
    private final int numHidden;

    @Nullable
    private final DTensor initialHidden;

    @NotNull
    private final Dense xh2u;

    @NotNull
    private final Dense xh2r;

    @NotNull
    private final Dense x2n;

    @NotNull
    private final Dense h2n;

    @NotNull
    private final List<Trainable<?>> trainables;

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    public LinearBeforeResetGRU(int i, int i2, @Nullable DTensor dTensor, @NotNull RecurrentBase.C0000RecurrentBase.AccType accType, @NotNull Dense dense, @NotNull Dense dense2, @NotNull Dense dense3, @NotNull Dense dense4) {
        super(i2, dTensor, accType);
        Intrinsics.checkNotNullParameter(accType, "accType");
        Intrinsics.checkNotNullParameter(dense, "xh2u");
        Intrinsics.checkNotNullParameter(dense2, "xh2r");
        Intrinsics.checkNotNullParameter(dense3, "x2n");
        Intrinsics.checkNotNullParameter(dense4, "h2n");
        this.numInputs = i;
        this.numHidden = i2;
        this.initialHidden = dTensor;
        this.xh2u = dense;
        this.xh2r = dense2;
        this.x2n = dense3;
        this.h2n = dense4;
        this.trainables = CollectionsKt.listOf(new Dense[]{this.xh2u, this.xh2r, this.x2n, this.h2n});
    }

    public /* synthetic */ LinearBeforeResetGRU(int i, int i2, DTensor dTensor, RecurrentBase.C0000RecurrentBase.AccType accType, Dense dense, Dense dense2, Dense dense3, Dense dense4, int i3, DefaultConstructorMarker defaultConstructorMarker) {
        this(i, i2, (i3 & 4) != 0 ? null : dTensor, (i3 & 8) != 0 ? RecurrentBase.C0000RecurrentBase.AccType.Fold : accType, dense, dense2, dense3, dense4);
    }

    public final int getNumInputs() {
        return this.numInputs;
    }

    public final int getNumHidden() {
        return this.numHidden;
    }

    @Nullable
    public final DTensor getInitialHidden() {
        return this.initialHidden;
    }

    /* JADX WARN: 'this' call moved to the top of the method (can break code semantics) */
    public LinearBeforeResetGRU(int i, int i2, @NotNull Random random, @Nullable DTensor dTensor, @NotNull RecurrentBase.C0000RecurrentBase.AccType accType) {
        this(i, i2, dTensor, accType, new Dense(i + i2, i2, random, false, Activation.Sigmoid.INSTANCE, null, null, 104, null), new Dense(i + i2, i2, random, false, Activation.Sigmoid.INSTANCE, null, null, 104, null), new Dense(i, i2, random, false, Activation.Identity.INSTANCE, null, null, 104, null), new Dense(i2, i2, random, false, Activation.Identity.INSTANCE, null, null, 104, null));
        Intrinsics.checkNotNullParameter(random, "random");
        Intrinsics.checkNotNullParameter(accType, "acc");
    }

    public /* synthetic */ LinearBeforeResetGRU(int i, int i2, Random random, DTensor dTensor, RecurrentBase.C0000RecurrentBase.AccType accType, int i3, DefaultConstructorMarker defaultConstructorMarker) {
        this(i, i2, random, (i3 & 8) != 0 ? null : dTensor, (i3 & 16) != 0 ? RecurrentBase.C0000RecurrentBase.AccType.Fold : accType);
    }

    @Override // org.diffkt.model.TrainableComponent
    @NotNull
    public List<Trainable<?>> getTrainables() {
        return this.trainables;
    }

    @Override // org.diffkt.model.TrainableComponent
    @NotNull
    public GRU withTrainables(@NotNull List<? extends Trainable<?>> list) {
        Intrinsics.checkNotNullParameter(list, "trainables");
        if (!(list.size() == 4)) {
            throw new IllegalArgumentException("Failed requirement.".toString());
        }
        int i = this.numInputs;
        int i2 = this.numHidden;
        DTensor dTensor = this.initialHidden;
        RecurrentBase.C0000RecurrentBase.AccType accType = getAccType();
        Trainable<?> trainable = list.get(0);
        Intrinsics.checkNotNull(trainable, "null cannot be cast to non-null type org.diffkt.model.Dense");
        Trainable<?> trainable2 = list.get(1);
        Intrinsics.checkNotNull(trainable2, "null cannot be cast to non-null type org.diffkt.model.Dense");
        Trainable<?> trainable3 = list.get(2);
        Intrinsics.checkNotNull(trainable3, "null cannot be cast to non-null type org.diffkt.model.Dense");
        Trainable<?> trainable4 = list.get(3);
        Intrinsics.checkNotNull(trainable4, "null cannot be cast to non-null type org.diffkt.model.Dense");
        return new LinearBeforeResetGRU(i, i2, dTensor, accType, (Dense) trainable, (Dense) trainable2, (Dense) trainable3, (Dense) trainable4);
    }

    @Override // org.diffkt.model.RecurrentBase
    @NotNull
    public Pair<DTensor, DTensor> cell(@NotNull Pair<? extends DTensor, ? extends DTensor> pair, @NotNull DTensor dTensor) {
        Intrinsics.checkNotNullParameter(pair, "state");
        Intrinsics.checkNotNullParameter(dTensor, "x");
        DTensor dTensor2 = (DTensor) pair.getFirst();
        DTensor concat = ConcatKt.concat(dTensor, dTensor2, 1);
        DTensor invoke = this.xh2u.invoke(concat);
        DTensor invoke2 = this.xh2r.invoke(concat);
        DTensor plus = PlusKt.plus(TimesKt.times(MinusKt.minus(FloatTensor.Companion.ones(invoke2.getShape()), invoke), TranscendentalKt.tanh(PlusKt.plus(this.x2n.invoke(dTensor), TimesKt.times(invoke2, this.h2n.invoke(dTensor2))))), TimesKt.times(invoke, dTensor2));
        return new Pair<>(plus, plus);
    }

    @Override // org.diffkt.model.GRU
    public int hashCode() {
        return UtilsKt.combineHash("LinearBeforeResetGRU", Integer.valueOf(this.numInputs), Integer.valueOf(this.numHidden), Integer.valueOf(super.hashCode()));
    }

    @Override // org.diffkt.model.GRU
    public boolean equals(@Nullable Object obj) {
        return (obj instanceof LinearBeforeResetGRU) && ((LinearBeforeResetGRU) obj).numInputs == this.numInputs && ((LinearBeforeResetGRU) obj).numHidden == this.numHidden && ((LinearBeforeResetGRU) obj).getAccType() == getAccType() && Intrinsics.areEqual(((LinearBeforeResetGRU) obj).xh2r, this.xh2r) && Intrinsics.areEqual(((LinearBeforeResetGRU) obj).xh2u, this.xh2u) && Intrinsics.areEqual(((LinearBeforeResetGRU) obj).x2n, this.x2n) && Intrinsics.areEqual(((LinearBeforeResetGRU) obj).h2n, this.h2n) && super.equals(obj);
    }

    @Override // org.diffkt.model.TrainableComponent
    public /* bridge */ /* synthetic */ TrainableComponent withTrainables(List list) {
        return withTrainables((List<? extends Trainable<?>>) list);
    }
}
