package us.ihmc.behaviors.sharedControl;

import com.fasterxml.jackson.databind.JsonNode;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
import us.ihmc.euclid.geometry.interfaces.Pose3DReadOnly;
import us.ihmc.euclid.orientation.interfaces.Orientation3DReadOnly;
import us.ihmc.euclid.referenceFrame.FramePose3D;
import us.ihmc.euclid.referenceFrame.ReferenceFrame;
import us.ihmc.euclid.referenceFrame.interfaces.FixedFramePoint3DBasics;
import us.ihmc.euclid.referenceFrame.interfaces.FixedFrameQuaternionBasics;
import us.ihmc.euclid.transform.RigidBodyTransform;
import us.ihmc.euclid.tuple3D.Point3D;
import us.ihmc.euclid.tuple3D.interfaces.Tuple3DReadOnly;
import us.ihmc.euclid.tuple4D.Quaternion;
import us.ihmc.log.LogTools;
import us.ihmc.tools.io.JSONFileTools;

/* loaded from: input_file:us/ihmc/behaviors/sharedControl/ProMPAssistant.class */
public class ProMPAssistant {
    private static final int INTERPOLATION_SAMPLES = 10;
    private FramePose3D taskGoalPose;
    private ReferenceFrame objectFrame;
    private final HashMap<String, ProMPManager> proMPManagers = new HashMap<>();
    private final HashMap<String, List<String>> contextTasksMap = new HashMap<>();
    private final List<Double> distanceCandidateTasks = new ArrayList();
    private boolean firstObservedBodyPart = true;
    private String currentTask = "";
    private int numberObservations = 0;
    private final HashMap<String, Point3D[]> initialStdDeviations = new HashMap<>();
    private final HashMap<String, Point3D[]> initialMeans = new HashMap<>();
    private String bodyPartInference = "";
    private String bodyPartGoal = "";
    private final HashMap<String, String> taskBodyPartInferenceMap = new HashMap<>();
    private final HashMap<String, String> taskBodyPartGoalMap = new HashMap<>();
    private final HashMap<String, RigidBodyTransform> taskTransformGoalMap = new HashMap<>();
    private final HashMap<String, List<FramePose3D>> bodyPartObservedTrajectoryMap = new HashMap<>();
    private final HashMap<String, List<FramePose3D>> bodyPartGeneratedTrajectoryMap = new HashMap<>();
    private final HashMap<String, Integer> bodyPartTrajectorySampleCounter = new HashMap<>();
    private boolean doneInitialProcessingTask = false;
    private boolean doneCurrentTask = false;
    private final AtomicBoolean isLastViaPoint = new AtomicBoolean(false);
    private int testNumber = 0;
    private int numberOfInferredSpeeds = 0;
    private boolean conditionOnlyLastObservation = true;
    private final ArrayList<Pose3DReadOnly> observationRecognition = new ArrayList<>();
    private boolean isMoving = false;
    private double isMovingThreshold = -1.0d;

