Skip to content

Commit 5c9bedc

Browse files
committed
Geometric: handle p>0 where 1-p rounds to 1
1 parent 579d7f3 commit 5c9bedc

File tree

2 files changed

+22
-4
lines changed

2 files changed

+22
-4
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1616
- `Dirichlet` no longer uses `const` generics, which means that its size is not required at compile time. Essentially a revert of rand#1292. (#15)
1717
- Add `Dirichlet::new_with_size` constructor (#15)
1818

19+
### Fixes
20+
- Fix `Geometric::new` for small `p > 0` where `1 - p` rounds to 1 (#36)
21+
1922
## [0.5.2]
2023

2124
### API Changes

src/geometric.rs

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,15 +70,16 @@ impl Geometric {
7070
/// Construct a new `Geometric` with the given shape parameter `p`
7171
/// (probability of success on each trial).
7272
pub fn new(p: f64) -> Result<Self, Error> {
73+
let mut pi = 1.0 - p;
7374
if !p.is_finite() || !(0.0..=1.0).contains(&p) {
7475
Err(Error::InvalidProbability)
75-
} else if p == 0.0 || p >= 2.0 / 3.0 {
76-
Ok(Geometric { p, pi: p, k: 0 })
76+
} else if pi == 1.0 || p >= 2.0 / 3.0 {
77+
Ok(Geometric { p, pi, k: 0 })
7778
} else {
7879
let (pi, k) = {
7980
// choose smallest k such that pi = (1 - p)^(2^k) <= 0.5
8081
let mut k = 1;
81-
let mut pi = (1.0 - p).powi(2);
82+
pi = pi * pi;
8283
while pi > 0.5 {
8384
k += 1;
8485
pi = pi * pi;
@@ -106,7 +107,7 @@ impl Distribution<u64> for Geometric {
106107
return failures;
107108
}
108109

109-
if self.p == 0.0 {
110+
if self.pi == 1.0 {
110111
return u64::MAX;
111112
}
112113

@@ -264,4 +265,18 @@ mod test {
264265
fn geometric_distributions_can_be_compared() {
265266
assert_eq!(Geometric::new(1.0), Geometric::new(1.0));
266267
}
268+
269+
#[test]
270+
fn small_p() {
271+
let a = f64::EPSILON / 2.0;
272+
assert!(1.0 - a < 1.0); // largest repr. value < 1
273+
assert!(Geometric::new(a).is_ok());
274+
275+
let b = f64::EPSILON / 4.0;
276+
assert!(b > 0.0);
277+
assert!(1.0 - b == 1.0); // rounds to 1
278+
let d = Geometric::new(b).unwrap();
279+
let mut rng = crate::test::VoidRng;
280+
assert_eq!(d.sample(&mut rng), u64::MAX);
281+
}
267282
}

0 commit comments

Comments
 (0)