package org.apache.commons.statistics.distribution;

import java.util.function.DoubleSupplier;
import org.apache.commons.numbers.gamma.Erf;
import org.apache.commons.numbers.gamma.ErfDifference;
import org.apache.commons.numbers.gamma.Erfcx;
import org.apache.commons.rng.UniformRandomProvider;
import org.apache.commons.rng.sampling.distribution.ZigguratSampler;
import org.apache.commons.statistics.distribution.ContinuousDistribution;

/* loaded from: input_file:org/apache/commons/statistics/distribution/TruncatedNormalDistribution.class */
public final class TruncatedNormalDistribution extends AbstractContinuousDistribution {
    private static final double MAX_X = 1.3407807929942596E154d;
    private static final double MIN_P = 0.0d;
    private static final double ROOT2 = 1.4142135623730951d;
    private static final double ROOT_2_PI = 0.7978845608028654d;
    private static final double ROOT_PI_2 = 1.2533141373155003d;
    private static final double REJECTION_THRESHOLD = 0.2d;
    private final NormalDistribution parentNormal;
    private final double lower;
    private final double upper;
    private final double cdfDelta;
    private final double logCdfDelta;
    private final double cdfAlpha;
    private final double sfBeta;

    private TruncatedNormalDistribution(NormalDistribution normalDistribution, double d, double d2, double d3) {
        this.parentNormal = normalDistribution;
        this.lower = d2;
        this.upper = d3;
        this.cdfDelta = d;
        this.logCdfDelta = Math.log(this.cdfDelta);
        this.cdfAlpha = this.parentNormal.cumulativeProbability(d2);
        this.sfBeta = this.parentNormal.survivalProbability(d3);
    }

    public static TruncatedNormalDistribution of(double d, double d2, double d3, double d4) {
        if (d2 <= MIN_P) {
            throw new DistributionException("Number %s is not greater than 0", Double.valueOf(d2));
        }
        if (d3 >= d4) {
            throw new DistributionException("Lower bound %s >= upper bound %s", Double.valueOf(d3), Double.valueOf(d4));
        }
        NormalDistribution of = NormalDistribution.of(d, d2);
        double probability = of.probability(d3, d4);
        if (probability <= MIN_P) {
            throw new DistributionException("Excess truncation of standard normal : CDF(%s, %s) = %s", Double.valueOf((d3 - d) / d2), Double.valueOf((d4 - d) / d2), Double.valueOf(probability));
        }
        return new TruncatedNormalDistribution(of, probability, d3, d4);
    }

    @Override // org.apache.commons.statistics.distribution.ContinuousDistribution
    public double density(double d) {
        return (d < this.lower || d > this.upper) ? MIN_P : this.parentNormal.density(d) / this.cdfDelta;
    }

    @Override // org.apache.commons.statistics.distribution.AbstractContinuousDistribution, org.apache.commons.statistics.distribution.ContinuousDistribution
    public double probability(double d, double d2) {
        if (d > d2) {
            throw new DistributionException("Lower bound %s > upper bound %s", Double.valueOf(d), Double.valueOf(d2));
        }
        return this.parentNormal.probability(clipToRange(d), clipToRange(d2)) / this.cdfDelta;
    }

    @Override // org.apache.commons.statistics.distribution.ContinuousDistribution
    public double logDensity(double d) {
        if (d < this.lower || d > this.upper) {
            return Double.NEGATIVE_INFINITY;
        }
        return this.parentNormal.logDensity(d) - this.logCdfDelta;
    }

    @Override // org.apache.commons.statistics.distribution.ContinuousDistribution
    public double cumulativeProbability(double d) {
        if (d <= this.lower) {
            return MIN_P;
        }
        if (d >= this.upper) {
            return 1.0d;
        }
        return this.parentNormal.probability(this.lower, d) / this.cdfDelta;
    }

    @Override // org.apache.commons.statistics.distribution.ContinuousDistribution
    public double survivalProbability(double d) {
        if (d <= this.lower) {
            return 1.0d;
        }
        return d >= this.upper ? MIN_P : this.parentNormal.probability(d, this.upper) / this.cdfDelta;
    }

    @Override // org.apache.commons.statistics.distribution.AbstractContinuousDistribution, org.apache.commons.statistics.distribution.ContinuousDistribution
    public double inverseCumulativeProbability(double d) {
        ArgumentUtils.checkProbability(d);
        return d == MIN_P ? this.lower : d == 1.0d ? this.upper : clipToRange(this.parentNormal.inverseCumulativeProbability(this.cdfAlpha + (d * this.cdfDelta)));
    }

    @Override // org.apache.commons.statistics.distribution.AbstractContinuousDistribution, org.apache.commons.statistics.distribution.ContinuousDistribution
    public double inverseSurvivalProbability(double d) {
        ArgumentUtils.checkProbability(d);
        return d == 1.0d ? this.lower : d == MIN_P ? this.upper : clipToRange(this.parentNormal.inverseSurvivalProbability(this.sfBeta + (d * this.cdfDelta)));
    }

