package org.diffkt.model;

import java.nio.ByteBuffer;
import java.util.ArrayList;
import kotlin.Metadata;
import kotlin.NoWhenBranchMatchedException;
import kotlin.Pair;
import kotlin.jvm.functions.Function2;
import kotlin.jvm.internal.Intrinsics;
import kotlin.jvm.internal.SerializedIr;
import org.diffkt.ConcatKt;
import org.diffkt.Convolve;
import org.diffkt.DTensor;
import org.diffkt.Device;
import org.diffkt.OnDevice;
import org.diffkt.Shape;
import org.diffkt.SliceKt;
import org.diffkt.SqueezeKt;
import org.diffkt.UnsqueezeKt;
import org.diffkt.Wrapper;
import org.diffkt.model.RecurrentBase;
import org.diffkt.model.Trainable;
import org.diffkt.model.TrainableComponent;
import org.diffkt.model.TrainableLayer;
import org.jetbrains.annotations.NotNull;

/* compiled from: RecurrentBase.kt */
@SerializedIr(b = {"��"})
@Metadata(mv = {Convolve.H_AXIS, 6, Convolve.N_AXIS}, k = Convolve.H_AXIS, xi = 48, d1 = {"��6\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0010\b\n\u0002\b\u0003\n\u0002\u0018\u0002\n\u0002\b\b\n\u0002\u0018\u0002\n\u0002\b\n\n\u0002\u0010\u0011\n\u0002\b\t\bf\u0018�� )*\u0014\b��\u0010\u0001*\u000e\u0012\u0004\u0012\u0002H\u0001\u0012\u0004\u0012\u0002H\u00020��*\u0004\b\u0001\u0010\u00022\b\u0012\u0004\u0012\u0002H\u00010\u0003:\u0001)J1\u0010\u0015\u001a\u000e\u0012\u0004\u0012\u00028\u0001\u0012\u0004\u0012\u00020\r0\u00162\u0006\u0010\u0017\u001a\u00020\r2\u0006\u0010\u0013\u001a\u00020\t2\u0006\u0010\u0010\u001a\u00028\u0001H\u0016¢\u0006\u0002\u0010\u0018J0\u0010\u0019\u001a\u000e\u0012\u0004\u0012\u00028\u0001\u0012\u0004\u0012\u00020\r0\u00162\u0012\u0010\u001a\u001a\u000e\u0012\u0004\u0012\u00028\u0001\u0012\u0004\u0012\u00020\r0\u00162\u0006\u0010\u001b\u001a\u00020\rH&J\u001f\u0010\u001c\u001a\u00020\r2\u0006\u0010\u001b\u001a\u00020\r2\b\b\u0002\u0010\u0010\u001a\u00028\u0001H\u0016¢\u0006\u0002\u0010\u001dJ1\u0010\u001e\u001a\u000e\u0012\u0004\u0012\u00028\u0001\u0012\u0004\u0012\u00020\r0\u00162\u0006\u0010\u0017\u001a\u00020\r2\u0006\u0010\u0013\u001a\u00020\t2\u0006\u0010\u0010\u001a\u00028\u0001H\u0016¢\u0006\u0002\u0010\u0018J\"\u0010\u001f\u001a\u00020\r2\u0012\u0010 \u001a\n\u0012\u0006\b\u0001\u0012\u00020\r0!\"\u00020\rH\u0096\u0002¢\u0006\u0002\u0010\"J1\u0010#\u001a\u000e\u0012\u0004\u0012\u00028\u0001\u0012\u0004\u0012\u00020\r0\u00162\u0006\u0010\u0010\u001a\u00028\u00012\u0006\u0010\f\u001a\u00020\r2\u0006\u0010$\u001a\u00020\tH&¢\u0006\u0002\u0010%J\u001e\u0010&\u001a\u00020\r*\u00020\r2\u0006\u0010'\u001a\u00020\t2\b\b\u0002\u0010(\u001a\u00020\tH\u0002R\u0012\u0010\u0004\u001a\u00020\u0005X¦\u0004¢\u0006\u0006\u001a\u0004\b\u0006\u0010\u0007R\u0014\u0010\b\u001a\u00020\t8VX\u0096\u0004¢\u0006\u0006\u001a\u0004\b\n\u0010\u000bR\u0012\u0010\f\u001a\u00020\rX¦\u0004¢\u0006\u0006\u001a\u0004\b\u000e\u0010\u000fR\u0012\u0010\u0010\u001a\u00028\u0001X¦\u0004¢\u0006\u0006\u001a\u0004\b\u0011\u0010\u0012R\u0014\u0010\u0013\u001a\u00020\t8VX\u0096\u0004¢\u0006\u0006\u001a\u0004\b\u0014\u0010\u000b¨\u0006*"}, d2 = {"Lorg/diffkt/model/RecurrentBase;", "Recurrent", "T", "Lorg/diffkt/model/TrainableLayer;", "accType", "Lorg/diffkt/model/RecurrentBase$RecurrentBase$AccType;", "getAccType", "()Lorg/diffkt/model/RecurrentBase$RecurrentBase$AccType;", "batchAxis", "", "getBatchAxis", "()I", "initialOutput", "Lorg/diffkt/DTensor;", "getInitialOutput", "()Lorg/diffkt/DTensor;", "initialState", "getInitialState", "()Ljava/lang/Object;", "sequenceAxis", "getSequenceAxis", "accMap", "Lkotlin/Pair;", "t", "(Lorg/diffkt/DTensor;ILjava/lang/Object;)Lkotlin/Pair;", "cell", "state", "x", "doRecurrence", "(Lorg/diffkt/DTensor;Ljava/lang/Object;)Lorg/diffkt/DTensor;", "fold", "invoke", "inputs", "", "([Lorg/diffkt/DTensor;)Lorg/diffkt/DTensor;", "processForBatching", "batchSize", "(Ljava/lang/Object;Lorg/diffkt/DTensor;I)Lkotlin/Pair;", "axisIdx", "n", "axis", "RecurrentBase", "api"})
/* loaded from: input_file:org/diffkt/model/RecurrentBase.class */
public interface RecurrentBase<Recurrent extends RecurrentBase<Recurrent, T>, T> extends TrainableLayer<Recurrent> {

