package org.diffkt.model;

import java.util.List;
import kotlin.Metadata;
import kotlin.Pair;
import kotlin.Triple;
import kotlin.collections.CollectionsKt;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import kotlin.jvm.internal.SerializedIr;
import org.diffkt.ConcatKt;
import org.diffkt.Convolve;
import org.diffkt.DTensor;
import org.diffkt.FloatTensor;
import org.diffkt.Shape;
import org.jetbrains.annotations.NotNull;

/* compiled from: BatchNormTraining.kt */
@SerializedIr(b = {"��"})
@Metadata(mv = {Convolve.H_AXIS, 6, Convolve.N_AXIS}, k = Convolve.H_AXIS, xi = 48, d1 = {"��8\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0010\b\n��\n\u0002\u0010\u0007\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0010 \n\u0002\u0018\u0002\n\u0002\b\u0007\u0018�� \u00182\b\u0012\u0004\u0012\u00020��0\u0001:\u0001\u0018B1\b\u0002\u0012\u0006\u0010\u0002\u001a\u00020\u0003\u0012\b\b\u0002\u0010\u0004\u001a\u00020\u0005\u0012\u0006\u0010\u0006\u001a\u00020\u0007\u0012\u0006\u0010\b\u001a\u00020\t\u0012\u0006\u0010\n\u001a\u00020\t¢\u0006\u0002\u0010\u000bJ\u0011\u0010\u0015\u001a\u00020\t2\u0006\u0010\u0016\u001a\u00020\tH\u0096\u0002J\u001a\u0010\u0017\u001a\u00020��2\u0010\u0010\u0010\u001a\f\u0012\b\u0012\u0006\u0012\u0002\b\u00030\u00120\u0011H\u0016R\u000e\u0010\b\u001a\u00020\tX\u0082\u000e¢\u0006\u0002\n��R\u000e\u0010\n\u001a\u00020\tX\u0082\u000e¢\u0006\u0002\n��R \u0010\f\u001a\u000e\u0012\u0004\u0012\u00020\t\u0012\u0004\u0012\u00020\t0\r8VX\u0096\u0004¢\u0006\u0006\u001a\u0004\b\u000e\u0010\u000fR\u001e\u0010\u0010\u001a\f\u0012\b\u0012\u0006\u0012\u0002\b\u00030\u00120\u00118VX\u0096\u0004¢\u0006\u0006\u001a\u0004\b\u0013\u0010\u0014¨\u0006\u0019"}, d2 = {"Lorg/diffkt/model/BatchNormTrainingV1;", "Lorg/diffkt/model/BatchNormTrainingBase;", "numFeatures", "", "momentum", "", "scaleShift", "Lorg/diffkt/model/TrainableTensor;", "runningMean", "Lorg/diffkt/DTensor;", "runningVariance", "(IFLorg/diffkt/model/TrainableTensor;Lorg/diffkt/DTensor;Lorg/diffkt/DTensor;)V", "stats", "Lkotlin/Pair;", "getStats", "()Lkotlin/Pair;", "trainables", "", "Lorg/diffkt/model/Trainable;", "getTrainables", "()Ljava/util/List;", "invoke", "input", "withTrainables", "Companion", "api"})
/* loaded from: input_file:org/diffkt/model/BatchNormTrainingV1.class */
public final class BatchNormTrainingV1 extends BatchNormTrainingBase<BatchNormTrainingV1> {

    @NotNull
    public static final Companion Companion = new Companion(null);

    @NotNull
    private DTensor runningMean;

    @NotNull
    private DTensor runningVariance;

    /* compiled from: BatchNormTraining.kt */
    @Metadata(mv = {Convolve.H_AXIS, 6, Convolve.N_AXIS}, k = Convolve.H_AXIS, xi = 48, d1 = {"��\u001e\n\u0002\u0018\u0002\n\u0002\u0010��\n\u0002\b\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0010\b\n��\n\u0002\u0010\u0007\n��\b\u0086\u0003\u0018��2\u00020\u0001B\u0007\b\u0002¢\u0006\u0002\u0010\u0002J\u001b\u0010\u0003\u001a\u00020\u00042\u0006\u0010\u0005\u001a\u00020\u00062\b\b\u0002\u0010\u0007\u001a\u00020\bH\u0086\u0002¨\u0006\t"}, d2 = {"Lorg/diffkt/model/BatchNormTrainingV1$Companion;", "", "()V", "invoke", "Lorg/diffkt/model/BatchNormTrainingV1;", "numFeatures", "", "momentum", "", "api"})
    /* loaded from: input_file:org/diffkt/model/BatchNormTrainingV1$Companion.class */
    public static final class Companion {
        private Companion() {
        }

