package com.simiacryptus.mindseye.eval;

import com.google.common.util.concurrent.AtomicDouble;
import com.simiacryptus.mindseye.lang.Delta;
import com.simiacryptus.mindseye.lang.DeltaSet;
import com.simiacryptus.mindseye.lang.Layer;
import com.simiacryptus.mindseye.lang.PointSample;
import com.simiacryptus.mindseye.network.DAGNetwork;
import com.simiacryptus.mindseye.opt.TrainingMonitor;
import com.simiacryptus.ref.wrappers.RefCollection;
import com.simiacryptus.ref.wrappers.RefCollectors;
import com.simiacryptus.ref.wrappers.RefList;
import com.simiacryptus.ref.wrappers.RefMap;
import java.util.UUID;

/* loaded from: input_file:com/simiacryptus/mindseye/eval/L12Normalizer.class */
public abstract class L12Normalizer extends TrainableBase {
    public final Trainable inner;
    private final boolean hideAdj = false;
    static final /* synthetic */ boolean $assertionsDisabled;

    public L12Normalizer(Trainable trainable) {
        Trainable mo1addRef = trainable == null ? null : trainable.mo1addRef();
        this.inner = mo1addRef == null ? null : mo1addRef.mo1addRef();
        if (null != mo1addRef) {
            mo1addRef.freeRef();
        }
        if (null != trainable) {
            trainable.freeRef();
        }
    }

    public Layer toLayer(UUID uuid) {
        if (!$assertionsDisabled && this.inner == null) {
            throw new AssertionError();
        }
        DAGNetwork dAGNetwork = (DAGNetwork) this.inner.getLayer();
        if (null == dAGNetwork) {
            return null;
        }
        RefMap<UUID, Layer> layersById = dAGNetwork.getLayersById();
        Layer layer = (Layer) layersById.get(uuid);
        layersById.freeRef();
        dAGNetwork.freeRef();
        return layer;
    }

    public RefCollection<Layer> getLayers(RefCollection<UUID> refCollection) {
        RefList refList = (RefList) refCollection.stream().map(this::toLayer).collect(RefCollectors.toList());
        refCollection.freeRef();
        return refList;
    }

    @Override // com.simiacryptus.mindseye.eval.Trainable
    public PointSample measure(TrainingMonitor trainingMonitor) {
        if (!$assertionsDisabled && this.inner == null) {
            throw new AssertionError();
        }
        PointSample measure = this.inner.measure(trainingMonitor);
        DeltaSet<UUID> deltaSet = new DeltaSet<>();
        AtomicDouble atomicDouble = new AtomicDouble(0.0d);
        RefCollection<Layer> layers = getLayers(measure.delta.keySet());
        layers.forEach(layer -> {
            Delta delta = measure.delta.get(layer.getId());
            if (!$assertionsDisabled && delta == null) {
                throw new AssertionError();
            }
            double[] dArr = delta.target;
            delta.freeRef();
            Delta delta2 = deltaSet.get((DeltaSet) layer.getId(), dArr);
            if (!$assertionsDisabled && delta2 == null) {
                throw new AssertionError();
            }
            double[] delta3 = delta2.getDelta();
            delta2.freeRef();
            double l1 = getL1(layer.mo21addRef());
            double l2 = getL2(layer);
            if (!$assertionsDisabled && null == delta3) {
                throw new AssertionError();
            }
            for (int i = 0; i < delta3.length; i++) {
                double d = dArr[i] < 0.0d ? -1.0d : 1.0d;
                int i2 = i;
                delta3[i2] = delta3[i2] + (l1 * d) + (2.0d * l2 * dArr[i]);
                atomicDouble.addAndGet(((l1 * d) + (l2 * dArr[i])) * dArr[i]);
            }
        });
        layers.freeRef();
        DeltaSet<UUID> add = measure.delta.add(deltaSet);
        PointSample pointSample = new PointSample(add.mo16addRef(), measure.weights.mo16addRef(), measure.sum + atomicDouble.get(), measure.rate, measure.count);
        add.freeRef();
        measure.freeRef();
        PointSample normalize = pointSample.normalize();
        pointSample.freeRef();
        return normalize;
    }

    @Override // com.simiacryptus.mindseye.eval.Trainable
    public boolean reseed(long j) {
        if ($assertionsDisabled || this.inner != null) {
            return this.inner.reseed(j);
        }
        throw new AssertionError();
    }

    @Override // com.simiacryptus.mindseye.eval.TrainableBase, com.simiacryptus.mindseye.eval.SampledTrainable, com.simiacryptus.mindseye.eval.TrainableDataMask
    public void _free() {
        super._free();
        if (null != this.inner) {
            this.inner.freeRef();
        }
    }

    @Override // com.simiacryptus.mindseye.eval.TrainableBase, com.simiacryptus.mindseye.eval.Trainable, com.simiacryptus.mindseye.eval.TrainableDataMask
    /* renamed from: addRef */
    public L12Normalizer mo1addRef() {
        return (L12Normalizer) super.mo1addRef();
    }

    protected abstract double getL1(Layer layer);

    protected abstract double getL2(Layer layer);

    static {
        $assertionsDisabled = !L12Normalizer.class.desiredAssertionStatus();
    }
}
