package com.simiacryptus.mindseye.opt.orient;

import com.simiacryptus.mindseye.eval.Trainable;
import com.simiacryptus.mindseye.lang.Delta;
import com.simiacryptus.mindseye.lang.DeltaSet;
import com.simiacryptus.mindseye.lang.DoubleBufferSet;
import com.simiacryptus.mindseye.lang.PointSample;
import com.simiacryptus.mindseye.opt.TrainingMonitor;
import com.simiacryptus.mindseye.opt.line.LineSearchPoint;
import com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor;
import com.simiacryptus.ref.lang.RefUtil;
import com.simiacryptus.ref.wrappers.RefArrayList;
import com.simiacryptus.ref.wrappers.RefArrays;
import com.simiacryptus.ref.wrappers.RefComparator;
import com.simiacryptus.ref.wrappers.RefList;
import com.simiacryptus.ref.wrappers.RefMap;
import com.simiacryptus.ref.wrappers.RefSet;
import com.simiacryptus.ref.wrappers.RefString;
import com.simiacryptus.ref.wrappers.RefSystem;
import com.simiacryptus.ref.wrappers.RefTreeSet;
import com.simiacryptus.util.ArrayUtil;
import java.util.List;
import java.util.UUID;
import java.util.function.Function;
import java.util.stream.Collectors;

/* loaded from: input_file:com/simiacryptus/mindseye/opt/orient/LBFGS.class */
public class LBFGS extends OrientationStrategyBase<SimpleLineSearchCursor> {
    protected final boolean verbose = true;
    private final RefTreeSet<PointSample> history = new RefTreeSet<>(RefComparator.reversed(RefComparator.comparingDouble((v0) -> {
        return v0.getMean();
    })));
    private int maxHistory = 30;
    private int minHistory = 3;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/simiacryptus/mindseye/opt/orient/LBFGS$Stats.class */
    public class Stats {
        private final double mag;
        private final double magGrad;
        private final double dot;
        private final List<CharSequence> anglesPerLayer;
        static final /* synthetic */ boolean $assertionsDisabled;

        public Stats(DeltaSet<UUID> deltaSet, DeltaSet<UUID> deltaSet2) {
            this.mag = Math.sqrt(deltaSet2.dot(deltaSet2.mo16addRef()));
            this.magGrad = Math.sqrt(deltaSet.dot(deltaSet.mo16addRef()));
            this.dot = deltaSet.dot(deltaSet2) / (this.mag * this.magGrad);
            RefMap<UUID, Delta<K>> map = deltaSet.getMap();
            RefSet entrySet = map.entrySet();
            try {
                this.anglesPerLayer = (List) entrySet.stream().map((Function) RefUtil.wrapInterface(entry -> {
                    Delta delta = deltaSet.get(entry.getKey());
                    if (!$assertionsDisabled && delta == null) {
                        throw new AssertionError();
                    }
                    double[] delta2 = delta.getDelta();
                    delta.freeRef();
                    if (!$assertionsDisabled && delta2 == null) {
                        throw new AssertionError();
                    }
                    for (int i = 0; i < delta2.length; i++) {
                        delta2[i] = Double.isFinite(delta2[i]) ? delta2[i] : 0.0d;
                    }
                    Delta delta3 = deltaSet.get(entry.getKey());
                    if (!$assertionsDisabled && delta3 == null) {
                        throw new AssertionError();
                    }
                    double[] delta4 = delta3.getDelta();
                    delta3.freeRef();
                    if (!$assertionsDisabled && delta4 == null) {
                        throw new AssertionError();
                    }
                    for (int i2 = 0; i2 < delta4.length; i2++) {
                        delta4[i2] = Double.isFinite(delta4[i2]) ? delta4[i2] : 0.0d;
                    }
                    double magnitude = ArrayUtil.magnitude(delta2);
                    double magnitude2 = ArrayUtil.magnitude(delta4);
                    if (!Double.isFinite(magnitude2)) {
                        RefUtil.freeRef(entry);
                        throw new IllegalStateException();
                    }
                    if (!Double.isFinite(magnitude)) {
                        RefUtil.freeRef(entry);
                        throw new IllegalStateException();
                    }
                    Delta delta5 = deltaSet.get(entry.getKey());
                    if (!$assertionsDisabled && delta5 == null) {
                        throw new AssertionError();
                    }
                    String uuid = ((UUID) delta5.key).toString();
                    delta5.freeRef();
                    RefUtil.freeRef(entry);
                    return magnitude2 == 0.0d ? RefString.format("%s = %.3e", new Object[]{uuid, Double.valueOf(magnitude)}) : RefString.format("%s = %.3f/%.3e", new Object[]{uuid, Double.valueOf(ArrayUtil.dot(delta2, delta4) / (magnitude * magnitude2)), Double.valueOf(magnitude / magnitude2)});
                }, new Object[]{deltaSet})).collect(Collectors.toList());
                entrySet.freeRef();
                map.freeRef();
            } catch (Throwable th) {
                entrySet.freeRef();
                map.freeRef();
                throw th;
            }
        }

