ssbc 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ssbc/__init__.py ADDED
@@ -0,0 +1,59 @@
1
+ """Top-level package for SSBC (Small-Sample Beta Correction)."""
2
+
3
+ __author__ = """Petrus H Zwart"""
4
+ __email__ = "phzwart@lbl.gov"
5
+ __version__ = "0.1.0"
6
+
7
+ # Core SSBC algorithm
8
+ # Conformal prediction
9
+ from .conformal import (
10
+ mondrian_conformal_calibrate,
11
+ split_by_class,
12
+ )
13
+ from .core import (
14
+ SSBCResult,
15
+ ssbc_correct,
16
+ )
17
+
18
+ # Hyperparameter tuning
19
+ from .hyperparameter import (
20
+ sweep_and_plot_parallel_plotly,
21
+ sweep_hyperparams_and_collect,
22
+ )
23
+
24
+ # Simulation (for testing and examples)
25
+ from .simulation import (
26
+ BinaryClassifierSimulator,
27
+ )
28
+
29
+ # Statistics utilities
30
+ from .statistics import (
31
+ clopper_pearson_intervals,
32
+ cp_interval,
33
+ )
34
+
35
+ # Visualization and reporting
36
+ from .visualization import (
37
+ plot_parallel_coordinates_plotly,
38
+ report_prediction_stats,
39
+ )
40
+
41
+ __all__ = [
42
+ # Core
43
+ "SSBCResult",
44
+ "ssbc_correct",
45
+ # Conformal
46
+ "mondrian_conformal_calibrate",
47
+ "split_by_class",
48
+ # Statistics
49
+ "clopper_pearson_intervals",
50
+ "cp_interval",
51
+ # Simulation
52
+ "BinaryClassifierSimulator",
53
+ # Visualization
54
+ "report_prediction_stats",
55
+ "plot_parallel_coordinates_plotly",
56
+ # Hyperparameter
57
+ "sweep_hyperparams_and_collect",
58
+ "sweep_and_plot_parallel_plotly",
59
+ ]
ssbc/__main__.py ADDED
@@ -0,0 +1,4 @@
1
+ from .cli import app
2
+
3
+ if __name__ == "__main__":
4
+ app()
ssbc/cli.py ADDED
@@ -0,0 +1,21 @@
1
+ """Console script for ssbc."""
2
+
3
+ import typer
4
+ from rich.console import Console
5
+
6
+ from ssbc import utils
7
+
8
+ app = typer.Typer()
9
+ console = Console()
10
+
11
+
12
+ @app.command()
13
+ def main():
14
+ """Console script for ssbc."""
15
+ console.print("Replace this message by putting your code into ssbc.cli.main")
16
+ console.print("See Typer documentation at https://typer.tiangolo.com/")
17
+ utils.do_something_useful()
18
+
19
+
20
+ if __name__ == "__main__":
21
+ app()
ssbc/conformal.py ADDED
@@ -0,0 +1,333 @@
1
+ """Mondrian conformal prediction with SSBC correction."""
2
+
3
+ from typing import Any, Literal
4
+
5
+ import numpy as np
6
+
7
+ from .core import ssbc_correct
8
+ from .statistics import cp_interval
9
+
10
+
11
+ def split_by_class(labels: np.ndarray, probs: np.ndarray) -> dict[int, dict[str, Any]]:
12
+ """Split calibration data by true class for Mondrian conformal prediction.
13
+
14
+ Parameters
15
+ ----------
16
+ labels : np.ndarray, shape (n,)
17
+ True binary labels (0 or 1)
18
+ probs : np.ndarray, shape (n, 2)
19
+ Classification probabilities [P(class=0), P(class=1)]
20
+
21
+ Returns
22
+ -------
23
+ dict
24
+ Dictionary with keys 0 and 1, each containing:
25
+ - 'labels': labels for this class (all same value)
26
+ - 'probs': probabilities for samples in this class
27
+ - 'indices': original indices (for tracking)
28
+ - 'n': number of samples in this class
29
+
30
+ Examples
31
+ --------
32
+ >>> labels = np.array([0, 1, 0, 1])
33
+ >>> probs = np.array([[0.8, 0.2], [0.3, 0.7], [0.9, 0.1], [0.2, 0.8]])
34
+ >>> class_data = split_by_class(labels, probs)
35
+ >>> print(class_data[0]['n']) # Number of class 0 samples
36
+ 2
37
+ """
38
+ class_data = {}
39
+
40
+ for label in [0, 1]:
41
+ mask = labels == label
42
+ indices = np.where(mask)[0]
43
+
44
+ class_data[label] = {"labels": labels[mask], "probs": probs[mask], "indices": indices, "n": np.sum(mask)}
45
+
46
+ return class_data
47
+
48
+
49
+ def mondrian_conformal_calibrate(
50
+ class_data: dict[int, dict[str, Any]],
51
+ alpha_target: float | dict[int, float],
52
+ delta: float | dict[int, float],
53
+ mode: Literal["beta", "beta-binomial"] = "beta",
54
+ m: int | None = None,
55
+ ) -> tuple[dict[int, dict[str, Any]], dict[Any, Any]]:
56
+ """Perform Mondrian (per-class) conformal calibration with SSBC correction.
57
+
58
+ For each class, compute:
59
+ 1. Nonconformity scores: s(x, y) = 1 - P(y|x)
60
+ 2. SSBC-corrected alpha for PAC guarantee
61
+ 3. Conformal quantile threshold
62
+ 4. Singleton error rate bounds via PAC guarantee
63
+
64
+ Then evaluate prediction set sizes on calibration data PER CLASS and MARGINALLY.
65
+
66
+ Parameters
67
+ ----------
68
+ class_data : dict
69
+ Output from split_by_class()
70
+ alpha_target : float or dict
71
+ Target miscoverage rate for each class
72
+ If float: same for both classes
73
+ If dict: {0: α0, 1: α1} for per-class control
74
+ delta : float or dict
75
+ PAC risk tolerance for each class
76
+ If float: same for both classes
77
+ If dict: {0: δ0, 1: δ1} for per-class control
78
+ mode : str, default="beta"
79
+ "beta" (infinite test) or "beta-binomial" (finite test)
80
+ m : int, optional
81
+ Test window size for beta-binomial mode
82
+
83
+ Returns
84
+ -------
85
+ calibration_result : dict
86
+ Dictionary with keys 0 and 1, each containing calibration info
87
+ prediction_stats : dict
88
+ Dictionary with keys:
89
+ - 0, 1: per-class statistics (conditioned on true label)
90
+ - 'marginal': overall statistics (ignoring true labels)
91
+
92
+ Examples
93
+ --------
94
+ >>> labels = np.array([0, 1, 0, 1])
95
+ >>> probs = np.array([[0.8, 0.2], [0.3, 0.7], [0.9, 0.1], [0.2, 0.8]])
96
+ >>> class_data = split_by_class(labels, probs)
97
+ >>> cal_result, pred_stats = mondrian_conformal_calibrate(
98
+ ... class_data, alpha_target=0.1, delta=0.1
99
+ ... )
100
+ """
101
+ # Handle scalar or dict inputs for alpha and delta
102
+ alpha_dict: dict[int, float]
103
+ if isinstance(alpha_target, int | float):
104
+ alpha_dict = {0: float(alpha_target), 1: float(alpha_target)}
105
+ else:
106
+ # alpha_target is dict[int, float] in this branch
107
+ assert isinstance(alpha_target, dict), "alpha_target must be dict if not scalar"
108
+ alpha_dict = {k: float(v) for k, v in alpha_target.items()}
109
+
110
+ delta_dict: dict[int, float]
111
+ if isinstance(delta, int | float):
112
+ delta_dict = {0: float(delta), 1: float(delta)}
113
+ else:
114
+ # delta is dict[int, float] in this branch
115
+ assert isinstance(delta, dict), "delta must be dict if not scalar"
116
+ delta_dict = {k: float(v) for k, v in delta.items()}
117
+
118
+ calibration_result = {}
119
+
120
+ # Step 1: Calibrate per class
121
+ for label in [0, 1]:
122
+ data = class_data[label]
123
+ n = data["n"]
124
+ alpha_class = alpha_dict[label]
125
+ delta_class = delta_dict[label]
126
+
127
+ if n == 0:
128
+ calibration_result[label] = {
129
+ "n": 0,
130
+ "alpha_target": alpha_class,
131
+ "alpha_corrected": None,
132
+ "delta": delta_class,
133
+ "threshold": None,
134
+ "scores": np.array([]),
135
+ "ssbc_result": None,
136
+ "error": "No calibration samples for this class",
137
+ }
138
+ continue
139
+
140
+ # Compute nonconformity scores: s(x, y) = 1 - P(y|x)
141
+ true_class_probs = data["probs"][:, label]
142
+ scores = 1.0 - true_class_probs
143
+
144
+ # Apply SSBC to get corrected alpha
145
+ ssbc_result = ssbc_correct(alpha_target=alpha_class, n=n, delta=delta_class, mode=mode, m=m)
146
+
147
+ alpha_corrected = ssbc_result.alpha_corrected
148
+
149
+ # Compute conformal quantile threshold
150
+ k = int(np.ceil((n + 1) * (1 - alpha_corrected)))
151
+ k = min(k, n)
152
+
153
+ sorted_scores = np.sort(scores)
154
+ threshold = sorted_scores[k - 1] if k > 0 else sorted_scores[0]
155
+
156
+ calibration_result[label] = {
157
+ "n": n,
158
+ "alpha_target": alpha_class,
159
+ "alpha_corrected": alpha_corrected,
160
+ "delta": delta_class,
161
+ "threshold": threshold,
162
+ "scores": sorted_scores,
163
+ "ssbc_result": ssbc_result,
164
+ "k": k,
165
+ }
166
+
167
+ # Step 2: Evaluate prediction sets
168
+ if calibration_result[0].get("threshold") is None or calibration_result[1].get("threshold") is None:
169
+ return calibration_result, {
170
+ "error": "Cannot compute prediction sets - missing threshold for at least one class"
171
+ }
172
+
173
+ threshold_0 = calibration_result[0]["threshold"]
174
+ threshold_1 = calibration_result[1]["threshold"]
175
+
176
+ prediction_stats = {}
177
+
178
+ # Step 2a: Evaluate per true class
179
+ for true_label in [0, 1]:
180
+ data = class_data[true_label]
181
+ n_class = data["n"]
182
+
183
+ if n_class == 0:
184
+ prediction_stats[true_label] = {"n_class": 0, "error": "No samples in this class"}
185
+ continue
186
+
187
+ probs = data["probs"]
188
+ prediction_sets = []
189
+
190
+ for i in range(n_class):
191
+ score_0 = 1.0 - probs[i, 0]
192
+ score_1 = 1.0 - probs[i, 1]
193
+
194
+ pred_set = []
195
+ if score_0 <= threshold_0:
196
+ pred_set.append(0)
197
+ if score_1 <= threshold_1:
198
+ pred_set.append(1)
199
+
200
+ prediction_sets.append(pred_set)
201
+
202
+ # Count set sizes and correctness
203
+ n_abstentions = sum(len(ps) == 0 for ps in prediction_sets)
204
+ n_doublets = sum(len(ps) == 2 for ps in prediction_sets)
205
+
206
+ n_singletons_correct = sum(ps == [true_label] for ps in prediction_sets)
207
+ n_singletons_incorrect = sum(len(ps) == 1 and true_label not in ps for ps in prediction_sets)
208
+ n_singletons_total = n_singletons_correct + n_singletons_incorrect
209
+
210
+ # PAC bounds
211
+ n_escalations = n_doublets + n_abstentions
212
+
213
+ if n_escalations > 0 and n_singletons_total > 0:
214
+ rho = n_singletons_total / n_escalations
215
+ kappa = n_abstentions / n_escalations
216
+ alpha_singlet_bound = alpha_dict[true_label] * (1 + 1 / rho) - kappa / rho
217
+ alpha_singlet_observed = n_singletons_incorrect / n_singletons_total if n_singletons_total > 0 else 0.0
218
+ else:
219
+ rho = None
220
+ kappa = None
221
+ alpha_singlet_bound = None
222
+ alpha_singlet_observed = None
223
+
224
+ prediction_stats[true_label] = {
225
+ "n_class": n_class,
226
+ "alpha_target": alpha_dict[true_label],
227
+ "delta": delta_dict[true_label],
228
+ "abstentions": cp_interval(n_abstentions, n_class),
229
+ "singletons": cp_interval(n_singletons_total, n_class),
230
+ "singletons_correct": cp_interval(n_singletons_correct, n_class),
231
+ "singletons_incorrect": cp_interval(n_singletons_incorrect, n_class),
232
+ "doublets": cp_interval(n_doublets, n_class),
233
+ "prediction_sets": prediction_sets,
234
+ "pac_bounds": {
235
+ "rho": rho,
236
+ "kappa": kappa,
237
+ "alpha_singlet_bound": alpha_singlet_bound,
238
+ "alpha_singlet_observed": alpha_singlet_observed,
239
+ "n_singletons": n_singletons_total,
240
+ "n_escalations": n_escalations,
241
+ },
242
+ }
243
+
244
+ # Step 2b: MARGINAL ANALYSIS (ignoring true labels)
245
+ # Reconstruct full dataset
246
+ all_labels = np.concatenate([class_data[0]["labels"], class_data[1]["labels"]])
247
+ all_probs = np.concatenate([class_data[0]["probs"], class_data[1]["probs"]], axis=0)
248
+ all_indices = np.concatenate([class_data[0]["indices"], class_data[1]["indices"]])
249
+
250
+ # Sort back to original order
251
+ sort_idx = np.argsort(all_indices)
252
+ all_labels = all_labels[sort_idx]
253
+ all_probs = all_probs[sort_idx]
254
+
255
+ n_total = len(all_labels)
256
+
257
+ # Compute prediction sets for all samples
258
+ all_prediction_sets = []
259
+ for i in range(n_total):
260
+ score_0 = 1.0 - all_probs[i, 0]
261
+ score_1 = 1.0 - all_probs[i, 1]
262
+
263
+ pred_set = []
264
+ if score_0 <= threshold_0:
265
+ pred_set.append(0)
266
+ if score_1 <= threshold_1:
267
+ pred_set.append(1)
268
+
269
+ all_prediction_sets.append(pred_set)
270
+
271
+ # Count overall set sizes
272
+ n_abstentions_total = sum(len(ps) == 0 for ps in all_prediction_sets)
273
+ n_singletons_total = sum(len(ps) == 1 for ps in all_prediction_sets)
274
+ n_doublets_total = sum(len(ps) == 2 for ps in all_prediction_sets)
275
+
276
+ # Break down singletons by predicted class
277
+ n_singletons_pred_0 = sum(ps == [0] for ps in all_prediction_sets)
278
+ n_singletons_pred_1 = sum(ps == [1] for ps in all_prediction_sets)
279
+
280
+ # Compute overall coverage
281
+ n_covered = sum(all_labels[i] in all_prediction_sets[i] for i in range(n_total))
282
+ coverage = n_covered / n_total
283
+
284
+ # Compute errors on singletons
285
+ singleton_mask = [len(ps) == 1 for ps in all_prediction_sets]
286
+ n_singletons_covered = sum(all_labels[i] in all_prediction_sets[i] for i in range(n_total) if singleton_mask[i])
287
+ n_singletons_errors = n_singletons_total - n_singletons_covered
288
+
289
+ # Overall PAC bounds (using weighted average of alphas for interpretation)
290
+ n_escalations_total = n_doublets_total + n_abstentions_total
291
+
292
+ if n_escalations_total > 0 and n_singletons_total > 0:
293
+ rho_marginal = n_singletons_total / n_escalations_total
294
+ kappa_marginal = n_abstentions_total / n_escalations_total
295
+
296
+ # Weighted average alpha (by class size)
297
+ n_0 = class_data[0]["n"]
298
+ n_1 = class_data[1]["n"]
299
+ alpha_weighted = (n_0 * alpha_dict[0] + n_1 * alpha_dict[1]) / (n_0 + n_1)
300
+
301
+ alpha_singlet_bound_marginal = alpha_weighted * (1 + 1 / rho_marginal) - kappa_marginal / rho_marginal
302
+ alpha_singlet_observed_marginal = n_singletons_errors / n_singletons_total
303
+ else:
304
+ rho_marginal = None
305
+ kappa_marginal = None
306
+ alpha_weighted = None
307
+ alpha_singlet_bound_marginal = None
308
+ alpha_singlet_observed_marginal = None
309
+
310
+ prediction_stats["marginal"] = {
311
+ "n_total": n_total,
312
+ "coverage": {"count": n_covered, "rate": coverage, "ci_95": cp_interval(n_covered, n_total)},
313
+ "abstentions": cp_interval(n_abstentions_total, n_total),
314
+ "singletons": {
315
+ **cp_interval(n_singletons_total, n_total),
316
+ "pred_0": n_singletons_pred_0,
317
+ "pred_1": n_singletons_pred_1,
318
+ "errors": n_singletons_errors,
319
+ },
320
+ "doublets": cp_interval(n_doublets_total, n_total),
321
+ "prediction_sets": all_prediction_sets,
322
+ "pac_bounds": {
323
+ "rho": rho_marginal,
324
+ "kappa": kappa_marginal,
325
+ "alpha_weighted": alpha_weighted,
326
+ "alpha_singlet_bound": alpha_singlet_bound_marginal,
327
+ "alpha_singlet_observed": alpha_singlet_observed_marginal,
328
+ "n_singletons": n_singletons_total,
329
+ "n_escalations": n_escalations_total,
330
+ },
331
+ }
332
+
333
+ return calibration_result, prediction_stats
ssbc/core.py ADDED
@@ -0,0 +1,205 @@
1
+ """Core SSBC (Small-Sample Beta Correction) algorithm."""
2
+
3
+ import math
4
+ from dataclasses import dataclass
5
+ from typing import Literal
6
+
7
+ from scipy.stats import beta as beta_dist
8
+ from scipy.stats import betabinom, norm
9
+
10
+
11
+ @dataclass
12
+ class SSBCResult:
13
+ """Result of SSBC correction.
14
+
15
+ Attributes:
16
+ alpha_target: Target miscoverage rate
17
+ alpha_corrected: Corrected miscoverage rate (u_star / (n+1))
18
+ u_star: Optimal u value found by the algorithm
19
+ n: Calibration set size
20
+ satisfied_mass: Probability that coverage >= target
21
+ mode: "beta" for infinite test window, "beta-binomial" for finite
22
+ details: Additional diagnostic information
23
+ """
24
+
25
+ alpha_target: float
26
+ alpha_corrected: float
27
+ u_star: int
28
+ n: int
29
+ satisfied_mass: float
30
+ mode: Literal["beta", "beta-binomial"]
31
+ details: dict
32
+
33
+
34
+ def ssbc_correct(
35
+ alpha_target: float,
36
+ n: int,
37
+ delta: float,
38
+ *,
39
+ mode: Literal["beta", "beta-binomial"] = "beta",
40
+ m: int | None = None,
41
+ bracket_width: int | None = None,
42
+ ) -> SSBCResult:
43
+ """Small-Sample Beta Correction (SSBC), corrected acceptance rule.
44
+
45
+ Find the largest α' = u/(n+1) ≤ α_target such that:
46
+ P(Coverage(α') ≥ 1 - α_target) ≥ 1 - δ
47
+
48
+ where Coverage(α') ~ Beta(n+1-u, u) for infinite test window.
49
+
50
+ Parameters
51
+ ----------
52
+ alpha_target : float
53
+ Target miscoverage rate (must be in (0,1))
54
+ n : int
55
+ Calibration set size (must be >= 1)
56
+ delta : float
57
+ Risk tolerance / PAC parameter (must be in (0,1))
58
+ mode : {"beta", "beta-binomial"}, default="beta"
59
+ "beta" for infinite test window
60
+ "beta-binomial" for finite test window
61
+ m : int, optional
62
+ Test window size (required for beta-binomial mode)
63
+ bracket_width : int, optional
64
+ Search radius around initial guess (default: adaptive based on n)
65
+
66
+ Returns
67
+ -------
68
+ SSBCResult
69
+ Dataclass containing correction results and diagnostic details
70
+
71
+ Raises
72
+ ------
73
+ ValueError
74
+ If parameters are out of valid ranges
75
+
76
+ Examples
77
+ --------
78
+ >>> result = ssbc_correct(alpha_target=0.10, n=50, delta=0.10)
79
+ >>> print(f"Corrected alpha: {result.alpha_corrected:.4f}")
80
+
81
+ Notes
82
+ -----
83
+ The algorithm uses a bracketed search with an initial guess based on
84
+ normal approximation to the Beta distribution. For large n, the search
85
+ is adaptive to maintain efficiency.
86
+ """
87
+ # Input validation
88
+ if not (0.0 < alpha_target < 1.0):
89
+ raise ValueError("alpha_target must be in (0,1).")
90
+ if n < 1:
91
+ raise ValueError("n must be >= 1.")
92
+ if not (0.0 < delta < 1.0):
93
+ raise ValueError("delta must be in (0,1).")
94
+ if mode not in ("beta", "beta-binomial"):
95
+ raise ValueError("mode must be 'beta' or 'beta-binomial'.")
96
+
97
+ # Maximum u to search (α' must be ≤ α_target)
98
+ u_max = min(n, math.floor(alpha_target * (n + 1)))
99
+ target_coverage = 1 - alpha_target
100
+
101
+ # Initial guess for u using normal approximation to Beta distribution
102
+ # We want P(Beta(n+1-u, u) >= target_coverage) ≈ 1-δ
103
+ # Using normal approximation: u ≈ u_target - z_δ * sqrt(u_target)
104
+ # where u_target = (n+1)*α_target and z_δ = Φ^(-1)(1-δ)
105
+ u_target = (n + 1) * alpha_target
106
+ z_delta = norm.ppf(1 - delta) # quantile function (inverse CDF)
107
+ u_star_guess = max(1, math.floor(u_target - z_delta * math.sqrt(u_target)))
108
+
109
+ # Clamp to valid range
110
+ u_star_guess = min(u_max, u_star_guess)
111
+
112
+ # Bracket width (Δ in Algorithm 1)
113
+ if bracket_width is None:
114
+ # Adaptive bracket: wider for small n, scales with √n for large n
115
+ # For large n, the uncertainty scales as √u_target ~ (n*α)^(1/2)
116
+ bracket_width = max(5, min(int(2 * z_delta * math.sqrt(u_target)), n // 10))
117
+ bracket_width = min(bracket_width, 100) # cap at 100 for efficiency
118
+
119
+ # Search bounds - ensure we don't go outside [1, u_max]
120
+ u_min = max(1, u_star_guess - bracket_width)
121
+ u_search_max = min(u_max, u_star_guess + bracket_width)
122
+
123
+ # If the guess is way off (e.g., guess > u_max), fall back to full search
124
+ if u_min > u_search_max:
125
+ u_min = 1
126
+ u_search_max = u_max
127
+
128
+ if mode == "beta-binomial":
129
+ m_eval = m if m is not None else n
130
+ if m_eval < 1:
131
+ raise ValueError("m must be >= 1 for beta-binomial mode.")
132
+ k_thresh = math.ceil(target_coverage * m_eval)
133
+
134
+ u_star: int | None = None
135
+ mass_star: float | None = None
136
+
137
+ # Search from u_min up to u_search_max to find the largest u that satisfies the condition
138
+ # Keep updating u_star as we find larger values that work
139
+ search_log = []
140
+ for u in range(u_min, u_search_max + 1):
141
+ # When we calibrate at α' = u/(n+1), coverage follows:
142
+ a = n + 1 - u # first parameter
143
+ b = u # second parameter
144
+ alpha_prime = u / (n + 1)
145
+
146
+ if mode == "beta":
147
+ # P(Coverage ≥ target_coverage) where Coverage ~ Beta(a, b)
148
+ # Using: P(X >= x) = 1 - CDF(x) for continuous distributions
149
+ ptail = 1 - beta_dist.cdf(target_coverage, a, b)
150
+ else:
151
+ # P(X ≥ k_thresh) where X ~ BetaBinomial(m, a, b)
152
+ ptail = betabinom.sf(k_thresh - 1, m_eval, a, b)
153
+
154
+ passes = ptail >= 1 - delta
155
+ search_log.append(
156
+ {
157
+ "u": u,
158
+ "alpha_prime": alpha_prime,
159
+ "a": a,
160
+ "b": b,
161
+ "ptail": ptail,
162
+ "threshold": 1 - delta,
163
+ "passes": passes,
164
+ }
165
+ )
166
+
167
+ # Accept if probability is high enough - keep updating to find the largest
168
+ if passes:
169
+ u_star = u
170
+ mass_star = ptail
171
+
172
+ # If nothing passes, fall back to u=1 (most conservative)
173
+ if u_star is None:
174
+ u_star = 1
175
+ a = n + 1 - u_star
176
+ b = u_star
177
+ mass_star = (
178
+ 1 - beta_dist.cdf(target_coverage, a, b)
179
+ if mode == "beta"
180
+ else betabinom.sf(k_thresh - 1, (m if m else n), a, b)
181
+ )
182
+
183
+ alpha_corrected = u_star / (n + 1)
184
+
185
+ # At this point, mass_star is always set (either from loop or fallback)
186
+ assert mass_star is not None, "mass_star should be set by this point"
187
+
188
+ return SSBCResult(
189
+ alpha_target=alpha_target,
190
+ alpha_corrected=alpha_corrected,
191
+ u_star=u_star,
192
+ n=n,
193
+ satisfied_mass=mass_star,
194
+ mode=mode,
195
+ details=dict(
196
+ u_max=u_max,
197
+ u_star_guess=u_star_guess,
198
+ search_range=(u_min, u_search_max),
199
+ bracket_width=bracket_width,
200
+ delta=delta,
201
+ m=(m if (mode == "beta-binomial") else None),
202
+ acceptance_rule="P(Coverage >= target) >= 1-delta",
203
+ search_log=search_log,
204
+ ),
205
+ )