package org.diffkt;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import kotlin.Metadata;
import kotlin.collections.ArraysKt;
import kotlin.collections.CollectionsKt;
import kotlin.jvm.functions.Function1;
import kotlin.jvm.functions.Function2;
import kotlin.jvm.functions.Function3;
import kotlin.jvm.internal.Intrinsics;
import kotlin.jvm.internal.SerializedIr;
import org.diffkt.Shape;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

/* compiled from: Combinators.kt */
@SerializedIr(b = {"��"})
@Metadata(mv = {Convolve.H_AXIS, 6, Convolve.N_AXIS}, k = Convolve.H_AXIS, xi = 48, d1 = {"��N\n\u0002\u0018\u0002\n\u0002\u0010��\n\u0002\b\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010\u0015\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\u0010\u0007\n\u0002\b\u0002\n\u0002\u0010\u000b\n��\n\u0002\u0018\u0002\n\u0002\u0010\u0014\n\u0002\u0010\b\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0003\bÆ\u0002\u0018��2\u00020\u0001:\u0001\u0019B\u0007\b\u0002¢\u0006\u0002\u0010\u0002J\u0018\u0010\u0003\u001a\u00020\u00042\u0006\u0010\u0005\u001a\u00020\u00062\u0006\u0010\u0007\u001a\u00020\bH\u0002J<\u0010\t\u001a\u00020\u00062\u0006\u0010\u0005\u001a\u00020\u00062\u0018\u0010\n\u001a\u0014\u0012\u0004\u0012\u00020\f\u0012\u0004\u0012\u00020\f\u0012\u0004\u0012\u00020\f0\u000b2\b\b\u0002\u0010\r\u001a\u00020\b2\b\b\u0002\u0010\u000e\u001a\u00020\u000fJS\u0010\u0010\u001a\u00020\u00062\u0006\u0010\u0005\u001a\u00020\u00062*\u0010\n\u001a&\u0012\u0004\u0012\u00020\u0012\u0012\u0004\u0012\u00020\u0013\u0012\u0010\u0012\u000e\u0012\u0004\u0012\u00020\u0013\u0012\u0004\u0012\u00020\u00130\u0014\u0012\u0004\u0012\u00020\f0\u00112\u0006\u0010\u0007\u001a\u00020\b2\b\b\u0002\u0010\u000e\u001a\u00020\u000fH��¢\u0006\u0002\b\u0015JC\u0010\u0016\u001a\u00020\u00062\u0006\u0010\u0005\u001a\u00020\u00172\u0018\u0010\n\u001a\u0014\u0012\u0004\u0012\u00020\f\u0012\u0004\u0012\u00020\f\u0012\u0004\u0012\u00020\f0\u000b2\b\b\u0002\u0010\r\u001a\u00020\b2\b\b\u0002\u0010\u000e\u001a\u00020\u000fH��¢\u0006\u0002\b\u0018¨\u0006\u001a"}, d2 = {"Lorg/diffkt/Combinators;", "", "()V", "partitionForReduce", "Lorg/diffkt/Combinators$PartitionInfo;", "x", "Lorg/diffkt/FloatTensor;", "cellAxes", "", "reduce", "f", "Lkotlin/Function2;", "", "axes", "keepDims", "", "reduceImpl", "Lkotlin/Function3;", "", "", "Lkotlin/Function1;", "reduceImpl$api", "reduceSparse", "Lorg/diffkt/SparseFloatTensor;", "reduceSparse$api", "PartitionInfo", "api"})
/* loaded from: input_file:org/diffkt/Combinators.class */
public final class Combinators {

    @NotNull
    public static final Combinators INSTANCE = new Combinators();