        public List<CharSequence> getAnglesPerLayer() {
            return this.anglesPerLayer;
        }

        public double getDot() {
            return this.dot;
        }

        public double getMag() {
            return this.mag;
        }

        public double getMagGrad() {
            return this.magGrad;
        }

        public String toString() {
            return RefString.format("LBFGS Orientation magnitude: %.3e, gradient %.3e, dot %.3f; %s", new Object[]{Double.valueOf(getMag()), Double.valueOf(getMagGrad()), Double.valueOf(getDot()), getAnglesPerLayer()});
        }

        static {
            $assertionsDisabled = !LBFGS.class.desiredAssertionStatus();
        }
    }

    public int getMaxHistory() {
        return this.maxHistory;
    }

    public void setMaxHistory(int i) {
        this.maxHistory = i;
    }

    public int getMinHistory() {
        return this.minHistory;
    }

    public void setMinHistory(int i) {
        this.minHistory = i;
    }

    private static boolean isFinite(DoubleBufferSet<?, ?> doubleBufferSet) {
        boolean allMatch = doubleBufferSet.stream().parallel().allMatch(doubleBuffer -> {
            boolean allMatch2 = RefArrays.stream(doubleBuffer.getDelta()).allMatch(Double::isFinite);
            doubleBuffer.freeRef();
            return allMatch2;
        });
        doubleBufferSet.freeRef();
        return allMatch;
    }

