package com.simiacryptus.mindseye.opt;

import com.simiacryptus.lang.TimedResult;
import com.simiacryptus.lang.UncheckedSupplier;
import com.simiacryptus.mindseye.eval.SampledCachedTrainable;
import com.simiacryptus.mindseye.eval.SampledTrainable;
import com.simiacryptus.mindseye.eval.Trainable;
import com.simiacryptus.mindseye.eval.TrainableBase;
import com.simiacryptus.mindseye.eval.TrainableWrapper;
import com.simiacryptus.mindseye.lang.DoubleBuffer;
import com.simiacryptus.mindseye.lang.IterativeStopException;
import com.simiacryptus.mindseye.lang.Layer;
import com.simiacryptus.mindseye.lang.PointSample;
import com.simiacryptus.mindseye.lang.State;
import com.simiacryptus.mindseye.lang.StateSet;
import com.simiacryptus.mindseye.layers.StochasticComponent;
import com.simiacryptus.mindseye.network.DAGNetwork;
import com.simiacryptus.mindseye.opt.line.ArmijoWolfeSearch;
import com.simiacryptus.mindseye.opt.line.FailsafeLineSearchCursor;
import com.simiacryptus.mindseye.opt.line.LineSearchCursor;
import com.simiacryptus.mindseye.opt.line.LineSearchStrategy;
import com.simiacryptus.mindseye.opt.orient.LBFGS;
import com.simiacryptus.mindseye.opt.orient.OrientationStrategy;
import com.simiacryptus.ref.lang.RefUtil;
import com.simiacryptus.ref.lang.ReferenceCountingBase;
import com.simiacryptus.ref.wrappers.RefArrayList;
import com.simiacryptus.ref.wrappers.RefArrays;
import com.simiacryptus.ref.wrappers.RefCollectors;
import com.simiacryptus.ref.wrappers.RefHashMap;
import com.simiacryptus.ref.wrappers.RefIntStream;
import com.simiacryptus.ref.wrappers.RefList;
import com.simiacryptus.ref.wrappers.RefMap;
import com.simiacryptus.ref.wrappers.RefSet;
import com.simiacryptus.ref.wrappers.RefString;
import com.simiacryptus.ref.wrappers.RefSystem;
import com.simiacryptus.util.FastRandom;
import com.simiacryptus.util.Util;
import com.simiacryptus.util.data.DoubleStatistics;
import java.lang.invoke.SerializedLambda;
import java.lang.management.ManagementFactory;
import java.time.Duration;
import java.time.temporal.ChronoUnit;
import java.time.temporal.TemporalUnit;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Function;
import java.util.function.IntFunction;
import java.util.function.ToDoubleFunction;

/* loaded from: input_file:com/simiacryptus/mindseye/opt/ValidatingTrainer.class */
public class ValidatingTrainer extends ReferenceCountingBase {
    private final RefList<TrainingPhase> regimen;
    private final Trainable validationSubject;
    private double terminateThreshold;
    private Duration timeout;
    private int trainingSize;
    static final /* synthetic */ boolean $assertionsDisabled;
    private final AtomicInteger disappointments = new AtomicInteger(0);
    private final AtomicLong trainingMeasurementTime = new AtomicLong(0);
    private final AtomicLong validatingMeasurementTime = new AtomicLong(0);
    private double adjustmentFactor = 0.5d;
    private double adjustmentTolerance = 0.1d;
    private AtomicInteger currentIteration = new AtomicInteger(0);
    private int disappointmentThreshold = 0;
    private int epochIterations = 1;
    private int improvmentStaleThreshold = 3;
    private int maxEpochIterations = 20;
    private int maxIterations = Integer.MAX_VALUE;
    private int maxTrainingSize = Integer.MAX_VALUE;
    private int minEpochIterations = 1;
    private int minTrainingSize = 100;
    private TrainingMonitor monitor = new TrainingMonitor();
    private double overtrainingTarget = 2.0d;
    private double pessimism = 10.0d;
    private double trainingTarget = 0.7d;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/simiacryptus/mindseye/opt/ValidatingTrainer$EpochParams.class */
    public static class EpochParams extends ReferenceCountingBase {
        final long timeoutMs;
        int iterations;
        int trainingSize;
        PointSample validation;

        private EpochParams(long j, int i, int i2, PointSample pointSample) {
            this.timeoutMs = j;
            this.iterations = i;
            this.trainingSize = i2;
            PointSample m35addRef = pointSample == null ? null : pointSample.m35addRef();
            this.validation = m35addRef == null ? null : m35addRef.m35addRef();
            if (null != m35addRef) {
                m35addRef.freeRef();
            }
            if (null != pointSample) {
                pointSample.freeRef();
            }
        }

        public void _free() {
            super._free();
            if (null != this.validation) {
                this.validation.freeRef();
            }
            this.validation = null;
        }

