diff --git a/CHANGELOG.md b/CHANGELOG.md index 579cce0..d89ddba 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fix panic in `NormalInverseGaussian::new` with very large `alpha`; this is a Value-breaking change (#40) - Fix panic in `Binomial::sample` with `n ≥ 2^63`; this is a Value-breaking change (#43) - Error instead of producing `-inf` output for `Exp` when `lambda` is `-0.0` (#44) +- Avoid returning NaN from `Gamma::sample`; this is a Value-breaking change and also affects `ChiSquared` and `Dirichlet` (#46) ## [0.5.2] diff --git a/src/gamma.rs b/src/gamma.rs index 0fc6b75..bf802e3 100644 --- a/src/gamma.rs +++ b/src/gamma.rs @@ -53,6 +53,12 @@ use serde::{Deserialize, Serialize}; /// /// # Notes /// +/// When the shape (`k`) or scale (`θ`) parameters are close to the upper limits +/// of the floating point type `F`, the implementation may overflow and +/// produce `inf`; similarly, if either is sufficiently close to `0.0`, +/// it may output `0.0`. Sampling may become inaccurate if `k` is close +/// to zero and `θ` is very large. +/// /// The algorithm used is that described by Marsaglia & Tsang 2000[^1], /// falling back to directly sampling from an Exponential for `shape /// == 1`, and using the boosting technique described in that paper for @@ -173,8 +179,10 @@ where return Err(Error::ScaleTooSmall); } - let repr = if shape == F::one() { - One(Exp::new(F::one() / scale).map_err(|_| Error::ScaleTooLarge)?) + let repr = if shape == F::infinity() || scale == F::infinity() { + One(Exp::new(F::zero()).unwrap()) + } else if shape == F::one() { + One(Exp::new(F::one() / scale).unwrap()) } else if shape < F::one() { Small(GammaSmallShape::new_raw(shape, scale)) } else { @@ -212,6 +220,28 @@ where d, } } + + fn sample_unscaled(&self, rng: &mut R) -> F { + // Marsaglia & Tsang method, 2000 + loop { + let x: F = rng.sample(StandardNormal); + let v_cbrt = F::one() + self.c * x; + if v_cbrt <= F::zero() { + continue; + } + + let v = v_cbrt * v_cbrt * v_cbrt; + let u: F = rng.sample(Open01); + + let x_sqr = x * x; + if u < F::one() - F::from(0.0331).unwrap() * x_sqr * x_sqr + || u.ln() < F::from(0.5).unwrap() * x_sqr + self.d * (F::one() - v + v.ln()) + { + // `x` is concentrated enough that `v` should always be finite + return v; + } + } + } } impl Distribution for Gamma @@ -238,9 +268,14 @@ where fn sample(&self, rng: &mut R) -> F { let u: F = rng.sample(Open01); - self.large_shape.sample(rng) * u.powf(self.inv_shape) + let a = self.large_shape.sample_unscaled(rng); + let b = u.powf(self.inv_shape); + // Multiplying numbers with `scale` can overflow, so do it last to avoid + // producing NaN = inf * 0.0. All the other terms are finite and small. + (a * b * self.large_shape.d) * self.large_shape.scale } } + impl Distribution for GammaLargeShape where F: Float, @@ -248,25 +283,7 @@ where Open01: Distribution, { fn sample(&self, rng: &mut R) -> F { - // Marsaglia & Tsang method, 2000 - loop { - let x: F = rng.sample(StandardNormal); - let v_cbrt = F::one() + self.c * x; - if v_cbrt <= F::zero() { - // a^3 <= 0 iff a <= 0 - continue; - } - - let v = v_cbrt * v_cbrt * v_cbrt; - let u: F = rng.sample(Open01); - - let x_sqr = x * x; - if u < F::one() - F::from(0.0331).unwrap() * x_sqr * x_sqr - || u.ln() < F::from(0.5).unwrap() * x_sqr + self.d * (F::one() - v + v.ln()) - { - return self.d * v * self.scale; - } - } + self.sample_unscaled(rng) * (self.d * self.scale) } } @@ -278,4 +295,13 @@ mod test { fn gamma_distributions_can_be_compared() { assert_eq!(Gamma::new(1.0, 2.0), Gamma::new(1.0, 2.0)); } + + #[test] + fn gamma_extreme_values() { + let d = Gamma::new(f64::infinity(), 2.0).unwrap(); + assert_eq!(d.sample(&mut crate::test::rng(21)), f64::infinity()); + + let d = Gamma::new(2.0, f64::infinity()).unwrap(); + assert_eq!(d.sample(&mut crate::test::rng(21)), f64::infinity()); + } }