package com.simiacryptus.mindseye.layers;

import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.simiacryptus.lang.TimedResult;
import com.simiacryptus.lang.UncheckedRunnable;
import com.simiacryptus.lang.UncheckedSupplier;
import com.simiacryptus.mindseye.lang.DataSerializer;
import com.simiacryptus.mindseye.lang.DeltaSet;
import com.simiacryptus.mindseye.lang.Layer;
import com.simiacryptus.mindseye.lang.Result;
import com.simiacryptus.mindseye.lang.TensorList;
import com.simiacryptus.ref.lang.RefUtil;
import com.simiacryptus.ref.wrappers.RefArrays;
import com.simiacryptus.ref.wrappers.RefDoubleStream;
import com.simiacryptus.ref.wrappers.RefList;
import com.simiacryptus.util.MonitoredItem;
import com.simiacryptus.util.MonitoredObject;
import com.simiacryptus.util.data.PercentileStatistics;
import com.simiacryptus.util.data.ScalarStatistics;
import java.lang.invoke.SerializedLambda;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicLong;

/* loaded from: input_file:com/simiacryptus/mindseye/layers/MonitoringWrapperLayer.class */
public final class MonitoringWrapperLayer extends WrapperLayer implements MonitoredItem {
    private final PercentileStatistics backwardPerformance;
    private final ScalarStatistics backwardSignal;
    private final PercentileStatistics forwardPerformance;
    private final ScalarStatistics forwardSignal;
    private final boolean verbose = false;
    private boolean recordSignalMetrics;
    private int totalBatches;
    private int totalItems;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/simiacryptus/mindseye/layers/MonitoringWrapperLayer$Accumulator.class */
    public static class Accumulator extends Result.Accumulator {
        private final AtomicLong passbackNanos;
        private Result.Accumulator accumulator;

        public Accumulator(AtomicLong atomicLong, Result.Accumulator accumulator) {
            this.passbackNanos = atomicLong;
            this.accumulator = accumulator;
        }

        @Override // java.util.function.BiConsumer
        public void accept(DeltaSet<UUID> deltaSet, TensorList tensorList) {
            TimedResult time = TimedResult.time((UncheckedRunnable) RefUtil.wrapInterface(() -> {
                this.accumulator.accept(deltaSet.mo16addRef(), tensorList.mo36addRef());
            }, new Object[]{tensorList, deltaSet}));
            this.passbackNanos.addAndGet(time.timeNanos);
            time.freeRef();
        }

        @Override // com.simiacryptus.mindseye.lang.Result.Accumulator
        public void _free() {
            super._free();
            this.accumulator.freeRef();
        }
    }

    /* loaded from: input_file:com/simiacryptus/mindseye/layers/MonitoringWrapperLayer$Accumulator2.class */
    private class Accumulator2 extends Result.Accumulator {
        private final AtomicLong passbackNanos;
        private final int items;
        private Result.Accumulator accumulator;

        public Accumulator2(AtomicLong atomicLong, int i, Result.Accumulator accumulator) {
            this.passbackNanos = atomicLong;
            this.items = i;
            this.accumulator = accumulator;
        }

        @Override // java.util.function.BiConsumer
        public void accept(DeltaSet<UUID> deltaSet, TensorList tensorList) {
            if (MonitoringWrapperLayer.this.recordSignalMetrics) {
                MonitoringWrapperLayer.this.backwardSignal.clear();
                tensorList.stream().parallel().forEach(tensor -> {
                    MonitoringWrapperLayer.this.backwardSignal.add(tensor.getData());
                    tensor.freeRef();
                });
            }
            TimedResult time = TimedResult.time((UncheckedRunnable) RefUtil.wrapInterface(() -> {
                DeltaSet mo16addRef = deltaSet == null ? null : deltaSet.mo16addRef();
                Result.Accumulator accumulator = this.accumulator;
                try {
                    accumulator.accept(mo16addRef, tensorList.mo36addRef());
                    accumulator.freeRef();
                } catch (Throwable th) {
                    accumulator.freeRef();
                    throw th;
                }
            }, new Object[]{tensorList, this.accumulator.mo38addRef(), deltaSet}));
            MonitoringWrapperLayer.this.backwardPerformance.add((time.timeNanos - this.passbackNanos.getAndSet(0L)) / (this.items * 1.0E9d));
            time.freeRef();
        }

        @Override // com.simiacryptus.mindseye.lang.Result.Accumulator
        public void _free() {
            super._free();
            this.accumulator.freeRef();
        }
    }

