package com.simiacryptus.mindseye.eval;

import com.simiacryptus.lang.TimedResult;
import com.simiacryptus.lang.UncheckedSupplier;
import com.simiacryptus.mindseye.lang.Layer;
import com.simiacryptus.mindseye.lang.PointSample;
import com.simiacryptus.mindseye.opt.TrainingMonitor;
import com.simiacryptus.ref.lang.RefUtil;
import com.simiacryptus.ref.wrappers.RefArrayList;
import com.simiacryptus.ref.wrappers.RefArrays;
import com.simiacryptus.ref.wrappers.RefList;
import com.simiacryptus.ref.wrappers.RefLists;
import com.simiacryptus.ref.wrappers.RefString;
import java.lang.invoke.SerializedLambda;
import java.util.function.Function;

/* loaded from: input_file:com/simiacryptus/mindseye/eval/BatchedTrainable.class */
public abstract class BatchedTrainable extends TrainableWrapper<DataTrainable> implements DataTrainable {
    protected final int batchSize;
    private boolean verbose;
    static final /* synthetic */ boolean $assertionsDisabled;

    public BatchedTrainable(DataTrainable dataTrainable, int i) {
        super(dataTrainable);
        this.verbose = false;
        this.batchSize = i;
    }

    public BatchedTrainable(Layer layer, int i) {
        this(new BasicTrainable(layer), i);
    }

    public int getBatchSize() {
        return this.batchSize;
    }

    public boolean isVerbose() {
        return this.verbose;
    }

    public void setVerbose(boolean z) {
        this.verbose = z;
    }

    @Override // com.simiacryptus.mindseye.eval.TrainableWrapper, com.simiacryptus.mindseye.eval.Trainable
    public PointSample measure(TrainingMonitor trainingMonitor) {
        RefList asList = RefArrays.asList(getData());
        int size = asList.size();
        TimedResult time = TimedResult.time((UncheckedSupplier) RefUtil.wrapInterface(() -> {
            DataTrainable inner = getInner();
            if (this.batchSize < size) {
                RefList partition = RefLists.partition(asList.addRef(), (int) Math.ceil((size * 1.0d) / ((int) Math.ceil((size * 1.0d) / this.batchSize))));
                PointSample pointSample = (PointSample) RefUtil.get(partition.stream().map((Function) RefUtil.wrapInterface(refList -> {
                    if (this.batchSize < refList.size()) {
                        refList.freeRef();
                        throw new RuntimeException();
                    }
                    if (!$assertionsDisabled && inner == null) {
                        throw new AssertionError();
                    }
                    inner.setData(refList.addRef());
                    refList.freeRef();
                    PointSample measure = super.measure(trainingMonitor);
                    inner.setData(new RefArrayList());
                    return measure;
                }, new Object[]{inner})).reduce((pointSample2, pointSample3) -> {
                    PointSample add = pointSample2.add(pointSample3 == null ? null : pointSample3.m35addRef());
                    if (null != pointSample3) {
                        pointSample3.freeRef();
                    }
                    pointSample2.freeRef();
                    return add;
                }));
                partition.freeRef();
                return pointSample;
            }
            if (!$assertionsDisabled && inner == null) {
                throw new AssertionError();
            }
            inner.setData(asList.addRef());
            PointSample measure = super.measure(trainingMonitor);
            inner.setData(new RefArrayList());
            inner.freeRef();
            return measure;
        }, new Object[]{asList}));
        PointSample pointSample = (PointSample) time.getResult();
        if (null != trainingMonitor && isVerbose()) {
            trainingMonitor.log(RefString.format("Evaluated %s items in %.4fs (%s/%s)", new Object[]{Integer.valueOf(size), Double.valueOf(time.timeNanos / 1.0E9d), Double.valueOf(pointSample.getMean()), Double.valueOf(pointSample.delta.getMagnitude())}));
        }
        time.freeRef();
        return pointSample;
    }

    @Override // com.simiacryptus.mindseye.eval.TrainableWrapper, com.simiacryptus.mindseye.eval.TrainableDataMask
    public void _free() {
        super._free();
    }

    @Override // com.simiacryptus.mindseye.eval.TrainableWrapper, com.simiacryptus.mindseye.eval.TrainableDataMask, com.simiacryptus.mindseye.eval.Trainable, com.simiacryptus.mindseye.eval.DataTrainable
    /* renamed from: addRef */
    public BatchedTrainable mo1addRef() {
        return (BatchedTrainable) super.mo1addRef();
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -1029779528:
                if (implMethodName.equals("lambda$measure$8766d393$1")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("com/simiacryptus/lang/UncheckedSupplier") && serializedLambda.getFunctionalInterfaceMethodName().equals("get") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("()Ljava/lang/Object;") && serializedLambda.getImplClass().equals("com/simiacryptus/mindseye/eval/BatchedTrainable") && serializedLambda.getImplMethodSignature().equals("(ILcom/simiacryptus/ref/wrappers/RefList;Lcom/simiacryptus/mindseye/opt/TrainingMonitor;)Lcom/simiacryptus/mindseye/lang/PointSample;")) {
                    BatchedTrainable batchedTrainable = (BatchedTrainable) serializedLambda.getCapturedArg(0);
                    int intValue = ((Integer) serializedLambda.getCapturedArg(1)).intValue();
                    RefList refList = (RefList) serializedLambda.getCapturedArg(2);
                    TrainingMonitor trainingMonitor = (TrainingMonitor) serializedLambda.getCapturedArg(3);
                    return () -> {
                        DataTrainable inner = getInner();
                        if (this.batchSize < intValue) {
                            RefList partition = RefLists.partition(refList.addRef(), (int) Math.ceil((intValue * 1.0d) / ((int) Math.ceil((intValue * 1.0d) / this.batchSize))));
                            PointSample pointSample = (PointSample) RefUtil.get(partition.stream().map((Function) RefUtil.wrapInterface(refList2 -> {
                                if (this.batchSize < refList2.size()) {
                                    refList2.freeRef();
                                    throw new RuntimeException();
                                }
                                if (!$assertionsDisabled && inner == null) {
                                    throw new AssertionError();
                                }
                                inner.setData(refList2.addRef());
                                refList2.freeRef();
                                PointSample measure = super.measure(trainingMonitor);
                                inner.setData(new RefArrayList());
                                return measure;
                            }, new Object[]{inner})).reduce((pointSample2, pointSample3) -> {
                                PointSample add = pointSample2.add(pointSample3 == null ? null : pointSample3.m35addRef());
                                if (null != pointSample3) {
                                    pointSample3.freeRef();
                                }
                                pointSample2.freeRef();
                                return add;
                            }));
                            partition.freeRef();
                            return pointSample;
                        }
                        if (!$assertionsDisabled && inner == null) {
                            throw new AssertionError();
                        }
                        inner.setData(refList.addRef());
                        PointSample measure = super.measure(trainingMonitor);
                        inner.setData(new RefArrayList());
                        inner.freeRef();
                        return measure;
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }

    static {
        $assertionsDisabled = !BatchedTrainable.class.desiredAssertionStatus();
    }
}
