package org.diffkt.model;

import java.nio.ByteBuffer;
import kotlin.Metadata;
import kotlin.Pair;
import kotlin.jvm.functions.Function2;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import kotlin.jvm.internal.SerializedIr;
import org.diffkt.Convolve;
import org.diffkt.DTensor;
import org.diffkt.Device;
import org.diffkt.OnDevice;
import org.diffkt.UtilsKt;
import org.diffkt.Wrapper;
import org.diffkt.model.BatchNormTrainingBase;
import org.diffkt.model.Trainable;
import org.diffkt.model.TrainableComponent;
import org.diffkt.model.TrainableLayerSingleInput;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

/* compiled from: BatchNormTraining.kt */
@SerializedIr(b = {"��"})
@Metadata(mv = {Convolve.H_AXIS, 6, Convolve.N_AXIS}, k = Convolve.H_AXIS, xi = 48, d1 = {"��F\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0010\b\n��\n\u0002\u0010\u0007\n��\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\t\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0010\u000b\n��\n\u0002\u0010��\n\u0002\b\u0002\b&\u0018��*\u000e\b��\u0010\u0001*\b\u0012\u0004\u0012\u0002H\u00010��2\b\u0012\u0004\u0012\u0002H\u00010\u00022\u00020\u0003B\u001f\u0012\u0006\u0010\u0004\u001a\u00020\u0005\u0012\b\b\u0002\u0010\u0006\u001a\u00020\u0007\u0012\u0006\u0010\b\u001a\u00020\t¢\u0006\u0002\u0010\nJ\u0013\u0010\u001a\u001a\u00020\u001b2\b\u0010\u001c\u001a\u0004\u0018\u00010\u001dH\u0096\u0002J\b\u0010\u001e\u001a\u00020\u0005H\u0016R\u0014\u0010\u000b\u001a\u00020\f8VX\u0096\u0004¢\u0006\u0006\u001a\u0004\b\r\u0010\u000eR\u0014\u0010\u0006\u001a\u00020\u0007X\u0084\u0004¢\u0006\b\n��\u001a\u0004\b\u000f\u0010\u0010R\u0014\u0010\u0004\u001a\u00020\u0005X\u0084\u0004¢\u0006\b\n��\u001a\u0004\b\u0011\u0010\u0012R\u0014\u0010\b\u001a\u00020\tX\u0084\u0004¢\u0006\b\n��\u001a\u0004\b\u0013\u0010\u0014R\u001e\u0010\u0015\u001a\u000e\u0012\u0004\u0012\u00020\u0017\u0012\u0004\u0012\u00020\u00170\u0016X¦\u0004¢\u0006\u0006\u001a\u0004\b\u0018\u0010\u0019¨\u0006\u001f"}, d2 = {"Lorg/diffkt/model/BatchNormTrainingBase;", "T", "Lorg/diffkt/model/TrainableLayerSingleInput;", "Lorg/diffkt/model/LayerWithInferenceMode;", "numFeatures", "", "momentum", "", "scaleShift", "Lorg/diffkt/model/TrainableTensor;", "(IFLorg/diffkt/model/TrainableTensor;)V", "inferenceMode", "Lorg/diffkt/model/AffineTransform;", "getInferenceMode", "()Lorg/diffkt/model/AffineTransform;", "getMomentum", "()F", "getNumFeatures", "()I", "getScaleShift", "()Lorg/diffkt/model/TrainableTensor;", "stats", "Lkotlin/Pair;", "Lorg/diffkt/DTensor;", "getStats", "()Lkotlin/Pair;", "equals", "", "other", "", "hashCode", "api"})
/* loaded from: input_file:org/diffkt/model/BatchNormTrainingBase.class */
public abstract class BatchNormTrainingBase<T extends BatchNormTrainingBase<T>> implements TrainableLayerSingleInput<T>, LayerWithInferenceMode {
    private final int numFeatures;
    private final float momentum;

    @NotNull
    private final TrainableTensor scaleShift;

    public BatchNormTrainingBase(int i, float f, @NotNull TrainableTensor trainableTensor) {
        Intrinsics.checkNotNullParameter(trainableTensor, "scaleShift");
        this.numFeatures = i;
        this.momentum = f;
        this.scaleShift = trainableTensor;
        float f2 = this.momentum;
        if (!(0.0f <= f2 ? f2 <= 1.0f : false)) {
            throw new IllegalArgumentException("Failed requirement.".toString());
        }
    }