    @NotNull
    public static final C0000RecurrentBase RecurrentBase = C0000RecurrentBase.$$INSTANCE;

    /* compiled from: RecurrentBase.kt */
    @Metadata(mv = {Convolve.H_AXIS, 6, Convolve.N_AXIS}, k = Convolve.C_AXIS, xi = 48)
    /* loaded from: input_file:org/diffkt/model/RecurrentBase$DefaultImpls.class */
    public static final class DefaultImpls {
        public static <Recurrent extends RecurrentBase<Recurrent, T>, T> int getBatchAxis(@NotNull RecurrentBase<Recurrent, T> recurrentBase) {
            return 0;
        }

        public static <Recurrent extends RecurrentBase<Recurrent, T>, T> int getSequenceAxis(@NotNull RecurrentBase<Recurrent, T> recurrentBase) {
            return 1;
        }

        @NotNull
        public static <Recurrent extends RecurrentBase<Recurrent, T>, T> Pair<T, DTensor> fold(@NotNull RecurrentBase<Recurrent, T> recurrentBase, @NotNull DTensor dTensor, int i, T t) {
            Intrinsics.checkNotNullParameter(dTensor, "t");
            if (dTensor.getSize() == 0) {
                return new Pair<>(t, dTensor);
            }
            DTensor squeeze = SqueezeKt.squeeze(axisIdx(recurrentBase, dTensor, 0, i), i);
            Pair<T, DTensor> processForBatching = recurrentBase.processForBatching(t, recurrentBase.getInitialOutput(), dTensor.getShape().get(recurrentBase.getBatchAxis()));
            Pair<T, DTensor> cell = recurrentBase.cell(new Pair<>(processForBatching.component1(), (DTensor) processForBatching.component2()), squeeze);
            int i2 = dTensor.getShape().get(i);
            for (int i3 = 1; i3 < i2; i3++) {
                cell = recurrentBase.cell(cell, SqueezeKt.squeeze(axisIdx(recurrentBase, dTensor, i3, i), i));
            }
            return cell;
        }

