package org.clulab.scala_transformers.encoder;

import org.clulab.scala_transformers.encoder.math.Mathematics$;
import org.clulab.shaded.org.ejml.data.FMatrixRMaj;
import scala.Array$;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.Predef$;
import scala.Tuple2;
import scala.Tuple2$;
import scala.collection.ArrayOps$;
import scala.math.Ordering$DeprecatedFloatOrdering$;
import scala.package$;
import scala.reflect.ClassTag$;
import scala.runtime.BoxesRunTime;
import scala.runtime.RichInt$;
import scala.runtime.Scala3RunTime$;
import scala.runtime.ScalaRunTime$;

/* compiled from: LinearLayer.scala */
/* loaded from: input_file:org/clulab/scala_transformers/encoder/LinearLayer.class */
public class LinearLayer {
    private final String name;
    private final boolean dual;
    private final FMatrixRMaj weights;
    private final Option biasesOpt;
    private final Option labelsOpt;

    public static boolean USE_CONCAT() {
        return LinearLayer$.MODULE$.USE_CONCAT();
    }

    public static LinearLayer fromFiles(String str) {
        return LinearLayer$.MODULE$.fromFiles(str);
    }

    public static LinearLayer fromResources(String str) {
        return LinearLayer$.MODULE$.fromResources(str);
    }

    public LinearLayer(String str, boolean z, FMatrixRMaj fMatrixRMaj, Option<FMatrixRMaj> option, Option<String[]> option2) {
        this.name = str;
        this.dual = z;
        this.weights = fMatrixRMaj;
        this.biasesOpt = option;
        this.labelsOpt = option2;
    }

    public String name() {
        return this.name;
    }

    public boolean dual() {
        return this.dual;
    }

    public FMatrixRMaj weights() {
        return this.weights;
    }

    public Option<FMatrixRMaj> biasesOpt() {
        return this.biasesOpt;
    }

    public Option<String[]> labelsOpt() {
        return this.labelsOpt;
    }

    public FMatrixRMaj forward(FMatrixRMaj fMatrixRMaj) {
        return (FMatrixRMaj) ArrayOps$.MODULE$.head$extension(Predef$.MODULE$.refArrayOps(forward(new FMatrixRMaj[]{fMatrixRMaj})));
    }

    public FMatrixRMaj[] forward(FMatrixRMaj[] fMatrixRMajArr) {
        return (FMatrixRMaj[]) ArrayOps$.MODULE$.map$extension(Predef$.MODULE$.refArrayOps(fMatrixRMajArr), fMatrixRMaj -> {
            FMatrixRMaj mul = Mathematics$.MODULE$.Math().mul(fMatrixRMaj, weights());
            biasesOpt().foreach(fMatrixRMaj -> {
                Mathematics$.MODULE$.Math().inplaceMatrixAddition(mul, fMatrixRMaj);
            });
            return mul;
        }, ClassTag$.MODULE$.apply(FMatrixRMaj.class));
    }

    public String[] predict(FMatrixRMaj fMatrixRMaj, Option<int[]> option, Option<boolean[]> option2) {
        return (String[]) ArrayOps$.MODULE$.head$extension(Predef$.MODULE$.refArrayOps(predict(new FMatrixRMaj[]{fMatrixRMaj}, option.map(iArr -> {
            return (int[][]) Array$.MODULE$.apply(ScalaRunTime$.MODULE$.wrapRefArray((Object[]) new int[]{iArr}), ClassTag$.MODULE$.apply(Integer.TYPE).wrap());
        }), option2.map(zArr -> {
            return (boolean[][]) Array$.MODULE$.apply(ScalaRunTime$.MODULE$.wrapRefArray((Object[]) new boolean[]{zArr}), ClassTag$.MODULE$.apply(Boolean.TYPE).wrap());
        }))));
    }

    public String[][] predict(FMatrixRMaj[] fMatrixRMajArr, Option<int[][]> option, Option<boolean[][]> option2) {
        return dual() ? predictDual(fMatrixRMajArr, option, option2) : predictPrimal(fMatrixRMajArr);
    }