        /* renamed from: addRef, reason: merged with bridge method [inline-methods] */
        public EpochParams m69addRef() {
            return super.addRef();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/simiacryptus/mindseye/opt/ValidatingTrainer$EpochResult.class */
    public static class EpochResult extends ReferenceCountingBase {
        final boolean continueTraining;
        final PointSample currentPoint;
        final int iterations;
        final double priorMean;

        public EpochResult(boolean z, double d, PointSample pointSample, int i) {
            this.priorMean = d;
            PointSample m35addRef = pointSample == null ? null : pointSample.m35addRef();
            this.currentPoint = m35addRef == null ? null : m35addRef.m35addRef();
            if (null != m35addRef) {
                m35addRef.freeRef();
            }
            if (null != pointSample) {
                pointSample.freeRef();
            }
            this.continueTraining = z;
            this.iterations = i;
        }

        public void _free() {
            super._free();
            if (null != this.currentPoint) {
                this.currentPoint.freeRef();
            }
        }

        /* renamed from: addRef, reason: merged with bridge method [inline-methods] */
        public EpochResult m70addRef() {
            return super.addRef();
        }
    }

    /* loaded from: input_file:com/simiacryptus/mindseye/opt/ValidatingTrainer$PerformanceWrapper.class */
    private static class PerformanceWrapper extends TrainableWrapper<SampledTrainable> implements SampledTrainable {
        private final ValidatingTrainer parent;
        static final /* synthetic */ boolean $assertionsDisabled;

        public PerformanceWrapper(SampledTrainable sampledTrainable, ValidatingTrainer validatingTrainer) {
            super(sampledTrainable);
            ValidatingTrainer m68addRef = validatingTrainer == null ? null : validatingTrainer.m68addRef();
            this.parent = m68addRef == null ? null : m68addRef.m68addRef();
            if (null != m68addRef) {
                m68addRef.freeRef();
            }
            if (null != validatingTrainer) {
                validatingTrainer.freeRef();
            }
        }

        @Override // com.simiacryptus.mindseye.eval.SampledTrainable
        public int getTrainingSize() {
            SampledTrainable inner = getInner();
            if (!$assertionsDisabled && inner == null) {
                throw new AssertionError();
            }
            int trainingSize = inner.getTrainingSize();
            inner.freeRef();
            return trainingSize;
        }

        @Override // com.simiacryptus.mindseye.eval.SampledTrainable
        public void setTrainingSize(int i) {
            SampledTrainable inner = getInner();
            if (!$assertionsDisabled && inner == null) {
                throw new AssertionError();
            }
            inner.setTrainingSize(i);
            inner.freeRef();
        }

        @Override // com.simiacryptus.mindseye.eval.Trainable
        public SampledCachedTrainable<? extends SampledTrainable> cached() {
            return new SampledCachedTrainable<>(mo1addRef());
        }

        @Override // com.simiacryptus.mindseye.eval.TrainableWrapper, com.simiacryptus.mindseye.eval.Trainable
        public PointSample measure(TrainingMonitor trainingMonitor) {
            TimedResult time = TimedResult.time(() -> {
                SampledTrainable inner = getInner();
                if (!$assertionsDisabled && inner == null) {
                    throw new AssertionError();
                }
                PointSample measure = inner.measure(trainingMonitor);
                inner.freeRef();
                return measure;
            });
            this.parent.trainingMeasurementTime.addAndGet(time.timeNanos);
            PointSample pointSample = (PointSample) time.getResult();
            time.freeRef();
            return pointSample;
        }

        @Override // com.simiacryptus.mindseye.eval.TrainableWrapper, com.simiacryptus.mindseye.eval.TrainableDataMask
        public void _free() {
            super._free();
            this.parent.freeRef();
        }

        @Override // com.simiacryptus.mindseye.eval.TrainableWrapper, com.simiacryptus.mindseye.eval.TrainableDataMask, com.simiacryptus.mindseye.eval.Trainable, com.simiacryptus.mindseye.eval.DataTrainable
        /* renamed from: addRef */
        public PerformanceWrapper mo1addRef() {
            return (PerformanceWrapper) super.mo1addRef();
        }

        private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
            String implMethodName = serializedLambda.getImplMethodName();
            boolean z = -1;
            switch (implMethodName.hashCode()) {
                case -1494189565:
                    if (implMethodName.equals("lambda$measure$e2ba1b6$1")) {
                        z = false;
                        break;
                    }
                    break;
            }
            switch (z) {
                case false:
                    if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("com/simiacryptus/lang/UncheckedSupplier") && serializedLambda.getFunctionalInterfaceMethodName().equals("get") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("()Ljava/lang/Object;") && serializedLambda.getImplClass().equals("com/simiacryptus/mindseye/opt/ValidatingTrainer$PerformanceWrapper") && serializedLambda.getImplMethodSignature().equals("(Lcom/simiacryptus/mindseye/opt/TrainingMonitor;)Lcom/simiacryptus/mindseye/lang/PointSample;")) {
                        PerformanceWrapper performanceWrapper = (PerformanceWrapper) serializedLambda.getCapturedArg(0);
                        TrainingMonitor trainingMonitor = (TrainingMonitor) serializedLambda.getCapturedArg(1);
                        return () -> {
                            SampledTrainable inner = getInner();
                            if (!$assertionsDisabled && inner == null) {
                                throw new AssertionError();
                            }
                            PointSample measure = inner.measure(trainingMonitor);
                            inner.freeRef();
                            return measure;
                        };
                    }
                    break;
            }
            throw new IllegalArgumentException("Invalid lambda deserialization");
        }

