package com.simiacryptus.mindseye.eval;

import com.simiacryptus.mindseye.lang.Layer;
import com.simiacryptus.mindseye.lang.Tensor;
import com.simiacryptus.ref.lang.RefUtil;
import com.simiacryptus.ref.wrappers.RefArrays;
import com.simiacryptus.ref.wrappers.RefCollectors;
import com.simiacryptus.ref.wrappers.RefIntStream;
import com.simiacryptus.ref.wrappers.RefList;
import com.simiacryptus.ref.wrappers.RefSupplier;
import com.simiacryptus.ref.wrappers.RefSystem;
import com.simiacryptus.util.Util;
import com.simiacryptus.util.function.WeakCachedSupplier;
import java.util.Random;

/* loaded from: input_file:com/simiacryptus/mindseye/eval/SampledArrayTrainable.class */
public class SampledArrayTrainable extends TrainableWrapper<ArrayTrainable> implements SampledTrainable, TrainableDataMask {
    private final RefList<? extends RefSupplier<Tensor[]>> trainingData;
    private int minSamples;
    private long seed;
    private int trainingSize;
    static final /* synthetic */ boolean $assertionsDisabled;

    public SampledArrayTrainable(RefList<? extends RefSupplier<Tensor[]>> refList, Layer layer, int i) {
        this(refList, layer, i, i);
    }

    public SampledArrayTrainable(RefList<? extends RefSupplier<Tensor[]>> refList, Layer layer, int i, int i2) {
        super(new ArrayTrainable((Tensor[][]) null, layer == null ? null : layer.mo21addRef(), i2));
        this.minSamples = 0;
        this.seed = ((Random) Util.R.get()).nextInt();
        if (null != layer) {
            layer.freeRef();
        }
        if (0 == refList.size()) {
            refList.freeRef();
            throw new IllegalArgumentException();
        }
        this.trainingData = refList;
        this.trainingSize = i;
        reseed(RefSystem.nanoTime());
    }

    public SampledArrayTrainable(Tensor[][] tensorArr, Layer layer, int i) {
        this(tensorArr, layer, i, i);
    }

    public SampledArrayTrainable(Tensor[][] tensorArr, Layer layer, int i, int i2) {
        super(new ArrayTrainable(layer == null ? null : layer.mo21addRef(), i2));
        this.minSamples = 0;
        this.seed = ((Random) Util.R.get()).nextInt();
        if (null != layer) {
            layer.freeRef();
        }
        RefUtil.freeRef(getInner());
        if (0 == tensorArr.length) {
            RefUtil.freeRef(tensorArr);
            throw new IllegalArgumentException();
        }
        this.trainingData = (RefList) RefArrays.stream(tensorArr).map(tensorArr2 -> {
            return (RefSupplier) RefUtil.wrapInterface(new WeakCachedSupplier(() -> {
                return (Tensor[]) RefUtil.addRef(tensorArr2);
            }), tensorArr2);
        }).collect(RefCollectors.toList());
        this.trainingSize = i;
        reseed(RefSystem.nanoTime());
    }

    public int getMinSamples() {
        return this.minSamples;
    }

    public void setMinSamples(int i) {
        this.minSamples = i;
    }

    @Override // com.simiacryptus.mindseye.eval.SampledTrainable
    public int getTrainingSize() {
        return Math.max(this.minSamples, Math.min(this.trainingData.size(), this.trainingSize));
    }

    @Override // com.simiacryptus.mindseye.eval.SampledTrainable
    public void setTrainingSize(int i) {
        this.trainingSize = i;
        refreshSampledData();
    }

    private void setSeed(int i) {
        if (this.seed == i) {
            return;
        }
        this.seed = i;
        refreshSampledData();
    }

    @Override // com.simiacryptus.mindseye.eval.Trainable
    public SampledCachedTrainable<? extends SampledTrainable> cached() {
        return new SampledCachedTrainable<>(mo1addRef());
    }

    @Override // com.simiacryptus.mindseye.eval.TrainableWrapper, com.simiacryptus.mindseye.eval.Trainable
    public boolean reseed(long j) {
        setSeed(((Random) Util.R.get()).nextInt());
        ArrayTrainable inner = getInner();
        if (!$assertionsDisabled && inner == null) {
            throw new AssertionError();
        }
        inner.reseed(j);
        inner.freeRef();
        super.reseed(j);
        return true;
    }

    @Override // com.simiacryptus.mindseye.eval.TrainableWrapper, com.simiacryptus.mindseye.eval.TrainableDataMask
    public void _free() {
        super._free();
        this.trainingData.freeRef();
    }

    @Override // com.simiacryptus.mindseye.eval.TrainableWrapper, com.simiacryptus.mindseye.eval.TrainableDataMask, com.simiacryptus.mindseye.eval.Trainable, com.simiacryptus.mindseye.eval.DataTrainable
    /* renamed from: addRef */
    public SampledArrayTrainable mo1addRef() {
        return (SampledArrayTrainable) super.mo1addRef();
    }

    protected void refreshSampledData() {
        Tensor[][] tensorArr;
        if (!$assertionsDisabled && 0 >= this.trainingData.size()) {
            throw new AssertionError();
        }
        Tensor[][] tensorArr2 = (Tensor[][]) null;
        if (0 >= getTrainingSize() || getTrainingSize() >= this.trainingData.size() - 1) {
            if (null != tensorArr2) {
                RefUtil.freeRef(tensorArr2);
            }
            tensorArr = (Tensor[][]) this.trainingData.stream().filter(refSupplier -> {
                if (refSupplier == null) {
                    return false;
                }
                Tensor[] tensorArr3 = (Tensor[]) refSupplier.get();
                if (tensorArr3 != null) {
                    RefUtil.freeRef(tensorArr3);
                    return true;
                }
                RefUtil.freeRef(tensorArr3);
                return false;
            }).limit(getTrainingSize()).map((v0) -> {
                return v0.get();
            }).toArray(i -> {
                return new Tensor[i];
            });
        } else {
            Random random = new Random(this.seed);
            if (null != tensorArr2) {
                RefUtil.freeRef(tensorArr2);
            }
            RefIntStream distinct = RefIntStream.generate(() -> {
                return random.nextInt(this.trainingData.size());
            }).distinct();
            RefList<? extends RefSupplier<Tensor[]>> refList = this.trainingData;
            refList.getClass();
            tensorArr = (Tensor[][]) distinct.mapToObj(refList::get).filter(refSupplier2 -> {
                if (refSupplier2 == null) {
                    return false;
                }
                Tensor[] tensorArr3 = (Tensor[]) refSupplier2.get();
                if (tensorArr3 != null) {
                    RefUtil.freeRef(tensorArr3);
                    return true;
                }
                RefUtil.freeRef(tensorArr3);
                return false;
            }).limit(getTrainingSize()).map((v0) -> {
                return v0.get();
            }).toArray(i2 -> {
                return new Tensor[i2];
            });
        }
        ArrayTrainable inner = getInner();
        if (!$assertionsDisabled && inner == null) {
            throw new AssertionError();
        }
        inner.setTrainingData(tensorArr);
        inner.freeRef();
    }

    static {
        $assertionsDisabled = !SampledArrayTrainable.class.desiredAssertionStatus();
    }
}
