package com.simiacryptus.mindseye.opt.line;

import com.simiacryptus.mindseye.lang.IterativeStopException;
import com.simiacryptus.mindseye.lang.PointSample;
import com.simiacryptus.mindseye.opt.TrainingMonitor;
import com.simiacryptus.mindseye.opt.orient.DescribeOrientationWrapper;
import com.simiacryptus.ref.lang.RefUtil;
import com.simiacryptus.ref.lang.ReferenceCountingBase;
import com.simiacryptus.ref.wrappers.RefString;

/* loaded from: input_file:com/simiacryptus/mindseye/opt/line/QuadraticSearch.class */
public class QuadraticSearch implements LineSearchStrategy {
    private final double initialDerivFactor = 0.95d;
    private double absoluteTolerance = 1.0E-12d;
    private double currentRate = 0.0d;
    private double minRate = 1.0E-10d;
    private double maxRate = 1.0E10d;
    private double relativeTolerance = 0.01d;
    private double stepSize = 1.0d;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:com/simiacryptus/mindseye/opt/line/QuadraticSearch$Stepper.class */
    private class Stepper extends ReferenceCountingBase {
        private final LineSearchCursor cursor;
        private final TrainingMonitor monitor;
        private double thisX;
        private double leftX;
        private double rightX;
        private LineSearchPoint thisPoint;
        private LineSearchPoint initialPoint;
        private LineSearchPoint leftPoint;
        private LineSearchPoint rightPoint;
        private int loops;
        static final /* synthetic */ boolean $assertionsDisabled;

        private Stepper(LineSearchCursor lineSearchCursor, TrainingMonitor trainingMonitor) {
            this.cursor = lineSearchCursor;
            this.monitor = trainingMonitor;
            this.thisX = 0.0d;
            this.thisPoint = this.cursor.step(this.thisX, this.monitor);
            this.initialPoint = getThisPoint();
            this.leftX = this.thisX;
            this.leftPoint = getThisPoint();
            this.monitor.log(RefString.format("F(%s) = %s", new Object[]{Double.valueOf(this.leftX), getLeftPoint()}));
            if (!$assertionsDisabled && this.leftPoint == null) {
                throw new AssertionError();
            }
        }

        private LineSearchPoint getLeftPoint() {
            if (this.leftPoint == null) {
                return null;
            }
            return this.leftPoint.m78addRef();
        }

        private LineSearchPoint getRightPoint() {
            if (this.rightPoint == null) {
                return null;
            }
            return this.rightPoint.m78addRef();
        }

        private LineSearchPoint getThisPoint() {
            if (this.thisPoint == null) {
                return null;
            }
            return this.thisPoint.m78addRef();
        }

        private void setThisPoint(LineSearchPoint lineSearchPoint) {
            if (null != this.thisPoint) {
                this.thisPoint.freeRef();
            }
            this.thisPoint = lineSearchPoint;
        }

        public PointSample step() {
            if (0.0d == this.leftPoint.derivative) {
                return QuadraticSearch.this.getPointSample(getLeftPoint());
            }
            this.rightX = QuadraticSearch.this.getCurrentRate() > 0.0d ? QuadraticSearch.this.getCurrentRate() : Math.abs((leftMean() * 1.0E-4d) / this.leftPoint.derivative);
            RefUtil.freeRef(this.rightPoint);
            this.rightPoint = this.cursor.step(this.rightX, this.monitor);
            if (!$assertionsDisabled && this.rightPoint == null) {
                throw new AssertionError();
            }
            this.monitor.log(RefString.format("F(%s) = %s, evalInputDelta = %s", new Object[]{Double.valueOf(this.rightX), this.rightPoint.m78addRef(), Double.valueOf(rightMean() - leftMean())}));
            this.loops = 0;
            do {
            } while (!locateInitialRightPoint());
            this.loops = 0;
            PointSample pointSample = null;
            while (true) {
                PointSample pointSample2 = pointSample;
                if (null != pointSample2) {
                    return pointSample2;
                }
                RefUtil.freeRef(pointSample2);
                pointSample = iterate();
            }
        }

