package org.apache.commons.math4.legacy.optim.nonlinear.scalar.noderiv;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.function.IntSupplier;
import java.util.function.UnaryOperator;
import org.apache.commons.math4.legacy.analysis.MultivariateFunction;
import org.apache.commons.math4.legacy.core.MathArrays;
import org.apache.commons.math4.legacy.exception.MathInternalError;
import org.apache.commons.math4.legacy.exception.MathUnsupportedOperationException;
import org.apache.commons.math4.legacy.exception.util.LocalizedFormats;
import org.apache.commons.math4.legacy.optim.ConvergenceChecker;
import org.apache.commons.math4.legacy.optim.InitialGuess;
import org.apache.commons.math4.legacy.optim.MaxEval;
import org.apache.commons.math4.legacy.optim.OptimizationData;
import org.apache.commons.math4.legacy.optim.PointValuePair;
import org.apache.commons.math4.legacy.optim.SimpleValueChecker;
import org.apache.commons.math4.legacy.optim.nonlinear.scalar.GoalType;
import org.apache.commons.math4.legacy.optim.nonlinear.scalar.MultivariateOptimizer;
import org.apache.commons.math4.legacy.optim.nonlinear.scalar.ObjectiveFunction;
import org.apache.commons.math4.legacy.optim.nonlinear.scalar.PopulationSize;
import org.apache.commons.math4.legacy.optim.nonlinear.scalar.SimulatedAnnealing;
import org.apache.commons.math4.legacy.optim.nonlinear.scalar.noderiv.Simplex;

/* loaded from: input_file:org/apache/commons/math4/legacy/optim/nonlinear/scalar/noderiv/SimplexOptimizer.class */
public class SimplexOptimizer extends MultivariateOptimizer {
    private static final double SIMPLEX_SIDE_RATIO = 0.1d;
    private Simplex.TransformFactory updateRule;
    private Simplex initialSimplex;
    private SimulatedAnnealing simulatedAnnealing;
    private int populationSize;
    private int additionalSearch;
    private final List<Observer> callbacks;

    @FunctionalInterface
    /* loaded from: input_file:org/apache/commons/math4/legacy/optim/nonlinear/scalar/noderiv/SimplexOptimizer$Observer.class */
    public interface Observer {
        void update(Simplex simplex, boolean z, int i);
    }

    public SimplexOptimizer(ConvergenceChecker<PointValuePair> convergenceChecker) {
        super(convergenceChecker);
        this.simulatedAnnealing = null;
        this.populationSize = 0;
        this.additionalSearch = 0;
        this.callbacks = new CopyOnWriteArrayList();
    }

    public SimplexOptimizer(double d, double d2) {
        this(new SimpleValueChecker(d, d2));
    }