        @NotNull
        public final BatchNormTrainingV1 invoke(int i, float f) {
            return new BatchNormTrainingV1(i, f, new TrainableTensor(ConcatKt.concat$default(FloatTensor.Companion.ones(Shape.Companion.invoke(1, i)), FloatTensor.Companion.zeros(Shape.Companion.invoke(1, i)), 0, 2, null)), FloatTensor.Companion.zeros(Shape.Companion.invoke(i)), FloatTensor.Companion.ones(Shape.Companion.invoke(i)), null);
        }

        public static /* synthetic */ BatchNormTrainingV1 invoke$default(Companion companion, int i, float f, int i2, Object obj) {
            if ((i2 & 2) != 0) {
                f = 0.1f;
            }
            return companion.invoke(i, f);
        }

        public /* synthetic */ Companion(DefaultConstructorMarker defaultConstructorMarker) {
            this();
        }
    }

    private BatchNormTrainingV1(int i, float f, TrainableTensor trainableTensor, DTensor dTensor, DTensor dTensor2) {
        super(i, f, trainableTensor);
        this.runningMean = dTensor;
        this.runningVariance = dTensor2;
        if (!(0.0f <= f ? f <= 1.0f : false)) {
            throw new IllegalArgumentException("Failed requirement.".toString());
        }
    }

    /* synthetic */ BatchNormTrainingV1(int i, float f, TrainableTensor trainableTensor, DTensor dTensor, DTensor dTensor2, int i2, DefaultConstructorMarker defaultConstructorMarker) {
        this(i, (i2 & 2) != 0 ? 0.1f : f, trainableTensor, dTensor, dTensor2);
    }

    @Override // org.diffkt.model.LayerSingleInput
    @NotNull
    public DTensor invoke(@NotNull DTensor dTensor) {
        Intrinsics.checkNotNullParameter(dTensor, "input");
        Triple<DTensor, DTensor, DTensor> batchNormTrainV1 = BatchNormTrainingKt.batchNormTrainV1(dTensor, getScaleShift().getTensor(), this.runningMean, this.runningVariance, getMomentum());
        DTensor dTensor2 = (DTensor) batchNormTrainV1.component1();
        DTensor dTensor3 = (DTensor) batchNormTrainV1.component2();
        DTensor dTensor4 = (DTensor) batchNormTrainV1.component3();
        this.runningMean = dTensor3;
        this.runningVariance = dTensor4;
        return dTensor2;
    }

    @Override // org.diffkt.model.BatchNormTrainingBase
    @NotNull
    public Pair<DTensor, DTensor> getStats() {
        return new Pair<>(this.runningMean, this.runningVariance);
    }

    @Override // org.diffkt.model.TrainableComponent
    @NotNull
    public List<Trainable<?>> getTrainables() {
        return CollectionsKt.listOf(getScaleShift());
    }

    @Override // org.diffkt.model.TrainableComponent
    @NotNull
    public BatchNormTrainingV1 withTrainables(@NotNull List<? extends Trainable<?>> list) {
        Intrinsics.checkNotNullParameter(list, "trainables");
        if (!(list.size() == 1)) {
            throw new IllegalArgumentException("Failed requirement.".toString());
        }
        Trainable<?> trainable = list.get(0);
        if (trainable instanceof TrainableTensor) {
            return new BatchNormTrainingV1(getNumFeatures(), getMomentum(), (TrainableTensor) trainable, this.runningMean, this.runningVariance);
        }
        throw new IllegalArgumentException("Failed requirement.".toString());
    }

    @Override // org.diffkt.model.TrainableComponent
    public /* bridge */ /* synthetic */ TrainableComponent withTrainables(List list) {
        return withTrainables((List<? extends Trainable<?>>) list);
    }

    public /* synthetic */ BatchNormTrainingV1(int i, float f, TrainableTensor trainableTensor, DTensor dTensor, DTensor dTensor2, DefaultConstructorMarker defaultConstructorMarker) {
        this(i, f, trainableTensor, dTensor, dTensor2);
    }
}