        protected void _free() {
            RefUtil.freeRef(this.thisPoint);
            RefUtil.freeRef(this.initialPoint);
            RefUtil.freeRef(this.leftPoint);
            RefUtil.freeRef(this.rightPoint);
            RefUtil.freeRef(this.cursor);
            super._free();
        }

        private boolean locateInitialRightPoint() {
            LineSearchPoint rightPoint = getRightPoint();
            try {
                if (QuadraticSearch.this.isSame(this.cursor.mo77addRef(), this.monitor, getLeftPoint(), getRightPoint())) {
                    this.monitor.log(RefString.format("%s ~= %s", new Object[]{Double.valueOf(this.leftPoint.getPointRate()), Double.valueOf(this.rightX)}));
                    return true;
                }
                if (rightMean() > leftMean() && this.rightX > QuadraticSearch.this.minRate) {
                    this.rightX /= 13.0d;
                } else {
                    if (this.rightPoint.derivative >= 0.95d * this.rightPoint.derivative || this.rightX >= QuadraticSearch.this.maxRate) {
                        this.monitor.log(RefString.format("%s <= %s", new Object[]{Double.valueOf(rightMean()), Double.valueOf(leftMean())}));
                        return true;
                    }
                    this.rightX *= 7.0d;
                }
                if (null != this.rightPoint) {
                    this.rightPoint.freeRef();
                }
                this.rightPoint = this.cursor.step(this.rightX, this.monitor);
                if (QuadraticSearch.this.isSame(this.cursor.mo77addRef(), this.monitor, rightPoint == null ? null : rightPoint.m78addRef(), getRightPoint())) {
                    if (!$assertionsDisabled && rightPoint == null) {
                        throw new AssertionError();
                    }
                    this.monitor.log(RefString.format("%s ~= %s", new Object[]{Double.valueOf(rightPoint.getPointRate()), Double.valueOf(this.rightX)}));
                    return true;
                }
                this.monitor.log(RefString.format("F(%s) = %s, evalInputDelta = %s", new Object[]{Double.valueOf(this.rightX), getRightPoint(), Double.valueOf(rightMean() - leftMean())}));
                int i = this.loops;
                this.loops = i + 1;
                if (i <= 50) {
                    return false;
                }
                this.monitor.log(RefString.format("Loops = %s", new Object[]{Integer.valueOf(this.loops)}));
                return true;
            } finally {
                RefUtil.freeRef(rightPoint);
            }
        }

