package org.diffkt.model;

import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import kotlin.Metadata;
import kotlin.collections.CollectionsKt;
import kotlin.jvm.functions.Function2;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import kotlin.jvm.internal.SerializedIr;
import kotlin.random.Random;
import org.diffkt.ConcatKt;
import org.diffkt.Convolve;
import org.diffkt.DTensor;
import org.diffkt.Device;
import org.diffkt.EmbeddingKt;
import org.diffkt.FloatTensor;
import org.diffkt.IntTensor;
import org.diffkt.OnDevice;
import org.diffkt.Shape;
import org.diffkt.SliceKt;
import org.diffkt.SumKt;
import org.diffkt.UtilsKt;
import org.diffkt.Wrapper;
import org.diffkt.model.Trainable;
import org.diffkt.model.TrainableComponent;
import org.diffkt.model.TrainableLayer;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

/* compiled from: EmbeddingBag.kt */
@SerializedIr(b = {"��"})
@Metadata(mv = {Convolve.H_AXIS, 6, Convolve.N_AXIS}, k = Convolve.H_AXIS, xi = 48, d1 = {"��j\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0010\b\n\u0002\b\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0006\n\u0002\u0010 \n\u0002\b\u0003\n\u0002\u0010\u000b\n��\n\u0002\u0010��\n\u0002\b\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0010\u0011\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0002\u0018�� (2\b\u0012\u0004\u0012\u00020��0\u0001:\u0001(BC\b\u0016\u0012\u0006\u0010\u0002\u001a\u00020\u0003\u0012\u0006\u0010\u0004\u001a\u00020\u0003\u0012\u0006\u0010\u0005\u001a\u00020\u0006\u0012\u0006\u0010\u0007\u001a\u00020\b\u0012\u001a\b\u0002\u0010\t\u001a\u0014\u0012\u0004\u0012\u00020\u000b\u0012\u0004\u0012\u00020\b\u0012\u0004\u0012\u00020\f0\n¢\u0006\u0002\u0010\rB\u0015\u0012\u0006\u0010\u000e\u001a\u00020\u000f\u0012\u0006\u0010\u0005\u001a\u00020\u0006¢\u0006\u0002\u0010\u0010J\u0013\u0010\u0019\u001a\u00020\u001a2\b\u0010\u001b\u001a\u0004\u0018\u00010\u001cH\u0096\u0002J\b\u0010\u001d\u001a\u00020\u0003H\u0016J\"\u0010\u001e\u001a\u00020\u001f2\u0012\u0010 \u001a\n\u0012\u0006\b\u0001\u0012\u00020\u001f0!\"\u00020\u001fH\u0096\u0002¢\u0006\u0002\u0010\"J\u0019\u0010\u001e\u001a\u00020\u001f2\u0006\u0010#\u001a\u00020$2\u0006\u0010%\u001a\u00020$H\u0086\u0002J\u001a\u0010&\u001a\u00020��2\u0010\u0010\u0015\u001a\f\u0012\b\u0012\u0006\u0012\u0002\b\u00030'0\u0016H\u0016R\u0011\u0010\u0005\u001a\u00020\u0006¢\u0006\b\n��\u001a\u0004\b\u0011\u0010\u0012R\u0011\u0010\u000e\u001a\u00020\u000f¢\u0006\b\n��\u001a\u0004\b\u0013\u0010\u0014R\u001a\u0010\u0015\u001a\b\u0012\u0004\u0012\u00020\u000f0\u0016X\u0096\u0004¢\u0006\b\n��\u001a\u0004\b\u0017\u0010\u0018¨\u0006)"}, d2 = {"Lorg/diffkt/model/EmbeddingBag;", "Lorg/diffkt/model/TrainableLayer;", "numEmbeddings", "", "embeddingSize", "reduction", "Lorg/diffkt/model/EmbeddingBag$Companion$Reduction;", "random", "Lkotlin/random/Random;", "initializer", "Lkotlin/Function2;", "Lorg/diffkt/Shape;", "Lorg/diffkt/FloatTensor;", "(IILorg/diffkt/model/EmbeddingBag$Companion$Reduction;Lkotlin/random/Random;Lkotlin/jvm/functions/Function2;)V", "trainableWeights", "Lorg/diffkt/model/TrainableTensor;", "(Lorg/diffkt/model/TrainableTensor;Lorg/diffkt/model/EmbeddingBag$Companion$Reduction;)V", "getReduction", "()Lorg/diffkt/model/EmbeddingBag$Companion$Reduction;", "getTrainableWeights", "()Lorg/diffkt/model/TrainableTensor;", "trainables", "", "getTrainables", "()Ljava/util/List;", "equals", "", "other", "", "hashCode", "invoke", "Lorg/diffkt/DTensor;", "inputs", "", "([Lorg/diffkt/DTensor;)Lorg/diffkt/DTensor;", "indices", "Lorg/diffkt/IntTensor;", "bagOffsets", "withTrainables", "Lorg/diffkt/model/Trainable;", "Companion", "api"})
/* loaded from: input_file:org/diffkt/model/EmbeddingBag.class */
public final class EmbeddingBag implements TrainableLayer<EmbeddingBag> {