        static {
            $assertionsDisabled = !ValidatingTrainer.class.desiredAssertionStatus();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/simiacryptus/mindseye/opt/ValidatingTrainer$StepResult.class */
    public static class StepResult extends ReferenceCountingBase {
        final PointSample currentPoint;
        final double[] performance;
        final PointSample previous;

        public StepResult(PointSample pointSample, PointSample pointSample2, double[] dArr) {
            this.currentPoint = pointSample2;
            this.previous = pointSample;
            this.performance = dArr;
        }

        public void _free() {
            super._free();
            if (null != this.previous) {
                this.previous.freeRef();
            }
            if (null != this.currentPoint) {
                this.currentPoint.freeRef();
            }
        }

        /* renamed from: addRef, reason: merged with bridge method [inline-methods] */
        public StepResult m72addRef() {
            return super.addRef();
        }
    }

    /* loaded from: input_file:com/simiacryptus/mindseye/opt/ValidatingTrainer$TrainingPhase.class */
    public static class TrainingPhase extends ReferenceCountingBase {
        private Function<CharSequence, LineSearchStrategy> lineSearchFactory = charSequence -> {
            return new ArmijoWolfeSearch();
        };
        private RefMap<CharSequence, LineSearchStrategy> lineSearchStrategyMap = new RefHashMap();
        private OrientationStrategy<?> orientation = new LBFGS();
        private SampledTrainable trainingSubject;

        public TrainingPhase(SampledTrainable sampledTrainable) {
            setTrainingSubject(sampledTrainable == null ? null : sampledTrainable.mo1addRef());
            if (null != sampledTrainable) {
                sampledTrainable.freeRef();
            }
        }

        public Function<CharSequence, LineSearchStrategy> getLineSearchFactory() {
            return this.lineSearchFactory;
        }

        public void setLineSearchFactory(Function<CharSequence, LineSearchStrategy> function) {
            this.lineSearchFactory = function;
        }

        public RefMap<CharSequence, LineSearchStrategy> getLineSearchStrategyMap() {
            if (this.lineSearchStrategyMap == null) {
                return null;
            }
            return this.lineSearchStrategyMap.addRef();
        }

        public void setLineSearchStrategyMap(RefMap<CharSequence, LineSearchStrategy> refMap) {
            RefMap addRef = refMap == null ? null : refMap.addRef();
            if (null != this.lineSearchStrategyMap) {
                this.lineSearchStrategyMap.freeRef();
            }
            this.lineSearchStrategyMap = addRef == null ? null : addRef.addRef();
            if (null != addRef) {
                addRef.freeRef();
            }
            if (null != refMap) {
                refMap.freeRef();
            }
        }

        public OrientationStrategy<?> getOrientation() {
            if (this.orientation == null) {
                return null;
            }
            return this.orientation.mo84addRef();
        }

        public void setOrientation(OrientationStrategy<?> orientationStrategy) {
            OrientationStrategy<?> mo84addRef = orientationStrategy == null ? null : orientationStrategy.mo84addRef();
            if (null != this.orientation) {
                this.orientation.freeRef();
            }
            this.orientation = mo84addRef == null ? null : mo84addRef.mo84addRef();
            if (null != mo84addRef) {
                mo84addRef.freeRef();
            }
            if (null != orientationStrategy) {
                orientationStrategy.freeRef();
            }
        }

        public SampledTrainable getTrainingSubject() {
            if (this.trainingSubject == null) {
                return null;
            }
            return this.trainingSubject.mo1addRef();
        }

        public void setTrainingSubject(SampledTrainable sampledTrainable) {
            SampledTrainable mo1addRef = sampledTrainable == null ? null : sampledTrainable.mo1addRef();
            if (null != this.trainingSubject) {
                this.trainingSubject.freeRef();
            }
            this.trainingSubject = mo1addRef == null ? null : mo1addRef.mo1addRef();
            if (null != mo1addRef) {
                mo1addRef.freeRef();
            }
            if (null != sampledTrainable) {
                sampledTrainable.freeRef();
            }
        }

        public String toString() {
            return "TrainingPhase{trainingSubject=" + this.trainingSubject + ", orientation=" + this.orientation + '}';
        }

        public void _free() {
            super._free();
            if (null != this.trainingSubject) {
                this.trainingSubject.freeRef();
            }
            this.trainingSubject = null;
            if (null != this.orientation) {
                this.orientation.freeRef();
            }
            this.orientation = null;
            if (null != this.lineSearchStrategyMap) {
                this.lineSearchStrategyMap.freeRef();
            }
            this.lineSearchStrategyMap = null;
        }

        /* renamed from: addRef, reason: merged with bridge method [inline-methods] */
        public TrainingPhase m73addRef() {
            return super.addRef();
        }
    }

    public ValidatingTrainer(SampledTrainable sampledTrainable, final Trainable trainable) {
        this.trainingSize = 10000;
        RefArrayList refArrayList = new RefArrayList(RefArrays.asList(new TrainingPhase[]{new TrainingPhase(new PerformanceWrapper(sampledTrainable.mo1addRef(), m68addRef()))}));
        this.regimen = refArrayList.addRef();
        refArrayList.freeRef();
        TrainableBase trainableBase = new TrainableBase() { // from class: com.simiacryptus.mindseye.opt.ValidatingTrainer.1
            {
                trainable.mo1addRef();
            }

            @Override // com.simiacryptus.mindseye.eval.Trainable
            public Layer getLayer() {
                return trainable.getLayer();
            }

            @Override // com.simiacryptus.mindseye.eval.Trainable
            public PointSample measure(TrainingMonitor trainingMonitor) {
                Trainable trainable2 = trainable;
                TimedResult time = TimedResult.time((UncheckedSupplier) RefUtil.wrapInterface(() -> {
                    return trainable2.measure(trainingMonitor);
                }, new Object[]{trainable.mo1addRef()}));
                ValidatingTrainer.this.validatingMeasurementTime.addAndGet(time.timeNanos);
                PointSample pointSample = (PointSample) time.getResult();
                time.freeRef();
                return pointSample;
            }

            @Override // com.simiacryptus.mindseye.eval.Trainable
            public boolean reseed(long j) {
                return trainable.reseed(j);
            }

            @Override // com.simiacryptus.mindseye.eval.TrainableBase, com.simiacryptus.mindseye.eval.SampledTrainable, com.simiacryptus.mindseye.eval.TrainableDataMask
            public void _free() {
                super._free();
                trainable.freeRef();
            }

            private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
                String implMethodName = serializedLambda.getImplMethodName();
                boolean z = -1;
                switch (implMethodName.hashCode()) {
                    case 201225402:
                        if (implMethodName.equals("lambda$measure$72689b26$1")) {
                            z = false;
                            break;
                        }
                        break;
                }
                switch (z) {
                    case false:
                        if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("com/simiacryptus/lang/UncheckedSupplier") && serializedLambda.getFunctionalInterfaceMethodName().equals("get") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("()Ljava/lang/Object;") && serializedLambda.getImplClass().equals("com/simiacryptus/mindseye/opt/ValidatingTrainer$1") && serializedLambda.getImplMethodSignature().equals("(Lcom/simiacryptus/mindseye/eval/Trainable;Lcom/simiacryptus/mindseye/opt/TrainingMonitor;)Lcom/simiacryptus/mindseye/lang/PointSample;")) {
                            Trainable trainable2 = (Trainable) serializedLambda.getCapturedArg(0);
                            TrainingMonitor trainingMonitor = (TrainingMonitor) serializedLambda.getCapturedArg(1);
                            return () -> {
                                return trainable2.measure(trainingMonitor);
                            };
                        }
                        break;
                }
                throw new IllegalArgumentException("Invalid lambda deserialization");
            }
        };
        this.validationSubject = trainableBase.mo1addRef();
        trainableBase.freeRef();
        trainable.freeRef();
        this.trainingSize = sampledTrainable.getTrainingSize();
        sampledTrainable.freeRef();
        this.timeout = Duration.of(5L, ChronoUnit.MINUTES);
        this.terminateThreshold = Double.NEGATIVE_INFINITY;
    }