    @Override // org.apache.commons.statistics.distribution.AbstractContinuousDistribution, org.apache.commons.statistics.distribution.ContinuousDistribution
    public ContinuousDistribution.Sampler createSampler(UniformRandomProvider uniformRandomProvider) {
        DoubleSupplier doubleSupplier;
        double d = 0.2d;
        if (this.lower >= MIN_P || this.upper <= MIN_P) {
            d = REJECTION_THRESHOLD * 0.5d;
        }
        if (this.cdfDelta <= d) {
            return super.createSampler(uniformRandomProvider);
        }
        ZigguratSampler.NormalizedGaussian of = ZigguratSampler.NormalizedGaussian.of(uniformRandomProvider);
        if (this.lower >= MIN_P) {
            doubleSupplier = () -> {
                return Math.abs(of.sample());
            };
        } else if (this.upper <= MIN_P) {
            doubleSupplier = () -> {
                return -Math.abs(of.sample());
            };
        } else {
            of.getClass();
            doubleSupplier = of::sample;
        }
        double mean = this.parentNormal.getMean();
        double standardDeviation = this.parentNormal.getStandardDeviation();
        double d2 = (this.lower - mean) / standardDeviation;
        double d3 = (this.upper - mean) / standardDeviation;
        DoubleSupplier doubleSupplier2 = doubleSupplier;
        return () -> {
            double asDouble = doubleSupplier2.getAsDouble();
            while (true) {
                double d4 = asDouble;
                if (d4 >= d2 && d4 <= d3) {
                    return clipToRange(mean + (d4 * standardDeviation));
                }
                asDouble = doubleSupplier2.getAsDouble();
            }
        };
    }

    @Override // org.apache.commons.statistics.distribution.ContinuousDistribution
    public double getMean() {
        double mean = this.parentNormal.getMean();
        double standardDeviation = this.parentNormal.getStandardDeviation();
        return mean + (moment1((this.lower - mean) / standardDeviation, (this.upper - mean) / standardDeviation) * standardDeviation);
    }

    @Override // org.apache.commons.statistics.distribution.ContinuousDistribution
    public double getVariance() {
        double mean = this.parentNormal.getMean();
        double standardDeviation = this.parentNormal.getStandardDeviation();
        return variance((this.lower - mean) / standardDeviation, (this.upper - mean) / standardDeviation) * standardDeviation * standardDeviation;
    }

    @Override // org.apache.commons.statistics.distribution.ContinuousDistribution
    public double getSupportLowerBound() {
        return this.lower;
    }

    @Override // org.apache.commons.statistics.distribution.ContinuousDistribution
    public double getSupportUpperBound() {
        return this.upper;
    }

    private double clipToRange(double d) {
        return clip(d, this.lower, this.upper);
    }

    private static double clip(double d, double d2, double d3) {
        return d <= d2 ? d2 : d < d3 ? d : d3;
    }

    static double moment1(double d, double d2) {
        double expm1;
        if (d == d2) {
            return d;
        }
        if (Math.abs(d) > Math.abs(d2)) {
            return MIN_P - moment1(-d2, -d);
        }
        if (d <= -1.3407807929942596E154d) {
            return MIN_P;
        }
        if (d2 >= MAX_X) {
            return ROOT_2_PI / Erfcx.value(d / ROOT2);
        }
        double d3 = 0.5d * (d2 + d) * (d2 - d);
        if (d <= MIN_P) {
            expm1 = ((ROOT_2_PI * (-Math.expm1(-d3))) * Math.exp(((-0.5d) * d) * d)) / ErfDifference.value(d / ROOT2, d2 / ROOT2);
        } else {
            double exp = (Math.exp(-d3) * Erfcx.value(d2 / ROOT2)) - Erfcx.value(d / ROOT2);
            if (exp == MIN_P) {
                return (d + d2) * 0.5d;
            }
            expm1 = (ROOT_2_PI * Math.expm1(-d3)) / exp;
        }
        return clip(expm1, d, d2);
    }

    private static double moment2(double d, double d2) {
        double clip;
        if (Math.abs(d) > Math.abs(d2)) {
            return moment2(-d2, -d);
        }
        if (d <= -1.3407807929942596E154d) {
            return 1.0d;
        }
        if (d2 >= MAX_X) {
            return 1.0d + ((ROOT_2_PI * d) / Erfcx.value(d / ROOT2));
        }
        if (d <= MIN_P) {
            double value = ROOT_PI_2 * Erf.value(d / ROOT2);
            double value2 = ROOT_PI_2 * Erf.value(d2 / ROOT2);
            clip = clip(((value2 - (d2 * Math.exp(((-0.5d) * d2) * d2))) - (value - (d * Math.exp(((-0.5d) * d) * d)))) / (value2 - value), MIN_P, 1.0d);
        } else {
            double exp = Math.exp(-(0.5d * (d2 + d) * (d2 - d)));
            double value3 = ROOT_PI_2 * Erfcx.value(d / ROOT2);
            double value4 = ROOT_PI_2 * Erfcx.value(d2 / ROOT2);
            clip = clip(((value3 + d) - ((value4 + d2) * exp)) / (value3 - (value4 * exp)), d * d, d2 * d2);
        }
        return clip;
    }

    static double variance(double d, double d2) {
        if (d == d2) {
            return MIN_P;
        }
        double moment1 = moment1(d, d2);
        double sqrt = Math.sqrt(moment2(d, d2));
        double d3 = (sqrt - moment1) * (sqrt + moment1);
        if (d3 < 1.0d) {
            return d3 <= MIN_P ? MIN_P : d3;
        }
        if (d >= -1.0d || d2 <= 1.0d) {
            return MIN_P;
        }
        return 1.0d;
    }
}
