package com.simiacryptus.mindseye.opt.orient;

import com.simiacryptus.mindseye.eval.Trainable;
import com.simiacryptus.mindseye.lang.Delta;
import com.simiacryptus.mindseye.lang.DeltaSet;
import com.simiacryptus.mindseye.lang.Layer;
import com.simiacryptus.mindseye.lang.PointSample;
import com.simiacryptus.mindseye.network.DAGNetwork;
import com.simiacryptus.mindseye.opt.TrainingMonitor;
import com.simiacryptus.mindseye.opt.line.LineSearchCursor;
import com.simiacryptus.mindseye.opt.line.LineSearchPoint;
import com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor;
import com.simiacryptus.ref.lang.RefUtil;
import com.simiacryptus.ref.wrappers.RefCollection;
import com.simiacryptus.ref.wrappers.RefCollectors;
import com.simiacryptus.ref.wrappers.RefList;
import com.simiacryptus.ref.wrappers.RefMap;
import com.simiacryptus.ref.wrappers.RefSet;
import com.simiacryptus.ref.wrappers.RefStream;
import java.util.UUID;
import java.util.function.BiConsumer;
import java.util.function.Function;

/* loaded from: input_file:com/simiacryptus/mindseye/opt/orient/OwlQn.class */
public class OwlQn extends OrientationStrategyBase<LineSearchCursor> {
    public final OrientationStrategy<?> inner;
    private double factor_L1;
    private double zeroTol;
    static final /* synthetic */ boolean $assertionsDisabled;

    public OwlQn() {
        this(new LBFGS());
    }

    protected OwlQn(OrientationStrategy<?> orientationStrategy) {
        this.factor_L1 = 0.0d;
        this.zeroTol = 1.0E-20d;
        OrientationStrategy<?> mo84addRef = orientationStrategy == null ? null : orientationStrategy.mo84addRef();
        this.inner = mo84addRef == null ? null : mo84addRef.mo84addRef();
        if (null != mo84addRef) {
            mo84addRef.freeRef();
        }
        if (null != orientationStrategy) {
            orientationStrategy.freeRef();
        }
    }

    public double getFactor_L1() {
        return this.factor_L1;
    }

    public void setFactor_L1(double d) {
        this.factor_L1 = d;
    }

    public double getZeroTol() {
        return this.zeroTol;
    }

    public void setZeroTol(double d) {
        this.zeroTol = d;
    }

    public RefCollection<Layer> getLayers(RefCollection<Layer> refCollection) {
        RefList refList = (RefList) refCollection.stream().collect(RefCollectors.toList());
        refCollection.freeRef();
        return refList;
    }