    public /* synthetic */ BatchNormTrainingBase(int i, float f, TrainableTensor trainableTensor, int i2, DefaultConstructorMarker defaultConstructorMarker) {
        this(i, (i2 & 2) != 0 ? 0.1f : f, trainableTensor);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public final int getNumFeatures() {
        return this.numFeatures;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public final float getMomentum() {
        return this.momentum;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @NotNull
    public final TrainableTensor getScaleShift() {
        return this.scaleShift;
    }

    @NotNull
    public abstract Pair<DTensor, DTensor> getStats();

    @Override // org.diffkt.model.LayerWithInferenceMode
    @NotNull
    public AffineTransform getInferenceMode() {
        Pair<DTensor, DTensor> stats = getStats();
        return BatchNormTrainingKt.freezeBatchNorm(this.scaleShift.getTensor(), (DTensor) stats.component1(), (DTensor) stats.component2());
    }

    public int hashCode() {
        return UtilsKt.combineHash("BatchNormTrainingBase", Integer.valueOf(this.numFeatures), Float.valueOf(this.momentum), this.scaleShift);
    }

    public boolean equals(@Nullable Object obj) {
        if ((obj instanceof BatchNormTrainingBase) && ((BatchNormTrainingBase) obj).numFeatures == this.numFeatures) {
            if ((((BatchNormTrainingBase) obj).momentum == this.momentum) && Intrinsics.areEqual(((BatchNormTrainingBase) obj).scaleShift, this.scaleShift)) {
                return true;
            }
        }
        return false;
    }

    @Override // org.diffkt.OnDevice
    @NotNull
    public T cpu() {
        return (T) TrainableLayerSingleInput.DefaultImpls.cpu(this);
    }

    @Override // org.diffkt.OnDevice
    @NotNull
    public T gpu() {
        return (T) TrainableLayerSingleInput.DefaultImpls.gpu(this);
    }

    @Override // org.diffkt.model.TrainableComponent, org.diffkt.model.Trainable
    @NotNull
    public T trainingStep(@NotNull Optimizer<?> optimizer, @NotNull Trainable.Tangent tangent) {
        return (T) TrainableLayerSingleInput.DefaultImpls.trainingStep(this, optimizer, tangent);
    }

    @Override // org.diffkt.model.TrainableComponent, org.diffkt.model.Trainable
    @NotNull
    public TrainableComponent.Companion.Tangent extractTangent(@NotNull DTensor dTensor, @NotNull Function2<? super DTensor, ? super DTensor, ? extends DTensor> function2) {
        return TrainableLayerSingleInput.DefaultImpls.extractTangent(this, dTensor, function2);
    }

    @Override // org.diffkt.model.TrainableComponent, org.diffkt.model.Trainable
    @NotNull
    public ByteBuffer store(@NotNull ByteBuffer byteBuffer) {
        return TrainableLayerSingleInput.DefaultImpls.store(this, byteBuffer);
    }

    @Override // org.diffkt.model.Trainable
    @NotNull
    public T load(@NotNull ByteBuffer byteBuffer) {
        return (T) TrainableLayerSingleInput.DefaultImpls.load(this, byteBuffer);
    }

    @Override // org.diffkt.Wrappable
    @NotNull
    public T wrap(@NotNull Wrapper wrapper) {
        return (T) TrainableLayerSingleInput.DefaultImpls.wrap(this, wrapper);
    }

    @Override // org.diffkt.OnDevice
    @NotNull
    public OnDevice to(@NotNull Device device) {
        return TrainableLayerSingleInput.DefaultImpls.to(this, device);
    }

    @Override // org.diffkt.model.Layer
    @NotNull
    public DTensor invoke(@NotNull DTensor... dTensorArr) {
        return TrainableLayerSingleInput.DefaultImpls.invoke(this, dTensorArr);
    }

    @Override // org.diffkt.model.Layer
    @NotNull
    public DTensor getSingleInput(@NotNull DTensor[] dTensorArr) {
        return TrainableLayerSingleInput.DefaultImpls.getSingleInput(this, dTensorArr);
    }

    @Override // org.diffkt.model.TrainableComponent, org.diffkt.model.Trainable
    public /* bridge */ /* synthetic */ TrainableComponent trainingStep(Optimizer optimizer, Trainable.Tangent tangent) {
        return trainingStep((Optimizer<?>) optimizer, tangent);
    }

    @Override // org.diffkt.model.Trainable
    public /* bridge */ /* synthetic */ Trainable trainingStep(Optimizer optimizer, Trainable.Tangent tangent) {
        return trainingStep((Optimizer<?>) optimizer, tangent);
    }

    @Override // org.diffkt.model.Trainable
    public /* bridge */ /* synthetic */ Trainable.Tangent extractTangent(DTensor dTensor, Function2 function2) {
        return extractTangent(dTensor, (Function2<? super DTensor, ? super DTensor, ? extends DTensor>) function2);
    }
}