    public ProMPAssistant() {
        InputStream resourceAsStream = getClass().getClassLoader().getResourceAsStream("us/ihmc/behaviors/sharedControl/ProMPAssistant.json");
        if (resourceAsStream == null) {
            LogTools.info("File path is null");
            return;
        }
        JSONFileTools.load(resourceAsStream, jsonNode -> {
            this.testNumber = jsonNode.get("testNumberUseOnlyForTesting").asInt();
            boolean asBoolean = jsonNode.get("logging").asBoolean();
            this.numberObservations = jsonNode.get("numberObservations").asInt();
            this.conditionOnlyLastObservation = jsonNode.get("conditionOnlyLastObservation").asBoolean();
            int asInt = jsonNode.get("numberBasisFunctions").asInt();
            long asLong = jsonNode.get("allowedIncreaseDecreaseSpeedFactor").asLong();
            this.numberOfInferredSpeeds = jsonNode.get("numberOfInferredSpeeds").asInt();
            JsonNode jsonNode = jsonNode.get("tasks");
            int size = jsonNode.size();
            String[] strArr = new String[size];
            String[] strArr2 = new String[size];
            String[] strArr3 = new String[size];
            HashMap[] hashMapArr = new HashMap[size];
            Tuple3DReadOnly[] tuple3DReadOnlyArr = new Point3D[size];
            Orientation3DReadOnly[] orientation3DReadOnlyArr = new Quaternion[size];
            for (int i = 0; i < size; i++) {
                JsonNode jsonNode2 = jsonNode.get(i);
                String asText = jsonNode2.get("context").asText();
                if (!this.contextTasksMap.containsKey(asText)) {
                    this.contextTasksMap.put(asText, new ArrayList());
                }
                strArr[i] = jsonNode2.get("name").asText();
                this.contextTasksMap.get(asText).add(strArr[i]);
                strArr2[i] = jsonNode2.get("bodyPartForInference").asText();
                strArr3[i] = jsonNode2.get("bodyPartWithObservableGoal").asText();
                JsonNode jsonNode3 = jsonNode2.get("translationGoalToEE");
                tuple3DReadOnlyArr[i] = new Point3D(jsonNode3.get(0).asDouble(), jsonNode3.get(1).asDouble(), jsonNode3.get(2).asDouble());
                JsonNode jsonNode4 = jsonNode2.get("rotationGoalToEE");
                orientation3DReadOnlyArr[i] = new Quaternion(jsonNode4.get(0).asDouble(), jsonNode4.get(1).asDouble(), jsonNode4.get(2).asDouble(), jsonNode4.get(3).asDouble());
                JsonNode jsonNode5 = jsonNode2.get("bodyParts");
                HashMap hashMap = new HashMap();
                Iterator it = jsonNode5.iterator();
                while (it.hasNext()) {
                    JsonNode jsonNode6 = (JsonNode) it.next();
                    hashMap.put(jsonNode6.get("name").asText(), jsonNode6.get("geometry").asText());
                }
                hashMapArr[i] = hashMap;
                this.proMPManagers.put(strArr[i], new ProMPManager(strArr[i], hashMapArr[i], asBoolean, this.isLastViaPoint, asInt, asLong, this.numberOfInferredSpeeds));
                this.taskBodyPartInferenceMap.put(strArr[i], strArr2[i]);
                this.taskBodyPartGoalMap.put(strArr[i], strArr3[i]);
                this.taskTransformGoalMap.put(strArr[i], new RigidBodyTransform(orientation3DReadOnlyArr[i], tuple3DReadOnlyArr[i]));
                LogTools.info("Loading ProMPs for tasks:");
                LogTools.info("{}", strArr[i]);
                for (HashMap hashMap2 : hashMapArr) {
                    for (String str : hashMap2.keySet()) {
                        LogTools.info("     {} {}", str, hashMap2.get(str));
                    }
                }
            }
            Iterator<ProMPManager> it2 = this.proMPManagers.values().iterator();
            while (it2.hasNext()) {
                it2.next().loadTaskFromDemos();
            }
            LogTools.info("ProMPs are ready to be used!");
        });
        try {
            resourceAsStream.close();
        } catch (IOException e) {
            LogTools.info(e);
        }
    }