    /* JADX INFO: Access modifiers changed from: private */
    /* compiled from: Combinators.kt */
    @Metadata(mv = {Convolve.H_AXIS, 6, Convolve.N_AXIS}, k = Convolve.H_AXIS, xi = 48, d1 = {"��.\n\u0002\u0018\u0002\n\u0002\u0010��\n��\n\u0002\u0010\u0015\n��\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010\b\n\u0002\b\u0017\n\u0002\u0010\u000b\n\u0002\b\u0003\n\u0002\u0010\u000e\n��\b\u0082\b\u0018��2\u00020\u0001B=\u0012\u0006\u0010\u0002\u001a\u00020\u0003\u0012\u0006\u0010\u0004\u001a\u00020\u0005\u0012\u0006\u0010\u0006\u001a\u00020\u0003\u0012\u0006\u0010\u0007\u001a\u00020\b\u0012\u0006\u0010\t\u001a\u00020\u0005\u0012\u0006\u0010\n\u001a\u00020\u0003\u0012\u0006\u0010\u000b\u001a\u00020\b¢\u0006\u0002\u0010\fJ\t\u0010\u0017\u001a\u00020\u0003HÆ\u0003J\t\u0010\u0018\u001a\u00020\u0005HÆ\u0003J\t\u0010\u0019\u001a\u00020\u0003HÆ\u0003J\t\u0010\u001a\u001a\u00020\bHÆ\u0003J\t\u0010\u001b\u001a\u00020\u0005HÆ\u0003J\t\u0010\u001c\u001a\u00020\u0003HÆ\u0003J\t\u0010\u001d\u001a\u00020\bHÆ\u0003JO\u0010\u001e\u001a\u00020��2\b\b\u0002\u0010\u0002\u001a\u00020\u00032\b\b\u0002\u0010\u0004\u001a\u00020\u00052\b\b\u0002\u0010\u0006\u001a\u00020\u00032\b\b\u0002\u0010\u0007\u001a\u00020\b2\b\b\u0002\u0010\t\u001a\u00020\u00052\b\b\u0002\u0010\n\u001a\u00020\u00032\b\b\u0002\u0010\u000b\u001a\u00020\bHÆ\u0001J\u0013\u0010\u001f\u001a\u00020 2\b\u0010!\u001a\u0004\u0018\u00010\u0001HÖ\u0003J\t\u0010\"\u001a\u00020\bHÖ\u0001J\t\u0010#\u001a\u00020$HÖ\u0001R\u0011\u0010\t\u001a\u00020\u0005¢\u0006\b\n��\u001a\u0004\b\r\u0010\u000eR\u0011\u0010\u000b\u001a\u00020\b¢\u0006\b\n��\u001a\u0004\b\u000f\u0010\u0010R\u0011\u0010\n\u001a\u00020\u0003¢\u0006\b\n��\u001a\u0004\b\u0011\u0010\u0012R\u0011\u0010\u0002\u001a\u00020\u0003¢\u0006\b\n��\u001a\u0004\b\u0013\u0010\u0012R\u0011\u0010\u0004\u001a\u00020\u0005¢\u0006\b\n��\u001a\u0004\b\u0014\u0010\u000eR\u0011\u0010\u0007\u001a\u00020\b¢\u0006\b\n��\u001a\u0004\b\u0015\u0010\u0010R\u0011\u0010\u0006\u001a\u00020\u0003¢\u0006\b\n��\u001a\u0004\b\u0016\u0010\u0012¨\u0006%"}, d2 = {"Lorg/diffkt/Combinators$PartitionInfo;", "", "frameAxes", "", "frameShape", "Lorg/diffkt/Shape;", "frameStrides", "frameSize", "", "cellShape", "cellStrides", "cellSize", "([ILorg/diffkt/Shape;[IILorg/diffkt/Shape;[II)V", "getCellShape", "()Lorg/diffkt/Shape;", "getCellSize", "()I", "getCellStrides", "()[I", "getFrameAxes", "getFrameShape", "getFrameSize", "getFrameStrides", "component1", "component2", "component3", "component4", "component5", "component6", "component7", "copy", "equals", "", "other", "hashCode", "toString", "", "api"})
    /* loaded from: input_file:org/diffkt/Combinators$PartitionInfo.class */
    public static final class PartitionInfo {

