package com.simiacryptus.mindseye.opt;

import com.simiacryptus.lang.TimedResult;
import com.simiacryptus.lang.UncheckedSupplier;
import com.simiacryptus.mindseye.eval.Trainable;
import com.simiacryptus.mindseye.lang.IterativeStopException;
import com.simiacryptus.mindseye.lang.Layer;
import com.simiacryptus.mindseye.lang.PointSample;
import com.simiacryptus.mindseye.network.DAGNetwork;
import com.simiacryptus.mindseye.opt.line.ArmijoWolfeSearch;
import com.simiacryptus.mindseye.opt.line.FailsafeLineSearchCursor;
import com.simiacryptus.mindseye.opt.line.LineSearchCursor;
import com.simiacryptus.mindseye.opt.line.LineSearchPoint;
import com.simiacryptus.mindseye.opt.line.LineSearchStrategy;
import com.simiacryptus.mindseye.opt.orient.LBFGS;
import com.simiacryptus.mindseye.opt.orient.OrientationStrategy;
import com.simiacryptus.ref.lang.RefUtil;
import com.simiacryptus.ref.lang.ReferenceCountingBase;
import com.simiacryptus.ref.wrappers.RefHashMap;
import com.simiacryptus.ref.wrappers.RefMap;
import com.simiacryptus.ref.wrappers.RefString;
import com.simiacryptus.ref.wrappers.RefSystem;
import com.simiacryptus.util.Util;
import java.lang.invoke.SerializedLambda;
import java.time.Duration;
import java.time.temporal.ChronoUnit;
import java.time.temporal.TemporalUnit;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/simiacryptus/mindseye/opt/IterativeTrainer.class */
public class IterativeTrainer extends ReferenceCountingBase {
    private static final Logger log;
    private final Trainable subject;
    static final /* synthetic */ boolean $assertionsDisabled;
    private final RefMap<CharSequence, LineSearchStrategy> lineSearchStrategyMap = new RefHashMap();
    private AtomicInteger currentIteration = new AtomicInteger(0);
    private int iterationsPerSample = 100;
    private Function<CharSequence, LineSearchStrategy> lineSearchFactory = charSequence -> {
        return new ArmijoWolfeSearch();
    };
    private int maxIterations = Integer.MAX_VALUE;
    private TrainingMonitor monitor = new TrainingMonitor();
    private OrientationStrategy<?> orientation = new LBFGS();
    private Duration timeout = Duration.of(5, ChronoUnit.MINUTES);
    private double terminateThreshold = 0.0d;

    public IterativeTrainer(Trainable trainable) {
        this.subject = trainable;
    }

    public AtomicInteger getCurrentIteration() {
        return this.currentIteration;
    }

    public void setCurrentIteration(AtomicInteger atomicInteger) {
        this.currentIteration = atomicInteger;
    }

    public int getIterationsPerSample() {
        return this.iterationsPerSample;
    }

    public void setIterationsPerSample(int i) {
        this.iterationsPerSample = i;
    }

    public Function<CharSequence, LineSearchStrategy> getLineSearchFactory() {
        return this.lineSearchFactory;
    }

    public void setLineSearchFactory(Function<CharSequence, LineSearchStrategy> function) {
        this.lineSearchFactory = function;
    }

    public int getMaxIterations() {
        return this.maxIterations;
    }

    public void setMaxIterations(int i) {
        this.maxIterations = i;
    }

    public TrainingMonitor getMonitor() {
        return this.monitor;
    }

    public void setMonitor(TrainingMonitor trainingMonitor) {
        this.monitor = trainingMonitor;
    }

    public OrientationStrategy<?> getOrientation() {
        if (this.orientation == null) {
            return null;
        }
        return this.orientation.mo84addRef();
    }

    public void setOrientation(OrientationStrategy<?> orientationStrategy) {
        if (null != this.orientation) {
            this.orientation.freeRef();
        }
        this.orientation = orientationStrategy;
    }

    public double getTerminateThreshold() {
        return this.terminateThreshold;
    }

    public void setTerminateThreshold(double d) {
        this.terminateThreshold = d;
    }

    public Duration getTimeout() {
        return this.timeout;
    }

    public void setTimeout(Duration duration) {
        this.timeout = duration;
    }

