package com.simiacryptus.mindseye.network;

import com.simiacryptus.mindseye.lang.CoreSettings;
import com.simiacryptus.mindseye.lang.DeltaSet;
import com.simiacryptus.mindseye.lang.Layer;
import com.simiacryptus.mindseye.lang.Result;
import com.simiacryptus.mindseye.lang.TensorList;
import com.simiacryptus.ref.lang.RefUtil;
import com.simiacryptus.ref.wrappers.RefArrayList;
import com.simiacryptus.ref.wrappers.RefCollection;
import com.simiacryptus.ref.wrappers.RefHashMap;
import com.simiacryptus.ref.wrappers.RefList;
import com.simiacryptus.ref.wrappers.RefMap;
import com.simiacryptus.ref.wrappers.RefStream;
import java.util.ArrayList;
import java.util.List;
import java.util.UUID;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/simiacryptus/mindseye/network/CountingResult.class */
public class CountingResult extends Result {
    protected static final Logger logger;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:com/simiacryptus/mindseye/network/CountingResult$CountingAccumulator.class */
    static class CountingAccumulator extends Result.Accumulator {
        private final RefList<Layer> fwdLinks = new RefArrayList();
        private final RefMap<StackTraceElement[], TensorList> passbackBuffers = new RefHashMap();
        private final List<StackTraceElement[]> accumulations = new ArrayList();
        private Result.Accumulator innerAccumulator;

        public CountingAccumulator(Result.Accumulator accumulator) {
            this.innerAccumulator = accumulator;
        }

        public int getFwdCount() {
            return this.fwdLinks.size();
        }

        public int incrementFwd(Layer layer) {
            int size;
            synchronized (this.fwdLinks) {
                this.fwdLinks.add(layer);
                size = this.fwdLinks.size();
            }
            return size;
        }

        @Override // java.util.function.BiConsumer
        public void accept(DeltaSet<UUID> deltaSet, TensorList tensorList) {
            assertAlive();
            tensorList.assertAlive();
            if (1 >= getFwdCount()) {
                accum(deltaSet, tensorList);
            } else {
                add(deltaSet, tensorList);
            }
        }

        public void accum(DeltaSet<UUID> deltaSet, TensorList tensorList) {
            this.innerAccumulator.accept(deltaSet, tensorList);
        }

        @Override // com.simiacryptus.mindseye.lang.Result.Accumulator
        public void _free() {
            super._free();
            synchronized (this.passbackBuffers) {
                if (this.passbackBuffers.size() > 0 && this.accumulations.size() > 0) {
                    CountingResult.logger.error("Passback incomplete");
                }
                this.passbackBuffers.freeRef();
            }
            this.fwdLinks.freeRef();
            if (null != this.innerAccumulator) {
                this.innerAccumulator.freeRef();
            }
        }

        @Override // com.simiacryptus.mindseye.lang.Result.Accumulator
        /* renamed from: addRef */
        public CountingAccumulator mo38addRef() {
            return (CountingAccumulator) super.mo38addRef();
        }

        private void add(DeltaSet<UUID> deltaSet, TensorList tensorList) {
            int size;
            StackTraceElement[] access$000 = CountingResult.access$000();
            synchronized (this.passbackBuffers) {
                RefUtil.freeRef(this.passbackBuffers.put(access$000, tensorList));
                if (this.passbackBuffers.size() > CoreSettings.INSTANCE().backpropAggregationSize) {
                    RefUtil.freeRef(this.passbackBuffers.put(access$000, reduce()));
                }
            }
            int fwdCount = getFwdCount();
            synchronized (this.accumulations) {
                this.accumulations.add(access$000);
                size = this.accumulations.size() % fwdCount;
            }
            if (size == 0) {
                accum(deltaSet, reduce());
            } else {
                deltaSet.freeRef();
            }
        }

        private boolean allAlive() {
            synchronized (this.passbackBuffers) {
                this.passbackBuffers.forEach((stackTraceElementArr, tensorList) -> {
                    tensorList.assertAlive();
                    tensorList.freeRef();
                });
            }
            return true;
        }

        private TensorList reduce() {
            TensorList tensorList;
            synchronized (this.passbackBuffers) {
                RefCollection values = this.passbackBuffers.values();
                RefStream stream = values.stream();
                if (!CoreSettings.INSTANCE().singleThreaded) {
                    stream = stream.parallel();
                }
                tensorList = (TensorList) RefUtil.get(stream.reduce((tensorList2, tensorList3) -> {
                    TensorList addAndFree = tensorList2.addAndFree(tensorList3);
                    tensorList2.freeRef();
                    return addAndFree;
                }));
                values.freeRef();
                this.passbackBuffers.clear();
            }
            return tensorList;
        }
    }

    public CountingResult(Result result) {
        super(result.getData(), new CountingAccumulator(result.getAccumulator()), result.isAlive());
        result.freeRef();
    }

    public CountingResult(Result result, int i, Layer layer) {
        this(result);
        Result.Accumulator accumulator = getAccumulator();
        if (accumulator instanceof CountingAccumulator) {
            CountingAccumulator countingAccumulator = (CountingAccumulator) accumulator;
            for (int i2 = 0; i2 < i; i2++) {
                countingAccumulator.incrementFwd(layer.mo21addRef());
            }
            countingAccumulator.freeRef();
        } else {
            if (!$assertionsDisabled && !Result.isNull(accumulator)) {
                throw new AssertionError();
            }
            accumulator.freeRef();
        }
        layer.freeRef();
    }

    private static StackTraceElement[] getStackTrace() {
        return new StackTraceElement[0];
    }

    @Override // com.simiacryptus.mindseye.lang.Result
    public void _free() {
        super._free();
    }

    @Override // com.simiacryptus.mindseye.lang.Result
    /* renamed from: addRef */
    public CountingResult mo10addRef() {
        return (CountingResult) super.mo10addRef();
    }

    static /* synthetic */ StackTraceElement[] access$000() {
        return getStackTrace();
    }

    static {
        $assertionsDisabled = !CountingResult.class.desiredAssertionStatus();
        logger = LoggerFactory.getLogger(CountingResult.class);
    }
}
