package com.simiacryptus.mindseye.layers;

import com.google.gson.JsonObject;
import com.simiacryptus.mindseye.lang.DeltaSet;
import com.simiacryptus.mindseye.lang.Layer;
import com.simiacryptus.mindseye.lang.Result;
import com.simiacryptus.mindseye.lang.Tensor;
import com.simiacryptus.mindseye.lang.TensorList;
import com.simiacryptus.ref.lang.RefUtil;
import com.simiacryptus.ref.wrappers.RefArrays;
import com.simiacryptus.ref.wrappers.RefIntStream;
import com.simiacryptus.ref.wrappers.RefString;
import java.util.Map;
import java.util.UUID;
import java.util.function.IntFunction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/simiacryptus/mindseye/layers/LoggingWrapperLayer.class */
public final class LoggingWrapperLayer extends WrapperLayer {
    static final Logger log;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:com/simiacryptus/mindseye/layers/LoggingWrapperLayer$Accumulator.class */
    private static class Accumulator extends Result.Accumulator {
        private Layer inner;
        private Result.Accumulator accumulator;

        public Accumulator(Layer layer, Result.Accumulator accumulator) {
            this.inner = layer;
            this.accumulator = accumulator;
        }

        @Override // java.util.function.BiConsumer
        public void accept(DeltaSet<UUID> deltaSet, TensorList tensorList) {
            LoggingWrapperLayer.log.info(RefString.format("Feedback Input for key %s: \n\t%s", new Object[]{this.inner.getName(), ((String) RefUtil.get(tensorList.stream().map(tensor -> {
                return LoggingWrapperLayer.getString(tensor);
            }).reduce((str, str2) -> {
                return str + "\n" + str2;
            }))).replaceAll("\n", "\n\t")}));
            Result.Accumulator accumulator = this.accumulator;
            try {
                accumulator.accept(deltaSet, tensorList);
                accumulator.freeRef();
            } catch (Throwable th) {
                accumulator.freeRef();
                throw th;
            }
        }

        @Override // com.simiacryptus.mindseye.lang.Result.Accumulator
        public void _free() {
            super._free();
            this.accumulator.freeRef();
            this.inner.freeRef();
        }
    }

    /* loaded from: input_file:com/simiacryptus/mindseye/layers/LoggingWrapperLayer$Accumulator2.class */
    private static class Accumulator2 extends Result.Accumulator {
        private final int i;
        private Layer inner;
        private Result.Accumulator accumulator;

        public Accumulator2(Layer layer, int i, Result.Accumulator accumulator) {
            this.i = i;
            this.inner = layer;
            this.accumulator = accumulator;
        }

        @Override // java.util.function.BiConsumer
        public void accept(DeltaSet<UUID> deltaSet, TensorList tensorList) {
            LoggingWrapperLayer.log.info(RefString.format("Feedback Output %s for key %s: \n\t%s", new Object[]{Integer.valueOf(this.i), this.inner.getName(), ((String) RefUtil.get(tensorList.stream().map(LoggingWrapperLayer::getString).reduce((str, str2) -> {
                return str + "\n" + str2;
            }))).replaceAll("\n", "\n\t")}));
            Result.Accumulator accumulator = this.accumulator;
            try {
                accumulator.accept(deltaSet, tensorList);
                accumulator.freeRef();
            } catch (Throwable th) {
                accumulator.freeRef();
                throw th;
            }
        }

        @Override // com.simiacryptus.mindseye.lang.Result.Accumulator
        public void _free() {
            super._free();
            this.accumulator.freeRef();
            this.inner.freeRef();
        }
    }

    protected LoggingWrapperLayer(JsonObject jsonObject, Map<CharSequence, byte[]> map) {
        super(jsonObject, map);
    }

    public LoggingWrapperLayer(Layer layer) {
        super(layer);
    }

    public static LoggingWrapperLayer fromJson(JsonObject jsonObject, Map<CharSequence, byte[]> map) {
        return new LoggingWrapperLayer(jsonObject, map);
    }

    public static String getString(Tensor tensor) {
        try {
            return RefArrays.toString(tensor.getDimensions());
        } finally {
            tensor.freeRef();
        }
    }

    @Override // com.simiacryptus.mindseye.layers.WrapperLayer, com.simiacryptus.mindseye.lang.Layer
    public Result eval(Result... resultArr) {
        LoggingWrapperLayer mo21addRef = mo21addRef();
        Result[] resultArr2 = (Result[]) RefIntStream.range(0, resultArr.length).mapToObj((IntFunction) RefUtil.wrapInterface(i -> {
            Result mo10addRef = resultArr[i].mo10addRef();
            boolean isAlive = mo10addRef.isAlive();
            TensorList data = mo10addRef.getData();
            Accumulator2 accumulator2 = new Accumulator2(this.inner.mo21addRef(), i, mo10addRef.getAccumulator());
            mo10addRef.freeRef();
            return new Result(data, accumulator2, isAlive);
        }, new Object[]{mo21addRef.mo21addRef(), RefUtil.addRef(resultArr)})).toArray(i2 -> {
            return new Result[i2];
        });
        for (int i3 = 0; i3 < resultArr.length; i3++) {
            TensorList data = resultArr[i3].getData();
            String str = (String) RefUtil.get(data.stream().map(LoggingWrapperLayer::getString).reduce((str2, str3) -> {
                return str2 + "\n" + str3;
            }));
            data.freeRef();
            log.info(RefString.format("Input %s for key %s: \n\t%s", new Object[]{Integer.valueOf(i3), this.inner.getName(), str.replaceAll("\n", "\n\t")}));
        }
        RefUtil.freeRef(resultArr);
        Result eval = this.inner.eval(resultArr2);
        if (!$assertionsDisabled && eval == null) {
            throw new AssertionError();
        }
        TensorList data2 = eval.getData();
        String str4 = (String) RefUtil.get(data2.stream().map(LoggingWrapperLayer::getString).reduce((str5, str6) -> {
            return str5 + "\n" + str6;
        }));
        data2.freeRef();
        log.info(RefString.format("Output for key %s: \n\t%s", new Object[]{this.inner.getName(), str4.replaceAll("\n", "\n\t")}));
        boolean isAlive = eval.isAlive();
        TensorList data3 = eval.getData();
        Accumulator accumulator = new Accumulator(this.inner.mo21addRef(), eval.getAccumulator());
        eval.freeRef();
        mo21addRef.freeRef();
        return new Result(data3, accumulator, isAlive);
    }

    @Override // com.simiacryptus.mindseye.layers.WrapperLayer, com.simiacryptus.mindseye.lang.LayerBase, com.simiacryptus.mindseye.lang.Layer
    public void _free() {
        super._free();
    }

    @Override // com.simiacryptus.mindseye.layers.WrapperLayer, com.simiacryptus.mindseye.lang.LayerBase, com.simiacryptus.mindseye.lang.Layer
    /* renamed from: addRef */
    public LoggingWrapperLayer mo21addRef() {
        return (LoggingWrapperLayer) super.mo21addRef();
    }

    static {
        $assertionsDisabled = !LoggingWrapperLayer.class.desiredAssertionStatus();
        log = LoggerFactory.getLogger(LoggingWrapperLayer.class);
    }
}
