package org.diffkt.model;

import java.nio.ByteBuffer;
import java.util.List;
import kotlin.Metadata;
import kotlin.collections.CollectionsKt;
import kotlin.jvm.functions.Function2;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import kotlin.jvm.internal.SerializedIr;
import kotlin.random.Random;
import org.diffkt.Convolve;
import org.diffkt.DTensor;
import org.diffkt.Device;
import org.diffkt.FloatScalar;
import org.diffkt.FloatTensor;
import org.diffkt.MatmulKt;
import org.diffkt.OnDevice;
import org.diffkt.PlusKt;
import org.diffkt.Shape;
import org.diffkt.UtilsKt;
import org.diffkt.Wrapper;
import org.diffkt.model.Activation;
import org.diffkt.model.Trainable;
import org.diffkt.model.TrainableComponent;
import org.diffkt.model.TrainableLayerSingleInput;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

/* compiled from: Dense.kt */
@SerializedIr(b = {"��"})
@Metadata(mv = {Convolve.H_AXIS, 6, Convolve.N_AXIS}, k = Convolve.H_AXIS, xi = 48, d1 = {"��d\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0010\b\n\u0002\b\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0010\u000b\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0004\n\u0002\u0018\u0002\n\u0002\b\u0005\n\u0002\u0018\u0002\n\u0002\b\u0005\n\u0002\u0010 \n\u0002\b\b\n\u0002\u0010��\n\u0002\b\u0003\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0002\u0018�� 02\b\u0012\u0004\u0012\u00020��0\u0001:\u00010Bk\b\u0016\u0012\u0006\u0010\u0002\u001a\u00020\u0003\u0012\u0006\u0010\u0004\u001a\u00020\u0003\u0012\u0006\u0010\u0005\u001a\u00020\u0006\u0012\b\b\u0002\u0010\u0007\u001a\u00020\b\u0012\b\b\u0002\u0010\t\u001a\u00020\n\u0012\u001a\b\u0002\u0010\u000b\u001a\u0014\u0012\u0004\u0012\u00020\r\u0012\u0004\u0012\u00020\u0006\u0012\u0004\u0012\u00020\u000e0\f\u0012\u001a\b\u0002\u0010\u000f\u001a\u0014\u0012\u0004\u0012\u00020\r\u0012\u0004\u0012\u00020\u0006\u0012\u0004\u0012\u00020\u000e0\f¢\u0006\u0002\u0010\u0010B_\b\u0016\u0012\u0006\u0010\u0002\u001a\u00020\u0003\u0012\u0006\u0010\u0004\u001a\u00020\u0003\u0012\u0006\u0010\t\u001a\u00020\n\u0012\u0006\u0010\u0005\u001a\u00020\u0006\u0012\u001a\b\u0002\u0010\u000b\u001a\u0014\u0012\u0004\u0012\u00020\r\u0012\u0004\u0012\u00020\u0006\u0012\u0004\u0012\u00020\u000e0\f\u0012\u001a\b\u0002\u0010\u000f\u001a\u0014\u0012\u0004\u0012\u00020\r\u0012\u0004\u0012\u00020\u0006\u0012\u0004\u0012\u00020\u000e0\f¢\u0006\u0002\u0010\u0011B'\b\u0002\u0012\u0006\u0010\u0012\u001a\u00020\u0013\u0012\u0006\u0010\u0014\u001a\u00020\u0013\u0012\u0006\u0010\t\u001a\u00020\n\u0012\u0006\u0010\u0007\u001a\u00020\b¢\u0006\u0002\u0010\u0015J\u000e\u0010$\u001a\u00020\u00192\u0006\u0010%\u001a\u00020\u0019J\u0013\u0010&\u001a\u00020\b2\b\u0010'\u001a\u0004\u0018\u00010(H\u0096\u0002J\b\u0010)\u001a\u00020\u0003H\u0016J\u0011\u0010*\u001a\u00020\u00192\u0006\u0010%\u001a\u00020\u0019H\u0096\u0002J\u001a\u0010+\u001a\u00020��2\u0010\u0010\u001e\u001a\f\u0012\b\u0012\u0006\u0012\u0002\b\u00030,0\u001fH\u0016J\u0010\u0010-\u001a\u00020��2\u0006\u0010.\u001a\u00020/H\u0016R\u0011\u0010\t\u001a\u00020\n¢\u0006\b\n��\u001a\u0004\b\u0016\u0010\u0017R\u0011\u0010\u0018\u001a\u00020\u00198F¢\u0006\u0006\u001a\u0004\b\u001a\u0010\u001bR\u0011\u0010\u0007\u001a\u00020\b¢\u0006\b\n��\u001a\u0004\b\u001c\u0010\u001dR\u000e\u0010\u0014\u001a\u00020\u0013X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u0012\u001a\u00020\u0013X\u0082\u0004¢\u0006\u0002\n��R\u001a\u0010\u001e\u001a\b\u0012\u0004\u0012\u00020\u00130\u001fX\u0096\u0004¢\u0006\b\n��\u001a\u0004\b \u0010!R\u0011\u0010\"\u001a\u00020\u00198F¢\u0006\u0006\u001a\u0004\b#\u0010\u001b¨\u00061"}, d2 = {"Lorg/diffkt/model/Dense;", "Lorg/diffkt/model/TrainableLayerSingleInput;", "numInputs", "", "numOutputs", "random", "Lkotlin/random/Random;", "bias", "", "activation", "Lorg/diffkt/model/Activation;", "weightInit", "Lkotlin/Function2;", "Lorg/diffkt/Shape;", "Lorg/diffkt/FloatTensor;", "biasInit", "(IILkotlin/random/Random;ZLorg/diffkt/model/Activation;Lkotlin/jvm/functions/Function2;Lkotlin/jvm/functions/Function2;)V", "(IILorg/diffkt/model/Activation;Lkotlin/random/Random;Lkotlin/jvm/functions/Function2;Lkotlin/jvm/functions/Function2;)V", "trainableW", "Lorg/diffkt/model/TrainableTensor;", "trainableB", "(Lorg/diffkt/model/TrainableTensor;Lorg/diffkt/model/TrainableTensor;Lorg/diffkt/model/Activation;Z)V", "getActivation", "()Lorg/diffkt/model/Activation;", "b", "Lorg/diffkt/DTensor;", "getB", "()Lorg/diffkt/DTensor;", "getBias", "()Z", "trainables", "", "getTrainables", "()Ljava/util/List;", "w", "getW", "apply", "input", "equals", "other", "", "hashCode", "invoke", "withTrainables", "Lorg/diffkt/model/Trainable;", "wrap", "wrapper", "Lorg/diffkt/Wrapper;", "Companion", "api"})
/* loaded from: input_file:org/diffkt/model/Dense.class */
public final class Dense implements TrainableLayerSingleInput<Dense> {

