package org.diffkt.model;

import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import kotlin.Metadata;
import kotlin.collections.CollectionsKt;
import kotlin.jvm.functions.Function2;
import kotlin.jvm.internal.Intrinsics;
import kotlin.jvm.internal.SerializedIr;
import org.diffkt.Convolve;
import org.diffkt.DTensor;
import org.diffkt.Device;
import org.diffkt.OnDevice;
import org.diffkt.Wrapper;
import org.diffkt.model.Layer;
import org.diffkt.model.Trainable;
import org.diffkt.model.TrainableComponent;
import org.diffkt.model.TrainableLayer;
import org.jetbrains.annotations.NotNull;

/* compiled from: TrainableLayer.kt */
@SerializedIr(b = {"��"})
@Metadata(mv = {Convolve.H_AXIS, 6, Convolve.N_AXIS}, k = Convolve.H_AXIS, xi = 48, d1 = {"��\u0012\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0004\bf\u0018��*\u000e\b��\u0010\u0001*\b\u0012\u0004\u0012\u0002H\u00010��2\b\u0012\u0004\u0012\u0002H\u00010\u00022\b\u0012\u0004\u0012\u0002H\u00010\u0003J\r\u0010\u0004\u001a\u00028��H\u0016¢\u0006\u0002\u0010\u0005J\r\u0010\u0006\u001a\u00028��H\u0016¢\u0006\u0002\u0010\u0005¨\u0006\u0007"}, d2 = {"Lorg/diffkt/model/TrainableLayer;", "T", "Lorg/diffkt/model/TrainableComponent;", "Lorg/diffkt/model/Layer;", "cpu", "()Lorg/diffkt/model/TrainableLayer;", "gpu", "api"})
/* loaded from: input_file:org/diffkt/model/TrainableLayer.class */
public interface TrainableLayer<T extends TrainableLayer<T>> extends TrainableComponent<T>, Layer<T> {

    /* compiled from: TrainableLayer.kt */
    @Metadata(mv = {Convolve.H_AXIS, 6, Convolve.N_AXIS}, k = Convolve.C_AXIS, xi = 48)
    /* loaded from: input_file:org/diffkt/model/TrainableLayer$DefaultImpls.class */
    public static final class DefaultImpls {
        @NotNull
        public static <T extends TrainableLayer<T>> T cpu(@NotNull TrainableLayer<T> trainableLayer) {
            List<Trainable<?>> trainables = trainableLayer.getTrainables();
            ArrayList arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(trainables, 10));
            Iterator<T> it = trainables.iterator();
            while (it.hasNext()) {
                arrayList.add(((Trainable) it.next()).cpu());
            }
            return (T) trainableLayer.withTrainables(arrayList);
        }

        @NotNull
        public static <T extends TrainableLayer<T>> T gpu(@NotNull TrainableLayer<T> trainableLayer) {
            List<Trainable<?>> trainables = trainableLayer.getTrainables();
            ArrayList arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(trainables, 10));
            Iterator<T> it = trainables.iterator();
            while (it.hasNext()) {
                arrayList.add(((Trainable) it.next()).gpu());
            }
            return (T) trainableLayer.withTrainables(arrayList);
        }

        @NotNull
        public static <T extends TrainableLayer<T>> T trainingStep(@NotNull TrainableLayer<T> trainableLayer, @NotNull Optimizer<?> optimizer, @NotNull Trainable.Tangent tangent) {
            Intrinsics.checkNotNullParameter(optimizer, "optim");
            Intrinsics.checkNotNullParameter(tangent, "tangent");
            return (T) TrainableComponent.DefaultImpls.trainingStep(trainableLayer, optimizer, tangent);
        }

        @NotNull
        public static <T extends TrainableLayer<T>> TrainableComponent.Companion.Tangent extractTangent(@NotNull TrainableLayer<T> trainableLayer, @NotNull DTensor dTensor, @NotNull Function2<? super DTensor, ? super DTensor, ? extends DTensor> function2) {
            Intrinsics.checkNotNullParameter(dTensor, "output");
            Intrinsics.checkNotNullParameter(function2, "extractor");
            return TrainableComponent.DefaultImpls.extractTangent(trainableLayer, dTensor, function2);
        }

        @NotNull
        public static <T extends TrainableLayer<T>> ByteBuffer store(@NotNull TrainableLayer<T> trainableLayer, @NotNull ByteBuffer byteBuffer) {
            Intrinsics.checkNotNullParameter(byteBuffer, "into");
            return TrainableComponent.DefaultImpls.store(trainableLayer, byteBuffer);
        }

        @NotNull
        public static <T extends TrainableLayer<T>> T load(@NotNull TrainableLayer<T> trainableLayer, @NotNull ByteBuffer byteBuffer) {
            Intrinsics.checkNotNullParameter(byteBuffer, "from");
            return (T) TrainableComponent.DefaultImpls.load(trainableLayer, byteBuffer);
        }

        @NotNull
        public static <T extends TrainableLayer<T>> T wrap(@NotNull TrainableLayer<T> trainableLayer, @NotNull Wrapper wrapper) {
            Intrinsics.checkNotNullParameter(wrapper, "wrapper");
            return (T) TrainableComponent.DefaultImpls.wrap(trainableLayer, wrapper);
        }

        @NotNull
        public static <T extends TrainableLayer<T>> OnDevice to(@NotNull TrainableLayer<T> trainableLayer, @NotNull Device device) {
            Intrinsics.checkNotNullParameter(device, "device");
            return TrainableComponent.DefaultImpls.to(trainableLayer, device);
        }

        @NotNull
        public static <T extends TrainableLayer<T>> DTensor getSingleInput(@NotNull TrainableLayer<T> trainableLayer, @NotNull DTensor[] dTensorArr) {
            Intrinsics.checkNotNullParameter(dTensorArr, "inputs");
            return Layer.DefaultImpls.getSingleInput(trainableLayer, dTensorArr);
        }
    }

    @Override // org.diffkt.model.TrainableComponent, org.diffkt.model.Trainable, org.diffkt.OnDevice
    @NotNull
    T cpu();

    @Override // org.diffkt.model.TrainableComponent, org.diffkt.model.Trainable, org.diffkt.OnDevice
    @NotNull
    T gpu();
}
