package ml.dmlc.xgboost4j.scala.example.flink;

import java.nio.file.Path;
import java.nio.file.Paths;
import ml.dmlc.xgboost4j.java.flink.XGBoost;
import ml.dmlc.xgboost4j.java.flink.XGBoostModel;
import org.apache.flink.api.common.typeinfo.TypeHint;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.operators.SingleInputUdfOperator;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.utils.DataSetUtils;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.Vectors;
import scala.MatchError;
import scala.Predef$;
import scala.collection.ArrayOps$;
import scala.collection.Map;
import scala.jdk.CollectionConverters$;
import scala.reflect.ClassTag$;
import scala.runtime.BoxesRunTime;
import scala.runtime.ScalaRunTime$;

/* compiled from: DistTrainWithFlink.scala */
/* loaded from: input_file:ml/dmlc/xgboost4j/scala/example/flink/DistTrainWithFlink$.class */
public final class DistTrainWithFlink$ {
    public static final DistTrainWithFlink$ MODULE$ = new DistTrainWithFlink$();
    private static final TypeInformation<Tuple2<Vector, Double>> rowTypeHint = TypeInformation.of(new TypeHint<Tuple2<Vector, Double>>() { // from class: ml.dmlc.xgboost4j.scala.example.flink.DistTrainWithFlink$$anon$1
    });
    private static final TypeInformation<Vector> testDataTypeHint = TypeInformation.of(Vector.class);

    private TypeInformation<Tuple2<Vector, Double>> rowTypeHint() {
        return rowTypeHint;
    }

    private TypeInformation<Vector> testDataTypeHint() {
        return testDataTypeHint;
    }

    public DataSet<Tuple2<Long, Tuple2<Vector, Double>>> parseCsv(Path path, ExecutionEnvironment executionEnvironment) {
        return DataSetUtils.zipWithIndex(executionEnvironment.readCsvFile(path.toString()).ignoreFirstLine().types(Double.TYPE, String.class, Double.TYPE, Double.TYPE, Double.TYPE, Integer.class, Integer.class, Integer.class, Integer.class, Integer.class, Integer.class, Integer.class, Integer.class).map(tuple13 -> {
            return new Tuple2(Vectors.dense(new double[]{BoxesRunTime.unboxToDouble(tuple13.f2), BoxesRunTime.unboxToDouble(tuple13.f3), BoxesRunTime.unboxToDouble(tuple13.f4), Predef$.MODULE$.Integer2int((Integer) tuple13.f5), Predef$.MODULE$.Integer2int((Integer) tuple13.f6), Predef$.MODULE$.Integer2int((Integer) tuple13.f7), Predef$.MODULE$.Integer2int((Integer) tuple13.f8), Predef$.MODULE$.Integer2int((Integer) tuple13.f9), Predef$.MODULE$.Integer2int((Integer) tuple13.f10), Predef$.MODULE$.Integer2int((Integer) tuple13.f11), Predef$.MODULE$.Integer2int((Integer) tuple13.f12)}), ((String) tuple13.f1).contains("inf") ? Double.valueOf(1.0d) : Double.valueOf(0.0d));
        }).returns(rowTypeHint()));
    }

    public scala.Tuple2<XGBoostModel, DataSet<float[]>> runPrediction(Path path, int i, ExecutionEnvironment executionEnvironment) {
        DataSet<Tuple2<Long, Tuple2<Vector, Double>>> parseCsv = parseCsv(path, executionEnvironment);
        long round = Math.round(0.01d * i * parseCsv.count());
        SingleInputUdfOperator returns = parseCsv.filter(tuple2 -> {
            return Predef$.MODULE$.Long2long((Long) tuple2.f0) < round;
        }).map(tuple22 -> {
            return (Tuple2) tuple22.f1;
        }).returns(rowTypeHint());
        SingleInputUdfOperator returns2 = parseCsv.filter(tuple23 -> {
            return Predef$.MODULE$.Long2long((Long) tuple23.f0) >= round;
        }).map(tuple24 -> {
            return (Vector) ((Tuple2) tuple24.f1).f0;
        }).returns(testDataTypeHint());
        XGBoostModel train = XGBoost.train(returns, CollectionConverters$.MODULE$.MapHasAsJava((Map) Predef$.MODULE$.Map().apply(ScalaRunTime$.MODULE$.wrapRefArray(new scala.Tuple2[]{new scala.Tuple2("eta", "0.1"), new scala.Tuple2("max_depth", "2"), new scala.Tuple2("objective", "binary:logistic"), new scala.Tuple2("verbosity", "1")}))).asJava(), 2);
        return new scala.Tuple2<>(train, train.predict(returns2).map(fArr -> {
            return (float[]) ArrayOps$.MODULE$.map$extension(Predef$.MODULE$.refArrayOps(fArr), obj -> {
                return BoxesRunTime.boxToFloat(BoxesRunTime.unboxToFloat(obj));
            }, ClassTag$.MODULE$.Float());
        }));
    }

    public void main(String[] strArr) {
        scala.Tuple2<XGBoostModel, DataSet<float[]>> runPrediction = runPrediction(Paths.get((String) ArrayOps$.MODULE$.headOption$extension(Predef$.MODULE$.refArrayOps(strArr)).getOrElse(() -> {
            return ".";
        }), new String[0]).resolve("veterans_lung_cancer.csv"), 70, ExecutionEnvironment.getExecutionEnvironment());
        if (runPrediction == null) {
            throw new MatchError(runPrediction);
        }
        Predef$.MODULE$.println(BoxesRunTime.boxToInteger(CollectionConverters$.MODULE$.ListHasAsScala(((DataSet) runPrediction._2()).collect()).asScala().length()));
    }

    private DistTrainWithFlink$() {
    }
}