    public void framePoseToPack(FramePose3D framePose3D, String str) {
        if (this.proMPManagers.get(this.currentTask).getBodyPartsGeometry().containsKey(str)) {
            List<FramePose3D> list = this.bodyPartGeneratedTrajectoryMap.get(str);
            int intValue = this.bodyPartTrajectorySampleCounter.get(str).intValue();
            if (intValue < this.numberObservations) {
                FramePose3D framePose3D2 = this.bodyPartObservedTrajectoryMap.get(str).get(intValue);
                if (this.objectFrame != null) {
                    framePose3D2.changeFrame(ReferenceFrame.getWorldFrame());
                }
                framePose3D.getPosition().set(framePose3D2.getPosition());
                framePose3D.getOrientation().set(framePose3D2.getOrientation());
                this.bodyPartTrajectorySampleCounter.replace(str, Integer.valueOf(intValue + 1));
                return;
            }
            if (intValue >= list.size() - 1) {
                if (list.size() < this.numberObservations) {
                    LogTools.warn("The predicted motion results being faster than the time set to observe it. You can either decrease the number of required observations or increase the range of possible inferred speeds in {}", "us/ihmc/behaviors/sharedControl/ProMPAssistant.json");
                } else {
                    FramePose3D framePose3D3 = list.get(list.size() - 1);
                    if (this.objectFrame != null) {
                        framePose3D3.changeFrame(ReferenceFrame.getWorldFrame());
                    }
                    framePose3D.getPosition().set(framePose3D3.getPosition());
                    framePose3D.getOrientation().set(framePose3D3.getOrientation());
                }
                LogTools.info("Assistance completed");
                this.doneCurrentTask = true;
                return;
            }
            FramePose3D framePose3D4 = list.get(intValue + 1);
            List<FramePose3D> list2 = this.bodyPartObservedTrajectoryMap.get(str);
            FramePose3D framePose3D5 = list2.get(list2.size() - 1);
            if (this.objectFrame != null) {
                framePose3D4.changeFrame(ReferenceFrame.getWorldFrame());
                framePose3D5.changeFrame(ReferenceFrame.getWorldFrame());
            }
            double d = (intValue - this.numberObservations) / 10.0d;
            if (d > 1.0d) {
                framePose3D.getPosition().set(framePose3D4.getPosition());
                framePose3D.getOrientation().set(framePose3D4.getOrientation());
            } else {
                FixedFrameQuaternionBasics orientation = framePose3D.getOrientation();
                orientation.set(((1.0d - d) * framePose3D5.getOrientation().getX()) + (d * framePose3D4.getOrientation().getX()), ((1.0d - d) * framePose3D5.getOrientation().getY()) + (d * framePose3D4.getOrientation().getY()), ((1.0d - d) * framePose3D5.getOrientation().getZ()) + (d * framePose3D4.getOrientation().getZ()), ((1.0d - d) * framePose3D5.getOrientation().getS()) + (d * framePose3D4.getOrientation().getS()));
                FixedFramePoint3DBasics position = framePose3D.getPosition();
                position.setX(((1.0d - d) * framePose3D5.getPosition().getX()) + (d * framePose3D4.getPosition().getX()));
                position.setY(((1.0d - d) * framePose3D5.getPosition().getY()) + (d * framePose3D4.getPosition().getY()));
                position.setZ(((1.0d - d) * framePose3D5.getPosition().getZ()) + (d * framePose3D4.getPosition().getZ()));
                framePose3D.getPosition().set(position);
                framePose3D.getOrientation().set(orientation);
            }
            this.bodyPartTrajectorySampleCounter.replace(str, Integer.valueOf(intValue + 1));
        }
    }

    public void processFrameAndObjectInformation(Pose3DReadOnly pose3DReadOnly, String str, String str2, ReferenceFrame referenceFrame) {
        if (taskDetected(pose3DReadOnly, str, str2, referenceFrame) && this.proMPManagers.get(this.currentTask).getBodyPartsGeometry().containsKey(str)) {
            FramePose3D framePose3D = new FramePose3D();
            framePose3D.getPosition().set(pose3DReadOnly.getPosition().getX(), pose3DReadOnly.getPosition().getY(), pose3DReadOnly.getPosition().getZ());
            framePose3D.getOrientation().set(pose3DReadOnly.getOrientation().getX(), pose3DReadOnly.getOrientation().getY(), pose3DReadOnly.getOrientation().getZ(), pose3DReadOnly.getOrientation().getS());
            if (!this.bodyPartGoal.isEmpty()) {
                this.objectFrame = referenceFrame;
                framePose3D.changeFrame(referenceFrame);
            }
            if (userIsMoving(framePose3D, str)) {
                this.bodyPartObservedTrajectoryMap.get(str).add(framePose3D);
                if (this.bodyPartObservedTrajectoryMap.get(str).size() > this.numberObservations) {
                    updateTask();
                    generateTaskTrajectories();
                    this.doneInitialProcessingTask = true;
                    LogTools.info("Generated prediction");
                }
            }
        }
    }