    public PointSample measure() {
        if (!$assertionsDisabled && this.subject == null) {
            throw new AssertionError();
        }
        PointSample measure = this.subject.measure(this.monitor);
        if (0 >= measure.delta.size()) {
            measure.freeRef();
            throw new AssertionError("Nothing to optimize");
        }
        double mean = measure.getMean();
        if (Double.isFinite(mean)) {
            return measure;
        }
        if (this.monitor.onStepFail(new Step(measure, this.currentIteration.get()))) {
            this.monitor.log(RefString.format("Retrying iteration %s", new Object[]{Integer.valueOf(this.currentIteration.get())}));
            return measure();
        }
        this.monitor.log(RefString.format("Optimization terminated %s", new Object[]{Integer.valueOf(this.currentIteration.get())}));
        throw new IterativeStopException(Double.toString(mean));
    }

    /* JADX WARN: Finally extract failed */
    public void shuffle() {
        long nanoTime = RefSystem.nanoTime();
        this.monitor.log(RefString.format("Reset training subject: " + nanoTime, new Object[0]));
        if (!$assertionsDisabled && this.orientation == null) {
            throw new AssertionError();
        }
        this.orientation.reset();
        if (!$assertionsDisabled && this.subject == null) {
            throw new AssertionError();
        }
        this.subject.reseed(nanoTime);
        Layer layer = this.subject.getLayer();
        try {
            if (layer instanceof DAGNetwork) {
                ((DAGNetwork) layer).shuffle(nanoTime);
            }
            if (null != layer) {
                layer.freeRef();
            }
        } catch (Throwable th) {
            if (null != layer) {
                layer.freeRef();
            }
            throw th;
        }
    }