    public void addObserver(Observer observer) {
        Objects.requireNonNull(observer, "Callback");
        this.callbacks.add(observer);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.apache.commons.math4.legacy.optim.BaseOptimizer
    public PointValuePair doOptimize() {
        checkParameters();
        MultivariateFunction multivariateFunction = this::computeObjectiveValue;
        boolean z = getGoalType() == GoalType.MINIMIZE;
        Comparator<PointValuePair> comparator = (pointValuePair, pointValuePair2) -> {
            double doubleValue = ((Double) pointValuePair.getValue()).doubleValue();
            double doubleValue2 = ((Double) pointValuePair2.getValue()).doubleValue();
            return z ? Double.compare(doubleValue, doubleValue2) : Double.compare(doubleValue2, doubleValue);
        };
        ArrayList arrayList = new ArrayList();
        Simplex evaluate = this.initialSimplex.translate(getStartPoint()).evaluate(multivariateFunction, comparator);
        notifyObservers(evaluate, true);
        Simplex simplex = null;
        double temperature = this.simulatedAnnealing != null ? temperature(evaluate.get(0), evaluate.get(evaluate.getDimension()), this.simulatedAnnealing.getStartProbability()) : Double.NaN;
        while (true) {
            if (simplex != null && hasConverged(simplex, evaluate)) {
                return evaluate.get(0);
            }
            simplex = evaluate;
            if (this.simulatedAnnealing != null) {
                temperature = this.simulatedAnnealing.getCoolingSchedule().apply(Double.valueOf(temperature), evaluate).doubleValue();
                if (temperature < temperature(evaluate.get(0), evaluate.get(evaluate.getDimension()), this.simulatedAnnealing.getEndProbability())) {
                    if (this.additionalSearch > 0) {
                        return bestListSearch(multivariateFunction, comparator, arrayList, () -> {
                            return getEvaluations();
                        });
                    }
                    throw new MathInternalError();
                }
                UnaryOperator<Simplex> create = this.updateRule.create(multivariateFunction, comparator, this.simulatedAnnealing.metropolis(temperature));
                for (int i = 0; i < this.simulatedAnnealing.getEpochDuration(); i++) {
                    evaluate = applyUpdate(create, evaluate, multivariateFunction, comparator);
                }
            } else {
                evaluate = applyUpdate(this.updateRule.create(multivariateFunction, comparator, null), evaluate, multivariateFunction, comparator);
            }
            if (this.additionalSearch != 0) {
                int max = Math.max(this.additionalSearch, 2);
                for (int i2 = 0; i2 < evaluate.getSize(); i2++) {
                    keepIfBetter(evaluate.get(i2), comparator, arrayList, max);
                }
            }
            incrementIterationCount();
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.apache.commons.math4.legacy.optim.nonlinear.scalar.MultivariateOptimizer, org.apache.commons.math4.legacy.optim.BaseMultivariateOptimizer, org.apache.commons.math4.legacy.optim.BaseOptimizer
    public void parseOptimizationData(OptimizationData... optimizationDataArr) {
        super.parseOptimizationData(optimizationDataArr);
        for (OptimizationData optimizationData : optimizationDataArr) {
            if (optimizationData instanceof Simplex) {
                this.initialSimplex = (Simplex) optimizationData;
            } else if (optimizationData instanceof Simplex.TransformFactory) {
                this.updateRule = (Simplex.TransformFactory) optimizationData;
            } else if (optimizationData instanceof SimulatedAnnealing) {
                this.simulatedAnnealing = (SimulatedAnnealing) optimizationData;
            } else if (optimizationData instanceof PopulationSize) {
                this.populationSize = ((PopulationSize) optimizationData).getPopulationSize();
            }
        }
    }

    private boolean hasConverged(Simplex simplex, Simplex simplex2) {
        ConvergenceChecker<PointValuePair> convergenceChecker = getConvergenceChecker();
        for (int i = 0; i < simplex2.getSize(); i++) {
            if (!convergenceChecker.converged(getIterations(), simplex.get(i), simplex2.get(i))) {
                return false;
            }
        }
        return true;
    }

    private void checkParameters() {
        Objects.requireNonNull(this.updateRule, "Update rule");
        Objects.requireNonNull(this.initialSimplex, "Initial simplex");
        if (getLowerBound() != null || getUpperBound() != null) {
            throw new MathUnsupportedOperationException(LocalizedFormats.CONSTRAINT, new Object[0]);
        }
        if (this.populationSize < 0) {
            throw new IllegalArgumentException("Population size");
        }
        this.additionalSearch = this.simulatedAnnealing == null ? Math.max(0, this.populationSize) : Math.max(1, this.populationSize);
    }

    private double temperature(PointValuePair pointValuePair, PointValuePair pointValuePair2, double d) {
        return (-Math.abs(((Double) pointValuePair.getValue()).doubleValue() - ((Double) pointValuePair2.getValue()).doubleValue())) / Math.log(d);
    }

    private static void keepIfBetter(PointValuePair pointValuePair, Comparator<PointValuePair> comparator, List<PointValuePair> list, int i) {
        int size = list.size();
        double[] point = pointValuePair.getPoint();
        if (size == 0) {
            list.add(pointValuePair);
            return;
        }
        if (size < i) {
            Iterator<PointValuePair> it = list.iterator();
            while (it.hasNext()) {
                if (Arrays.equals(it.next().getPoint(), point)) {
                    return;
                }
            }
            list.add(pointValuePair);
            if (list.size() == i) {
                Collections.sort(list, comparator);
                return;
            }
            return;
        }
        int i2 = i - 1;
        if (comparator.compare(pointValuePair, list.get(i2)) < 0) {
            Iterator<PointValuePair> it2 = list.iterator();
            while (it2.hasNext()) {
                if (Arrays.equals(it2.next().getPoint(), point)) {
                    return;
                }
            }
            list.set(i2, pointValuePair);
            Collections.sort(list, comparator);
        }
    }

    private static double shortestDistance(PointValuePair pointValuePair, List<PointValuePair> list) {
        double d = Double.POSITIVE_INFINITY;
        double[] point = pointValuePair.getPoint();
        Iterator<PointValuePair> it = list.iterator();
        while (it.hasNext()) {
            double[] point2 = it.next().getPoint();
            if (!Arrays.equals(point, point2)) {
                double distance = MathArrays.distance(point, point2);
                if (distance < d) {
                    d = distance;
                }
            }
        }
        return d;
    }

    private PointValuePair bestListSearch(MultivariateFunction multivariateFunction, Comparator<PointValuePair> comparator, List<PointValuePair> list, IntSupplier intSupplier) {
        PointValuePair pointValuePair = list.get(0);
        for (int i = 0; i < this.additionalSearch; i++) {
            PointValuePair pointValuePair2 = list.get(i);
            double shortestDistance = shortestDistance(pointValuePair2, list);
            double[] point = pointValuePair2.getPoint();
            PointValuePair directSearch = directSearch(point, Simplex.equalSidesAlongAxes(point.length, SIMPLEX_SIDE_RATIO * shortestDistance), multivariateFunction, getConvergenceChecker(), getGoalType(), this.callbacks, intSupplier);
            if (comparator.compare(directSearch, pointValuePair) < 0) {
                pointValuePair = directSearch;
            }
        }
        return pointValuePair;
    }

    private static PointValuePair directSearch(double[] dArr, Simplex simplex, MultivariateFunction multivariateFunction, ConvergenceChecker<PointValuePair> convergenceChecker, GoalType goalType, List<Observer> list, IntSupplier intSupplier) {
        SimplexOptimizer simplexOptimizer = new SimplexOptimizer(convergenceChecker);
        for (Observer observer : list) {
            simplexOptimizer.addObserver((simplex2, z, i) -> {
                observer.update(simplex2, z, intSupplier.getAsInt());
            });
        }
        return simplexOptimizer.optimize(MaxEval.unlimited(), new ObjectiveFunction(multivariateFunction), goalType, new InitialGuess(dArr), simplex, new MultiDirectionalTransform());
    }

    private void notifyObservers(Simplex simplex, boolean z) {
        Iterator<Observer> it = this.callbacks.iterator();
        while (it.hasNext()) {
            it.next().update(simplex, z, getEvaluations());
        }
    }

    private Simplex applyUpdate(UnaryOperator<Simplex> unaryOperator, Simplex simplex, MultivariateFunction multivariateFunction, Comparator<PointValuePair> comparator) {
        Simplex evaluate = ((Simplex) unaryOperator.apply(simplex)).evaluate(multivariateFunction, comparator);
        notifyObservers(evaluate, false);
        return evaluate;
    }
}