    public double getAdjustmentFactor() {
        return this.adjustmentFactor;
    }

    public void setAdjustmentFactor(double d) {
        this.adjustmentFactor = d;
    }

    public double getAdjustmentTolerance() {
        return this.adjustmentTolerance;
    }

    public void setAdjustmentTolerance(double d) {
        this.adjustmentTolerance = d;
    }

    public AtomicInteger getCurrentIteration() {
        return this.currentIteration;
    }

    public void setCurrentIteration(AtomicInteger atomicInteger) {
        this.currentIteration = atomicInteger;
    }

    public int getDisappointmentThreshold() {
        return this.disappointmentThreshold;
    }

    public void setDisappointmentThreshold(int i) {
        this.disappointmentThreshold = i;
    }

    public int getEpochIterations() {
        return this.epochIterations;
    }

    public void setEpochIterations(int i) {
        this.epochIterations = i;
    }

    public int getImprovmentStaleThreshold() {
        return this.improvmentStaleThreshold;
    }

    public void setImprovmentStaleThreshold(int i) {
        this.improvmentStaleThreshold = i;
    }

    public int getMaxEpochIterations() {
        return this.maxEpochIterations;
    }

    public void setMaxEpochIterations(int i) {
        this.maxEpochIterations = i;
    }

    public int getMaxIterations() {
        return this.maxIterations;
    }

    public void setMaxIterations(int i) {
        this.maxIterations = i;
    }

    public int getMaxTrainingSize() {
        return this.maxTrainingSize;
    }

    public void setMaxTrainingSize(int i) {
        this.maxTrainingSize = i;
    }

    public int getMinEpochIterations() {
        return this.minEpochIterations;
    }

    public void setMinEpochIterations(int i) {
        this.minEpochIterations = i;
    }

    public int getMinTrainingSize() {
        return this.minTrainingSize;
    }

    public void setMinTrainingSize(int i) {
        this.minTrainingSize = i;
    }

    public TrainingMonitor getMonitor() {
        return this.monitor;
    }

    public void setMonitor(TrainingMonitor trainingMonitor) {
        this.monitor = trainingMonitor;
    }

    public double getOvertrainingTarget() {
        return this.overtrainingTarget;
    }

    public void setOvertrainingTarget(double d) {
        this.overtrainingTarget = d;
    }

    public double getPessimism() {
        return this.pessimism;
    }

    public void setPessimism(double d) {
        this.pessimism = d;
    }

    public RefList<TrainingPhase> getRegimen() {
        return this.regimen.addRef();
    }

    public double getTerminateThreshold() {
        return this.terminateThreshold;
    }

    public void setTerminateThreshold(double d) {
        this.terminateThreshold = d;
    }

    public Duration getTimeout() {
        return this.timeout;
    }

    public void setTimeout(Duration duration) {
        this.timeout = duration;
    }

    public int getTrainingSize() {
        return this.trainingSize;
    }

    public void setTrainingSize(int i) {
        this.trainingSize = i;
    }

    public double getTrainingTarget() {
        return this.trainingTarget;
    }

    public void setTrainingTarget(double d) {
        this.trainingTarget = d;
    }

    public Trainable getValidationSubject() {
        return this.validationSubject.mo1addRef();
    }

    public void setLineSearchFactory(Function<CharSequence, LineSearchStrategy> function) {
        RefList<TrainingPhase> regimen = getRegimen();
        TrainingPhase trainingPhase = (TrainingPhase) regimen.get(0);
        trainingPhase.setLineSearchFactory(function);
        trainingPhase.freeRef();
        regimen.freeRef();
    }

    public void setOrientation(OrientationStrategy<?> orientationStrategy) {
        RefList<TrainingPhase> regimen = getRegimen();
        TrainingPhase trainingPhase = (TrainingPhase) regimen.get(0);
        trainingPhase.setOrientation(orientationStrategy == null ? null : orientationStrategy.mo84addRef());
        trainingPhase.freeRef();
        regimen.freeRef();
        if (null != orientationStrategy) {
            orientationStrategy.freeRef();
        }
    }