    public double run() {
        long currentTimeMillis = RefSystem.currentTimeMillis() + this.timeout.toMillis();
        long nanoTime = RefSystem.nanoTime();
        shuffle();
        PointSample measure = measure();
        try {
            try {
                if (!$assertionsDisabled && measure == null) {
                    throw new AssertionError();
                }
                loop0: while (true) {
                    if (currentTimeMillis <= RefSystem.currentTimeMillis() || this.terminateThreshold >= measure.getMean() || this.maxIterations <= this.currentIteration.get()) {
                        break;
                    }
                    shuffle();
                    if (null != measure) {
                        measure.freeRef();
                    }
                    measure = measure();
                    int i = 0;
                    while (true) {
                        if (i < this.iterationsPerSample || this.iterationsPerSample <= 0) {
                            if (currentTimeMillis >= RefSystem.currentTimeMillis() && this.currentIteration.incrementAndGet() <= this.maxIterations) {
                                if (null != measure) {
                                    measure.freeRef();
                                }
                                PointSample measure2 = measure();
                                PointSample m35addRef = measure2 == null ? null : measure2.m35addRef();
                                TimedResult time = TimedResult.time((UncheckedSupplier) RefUtil.wrapInterface(() -> {
                                    return this.orientation.orient(this.subject == null ? null : this.subject.mo1addRef(), m35addRef == null ? null : m35addRef.m35addRef(), this.monitor);
                                }, new Object[]{m35addRef}));
                                LineSearchCursor lineSearchCursor = (LineSearchCursor) time.getResult();
                                CharSequence directionType = lineSearchCursor.getDirectionType();
                                PointSample m35addRef2 = measure2 == null ? null : measure2.m35addRef();
                                try {
                                    UncheckedSupplier uncheckedSupplier = () -> {
                                        return step(lineSearchCursor.mo77addRef(), directionType, m35addRef2 == null ? null : m35addRef2.m35addRef());
                                    };
                                    Object[] objArr = new Object[2];
                                    objArr[0] = m35addRef2 == null ? null : m35addRef2.m35addRef();
                                    objArr[1] = lineSearchCursor.mo77addRef();
                                    TimedResult time2 = TimedResult.time((UncheckedSupplier) RefUtil.wrapInterface(uncheckedSupplier, objArr));
                                    if (null != measure2) {
                                        measure2.freeRef();
                                    }
                                    measure = (PointSample) time2.getResult();
                                    long nanoTime2 = RefSystem.nanoTime();
                                    String format = RefString.format("Total: %.4f; Orientation: %.4f; Line Search: %.4f", new Object[]{Double.valueOf((nanoTime2 - nanoTime) / 1.0E9d), Double.valueOf(time.timeNanos / 1.0E9d), Double.valueOf(time2.timeNanos / 1.0E9d)});
                                    time2.freeRef();
                                    nanoTime = nanoTime2;
                                    if (!$assertionsDisabled && m35addRef2 == null) {
                                        throw new AssertionError();
                                    }
                                    this.monitor.log(RefString.format("Fitness changed from %s to %s", new Object[]{Double.valueOf(m35addRef2.getMean()), Double.valueOf(measure.getMean())}));
                                    if (m35addRef2.getMean() <= measure.getMean()) {
                                        if (m35addRef2.getMean() < measure.getMean()) {
                                            this.monitor.log(RefString.format("Resetting Iteration %s", new Object[]{format}));
                                            LineSearchPoint step = lineSearchCursor.step(0.0d, this.monitor);
                                            if (!$assertionsDisabled && step == null) {
                                                throw new AssertionError();
                                            }
                                            if (null != measure) {
                                                measure.freeRef();
                                            }
                                            measure = step.getPoint();
                                            step.freeRef();
                                        } else {
                                            this.monitor.log(RefString.format("Static Iteration %s", new Object[]{format}));
                                        }
                                        this.monitor.log(RefString.format("Iteration %s failed. Error: %s", new Object[]{Integer.valueOf(this.currentIteration.get()), Double.valueOf(measure.getMean())}));
                                        this.monitor.log(RefString.format("Previous Error: %s -> %s", new Object[]{Double.valueOf(m35addRef2.getRate()), Double.valueOf(m35addRef2.getMean())}));
                                        if (!this.monitor.onStepFail(new Step(measure.m35addRef(), this.currentIteration.get()))) {
                                            this.monitor.log(RefString.format("Optimization terminated %s", new Object[]{Integer.valueOf(this.currentIteration.get())}));
                                            time.freeRef();
                                            m35addRef2.freeRef();
                                            lineSearchCursor.freeRef();
                                            break;
                                        }
                                        this.monitor.log(RefString.format("Retrying iteration %s", new Object[]{Integer.valueOf(this.currentIteration.get())}));
                                        time.freeRef();
                                        m35addRef2.freeRef();
                                        lineSearchCursor.freeRef();
                                    } else {
                                        this.monitor.log(RefString.format("Iteration %s complete. Error: %s " + ((Object) format), new Object[]{Integer.valueOf(this.currentIteration.get()), Double.valueOf(measure.getMean())}));
                                        this.monitor.onStepComplete(new Step(measure.m35addRef(), this.currentIteration.get()));
                                        time.freeRef();
                                        m35addRef2.freeRef();
                                        lineSearchCursor.freeRef();
                                        i++;
                                    }
                                } catch (Throwable th) {
                                    time.freeRef();
                                    m35addRef2.freeRef();
                                    lineSearchCursor.freeRef();
                                    throw th;
                                }
                            }
                        }
                    }
                }
                if (!$assertionsDisabled && this.subject == null) {
                    throw new AssertionError();
                }
                Layer layer = this.subject.getLayer();
                if (layer instanceof DAGNetwork) {
                    ((DAGNetwork) layer).clearNoise();
                }
                if (null != layer) {
                    layer.freeRef();
                }
                double mean = null == measure ? Double.NaN : measure.getMean();
                TrainingMonitor trainingMonitor = this.monitor;
                Object[] objArr2 = new Object[5];
                objArr2[0] = Integer.valueOf(this.currentIteration.get());
                objArr2[1] = null == measure ? null : Double.valueOf(measure.getMean());
                objArr2[2] = Double.valueOf(this.terminateThreshold);
                objArr2[3] = Double.valueOf((RefSystem.currentTimeMillis() - r0) / 1000.0d);
                objArr2[4] = Double.valueOf(this.timeout.toMillis() / 1000.0d);
                trainingMonitor.log(RefString.format("Final threshold in iteration %s: %s (> %s) after %.3fs (< %.3fs)", objArr2));
                if (null != measure) {
                    measure.freeRef();
                }
                return mean;
            } catch (Throwable th2) {
                this.monitor.log(RefString.format("Error %s", new Object[]{Util.toString(th2)}));
                throw Util.throwException(th2);
            }
        } catch (Throwable th3) {
            TrainingMonitor trainingMonitor2 = this.monitor;
            Object[] objArr3 = new Object[5];
            objArr3[0] = Integer.valueOf(this.currentIteration.get());
            objArr3[1] = null == measure ? null : Double.valueOf(measure.getMean());
            objArr3[2] = Double.valueOf(this.terminateThreshold);
            objArr3[3] = Double.valueOf((RefSystem.currentTimeMillis() - r0) / 1000.0d);
            objArr3[4] = Double.valueOf(this.timeout.toMillis() / 1000.0d);
            trainingMonitor2.log(RefString.format("Final threshold in iteration %s: %s (> %s) after %.3fs (< %.3fs)", objArr3));
            if (null != measure) {
                measure.freeRef();
            }
            throw th3;
        }
    }