    @NotNull
    public static final Companion Companion = new Companion(null);

    @NotNull
    private final TrainableTensor trainableWeights;

    @NotNull
    private final Companion.Reduction reduction;

    @NotNull
    private final List<TrainableTensor> trainables;

    /* compiled from: EmbeddingBag.kt */
    @Metadata(mv = {Convolve.H_AXIS, 6, Convolve.N_AXIS}, k = Convolve.H_AXIS, xi = 48, d1 = {"��\f\n\u0002\u0018\u0002\n\u0002\u0010��\n\u0002\b\u0003\b\u0086\u0003\u0018��2\u00020\u0001:\u0001\u0003B\u0007\b\u0002¢\u0006\u0002\u0010\u0002¨\u0006\u0004"}, d2 = {"Lorg/diffkt/model/EmbeddingBag$Companion;", "", "()V", "Reduction", "api"})
    /* loaded from: input_file:org/diffkt/model/EmbeddingBag$Companion.class */
    public static final class Companion {

        /* compiled from: EmbeddingBag.kt */
        @Metadata(mv = {Convolve.H_AXIS, 6, Convolve.N_AXIS}, k = Convolve.H_AXIS, xi = 48, d1 = {"��\u001a\n\u0002\u0018\u0002\n\u0002\u0010��\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n��\b6\u0018��2\u00020\u0001:\u0001\u0006B\u0007\b\u0004¢\u0006\u0002\u0010\u0002J\u0010\u0010\u0003\u001a\u00020\u00042\u0006\u0010\u0005\u001a\u00020\u0004H&\u0082\u0001\u0001\u0007¨\u0006\b"}, d2 = {"Lorg/diffkt/model/EmbeddingBag$Companion$Reduction;", "", "()V", "reduce", "Lorg/diffkt/DTensor;", "x", "Sum", "Lorg/diffkt/model/EmbeddingBag$Companion$Reduction$Sum;", "api"})
        /* loaded from: input_file:org/diffkt/model/EmbeddingBag$Companion$Reduction.class */
        public static abstract class Reduction {

            /* compiled from: EmbeddingBag.kt */
            @Metadata(mv = {Convolve.H_AXIS, 6, Convolve.N_AXIS}, k = Convolve.H_AXIS, xi = 48, d1 = {"��\u0014\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0002\bÆ\u0002\u0018��2\u00020\u0001B\u0007\b\u0002¢\u0006\u0002\u0010\u0002J\u0010\u0010\u0003\u001a\u00020\u00042\u0006\u0010\u0005\u001a\u00020\u0004H\u0016¨\u0006\u0006"}, d2 = {"Lorg/diffkt/model/EmbeddingBag$Companion$Reduction$Sum;", "Lorg/diffkt/model/EmbeddingBag$Companion$Reduction;", "()V", "reduce", "Lorg/diffkt/DTensor;", "x", "api"})
            /* loaded from: input_file:org/diffkt/model/EmbeddingBag$Companion$Reduction$Sum.class */
            public static final class Sum extends Reduction {

                @NotNull
                public static final Sum INSTANCE = new Sum();

                private Sum() {
                    super(null);
                }

