ssbc 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ssbc/__init__.py +59 -0
- ssbc/__main__.py +4 -0
- ssbc/cli.py +21 -0
- ssbc/conformal.py +333 -0
- ssbc/core.py +205 -0
- ssbc/hyperparameter.py +258 -0
- ssbc/simulation.py +148 -0
- ssbc/ssbc.py +1 -0
- ssbc/statistics.py +158 -0
- ssbc/utils.py +2 -0
- ssbc/visualization.py +459 -0
- ssbc-0.1.0.dist-info/METADATA +266 -0
- ssbc-0.1.0.dist-info/RECORD +17 -0
- ssbc-0.1.0.dist-info/WHEEL +5 -0
- ssbc-0.1.0.dist-info/entry_points.txt +2 -0
- ssbc-0.1.0.dist-info/licenses/LICENSE +21 -0
- ssbc-0.1.0.dist-info/top_level.txt +1 -0
ssbc/__init__.py
ADDED
@@ -0,0 +1,59 @@
|
|
1
|
+
"""Top-level package for SSBC (Small-Sample Beta Correction)."""
|
2
|
+
|
3
|
+
__author__ = """Petrus H Zwart"""
|
4
|
+
__email__ = "phzwart@lbl.gov"
|
5
|
+
__version__ = "0.1.0"
|
6
|
+
|
7
|
+
# Core SSBC algorithm
|
8
|
+
# Conformal prediction
|
9
|
+
from .conformal import (
|
10
|
+
mondrian_conformal_calibrate,
|
11
|
+
split_by_class,
|
12
|
+
)
|
13
|
+
from .core import (
|
14
|
+
SSBCResult,
|
15
|
+
ssbc_correct,
|
16
|
+
)
|
17
|
+
|
18
|
+
# Hyperparameter tuning
|
19
|
+
from .hyperparameter import (
|
20
|
+
sweep_and_plot_parallel_plotly,
|
21
|
+
sweep_hyperparams_and_collect,
|
22
|
+
)
|
23
|
+
|
24
|
+
# Simulation (for testing and examples)
|
25
|
+
from .simulation import (
|
26
|
+
BinaryClassifierSimulator,
|
27
|
+
)
|
28
|
+
|
29
|
+
# Statistics utilities
|
30
|
+
from .statistics import (
|
31
|
+
clopper_pearson_intervals,
|
32
|
+
cp_interval,
|
33
|
+
)
|
34
|
+
|
35
|
+
# Visualization and reporting
|
36
|
+
from .visualization import (
|
37
|
+
plot_parallel_coordinates_plotly,
|
38
|
+
report_prediction_stats,
|
39
|
+
)
|
40
|
+
|
41
|
+
__all__ = [
|
42
|
+
# Core
|
43
|
+
"SSBCResult",
|
44
|
+
"ssbc_correct",
|
45
|
+
# Conformal
|
46
|
+
"mondrian_conformal_calibrate",
|
47
|
+
"split_by_class",
|
48
|
+
# Statistics
|
49
|
+
"clopper_pearson_intervals",
|
50
|
+
"cp_interval",
|
51
|
+
# Simulation
|
52
|
+
"BinaryClassifierSimulator",
|
53
|
+
# Visualization
|
54
|
+
"report_prediction_stats",
|
55
|
+
"plot_parallel_coordinates_plotly",
|
56
|
+
# Hyperparameter
|
57
|
+
"sweep_hyperparams_and_collect",
|
58
|
+
"sweep_and_plot_parallel_plotly",
|
59
|
+
]
|
ssbc/__main__.py
ADDED
ssbc/cli.py
ADDED
@@ -0,0 +1,21 @@
|
|
1
|
+
"""Console script for ssbc."""
|
2
|
+
|
3
|
+
import typer
|
4
|
+
from rich.console import Console
|
5
|
+
|
6
|
+
from ssbc import utils
|
7
|
+
|
8
|
+
app = typer.Typer()
|
9
|
+
console = Console()
|
10
|
+
|
11
|
+
|
12
|
+
@app.command()
|
13
|
+
def main():
|
14
|
+
"""Console script for ssbc."""
|
15
|
+
console.print("Replace this message by putting your code into ssbc.cli.main")
|
16
|
+
console.print("See Typer documentation at https://typer.tiangolo.com/")
|
17
|
+
utils.do_something_useful()
|
18
|
+
|
19
|
+
|
20
|
+
if __name__ == "__main__":
|
21
|
+
app()
|
ssbc/conformal.py
ADDED
@@ -0,0 +1,333 @@
|
|
1
|
+
"""Mondrian conformal prediction with SSBC correction."""
|
2
|
+
|
3
|
+
from typing import Any, Literal
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
|
7
|
+
from .core import ssbc_correct
|
8
|
+
from .statistics import cp_interval
|
9
|
+
|
10
|
+
|
11
|
+
def split_by_class(labels: np.ndarray, probs: np.ndarray) -> dict[int, dict[str, Any]]:
|
12
|
+
"""Split calibration data by true class for Mondrian conformal prediction.
|
13
|
+
|
14
|
+
Parameters
|
15
|
+
----------
|
16
|
+
labels : np.ndarray, shape (n,)
|
17
|
+
True binary labels (0 or 1)
|
18
|
+
probs : np.ndarray, shape (n, 2)
|
19
|
+
Classification probabilities [P(class=0), P(class=1)]
|
20
|
+
|
21
|
+
Returns
|
22
|
+
-------
|
23
|
+
dict
|
24
|
+
Dictionary with keys 0 and 1, each containing:
|
25
|
+
- 'labels': labels for this class (all same value)
|
26
|
+
- 'probs': probabilities for samples in this class
|
27
|
+
- 'indices': original indices (for tracking)
|
28
|
+
- 'n': number of samples in this class
|
29
|
+
|
30
|
+
Examples
|
31
|
+
--------
|
32
|
+
>>> labels = np.array([0, 1, 0, 1])
|
33
|
+
>>> probs = np.array([[0.8, 0.2], [0.3, 0.7], [0.9, 0.1], [0.2, 0.8]])
|
34
|
+
>>> class_data = split_by_class(labels, probs)
|
35
|
+
>>> print(class_data[0]['n']) # Number of class 0 samples
|
36
|
+
2
|
37
|
+
"""
|
38
|
+
class_data = {}
|
39
|
+
|
40
|
+
for label in [0, 1]:
|
41
|
+
mask = labels == label
|
42
|
+
indices = np.where(mask)[0]
|
43
|
+
|
44
|
+
class_data[label] = {"labels": labels[mask], "probs": probs[mask], "indices": indices, "n": np.sum(mask)}
|
45
|
+
|
46
|
+
return class_data
|
47
|
+
|
48
|
+
|
49
|
+
def mondrian_conformal_calibrate(
|
50
|
+
class_data: dict[int, dict[str, Any]],
|
51
|
+
alpha_target: float | dict[int, float],
|
52
|
+
delta: float | dict[int, float],
|
53
|
+
mode: Literal["beta", "beta-binomial"] = "beta",
|
54
|
+
m: int | None = None,
|
55
|
+
) -> tuple[dict[int, dict[str, Any]], dict[Any, Any]]:
|
56
|
+
"""Perform Mondrian (per-class) conformal calibration with SSBC correction.
|
57
|
+
|
58
|
+
For each class, compute:
|
59
|
+
1. Nonconformity scores: s(x, y) = 1 - P(y|x)
|
60
|
+
2. SSBC-corrected alpha for PAC guarantee
|
61
|
+
3. Conformal quantile threshold
|
62
|
+
4. Singleton error rate bounds via PAC guarantee
|
63
|
+
|
64
|
+
Then evaluate prediction set sizes on calibration data PER CLASS and MARGINALLY.
|
65
|
+
|
66
|
+
Parameters
|
67
|
+
----------
|
68
|
+
class_data : dict
|
69
|
+
Output from split_by_class()
|
70
|
+
alpha_target : float or dict
|
71
|
+
Target miscoverage rate for each class
|
72
|
+
If float: same for both classes
|
73
|
+
If dict: {0: α0, 1: α1} for per-class control
|
74
|
+
delta : float or dict
|
75
|
+
PAC risk tolerance for each class
|
76
|
+
If float: same for both classes
|
77
|
+
If dict: {0: δ0, 1: δ1} for per-class control
|
78
|
+
mode : str, default="beta"
|
79
|
+
"beta" (infinite test) or "beta-binomial" (finite test)
|
80
|
+
m : int, optional
|
81
|
+
Test window size for beta-binomial mode
|
82
|
+
|
83
|
+
Returns
|
84
|
+
-------
|
85
|
+
calibration_result : dict
|
86
|
+
Dictionary with keys 0 and 1, each containing calibration info
|
87
|
+
prediction_stats : dict
|
88
|
+
Dictionary with keys:
|
89
|
+
- 0, 1: per-class statistics (conditioned on true label)
|
90
|
+
- 'marginal': overall statistics (ignoring true labels)
|
91
|
+
|
92
|
+
Examples
|
93
|
+
--------
|
94
|
+
>>> labels = np.array([0, 1, 0, 1])
|
95
|
+
>>> probs = np.array([[0.8, 0.2], [0.3, 0.7], [0.9, 0.1], [0.2, 0.8]])
|
96
|
+
>>> class_data = split_by_class(labels, probs)
|
97
|
+
>>> cal_result, pred_stats = mondrian_conformal_calibrate(
|
98
|
+
... class_data, alpha_target=0.1, delta=0.1
|
99
|
+
... )
|
100
|
+
"""
|
101
|
+
# Handle scalar or dict inputs for alpha and delta
|
102
|
+
alpha_dict: dict[int, float]
|
103
|
+
if isinstance(alpha_target, int | float):
|
104
|
+
alpha_dict = {0: float(alpha_target), 1: float(alpha_target)}
|
105
|
+
else:
|
106
|
+
# alpha_target is dict[int, float] in this branch
|
107
|
+
assert isinstance(alpha_target, dict), "alpha_target must be dict if not scalar"
|
108
|
+
alpha_dict = {k: float(v) for k, v in alpha_target.items()}
|
109
|
+
|
110
|
+
delta_dict: dict[int, float]
|
111
|
+
if isinstance(delta, int | float):
|
112
|
+
delta_dict = {0: float(delta), 1: float(delta)}
|
113
|
+
else:
|
114
|
+
# delta is dict[int, float] in this branch
|
115
|
+
assert isinstance(delta, dict), "delta must be dict if not scalar"
|
116
|
+
delta_dict = {k: float(v) for k, v in delta.items()}
|
117
|
+
|
118
|
+
calibration_result = {}
|
119
|
+
|
120
|
+
# Step 1: Calibrate per class
|
121
|
+
for label in [0, 1]:
|
122
|
+
data = class_data[label]
|
123
|
+
n = data["n"]
|
124
|
+
alpha_class = alpha_dict[label]
|
125
|
+
delta_class = delta_dict[label]
|
126
|
+
|
127
|
+
if n == 0:
|
128
|
+
calibration_result[label] = {
|
129
|
+
"n": 0,
|
130
|
+
"alpha_target": alpha_class,
|
131
|
+
"alpha_corrected": None,
|
132
|
+
"delta": delta_class,
|
133
|
+
"threshold": None,
|
134
|
+
"scores": np.array([]),
|
135
|
+
"ssbc_result": None,
|
136
|
+
"error": "No calibration samples for this class",
|
137
|
+
}
|
138
|
+
continue
|
139
|
+
|
140
|
+
# Compute nonconformity scores: s(x, y) = 1 - P(y|x)
|
141
|
+
true_class_probs = data["probs"][:, label]
|
142
|
+
scores = 1.0 - true_class_probs
|
143
|
+
|
144
|
+
# Apply SSBC to get corrected alpha
|
145
|
+
ssbc_result = ssbc_correct(alpha_target=alpha_class, n=n, delta=delta_class, mode=mode, m=m)
|
146
|
+
|
147
|
+
alpha_corrected = ssbc_result.alpha_corrected
|
148
|
+
|
149
|
+
# Compute conformal quantile threshold
|
150
|
+
k = int(np.ceil((n + 1) * (1 - alpha_corrected)))
|
151
|
+
k = min(k, n)
|
152
|
+
|
153
|
+
sorted_scores = np.sort(scores)
|
154
|
+
threshold = sorted_scores[k - 1] if k > 0 else sorted_scores[0]
|
155
|
+
|
156
|
+
calibration_result[label] = {
|
157
|
+
"n": n,
|
158
|
+
"alpha_target": alpha_class,
|
159
|
+
"alpha_corrected": alpha_corrected,
|
160
|
+
"delta": delta_class,
|
161
|
+
"threshold": threshold,
|
162
|
+
"scores": sorted_scores,
|
163
|
+
"ssbc_result": ssbc_result,
|
164
|
+
"k": k,
|
165
|
+
}
|
166
|
+
|
167
|
+
# Step 2: Evaluate prediction sets
|
168
|
+
if calibration_result[0].get("threshold") is None or calibration_result[1].get("threshold") is None:
|
169
|
+
return calibration_result, {
|
170
|
+
"error": "Cannot compute prediction sets - missing threshold for at least one class"
|
171
|
+
}
|
172
|
+
|
173
|
+
threshold_0 = calibration_result[0]["threshold"]
|
174
|
+
threshold_1 = calibration_result[1]["threshold"]
|
175
|
+
|
176
|
+
prediction_stats = {}
|
177
|
+
|
178
|
+
# Step 2a: Evaluate per true class
|
179
|
+
for true_label in [0, 1]:
|
180
|
+
data = class_data[true_label]
|
181
|
+
n_class = data["n"]
|
182
|
+
|
183
|
+
if n_class == 0:
|
184
|
+
prediction_stats[true_label] = {"n_class": 0, "error": "No samples in this class"}
|
185
|
+
continue
|
186
|
+
|
187
|
+
probs = data["probs"]
|
188
|
+
prediction_sets = []
|
189
|
+
|
190
|
+
for i in range(n_class):
|
191
|
+
score_0 = 1.0 - probs[i, 0]
|
192
|
+
score_1 = 1.0 - probs[i, 1]
|
193
|
+
|
194
|
+
pred_set = []
|
195
|
+
if score_0 <= threshold_0:
|
196
|
+
pred_set.append(0)
|
197
|
+
if score_1 <= threshold_1:
|
198
|
+
pred_set.append(1)
|
199
|
+
|
200
|
+
prediction_sets.append(pred_set)
|
201
|
+
|
202
|
+
# Count set sizes and correctness
|
203
|
+
n_abstentions = sum(len(ps) == 0 for ps in prediction_sets)
|
204
|
+
n_doublets = sum(len(ps) == 2 for ps in prediction_sets)
|
205
|
+
|
206
|
+
n_singletons_correct = sum(ps == [true_label] for ps in prediction_sets)
|
207
|
+
n_singletons_incorrect = sum(len(ps) == 1 and true_label not in ps for ps in prediction_sets)
|
208
|
+
n_singletons_total = n_singletons_correct + n_singletons_incorrect
|
209
|
+
|
210
|
+
# PAC bounds
|
211
|
+
n_escalations = n_doublets + n_abstentions
|
212
|
+
|
213
|
+
if n_escalations > 0 and n_singletons_total > 0:
|
214
|
+
rho = n_singletons_total / n_escalations
|
215
|
+
kappa = n_abstentions / n_escalations
|
216
|
+
alpha_singlet_bound = alpha_dict[true_label] * (1 + 1 / rho) - kappa / rho
|
217
|
+
alpha_singlet_observed = n_singletons_incorrect / n_singletons_total if n_singletons_total > 0 else 0.0
|
218
|
+
else:
|
219
|
+
rho = None
|
220
|
+
kappa = None
|
221
|
+
alpha_singlet_bound = None
|
222
|
+
alpha_singlet_observed = None
|
223
|
+
|
224
|
+
prediction_stats[true_label] = {
|
225
|
+
"n_class": n_class,
|
226
|
+
"alpha_target": alpha_dict[true_label],
|
227
|
+
"delta": delta_dict[true_label],
|
228
|
+
"abstentions": cp_interval(n_abstentions, n_class),
|
229
|
+
"singletons": cp_interval(n_singletons_total, n_class),
|
230
|
+
"singletons_correct": cp_interval(n_singletons_correct, n_class),
|
231
|
+
"singletons_incorrect": cp_interval(n_singletons_incorrect, n_class),
|
232
|
+
"doublets": cp_interval(n_doublets, n_class),
|
233
|
+
"prediction_sets": prediction_sets,
|
234
|
+
"pac_bounds": {
|
235
|
+
"rho": rho,
|
236
|
+
"kappa": kappa,
|
237
|
+
"alpha_singlet_bound": alpha_singlet_bound,
|
238
|
+
"alpha_singlet_observed": alpha_singlet_observed,
|
239
|
+
"n_singletons": n_singletons_total,
|
240
|
+
"n_escalations": n_escalations,
|
241
|
+
},
|
242
|
+
}
|
243
|
+
|
244
|
+
# Step 2b: MARGINAL ANALYSIS (ignoring true labels)
|
245
|
+
# Reconstruct full dataset
|
246
|
+
all_labels = np.concatenate([class_data[0]["labels"], class_data[1]["labels"]])
|
247
|
+
all_probs = np.concatenate([class_data[0]["probs"], class_data[1]["probs"]], axis=0)
|
248
|
+
all_indices = np.concatenate([class_data[0]["indices"], class_data[1]["indices"]])
|
249
|
+
|
250
|
+
# Sort back to original order
|
251
|
+
sort_idx = np.argsort(all_indices)
|
252
|
+
all_labels = all_labels[sort_idx]
|
253
|
+
all_probs = all_probs[sort_idx]
|
254
|
+
|
255
|
+
n_total = len(all_labels)
|
256
|
+
|
257
|
+
# Compute prediction sets for all samples
|
258
|
+
all_prediction_sets = []
|
259
|
+
for i in range(n_total):
|
260
|
+
score_0 = 1.0 - all_probs[i, 0]
|
261
|
+
score_1 = 1.0 - all_probs[i, 1]
|
262
|
+
|
263
|
+
pred_set = []
|
264
|
+
if score_0 <= threshold_0:
|
265
|
+
pred_set.append(0)
|
266
|
+
if score_1 <= threshold_1:
|
267
|
+
pred_set.append(1)
|
268
|
+
|
269
|
+
all_prediction_sets.append(pred_set)
|
270
|
+
|
271
|
+
# Count overall set sizes
|
272
|
+
n_abstentions_total = sum(len(ps) == 0 for ps in all_prediction_sets)
|
273
|
+
n_singletons_total = sum(len(ps) == 1 for ps in all_prediction_sets)
|
274
|
+
n_doublets_total = sum(len(ps) == 2 for ps in all_prediction_sets)
|
275
|
+
|
276
|
+
# Break down singletons by predicted class
|
277
|
+
n_singletons_pred_0 = sum(ps == [0] for ps in all_prediction_sets)
|
278
|
+
n_singletons_pred_1 = sum(ps == [1] for ps in all_prediction_sets)
|
279
|
+
|
280
|
+
# Compute overall coverage
|
281
|
+
n_covered = sum(all_labels[i] in all_prediction_sets[i] for i in range(n_total))
|
282
|
+
coverage = n_covered / n_total
|
283
|
+
|
284
|
+
# Compute errors on singletons
|
285
|
+
singleton_mask = [len(ps) == 1 for ps in all_prediction_sets]
|
286
|
+
n_singletons_covered = sum(all_labels[i] in all_prediction_sets[i] for i in range(n_total) if singleton_mask[i])
|
287
|
+
n_singletons_errors = n_singletons_total - n_singletons_covered
|
288
|
+
|
289
|
+
# Overall PAC bounds (using weighted average of alphas for interpretation)
|
290
|
+
n_escalations_total = n_doublets_total + n_abstentions_total
|
291
|
+
|
292
|
+
if n_escalations_total > 0 and n_singletons_total > 0:
|
293
|
+
rho_marginal = n_singletons_total / n_escalations_total
|
294
|
+
kappa_marginal = n_abstentions_total / n_escalations_total
|
295
|
+
|
296
|
+
# Weighted average alpha (by class size)
|
297
|
+
n_0 = class_data[0]["n"]
|
298
|
+
n_1 = class_data[1]["n"]
|
299
|
+
alpha_weighted = (n_0 * alpha_dict[0] + n_1 * alpha_dict[1]) / (n_0 + n_1)
|
300
|
+
|
301
|
+
alpha_singlet_bound_marginal = alpha_weighted * (1 + 1 / rho_marginal) - kappa_marginal / rho_marginal
|
302
|
+
alpha_singlet_observed_marginal = n_singletons_errors / n_singletons_total
|
303
|
+
else:
|
304
|
+
rho_marginal = None
|
305
|
+
kappa_marginal = None
|
306
|
+
alpha_weighted = None
|
307
|
+
alpha_singlet_bound_marginal = None
|
308
|
+
alpha_singlet_observed_marginal = None
|
309
|
+
|
310
|
+
prediction_stats["marginal"] = {
|
311
|
+
"n_total": n_total,
|
312
|
+
"coverage": {"count": n_covered, "rate": coverage, "ci_95": cp_interval(n_covered, n_total)},
|
313
|
+
"abstentions": cp_interval(n_abstentions_total, n_total),
|
314
|
+
"singletons": {
|
315
|
+
**cp_interval(n_singletons_total, n_total),
|
316
|
+
"pred_0": n_singletons_pred_0,
|
317
|
+
"pred_1": n_singletons_pred_1,
|
318
|
+
"errors": n_singletons_errors,
|
319
|
+
},
|
320
|
+
"doublets": cp_interval(n_doublets_total, n_total),
|
321
|
+
"prediction_sets": all_prediction_sets,
|
322
|
+
"pac_bounds": {
|
323
|
+
"rho": rho_marginal,
|
324
|
+
"kappa": kappa_marginal,
|
325
|
+
"alpha_weighted": alpha_weighted,
|
326
|
+
"alpha_singlet_bound": alpha_singlet_bound_marginal,
|
327
|
+
"alpha_singlet_observed": alpha_singlet_observed_marginal,
|
328
|
+
"n_singletons": n_singletons_total,
|
329
|
+
"n_escalations": n_escalations_total,
|
330
|
+
},
|
331
|
+
}
|
332
|
+
|
333
|
+
return calibration_result, prediction_stats
|
ssbc/core.py
ADDED
@@ -0,0 +1,205 @@
|
|
1
|
+
"""Core SSBC (Small-Sample Beta Correction) algorithm."""
|
2
|
+
|
3
|
+
import math
|
4
|
+
from dataclasses import dataclass
|
5
|
+
from typing import Literal
|
6
|
+
|
7
|
+
from scipy.stats import beta as beta_dist
|
8
|
+
from scipy.stats import betabinom, norm
|
9
|
+
|
10
|
+
|
11
|
+
@dataclass
|
12
|
+
class SSBCResult:
|
13
|
+
"""Result of SSBC correction.
|
14
|
+
|
15
|
+
Attributes:
|
16
|
+
alpha_target: Target miscoverage rate
|
17
|
+
alpha_corrected: Corrected miscoverage rate (u_star / (n+1))
|
18
|
+
u_star: Optimal u value found by the algorithm
|
19
|
+
n: Calibration set size
|
20
|
+
satisfied_mass: Probability that coverage >= target
|
21
|
+
mode: "beta" for infinite test window, "beta-binomial" for finite
|
22
|
+
details: Additional diagnostic information
|
23
|
+
"""
|
24
|
+
|
25
|
+
alpha_target: float
|
26
|
+
alpha_corrected: float
|
27
|
+
u_star: int
|
28
|
+
n: int
|
29
|
+
satisfied_mass: float
|
30
|
+
mode: Literal["beta", "beta-binomial"]
|
31
|
+
details: dict
|
32
|
+
|
33
|
+
|
34
|
+
def ssbc_correct(
|
35
|
+
alpha_target: float,
|
36
|
+
n: int,
|
37
|
+
delta: float,
|
38
|
+
*,
|
39
|
+
mode: Literal["beta", "beta-binomial"] = "beta",
|
40
|
+
m: int | None = None,
|
41
|
+
bracket_width: int | None = None,
|
42
|
+
) -> SSBCResult:
|
43
|
+
"""Small-Sample Beta Correction (SSBC), corrected acceptance rule.
|
44
|
+
|
45
|
+
Find the largest α' = u/(n+1) ≤ α_target such that:
|
46
|
+
P(Coverage(α') ≥ 1 - α_target) ≥ 1 - δ
|
47
|
+
|
48
|
+
where Coverage(α') ~ Beta(n+1-u, u) for infinite test window.
|
49
|
+
|
50
|
+
Parameters
|
51
|
+
----------
|
52
|
+
alpha_target : float
|
53
|
+
Target miscoverage rate (must be in (0,1))
|
54
|
+
n : int
|
55
|
+
Calibration set size (must be >= 1)
|
56
|
+
delta : float
|
57
|
+
Risk tolerance / PAC parameter (must be in (0,1))
|
58
|
+
mode : {"beta", "beta-binomial"}, default="beta"
|
59
|
+
"beta" for infinite test window
|
60
|
+
"beta-binomial" for finite test window
|
61
|
+
m : int, optional
|
62
|
+
Test window size (required for beta-binomial mode)
|
63
|
+
bracket_width : int, optional
|
64
|
+
Search radius around initial guess (default: adaptive based on n)
|
65
|
+
|
66
|
+
Returns
|
67
|
+
-------
|
68
|
+
SSBCResult
|
69
|
+
Dataclass containing correction results and diagnostic details
|
70
|
+
|
71
|
+
Raises
|
72
|
+
------
|
73
|
+
ValueError
|
74
|
+
If parameters are out of valid ranges
|
75
|
+
|
76
|
+
Examples
|
77
|
+
--------
|
78
|
+
>>> result = ssbc_correct(alpha_target=0.10, n=50, delta=0.10)
|
79
|
+
>>> print(f"Corrected alpha: {result.alpha_corrected:.4f}")
|
80
|
+
|
81
|
+
Notes
|
82
|
+
-----
|
83
|
+
The algorithm uses a bracketed search with an initial guess based on
|
84
|
+
normal approximation to the Beta distribution. For large n, the search
|
85
|
+
is adaptive to maintain efficiency.
|
86
|
+
"""
|
87
|
+
# Input validation
|
88
|
+
if not (0.0 < alpha_target < 1.0):
|
89
|
+
raise ValueError("alpha_target must be in (0,1).")
|
90
|
+
if n < 1:
|
91
|
+
raise ValueError("n must be >= 1.")
|
92
|
+
if not (0.0 < delta < 1.0):
|
93
|
+
raise ValueError("delta must be in (0,1).")
|
94
|
+
if mode not in ("beta", "beta-binomial"):
|
95
|
+
raise ValueError("mode must be 'beta' or 'beta-binomial'.")
|
96
|
+
|
97
|
+
# Maximum u to search (α' must be ≤ α_target)
|
98
|
+
u_max = min(n, math.floor(alpha_target * (n + 1)))
|
99
|
+
target_coverage = 1 - alpha_target
|
100
|
+
|
101
|
+
# Initial guess for u using normal approximation to Beta distribution
|
102
|
+
# We want P(Beta(n+1-u, u) >= target_coverage) ≈ 1-δ
|
103
|
+
# Using normal approximation: u ≈ u_target - z_δ * sqrt(u_target)
|
104
|
+
# where u_target = (n+1)*α_target and z_δ = Φ^(-1)(1-δ)
|
105
|
+
u_target = (n + 1) * alpha_target
|
106
|
+
z_delta = norm.ppf(1 - delta) # quantile function (inverse CDF)
|
107
|
+
u_star_guess = max(1, math.floor(u_target - z_delta * math.sqrt(u_target)))
|
108
|
+
|
109
|
+
# Clamp to valid range
|
110
|
+
u_star_guess = min(u_max, u_star_guess)
|
111
|
+
|
112
|
+
# Bracket width (Δ in Algorithm 1)
|
113
|
+
if bracket_width is None:
|
114
|
+
# Adaptive bracket: wider for small n, scales with √n for large n
|
115
|
+
# For large n, the uncertainty scales as √u_target ~ (n*α)^(1/2)
|
116
|
+
bracket_width = max(5, min(int(2 * z_delta * math.sqrt(u_target)), n // 10))
|
117
|
+
bracket_width = min(bracket_width, 100) # cap at 100 for efficiency
|
118
|
+
|
119
|
+
# Search bounds - ensure we don't go outside [1, u_max]
|
120
|
+
u_min = max(1, u_star_guess - bracket_width)
|
121
|
+
u_search_max = min(u_max, u_star_guess + bracket_width)
|
122
|
+
|
123
|
+
# If the guess is way off (e.g., guess > u_max), fall back to full search
|
124
|
+
if u_min > u_search_max:
|
125
|
+
u_min = 1
|
126
|
+
u_search_max = u_max
|
127
|
+
|
128
|
+
if mode == "beta-binomial":
|
129
|
+
m_eval = m if m is not None else n
|
130
|
+
if m_eval < 1:
|
131
|
+
raise ValueError("m must be >= 1 for beta-binomial mode.")
|
132
|
+
k_thresh = math.ceil(target_coverage * m_eval)
|
133
|
+
|
134
|
+
u_star: int | None = None
|
135
|
+
mass_star: float | None = None
|
136
|
+
|
137
|
+
# Search from u_min up to u_search_max to find the largest u that satisfies the condition
|
138
|
+
# Keep updating u_star as we find larger values that work
|
139
|
+
search_log = []
|
140
|
+
for u in range(u_min, u_search_max + 1):
|
141
|
+
# When we calibrate at α' = u/(n+1), coverage follows:
|
142
|
+
a = n + 1 - u # first parameter
|
143
|
+
b = u # second parameter
|
144
|
+
alpha_prime = u / (n + 1)
|
145
|
+
|
146
|
+
if mode == "beta":
|
147
|
+
# P(Coverage ≥ target_coverage) where Coverage ~ Beta(a, b)
|
148
|
+
# Using: P(X >= x) = 1 - CDF(x) for continuous distributions
|
149
|
+
ptail = 1 - beta_dist.cdf(target_coverage, a, b)
|
150
|
+
else:
|
151
|
+
# P(X ≥ k_thresh) where X ~ BetaBinomial(m, a, b)
|
152
|
+
ptail = betabinom.sf(k_thresh - 1, m_eval, a, b)
|
153
|
+
|
154
|
+
passes = ptail >= 1 - delta
|
155
|
+
search_log.append(
|
156
|
+
{
|
157
|
+
"u": u,
|
158
|
+
"alpha_prime": alpha_prime,
|
159
|
+
"a": a,
|
160
|
+
"b": b,
|
161
|
+
"ptail": ptail,
|
162
|
+
"threshold": 1 - delta,
|
163
|
+
"passes": passes,
|
164
|
+
}
|
165
|
+
)
|
166
|
+
|
167
|
+
# Accept if probability is high enough - keep updating to find the largest
|
168
|
+
if passes:
|
169
|
+
u_star = u
|
170
|
+
mass_star = ptail
|
171
|
+
|
172
|
+
# If nothing passes, fall back to u=1 (most conservative)
|
173
|
+
if u_star is None:
|
174
|
+
u_star = 1
|
175
|
+
a = n + 1 - u_star
|
176
|
+
b = u_star
|
177
|
+
mass_star = (
|
178
|
+
1 - beta_dist.cdf(target_coverage, a, b)
|
179
|
+
if mode == "beta"
|
180
|
+
else betabinom.sf(k_thresh - 1, (m if m else n), a, b)
|
181
|
+
)
|
182
|
+
|
183
|
+
alpha_corrected = u_star / (n + 1)
|
184
|
+
|
185
|
+
# At this point, mass_star is always set (either from loop or fallback)
|
186
|
+
assert mass_star is not None, "mass_star should be set by this point"
|
187
|
+
|
188
|
+
return SSBCResult(
|
189
|
+
alpha_target=alpha_target,
|
190
|
+
alpha_corrected=alpha_corrected,
|
191
|
+
u_star=u_star,
|
192
|
+
n=n,
|
193
|
+
satisfied_mass=mass_star,
|
194
|
+
mode=mode,
|
195
|
+
details=dict(
|
196
|
+
u_max=u_max,
|
197
|
+
u_star_guess=u_star_guess,
|
198
|
+
search_range=(u_min, u_search_max),
|
199
|
+
bracket_width=bracket_width,
|
200
|
+
delta=delta,
|
201
|
+
m=(m if (mode == "beta-binomial") else None),
|
202
|
+
acceptance_rule="P(Coverage >= target) >= 1-delta",
|
203
|
+
search_log=search_log,
|
204
|
+
),
|
205
|
+
)
|