    @NotNull
    private final TrainableTensor trainableW;

    @NotNull
    private final TrainableTensor trainableB;

    @NotNull
    private final Activation activation;
    private final boolean bias;

    @NotNull
    private final List<TrainableTensor> trainables;

    @NotNull
    public static final Companion Companion = new Companion(null);

    @NotNull
    private static final Activation.Identity defaultActivation = Activation.Identity.INSTANCE;

    /* compiled from: Dense.kt */
    @Metadata(mv = {Convolve.H_AXIS, 6, Convolve.N_AXIS}, k = Convolve.H_AXIS, xi = 48, d1 = {"��*\n\u0002\u0018\u0002\n\u0002\u0010��\n\u0002\b\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0010\b\n��\b\u0086\u0003\u0018��2\u00020\u0001B\u0007\b\u0002¢\u0006\u0002\u0010\u0002J\"\u0010\u0005\u001a\u0014\u0012\u0004\u0012\u00020\u0007\u0012\u0004\u0012\u00020\b\u0012\u0004\u0012\u00020\t0\u00062\u0006\u0010\n\u001a\u00020\u000bH\u0002R\u000e\u0010\u0003\u001a\u00020\u0004X\u0082\u0004¢\u0006\u0002\n��¨\u0006\f"}, d2 = {"Lorg/diffkt/model/Dense$Companion;", "", "()V", "defaultActivation", "Lorg/diffkt/model/Activation$Identity;", "defaultInit", "Lkotlin/Function2;", "Lorg/diffkt/Shape;", "Lkotlin/random/Random;", "Lorg/diffkt/FloatTensor;", "numInputs", "", "api"})
    /* loaded from: input_file:org/diffkt/model/Dense$Companion.class */
    public static final class Companion {
        private Companion() {
        }

        /* JADX INFO: Access modifiers changed from: private */
        public final Function2<Shape, Random, FloatTensor> defaultInit(int i) {
            float sqrt = (float) Math.sqrt(1.0f / i);
            return Initializer.INSTANCE.uniform(-sqrt, sqrt);
        }

        public /* synthetic */ Companion(DefaultConstructorMarker defaultConstructorMarker) {
            this();
        }
    }

    private Dense(TrainableTensor trainableTensor, TrainableTensor trainableTensor2, Activation activation, boolean z) {
        this.trainableW = trainableTensor;
        this.trainableB = trainableTensor2;
        this.activation = activation;
        this.bias = z;
        this.trainables = this.bias ? CollectionsKt.listOf(new TrainableTensor[]{this.trainableW, this.trainableB}) : CollectionsKt.listOf(this.trainableW);
    }