        private static <Recurrent extends RecurrentBase<Recurrent, T>, T> DTensor axisIdx(RecurrentBase<Recurrent, T> recurrentBase, DTensor dTensor, int i, int i2) {
            return SliceKt.slice(dTensor, i, i + 1, i2);
        }

        public static /* synthetic */ DTensor axisIdx$default(RecurrentBase recurrentBase, DTensor dTensor, int i, int i2, int i3, Object obj) {
            if (obj != null) {
                throw new UnsupportedOperationException("Super calls with default arguments not supported in this target, function: axisIdx");
            }
            if ((i3 & 2) != 0) {
                i2 = 0;
            }
            return axisIdx(recurrentBase, dTensor, i, i2);
        }

        @NotNull
        public static <Recurrent extends RecurrentBase<Recurrent, T>, T> Pair<T, DTensor> accMap(@NotNull RecurrentBase<Recurrent, T> recurrentBase, @NotNull DTensor dTensor, int i, T t) {
            Intrinsics.checkNotNullParameter(dTensor, "t");
            if (dTensor.getSize() == 0) {
                return new Pair<>(t, dTensor);
            }
            Shape shape = dTensor.getShape();
            ArrayList arrayList = new ArrayList();
            Pair<T, DTensor> processForBatching = recurrentBase.processForBatching(t, recurrentBase.getInitialOutput(), dTensor.getShape().get(recurrentBase.getBatchAxis()));
            int i2 = shape.get(i);
            for (int i3 = 0; i3 < i2; i3++) {
                processForBatching = recurrentBase.cell(processForBatching, SqueezeKt.squeeze(axisIdx(recurrentBase, dTensor, i3, i), i));
                arrayList.add(UnsqueezeKt.unsqueeze((DTensor) processForBatching.getSecond(), i));
            }
            return new Pair<>(processForBatching.getFirst(), ConcatKt.concat(arrayList, i));
        }

        @NotNull
        public static <Recurrent extends RecurrentBase<Recurrent, T>, T> DTensor invoke(@NotNull RecurrentBase<Recurrent, T> recurrentBase, @NotNull DTensor... dTensorArr) {
            Intrinsics.checkNotNullParameter(dTensorArr, "inputs");
            int length = dTensorArr.length;
            if (1 <= length ? length < 3 : false) {
                return dTensorArr.length == 1 ? doRecurrence$default(recurrentBase, dTensorArr[0], null, 2, null) : recurrentBase.doRecurrence(dTensorArr[0], dTensorArr[1]);
            }
            throw new IllegalArgumentException(("Expected 1 or 2 inputs, got " + dTensorArr.length + " ").toString());
        }

        @NotNull
        public static <Recurrent extends RecurrentBase<Recurrent, T>, T> DTensor doRecurrence(@NotNull RecurrentBase<Recurrent, T> recurrentBase, @NotNull DTensor dTensor, T t) {
            Pair<T, DTensor> accMap;
            Intrinsics.checkNotNullParameter(dTensor, "x");
            if (!(dTensor.getRank() == 3)) {
                throw new IllegalArgumentException(("input must be rank 3, got rank " + dTensor.getRank() + " ").toString());
            }
            switch (WhenMappings.$EnumSwitchMapping$0[recurrentBase.getAccType().ordinal()]) {
                case Convolve.H_AXIS /* 1 */:
                    accMap = recurrentBase.fold(dTensor, recurrentBase.getSequenceAxis(), t);
                    break;
                case Convolve.W_AXIS /* 2 */:
                    accMap = recurrentBase.accMap(dTensor, recurrentBase.getSequenceAxis(), t);
                    break;
                default:
                    throw new NoWhenBranchMatchedException();
            }
            return (DTensor) accMap.getSecond();
        }