    public Tuple2<String, Object>[][] predictWithScores(FMatrixRMaj fMatrixRMaj, Option<int[][]> option, Option<boolean[]> option2) {
        return (Tuple2[][]) ArrayOps$.MODULE$.head$extension(Predef$.MODULE$.refArrayOps(predictWithScores(new FMatrixRMaj[]{fMatrixRMaj}, option.map(iArr -> {
            return (int[][][]) Array$.MODULE$.apply(ScalaRunTime$.MODULE$.wrapRefArray(new int[][]{iArr}), ClassTag$.MODULE$.apply(Integer.TYPE).wrap().wrap());
        }), option2.map(zArr -> {
            return (boolean[][]) Array$.MODULE$.apply(ScalaRunTime$.MODULE$.wrapRefArray((Object[]) new boolean[]{zArr}), ClassTag$.MODULE$.apply(Boolean.TYPE).wrap());
        }))));
    }

    public Tuple2<String, Object>[][][] predictWithScores(FMatrixRMaj[] fMatrixRMajArr, Option<int[][][]> option, Option<boolean[][]> option2) {
        return dual() ? predictDualWithScores(fMatrixRMajArr, option, option2) : predictPrimalWithScores(fMatrixRMajArr);
    }

    public FMatrixRMaj concatenateModifiersAndHeads(FMatrixRMaj fMatrixRMaj, int[] iArr) {
        FMatrixRMaj zeros = LinearLayer$.MODULE$.USE_CONCAT() ? Mathematics$.MODULE$.Math().zeros(Mathematics$.MODULE$.Math().rows(fMatrixRMaj), 2 * Mathematics$.MODULE$.Math().cols(fMatrixRMaj)) : Mathematics$.MODULE$.Math().zeros(Mathematics$.MODULE$.Math().rows(fMatrixRMaj), Mathematics$.MODULE$.Math().cols(fMatrixRMaj));
        RichInt$.MODULE$.until$extension(Predef$.MODULE$.intWrapper(0), Mathematics$.MODULE$.Math().rows(fMatrixRMaj)).foreach(i -> {
            FMatrixRMaj row = Mathematics$.MODULE$.Math().row(fMatrixRMaj, i);
            int i = i + iArr[i];
            FMatrixRMaj row2 = Mathematics$.MODULE$.Math().row(fMatrixRMaj, (i < 0 || i >= Mathematics$.MODULE$.Math().rows(fMatrixRMaj)) ? i : i);
            if (LinearLayer$.MODULE$.USE_CONCAT()) {
                Mathematics$.MODULE$.Math().inplaceMatrixAddition(zeros, i, Mathematics$.MODULE$.Math().horcat(row, row2));
            } else {
                Mathematics$.MODULE$.Math().inplaceMatrixAddition(zeros, i, row);
                Mathematics$.MODULE$.Math().inplaceMatrixAddition(zeros, i, row2);
            }
        });
        return zeros;
    }

    public FMatrixRMaj concatenateModifierAndHead(FMatrixRMaj fMatrixRMaj, int i, int i2) {
        FMatrixRMaj zeros = LinearLayer$.MODULE$.USE_CONCAT() ? Mathematics$.MODULE$.Math().zeros(1, 2 * Mathematics$.MODULE$.Math().cols(fMatrixRMaj)) : Mathematics$.MODULE$.Math().zeros(1, Mathematics$.MODULE$.Math().cols(fMatrixRMaj));
        FMatrixRMaj row = Mathematics$.MODULE$.Math().row(fMatrixRMaj, i);
        int i3 = i + i2;
        FMatrixRMaj row2 = Mathematics$.MODULE$.Math().row(fMatrixRMaj, (i3 < 0 || i3 >= Mathematics$.MODULE$.Math().rows(fMatrixRMaj)) ? i : i3);
        if (LinearLayer$.MODULE$.USE_CONCAT()) {
            Mathematics$.MODULE$.Math().inplaceMatrixAddition(zeros, 0, Mathematics$.MODULE$.Math().horcat(row, row2));
        } else {
            Mathematics$.MODULE$.Math().inplaceMatrixAddition(zeros, 0, row);
            Mathematics$.MODULE$.Math().inplaceMatrixAddition(zeros, 0, row2);
        }
        return zeros;
    }