    private static CharSequence getId(DoubleBuffer<UUID> doubleBuffer) {
        String uuid = doubleBuffer.key.toString();
        doubleBuffer.freeRef();
        return uuid;
    }

    public double run() {
        Layer layer = this.validationSubject.getLayer();
        try {
            try {
                long currentTimeMillis = RefSystem.currentTimeMillis() + this.timeout.toMillis();
                if (layer instanceof DAGNetwork) {
                    ((DAGNetwork) layer).visitLayers(layer2 -> {
                        if (layer2 instanceof StochasticComponent) {
                            ((StochasticComponent) layer2).clearNoise();
                        }
                        if (null != layer2) {
                            layer2.freeRef();
                        }
                    });
                }
                EpochParams epochParams = new EpochParams(currentTimeMillis, this.epochIterations, getTrainingSize(), this.validationSubject.measure(this.monitor));
                int i = 0;
                int i2 = 0;
                int i3 = 0;
                double d = Double.POSITIVE_INFINITY;
                while (true) {
                    if (shouldHalt(this.monitor, currentTimeMillis)) {
                        this.monitor.log("Training halted");
                        break;
                    }
                    this.monitor.log(RefString.format("Epoch parameters: %s, %s", new Object[]{Integer.valueOf(epochParams.trainingSize), Integer.valueOf(epochParams.iterations)}));
                    RefList<TrainingPhase> regimen = getRegimen();
                    long nanoTime = RefSystem.nanoTime();
                    RefList refList = (RefList) RefIntStream.range(0, regimen.size()).mapToObj((IntFunction) RefUtil.wrapInterface(i4 -> {
                        RefList<TrainingPhase> regimen2 = getRegimen();
                        TrainingPhase trainingPhase = (TrainingPhase) regimen2.get(i4);
                        regimen2.freeRef();
                        EpochResult runPhase = runPhase(epochParams.m69addRef(), trainingPhase == null ? null : trainingPhase.m73addRef(), i4, nanoTime);
                        if (null != trainingPhase) {
                            trainingPhase.freeRef();
                        }
                        return runPhase;
                    }, new Object[]{epochParams.m69addRef()})).collect(RefCollectors.toList());
                    regimen.freeRef();
                    EpochResult epochResult = (EpochResult) refList.get(0);
                    refList.freeRef();
                    i2 += epochResult.iterations;
                    if (!$assertionsDisabled && epochResult.currentPoint == null) {
                        throw new AssertionError();
                    }
                    double mean = epochResult.currentPoint.getMean() / epochResult.priorMean;
                    if (layer instanceof DAGNetwork) {
                        ((DAGNetwork) layer).visitLayers(layer3 -> {
                            if (layer3 instanceof StochasticComponent) {
                                ((StochasticComponent) layer3).clearNoise();
                            }
                            if (null != layer3) {
                                layer3.freeRef();
                            }
                        });
                    }
                    PointSample measure = this.validationSubject.measure(this.monitor);
                    if (!$assertionsDisabled && epochParams.validation == null) {
                        throw new AssertionError();
                    }
                    double log = Math.log(mean) / Math.log(measure.getMean() / epochParams.validation.getMean());
                    double mean2 = measure.getMean() / epochParams.validation.getMean();
                    double pow = Math.pow(Math.log(getTrainingTarget()) / Math.log(mean2), this.adjustmentFactor);
                    double pow2 = Math.pow(log / getOvertrainingTarget(), this.adjustmentFactor);
                    double mean3 = measure.getMean();
                    if (mean3 < d) {
                        d = mean3;
                        i3 = i2;
                    }
                    i++;
                    this.monitor.log(RefString.format("Epoch %d result apply %s iterations, %s/%s samples: {validation *= 2^%.5f; training *= 2^%.3f; Overtraining = %.2f}, {itr*=%.2f, len*=%.2f} %s since improvement; %.4f validation time", new Object[]{Integer.valueOf(i), Integer.valueOf(epochResult.iterations), Integer.valueOf(epochParams.trainingSize), Integer.valueOf(getMaxTrainingSize()), Double.valueOf(Math.log(mean2) / Math.log(2.0d)), Double.valueOf(Math.log(mean) / Math.log(2.0d)), Double.valueOf(log), Double.valueOf(pow), Double.valueOf(pow2), Integer.valueOf(i2 - i3), Double.valueOf(this.validatingMeasurementTime.getAndSet(0L) / 1.0E9d)}));
                    if (!epochResult.continueTraining) {
                        this.monitor.log(RefString.format("Training %d runPhase halted", new Object[]{Integer.valueOf(i)}));
                        break;
                    }
                    if (epochParams.trainingSize >= getMaxTrainingSize()) {
                        double random = FastRandom.INSTANCE.random();
                        if (random > Math.pow(2.0d - mean2, this.pessimism)) {
                            this.monitor.log(RefString.format("Training randomly converged: %3f", new Object[]{Double.valueOf(random)}));
                            break;
                        }
                        if (i2 - i3 <= this.improvmentStaleThreshold) {
                            this.disappointments.set(0);
                        } else {
                            if (this.disappointments.incrementAndGet() > getDisappointmentThreshold()) {
                                this.monitor.log(RefString.format("Training converged after %s iterations", new Object[]{Integer.valueOf(i2 - i3)}));
                                break;
                            }
                            this.monitor.log(RefString.format("Training failed to converged on %s attempt after %s iterations", new Object[]{Integer.valueOf(this.disappointments.get()), Integer.valueOf(i2 - i3)}));
                        }
                    }
                    if (mean2 >= 1.0d || mean >= 1.0d) {
                        epochParams.trainingSize = Math.max(0, Math.min(Math.max(getMinTrainingSize(), Math.min(getMaxTrainingSize(), epochParams.trainingSize * 5)), epochParams.trainingSize));
                        epochParams.iterations = 1;
                    } else {
                        if (pow < 1.0d - this.adjustmentTolerance || pow > 1.0d + this.adjustmentTolerance) {
                            epochParams.iterations = Math.max(getMinEpochIterations(), Math.min(getMaxEpochIterations(), (int) (epochResult.iterations * pow)));
                        }
                        if (pow2 < 1.0d + this.adjustmentTolerance || pow2 > 1.0d - this.adjustmentTolerance) {
                            epochParams.trainingSize = Math.max(0, Math.min(Math.max(getMinTrainingSize(), Math.min(getMaxTrainingSize(), (int) (epochParams.trainingSize * pow2))), epochParams.trainingSize));
                        }
                    }
                    epochResult.freeRef();
                    epochParams.validation = measure.m35addRef();
                    measure.freeRef();
                }
                if (layer instanceof DAGNetwork) {
                    ((DAGNetwork) layer).visitLayers(layer4 -> {
                        if (layer4 instanceof StochasticComponent) {
                            ((StochasticComponent) layer4).clearNoise();
                        }
                        if (null != layer4) {
                            layer4.freeRef();
                        }
                    });
                }
                if (!$assertionsDisabled && epochParams.validation == null) {
                    throw new AssertionError();
                }
                double mean4 = epochParams.validation.getMean();
                epochParams.freeRef();
                if (null != layer) {
                    layer.freeRef();
                }
                return mean4;
            } catch (Throwable th) {
                throw Util.throwException(th);
            }
        } catch (Throwable th2) {
            if (null != layer) {
                layer.freeRef();
            }
            throw th2;
        }
    }