        @NotNull
        private final int[] frameAxes;

        @NotNull
        private final Shape frameShape;

        @NotNull
        private final int[] frameStrides;
        private final int frameSize;

        @NotNull
        private final Shape cellShape;

        @NotNull
        private final int[] cellStrides;
        private final int cellSize;

        public PartitionInfo(@NotNull int[] iArr, @NotNull Shape shape, @NotNull int[] iArr2, int i, @NotNull Shape shape2, @NotNull int[] iArr3, int i2) {
            Intrinsics.checkNotNullParameter(iArr, "frameAxes");
            Intrinsics.checkNotNullParameter(shape, "frameShape");
            Intrinsics.checkNotNullParameter(iArr2, "frameStrides");
            Intrinsics.checkNotNullParameter(shape2, "cellShape");
            Intrinsics.checkNotNullParameter(iArr3, "cellStrides");
            this.frameAxes = iArr;
            this.frameShape = shape;
            this.frameStrides = iArr2;
            this.frameSize = i;
            this.cellShape = shape2;
            this.cellStrides = iArr3;
            this.cellSize = i2;
        }

        @NotNull
        public final int[] getFrameAxes() {
            return this.frameAxes;
        }

        @NotNull
        public final Shape getFrameShape() {
            return this.frameShape;
        }

        @NotNull
        public final int[] getFrameStrides() {
            return this.frameStrides;
        }

        public final int getFrameSize() {
            return this.frameSize;
        }

        @NotNull
        public final Shape getCellShape() {
            return this.cellShape;
        }

        @NotNull
        public final int[] getCellStrides() {
            return this.cellStrides;
        }

        public final int getCellSize() {
            return this.cellSize;
        }

        @NotNull
        public final int[] component1() {
            return this.frameAxes;
        }

        @NotNull
        public final Shape component2() {
            return this.frameShape;
        }

        @NotNull
        public final int[] component3() {
            return this.frameStrides;
        }

        public final int component4() {
            return this.frameSize;
        }

        @NotNull
        public final Shape component5() {
            return this.cellShape;
        }

        @NotNull
        public final int[] component6() {
            return this.cellStrides;
        }

        public final int component7() {
            return this.cellSize;
        }

        @NotNull
        public final PartitionInfo copy(@NotNull int[] iArr, @NotNull Shape shape, @NotNull int[] iArr2, int i, @NotNull Shape shape2, @NotNull int[] iArr3, int i2) {
            Intrinsics.checkNotNullParameter(iArr, "frameAxes");
            Intrinsics.checkNotNullParameter(shape, "frameShape");
            Intrinsics.checkNotNullParameter(iArr2, "frameStrides");
            Intrinsics.checkNotNullParameter(shape2, "cellShape");
            Intrinsics.checkNotNullParameter(iArr3, "cellStrides");
            return new PartitionInfo(iArr, shape, iArr2, i, shape2, iArr3, i2);
        }

        public static /* synthetic */ PartitionInfo copy$default(PartitionInfo partitionInfo, int[] iArr, Shape shape, int[] iArr2, int i, Shape shape2, int[] iArr3, int i2, int i3, Object obj) {
            if ((i3 & 1) != 0) {
                iArr = partitionInfo.frameAxes;
            }
            if ((i3 & 2) != 0) {
                shape = partitionInfo.frameShape;
            }
            if ((i3 & 4) != 0) {
                iArr2 = partitionInfo.frameStrides;
            }
            if ((i3 & 8) != 0) {
                i = partitionInfo.frameSize;
            }
            if ((i3 & 16) != 0) {
                shape2 = partitionInfo.cellShape;
            }
            if ((i3 & 32) != 0) {
                iArr3 = partitionInfo.cellStrides;
            }
            if ((i3 & 64) != 0) {
                i2 = partitionInfo.cellSize;
            }
            return partitionInfo.copy(iArr, shape, iArr2, i, shape2, iArr3, i2);
        }