        /* JADX WARN: Multi-variable type inference failed */
        public static /* synthetic */ DTensor doRecurrence$default(RecurrentBase recurrentBase, DTensor dTensor, Object obj, int i, Object obj2) {
            if (obj2 != null) {
                throw new UnsupportedOperationException("Super calls with default arguments not supported in this target, function: doRecurrence");
            }
            T t = obj;
            if ((i & 2) != 0) {
                t = recurrentBase.getInitialState();
            }
            return recurrentBase.doRecurrence(dTensor, t);
        }

        @NotNull
        public static <Recurrent extends RecurrentBase<Recurrent, T>, T> Recurrent cpu(@NotNull RecurrentBase<Recurrent, T> recurrentBase) {
            return (Recurrent) TrainableLayer.DefaultImpls.cpu(recurrentBase);
        }

        @NotNull
        public static <Recurrent extends RecurrentBase<Recurrent, T>, T> Recurrent gpu(@NotNull RecurrentBase<Recurrent, T> recurrentBase) {
            return (Recurrent) TrainableLayer.DefaultImpls.gpu(recurrentBase);
        }

        @NotNull
        public static <Recurrent extends RecurrentBase<Recurrent, T>, T> Recurrent trainingStep(@NotNull RecurrentBase<Recurrent, T> recurrentBase, @NotNull Optimizer<?> optimizer, @NotNull Trainable.Tangent tangent) {
            Intrinsics.checkNotNullParameter(optimizer, "optim");
            Intrinsics.checkNotNullParameter(tangent, "tangent");
            return (Recurrent) TrainableLayer.DefaultImpls.trainingStep(recurrentBase, optimizer, tangent);
        }

        @NotNull
        public static <Recurrent extends RecurrentBase<Recurrent, T>, T> TrainableComponent.Companion.Tangent extractTangent(@NotNull RecurrentBase<Recurrent, T> recurrentBase, @NotNull DTensor dTensor, @NotNull Function2<? super DTensor, ? super DTensor, ? extends DTensor> function2) {
            Intrinsics.checkNotNullParameter(dTensor, "output");
            Intrinsics.checkNotNullParameter(function2, "extractor");
            return TrainableLayer.DefaultImpls.extractTangent(recurrentBase, dTensor, function2);
        }

        @NotNull
        public static <Recurrent extends RecurrentBase<Recurrent, T>, T> ByteBuffer store(@NotNull RecurrentBase<Recurrent, T> recurrentBase, @NotNull ByteBuffer byteBuffer) {
            Intrinsics.checkNotNullParameter(byteBuffer, "into");
            return TrainableLayer.DefaultImpls.store(recurrentBase, byteBuffer);
        }

        @NotNull
        public static <Recurrent extends RecurrentBase<Recurrent, T>, T> Recurrent load(@NotNull RecurrentBase<Recurrent, T> recurrentBase, @NotNull ByteBuffer byteBuffer) {
            Intrinsics.checkNotNullParameter(byteBuffer, "from");
            return (Recurrent) TrainableLayer.DefaultImpls.load(recurrentBase, byteBuffer);
        }

        @NotNull
        public static <Recurrent extends RecurrentBase<Recurrent, T>, T> Recurrent wrap(@NotNull RecurrentBase<Recurrent, T> recurrentBase, @NotNull Wrapper wrapper) {
            Intrinsics.checkNotNullParameter(wrapper, "wrapper");
            return (Recurrent) TrainableLayer.DefaultImpls.wrap(recurrentBase, wrapper);
        }

        @NotNull
        public static <Recurrent extends RecurrentBase<Recurrent, T>, T> OnDevice to(@NotNull RecurrentBase<Recurrent, T> recurrentBase, @NotNull Device device) {
            Intrinsics.checkNotNullParameter(device, "device");
            return TrainableLayer.DefaultImpls.to(recurrentBase, device);
        }

