Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fix panic in `NormalInverseGaussian::new` with very large `alpha`; this is a Value-breaking change (#40)
- Fix panic in `Binomial::sample` with `n ≥ 2^63`; this is a Value-breaking change (#43)
- Error instead of producing `-inf` output for `Exp` when `lambda` is `-0.0` (#44)
- Avoid returning NaN from `Gamma::sample`; this is a Value-breaking change and also affects `ChiSquared` and `Dirichlet` (#46)

## [0.5.2]

Expand Down
70 changes: 48 additions & 22 deletions src/gamma.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ use serde::{Deserialize, Serialize};
///
/// # Notes
///
/// When the shape (`k`) or scale (`θ`) parameters are close to the upper limits
/// of the floating point type `F`, the implementation may overflow and
/// produce `inf`; similarly, if either is sufficiently close to `0.0`,
/// it may output `0.0`. Sampling may become inaccurate if `k` is close
/// to zero and `θ` is very large.
///
/// The algorithm used is that described by Marsaglia & Tsang 2000[^1],
/// falling back to directly sampling from an Exponential for `shape
/// == 1`, and using the boosting technique described in that paper for
Expand Down Expand Up @@ -173,8 +179,10 @@ where
return Err(Error::ScaleTooSmall);
}

let repr = if shape == F::one() {
One(Exp::new(F::one() / scale).map_err(|_| Error::ScaleTooLarge)?)
let repr = if shape == F::infinity() || scale == F::infinity() {
One(Exp::new(F::zero()).unwrap())
} else if shape == F::one() {
One(Exp::new(F::one() / scale).unwrap())
} else if shape < F::one() {
Small(GammaSmallShape::new_raw(shape, scale))
} else {
Expand Down Expand Up @@ -212,6 +220,28 @@ where
d,
}
}

fn sample_unscaled<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
// Marsaglia & Tsang method, 2000
loop {
let x: F = rng.sample(StandardNormal);
let v_cbrt = F::one() + self.c * x;
if v_cbrt <= F::zero() {
continue;
}

let v = v_cbrt * v_cbrt * v_cbrt;
let u: F = rng.sample(Open01);

let x_sqr = x * x;
if u < F::one() - F::from(0.0331).unwrap() * x_sqr * x_sqr
|| u.ln() < F::from(0.5).unwrap() * x_sqr + self.d * (F::one() - v + v.ln())
{
// `x` is concentrated enough that `v` should always be finite
return v;
}
}
}
}

impl<F> Distribution<F> for Gamma<F>
Expand All @@ -238,35 +268,22 @@ where
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
let u: F = rng.sample(Open01);

self.large_shape.sample(rng) * u.powf(self.inv_shape)
let a = self.large_shape.sample_unscaled(rng);
let b = u.powf(self.inv_shape);
// Multiplying numbers with `scale` can overflow, so do it last to avoid
// producing NaN = inf * 0.0. All the other terms are finite and small.
(a * b * self.large_shape.d) * self.large_shape.scale
}
}

impl<F> Distribution<F> for GammaLargeShape<F>
where
F: Float,
StandardNormal: Distribution<F>,
Open01: Distribution<F>,
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
// Marsaglia & Tsang method, 2000
loop {
let x: F = rng.sample(StandardNormal);
let v_cbrt = F::one() + self.c * x;
if v_cbrt <= F::zero() {
// a^3 <= 0 iff a <= 0
continue;
}

let v = v_cbrt * v_cbrt * v_cbrt;
let u: F = rng.sample(Open01);

let x_sqr = x * x;
if u < F::one() - F::from(0.0331).unwrap() * x_sqr * x_sqr
|| u.ln() < F::from(0.5).unwrap() * x_sqr + self.d * (F::one() - v + v.ln())
{
return self.d * v * self.scale;
}
}
self.sample_unscaled(rng) * (self.d * self.scale)
}
}

Expand All @@ -278,4 +295,13 @@ mod test {
fn gamma_distributions_can_be_compared() {
assert_eq!(Gamma::new(1.0, 2.0), Gamma::new(1.0, 2.0));
}

#[test]
fn gamma_extreme_values() {
let d = Gamma::new(f64::infinity(), 2.0).unwrap();
assert_eq!(d.sample(&mut crate::test::rng(21)), f64::infinity());

let d = Gamma::new(2.0, f64::infinity()).unwrap();
assert_eq!(d.sample(&mut crate::test::rng(21)), f64::infinity());
}
}