-
Notifications
You must be signed in to change notification settings - Fork 0
Fix errors in edge cases for wang method #19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
151d643
bd72da6
175ad19
e16579c
f4fd5b4
9dc245b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -101,6 +101,7 @@ def wang_binomial_ci( | |||||||||||||||||||||
| precision, | ||||||||||||||||||||||
| grid_one, | ||||||||||||||||||||||
| grid_two, | ||||||||||||||||||||||
| verbose, | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| if verbose: | ||||||||||||||||||||||
| print(f"Left CI: {ci_l}") | ||||||||||||||||||||||
|
|
@@ -114,6 +115,7 @@ def wang_binomial_ci( | |||||||||||||||||||||
| precision, | ||||||||||||||||||||||
| grid_one, | ||||||||||||||||||||||
| grid_two, | ||||||||||||||||||||||
| verbose, | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| if verbose: | ||||||||||||||||||||||
| print(f"Right CI: {ci_u}") | ||||||||||||||||||||||
|
|
@@ -122,7 +124,7 @@ def wang_binomial_ci( | |||||||||||||||||||||
| return ConfidenceInterval(lower, upper, estimate, conf_level, "wang", sides_val) | ||||||||||||||||||||||
| else: | ||||||||||||||||||||||
| ci = binomial_ci_one_sided( | ||||||||||||||||||||||
| n_positive, n_total, ref_positive, ref_total, conf_level, sides_val, precision, grid_one, grid_two | ||||||||||||||||||||||
| n_positive, n_total, ref_positive, ref_total, conf_level, sides_val, precision, grid_one, grid_two, verbose | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| return ConfidenceInterval(ci[1], ci[2], ci[0], conf_level, "wang", sides_val) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
@@ -137,6 +139,7 @@ def binomial_ci_one_sided( | |||||||||||||||||||||
| precision: float, | ||||||||||||||||||||||
| grid_one: int, | ||||||||||||||||||||||
| grid_two: int, | ||||||||||||||||||||||
| verbose: bool = False, | ||||||||||||||||||||||
| ) -> List[float]: | ||||||||||||||||||||||
| """Helper function that calculates one-sided confidence interval. | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
@@ -160,11 +163,13 @@ def binomial_ci_one_sided( | |||||||||||||||||||||
| "right-sided" (aliases "right_sided", "right", "rs", "r"), | ||||||||||||||||||||||
| case insensitive. | ||||||||||||||||||||||
| precision : float, optional | ||||||||||||||||||||||
| Precision for the search algorithm, by default 1e-5 | ||||||||||||||||||||||
| Precision for the search algorithm, by default 1e-5. | ||||||||||||||||||||||
| grid_one : int, optional | ||||||||||||||||||||||
| Number of grid points in first step, by default 30 | ||||||||||||||||||||||
| Number of grid points in first step, by default 30. | ||||||||||||||||||||||
| grid_two : int, optional | ||||||||||||||||||||||
| Number of grid points in second step, by default 20 | ||||||||||||||||||||||
| Number of grid points in second step, by default 20. | ||||||||||||||||||||||
| verbose : bool, optional | ||||||||||||||||||||||
| Verbosity for debug message, by default False. | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| Returns | ||||||||||||||||||||||
| ------- | ||||||||||||||||||||||
|
|
@@ -204,7 +209,7 @@ def binomial_ci_one_sided( | |||||||||||||||||||||
| f[:, 2] = (p1hat - p0hat) / np.sqrt(denom) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| # Sort f by the third column in descending order | ||||||||||||||||||||||
| f = f[(-f[:, 2]).argsort(), :] | ||||||||||||||||||||||
| f = f[(-f[:, 2]).argsort(kind="stable"), :] | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| allvector = np.round(f[:, 0] * (m + 2) + f[:, 1]).astype(int) | ||||||||||||||||||||||
| allvectormove = np.round((f[:, 0] + 1) * (m + 3) + (f[:, 1] + 1)).astype(int) | ||||||||||||||||||||||
|
|
@@ -268,7 +273,7 @@ def binomial_ci_one_sided( | |||||||||||||||||||||
|
|
||||||||||||||||||||||
| # Generate N | ||||||||||||||||||||||
| n_arr = np.unique(np.vstack((a, b)), axis=0) | ||||||||||||||||||||||
| nvector = ((n_arr[:, 0] + 1) * (m + 3) + n_arr[:, 1] + 1).astype(int) | ||||||||||||||||||||||
| nvector = ((n_arr[:, 0] + 1) * (m + 3) + n_arr[:, 1] + 1).astype(int) # type: ignore | ||||||||||||||||||||||
| nvector = nvector[np.isin(nvector, allvectormove)] | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| skvector = ((s[:kk, 0] + 1) * (m + 3) + s[:kk, 1] + 1).astype(int) | ||||||||||||||||||||||
|
|
@@ -303,6 +308,7 @@ def binomial_ci_one_sided( | |||||||||||||||||||||
| else: | ||||||||||||||||||||||
| length_nc = nc_arr.shape[0] | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| ncmax = 0 # avoid pylance warning | ||||||||||||||||||||||
| for ci in range(length_nc): | ||||||||||||||||||||||
| ls_arr[kk, 0:2] = nc_arr[ci, 0:2] | ||||||||||||||||||||||
| i1_vec = ls_arr[: (kk + 1), 0] | ||||||||||||||||||||||
|
|
@@ -363,7 +369,7 @@ def binomial_ci_one_sided( | |||||||||||||||||||||
| if length_nc >= 2: | ||||||||||||||||||||||
| valid = ~np.isnan(nc_arr[:, 0]) | ||||||||||||||||||||||
| ncnomiss = nc_arr[valid] | ||||||||||||||||||||||
| ncnomiss = ncnomiss[(-ncnomiss[:, 2]).argsort(), :] | ||||||||||||||||||||||
| ncnomiss = ncnomiss[(-ncnomiss[:, 2]).argsort(kind="stable"), :] | ||||||||||||||||||||||
| morepoint = np.sum(ncnomiss[:, 2] >= ncnomiss[0, 2] - delta) | ||||||||||||||||||||||
| if morepoint >= 2: | ||||||||||||||||||||||
| ls_arr[kk : kk + morepoint, 0:2] = ncnomiss[:morepoint, 0:2] | ||||||||||||||||||||||
|
|
@@ -429,7 +435,8 @@ def binomial_ci_one_sided( | |||||||||||||||||||||
|
|
||||||||||||||||||||||
| kk1 = kk | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| output = [val.item() if isinstance(val, np.generic) else val for val in output] | ||||||||||||||||||||||
| # output = [val.item() if isinstance(val, np.generic) else val for val in output] | ||||||||||||||||||||||
| output = np.array(output).tolist() | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| return output | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
@@ -456,27 +463,32 @@ def _prob2step(delv, delta, n, m, i1, i2, grid_one, grid_two): | |||||||||||||||||||||
| p0 = np.linspace(-delv + delta, 1 - delta, grid_one) | ||||||||||||||||||||||
| else: | ||||||||||||||||||||||
| p0 = np.linspace(delta, 1 - delv - delta, grid_one) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| i1 = np.atleast_1d(i1) | ||||||||||||||||||||||
| i2 = np.atleast_1d(i2) | ||||||||||||||||||||||
| part1 = np.log(comb(n, i1))[:, None] + np.outer(i1, np.log(p0 + delv)) + np.outer(n - i1, np.log(1 - p0 - delv)) | ||||||||||||||||||||||
| part2 = np.log(comb(m, i2))[:, None] + np.outer(i2, np.log(p0)) + np.outer(m - i2, np.log(1 - p0)) | ||||||||||||||||||||||
| sumofprob = np.exp(part1 + part2).sum(axis=0) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| # plateau-aware refinement (R: which(sumofprob == max(sumofprob))) | ||||||||||||||||||||||
| mansum = sumofprob.max() | ||||||||||||||||||||||
| atol = 1e-14 * (mansum if mansum > 0 else 1.0) | ||||||||||||||||||||||
| plateau_idx = np.where(np.isclose(sumofprob, mansum, rtol=0.0, atol=atol))[0] | ||||||||||||||||||||||
| # mansum = sumofprob.max() | ||||||||||||||||||||||
| # atol = 1e-14 * (mansum if mansum > 0 else 1.0) | ||||||||||||||||||||||
| # plateau_idx = np.where(np.isclose(sumofprob, mansum, rtol=0.0, atol=atol))[0] | ||||||||||||||||||||||
|
Comment on lines
+473
to
+475
|
||||||||||||||||||||||
| # mansum = sumofprob.max() | |
| # atol = 1e-14 * (mansum if mansum > 0 else 1.0) | |
| # plateau_idx = np.where(np.isclose(sumofprob, mansum, rtol=0.0, atol=atol))[0] |
Copilot
AI
Jan 22, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The commented-out code should be removed rather than left in the codebase. These lines appear to be the old implementation that was replaced with the simpler plateau detection logic on line 507. Leaving commented code in production reduces readability and can cause confusion.
| # mansum = sumofprob.min() | |
| # atol = 1e-14 * (abs(mansum) if mansum != 0 else 1.0) | |
| # plateau_idx = np.where(np.isclose(sumofprob, mansum, rtol=0.0, atol=atol))[0] |
Copilot
AI
Jan 22, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using exact equality (==) for floating-point comparison could miss values that are numerically very close to the minimum due to floating-point precision issues. The commented-out code used np.isclose with a small tolerance to handle this. Consider whether exact equality is appropriate here, or if the tolerance-based comparison should be retained to handle numerical precision edge cases robustly.
| # mansum = sumofprob.min() | |
| # atol = 1e-14 * (abs(mansum) if mansum != 0 else 1.0) | |
| # plateau_idx = np.where(np.isclose(sumofprob, mansum, rtol=0.0, atol=atol))[0] | |
| plateau_idx = np.where(sumofprob == sumofprob.min())[0] | |
| mansum = sumofprob.min() | |
| atol = 1e-14 * (abs(mansum) if mansum != 0 else 1.0) | |
| plateau_idx = np.where(np.isclose(sumofprob, mansum, rtol=0.0, atol=atol))[0] |
Copilot
AI
Jan 22, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a potential issue where lowerb could be greater than upperb after the calculations on lines 513-514. This happens because both lowerb and upperb are computed independently with + delta and - delta adjustments. When p0[rightmost] - stepv + delta > p0[leftmost] + stepv - delta, the second call to np.linspace(lowerb, upperb, grid_two) on line 516 would create a reversed array or fail if bounds checking is strict. Consider adding a check to ensure lowerb <= upperb, or swap them if necessary.
| lowerb = max(p0[0], p0[rightmost] - stepv) + delta | |
| upperb = min(p0[-1], p0[leftmost] + stepv) - delta | |
| raw_lowerb = max(p0[0], p0[rightmost] - stepv) + delta | |
| raw_upperb = min(p0[-1], p0[leftmost] + stepv) - delta | |
| if raw_lowerb <= raw_upperb: | |
| lowerb, upperb = raw_lowerb, raw_upperb | |
| else: | |
| # Ensure bounds are ordered for linspace; swap if necessary | |
| lowerb, upperb = raw_upperb, raw_lowerb |
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -1,3 +1,5 @@ | ||||||
| import warnings | ||||||
|
|
||||||
| import numpy as np | ||||||
| import pytest | ||||||
| from rpy2.robjects import r | ||||||
|
|
@@ -130,7 +132,6 @@ | |||||
| allvector<-setdiff(allvector,partvector) | ||||||
|
|
||||||
|
|
||||||
|
|
||||||
| ################### from the second table ################################ | ||||||
|
|
||||||
| morepoint=1 | ||||||
|
|
@@ -144,7 +145,6 @@ | |||||
| if(x==0 && y==m && CItype=="Upper"){output[2]=-1;output[3]=-Ls[1,4];kk<-dimoftable} | ||||||
|
|
||||||
|
|
||||||
|
|
||||||
| while(kk<=(dimoftable-2)) | ||||||
| { | ||||||
| C<-Ls[(kk-morepoint+1):kk,1:2] | ||||||
|
|
@@ -205,8 +205,6 @@ | |||||
| } | ||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
| prob2step<-function(delv) | ||||||
| { | ||||||
| delvalue<-delv | ||||||
|
|
@@ -359,8 +357,6 @@ | |||||
| }## end of function morepointLsest | ||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
| if(i>=2) | ||||||
| {NCnomiss<-NC[1:dim(na.omit(NC))[1],] | ||||||
| NCnomiss<-NCnomiss[order(-NCnomiss[,3]),] | ||||||
|
|
@@ -415,6 +411,10 @@ | |||||
| # fmt: off | ||||||
|
|
||||||
|
|
||||||
| ERR_LIMIT_STRICT = 1e-4 | ||||||
| ERR_LIMIT_LOOSE = 1e-2 | ||||||
|
Comment on lines
+414
to
+415
|
||||||
|
|
||||||
|
|
||||||
| def test_wang_method(): | ||||||
| n_test = 7 | ||||||
| tot_ub = 100 | ||||||
|
|
@@ -425,33 +425,63 @@ def test_wang_method(): | |||||
| ref_positive = np.random.randint(ref_total + 1) | ||||||
|
|
||||||
| # results computed from R function | ||||||
| r_result = r["wang_binomial_ci_r"](n_positive, n_total, ref_positive, ref_total) | ||||||
| r_result = r["wang_binomial_ci_r"](n_positive, n_total, ref_positive, ref_total) # type: ignore | ||||||
| r_result_dict = dict(zip(r_result.names, r_result)) | ||||||
| r_lb, r_ub = [item[1] for item in r_result_dict["ExactCI"].items()] | ||||||
|
|
||||||
| # results computed from Python function | ||||||
| lb, ub = compute_difference_confidence_interval(n_positive, n_total, ref_positive, ref_total, method="wang").astuple() | ||||||
| lb, ub = compute_difference_confidence_interval(n_positive, n_total, ref_positive, ref_total, method="wang").astuple() # type: ignore | ||||||
|
|
||||||
| # compare results | ||||||
| assert np.isclose( | ||||||
| (r_lb, r_ub), (lb, ub), atol=1e-4 | ||||||
| ).all(), f"R result: {r_lb, r_ub}, Python result: {lb, ub} for {n_positive = }, {n_total = }, {ref_positive = }, {ref_total = }" # noqa: E202, E251 | ||||||
| print(f"Test passed for {n_positive = }, {n_total = }, {ref_positive = }, {ref_total = }") # noqa: E202, E251 | ||||||
| if not np.isclose((r_lb, r_ub), (lb, ub), atol=ERR_LIMIT_STRICT).all(): | ||||||
| warnings.warn( | ||||||
| f"Strict test failed for {n_positive = }, {n_total = }, {ref_positive = }, {ref_total = }, " | ||||||
| f"R result: {r_lb, r_ub}, Python result: {lb, ub}. falling back to loose test.", | ||||||
|
||||||
| f"R result: {r_lb, r_ub}, Python result: {lb, ub}. falling back to loose test.", | |
| f" R result: {r_lb, r_ub}, Python result: {lb, ub}. falling back to loose test.", |
Copilot
AI
Jan 22, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing space after comma in the warning message. The message should read "...{ref_total = }, R result:" instead of "...{ref_total = },R result:" for better readability.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The commented-out code should be removed rather than left in the codebase. This appears to be the old implementation that was replaced. Leaving commented code in production reduces readability.