    public String[][] predictDual(FMatrixRMaj[] fMatrixRMajArr, Option<int[][]> option, Option<boolean[][]> option2) {
        if (!option.isDefined()) {
            throw Scala3RunTime$.MODULE$.assertFailed();
        }
        if (!option2.isDefined()) {
            throw Scala3RunTime$.MODULE$.assertFailed();
        }
        String[] strArr = (String[]) labelsOpt().getOrElse(LinearLayer::$anonfun$5);
        return (String[][]) ArrayOps$.MODULE$.map$extension(Predef$.MODULE$.refArrayOps(ArrayOps$.MODULE$.zip$extension(Predef$.MODULE$.refArrayOps(fMatrixRMajArr), Predef$.MODULE$.wrapRefArray((Object[]) option.get()))), tuple2 -> {
            if (tuple2 == null) {
                throw new MatchError(tuple2);
            }
            FMatrixRMaj fMatrixRMaj = forward(new FMatrixRMaj[]{concatenateModifiersAndHeads((FMatrixRMaj) tuple2._1(), (int[]) tuple2._2())})[0];
            return (String[]) package$.MODULE$.Range().apply(0, Mathematics$.MODULE$.Math().rows(fMatrixRMaj)).map(obj -> {
                return $anonfun$7(strArr, fMatrixRMaj, BoxesRunTime.unboxToInt(obj));
            }).toArray(ClassTag$.MODULE$.apply(String.class));
        }, ClassTag$.MODULE$.apply(String.class).wrap());
    }

    public Option<int[][]> predictDual$default$2() {
        return None$.MODULE$;
    }

    public Option<boolean[][]> predictDual$default$3() {
        return None$.MODULE$;
    }

    public Tuple2<String, Object>[][][] predictDualWithScores(FMatrixRMaj[] fMatrixRMajArr, Option<int[][][]> option, Option<boolean[][]> option2) {
        if (!option.isDefined()) {
            throw Scala3RunTime$.MODULE$.assertFailed();
        }
        if (!option2.isDefined()) {
            throw Scala3RunTime$.MODULE$.assertFailed();
        }
        String[] strArr = (String[]) labelsOpt().getOrElse(LinearLayer::$anonfun$8);
        return (Tuple2[][][]) ArrayOps$.MODULE$.map$extension(Predef$.MODULE$.refArrayOps(ArrayOps$.MODULE$.zip$extension(Predef$.MODULE$.refArrayOps(fMatrixRMajArr), Predef$.MODULE$.wrapRefArray((Object[]) option.get()))), tuple2 -> {
            if (tuple2 == null) {
                throw new MatchError(tuple2);
            }
            FMatrixRMaj fMatrixRMaj = (FMatrixRMaj) tuple2._1();
            int[][] iArr = (int[][]) tuple2._2();
            return (Tuple2[][]) ArrayOps$.MODULE$.map$extension(Predef$.MODULE$.refArrayOps(ArrayOps$.MODULE$.zipWithIndex$extension(Predef$.MODULE$.refArrayOps(iArr))), tuple2 -> {
                if (tuple2 == null) {
                    throw new MatchError(tuple2);
                }
                int[] iArr2 = (int[]) tuple2._1();
                int unboxToInt = BoxesRunTime.unboxToInt(tuple2._2());
                return (Tuple2[]) ArrayOps$.MODULE$.map$extension(Predef$.MODULE$.intArrayOps(iArr2), obj -> {
                    return $anonfun$9$$anonfun$1$$anonfun$1(strArr, fMatrixRMaj, unboxToInt, BoxesRunTime.unboxToInt(obj));
                }, ClassTag$.MODULE$.apply(Tuple2.class));
            }, ClassTag$.MODULE$.apply(Tuple2.class).wrap());
        }, ClassTag$.MODULE$.apply(Tuple2.class).wrap().wrap());
    }

    public Option<int[][][]> predictDualWithScores$default$2() {
        return None$.MODULE$;
    }

    public Option<boolean[][]> predictDualWithScores$default$3() {
        return None$.MODULE$;
    }