    protected MonitoringWrapperLayer(JsonObject jsonObject, Map<CharSequence, byte[]> map) {
        super(jsonObject, map);
        this.backwardPerformance = new PercentileStatistics();
        this.backwardSignal = new PercentileStatistics();
        this.forwardPerformance = new PercentileStatistics();
        this.forwardSignal = new PercentileStatistics();
        this.verbose = false;
        this.recordSignalMetrics = false;
        this.totalBatches = 0;
        this.totalItems = 0;
        if (jsonObject.has("forwardPerf")) {
            this.forwardPerformance.readJson(jsonObject.getAsJsonObject("forwardPerf"));
        }
        if (jsonObject.has("backwardPerf")) {
            this.backwardPerformance.readJson(jsonObject.getAsJsonObject("backwardPerf"));
        }
        if (jsonObject.has("backpropStatistics")) {
            this.backwardSignal.readJson(jsonObject.getAsJsonObject("backpropStatistics"));
        }
        if (jsonObject.has("outputStatistics")) {
            this.forwardSignal.readJson(jsonObject.getAsJsonObject("outputStatistics"));
        }
        this.recordSignalMetrics = jsonObject.get("recordSignalMetrics").getAsBoolean();
        this.totalBatches = jsonObject.get("totalBatches").getAsInt();
        this.totalItems = jsonObject.get("totalItems").getAsInt();
    }

    public MonitoringWrapperLayer(Layer layer) {
        super(layer);
        this.backwardPerformance = new PercentileStatistics();
        this.backwardSignal = new PercentileStatistics();
        this.forwardPerformance = new PercentileStatistics();
        this.forwardSignal = new PercentileStatistics();
        this.verbose = false;
        this.recordSignalMetrics = false;
        this.totalBatches = 0;
        this.totalItems = 0;
    }

    public PercentileStatistics getBackwardPerformance() {
        return this.backwardPerformance;
    }

    public ScalarStatistics getBackwardSignal() {
        return this.backwardSignal;
    }

    public PercentileStatistics getForwardPerformance() {
        return this.forwardPerformance;
    }

    public ScalarStatistics getForwardSignal() {
        return this.forwardSignal;
    }

    public Map<CharSequence, Object> getMetrics() {
        HashMap hashMap = new HashMap();
        hashMap.put("class", this.inner.getClass().getName());
        hashMap.put("totalBatches", Integer.valueOf(this.totalBatches));
        hashMap.put("totalItems", Integer.valueOf(this.totalItems));
        hashMap.put("outputStatistics", this.forwardSignal.getMetrics());
        hashMap.put("backpropStatistics", this.backwardSignal.getMetrics());
        double d = (this.totalBatches * 1.0d) / this.totalItems;
        hashMap.put("avgMsPerItem", Double.valueOf(1000.0d * d * this.forwardPerformance.getMean()));
        hashMap.put("medianMsPerItem", Double.valueOf(1000.0d * d * this.forwardPerformance.getPercentile(0.5d).doubleValue()));
        double mean = this.backwardPerformance.getMean();
        double doubleValue = this.backwardPerformance.getPercentile(0.5d).doubleValue();
        hashMap.put("avgMsPerItem_Backward", Double.valueOf(1000.0d * d * mean));
        hashMap.put("medianMsPerItem_Backward", Double.valueOf(1000.0d * d * doubleValue));
        RefList<double[]> state = state();
        PercentileStatistics percentileStatistics = new PercentileStatistics();
        if (!$assertionsDisabled && state == null) {
            throw new AssertionError();
        }
        RefDoubleStream flatMapToDouble = state.stream().flatMapToDouble(Arrays::stream);
        percentileStatistics.getClass();
        flatMapToDouble.forEach(percentileStatistics::add);
        if (percentileStatistics.getCount() > 0) {
            HashMap hashMap2 = new HashMap();
            hashMap2.put("buffers", Integer.valueOf(state.size()));
            hashMap2.putAll(percentileStatistics.getMetrics());
            hashMap.put("weights", hashMap2);
        }
        state.freeRef();
        return hashMap;
    }

    @Override // com.simiacryptus.mindseye.lang.LayerBase, com.simiacryptus.mindseye.lang.Layer
    public String getName() {
        return this.inner.getName();
    }

    @Override // com.simiacryptus.mindseye.lang.LayerBase, com.simiacryptus.mindseye.lang.Layer
    public void setName(String str) {
        if (null != this.inner) {
            this.inner.setName(str);
        }
    }

    public static MonitoringWrapperLayer fromJson(JsonObject jsonObject, Map<CharSequence, byte[]> map) {
        return new MonitoringWrapperLayer(jsonObject, map);
    }

    public MonitoringWrapperLayer addTo2(MonitoredObject monitoredObject) {
        addTo(monitoredObject, this.inner.getName());
        return mo21addRef();
    }

