package com.simiacryptus.mindseye.network;

import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.simiacryptus.mindseye.lang.DataSerializer;
import com.simiacryptus.mindseye.lang.Layer;
import com.simiacryptus.mindseye.lang.SerialPrecision;
import com.simiacryptus.mindseye.lang.Tensor;
import com.simiacryptus.mindseye.layers.ValueLayer;
import com.simiacryptus.ref.lang.RefUtil;
import com.simiacryptus.ref.wrappers.RefArrays;
import java.util.Map;
import java.util.UUID;
import java.util.function.Function;

/* loaded from: input_file:com/simiacryptus/mindseye/network/PipelineNetwork.class */
public class PipelineNetwork extends DAGNetwork {
    private DAGNode head;
    static final /* synthetic */ boolean $assertionsDisabled;

    public PipelineNetwork() {
        this(1, new Layer[0]);
    }

    public PipelineNetwork(int i, Layer... layerArr) {
        super(i);
        RefArrays.stream(layerArr).forEach(layer -> {
            RefUtil.freeRef(add(layer));
        });
    }

    protected PipelineNetwork(JsonObject jsonObject, Map<CharSequence, byte[]> map) {
        super(jsonObject, map);
        UUID fromString = UUID.fromString(jsonObject.get("head").getAsString());
        if (this.inputHandles.contains(fromString)) {
            return;
        }
        setHead(getNodeById(fromString));
    }

    public PipelineNetwork(Layer... layerArr) {
        this();
        addAll((Layer[]) RefUtil.addRef(layerArr));
        if (null != layerArr) {
            RefUtil.freeRef(layerArr);
        }
    }

    public PipelineNetwork(int i, UUID uuid, String str) {
        super(i, uuid, str);
    }

    @Override // com.simiacryptus.mindseye.network.DAGNetwork
    public final DAGNode getHead() {
        assertAlive();
        return null == this.head ? getInput(0) : this.head.mo58addRef();
    }

    public synchronized void setHead(DAGNode dAGNode) {
        if (dAGNode != this.head) {
            if (null != this.head) {
                this.head.freeRef();
            }
            this.head = dAGNode;
        } else if (null != dAGNode) {
            dAGNode.freeRef();
        }
    }

    public static PipelineNetwork fromJson(JsonObject jsonObject, Map<CharSequence, byte[]> map) {
        return new PipelineNetwork(jsonObject, map);
    }

    public static PipelineNetwork build(int i, Layer... layerArr) {
        PipelineNetwork pipelineNetwork = new PipelineNetwork(i, new Layer[0]);
        int length = layerArr.length;
        for (int i2 = 0; i2 < length; i2++) {
            Layer layer = layerArr[i2];
            layer.assertAlive();
            RefUtil.freeRef(pipelineNetwork.add(layer == null ? null : layer.mo21addRef()));
        }
        RefUtil.freeRef(layerArr);
        return pipelineNetwork;
    }

    public static PipelineNetwork combine(Layer layer, PipelineNetwork... pipelineNetworkArr) {
        if (!$assertionsDisabled && !RefArrays.stream((Object[]) RefUtil.addRef(pipelineNetworkArr)).allMatch(pipelineNetwork -> {
            boolean assertAlive = pipelineNetwork.assertAlive();
            pipelineNetwork.freeRef();
            return assertAlive;
        })) {
            throw new AssertionError();
        }
        if (1 != pipelineNetworkArr.length) {
            PipelineNetwork pipelineNetwork2 = new PipelineNetwork(1, new Layer[0]);
            RefUtil.freeRef(pipelineNetwork2.add(layer, (DAGNode[]) RefArrays.stream(pipelineNetworkArr).map((Function) RefUtil.wrapInterface(pipelineNetwork3 -> {
                return pipelineNetwork2.transferNode(pipelineNetwork3, pipelineNetwork3.getHead());
            }, new Object[0])).toArray(i -> {
                return new DAGNode[i];
            })));
            return pipelineNetwork2;
        }
        if (null != layer) {
            layer.freeRef();
        }
        PipelineNetwork mo21addRef = pipelineNetworkArr[0].mo21addRef();
        RefUtil.freeRef(pipelineNetworkArr);
        return mo21addRef;
    }

    public static PipelineNetwork sequence(PipelineNetwork... pipelineNetworkArr) {
        if (!$assertionsDisabled && !RefArrays.stream((Object[]) RefUtil.addRef(pipelineNetworkArr)).allMatch(pipelineNetwork -> {
            boolean assertAlive = pipelineNetwork.assertAlive();
            pipelineNetwork.freeRef();
            return assertAlive;
        })) {
            throw new AssertionError();
        }
        if (1 == pipelineNetworkArr.length) {
            PipelineNetwork mo21addRef = pipelineNetworkArr[0].mo21addRef();
            RefUtil.freeRef(pipelineNetworkArr);
            return mo21addRef;
        }
        if (1 != pipelineNetworkArr[0].inputHandles.size()) {
            RefUtil.freeRef(pipelineNetworkArr);
            throw new IllegalArgumentException();
        }
        if (1 != RefArrays.stream((Object[]) RefUtil.addRef(pipelineNetworkArr)).mapToInt(pipelineNetwork2 -> {
            int size = pipelineNetwork2.inputHandles.size();
            pipelineNetwork2.freeRef();
            return size;
        }).distinct().count()) {
            RefUtil.freeRef(pipelineNetworkArr);
            throw new IllegalArgumentException();
        }
        PipelineNetwork pipelineNetwork3 = new PipelineNetwork(pipelineNetworkArr[0].inputHandles.size(), new Layer[0]);
        for (PipelineNetwork pipelineNetwork4 : pipelineNetworkArr) {
            RefUtil.freeRef(pipelineNetwork3.transferNode(pipelineNetwork4.mo21addRef(), pipelineNetwork4.getHead()));
        }
        RefUtil.freeRef(pipelineNetworkArr);
        return pipelineNetwork3;
    }