        @NotNull
        public String toString() {
            return "PartitionInfo(frameAxes=" + Arrays.toString(this.frameAxes) + ", frameShape=" + this.frameShape + ", frameStrides=" + Arrays.toString(this.frameStrides) + ", frameSize=" + this.frameSize + ", cellShape=" + this.cellShape + ", cellStrides=" + Arrays.toString(this.cellStrides) + ", cellSize=" + this.cellSize + ")";
        }

        public int hashCode() {
            return (((((((((((Arrays.hashCode(this.frameAxes) * 31) + this.frameShape.hashCode()) * 31) + Arrays.hashCode(this.frameStrides)) * 31) + Integer.hashCode(this.frameSize)) * 31) + this.cellShape.hashCode()) * 31) + Arrays.hashCode(this.cellStrides)) * 31) + Integer.hashCode(this.cellSize);
        }

        public boolean equals(@Nullable Object obj) {
            if (this == obj) {
                return true;
            }
            if (!(obj instanceof PartitionInfo)) {
                return false;
            }
            PartitionInfo partitionInfo = (PartitionInfo) obj;
            return Intrinsics.areEqual(this.frameAxes, partitionInfo.frameAxes) && Intrinsics.areEqual(this.frameShape, partitionInfo.frameShape) && Intrinsics.areEqual(this.frameStrides, partitionInfo.frameStrides) && this.frameSize == partitionInfo.frameSize && Intrinsics.areEqual(this.cellShape, partitionInfo.cellShape) && Intrinsics.areEqual(this.cellStrides, partitionInfo.cellStrides) && this.cellSize == partitionInfo.cellSize;
        }
    }

    private Combinators() {
    }

    /* JADX WARN: Multi-variable type inference failed */
    @NotNull
    public final FloatTensor reduce(@NotNull FloatTensor floatTensor, @NotNull final Function2<? super Float, ? super Float, Float> function2, @NotNull int[] iArr, boolean z) {
        Intrinsics.checkNotNullParameter(floatTensor, "x");
        Intrinsics.checkNotNullParameter(function2, "f");
        Intrinsics.checkNotNullParameter(iArr, "axes");
        return floatTensor instanceof SparseFloatTensor ? reduceSparse$api((SparseFloatTensor) floatTensor, function2, iArr, z) : reduceImpl$api(floatTensor, new Function3<float[], Integer, Function1<? super Integer, ? extends Integer>, Float>() { // from class: org.diffkt.Combinators$reduce$1
            /* JADX INFO: Access modifiers changed from: package-private */
            /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
            /* JADX WARN: Multi-variable type inference failed */
            {
                super(3);
            }

            @NotNull
            public final Float invoke(@NotNull float[] fArr, int i, @NotNull Function1<? super Integer, Integer> function1) {
                float reduce$cell;
                Intrinsics.checkNotNullParameter(fArr, "data");
                Intrinsics.checkNotNullParameter(function1, "ix");
                reduce$cell = Combinators.reduce$cell(function2, fArr, i, function1);
                return Float.valueOf(reduce$cell);
            }

            public /* bridge */ /* synthetic */ Object invoke(Object obj, Object obj2, Object obj3) {
                return invoke((float[]) obj, ((Number) obj2).intValue(), (Function1<? super Integer, Integer>) obj3);
            }
        }, iArr, z);
    }

    public static /* synthetic */ FloatTensor reduce$default(Combinators combinators, FloatTensor floatTensor, Function2 function2, int[] iArr, boolean z, int i, Object obj) {
        if ((i & 4) != 0) {
            iArr = floatTensor.getAllAxes();
        }
        if ((i & 8) != 0) {
            z = false;
        }
        return combinators.reduce(floatTensor, function2, iArr, z);
    }