    public void setTimeout(int i, TemporalUnit temporalUnit) {
        this.timeout = Duration.of(i, temporalUnit);
    }

    public void setTimeout(int i, TimeUnit timeUnit) {
        setTimeout(i, Util.cvt(timeUnit));
    }

    public void _free() {
        super._free();
        this.validationSubject.freeRef();
        this.regimen.freeRef();
    }

    /* renamed from: addRef, reason: merged with bridge method [inline-methods] */
    public ValidatingTrainer m68addRef() {
        return super.addRef();
    }

    protected EpochResult runPhase(EpochParams epochParams, TrainingPhase trainingPhase, int i, long j) {
        this.monitor.log(RefString.format("Phase %d: %s", new Object[]{Integer.valueOf(i), trainingPhase.m73addRef()}));
        if (!$assertionsDisabled && trainingPhase.trainingSubject == null) {
            throw new AssertionError();
        }
        trainingPhase.trainingSubject.setTrainingSize(epochParams.trainingSize);
        this.monitor.log(RefString.format("resetAndMeasure; trainingSize=%s", new Object[]{Integer.valueOf(epochParams.trainingSize)}));
        reset(trainingPhase.m73addRef(), j);
        ValidatingTrainer m68addRef = m68addRef();
        PointSample measure = m68addRef.measure(trainingPhase.m73addRef());
        m68addRef.freeRef();
        double mean = measure.getMean();
        if (!$assertionsDisabled && 0 >= measure.delta.size()) {
            throw new AssertionError("Nothing to optimize");
        }
        int i2 = 1;
        while (true) {
            if (i2 > epochParams.iterations && epochParams.iterations > 0) {
                trainingPhase.freeRef();
                epochParams.freeRef();
                return new EpochResult(true, mean, measure, i2);
            }
            if (shouldHalt(this.monitor, epochParams.timeoutMs)) {
                EpochResult epochResult = new EpochResult(false, mean, measure.m35addRef(), i2);
                measure.freeRef();
                epochParams.freeRef();
                trainingPhase.freeRef();
                return epochResult;
            }
            long nanoTime = RefSystem.nanoTime();
            long sum = ManagementFactory.getGarbageCollectorMXBeans().stream().mapToLong((v0) -> {
                return v0.getCollectionTime();
            }).sum();
            StepResult runStep = runStep(measure.m35addRef(), trainingPhase.m73addRef());
            String format = RefString.format("%s in %.3f seconds; %.3f in orientation, %.3f in gc, %.3f in line search; %.3f trainAll time", new Object[]{Integer.valueOf(epochParams.trainingSize), Double.valueOf((RefSystem.nanoTime() - nanoTime) / 1.0E9d), Double.valueOf(runStep.performance[0]), Double.valueOf((ManagementFactory.getGarbageCollectorMXBeans().stream().mapToLong((v0) -> {
                return v0.getCollectionTime();
            }).sum() - sum) / 1000.0d), Double.valueOf(runStep.performance[1]), Double.valueOf(this.trainingMeasurementTime.getAndSet(0L) / 1.0E9d)});
            measure.freeRef();
            runStep.currentPoint.setRate(0.0d);
            measure = runStep.currentPoint.m35addRef();
            if (runStep.previous.getMean() <= runStep.currentPoint.getMean()) {
                this.monitor.log(RefString.format("Iteration %s failed, aborting. Error: %s (%s)", new Object[]{Integer.valueOf(this.currentIteration.get()), Double.valueOf(runStep.currentPoint.getMean()), format}));
                runStep.freeRef();
                epochParams.freeRef();
                trainingPhase.freeRef();
                return new EpochResult(false, mean, measure, i2);
            }
            this.monitor.log(RefString.format("Iteration %s complete. Error: %s (%s)", new Object[]{Integer.valueOf(this.currentIteration.get()), Double.valueOf(runStep.currentPoint.getMean()), format}));
            runStep.freeRef();
            this.monitor.onStepComplete(new Step(measure.m35addRef(), this.currentIteration.get()));
            i2++;
        }
    }