    public String[][] predictPrimal(FMatrixRMaj[] fMatrixRMajArr) {
        String[] strArr = (String[]) labelsOpt().getOrElse(LinearLayer::$anonfun$10);
        return (String[][]) ArrayOps$.MODULE$.map$extension(Predef$.MODULE$.refArrayOps(forward(fMatrixRMajArr)), fMatrixRMaj -> {
            return (String[]) package$.MODULE$.Range().apply(0, Mathematics$.MODULE$.Math().rows(fMatrixRMaj)).map(obj -> {
                return $anonfun$12(strArr, fMatrixRMaj, BoxesRunTime.unboxToInt(obj));
            }).toArray(ClassTag$.MODULE$.apply(String.class));
        }, ClassTag$.MODULE$.apply(String.class).wrap());
    }

    public Tuple2<String, Object>[][][] predictPrimalWithScores(FMatrixRMaj[] fMatrixRMajArr) {
        String[] strArr = (String[]) labelsOpt().getOrElse(LinearLayer::$anonfun$13);
        return (Tuple2[][][]) ArrayOps$.MODULE$.map$extension(Predef$.MODULE$.refArrayOps(forward(fMatrixRMajArr)), fMatrixRMaj -> {
            return (Tuple2[][]) package$.MODULE$.Range().apply(0, Mathematics$.MODULE$.Math().rows(fMatrixRMaj)).map(obj -> {
                return $anonfun$15(strArr, fMatrixRMaj, BoxesRunTime.unboxToInt(obj));
            }).toArray(ClassTag$.MODULE$.apply(Tuple2.class).wrap());
        }, ClassTag$.MODULE$.apply(Tuple2.class).wrap().wrap());
    }

    private static final String[] $anonfun$5() {
        throw new RuntimeException("ERROR: can't predict without labels!");
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static final /* synthetic */ String $anonfun$7(String[] strArr, FMatrixRMaj fMatrixRMaj, int i) {
        return strArr[Mathematics$.MODULE$.Math().argmax(Mathematics$.MODULE$.Math().row(fMatrixRMaj, i))];
    }

    private static final String[] $anonfun$8() {
        throw new RuntimeException("ERROR: can't predict without labels!");
    }

    private final /* synthetic */ Tuple2 $anonfun$9$$anonfun$1$$anonfun$1(String[] strArr, FMatrixRMaj fMatrixRMaj, int i, int i2) {
        FMatrixRMaj row = Mathematics$.MODULE$.Math().row(forward(new FMatrixRMaj[]{concatenateModifierAndHead(fMatrixRMaj, i, i2)})[0], 0);
        int argmax = Mathematics$.MODULE$.Math().argmax(row);
        return Tuple2$.MODULE$.apply(strArr[argmax], BoxesRunTime.boxToFloat(Mathematics$.MODULE$.Math().get(row, argmax)));
    }

    private static final String[] $anonfun$10() {
        throw new RuntimeException("ERROR: can't predict without labels!");
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static final /* synthetic */ String $anonfun$12(String[] strArr, FMatrixRMaj fMatrixRMaj, int i) {
        return strArr[Mathematics$.MODULE$.Math().argmax(Mathematics$.MODULE$.Math().row(fMatrixRMaj, i))];
    }

    private static final String[] $anonfun$13() {
        throw new RuntimeException("ERROR: can't predict without labels!");
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static final /* synthetic */ Tuple2[] $anonfun$15(String[] strArr, FMatrixRMaj fMatrixRMaj, int i) {
        float[] array = Mathematics$.MODULE$.Math().toArray(Mathematics$.MODULE$.Math().row(fMatrixRMaj, i));
        return (Tuple2[]) ArrayOps$.MODULE$.sortBy$extension(Predef$.MODULE$.refArrayOps(ArrayOps$.MODULE$.zip$extension(Predef$.MODULE$.refArrayOps(strArr), Predef$.MODULE$.wrapFloatArray(array))), tuple2 -> {
            return -BoxesRunTime.unboxToFloat(tuple2._2());
        }, Ordering$DeprecatedFloatOrdering$.MODULE$);
    }
}
