Skip to content

Commit d63e8e6

Browse files
committed
Avoid panics in Binomial BTPE sampler by using u64 values
1 parent a6a9f7b commit d63e8e6

File tree

2 files changed

+29
-24
lines changed

2 files changed

+29
-24
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2121
- Use `direct-minimal-versions` (#38)
2222
- Fix panic in `FisherF::new` on almost zero parameters (#39)
2323
- Fix panic in `NormalInverseGaussian::new` with very large `alpha`; this is a Value-breaking change (#40)
24+
- Fix panic in `Binomial::sample` with `n ≥ 2^63`; this is a Value-breaking change (#43)
2425
- Error instead of producing `-inf` output for `Exp` when `lambda` is `-0.0` (#44)
2526

2627
## [0.5.2]

src/binomial.rs

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ struct Binv {
7878
struct Btpe {
7979
n: u64,
8080
p: f64,
81-
m: i64,
81+
m: u64,
8282
p1: f64,
8383
}
8484

@@ -168,17 +168,17 @@ impl Binomial {
168168
let npq = np * q;
169169
let p1 = (2.195 * npq.sqrt() - 4.6 * q).floor() + 0.5;
170170
let f_m = np + p;
171-
let m = f64_to_i64(f_m);
171+
let m = f64_to_u64(f_m);
172172
Method::Btpe(Btpe { n, p, m, p1 }, flipped)
173173
};
174174
Ok(Binomial { method })
175175
}
176176
}
177177

178-
/// Convert a `f64` to an `i64`, panicking on overflow.
179-
fn f64_to_i64(x: f64) -> i64 {
180-
assert!(x < (i64::MAX as f64));
181-
x as i64
178+
/// Convert a `f64` to a `u64`, panicking on overflow.
179+
fn f64_to_u64(x: f64) -> u64 {
180+
assert!(x >= 0.0 && x < (u64::MAX as f64));
181+
x as u64
182182
}
183183

184184
fn binv<R: Rng + ?Sized>(binv: Binv, flipped: bool, rng: &mut R) -> u64 {
@@ -211,11 +211,11 @@ fn binv<R: Rng + ?Sized>(binv: Binv, flipped: bool, rng: &mut R) -> u64 {
211211
fn btpe<R: Rng + ?Sized>(btpe: Btpe, flipped: bool, rng: &mut R) -> u64 {
212212
// Threshold for using the squeeze algorithm. This can be freely
213213
// chosen based on performance. Ranlib and GSL use 20.
214-
const SQUEEZE_THRESHOLD: i64 = 20;
214+
const SQUEEZE_THRESHOLD: u64 = 20;
215215

216216
// Step 0: Calculate constants as functions of `n` and `p`.
217-
let n = btpe.n as f64;
218-
let np = n * btpe.p;
217+
let n = btpe.n;
218+
let np = (n as f64) * btpe.p;
219219
let q = 1. - btpe.p;
220220
let npq = np * q;
221221
let f_m = np + btpe.p;
@@ -244,7 +244,7 @@ fn btpe<R: Rng + ?Sized>(btpe: Btpe, flipped: bool, rng: &mut R) -> u64 {
244244
let p4 = p3 + c / lambda_r;
245245

246246
// return value
247-
let mut y: i64;
247+
let mut y: u64;
248248

249249
let gen_u = Uniform::new(0., p4).unwrap();
250250
let gen_v = Uniform::new(0., 1.).unwrap();
@@ -255,7 +255,7 @@ fn btpe<R: Rng + ?Sized>(btpe: Btpe, flipped: bool, rng: &mut R) -> u64 {
255255
let u = gen_u.sample(rng);
256256
let mut v = gen_v.sample(rng);
257257
if !(u > p1) {
258-
y = f64_to_i64(x_m - p1 * v + u);
258+
y = f64_to_u64(x_m - p1 * v + u);
259259
break;
260260
}
261261

@@ -267,20 +267,21 @@ fn btpe<R: Rng + ?Sized>(btpe: Btpe, flipped: bool, rng: &mut R) -> u64 {
267267
if v > 1. {
268268
continue;
269269
} else {
270-
y = f64_to_i64(x);
270+
y = f64_to_u64(x);
271271
}
272272
} else if !(u > p3) {
273273
// Step 3: Region 3, left exponential tail.
274-
y = f64_to_i64(x_l + v.ln() / lambda_l);
275-
if y < 0 {
274+
let y_tmp = x_l + v.ln() / lambda_l;
275+
if y_tmp < 0.0 {
276276
continue;
277277
} else {
278+
y = f64_to_u64(y_tmp);
278279
v *= (u - p2) * lambda_l;
279280
}
280281
} else {
281282
// Step 4: Region 4, right exponential tail.
282-
y = f64_to_i64(x_r - v.ln() / lambda_r);
283-
if y > 0 && (y as u64) > btpe.n {
283+
y = (x_r - v.ln() / lambda_r) as u64; // `as` cast saturates
284+
if y > btpe.n {
284285
continue;
285286
} else {
286287
v *= (u - p3) * lambda_r;
@@ -290,12 +291,12 @@ fn btpe<R: Rng + ?Sized>(btpe: Btpe, flipped: bool, rng: &mut R) -> u64 {
290291
// Step 5: Acceptance/rejection comparison.
291292

292293
// Step 5.0: Test for appropriate method of evaluating f(y).
293-
let k = (y - m).abs();
294+
let k = y.abs_diff(m);
294295
if !(k > SQUEEZE_THRESHOLD && (k as f64) < 0.5 * npq - 1.) {
295296
// Step 5.1: Evaluate f(y) via the recursive relationship. Start the
296297
// search from the mode.
297298
let s = btpe.p / q;
298-
let a = s * (n + 1.);
299+
let a = s * (n as f64 + 1.);
299300
let mut f = 1.0;
300301
match m.cmp(&y) {
301302
Ordering::Less => {
@@ -343,18 +344,23 @@ fn btpe<R: Rng + ?Sized>(btpe: Btpe, flipped: bool, rng: &mut R) -> u64 {
343344
// Step 5.3: Final acceptance/rejection test.
344345
let x1 = (y + 1) as f64;
345346
let f1 = (m + 1) as f64;
346-
let z = (f64_to_i64(n) + 1 - m) as f64;
347-
let w = (f64_to_i64(n) - y + 1) as f64;
347+
let z = ((n - m) + 1) as f64;
348+
let w = ((n - y) + 1) as f64;
348349

349350
fn stirling(a: f64) -> f64 {
350351
let a2 = a * a;
351352
(13860. - (462. - (132. - (99. - 140. / a2) / a2) / a2) / a2) / a / 166320.
352353
}
353354

355+
let y_sub_m = if y > m {
356+
(y - m) as f64
357+
} else {
358+
-((m - y) as f64)
359+
};
354360
if alpha
355361
> x_m * (f1 / x1).ln()
356-
+ (n - (m as f64) + 0.5) * (z / w).ln()
357-
+ ((y - m) as f64) * (w * btpe.p / (x1 * q)).ln()
362+
+ (((n - m) as f64) + 0.5) * (z / w).ln()
363+
+ y_sub_m * (w * btpe.p / (x1 * q)).ln()
358364
// We use the signs from the GSL implementation, which are
359365
// different than the ones in the reference. According to
360366
// the GSL authors, the new signs were verified to be
@@ -370,8 +376,6 @@ fn btpe<R: Rng + ?Sized>(btpe: Btpe, flipped: bool, rng: &mut R) -> u64 {
370376

371377
break;
372378
}
373-
assert!(y >= 0);
374-
let y = y as u64;
375379

376380
if flipped { btpe.n - y } else { y }
377381
}

0 commit comments

Comments
 (0)