    @NotNull
    public final Activation getActivation() {
        return this.activation;
    }

    public final boolean getBias() {
        return this.bias;
    }

    /* JADX WARN: 'this' call moved to the top of the method (can break code semantics) */
    public Dense(int i, int i2, @NotNull Random random, boolean z, @NotNull Activation activation, @NotNull Function2<? super Shape, ? super Random, ? extends FloatTensor> function2, @NotNull Function2<? super Shape, ? super Random, ? extends FloatTensor> function22) {
        this(new TrainableTensor((DTensor) function2.invoke(Shape.Companion.invoke(i, i2), random)), new TrainableTensor(z ? (FloatTensor) function22.invoke(Shape.Companion.invoke(i2), random) : FloatScalar.Companion.getZERO()), activation, z);
        Intrinsics.checkNotNullParameter(random, "random");
        Intrinsics.checkNotNullParameter(activation, "activation");
        Intrinsics.checkNotNullParameter(function2, "weightInit");
        Intrinsics.checkNotNullParameter(function22, "biasInit");
    }

    public /* synthetic */ Dense(int i, int i2, Random random, boolean z, Activation activation, Function2 function2, Function2 function22, int i3, DefaultConstructorMarker defaultConstructorMarker) {
        this(i, i2, random, (i3 & 8) != 0 ? true : z, (i3 & 16) != 0 ? defaultActivation : activation, (i3 & 32) != 0 ? Companion.defaultInit(i) : function2, (i3 & 64) != 0 ? Companion.defaultInit(i) : function22);
    }

    /* JADX WARN: 'this' call moved to the top of the method (can break code semantics) */
    public Dense(int i, int i2, @NotNull Activation activation, @NotNull Random random, @NotNull Function2<? super Shape, ? super Random, ? extends FloatTensor> function2, @NotNull Function2<? super Shape, ? super Random, ? extends FloatTensor> function22) {
        this(i, i2, random, true, activation, function2, function22);
        Intrinsics.checkNotNullParameter(activation, "activation");
        Intrinsics.checkNotNullParameter(random, "random");
        Intrinsics.checkNotNullParameter(function2, "weightInit");
        Intrinsics.checkNotNullParameter(function22, "biasInit");
    }

    public /* synthetic */ Dense(int i, int i2, Activation activation, Random random, Function2 function2, Function2 function22, int i3, DefaultConstructorMarker defaultConstructorMarker) {
        this(i, i2, activation, random, (i3 & 16) != 0 ? Companion.defaultInit(i) : function2, (i3 & 32) != 0 ? Companion.defaultInit(i) : function22);
    }

    @NotNull
    public final DTensor getW() {
        return this.trainableW.getTensor();
    }

    @NotNull
    public final DTensor getB() {
        return this.trainableB.getTensor();
    }

    @Override // org.diffkt.model.TrainableComponent
    @NotNull
    public List<TrainableTensor> getTrainables() {
        return this.trainables;
    }

    @Override // org.diffkt.model.TrainableComponent
    @NotNull
    public Dense withTrainables(@NotNull List<? extends Trainable<?>> list) {
        TrainableTensor trainableTensor;
        Intrinsics.checkNotNullParameter(list, "trainables");
        Object[] array = list.toArray(new Trainable[0]);
        if (array == null) {
            throw new NullPointerException("null cannot be cast to non-null type kotlin.Array<T of kotlin.collections.ArraysKt__ArraysJVMKt.toTypedArray>");
        }
        Trainable[] trainableArr = (Trainable[]) array;
        Trainable trainable = trainableArr[0];
        Intrinsics.checkNotNull(trainable, "null cannot be cast to non-null type org.diffkt.model.TrainableTensor");
        TrainableTensor trainableTensor2 = (TrainableTensor) trainable;
        if (trainableArr.length > 1) {
            Trainable trainable2 = trainableArr[1];
            Intrinsics.checkNotNull(trainable2, "null cannot be cast to non-null type org.diffkt.model.TrainableTensor");
            trainableTensor = (TrainableTensor) trainable2;
        } else {
            trainableTensor = this.trainableB;
        }
        return new Dense(trainableTensor2, trainableTensor, this.activation, this.bias);
    }

    @Override // org.diffkt.model.LayerSingleInput
    @NotNull
    public DTensor invoke(@NotNull DTensor dTensor) {
        Intrinsics.checkNotNullParameter(dTensor, "input");
        if (dTensor.getRank() >= 2) {
            return apply(dTensor);
        }
        throw new IllegalArgumentException(("input rank to Dense must be at least 2, but was " + dTensor.getRank()).toString());
    }