                @Override // org.diffkt.model.EmbeddingBag.Companion.Reduction
                @NotNull
                public DTensor reduce(@NotNull DTensor dTensor) {
                    Intrinsics.checkNotNullParameter(dTensor, "x");
                    return SumKt.varargSum(dTensor, new int[]{0}, true);
                }
            }

            private Reduction() {
            }

            @NotNull
            public abstract DTensor reduce(@NotNull DTensor dTensor);

            public /* synthetic */ Reduction(DefaultConstructorMarker defaultConstructorMarker) {
                this();
            }
        }

        private Companion() {
        }

        public /* synthetic */ Companion(DefaultConstructorMarker defaultConstructorMarker) {
            this();
        }
    }

    public EmbeddingBag(@NotNull TrainableTensor trainableTensor, @NotNull Companion.Reduction reduction) {
        Intrinsics.checkNotNullParameter(trainableTensor, "trainableWeights");
        Intrinsics.checkNotNullParameter(reduction, "reduction");
        this.trainableWeights = trainableTensor;
        this.reduction = reduction;
        this.trainables = CollectionsKt.listOf(this.trainableWeights);
    }

    @NotNull
    public final TrainableTensor getTrainableWeights() {
        return this.trainableWeights;
    }

    @NotNull
    public final Companion.Reduction getReduction() {
        return this.reduction;
    }

    @Override // org.diffkt.model.TrainableComponent
    @NotNull
    public List<TrainableTensor> getTrainables() {
        return this.trainables;
    }

    /* JADX WARN: 'this' call moved to the top of the method (can break code semantics) */
    public EmbeddingBag(int i, int i2, @NotNull Companion.Reduction reduction, @NotNull Random random, @NotNull Function2<? super Shape, ? super Random, ? extends FloatTensor> function2) {
        this(new TrainableTensor((DTensor) function2.invoke(Shape.Companion.invoke(i, i2), random)), reduction);
        Intrinsics.checkNotNullParameter(reduction, "reduction");
        Intrinsics.checkNotNullParameter(random, "random");
        Intrinsics.checkNotNullParameter(function2, "initializer");
    }

    public /* synthetic */ EmbeddingBag(int i, int i2, Companion.Reduction reduction, Random random, Function2 function2, int i3, DefaultConstructorMarker defaultConstructorMarker) {
        this(i, i2, reduction, random, (i3 & 16) != 0 ? Initializer.gaussian$default(Initializer.INSTANCE, 0.0f, 0.0f, 3, null) : function2);
    }

    @NotNull
    public final DTensor invoke(@NotNull IntTensor intTensor, @NotNull IntTensor intTensor2) {
        Intrinsics.checkNotNullParameter(intTensor, "indices");
        Intrinsics.checkNotNullParameter(intTensor2, "bagOffsets");
        DTensor embedding$default = EmbeddingKt.embedding$default(this.trainableWeights.getTensor(), IntTensor.flatten$default(intTensor, 0, 0, 3, null), 0, 4, null);
        IntTensor flatten$default = IntTensor.flatten$default(intTensor2, 0, 0, 3, null);
        int size = intTensor.getSize();
        ArrayList arrayList = new ArrayList();
        Iterator<Integer> it = flatten$default.getDataIterator().iterator();
        int i = 0;
        while (it.hasNext()) {
            int i2 = i;
            i++;
            arrayList.add(this.reduction.reduce(SliceKt.slice$default(embedding$default, it.next().intValue(), i2 + 1 == flatten$default.getSize() ? size : flatten$default.at(i2 + 1), 0, 4, (Object) null)));
        }
        return ConcatKt.concat$default(arrayList, 0, 2, null);
    }

    @Override // org.diffkt.model.Layer
    @NotNull
    public DTensor invoke(@NotNull DTensor... dTensorArr) {
        Intrinsics.checkNotNullParameter(dTensorArr, "inputs");
        throw new IllegalArgumentException("Embedding must be called with an IntTensor");
    }