    public void addTo(MonitoredObject monitoredObject, String str) {
        setName(str);
        monitoredObject.addObj(getName(), mo21addRef());
        monitoredObject.freeRef();
    }

    @Override // com.simiacryptus.mindseye.layers.WrapperLayer, com.simiacryptus.mindseye.lang.Layer
    public Result eval(Result... resultArr) {
        AtomicLong atomicLong = new AtomicLong(0L);
        Result[] resultArr2 = (Result[]) RefArrays.stream((Object[]) RefUtil.addRef(resultArr)).map(result -> {
            boolean isAlive = result.isAlive();
            TensorList data = result.getData();
            Accumulator accumulator = new Accumulator(atomicLong, result.getAccumulator());
            result.freeRef();
            return new Result(data, accumulator, isAlive);
        }).toArray(i -> {
            return new Result[i];
        });
        TimedResult time = TimedResult.time((UncheckedSupplier) RefUtil.wrapInterface(() -> {
            return this.inner.eval((Result[]) RefUtil.addRef(resultArr2));
        }, resultArr2));
        Result result2 = (Result) time.getResult();
        this.forwardPerformance.add(time.timeNanos / 1.0E9d);
        time.freeRef();
        this.totalBatches++;
        int orElse = RefArrays.stream(resultArr).mapToInt(result3 -> {
            TensorList data = result3.getData();
            int length = data.length();
            data.freeRef();
            result3.freeRef();
            return length;
        }).max().orElse(1);
        this.totalItems += orElse;
        if (this.recordSignalMetrics) {
            this.forwardSignal.clear();
            TensorList data = result2.getData();
            data.stream().parallel().forEach(tensor -> {
                this.forwardSignal.add(tensor.getData());
                tensor.freeRef();
            });
            data.freeRef();
        }
        boolean isAlive = result2.isAlive();
        TensorList data2 = result2.getData();
        Accumulator2 accumulator2 = new Accumulator2(atomicLong, orElse, result2.getAccumulator());
        result2.freeRef();
        return new Result(data2, accumulator2, isAlive);
    }

    @Override // com.simiacryptus.mindseye.layers.WrapperLayer
    public JsonObject getJson(Map<CharSequence, byte[]> map, DataSerializer dataSerializer) {
        JsonObject json = super.getJson(map, dataSerializer);
        json.addProperty("totalBatches", Integer.valueOf(this.totalBatches));
        json.addProperty("totalItems", Integer.valueOf(this.totalItems));
        json.addProperty("recordSignalMetrics", Boolean.valueOf(this.recordSignalMetrics));
        return json;
    }

    public void shouldRecordSignalMetrics(boolean z) {
        this.recordSignalMetrics = z;
    }

    @Override // com.simiacryptus.mindseye.layers.WrapperLayer, com.simiacryptus.mindseye.lang.LayerBase, com.simiacryptus.mindseye.lang.Layer
    public void _free() {
        super._free();
    }

    @Override // com.simiacryptus.mindseye.layers.WrapperLayer, com.simiacryptus.mindseye.lang.LayerBase, com.simiacryptus.mindseye.lang.Layer
    /* renamed from: addRef */
    public MonitoringWrapperLayer mo21addRef() {
        return (MonitoringWrapperLayer) super.mo21addRef();
    }

    @Override // com.simiacryptus.mindseye.layers.WrapperLayer, com.simiacryptus.mindseye.lang.ZipSerializable
    /* renamed from: getJson, reason: collision with other method in class */
    public /* bridge */ /* synthetic */ JsonElement mo52getJson(Map map, DataSerializer dataSerializer) {
        return getJson((Map<CharSequence, byte[]>) map, dataSerializer);
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -965542325:
                if (implMethodName.equals("lambda$eval$d141bf49$1")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("com/simiacryptus/lang/UncheckedSupplier") && serializedLambda.getFunctionalInterfaceMethodName().equals("get") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("()Ljava/lang/Object;") && serializedLambda.getImplClass().equals("com/simiacryptus/mindseye/layers/MonitoringWrapperLayer") && serializedLambda.getImplMethodSignature().equals("([Lcom/simiacryptus/mindseye/lang/Result;)Lcom/simiacryptus/mindseye/lang/Result;")) {
                    MonitoringWrapperLayer monitoringWrapperLayer = (MonitoringWrapperLayer) serializedLambda.getCapturedArg(0);
                    Result[] resultArr = (Result[]) serializedLambda.getCapturedArg(1);
                    return () -> {
                        return this.inner.eval((Result[]) RefUtil.addRef(resultArr));
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }

    static {
        $assertionsDisabled = !MonitoringWrapperLayer.class.desiredAssertionStatus();
    }
}