        private PointSample iterate() {
            boolean z;
            if (!$assertionsDisabled && this.rightPoint == null) {
                throw new AssertionError();
            }
            double d = (this.rightPoint.derivative - this.leftPoint.derivative) / (this.rightX - this.leftX);
            this.thisX = (-(this.rightPoint.derivative - (d * this.rightX))) / d;
            boolean z2 = Math.signum(this.leftPoint.derivative) != Math.signum(this.rightPoint.derivative);
            if (!Double.isFinite(this.thisX) || (z2 && (this.leftX > this.thisX || this.rightX < this.thisX))) {
                this.thisX = (this.rightX + this.leftX) / 2.0d;
            }
            if (!z2 && this.thisX < 0.0d) {
                this.thisX = this.rightX * 2.0d;
            }
            if (this.thisX < QuadraticSearch.this.getMinRate()) {
                this.thisX = QuadraticSearch.this.getMinRate();
            }
            if (this.thisX > QuadraticSearch.this.getMaxRate()) {
                this.thisX = QuadraticSearch.this.getMaxRate();
            }
            if (QuadraticSearch.this.isSame(this.leftX, this.thisX, 1.0d)) {
                this.monitor.log(RefString.format("Converged to left", new Object[0]));
                return QuadraticSearch.this.filter(this.cursor.mo77addRef(), this.leftPoint.getPoint(), this.monitor);
            }
            if (QuadraticSearch.this.isSame(this.thisX, this.rightX, 1.0d)) {
                this.monitor.log(RefString.format("Converged to right", new Object[0]));
                return QuadraticSearch.this.filter(this.cursor.mo77addRef(), this.rightPoint.getPoint(), this.monitor);
            }
            setThisPoint(this.cursor.step(this.thisX, this.monitor));
            if (QuadraticSearch.this.isSame(this.cursor.mo77addRef(), this.monitor, getLeftPoint(), getThisPoint())) {
                this.monitor.log(RefString.format("%s ~= %s", new Object[]{Double.valueOf(this.leftX), Double.valueOf(this.thisX)}));
                return QuadraticSearch.this.filter(this.cursor.mo77addRef(), this.leftPoint.getPoint(), this.monitor);
            }
            if (QuadraticSearch.this.isSame(this.cursor.mo77addRef(), this.monitor, getThisPoint(), this.rightPoint.m78addRef())) {
                this.monitor.log(RefString.format("%s ~= %s", new Object[]{Double.valueOf(this.thisX), Double.valueOf(this.rightX)}));
                return QuadraticSearch.this.filter(this.cursor.mo77addRef(), this.rightPoint.getPoint(), this.monitor);
            }
            setThisPoint(this.cursor.step(this.thisX, this.monitor));
            if (z2) {
                if (!$assertionsDisabled && this.thisPoint == null) {
                    throw new AssertionError();
                }
                z = this.thisPoint.derivative < 0.0d;
            } else {
                if (!$assertionsDisabled && this.thisPoint == null) {
                    throw new AssertionError();
                }
                z = Math.abs(this.rightPoint.getPointRate() - this.thisPoint.getPointRate()) > Math.abs(this.leftPoint.getPointRate() - this.thisPoint.getPointRate());
            }
            this.monitor.log(RefString.format("F(%s) = %s, evalInputDelta = %s", new Object[]{Double.valueOf(this.thisX), this.thisPoint.m78addRef(), Double.valueOf(thisMean() - this.initialPoint.getPointMean())}));
            int i = this.loops;
            this.loops = i + 1;
            if (i > 10) {
                this.monitor.log(RefString.format("Loops = %s", new Object[]{Integer.valueOf(this.loops)}));
                return QuadraticSearch.this.filter(this.cursor.mo77addRef(), this.thisPoint.getPoint(), this.monitor);
            }
            if (QuadraticSearch.this.isSame(this.cursor.mo77addRef(), this.monitor, getLeftPoint(), this.rightPoint.m78addRef())) {
                this.monitor.log(RefString.format("%s ~= %s", new Object[]{Double.valueOf(this.leftX), Double.valueOf(this.rightX)}));
                return QuadraticSearch.this.filter(this.cursor.mo77addRef(), this.thisPoint.getPoint(), this.monitor);
            }
            if (z) {
                if (thisMean() > leftMean()) {
                    this.monitor.log(RefString.format("%s > %s", new Object[]{Double.valueOf(thisMean()), Double.valueOf(leftMean())}));
                    return QuadraticSearch.this.filter(this.cursor.mo77addRef(), this.leftPoint.getPoint(), this.monitor);
                }
                if (z2 || leftMean() >= rightMean()) {
                    this.leftPoint.freeRef();
                } else {
                    this.rightX = this.leftX;
                    this.rightPoint.freeRef();
                    this.rightPoint = this.leftPoint;
                }
                this.leftPoint = this.thisPoint.m78addRef();
                this.leftX = this.thisX;
                this.monitor.log(RefString.format("Left bracket at %s", new Object[]{Double.valueOf(this.thisX)}));
                return null;
            }
            if (thisMean() > rightMean()) {
                this.monitor.log(RefString.format("%s > %s", new Object[]{Double.valueOf(thisMean()), Double.valueOf(rightMean())}));
                return QuadraticSearch.this.filter(this.cursor.mo77addRef(), this.rightPoint.getPoint(), this.monitor);
            }
            if (z2 || rightMean() >= leftMean()) {
                this.rightPoint.freeRef();
            } else {
                this.leftX = this.rightX;
                if (null != this.leftPoint) {
                    this.leftPoint.freeRef();
                }
                this.leftPoint = this.rightPoint;
            }
            this.rightX = this.thisX;
            this.rightPoint = this.thisPoint.m78addRef();
            this.monitor.log(RefString.format("Right bracket at %s", new Object[]{Double.valueOf(this.thisX)}));
            return null;
        }