    @NotNull
    public final DTensor apply(@NotNull DTensor dTensor) {
        Intrinsics.checkNotNullParameter(dTensor, "input");
        return this.activation.invoke(PlusKt.plus(MatmulKt.matmul(dTensor, getW()), getB()));
    }

    @Override // org.diffkt.Wrappable
    @NotNull
    public Dense wrap(@NotNull Wrapper wrapper) {
        Intrinsics.checkNotNullParameter(wrapper, "wrapper");
        return new Dense((TrainableTensor) wrapper.wrap(this.trainableW), (TrainableTensor) wrapper.wrap(this.trainableB), this.activation, this.bias);
    }

    public int hashCode() {
        return UtilsKt.combineHash("Dense", this.trainableW, this.trainableB, this.activation, Boolean.valueOf(this.bias));
    }

    public boolean equals(@Nullable Object obj) {
        return (obj instanceof Dense) && Intrinsics.areEqual(((Dense) obj).trainableW, this.trainableW) && Intrinsics.areEqual(((Dense) obj).trainableB, this.trainableB) && Intrinsics.areEqual(((Dense) obj).activation, this.activation) && ((Dense) obj).bias == this.bias;
    }

    @Override // org.diffkt.model.Layer
    @NotNull
    public DTensor invoke(@NotNull DTensor... dTensorArr) {
        return TrainableLayerSingleInput.DefaultImpls.invoke(this, dTensorArr);
    }

    @Override // org.diffkt.OnDevice
    @NotNull
    public Dense cpu() {
        return (Dense) TrainableLayerSingleInput.DefaultImpls.cpu(this);
    }

    @Override // org.diffkt.OnDevice
    @NotNull
    public Dense gpu() {
        return (Dense) TrainableLayerSingleInput.DefaultImpls.gpu(this);
    }

    @Override // org.diffkt.model.TrainableComponent, org.diffkt.model.Trainable
    @NotNull
    public Dense trainingStep(@NotNull Optimizer<?> optimizer, @NotNull Trainable.Tangent tangent) {
        return (Dense) TrainableLayerSingleInput.DefaultImpls.trainingStep(this, optimizer, tangent);
    }

    @Override // org.diffkt.model.TrainableComponent, org.diffkt.model.Trainable
    @NotNull
    public TrainableComponent.Companion.Tangent extractTangent(@NotNull DTensor dTensor, @NotNull Function2<? super DTensor, ? super DTensor, ? extends DTensor> function2) {
        return TrainableLayerSingleInput.DefaultImpls.extractTangent(this, dTensor, function2);
    }

    @Override // org.diffkt.model.TrainableComponent, org.diffkt.model.Trainable
    @NotNull
    public ByteBuffer store(@NotNull ByteBuffer byteBuffer) {
        return TrainableLayerSingleInput.DefaultImpls.store(this, byteBuffer);
    }

    @Override // org.diffkt.model.Trainable
    @NotNull
    public Dense load(@NotNull ByteBuffer byteBuffer) {
        return (Dense) TrainableLayerSingleInput.DefaultImpls.load(this, byteBuffer);
    }

    @Override // org.diffkt.OnDevice
    @NotNull
    public OnDevice to(@NotNull Device device) {
        return TrainableLayerSingleInput.DefaultImpls.to(this, device);
    }

    @Override // org.diffkt.model.Layer
    @NotNull
    public DTensor getSingleInput(@NotNull DTensor[] dTensorArr) {
        return TrainableLayerSingleInput.DefaultImpls.getSingleInput(this, dTensorArr);
    }

    @Override // org.diffkt.model.TrainableComponent
    public /* bridge */ /* synthetic */ TrainableComponent withTrainables(List list) {
        return withTrainables((List<? extends Trainable<?>>) list);
    }

    @Override // org.diffkt.model.TrainableComponent, org.diffkt.model.Trainable
    public /* bridge */ /* synthetic */ TrainableComponent trainingStep(Optimizer optimizer, Trainable.Tangent tangent) {
        return trainingStep((Optimizer<?>) optimizer, tangent);
    }

    @Override // org.diffkt.model.Trainable
    public /* bridge */ /* synthetic */ Trainable trainingStep(Optimizer optimizer, Trainable.Tangent tangent) {
        return trainingStep((Optimizer<?>) optimizer, tangent);
    }

    @Override // org.diffkt.model.Trainable
    public /* bridge */ /* synthetic */ Trainable.Tangent extractTangent(DTensor dTensor, Function2 function2) {
        return extractTangent(dTensor, (Function2<? super DTensor, ? super DTensor, ? extends DTensor>) function2);
    }
}