    @NotNull
    public final FloatTensor reduceImpl$api(@NotNull FloatTensor floatTensor, @NotNull Function3<? super float[], ? super Integer, ? super Function1<? super Integer, Integer>, Float> function3, @NotNull int[] iArr, boolean z) {
        Shape invoke;
        int[] intArray;
        Intrinsics.checkNotNullParameter(floatTensor, "x");
        Intrinsics.checkNotNullParameter(function3, "f");
        Intrinsics.checkNotNullParameter(iArr, "cellAxes");
        if (iArr.length == 0) {
            return floatTensor;
        }
        final PartitionInfo partitionForReduce = partitionForReduce(floatTensor, iArr);
        final StridedFloatTensor asStrided = floatTensor.asStrided();
        int frameSize = partitionForReduce.getFrameSize();
        float[] fArr = new float[frameSize];
        for (int i = 0; i < frameSize; i++) {
            int i2 = i;
            final int strided$api = StridedUtils.INSTANCE.strided$api(i2, partitionForReduce.getFrameShape(), asStrided.getStrides$api());
            fArr[i2] = ((Number) function3.invoke(asStrided.getData$api(), Integer.valueOf(partitionForReduce.getCellSize()), new Function1<Integer, Integer>() { // from class: org.diffkt.Combinators$reduceImpl$a$1$1
                /* JADX INFO: Access modifiers changed from: package-private */
                /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
                {
                    super(1);
                }

                @NotNull
                public final Integer invoke(int i3) {
                    int reduceImpl$ix;
                    reduceImpl$ix = Combinators.reduceImpl$ix(strided$api, partitionForReduce, asStrided, i3);
                    return Integer.valueOf(reduceImpl$ix);
                }

                public /* bridge */ /* synthetic */ Object invoke(Object obj) {
                    return invoke(((Number) obj).intValue());
                }
            })).floatValue();
        }
        if (z) {
            invoke = partitionForReduce.getFrameShape();
        } else {
            Shape.Companion companion = Shape.Companion;
            int[] frameAxes = partitionForReduce.getFrameAxes();
            ArrayList arrayList = new ArrayList(frameAxes.length);
            for (int i3 : frameAxes) {
                arrayList.add(Integer.valueOf(partitionForReduce.getFrameShape().get(i3)));
            }
            invoke = companion.invoke(arrayList);
        }
        Shape shape = invoke;
        if (z) {
            intArray = partitionForReduce.getFrameStrides();
        } else {
            int[] frameAxes2 = partitionForReduce.getFrameAxes();
            ArrayList arrayList2 = new ArrayList(frameAxes2.length);
            for (int i4 : frameAxes2) {
                arrayList2.add(Integer.valueOf(partitionForReduce.getFrameStrides()[i4]));
            }
            intArray = CollectionsKt.toIntArray(arrayList2);
        }
        return shape.isScalar() ? new FloatScalar(fArr[0]) : new StridedFloatTensor(shape, 0, intArray, fArr, null, 16, null);
    }

    public static /* synthetic */ FloatTensor reduceImpl$api$default(Combinators combinators, FloatTensor floatTensor, Function3 function3, int[] iArr, boolean z, int i, Object obj) {
        if ((i & 8) != 0) {
            z = false;
        }
        return combinators.reduceImpl$api(floatTensor, function3, iArr, z);
    }

    /* JADX WARN: Removed duplicated region for block: B:11:0x005e  */
    /* JADX WARN: Removed duplicated region for block: B:14:0x006f  */
    @org.jetbrains.annotations.NotNull
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    public final org.diffkt.FloatTensor reduceSparse$api(@org.jetbrains.annotations.NotNull org.diffkt.SparseFloatTensor r6, @org.jetbrains.annotations.NotNull kotlin.jvm.functions.Function2<? super java.lang.Float, ? super java.lang.Float, java.lang.Float> r7, @org.jetbrains.annotations.NotNull int[] r8, boolean r9) {
        /*
            Method dump skipped, instructions count: 1360
            To view this dump add '--comments-level debug' option
        */
        throw new UnsupportedOperationException("Method not decompiled: org.diffkt.Combinators.reduceSparse$api(org.diffkt.SparseFloatTensor, kotlin.jvm.functions.Function2, int[], boolean):org.diffkt.FloatTensor");
    }