    public static PipelineNetwork getCopy(PipelineNetwork pipelineNetwork) {
        if (null == pipelineNetwork) {
            return null;
        }
        try {
            return pipelineNetwork.copyPipeline();
        } finally {
            pipelineNetwork.freeRef();
        }
    }

    @Override // com.simiacryptus.mindseye.network.DAGNetwork, com.simiacryptus.mindseye.lang.Layer
    public PipelineNetwork copy(SerialPrecision serialPrecision) {
        return (PipelineNetwork) super.copy(serialPrecision);
    }

    @Override // com.simiacryptus.mindseye.lang.Layer
    public PipelineNetwork copy() {
        return (PipelineNetwork) super.copy();
    }

    public InnerNode add(Layer layer) {
        if (!$assertionsDisabled && layer == null) {
            throw new AssertionError();
        }
        if ($assertionsDisabled || layer.assertAlive()) {
            return add(layer, getHead());
        }
        throw new AssertionError();
    }

    @Override // com.simiacryptus.mindseye.network.DAGNetwork
    public InnerNode add(Layer layer, DAGNode... dAGNodeArr) {
        if (!$assertionsDisabled && layer == null) {
            throw new AssertionError();
        }
        InnerNode add = super.add(layer, dAGNodeArr);
        setHead(add.mo58addRef());
        return add;
    }

    @Override // com.simiacryptus.mindseye.network.DAGNetwork
    @SafeVarargs
    public final InnerNode add(CharSequence charSequence, Layer layer, DAGNode... dAGNodeArr) {
        if (!$assertionsDisabled && layer == null) {
            throw new AssertionError();
        }
        InnerNode add = super.add(charSequence, layer, dAGNodeArr);
        setHead(add.mo58addRef());
        return add;
    }

    public void addAll(InnerNode innerNode, Layer... layerArr) {
        int length = layerArr.length;
        for (int i = 0; i < length; i++) {
            Layer layer = layerArr[i];
            innerNode = add(layer == null ? null : layer.mo21addRef(), innerNode);
        }
        RefUtil.freeRef(layerArr);
        if (null != innerNode) {
            innerNode.freeRef();
        }
    }

    public void addAll(Layer... layerArr) {
        addAll((InnerNode) getHead(), (Layer[]) RefUtil.addRef(layerArr));
        if (null != layerArr) {
            RefUtil.freeRef(layerArr);
        }
    }

    public DAGNode constValue(Tensor tensor) {
        return add(new ValueLayer(tensor), new DAGNode[0]);
    }

    public DAGNode constValueWrap(Tensor tensor) {
        DAGNode constValue = constValue(tensor == null ? null : tensor.m43addRef());
        if (null != tensor) {
            tensor.freeRef();
        }
        return constValue;
    }

    @Override // com.simiacryptus.mindseye.network.DAGNetwork
    public JsonObject getJson(Map<CharSequence, byte[]> map, DataSerializer dataSerializer) {
        assertConsistent();
        JsonObject json = super.getJson(map, dataSerializer);
        if (!$assertionsDisabled && json == null) {
            throw new AssertionError();
        }
        json.addProperty("head", getHeadId().toString());
        return json;
    }

    @Override // com.simiacryptus.mindseye.lang.Layer
    public PipelineNetwork andThen(Layer layer) {
        if (!$assertionsDisabled && !layer.assertAlive()) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && !assertAlive()) {
            throw new AssertionError();
        }
        RefUtil.freeRef(add(layer.mo21addRef()));
        layer.freeRef();
        return mo21addRef();
    }

    public final PipelineNetwork copyPipeline() {
        PipelineNetwork pipelineNetwork = new PipelineNetwork(1, getId(), getName());
        if (!this.internalNodes.isEmpty()) {
            pipelineNetwork.setHead(pipelineNetwork.transferNode(mo21addRef(), getHead()));
        }
        return pipelineNetwork;
    }

    @Override // com.simiacryptus.mindseye.network.DAGNetwork, com.simiacryptus.mindseye.lang.LayerBase, com.simiacryptus.mindseye.lang.Layer
    public void _free() {
        super._free();
        if (null != this.head) {
            this.head.freeRef();
            this.head = null;
        }
    }

    @Override // com.simiacryptus.mindseye.network.DAGNetwork, com.simiacryptus.mindseye.lang.LayerBase, com.simiacryptus.mindseye.lang.Layer
    /* renamed from: addRef */
    public PipelineNetwork mo21addRef() {
        return (PipelineNetwork) super.mo21addRef();
    }

    @Override // com.simiacryptus.mindseye.network.DAGNetwork, com.simiacryptus.mindseye.lang.ZipSerializable
    /* renamed from: getJson */
    public /* bridge */ /* synthetic */ JsonElement mo52getJson(Map map, DataSerializer dataSerializer) {
        return getJson((Map<CharSequence, byte[]>) map, dataSerializer);
    }

    static {
        $assertionsDisabled = !PipelineNetwork.class.desiredAssertionStatus();
    }
}