    protected StepResult runStep(PointSample pointSample, TrainingPhase trainingPhase) {
        LineSearchStrategy lineSearchStrategy;
        this.currentIteration.incrementAndGet();
        TimedResult time = TimedResult.time((UncheckedSupplier) RefUtil.wrapInterface(() -> {
            if (!$assertionsDisabled && trainingPhase.trainingSubject == null) {
                throw new AssertionError();
            }
            if ($assertionsDisabled || trainingPhase.orientation != null) {
                return trainingPhase.orientation.orient(trainingPhase.trainingSubject.mo1addRef(), pointSample.m35addRef(), this.monitor);
            }
            throw new AssertionError();
        }, new Object[]{pointSample.m35addRef(), trainingPhase.m73addRef()}));
        LineSearchCursor lineSearchCursor = (LineSearchCursor) time.getResult();
        CharSequence directionType = lineSearchCursor.getDirectionType();
        if (!$assertionsDisabled && trainingPhase.lineSearchStrategyMap == null) {
            throw new AssertionError();
        }
        if (trainingPhase.lineSearchStrategyMap.containsKey(directionType)) {
            lineSearchStrategy = (LineSearchStrategy) trainingPhase.lineSearchStrategyMap.get(directionType);
        } else {
            this.monitor.log(RefString.format("Constructing line search parameters: %s", new Object[]{directionType}));
            lineSearchStrategy = (LineSearchStrategy) trainingPhase.lineSearchFactory.apply(lineSearchCursor.getDirectionType());
            RefUtil.freeRef(trainingPhase.lineSearchStrategyMap.put(directionType, lineSearchStrategy));
        }
        trainingPhase.freeRef();
        LineSearchStrategy lineSearchStrategy2 = lineSearchStrategy;
        TimedResult time2 = TimedResult.time((UncheckedSupplier) RefUtil.wrapInterface(() -> {
            FailsafeLineSearchCursor failsafeLineSearchCursor = new FailsafeLineSearchCursor(lineSearchCursor.mo77addRef(), pointSample.m35addRef(), this.monitor);
            if (!$assertionsDisabled && lineSearchStrategy2 == null) {
                throw new AssertionError();
            }
            RefUtil.freeRef(lineSearchStrategy2.step(failsafeLineSearchCursor.mo77addRef(), this.monitor));
            PointSample best = failsafeLineSearchCursor.getBest();
            if (!$assertionsDisabled && best == null) {
                throw new AssertionError();
            }
            best.restore();
            PointSample m35addRef = best.m35addRef();
            best.freeRef();
            failsafeLineSearchCursor.freeRef();
            return m35addRef;
        }, new Object[]{pointSample.m35addRef(), lineSearchCursor.mo77addRef()}));
        lineSearchCursor.freeRef();
        PointSample pointSample2 = (PointSample) time2.getResult();
        if (pointSample2.getMean() > pointSample.getMean()) {
            IllegalStateException illegalStateException = new IllegalStateException(pointSample2.getMean() + " > " + pointSample.getMean());
            pointSample2.freeRef();
            pointSample.freeRef();
            time.freeRef();
            time2.freeRef();
            throw illegalStateException;
        }
        this.monitor.log(compare(pointSample.m35addRef(), pointSample2.m35addRef()));
        StepResult stepResult = new StepResult(pointSample, pointSample2.m35addRef(), new double[]{time.timeNanos / 1.0E9d, time2.timeNanos / 1.0E9d});
        time2.freeRef();
        time.freeRef();
        pointSample2.freeRef();
        return stepResult;
    }

    protected boolean shouldHalt(TrainingMonitor trainingMonitor, long j) {
        RefSystem.currentTimeMillis();
        if (j < RefSystem.currentTimeMillis()) {
            trainingMonitor.log("Training timeout");
            return true;
        }
        if (this.currentIteration.get() <= this.maxIterations) {
            return false;
        }
        trainingMonitor.log("Training iteration overflow");
        return true;
    }

    private String compare(PointSample pointSample, PointSample pointSample2) {
        StateSet<UUID> mo16addRef = pointSample2.weights.mo16addRef();
        pointSample2.freeRef();
        StateSet<UUID> mo16addRef2 = pointSample.weights.mo16addRef();
        pointSample.freeRef();
        RefMap refMap = (RefMap) mo16addRef2.stream().collect(RefCollectors.groupingBy(state -> {
            return state;
        }, RefCollectors.toList()));
        RefSet entrySet = refMap.entrySet();
        String format = RefString.format("Overall network state change: %s", new Object[]{(RefMap) entrySet.stream().collect(RefCollectors.toMap(entry -> {
            State state2 = (State) entry.getKey();
            RefUtil.freeRef(entry);
            return state2;
        }, (Function) RefUtil.wrapInterface(entry2 -> {
            RefList refList = (RefList) entry2.getValue();
            double[] array = refList.stream().mapToDouble((ToDoubleFunction) RefUtil.wrapInterface(state2 -> {
                DoubleBuffer doubleBuffer = mo16addRef.get(state2.key);
                double rms = state2.deltaStatistics().rms();
                state2.freeRef();
                double rms2 = null == doubleBuffer ? 0.0d : doubleBuffer.deltaStatistics().rms();
                if (null != doubleBuffer) {
                    doubleBuffer.freeRef();
                }
                return rms / (0.0d == rms2 ? 1.0d : rms2);
            }, new Object[]{mo16addRef.mo16addRef()})).toArray();
            refList.freeRef();
            RefUtil.freeRef(entry2);
            return 1 == array.length ? Double.toString(array[0]) : new DoubleStatistics().accept(array).toString();
        }, new Object[]{mo16addRef})))});
        entrySet.freeRef();
        refMap.freeRef();
        mo16addRef2.freeRef();
        return format;
    }