    public void setTimeout(int i, TemporalUnit temporalUnit) {
        this.timeout = Duration.of(i, temporalUnit);
    }

    public void setTimeout(int i, TimeUnit timeUnit) {
        setTimeout(i, Util.cvt(timeUnit));
    }

    public PointSample step(LineSearchCursor lineSearchCursor, CharSequence charSequence, PointSample pointSample) {
        LineSearchStrategy apply;
        if (this.lineSearchStrategyMap.containsKey(charSequence)) {
            apply = (LineSearchStrategy) this.lineSearchStrategyMap.get(charSequence);
        } else {
            log.info(RefString.format("Constructing line search parameters: %s", new Object[]{charSequence}));
            apply = this.lineSearchFactory.apply(lineSearchCursor.getDirectionType());
            RefUtil.freeRef(this.lineSearchStrategyMap.put(charSequence, apply));
        }
        FailsafeLineSearchCursor failsafeLineSearchCursor = new FailsafeLineSearchCursor(lineSearchCursor, pointSample, this.monitor);
        if (!$assertionsDisabled && apply == null) {
            throw new AssertionError();
        }
        try {
            RefUtil.freeRef(apply.step(failsafeLineSearchCursor.mo77addRef(), this.monitor));
            PointSample best = failsafeLineSearchCursor.getBest();
            failsafeLineSearchCursor.freeRef();
            return best;
        } catch (Throwable th) {
            failsafeLineSearchCursor.freeRef();
            throw th;
        }
    }

    public void _free() {
        super._free();
        if (null != this.orientation) {
            this.orientation.freeRef();
        }
        this.orientation = null;
        if (null != this.subject) {
            this.subject.freeRef();
        }
        this.lineSearchStrategyMap.freeRef();
    }

    /* renamed from: addRef, reason: merged with bridge method [inline-methods] */
    public IterativeTrainer m65addRef() {
        return super.addRef();
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -1968846353:
                if (implMethodName.equals("lambda$run$a9eb8890$1")) {
                    z = false;
                    break;
                }
                break;
            case -1736276419:
                if (implMethodName.equals("lambda$run$600f52c0$1")) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("com/simiacryptus/lang/UncheckedSupplier") && serializedLambda.getFunctionalInterfaceMethodName().equals("get") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("()Ljava/lang/Object;") && serializedLambda.getImplClass().equals("com/simiacryptus/mindseye/opt/IterativeTrainer") && serializedLambda.getImplMethodSignature().equals("(Lcom/simiacryptus/mindseye/lang/PointSample;)Lcom/simiacryptus/mindseye/opt/line/LineSearchCursor;")) {
                    IterativeTrainer iterativeTrainer = (IterativeTrainer) serializedLambda.getCapturedArg(0);
                    PointSample pointSample = (PointSample) serializedLambda.getCapturedArg(1);
                    return () -> {
                        return this.orientation.orient(this.subject == null ? null : this.subject.mo1addRef(), pointSample == null ? null : pointSample.m35addRef(), this.monitor);
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("com/simiacryptus/lang/UncheckedSupplier") && serializedLambda.getFunctionalInterfaceMethodName().equals("get") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("()Ljava/lang/Object;") && serializedLambda.getImplClass().equals("com/simiacryptus/mindseye/opt/IterativeTrainer") && serializedLambda.getImplMethodSignature().equals("(Lcom/simiacryptus/mindseye/opt/line/LineSearchCursor;Ljava/lang/CharSequence;Lcom/simiacryptus/mindseye/lang/PointSample;)Lcom/simiacryptus/mindseye/lang/PointSample;")) {
                    IterativeTrainer iterativeTrainer2 = (IterativeTrainer) serializedLambda.getCapturedArg(0);
                    LineSearchCursor lineSearchCursor = (LineSearchCursor) serializedLambda.getCapturedArg(1);
                    CharSequence charSequence = (CharSequence) serializedLambda.getCapturedArg(2);
                    PointSample pointSample2 = (PointSample) serializedLambda.getCapturedArg(3);
                    return () -> {
                        return step(lineSearchCursor.mo77addRef(), charSequence, pointSample2 == null ? null : pointSample2.m35addRef());
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }

    static {
        $assertionsDisabled = !IterativeTrainer.class.desiredAssertionStatus();
        log = LoggerFactory.getLogger(IterativeTrainer.class);
    }
}