        private double leftMean() {
            return this.leftPoint.getPointMean();
        }

        private double thisMean() {
            return this.thisPoint.getPointMean();
        }

        private double rightMean() {
            return this.rightPoint.getPointMean();
        }

        static {
            $assertionsDisabled = !QuadraticSearch.class.desiredAssertionStatus();
        }
    }

    public double getAbsoluteTolerance() {
        return this.absoluteTolerance;
    }

    public QuadraticSearch setAbsoluteTolerance(double d) {
        this.absoluteTolerance = d;
        return this;
    }

    public double getCurrentRate() {
        return this.currentRate;
    }

    public QuadraticSearch setCurrentRate(double d) {
        this.currentRate = d;
        return this;
    }

    public double getMaxRate() {
        return this.maxRate;
    }

    public QuadraticSearch setMaxRate(double d) {
        this.maxRate = d;
        return this;
    }

    public double getMinRate() {
        return this.minRate;
    }

    public QuadraticSearch setMinRate(double d) {
        this.minRate = d;
        return this;
    }

    public double getRelativeTolerance() {
        return this.relativeTolerance;
    }

    public QuadraticSearch setRelativeTolerance(double d) {
        this.relativeTolerance = d;
        return this;
    }

    public double getStepSize() {
        return this.stepSize;
    }

    public QuadraticSearch setStepSize(double d) {
        this.stepSize = d;
        return this;
    }

    @Override // com.simiacryptus.mindseye.opt.line.LineSearchStrategy
    public PointSample step(LineSearchCursor lineSearchCursor, TrainingMonitor trainingMonitor) {
        if (this.currentRate < getMinRate()) {
            this.currentRate = getMinRate();
        }
        if (this.currentRate > getMaxRate()) {
            this.currentRate = getMaxRate();
        }
        Stepper stepper = new Stepper(lineSearchCursor, trainingMonitor);
        try {
            PointSample step = stepper.step();
            if (!$assertionsDisabled && step == null) {
                throw new AssertionError();
            }
            setCurrentRate(step.rate);
            stepper.freeRef();
            return step;
        } catch (Throwable th) {
            stepper.freeRef();
            throw th;
        }
    }

    protected boolean isSame(double d, double d2, double d3) {
        double abs = Math.abs(d - d2) / d3;
        return abs < this.absoluteTolerance || abs < Math.max(Math.abs(d), Math.abs(d2)) * this.relativeTolerance;
    }

    protected boolean isSame(LineSearchCursor lineSearchCursor, TrainingMonitor trainingMonitor, LineSearchPoint lineSearchPoint, LineSearchPoint lineSearchPoint2) {
        PointSample point = lineSearchPoint2.getPoint();
        if (!$assertionsDisabled && point == null) {
            throw new AssertionError();
        }
        PointSample point2 = lineSearchPoint.getPoint();
        if (!$assertionsDisabled && point2 == null) {
            throw new AssertionError();
        }
        try {
            if (!isSame(point2.rate, point.rate, 1.0d)) {
                lineSearchCursor.freeRef();
                lineSearchPoint.freeRef();
                lineSearchPoint2.freeRef();
                point2.freeRef();
                point.freeRef();
                return false;
            }
            if (!isSame(point2.getMean(), point.getMean(), 10.0d)) {
                String diagnose = diagnose(lineSearchCursor, trainingMonitor, lineSearchPoint, lineSearchPoint2);
                trainingMonitor.log(diagnose);
                throw new IterativeStopException(diagnose);
            }
            lineSearchCursor.freeRef();
            lineSearchPoint.freeRef();
            lineSearchPoint2.freeRef();
            point2.freeRef();
            point.freeRef();
            return true;
        } catch (Throwable th) {
            point2.freeRef();
            point.freeRef();
            throw th;
        }
    }