    public void processFrameAndObjectInformation(Pose3DReadOnly pose3DReadOnly, String str, String str2, FramePose3D framePose3D) {
        if (taskDetected(pose3DReadOnly, str, str2, null) && this.proMPManagers.get(this.currentTask).getBodyPartsGeometry().containsKey(str)) {
            FramePose3D framePose3D2 = new FramePose3D();
            framePose3D2.getPosition().set(pose3DReadOnly.getPosition().getX(), pose3DReadOnly.getPosition().getY(), pose3DReadOnly.getPosition().getZ());
            framePose3D2.getOrientation().set(pose3DReadOnly.getOrientation().getX(), pose3DReadOnly.getOrientation().getY(), pose3DReadOnly.getOrientation().getZ(), pose3DReadOnly.getOrientation().getS());
            if (userIsMoving(framePose3D2, str)) {
                this.bodyPartObservedTrajectoryMap.get(str).add(framePose3D2);
                if (this.bodyPartObservedTrajectoryMap.get(str).size() > this.numberObservations + 1) {
                    if (!this.bodyPartGoal.isEmpty() && framePose3D != null) {
                        this.taskGoalPose = new FramePose3D(framePose3D);
                        this.taskGoalPose.appendTransform(this.taskTransformGoalMap.get(this.currentTask));
                        if (Math.signum(this.proMPManagers.get(this.currentTask).getMeanEndValueQS() * this.taskGoalPose.getOrientation().getS()) == -1.0d) {
                            this.taskGoalPose.getOrientation().negate();
                        }
                    }
                    updateTask();
                    generateTaskTrajectories();
                    this.doneInitialProcessingTask = true;
                    LogTools.info("Generated prediction");
                }
            }
        }
    }

    private boolean taskDetected(Pose3DReadOnly pose3DReadOnly, String str, String str2, ReferenceFrame referenceFrame) {
        if (this.currentTask.isEmpty()) {
            if (!this.contextTasksMap.containsKey(str2)) {
                this.firstObservedBodyPart = true;
                LogTools.info("Detected object ({}) does not have any associated learned policy for assistance", str2);
                return false;
            }
            List<String> list = this.contextTasksMap.get(str2);
            FramePose3D framePose3D = new FramePose3D(pose3DReadOnly);
            if (referenceFrame != null) {
                framePose3D.changeFrame(referenceFrame);
            }
            if (list.size() > 1) {
                for (int i = 0; i < list.size(); i++) {
                    if (this.firstObservedBodyPart) {
                        this.distanceCandidateTasks.add(Double.valueOf(this.proMPManagers.get(list.get(i)).computeInitialDistance(framePose3D, str)));
                        this.firstObservedBodyPart = false;
                    } else {
                        this.distanceCandidateTasks.set(i, Double.valueOf(this.distanceCandidateTasks.get(i).doubleValue() + this.proMPManagers.get(list.get(i)).computeInitialDistance(framePose3D, str)));
                    }
                }
                if (!this.firstObservedBodyPart) {
                    this.currentTask = list.get(getMinIndex(this.distanceCandidateTasks));
                    this.bodyPartInference = this.taskBodyPartInferenceMap.get(this.currentTask);
                    this.bodyPartGoal = this.taskBodyPartGoalMap.get(this.currentTask);
                    this.proMPManagers.get(this.currentTask).getBodyPartsGeometry().keySet().forEach(str3 -> {
                        this.bodyPartObservedTrajectoryMap.put(str3, new ArrayList());
                    });
                }
            } else {
                this.currentTask = list.get(0);
                this.bodyPartInference = this.taskBodyPartInferenceMap.get(this.currentTask);
                this.bodyPartGoal = this.taskBodyPartGoalMap.get(this.currentTask);
                this.proMPManagers.get(this.currentTask).getBodyPartsGeometry().keySet().forEach(str4 -> {
                    this.bodyPartObservedTrajectoryMap.put(str4, new ArrayList());
                });
            }
            if (!this.currentTask.isEmpty()) {
                LogTools.info("Found task! {}", this.currentTask);
                for (Map.Entry<String, String> entry : this.proMPManagers.get(this.currentTask).getBodyPartsGeometry().entrySet()) {
                    this.initialStdDeviations.put(entry.getKey(), this.proMPManagers.get(this.currentTask).generateStdDeviationTrajectory(entry.getKey()));
                    this.initialMeans.put(entry.getKey(), this.proMPManagers.get(this.currentTask).generateMeanTrajectory(entry.getKey(), referenceFrame));
                }
            }
        }
        return !this.currentTask.isEmpty();
    }

    public static int getMinIndex(List<Double> list) {
        if (list == null || list.isEmpty()) {
            throw new IllegalArgumentException("List cannot be null or empty.");
        }
        double doubleValue = list.get(0).doubleValue();
        int i = 0;
        for (int i2 = 1; i2 < list.size(); i2++) {
            double doubleValue2 = list.get(i2).doubleValue();
            if (doubleValue2 < doubleValue) {
                doubleValue = doubleValue2;
                i = i2;
            }
        }
        return i;
    }