    @Override // com.simiacryptus.mindseye.opt.orient.OrientationStrategy
    public LineSearchCursor orient(Trainable trainable, PointSample pointSample, TrainingMonitor trainingMonitor) {
        if (!$assertionsDisabled && this.inner == null) {
            throw new AssertionError();
        }
        SimpleLineSearchCursor simpleLineSearchCursor = (SimpleLineSearchCursor) this.inner.orient(trainable == null ? null : trainable.mo1addRef(), pointSample.m35addRef(), trainingMonitor);
        if (!$assertionsDisabled && simpleLineSearchCursor.direction == null) {
            throw new AssertionError();
        }
        DeltaSet<UUID> copy = simpleLineSearchCursor.direction.copy();
        DeltaSet deltaSet = new DeltaSet();
        RefSet<UUID> keySet = simpleLineSearchCursor.direction.keySet();
        RefStream stream = keySet.stream();
        Function function = uuid -> {
            if (!$assertionsDisabled && trainable == null) {
                throw new AssertionError();
            }
            DAGNetwork dAGNetwork = (DAGNetwork) trainable.getLayer();
            if (null == dAGNetwork) {
                return null;
            }
            RefMap<UUID, Layer> layersById = dAGNetwork.getLayersById();
            Layer layer = (Layer) layersById.get(uuid);
            layersById.freeRef();
            dAGNetwork.freeRef();
            return layer;
        };
        Object[] objArr = new Object[1];
        objArr[0] = trainable == null ? null : trainable.mo1addRef();
        RefList refList = (RefList) stream.map((Function) RefUtil.wrapInterface(function, objArr)).collect(RefCollectors.toList());
        keySet.freeRef();
        RefMap<UUID, Delta<K>> map = simpleLineSearchCursor.direction.getMap();
        map.forEach((BiConsumer) RefUtil.wrapInterface((uuid2, delta) -> {
            double[] dArr = delta.target;
            double[] delta = delta.getDelta();
            delta.freeRef();
            Delta delta2 = copy.get((DeltaSet) uuid2, dArr);
            if (!$assertionsDisabled && delta2 == null) {
                throw new AssertionError();
            }
            double[] delta3 = delta2.getDelta();
            delta2.freeRef();
            Delta delta4 = deltaSet.get((DeltaSet) uuid2, dArr);
            if (!$assertionsDisabled && delta4 == null) {
                throw new AssertionError();
            }
            double[] delta5 = delta4.getDelta();
            delta4.freeRef();
            if (!$assertionsDisabled && delta3 == null) {
                throw new AssertionError();
            }
            for (int i = 0; i < delta3.length; i++) {
                int sign = sign(dArr[i]);
                if (!$assertionsDisabled && delta == null) {
                    throw new AssertionError();
                }
                int sign2 = sign(delta[i]);
                if (!$assertionsDisabled && delta5 == null) {
                    throw new AssertionError();
                }
                delta5[i] = 0 == sign ? sign2 : sign;
                int i2 = i;
                delta3[i2] = delta3[i2] + (this.factor_L1 * (dArr[i] < 0.0d ? -1.0d : 1.0d));
                if (sign(delta3[i]) != sign2) {
                    delta3[i] = delta[i];
                }
            }
        }, new Object[]{copy.mo16addRef(), deltaSet}));
        map.freeRef();
        if (null != refList) {
            refList.freeRef();
        }
        simpleLineSearchCursor.freeRef();
        SimpleLineSearchCursor simpleLineSearchCursor2 = new SimpleLineSearchCursor(trainable, pointSample, copy) { // from class: com.simiacryptus.mindseye.opt.orient.OwlQn.1
            static final /* synthetic */ boolean $assertionsDisabled;

            @Override // com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor, com.simiacryptus.mindseye.opt.line.LineSearchCursorBase, com.simiacryptus.mindseye.opt.line.LineSearchCursor
            public LineSearchPoint step(double d, TrainingMonitor trainingMonitor2) {
                this.origin.weights.stream().forEach(state -> {
                    state.restore();
                    state.freeRef();
                });
                if (!$assertionsDisabled && this.direction == null) {
                    throw new AssertionError();
                }
                DeltaSet<UUID> copy2 = this.direction.copy();
                RefMap<UUID, Delta<K>> map2 = this.direction.getMap();
                map2.forEach((BiConsumer) RefUtil.wrapInterface((uuid3, delta2) -> {
                    if (null == delta2.getDelta()) {
                        delta2.freeRef();
                        return;
                    }
                    Delta delta2 = copy2.get((DeltaSet) uuid3, delta2.target);
                    if (!$assertionsDisabled && delta2 == null) {
                        throw new AssertionError();
                    }
                    double[] delta3 = delta2.getDelta();
                    delta2.freeRef();
                    for (int i = 0; i < delta2.getDelta().length; i++) {
                        double d2 = delta2.target[i];
                        double d3 = d2 + (delta2.getDelta()[i] * d);
                        if (OwlQn.this.sign(d2) == 0 || OwlQn.this.sign(d2) == OwlQn.this.sign(d3)) {
                            delta2.target[i] = d3;
                        } else {
                            if (!$assertionsDisabled && delta3 == null) {
                                throw new AssertionError();
                            }
                            delta3[i] = 0.0d;
                            delta2.target[i] = 0.0d;
                        }
                    }
                    delta2.freeRef();
                }, new Object[]{copy2.mo16addRef()}));
                map2.freeRef();
                if (!$assertionsDisabled && this.subject == null) {
                    throw new AssertionError();
                }
                PointSample measure = this.subject.measure(trainingMonitor2);
                measure.setRate(d);
                PointSample afterStep = afterStep(measure);
                double dot = copy2.dot(afterStep.delta.mo16addRef());
                copy2.freeRef();
                return new LineSearchPoint(afterStep, dot);
            }

            @Override // com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor, com.simiacryptus.mindseye.opt.line.LineSearchCursorBase, com.simiacryptus.mindseye.opt.line.LineSearchCursor
            public void _free() {
                super._free();
            }

            static {
                $assertionsDisabled = !OwlQn.class.desiredAssertionStatus();
            }
        };
        simpleLineSearchCursor2.setDirectionType("OWL/QN");
        return simpleLineSearchCursor2;
    }

    @Override // com.simiacryptus.mindseye.opt.orient.OrientationStrategy
    public void reset() {
        if (!$assertionsDisabled && this.inner == null) {
            throw new AssertionError();
        }
        this.inner.reset();
    }

    @Override // com.simiacryptus.mindseye.opt.orient.OrientationStrategyBase, com.simiacryptus.mindseye.opt.orient.OrientationStrategy
    public void _free() {
        super._free();
        if (null != this.inner) {
            this.inner.freeRef();
        }
    }

    @Override // com.simiacryptus.mindseye.opt.orient.OrientationStrategyBase, com.simiacryptus.mindseye.opt.orient.OrientationStrategy
    /* renamed from: addRef */
    public OwlQn mo84addRef() {
        return (OwlQn) super.mo84addRef();
    }

    protected int sign(double d) {
        if (d > this.zeroTol) {
            return 1;
        }
        return d >= (-this.zeroTol) ? -1 : 0;
    }

    static {
        $assertionsDisabled = !OwlQn.class.desiredAssertionStatus();
    }
}
