ssbc 0.1.0__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ssbc/__init__.py +50 -2
- ssbc/bootstrap.py +411 -0
- ssbc/cli.py +0 -3
- ssbc/conformal.py +700 -1
- ssbc/cross_conformal.py +425 -0
- ssbc/mcp_server.py +93 -0
- ssbc/operational_bounds_simple.py +367 -0
- ssbc/rigorous_report.py +601 -0
- ssbc/statistics.py +70 -0
- ssbc/utils.py +72 -2
- ssbc/validation.py +409 -0
- ssbc/visualization.py +323 -300
- ssbc-1.1.0.dist-info/METADATA +337 -0
- ssbc-1.1.0.dist-info/RECORD +22 -0
- ssbc-1.1.0.dist-info/licenses/LICENSE +29 -0
- ssbc/ssbc.py +0 -1
- ssbc-0.1.0.dist-info/METADATA +0 -266
- ssbc-0.1.0.dist-info/RECORD +0 -17
- ssbc-0.1.0.dist-info/licenses/LICENSE +0 -21
- {ssbc-0.1.0.dist-info → ssbc-1.1.0.dist-info}/WHEEL +0 -0
- {ssbc-0.1.0.dist-info → ssbc-1.1.0.dist-info}/entry_points.txt +0 -0
- {ssbc-0.1.0.dist-info → ssbc-1.1.0.dist-info}/top_level.txt +0 -0
ssbc/cross_conformal.py
ADDED
@@ -0,0 +1,425 @@
|
|
1
|
+
"""Cross-conformal validation for estimating rate variability.
|
2
|
+
|
3
|
+
This implements K-fold cross-validation specifically for conformal prediction:
|
4
|
+
- Split calibration data into K folds
|
5
|
+
- For each fold: train thresholds on K-1 folds, evaluate rates on held-out fold
|
6
|
+
- Aggregate rates across folds to quantify finite-sample variability
|
7
|
+
|
8
|
+
Different from:
|
9
|
+
- LOO-CV: Leave-one-out, aggregates counts (not rates per fold)
|
10
|
+
- Bootstrap: Resamples with replacement, tests on fresh data
|
11
|
+
- Cross-conformal: K-fold split, estimates rate distribution from finite calibration
|
12
|
+
"""
|
13
|
+
|
14
|
+
from typing import Any
|
15
|
+
|
16
|
+
import numpy as np
|
17
|
+
|
18
|
+
from ssbc.conformal import split_by_class
|
19
|
+
from ssbc.core import ssbc_correct
|
20
|
+
|
21
|
+
|
22
|
+
def _compute_fold_rates_mondrian(
|
23
|
+
train_labels: np.ndarray,
|
24
|
+
train_probs: np.ndarray,
|
25
|
+
test_labels: np.ndarray,
|
26
|
+
test_probs: np.ndarray,
|
27
|
+
alpha_target: float,
|
28
|
+
delta: float,
|
29
|
+
) -> dict[str, dict[str, float]]:
|
30
|
+
"""Compute operational rates for one fold in Mondrian conformal.
|
31
|
+
|
32
|
+
Parameters
|
33
|
+
----------
|
34
|
+
train_labels : np.ndarray
|
35
|
+
Training fold labels
|
36
|
+
train_probs : np.ndarray
|
37
|
+
Training fold probabilities
|
38
|
+
test_labels : np.ndarray
|
39
|
+
Test fold labels
|
40
|
+
test_probs : np.ndarray
|
41
|
+
Test fold probabilities
|
42
|
+
alpha_target : float
|
43
|
+
Target miscoverage
|
44
|
+
delta : float
|
45
|
+
PAC risk (for SSBC correction)
|
46
|
+
|
47
|
+
Returns
|
48
|
+
-------
|
49
|
+
dict
|
50
|
+
Rates for this fold: marginal and per-class
|
51
|
+
"""
|
52
|
+
# Split training data by class
|
53
|
+
train_class_data = split_by_class(train_labels, train_probs)
|
54
|
+
|
55
|
+
# SSBC correction and threshold computation
|
56
|
+
thresholds = {}
|
57
|
+
for class_label in [0, 1]:
|
58
|
+
class_data = train_class_data[class_label]
|
59
|
+
if class_data["n"] == 0:
|
60
|
+
thresholds[class_label] = 0.0
|
61
|
+
continue
|
62
|
+
|
63
|
+
# SSBC correction
|
64
|
+
ssbc_result = ssbc_correct(alpha_target=alpha_target, n=class_data["n"], delta=delta)
|
65
|
+
|
66
|
+
# Compute threshold
|
67
|
+
n_class = class_data["n"]
|
68
|
+
k = int(np.ceil((n_class + 1) * (1 - ssbc_result.alpha_corrected)))
|
69
|
+
|
70
|
+
mask = train_labels == class_label
|
71
|
+
scores = 1.0 - train_probs[mask, class_label]
|
72
|
+
sorted_scores = np.sort(scores)
|
73
|
+
|
74
|
+
thresholds[class_label] = sorted_scores[min(k - 1, len(sorted_scores) - 1)]
|
75
|
+
|
76
|
+
# Evaluate on test fold
|
77
|
+
n_test = len(test_labels)
|
78
|
+
|
79
|
+
# Marginal counters
|
80
|
+
n_abstentions = 0
|
81
|
+
n_singletons = 0
|
82
|
+
n_doublets = 0
|
83
|
+
n_singletons_correct = 0
|
84
|
+
|
85
|
+
# Per-class counters
|
86
|
+
counts_0 = {"abstentions": 0, "singletons": 0, "doublets": 0, "singletons_correct": 0, "n": 0}
|
87
|
+
counts_1 = {"abstentions": 0, "singletons": 0, "doublets": 0, "singletons_correct": 0, "n": 0}
|
88
|
+
|
89
|
+
for i in range(n_test):
|
90
|
+
true_label = test_labels[i]
|
91
|
+
score_0 = 1.0 - test_probs[i, 0]
|
92
|
+
score_1 = 1.0 - test_probs[i, 1]
|
93
|
+
|
94
|
+
in_0 = score_0 <= thresholds[0]
|
95
|
+
in_1 = score_1 <= thresholds[1]
|
96
|
+
|
97
|
+
# Marginal
|
98
|
+
if in_0 and in_1:
|
99
|
+
n_doublets += 1
|
100
|
+
elif in_0 or in_1:
|
101
|
+
n_singletons += 1
|
102
|
+
if (in_0 and true_label == 0) or (in_1 and true_label == 1):
|
103
|
+
n_singletons_correct += 1
|
104
|
+
else:
|
105
|
+
n_abstentions += 1
|
106
|
+
|
107
|
+
# Per-class
|
108
|
+
if true_label == 0:
|
109
|
+
counts_0["n"] += 1
|
110
|
+
if in_0 and in_1:
|
111
|
+
counts_0["doublets"] += 1
|
112
|
+
elif in_0 or in_1:
|
113
|
+
counts_0["singletons"] += 1
|
114
|
+
if in_0:
|
115
|
+
counts_0["singletons_correct"] += 1
|
116
|
+
else:
|
117
|
+
counts_0["abstentions"] += 1
|
118
|
+
else:
|
119
|
+
counts_1["n"] += 1
|
120
|
+
if in_0 and in_1:
|
121
|
+
counts_1["doublets"] += 1
|
122
|
+
elif in_0 or in_1:
|
123
|
+
counts_1["singletons"] += 1
|
124
|
+
if in_1:
|
125
|
+
counts_1["singletons_correct"] += 1
|
126
|
+
else:
|
127
|
+
counts_1["abstentions"] += 1
|
128
|
+
|
129
|
+
# Compute rates
|
130
|
+
marginal_rates = {
|
131
|
+
"abstention": n_abstentions / n_test,
|
132
|
+
"singleton": n_singletons / n_test,
|
133
|
+
"doublet": n_doublets / n_test,
|
134
|
+
"singleton_error": (n_singletons - n_singletons_correct) / n_singletons if n_singletons > 0 else np.nan,
|
135
|
+
}
|
136
|
+
|
137
|
+
class_0_rates = {
|
138
|
+
"abstention": counts_0["abstentions"] / counts_0["n"] if counts_0["n"] > 0 else np.nan,
|
139
|
+
"singleton": counts_0["singletons"] / counts_0["n"] if counts_0["n"] > 0 else np.nan,
|
140
|
+
"doublet": counts_0["doublets"] / counts_0["n"] if counts_0["n"] > 0 else np.nan,
|
141
|
+
"singleton_error": (
|
142
|
+
(counts_0["singletons"] - counts_0["singletons_correct"]) / counts_0["singletons"]
|
143
|
+
if counts_0["singletons"] > 0
|
144
|
+
else np.nan
|
145
|
+
),
|
146
|
+
}
|
147
|
+
|
148
|
+
class_1_rates = {
|
149
|
+
"abstention": counts_1["abstentions"] / counts_1["n"] if counts_1["n"] > 0 else np.nan,
|
150
|
+
"singleton": counts_1["singletons"] / counts_1["n"] if counts_1["n"] > 0 else np.nan,
|
151
|
+
"doublet": counts_1["doublets"] / counts_1["n"] if counts_1["n"] > 0 else np.nan,
|
152
|
+
"singleton_error": (
|
153
|
+
(counts_1["singletons"] - counts_1["singletons_correct"]) / counts_1["singletons"]
|
154
|
+
if counts_1["singletons"] > 0
|
155
|
+
else np.nan
|
156
|
+
),
|
157
|
+
}
|
158
|
+
|
159
|
+
return {
|
160
|
+
"marginal": marginal_rates,
|
161
|
+
"class_0": class_0_rates,
|
162
|
+
"class_1": class_1_rates,
|
163
|
+
}
|
164
|
+
|
165
|
+
|
166
|
+
def cross_conformal_validation(
|
167
|
+
labels: np.ndarray,
|
168
|
+
probs: np.ndarray,
|
169
|
+
alpha_target: float = 0.10,
|
170
|
+
delta: float = 0.10,
|
171
|
+
n_folds: int = 5,
|
172
|
+
stratify: bool = True,
|
173
|
+
seed: int | None = None,
|
174
|
+
) -> dict[str, Any]:
|
175
|
+
"""K-fold cross-conformal validation for Mondrian conformal prediction.
|
176
|
+
|
177
|
+
Estimates the variability of operational rates (abstentions, singletons, doublets)
|
178
|
+
due to finite calibration sample effects by splitting data into K folds.
|
179
|
+
|
180
|
+
For each fold:
|
181
|
+
1. Train: Compute SSBC-corrected thresholds on K-1 folds
|
182
|
+
2. Test: Evaluate operational rates on held-out fold
|
183
|
+
3. Record: Store rates for this fold
|
184
|
+
|
185
|
+
Aggregate rates across folds to quantify finite-sample variability.
|
186
|
+
|
187
|
+
Parameters
|
188
|
+
----------
|
189
|
+
labels : np.ndarray, shape (n,)
|
190
|
+
Calibration labels (0 or 1)
|
191
|
+
probs : np.ndarray, shape (n, 2)
|
192
|
+
Calibration probabilities [P(class=0), P(class=1)]
|
193
|
+
alpha_target : float, default=0.10
|
194
|
+
Target miscoverage rate
|
195
|
+
delta : float, default=0.10
|
196
|
+
PAC risk for SSBC correction
|
197
|
+
n_folds : int, default=5
|
198
|
+
Number of folds (K)
|
199
|
+
stratify : bool, default=True
|
200
|
+
Stratify folds by class labels
|
201
|
+
seed : int, optional
|
202
|
+
Random seed for reproducibility
|
203
|
+
|
204
|
+
Returns
|
205
|
+
-------
|
206
|
+
dict
|
207
|
+
Cross-conformal results with keys:
|
208
|
+
- 'fold_rates': List of rate dicts for each fold
|
209
|
+
- 'marginal': Statistics for marginal rates
|
210
|
+
- 'class_0': Statistics for class 0 rates
|
211
|
+
- 'class_1': Statistics for class 1 rates
|
212
|
+
Each statistics dict contains:
|
213
|
+
- 'samples': Array of rates across folds
|
214
|
+
- 'mean': Mean rate
|
215
|
+
- 'std': Standard deviation
|
216
|
+
- 'quantiles': Dict with q05, q25, q50, q75, q95
|
217
|
+
- 'ci_95': 95% Clopper-Pearson CI (if applicable)
|
218
|
+
|
219
|
+
Examples
|
220
|
+
--------
|
221
|
+
>>> from ssbc import cross_conformal_validation
|
222
|
+
>>> results = cross_conformal_validation(labels, probs, n_folds=10)
|
223
|
+
>>> m = results['marginal']['singleton']
|
224
|
+
>>> print(f"Singleton rate: {m['mean']:.3f} ± {m['std']:.3f}")
|
225
|
+
>>> print(f"95% range: [{m['quantiles']['q05']:.3f}, {m['quantiles']['q95']:.3f}]")
|
226
|
+
|
227
|
+
Notes
|
228
|
+
-----
|
229
|
+
Different from other methods:
|
230
|
+
- **LOO-CV**: Leave-one-out, aggregates counts (not fold-level rates)
|
231
|
+
- **Bootstrap**: Resamples with replacement, tests on fresh data
|
232
|
+
- **Cross-conformal**: K-fold split, estimates rate distribution from calibration
|
233
|
+
|
234
|
+
This method directly estimates the variability of rates due to finite calibration samples,
|
235
|
+
without requiring a data simulator.
|
236
|
+
"""
|
237
|
+
if seed is not None:
|
238
|
+
np.random.seed(seed)
|
239
|
+
|
240
|
+
n = len(labels)
|
241
|
+
|
242
|
+
# Create fold indices
|
243
|
+
indices = np.arange(n)
|
244
|
+
|
245
|
+
if stratify:
|
246
|
+
# Stratified K-fold: maintain class proportions in each fold
|
247
|
+
class_0_idx = indices[labels == 0]
|
248
|
+
class_1_idx = indices[labels == 1]
|
249
|
+
|
250
|
+
np.random.shuffle(class_0_idx)
|
251
|
+
np.random.shuffle(class_1_idx)
|
252
|
+
|
253
|
+
class_0_folds = np.array_split(class_0_idx, n_folds)
|
254
|
+
class_1_folds = np.array_split(class_1_idx, n_folds)
|
255
|
+
|
256
|
+
folds = [np.concatenate([class_0_folds[i], class_1_folds[i]]) for i in range(n_folds)]
|
257
|
+
else:
|
258
|
+
# Standard K-fold
|
259
|
+
np.random.shuffle(indices)
|
260
|
+
folds = np.array_split(indices, n_folds)
|
261
|
+
|
262
|
+
# Compute rates for each fold
|
263
|
+
fold_rates = []
|
264
|
+
|
265
|
+
for fold_idx in range(n_folds):
|
266
|
+
# Test fold
|
267
|
+
test_idx = folds[fold_idx]
|
268
|
+
|
269
|
+
# Train folds (all except test)
|
270
|
+
train_idx = np.concatenate([folds[i] for i in range(n_folds) if i != fold_idx])
|
271
|
+
|
272
|
+
# Compute fold rates
|
273
|
+
rates = _compute_fold_rates_mondrian(
|
274
|
+
train_labels=labels[train_idx],
|
275
|
+
train_probs=probs[train_idx],
|
276
|
+
test_labels=labels[test_idx],
|
277
|
+
test_probs=probs[test_idx],
|
278
|
+
alpha_target=alpha_target,
|
279
|
+
delta=delta,
|
280
|
+
)
|
281
|
+
|
282
|
+
fold_rates.append(rates)
|
283
|
+
|
284
|
+
# Aggregate statistics
|
285
|
+
metrics = ["abstention", "singleton", "doublet", "singleton_error"]
|
286
|
+
|
287
|
+
def compute_stats(values: list[float], metric_name: str) -> dict[str, Any]:
|
288
|
+
"""Compute statistics for a metric across folds."""
|
289
|
+
arr = np.array(values)
|
290
|
+
valid = arr[~np.isnan(arr)]
|
291
|
+
|
292
|
+
if len(valid) == 0:
|
293
|
+
return {
|
294
|
+
"samples": arr,
|
295
|
+
"mean": np.nan,
|
296
|
+
"std": np.nan,
|
297
|
+
"quantiles": {"q05": np.nan, "q25": np.nan, "q50": np.nan, "q75": np.nan, "q95": np.nan},
|
298
|
+
"ci_95": {"lower": np.nan, "upper": np.nan},
|
299
|
+
}
|
300
|
+
|
301
|
+
quantiles = {
|
302
|
+
"q05": float(np.percentile(valid, 5)),
|
303
|
+
"q25": float(np.percentile(valid, 25)),
|
304
|
+
"q50": float(np.percentile(valid, 50)),
|
305
|
+
"q75": float(np.percentile(valid, 75)),
|
306
|
+
"q95": float(np.percentile(valid, 95)),
|
307
|
+
}
|
308
|
+
|
309
|
+
stats = {
|
310
|
+
"samples": arr,
|
311
|
+
"mean": float(np.mean(valid)),
|
312
|
+
"std": float(np.std(valid, ddof=1)) if len(valid) > 1 else 0.0,
|
313
|
+
"quantiles": quantiles,
|
314
|
+
}
|
315
|
+
|
316
|
+
# Add empirical CI based on fold distribution (binomial-like but for fold means)
|
317
|
+
# This is approximate - treats fold means as if they were Bernoulli trials
|
318
|
+
# Better: just use quantiles, but keeping for compatibility
|
319
|
+
stats["ci_95"] = {
|
320
|
+
"lower": quantiles["q05"],
|
321
|
+
"upper": quantiles["q95"],
|
322
|
+
}
|
323
|
+
|
324
|
+
return stats
|
325
|
+
|
326
|
+
# Aggregate marginal statistics
|
327
|
+
marginal_stats = {
|
328
|
+
metric: compute_stats([fold["marginal"][metric] for fold in fold_rates], metric) for metric in metrics
|
329
|
+
}
|
330
|
+
|
331
|
+
# Aggregate class-specific statistics
|
332
|
+
class_0_stats = {
|
333
|
+
metric: compute_stats([fold["class_0"][metric] for fold in fold_rates], metric) for metric in metrics
|
334
|
+
}
|
335
|
+
|
336
|
+
class_1_stats = {
|
337
|
+
metric: compute_stats([fold["class_1"][metric] for fold in fold_rates], metric) for metric in metrics
|
338
|
+
}
|
339
|
+
|
340
|
+
return {
|
341
|
+
"n_folds": n_folds,
|
342
|
+
"n_samples": n,
|
343
|
+
"stratified": stratify,
|
344
|
+
"fold_rates": fold_rates,
|
345
|
+
"marginal": marginal_stats,
|
346
|
+
"class_0": class_0_stats,
|
347
|
+
"class_1": class_1_stats,
|
348
|
+
"parameters": {
|
349
|
+
"alpha_target": alpha_target,
|
350
|
+
"delta": delta,
|
351
|
+
"n_folds": n_folds,
|
352
|
+
"stratify": stratify,
|
353
|
+
},
|
354
|
+
}
|
355
|
+
|
356
|
+
|
357
|
+
def print_cross_conformal_results(results: dict) -> None:
|
358
|
+
"""Pretty print cross-conformal validation results.
|
359
|
+
|
360
|
+
Parameters
|
361
|
+
----------
|
362
|
+
results : dict
|
363
|
+
Results from cross_conformal_validation()
|
364
|
+
"""
|
365
|
+
print("=" * 80)
|
366
|
+
print("CROSS-CONFORMAL VALIDATION RESULTS")
|
367
|
+
print("=" * 80)
|
368
|
+
print("\nParameters:")
|
369
|
+
print(f" K-folds: {results['n_folds']}")
|
370
|
+
print(f" Samples: {results['n_samples']}")
|
371
|
+
print(f" Stratified: {results['stratified']}")
|
372
|
+
print(f" Alpha target: {results['parameters']['alpha_target']:.3f}")
|
373
|
+
print(f" Delta (PAC): {results['parameters']['delta']:.3f}")
|
374
|
+
|
375
|
+
# Marginal
|
376
|
+
print("\n" + "-" * 80)
|
377
|
+
print("MARGINAL RATES (Across All Samples)")
|
378
|
+
print("-" * 80)
|
379
|
+
|
380
|
+
for metric, name in [
|
381
|
+
("singleton", "SINGLETON"),
|
382
|
+
("doublet", "DOUBLET"),
|
383
|
+
("abstention", "ABSTENTION"),
|
384
|
+
("singleton_error", "SINGLETON ERROR"),
|
385
|
+
]:
|
386
|
+
m = results["marginal"][metric]
|
387
|
+
q = m["quantiles"]
|
388
|
+
|
389
|
+
print(f"\n{name}:")
|
390
|
+
print(f" Mean across folds: {m['mean']:.4f} ± {m['std']:.4f}")
|
391
|
+
print(f" Median: {q['q50']:.4f}")
|
392
|
+
print(f" [5%, 95%] range: [{q['q05']:.4f}, {q['q95']:.4f}]")
|
393
|
+
print(f" [25%, 75%] IQR: [{q['q25']:.4f}, {q['q75']:.4f}]")
|
394
|
+
|
395
|
+
# Per-class
|
396
|
+
for class_label in [0, 1]:
|
397
|
+
print(f"\n{'-' * 80}")
|
398
|
+
print(f"CLASS {class_label} RATES")
|
399
|
+
print("-" * 80)
|
400
|
+
|
401
|
+
for metric, name in [
|
402
|
+
("singleton", "SINGLETON"),
|
403
|
+
("doublet", "DOUBLET"),
|
404
|
+
("abstention", "ABSTENTION"),
|
405
|
+
("singleton_error", "SINGLETON ERROR"),
|
406
|
+
]:
|
407
|
+
m = results[f"class_{class_label}"][metric]
|
408
|
+
q = m["quantiles"]
|
409
|
+
|
410
|
+
print(f"\n{name}:")
|
411
|
+
print(f" Mean across folds: {m['mean']:.4f} ± {m['std']:.4f}")
|
412
|
+
print(f" Median: {q['q50']:.4f}")
|
413
|
+
print(f" [5%, 95%] range: [{q['q05']:.4f}, {q['q95']:.4f}]")
|
414
|
+
|
415
|
+
print("\n" + "=" * 80)
|
416
|
+
print("INTERPRETATION")
|
417
|
+
print("=" * 80)
|
418
|
+
print("\n✓ Shows finite-sample variability from K-fold splits of calibration data")
|
419
|
+
print("✓ [5%, 95%] range indicates expected rate fluctuations")
|
420
|
+
print("✓ Smaller std → more stable rates across different calibration subsets")
|
421
|
+
print("✓ Complementary to:")
|
422
|
+
print(" • LOO-CV bounds: Uncertainty for fixed full calibration")
|
423
|
+
print(" • Bootstrap: Recalibration uncertainty with fresh test data")
|
424
|
+
print(" • Cross-conformal: Rate variability from finite calibration splits")
|
425
|
+
print("\n" + "=" * 80)
|
ssbc/mcp_server.py
ADDED
@@ -0,0 +1,93 @@
|
|
1
|
+
"""MCP Server for SSBC - Small-Sample Beta Correction.
|
2
|
+
|
3
|
+
Exposes SSBC functionality as MCP tools for AI assistants.
|
4
|
+
"""
|
5
|
+
|
6
|
+
from typing import Literal
|
7
|
+
|
8
|
+
from mcp.server.fastmcp import FastMCP # type: ignore[import-untyped]
|
9
|
+
|
10
|
+
from .core import ssbc_correct
|
11
|
+
|
12
|
+
# Initialize FastMCP server
|
13
|
+
mcp = FastMCP("SSBC Server")
|
14
|
+
|
15
|
+
|
16
|
+
@mcp.tool()
|
17
|
+
def compute_ssbc_correction(
|
18
|
+
alpha_target: float,
|
19
|
+
n: int,
|
20
|
+
delta: float,
|
21
|
+
mode: Literal["beta", "beta-binomial"] = "beta",
|
22
|
+
) -> dict[str, float | int | str]:
|
23
|
+
"""Compute Small-Sample Beta Correction for conformal prediction.
|
24
|
+
|
25
|
+
Corrects the miscoverage rate α to provide finite-sample PAC guarantees.
|
26
|
+
Unlike asymptotic methods or concentration inequalities (Hoeffding, DKWM),
|
27
|
+
SSBC uses the exact induced beta distribution for tighter bounds.
|
28
|
+
|
29
|
+
Parameters
|
30
|
+
----------
|
31
|
+
alpha_target : float
|
32
|
+
Target miscoverage rate (e.g., 0.10 for 90% coverage target)
|
33
|
+
n : int
|
34
|
+
Calibration set size (number of calibration points)
|
35
|
+
delta : float
|
36
|
+
PAC risk parameter (e.g., 0.05 for 95% probability guarantee)
|
37
|
+
mode : str, default="beta"
|
38
|
+
"beta" for infinite test window, "beta-binomial" for finite test window
|
39
|
+
|
40
|
+
Returns
|
41
|
+
-------
|
42
|
+
result : dict
|
43
|
+
Dictionary containing:
|
44
|
+
- alpha_corrected: Corrected miscoverage rate (α')
|
45
|
+
- u_star: Optimal threshold index (1-based)
|
46
|
+
- pac_mass: Beta distribution mass satisfying guarantee
|
47
|
+
- guarantee: Human-readable guarantee statement
|
48
|
+
|
49
|
+
Examples
|
50
|
+
--------
|
51
|
+
For a calibration set of 100 points targeting 90% coverage with 95% confidence:
|
52
|
+
|
53
|
+
>>> compute_ssbc_correction(alpha_target=0.10, n=100, delta=0.05)
|
54
|
+
{
|
55
|
+
"alpha_corrected": 0.0571,
|
56
|
+
"u_star": 95,
|
57
|
+
"pac_mass": 0.9549,
|
58
|
+
"guarantee": "With 95.0% probability, coverage ≥ 90.0%"
|
59
|
+
}
|
60
|
+
|
61
|
+
Notes
|
62
|
+
-----
|
63
|
+
Statistical Properties:
|
64
|
+
- Distribution-free: No assumptions about P(X,Y)
|
65
|
+
- Frequentist: Valid frequentist guarantee (no priors)
|
66
|
+
- Finite-sample: Exact for ANY n (not asymptotic)
|
67
|
+
- Model-agnostic: Works with any probabilistic classifier
|
68
|
+
|
69
|
+
The corrected α' < α_target provides more conservative thresholds,
|
70
|
+
leading to larger prediction sets and higher coverage guarantees.
|
71
|
+
"""
|
72
|
+
# Call the core SSBC function
|
73
|
+
result = ssbc_correct(alpha_target=alpha_target, n=n, delta=delta, mode=mode)
|
74
|
+
|
75
|
+
# Format response
|
76
|
+
return {
|
77
|
+
"alpha_corrected": float(result.alpha_corrected),
|
78
|
+
"u_star": int(result.u_star),
|
79
|
+
"guarantee": f"With {100 * (1 - delta):.1f}% probability, coverage ≥ {100 * (1 - alpha_target):.1f}%",
|
80
|
+
"explanation": (
|
81
|
+
f"Use α'={result.alpha_corrected:.4f} instead of α={alpha_target:.4f}. "
|
82
|
+
f"This ensures coverage ≥ {100 * (1 - alpha_target):.1f}% with "
|
83
|
+
f"{100 * (1 - delta):.1f}% probability over calibration sets of size {n}."
|
84
|
+
),
|
85
|
+
"calibration_size": n,
|
86
|
+
"target_coverage": f"{100 * (1 - alpha_target):.1f}%",
|
87
|
+
"pac_confidence": f"{100 * (1 - delta):.1f}%",
|
88
|
+
}
|
89
|
+
|
90
|
+
|
91
|
+
if __name__ == "__main__":
|
92
|
+
# Run the MCP server
|
93
|
+
mcp.run()
|