    public static /* synthetic */ FloatTensor reduceSparse$api$default(Combinators combinators, SparseFloatTensor sparseFloatTensor, Function2 function2, int[] iArr, boolean z, int i, Object obj) {
        if ((i & 4) != 0) {
            iArr = sparseFloatTensor.getAllAxes();
        }
        if ((i & 8) != 0) {
            z = false;
        }
        return combinators.reduceSparse$api(sparseFloatTensor, function2, iArr, z);
    }

    private final PartitionInfo partitionForReduce(FloatTensor floatTensor, int[] iArr) {
        int[] intArray = CollectionsKt.toIntArray(ArraysKt.sorted(iArr));
        if (!(intArray.length == 0)) {
            if (intArray[0] < 0) {
                throw new IllegalAccessException("reduce: invalid axis " + iArr[0]);
            }
            if (ArraysKt.last(intArray) >= floatTensor.getRank()) {
                throw new IllegalAccessException("reduce: invalid axis " + iArr[0] + " for shape " + floatTensor.getShape());
            }
        }
        Iterable indices = floatTensor.getShape().getIndices();
        ArrayList arrayList = new ArrayList();
        for (Object obj : indices) {
            if (!ArraysKt.contains(intArray, ((Number) obj).intValue())) {
                arrayList.add(obj);
            }
        }
        int[] intArray2 = CollectionsKt.toIntArray(arrayList);
        Shape invoke = Shape.Companion.invoke(partitionForReduce$subshape(floatTensor, intArray));
        int[] contigStrides = StridedUtils.INSTANCE.contigStrides(invoke);
        int dataSize = StridedUtils.INSTANCE.dataSize(invoke, contigStrides);
        Shape invoke2 = Shape.Companion.invoke(partitionForReduce$subshape(floatTensor, intArray2));
        int[] contigStrides2 = StridedUtils.INSTANCE.contigStrides(invoke2);
        return new PartitionInfo(intArray2, invoke2, contigStrides2, StridedUtils.INSTANCE.dataSize(invoke2, contigStrides2), invoke, contigStrides, dataSize);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static final float reduce$cell(Function2<? super Float, ? super Float, Float> function2, float[] fArr, int i, Function1<? super Integer, Integer> function1) {
        if (fArr.length == 0) {
            throw new IllegalArgumentException("Empty selection cannot be reduced");
        }
        if (fArr.length == 1 && i == 1) {
            return fArr[((Number) function1.invoke(0)).intValue()];
        }
        float f = fArr[((Number) function1.invoke(0)).intValue()];
        for (int i2 = 1; i2 < i; i2++) {
            f = ((Number) function2.invoke(Float.valueOf(f), Float.valueOf(fArr[((Number) function1.invoke(Integer.valueOf(i2))).intValue()]))).floatValue();
        }
        return f;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static final int reduceImpl$ix(int i, PartitionInfo partitionInfo, StridedFloatTensor stridedFloatTensor, int i2) {
        return i + StridedUtils.INSTANCE.strided$api(i2, partitionInfo.getCellShape(), stridedFloatTensor.getStrides$api()) + stridedFloatTensor.getOffset$api();
    }

    private static final List<Integer> partitionForReduce$subshape(FloatTensor floatTensor, int[] iArr) {
        int rank = floatTensor.getRank();
        ArrayList arrayList = new ArrayList(rank);
        for (int i = 0; i < rank; i++) {
            arrayList.add(1);
        }
        ArrayList arrayList2 = arrayList;
        for (int i2 : iArr) {
            arrayList2.set(i2, Integer.valueOf(floatTensor.getShape().get(i2)));
        }
        return arrayList2;
    }
}