    private PointSample measure(TrainingPhase trainingPhase) {
        int i = 0;
        while (true) {
            try {
                int i2 = i;
                i++;
                if (10 < i2) {
                    throw new IterativeStopException();
                }
                if (!$assertionsDisabled && trainingPhase.trainingSubject == null) {
                    throw new AssertionError();
                }
                PointSample measure = trainingPhase.trainingSubject.measure(this.monitor);
                if (Double.isFinite(measure.getMean())) {
                    return measure;
                }
                measure.freeRef();
                if (!$assertionsDisabled && trainingPhase.orientation == null) {
                    throw new AssertionError();
                }
                trainingPhase.orientation.reset();
            } finally {
                trainingPhase.freeRef();
            }
        }
    }

    private void reset(TrainingPhase trainingPhase, long j) {
        if (!$assertionsDisabled && trainingPhase.trainingSubject == null) {
            throw new AssertionError();
        }
        if (!trainingPhase.trainingSubject.reseed(j)) {
            trainingPhase.freeRef();
            throw new IterativeStopException();
        }
        if (!$assertionsDisabled && trainingPhase.orientation == null) {
            throw new AssertionError();
        }
        trainingPhase.orientation.reset();
        trainingPhase.trainingSubject.reseed(j);
        Layer layer = trainingPhase.trainingSubject.getLayer();
        if (layer instanceof DAGNetwork) {
            ((DAGNetwork) layer).shuffle(StochasticComponent.random.get().nextLong());
        }
        if (null != layer) {
            layer.freeRef();
        }
        trainingPhase.freeRef();
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case 853531514:
                if (implMethodName.equals("lambda$runStep$55006801$1")) {
                    z = false;
                    break;
                }
                break;
            case 1118358802:
                if (implMethodName.equals("lambda$runStep$90d577a4$1")) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("com/simiacryptus/lang/UncheckedSupplier") && serializedLambda.getFunctionalInterfaceMethodName().equals("get") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("()Ljava/lang/Object;") && serializedLambda.getImplClass().equals("com/simiacryptus/mindseye/opt/ValidatingTrainer") && serializedLambda.getImplMethodSignature().equals("(Lcom/simiacryptus/mindseye/opt/line/LineSearchCursor;Lcom/simiacryptus/mindseye/lang/PointSample;Lcom/simiacryptus/mindseye/opt/line/LineSearchStrategy;)Lcom/simiacryptus/mindseye/lang/PointSample;")) {
                    ValidatingTrainer validatingTrainer = (ValidatingTrainer) serializedLambda.getCapturedArg(0);
                    LineSearchCursor lineSearchCursor = (LineSearchCursor) serializedLambda.getCapturedArg(1);
                    PointSample pointSample = (PointSample) serializedLambda.getCapturedArg(2);
                    LineSearchStrategy lineSearchStrategy = (LineSearchStrategy) serializedLambda.getCapturedArg(3);
                    return () -> {
                        FailsafeLineSearchCursor failsafeLineSearchCursor = new FailsafeLineSearchCursor(lineSearchCursor.mo77addRef(), pointSample.m35addRef(), this.monitor);
                        if (!$assertionsDisabled && lineSearchStrategy == null) {
                            throw new AssertionError();
                        }
                        RefUtil.freeRef(lineSearchStrategy.step(failsafeLineSearchCursor.mo77addRef(), this.monitor));
                        PointSample best = failsafeLineSearchCursor.getBest();
                        if (!$assertionsDisabled && best == null) {
                            throw new AssertionError();
                        }
                        best.restore();
                        PointSample m35addRef = best.m35addRef();
                        best.freeRef();
                        failsafeLineSearchCursor.freeRef();
                        return m35addRef;
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("com/simiacryptus/lang/UncheckedSupplier") && serializedLambda.getFunctionalInterfaceMethodName().equals("get") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("()Ljava/lang/Object;") && serializedLambda.getImplClass().equals("com/simiacryptus/mindseye/opt/ValidatingTrainer") && serializedLambda.getImplMethodSignature().equals("(Lcom/simiacryptus/mindseye/opt/ValidatingTrainer$TrainingPhase;Lcom/simiacryptus/mindseye/lang/PointSample;)Lcom/simiacryptus/mindseye/opt/line/LineSearchCursor;")) {
                    ValidatingTrainer validatingTrainer2 = (ValidatingTrainer) serializedLambda.getCapturedArg(0);
                    TrainingPhase trainingPhase = (TrainingPhase) serializedLambda.getCapturedArg(1);
                    PointSample pointSample2 = (PointSample) serializedLambda.getCapturedArg(2);
                    return () -> {
                        if (!$assertionsDisabled && trainingPhase.trainingSubject == null) {
                            throw new AssertionError();
                        }
                        if ($assertionsDisabled || trainingPhase.orientation != null) {
                            return trainingPhase.orientation.orient(trainingPhase.trainingSubject.mo1addRef(), pointSample2.m35addRef(), this.monitor);
                        }
                        throw new AssertionError();
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }

    static {
        $assertionsDisabled = !ValidatingTrainer.class.desiredAssertionStatus();
    }
}