        @NotNull
        public static <Recurrent extends RecurrentBase<Recurrent, T>, T> DTensor getSingleInput(@NotNull RecurrentBase<Recurrent, T> recurrentBase, @NotNull DTensor[] dTensorArr) {
            Intrinsics.checkNotNullParameter(dTensorArr, "inputs");
            return TrainableLayer.DefaultImpls.getSingleInput(recurrentBase, dTensorArr);
        }
    }

    /* compiled from: RecurrentBase.kt */
    @Metadata(mv = {Convolve.H_AXIS, 6, Convolve.N_AXIS}, k = Convolve.H_AXIS, xi = 48, d1 = {"��\f\n\u0002\u0018\u0002\n\u0002\u0010��\n\u0002\b\u0003\b\u0086\u0003\u0018��2\u00020\u0001:\u0001\u0003B\u0007\b\u0002¢\u0006\u0002\u0010\u0002¨\u0006\u0004"}, d2 = {"Lorg/diffkt/model/RecurrentBase$RecurrentBase;", "", "()V", "AccType", "api"})
    /* renamed from: org.diffkt.model.RecurrentBase$RecurrentBase, reason: collision with other inner class name */
    /* loaded from: input_file:org/diffkt/model/RecurrentBase$RecurrentBase.class */
    public static final class C0000RecurrentBase {
        static final /* synthetic */ C0000RecurrentBase $$INSTANCE = new C0000RecurrentBase();

        /* compiled from: RecurrentBase.kt */
        @Metadata(mv = {Convolve.H_AXIS, 6, Convolve.N_AXIS}, k = Convolve.H_AXIS, xi = 48, d1 = {"��\f\n\u0002\u0018\u0002\n\u0002\u0010\u0010\n\u0002\b\u0004\b\u0086\u0001\u0018��2\b\u0012\u0004\u0012\u00020��0\u0001B\u0007\b\u0002¢\u0006\u0002\u0010\u0002j\u0002\b\u0003j\u0002\b\u0004¨\u0006\u0005"}, d2 = {"Lorg/diffkt/model/RecurrentBase$RecurrentBase$AccType;", "", "(Ljava/lang/String;I)V", "Fold", "AccMap", "api"})
        /* renamed from: org.diffkt.model.RecurrentBase$RecurrentBase$AccType */
        /* loaded from: input_file:org/diffkt/model/RecurrentBase$RecurrentBase$AccType.class */
        public enum AccType {
            Fold,
            AccMap
        }

        private C0000RecurrentBase() {
        }
    }

    /* compiled from: RecurrentBase.kt */
    @Metadata(mv = {Convolve.H_AXIS, 6, Convolve.N_AXIS}, k = Convolve.C_AXIS, xi = 48)
    /* loaded from: input_file:org/diffkt/model/RecurrentBase$WhenMappings.class */
    public /* synthetic */ class WhenMappings {
        public static final /* synthetic */ int[] $EnumSwitchMapping$0;

        static {
            int[] iArr = new int[C0000RecurrentBase.AccType.values().length];
            iArr[C0000RecurrentBase.AccType.Fold.ordinal()] = 1;
            iArr[C0000RecurrentBase.AccType.AccMap.ordinal()] = 2;
            $EnumSwitchMapping$0 = iArr;
        }
    }

    T getInitialState();

    @NotNull
    DTensor getInitialOutput();

    int getBatchAxis();

    int getSequenceAxis();

    @NotNull
    C0000RecurrentBase.AccType getAccType();

    @NotNull
    Pair<T, DTensor> cell(@NotNull Pair<? extends T, ? extends DTensor> pair, @NotNull DTensor dTensor);

    @NotNull
    Pair<T, DTensor> processForBatching(T t, @NotNull DTensor dTensor, int i);

    @NotNull
    Pair<T, DTensor> fold(@NotNull DTensor dTensor, int i, T t);

    @NotNull
    Pair<T, DTensor> accMap(@NotNull DTensor dTensor, int i, T t);

    @NotNull
    DTensor invoke(@NotNull DTensor... dTensorArr);

    @NotNull
    DTensor doRecurrence(@NotNull DTensor dTensor, T t);
}