    public void addToHistory(PointSample pointSample, TrainingMonitor trainingMonitor) {
        if (!$assertionsDisabled && !assertAlive()) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && !pointSample.assertAlive()) {
            throw new AssertionError();
        }
        if (null == pointSample) {
            return;
        }
        try {
            if (!isFinite(pointSample.delta.mo16addRef())) {
                trainingMonitor.log("Corrupt evalInputDelta measurement");
            } else if (isFinite(pointSample.weights.mo16addRef())) {
                double orElse = this.history.stream().mapToDouble(pointSample2 -> {
                    double mean = pointSample2.getMean();
                    pointSample2.freeRef();
                    return mean;
                }).min().orElse(Double.POSITIVE_INFINITY);
                if (pointSample.getMean() < orElse) {
                    PointSample copyFull = pointSample.copyFull();
                    trainingMonitor.log(RefString.format("Adding measurement %s to history. Total: %s", new Object[]{Long.toHexString(RefSystem.identityHashCode(copyFull.m35addRef())), Integer.valueOf(this.history.size())}));
                    this.history.add(copyFull);
                } else {
                    trainingMonitor.log(RefString.format("Non-optimal measurement %s < %s. Total: %s", new Object[]{Double.valueOf(pointSample.getMean()), Double.valueOf(orElse), Integer.valueOf(this.history.size())}));
                }
            } else {
                trainingMonitor.log("Corrupt weights measurement");
            }
        } finally {
            pointSample.freeRef();
        }
    }

    @Override // com.simiacryptus.mindseye.opt.orient.OrientationStrategy
    public SimpleLineSearchCursor orient(Trainable trainable, PointSample pointSample, TrainingMonitor trainingMonitor) {
        addToHistory(pointSample.m35addRef(), trainingMonitor);
        DeltaSet<UUID> lbfgs = lbfgs(pointSample.m35addRef(), trainingMonitor, new RefArrayList(this.history.addRef()));
        int i = null == lbfgs ? this.minHistory : this.maxHistory;
        try {
            SimpleLineSearchCursor cursor = null != lbfgs ? cursor(trainable, pointSample, "LBFGS", lbfgs) : cursor(trainable, pointSample, "GD", pointSample.delta.scale(-1.0d));
            truncateHistory(trainingMonitor, i);
            return cursor;
        } catch (Throwable th) {
            truncateHistory(trainingMonitor, i);
            throw th;
        }
    }

    @Override // com.simiacryptus.mindseye.opt.orient.OrientationStrategy
    public synchronized void reset() {
        this.history.clear();
    }

    @Override // com.simiacryptus.mindseye.opt.orient.OrientationStrategyBase, com.simiacryptus.mindseye.opt.orient.OrientationStrategy
    public void _free() {
        super._free();
        this.history.freeRef();
    }

    @Override // com.simiacryptus.mindseye.opt.orient.OrientationStrategyBase, com.simiacryptus.mindseye.opt.orient.OrientationStrategy
    /* renamed from: addRef */
    public LBFGS mo84addRef() {
        return (LBFGS) super.mo84addRef();
    }

    protected DeltaSet<UUID> lbfgs(PointSample pointSample, TrainingMonitor trainingMonitor, RefList<PointSample> refList) {
        DeltaSet<UUID> scale = pointSample.delta.scale(-1.0d);
        if (refList.size() <= this.minHistory) {
            scale.freeRef();
            trainingMonitor.log(RefString.format("LBFGS Accumulation History: %s points", new Object[]{Integer.valueOf(refList.size())}));
            pointSample.freeRef();
            refList.freeRef();
            return null;
        }
        if (lbfgs(pointSample.m35addRef(), trainingMonitor, refList.addRef(), scale.mo16addRef())) {
            setHistory(trainingMonitor, refList);
            pointSample.freeRef();
            return scale;
        }
        scale.freeRef();
        trainingMonitor.log("Orientation rejected. Popping history element from " + ((String) RefUtil.get(refList.stream().map(pointSample2 -> {
            double mean = pointSample2.getMean();
            pointSample2.freeRef();
            return String.format("%s", Double.valueOf(mean));
        }).reduce((str, str2) -> {
            return str + ", " + str2;
        }))));
        RefList<PointSample> subList = refList.subList(0, refList.size() - 1);
        refList.freeRef();
        return lbfgs(pointSample, trainingMonitor, subList);
    }

    private void truncateHistory(TrainingMonitor trainingMonitor, int i) {
        while (this.history.size() > i) {
            trainingMonitor.log(RefString.format("Removed measurement %s to history. Total: %s", new Object[]{Long.toHexString(RefSystem.identityHashCode((PointSample) this.history.pollFirst())), Integer.valueOf(this.history.size())}));
        }
    }

    private SimpleLineSearchCursor cursor(Trainable trainable, PointSample pointSample, String str, DeltaSet<UUID> deltaSet) {
        SimpleLineSearchCursor simpleLineSearchCursor = new SimpleLineSearchCursor(trainable, pointSample, deltaSet) { // from class: com.simiacryptus.mindseye.opt.orient.LBFGS.1
            @Override // com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor, com.simiacryptus.mindseye.opt.line.LineSearchCursorBase, com.simiacryptus.mindseye.opt.line.LineSearchCursor
            public LineSearchPoint step(double d, TrainingMonitor trainingMonitor) {
                LineSearchPoint step = super.step(d, trainingMonitor);
                LBFGS.this.addToHistory(step.getPoint(), trainingMonitor);
                return step;
            }

            @Override // com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor, com.simiacryptus.mindseye.opt.line.LineSearchCursorBase, com.simiacryptus.mindseye.opt.line.LineSearchCursor
            public void _free() {
                super._free();
            }
        };
        simpleLineSearchCursor.setDirectionType(str);
        return simpleLineSearchCursor;
    }

    private void setHistory(TrainingMonitor trainingMonitor, RefList<PointSample> refList) {
        if (refList.size() == this.history.size() && refList.stream().filter(pointSample -> {
            return !this.history.contains(pointSample);
        }).count() == 0) {
            refList.freeRef();
            return;
        }
        trainingMonitor.log(RefString.format("Overwriting history with %s points", new Object[]{Integer.valueOf(refList.size())}));
        synchronized (this.history) {
            this.history.clear();
            this.history.addAll(refList);
        }
    }

    private boolean lbfgs(PointSample pointSample, TrainingMonitor trainingMonitor, RefList<PointSample> refList, DeltaSet<UUID> deltaSet) {
        DeltaSet<UUID> copy = pointSample.delta.copy();
        DeltaSet<UUID> allFinite = copy.allFinite(0.0d);
        copy.freeRef();
        try {
            try {
                double[] dArr = new double[refList.size()];
                for (int size = refList.size() - 2; size >= 0; size--) {
                    DeltaSet<UUID> subtractWeights = subtractWeights((PointSample) refList.get(size + 1), (PointSample) refList.get(size));
                    DeltaSet<UUID> subtractDelta = subtractDelta((PointSample) refList.get(size + 1), (PointSample) refList.get(size));
                    double dot = subtractWeights.dot(subtractDelta.mo16addRef());
                    if (0.0d == dot) {
                        subtractWeights.freeRef();
                        subtractDelta.freeRef();
                        throw new IllegalStateException("Orientation vanished.");
                    }
                    dArr[size] = allFinite.dot(subtractWeights) / dot;
                    DeltaSet<UUID> scale = subtractDelta.scale(dArr[size]);
                    subtractDelta.freeRef();
                    DeltaSet<UUID> subtract = allFinite.subtract(scale);
                    if (allFinite != null) {
                        allFinite.freeRef();
                    }
                    allFinite = subtract.allFinite(0.0d);
                    subtract.freeRef();
                }
                DeltaSet<UUID> subtractWeights2 = subtractWeights((PointSample) refList.get(refList.size() - 1), (PointSample) refList.get(refList.size() - 2));
                DeltaSet<UUID> subtractDelta2 = subtractDelta((PointSample) refList.get(refList.size() - 1), (PointSample) refList.get(refList.size() - 2));
                double dot2 = subtractWeights2.dot(subtractDelta2.mo16addRef());
                subtractWeights2.freeRef();
                double dot3 = dot2 / subtractDelta2.dot(subtractDelta2.mo16addRef());
                subtractDelta2.freeRef();
                DeltaSet<UUID> scale2 = allFinite.scale(dot3);
                if (allFinite != null) {
                    allFinite.freeRef();
                }
                DeltaSet<UUID> allFinite2 = scale2.allFinite(0.0d);
                scale2.freeRef();
                for (int i = 0; i < refList.size() - 1; i++) {
                    DeltaSet<UUID> subtractWeights3 = subtractWeights((PointSample) refList.get(i + 1), (PointSample) refList.get(i));
                    DeltaSet<UUID> subtractDelta3 = subtractDelta((PointSample) refList.get(i + 1), (PointSample) refList.get(i));
                    DeltaSet<UUID> add = allFinite2.add(subtractWeights3.scale(dArr[i] - (allFinite2.dot(subtractDelta3) / subtractWeights3.dot(subtractDelta3.mo16addRef()))));
                    subtractWeights3.freeRef();
                    if (allFinite2 != null) {
                        allFinite2.freeRef();
                    }
                    allFinite2 = add.allFinite(0.0d);
                    add.freeRef();
                }
                boolean z = pointSample.delta.dot(allFinite2.mo16addRef()) < 0.0d;
                if (z) {
                    trainingMonitor.log("Accepted: " + new Stats(deltaSet.mo16addRef(), allFinite2.mo16addRef()));
                    copy(allFinite2.mo16addRef(), deltaSet.mo16addRef());
                } else {
                    trainingMonitor.log("Rejected: " + new Stats(deltaSet.mo16addRef(), allFinite2.mo16addRef()));
                }
                allFinite2.freeRef();
                pointSample.freeRef();
                refList.freeRef();
                deltaSet.freeRef();
                return z;
            } catch (Throwable th) {
                trainingMonitor.log(RefString.format("LBFGS Orientation Error: %s", new Object[]{th.getMessage()}));
                allFinite.freeRef();
                pointSample.freeRef();
                refList.freeRef();
                deltaSet.freeRef();
                return false;
            }
        } catch (Throwable th2) {
            allFinite.freeRef();
            pointSample.freeRef();
            refList.freeRef();
            deltaSet.freeRef();
            throw th2;
        }
    }

    private DeltaSet<UUID> subtractDelta(PointSample pointSample, PointSample pointSample2) {
        try {
            DeltaSet<UUID> subtract = pointSample.delta.subtract(pointSample2.delta.mo16addRef());
            pointSample2.freeRef();
            pointSample.freeRef();
            return subtract;
        } catch (Throwable th) {
            pointSample2.freeRef();
            pointSample.freeRef();
            throw th;
        }
    }

    private DeltaSet<UUID> subtractWeights(PointSample pointSample, PointSample pointSample2) {
        try {
            DeltaSet<UUID> subtract = pointSample.weights.subtract(pointSample2.weights.mo16addRef());
            pointSample2.freeRef();
            pointSample.freeRef();
            return subtract;
        } catch (Throwable th) {
            pointSample2.freeRef();
            pointSample.freeRef();
            throw th;
        }
    }

    private void copy(DeltaSet<UUID> deltaSet, DeltaSet<UUID> deltaSet2) {
        RefMap<UUID, Delta<K>> map = deltaSet2.getMap();
        RefSet entrySet = map.entrySet();
        map.freeRef();
        entrySet.forEach(entry -> {
            Delta delta = deltaSet.get(entry.getKey());
            if (!$assertionsDisabled && delta == null) {
                throw new AssertionError();
            }
            double[] delta2 = delta.getDelta();
            delta.freeRef();
            Delta delta3 = (Delta) entry.getValue();
            RefUtil.freeRef(entry);
            RefArrays.setAll(delta3.getDelta(), i -> {
                if ($assertionsDisabled || delta2 != null) {
                    return delta2[i];
                }
                throw new AssertionError();
            });
            delta3.freeRef();
        });
        entrySet.freeRef();
        deltaSet2.freeRef();
        deltaSet.freeRef();
    }

    static {
        $assertionsDisabled = !LBFGS.class.desiredAssertionStatus();
    }
}