    private boolean userIsMoving(Pose3DReadOnly pose3DReadOnly, String str) {
        if (str.equals(this.bodyPartInference) && !this.isMoving) {
            this.observationRecognition.add(pose3DReadOnly);
            if (this.observationRecognition.size() > 1) {
                double distance = this.observationRecognition.get(this.observationRecognition.size() - 1).getTranslation().distance(this.observationRecognition.get(0).getTranslation());
                if (this.isMovingThreshold < 0.0d) {
                    this.isMoving = distance > 0.04d;
                } else {
                    this.isMoving = distance > this.isMovingThreshold;
                }
                LogTools.info("Is user moving? {}, body part moved by {}[m]", Boolean.valueOf(this.isMoving), Double.valueOf(distance));
            }
        }
        return this.isMoving;
    }

    private void updateTask() {
        if (this.numberOfInferredSpeeds > 0) {
            this.proMPManagers.get(this.currentTask).updateTaskSpeed(this.bodyPartObservedTrajectoryMap.get(this.bodyPartInference), this.bodyPartInference);
        }
        for (Map.Entry<String, List<FramePose3D>> entry : this.bodyPartObservedTrajectoryMap.entrySet()) {
            List<FramePose3D> value = entry.getValue();
            if (!this.conditionOnlyLastObservation) {
                for (int i = 0; i < value.size(); i++) {
                    if (i == value.size() - 1) {
                        this.isLastViaPoint.set(true);
                    }
                    this.proMPManagers.get(this.currentTask).updateTaskTrajectory(entry.getKey(), (Pose3DReadOnly) value.get(i), i);
                }
            } else if (value.size() > 0) {
                this.isLastViaPoint.set(true);
                this.proMPManagers.get(this.currentTask).updateTaskTrajectory(entry.getKey(), (Pose3DReadOnly) value.get(value.size() - 1), value.size() - 1);
            }
        }
        if (this.taskGoalPose != null) {
            this.proMPManagers.get(this.currentTask).updateTaskTrajectoryGoal(this.bodyPartGoal, this.taskGoalPose);
        }
    }

    private void generateTaskTrajectories() {
        ReferenceFrame worldFrame = this.objectFrame != null ? this.objectFrame : ReferenceFrame.getWorldFrame();
        for (String str : this.bodyPartObservedTrajectoryMap.keySet()) {
            this.bodyPartGeneratedTrajectoryMap.put(str, this.proMPManagers.get(this.currentTask).generateTaskTrajectory(str, worldFrame));
            setStartTrajectories(this.numberObservations + 1);
        }
    }

    public void setStartTrajectories(int i) {
        this.doneCurrentTask = false;
        Iterator<String> it = this.bodyPartObservedTrajectoryMap.keySet().iterator();
        while (it.hasNext()) {
            this.bodyPartTrajectorySampleCounter.put(it.next(), Integer.valueOf(i));
        }
    }

    public boolean readyToPack() {
        return this.doneInitialProcessingTask;
    }

    public void reset() {
        if (!this.currentTask.isEmpty()) {
            this.proMPManagers.get(this.currentTask).resetTask();
            this.currentTask = "";
        }
        this.doneCurrentTask = false;
        this.taskGoalPose = null;
        this.objectFrame = null;
        this.bodyPartObservedTrajectoryMap.clear();
        this.bodyPartGeneratedTrajectoryMap.clear();
        this.bodyPartTrajectorySampleCounter.clear();
        this.initialStdDeviations.clear();
        this.initialMeans.clear();
        this.observationRecognition.clear();
        this.doneInitialProcessingTask = false;
        this.isLastViaPoint.set(false);
        this.isMoving = false;
        this.firstObservedBodyPart = true;
        this.distanceCandidateTasks.clear();
    }

    public int getTestNumber() {
        return this.testNumber;
    }

    public void setIsMovingThreshold(double d) {
        this.isMovingThreshold = d;
    }

    public boolean isCurrentTaskDone() {
        return this.doneCurrentTask;
    }

    public ProMPManager getProMPManager(String str) {
        return this.proMPManagers.get(str);
    }

    public Point3D[] getInitialStdDeviation(String str) {
        return this.initialStdDeviations.get(str);
    }

    public Point3D[] getInitialMean(String str) {
        return this.initialMeans.get(str);
    }

    public Set<String> getTaskNames() {
        return this.proMPManagers.keySet();
    }

    public boolean startedProcessing() {
        return !this.currentTask.isEmpty();
    }
}