    @Override // org.diffkt.model.TrainableComponent
    @NotNull
    public EmbeddingBag withTrainables(@NotNull List<? extends Trainable<?>> list) {
        Intrinsics.checkNotNullParameter(list, "trainables");
        if (!(list.size() == 1)) {
            throw new IllegalArgumentException("Failed requirement.".toString());
        }
        Trainable<?> trainable = list.get(0);
        Intrinsics.checkNotNull(trainable, "null cannot be cast to non-null type org.diffkt.model.TrainableTensor");
        return new EmbeddingBag((TrainableTensor) trainable, this.reduction);
    }

    public int hashCode() {
        return UtilsKt.combineHash("EmbeddingBag", this.trainableWeights, this.reduction);
    }

    public boolean equals(@Nullable Object obj) {
        return (obj instanceof EmbeddingBag) && Intrinsics.areEqual(((EmbeddingBag) obj).trainableWeights, this.trainableWeights) && Intrinsics.areEqual(((EmbeddingBag) obj).reduction, this.reduction);
    }

    @Override // org.diffkt.OnDevice
    @NotNull
    public EmbeddingBag cpu() {
        return (EmbeddingBag) TrainableLayer.DefaultImpls.cpu(this);
    }

    @Override // org.diffkt.OnDevice
    @NotNull
    public EmbeddingBag gpu() {
        return (EmbeddingBag) TrainableLayer.DefaultImpls.gpu(this);
    }

    @Override // org.diffkt.model.TrainableComponent, org.diffkt.model.Trainable
    @NotNull
    public EmbeddingBag trainingStep(@NotNull Optimizer<?> optimizer, @NotNull Trainable.Tangent tangent) {
        return (EmbeddingBag) TrainableLayer.DefaultImpls.trainingStep(this, optimizer, tangent);
    }

    @Override // org.diffkt.model.TrainableComponent, org.diffkt.model.Trainable
    @NotNull
    public TrainableComponent.Companion.Tangent extractTangent(@NotNull DTensor dTensor, @NotNull Function2<? super DTensor, ? super DTensor, ? extends DTensor> function2) {
        return TrainableLayer.DefaultImpls.extractTangent(this, dTensor, function2);
    }

    @Override // org.diffkt.model.TrainableComponent, org.diffkt.model.Trainable
    @NotNull
    public ByteBuffer store(@NotNull ByteBuffer byteBuffer) {
        return TrainableLayer.DefaultImpls.store(this, byteBuffer);
    }

    @Override // org.diffkt.model.Trainable
    @NotNull
    public EmbeddingBag load(@NotNull ByteBuffer byteBuffer) {
        return (EmbeddingBag) TrainableLayer.DefaultImpls.load(this, byteBuffer);
    }

    @Override // org.diffkt.Wrappable
    @NotNull
    public EmbeddingBag wrap(@NotNull Wrapper wrapper) {
        return (EmbeddingBag) TrainableLayer.DefaultImpls.wrap(this, wrapper);
    }

    @Override // org.diffkt.OnDevice
    @NotNull
    public OnDevice to(@NotNull Device device) {
        return TrainableLayer.DefaultImpls.to(this, device);
    }

    @Override // org.diffkt.model.Layer
    @NotNull
    public DTensor getSingleInput(@NotNull DTensor[] dTensorArr) {
        return TrainableLayer.DefaultImpls.getSingleInput(this, dTensorArr);
    }

    @Override // org.diffkt.model.TrainableComponent
    public /* bridge */ /* synthetic */ TrainableComponent withTrainables(List list) {
        return withTrainables((List<? extends Trainable<?>>) list);
    }

    @Override // org.diffkt.model.TrainableComponent, org.diffkt.model.Trainable
    public /* bridge */ /* synthetic */ TrainableComponent trainingStep(Optimizer optimizer, Trainable.Tangent tangent) {
        return trainingStep((Optimizer<?>) optimizer, tangent);
    }

    @Override // org.diffkt.model.Trainable
    public /* bridge */ /* synthetic */ Trainable trainingStep(Optimizer optimizer, Trainable.Tangent tangent) {
        return trainingStep((Optimizer<?>) optimizer, tangent);
    }

    @Override // org.diffkt.model.Trainable
    public /* bridge */ /* synthetic */ Trainable.Tangent extractTangent(DTensor dTensor, Function2 function2) {
        return extractTangent(dTensor, (Function2<? super DTensor, ? super DTensor, ? extends DTensor>) function2);
    }
}