    private String diagnose(LineSearchCursor lineSearchCursor, TrainingMonitor trainingMonitor, LineSearchPoint lineSearchPoint, LineSearchPoint lineSearchPoint2) {
        PointSample point = lineSearchPoint.getPoint();
        PointSample point2 = lineSearchPoint2.getPoint();
        try {
            double verifyMean = verifyMean(lineSearchCursor.mo77addRef(), trainingMonitor, lineSearchPoint.getPointRate());
            boolean isSame = isSame(point.getMean(), verifyMean, 1.0d);
            trainingMonitor.log(RefString.format("Verify %s: %s (%s)", new Object[]{Double.valueOf(point.rate), Double.valueOf(verifyMean), Boolean.valueOf(isSame)}));
            if (!$assertionsDisabled && point == null) {
                throw new AssertionError();
            }
            if (!$assertionsDisabled && point2 == null) {
                throw new AssertionError();
            }
            if (!isSame) {
                DescribeOrientationWrapper.render(point.weights.mo16addRef(), point.delta.mo16addRef());
                String str = "Non-Reproducable Point Found: " + point.rate;
                lineSearchCursor.freeRef();
                lineSearchPoint.freeRef();
                lineSearchPoint2.freeRef();
                point.freeRef();
                point2.freeRef();
                return str;
            }
            double verifyMean2 = verifyMean(lineSearchCursor.mo77addRef(), trainingMonitor, point2.rate);
            boolean isSame2 = isSame(point2.getMean(), verifyMean2, 1.0d);
            trainingMonitor.log(RefString.format("Verify %s: %s (%s)", new Object[]{Double.valueOf(point2.rate), Double.valueOf(verifyMean2), Boolean.valueOf(isSame2)}));
            if (isSame2) {
                return "Function Discontinuity Found";
            }
            String str2 = "Non-Reproducable Point Found: " + point2.rate;
            lineSearchCursor.freeRef();
            lineSearchPoint.freeRef();
            lineSearchPoint2.freeRef();
            point.freeRef();
            point2.freeRef();
            return str2;
        } finally {
            lineSearchCursor.freeRef();
            lineSearchPoint.freeRef();
            lineSearchPoint2.freeRef();
            point.freeRef();
            point2.freeRef();
        }
    }

    private double verifyMean(LineSearchCursor lineSearchCursor, TrainingMonitor trainingMonitor, double d) {
        LineSearchPoint step = lineSearchCursor.step(d, trainingMonitor);
        try {
            if (!$assertionsDisabled && step == null) {
                throw new AssertionError();
            }
            double pointMean = step.getPointMean();
            lineSearchCursor.freeRef();
            step.freeRef();
            return pointMean;
        } catch (Throwable th) {
            lineSearchCursor.freeRef();
            step.freeRef();
            throw th;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public PointSample filter(LineSearchCursor lineSearchCursor, PointSample pointSample, TrainingMonitor trainingMonitor) {
        if (this.stepSize == 1.0d) {
            lineSearchCursor.freeRef();
            return pointSample;
        }
        try {
            LineSearchPoint step = lineSearchCursor.step(pointSample.rate * this.stepSize, trainingMonitor);
            if (!$assertionsDisabled && step == null) {
                throw new AssertionError();
            }
            PointSample pointSample2 = getPointSample(step);
            lineSearchCursor.freeRef();
            pointSample.freeRef();
            return pointSample2;
        } catch (Throwable th) {
            lineSearchCursor.freeRef();
            pointSample.freeRef();
            throw th;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public PointSample getPointSample(LineSearchPoint lineSearchPoint) {
        try {
            return lineSearchPoint.getPoint();
        } finally {
            lineSearchPoint.freeRef();
        }
    }

    static {
        $assertionsDisabled = !QuadraticSearch.class.desiredAssertionStatus();
    }
}
