ssbc 1.0.0__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,601 @@
1
+ """Unified rigorous reporting with full PAC guarantees.
2
+
3
+ This module provides a single comprehensive report that properly accounts for
4
+ coverage volatility across all operational metrics.
5
+ """
6
+
7
+ from typing import Any, cast
8
+
9
+ import numpy as np
10
+
11
+ from .bootstrap import bootstrap_calibration_uncertainty
12
+ from .conformal import mondrian_conformal_calibrate, split_by_class
13
+ from .core import ssbc_correct
14
+ from .cross_conformal import cross_conformal_validation
15
+ from .operational_bounds_simple import (
16
+ compute_pac_operational_bounds_marginal,
17
+ compute_pac_operational_bounds_perclass,
18
+ )
19
+
20
+
21
+ def generate_rigorous_pac_report(
22
+ labels: np.ndarray,
23
+ probs: np.ndarray,
24
+ alpha_target: float | dict[int, float] = 0.10,
25
+ delta: float | dict[int, float] = 0.10,
26
+ test_size: int | None = None,
27
+ ci_level: float = 0.95,
28
+ use_union_bound: bool = True,
29
+ n_jobs: int = -1,
30
+ verbose: bool = True,
31
+ run_bootstrap: bool = False,
32
+ n_bootstrap: int = 1000,
33
+ simulator: Any = None,
34
+ run_cross_conformal: bool = False,
35
+ n_folds: int = 10,
36
+ ) -> dict[str, Any]:
37
+ """Generate complete rigorous PAC report with coverage volatility.
38
+
39
+ This is the UNIFIED function that gives you everything properly:
40
+ - SSBC-corrected thresholds
41
+ - Coverage guarantees
42
+ - PAC-controlled operational bounds (marginal + per-class)
43
+ - Singleton error rates with PAC guarantees
44
+ - All bounds account for coverage volatility via BetaBinomial
45
+
46
+ Parameters
47
+ ----------
48
+ labels : np.ndarray, shape (n,)
49
+ True labels (0 or 1)
50
+ probs : np.ndarray, shape (n, 2)
51
+ Predicted probabilities [P(class=0), P(class=1)]
52
+ alpha_target : float or dict[int, float], default=0.10
53
+ Target miscoverage per class
54
+ delta : float or dict[int, float], default=0.10
55
+ PAC risk tolerance. Used for both:
56
+ - Coverage guarantee (via SSBC)
57
+ - Operational bounds (pac_level = 1 - delta)
58
+ test_size : int, optional
59
+ Expected test set size. If None, uses calibration size
60
+ ci_level : float, default=0.95
61
+ Confidence level for Clopper-Pearson intervals
62
+ use_union_bound : bool, default=True
63
+ Apply Bonferroni for simultaneous guarantees (recommended)
64
+ n_jobs : int, default=-1
65
+ Number of parallel jobs for LOO-CV computation.
66
+ -1 = use all cores (default), 1 = single-threaded, N = use N cores.
67
+ verbose : bool, default=True
68
+ Print comprehensive report
69
+ run_bootstrap : bool, default=False
70
+ Run bootstrap calibration uncertainty analysis
71
+ n_bootstrap : int, default=1000
72
+ Number of bootstrap trials (only if run_bootstrap=True)
73
+ simulator : DataGenerator, optional
74
+ Simulator for generating fresh test sets (required if run_bootstrap=True)
75
+ run_cross_conformal : bool, default=False
76
+ Run cross-conformal validation for finite-sample diagnostics
77
+ n_folds : int, default=10
78
+ Number of folds for cross-conformal validation (only if run_cross_conformal=True)
79
+
80
+ Returns
81
+ -------
82
+ dict
83
+ Complete report with keys:
84
+ - 'ssbc_class_0': SSBCResult for class 0
85
+ - 'ssbc_class_1': SSBCResult for class 1
86
+ - 'pac_bounds_marginal': PAC operational bounds (marginal)
87
+ - 'pac_bounds_class_0': PAC operational bounds (class 0)
88
+ - 'pac_bounds_class_1': PAC operational bounds (class 1)
89
+ - 'calibration_result': From mondrian_conformal_calibrate
90
+ - 'prediction_stats': From mondrian_conformal_calibrate
91
+
92
+ Examples
93
+ --------
94
+ >>> from ssbc import BinaryClassifierSimulator
95
+ >>> from ssbc.rigorous_report import generate_rigorous_pac_report
96
+ >>>
97
+ >>> sim = BinaryClassifierSimulator(p_class1=0.5, seed=42)
98
+ >>> labels, probs = sim.generate(n_samples=1000)
99
+ >>>
100
+ >>> report = generate_rigorous_pac_report(
101
+ ... labels, probs,
102
+ ... alpha_target=0.10,
103
+ ... delta=0.10,
104
+ ... verbose=True
105
+ ... )
106
+
107
+ Notes
108
+ -----
109
+ **This replaces the old workflow:**
110
+
111
+ OLD (incomplete):
112
+ ```python
113
+ cal_result, pred_stats = mondrian_conformal_calibrate(...)
114
+ op_bounds = compute_mondrian_operational_bounds(...) # No coverage volatility!
115
+ marginal_bounds = compute_marginal_operational_bounds(...) # No coverage volatility!
116
+ report_prediction_stats(...) # Uses incomplete bounds
117
+ ```
118
+
119
+ NEW (rigorous):
120
+ ```python
121
+ report = generate_rigorous_pac_report(labels, probs, alpha_target, delta)
122
+ # Done! All bounds account for coverage volatility.
123
+ ```
124
+ """
125
+ # Handle scalar inputs - convert to dict format
126
+ if isinstance(alpha_target, int | float):
127
+ alpha_dict: dict[int, float] = {0: float(alpha_target), 1: float(alpha_target)}
128
+ else:
129
+ alpha_dict = cast(dict[int, float], alpha_target)
130
+
131
+ if isinstance(delta, int | float):
132
+ delta_dict: dict[int, float] = {0: float(delta), 1: float(delta)}
133
+ else:
134
+ delta_dict = cast(dict[int, float], delta)
135
+
136
+ # Split by class
137
+ class_data = split_by_class(labels, probs)
138
+ n_0 = class_data[0]["n"]
139
+ n_1 = class_data[1]["n"]
140
+ n_total = len(labels)
141
+
142
+ # Set test_size if not provided
143
+ if test_size is None:
144
+ test_size = n_total
145
+
146
+ # Derive PAC levels from delta values
147
+ # For marginal: use independence since split (n₀, n₁) is observed
148
+ # Pr(both coverage guarantees hold) = (1-δ₀)(1-δ₁)
149
+ pac_level_marginal = (1 - delta_dict[0]) * (1 - delta_dict[1])
150
+ pac_level_0 = 1 - delta_dict[0]
151
+ pac_level_1 = 1 - delta_dict[1]
152
+
153
+ # Step 1: Run SSBC for each class
154
+ ssbc_result_0 = ssbc_correct(alpha_target=alpha_dict[0], n=n_0, delta=delta_dict[0], mode="beta")
155
+ ssbc_result_1 = ssbc_correct(alpha_target=alpha_dict[1], n=n_1, delta=delta_dict[1], mode="beta")
156
+
157
+ # Step 2: Get calibration results (for thresholds and basic stats)
158
+ cal_result, pred_stats = mondrian_conformal_calibrate(
159
+ class_data=class_data, alpha_target=alpha_dict, delta=delta_dict, mode="beta"
160
+ )
161
+
162
+ # Step 3: Compute PAC operational bounds - MARGINAL
163
+ # Uses minimum confidence (max delta) for conservativeness
164
+ pac_bounds_marginal = compute_pac_operational_bounds_marginal(
165
+ ssbc_result_0=ssbc_result_0,
166
+ ssbc_result_1=ssbc_result_1,
167
+ labels=labels,
168
+ probs=probs,
169
+ test_size=test_size,
170
+ ci_level=ci_level,
171
+ pac_level=pac_level_marginal,
172
+ use_union_bound=use_union_bound,
173
+ n_jobs=n_jobs,
174
+ )
175
+
176
+ # Step 4: Compute PAC operational bounds - PER-CLASS
177
+ # Each class uses its own delta
178
+ pac_bounds_class_0 = compute_pac_operational_bounds_perclass(
179
+ ssbc_result_0=ssbc_result_0,
180
+ ssbc_result_1=ssbc_result_1,
181
+ labels=labels,
182
+ probs=probs,
183
+ class_label=0,
184
+ test_size=test_size,
185
+ ci_level=ci_level,
186
+ pac_level=pac_level_0,
187
+ use_union_bound=use_union_bound,
188
+ n_jobs=n_jobs,
189
+ )
190
+
191
+ pac_bounds_class_1 = compute_pac_operational_bounds_perclass(
192
+ ssbc_result_0=ssbc_result_0,
193
+ ssbc_result_1=ssbc_result_1,
194
+ labels=labels,
195
+ probs=probs,
196
+ class_label=1,
197
+ test_size=test_size,
198
+ ci_level=ci_level,
199
+ pac_level=pac_level_1,
200
+ use_union_bound=use_union_bound,
201
+ n_jobs=n_jobs,
202
+ )
203
+
204
+ # Bootstrap calibration uncertainty analysis (optional)
205
+ bootstrap_results = None
206
+ if run_bootstrap:
207
+ if simulator is None:
208
+ raise ValueError("simulator is required when run_bootstrap=True")
209
+
210
+ if verbose:
211
+ print("\n" + "=" * 80)
212
+ print("BOOTSTRAP CALIBRATION UNCERTAINTY ANALYSIS")
213
+ print("=" * 80)
214
+ print(f"\nRunning {n_bootstrap} bootstrap trials...")
215
+ print(f" Calibration size: n={len(labels)}")
216
+ print(f" Test size per trial: {test_size if test_size else len(labels)}")
217
+
218
+ bootstrap_results = bootstrap_calibration_uncertainty(
219
+ labels=labels,
220
+ probs=probs,
221
+ simulator=simulator,
222
+ alpha_target=alpha_dict[0], # Use class 0 alpha
223
+ delta=delta_dict[0], # Use class 0 delta
224
+ test_size=test_size if test_size else len(labels),
225
+ n_bootstrap=n_bootstrap,
226
+ n_jobs=n_jobs,
227
+ seed=None,
228
+ )
229
+
230
+ # Cross-conformal validation for finite-sample diagnostics (optional)
231
+ cross_conformal_results = None
232
+ if run_cross_conformal:
233
+ if verbose:
234
+ print("\n" + "=" * 80)
235
+ print("CROSS-CONFORMAL VALIDATION")
236
+ print("=" * 80)
237
+ print(f"\nRunning {n_folds}-fold cross-conformal validation...")
238
+ print(f" Calibration size: n={len(labels)}")
239
+
240
+ cross_conformal_results = cross_conformal_validation(
241
+ labels=labels,
242
+ probs=probs,
243
+ alpha_target=alpha_dict[0], # Use class 0 alpha
244
+ delta=delta_dict[0], # Use class 0 delta
245
+ n_folds=n_folds,
246
+ stratify=True,
247
+ seed=None,
248
+ )
249
+
250
+ # Build comprehensive report dict
251
+ report = {
252
+ "ssbc_class_0": ssbc_result_0,
253
+ "ssbc_class_1": ssbc_result_1,
254
+ "pac_bounds_marginal": pac_bounds_marginal,
255
+ "pac_bounds_class_0": pac_bounds_class_0,
256
+ "pac_bounds_class_1": pac_bounds_class_1,
257
+ "calibration_result": cal_result,
258
+ "prediction_stats": pred_stats,
259
+ "bootstrap_results": bootstrap_results,
260
+ "cross_conformal_results": cross_conformal_results,
261
+ "parameters": {
262
+ "alpha_target": alpha_dict,
263
+ "delta": delta_dict,
264
+ "test_size": test_size,
265
+ "ci_level": ci_level,
266
+ "pac_level_marginal": pac_level_marginal,
267
+ "pac_level_0": pac_level_0,
268
+ "pac_level_1": pac_level_1,
269
+ "use_union_bound": use_union_bound,
270
+ "run_bootstrap": run_bootstrap,
271
+ "n_bootstrap": n_bootstrap if run_bootstrap else None,
272
+ "run_cross_conformal": run_cross_conformal,
273
+ "n_folds": n_folds if run_cross_conformal else None,
274
+ },
275
+ }
276
+
277
+ # Print comprehensive report if verbose
278
+ if verbose:
279
+ _print_rigorous_report(report)
280
+
281
+ return report
282
+
283
+
284
+ def _print_rigorous_report(report: dict) -> None:
285
+ """Print comprehensive rigorous PAC report."""
286
+ cal_result = report["calibration_result"]
287
+ pred_stats = report["prediction_stats"]
288
+ params = report["parameters"]
289
+
290
+ print("=" * 80)
291
+ print("RIGOROUS PAC-CONTROLLED CONFORMAL PREDICTION REPORT")
292
+ print("=" * 80)
293
+ print("\nParameters:")
294
+ print(f" Test size: {params['test_size']}")
295
+ print(f" CI level: {params['ci_level']:.0%} (Clopper-Pearson)")
296
+ pac_0 = params["pac_level_0"]
297
+ pac_1 = params["pac_level_1"]
298
+ pac_m = params["pac_level_marginal"]
299
+ print(f" PAC confidence: Class 0: {pac_0:.0%}, Class 1: {pac_1:.0%}, Marginal: {pac_m:.0%}")
300
+ union_msg = "YES (all metrics hold simultaneously)" if params["use_union_bound"] else "NO"
301
+ print(f" Union bound: {union_msg}")
302
+
303
+ # Per-class reports
304
+ for class_label in [0, 1]:
305
+ ssbc = report[f"ssbc_class_{class_label}"]
306
+ pac = report[f"pac_bounds_class_{class_label}"]
307
+ cal = cal_result[class_label]
308
+
309
+ print("\n" + "=" * 80)
310
+ print(f"CLASS {class_label} (Conditioned on True Label = {class_label})")
311
+ print("=" * 80)
312
+
313
+ print(f" Calibration size: n = {ssbc.n}")
314
+ print(f" Target miscoverage: α = {params['alpha_target'][class_label]:.3f}")
315
+ print(f" SSBC-corrected α: α' = {ssbc.alpha_corrected:.4f}")
316
+ print(f" PAC risk: δ = {params['delta'][class_label]:.3f}")
317
+ print(f" Conformal threshold: {cal['threshold']:.4f}")
318
+
319
+ # Calibration data statistics
320
+ stats = pred_stats[class_label]
321
+ if "error" not in stats:
322
+ print(f"\n 📊 Statistics from Calibration Data (n={ssbc.n}):")
323
+ print(" [Basic CP CIs without PAC guarantee - evaluated on calibration data]")
324
+
325
+ # Abstentions
326
+ abst = stats["abstentions"]
327
+ print(
328
+ f" Abstentions: {abst['count']:4d} / {ssbc.n:4d} = {abst['proportion']:6.2%} "
329
+ f"95% CI: [{abst['lower']:.3f}, {abst['upper']:.3f}]"
330
+ )
331
+
332
+ # Singletons
333
+ sing = stats["singletons"]
334
+ print(
335
+ f" Singletons: {sing['count']:4d} / {ssbc.n:4d} = {sing['proportion']:6.2%} "
336
+ f"95% CI: [{sing['lower']:.3f}, {sing['upper']:.3f}]"
337
+ )
338
+
339
+ # Correct/incorrect singletons
340
+ sing_corr = stats["singletons_correct"]
341
+ print(
342
+ f" Correct: {sing_corr['count']:4d} / {ssbc.n:4d} = {sing_corr['proportion']:6.2%} "
343
+ f"95% CI: [{sing_corr['lower']:.3f}, {sing_corr['upper']:.3f}]"
344
+ )
345
+
346
+ sing_incorr = stats["singletons_incorrect"]
347
+ print(
348
+ f" Incorrect: {sing_incorr['count']:4d} / {ssbc.n:4d} = {sing_incorr['proportion']:6.2%} "
349
+ f"95% CI: [{sing_incorr['lower']:.3f}, {sing_incorr['upper']:.3f}]"
350
+ )
351
+
352
+ # Error | singleton
353
+ if sing["count"] > 0:
354
+ from .statistics import cp_interval
355
+
356
+ error_cond = cp_interval(sing_incorr["count"], sing["count"])
357
+ print(
358
+ f" Error | singleton: {sing_incorr['count']:4d} / {sing['count']:4d} = "
359
+ f"{error_cond['proportion']:6.2%} 95% CI: [{error_cond['lower']:.3f}, {error_cond['upper']:.3f}]"
360
+ )
361
+
362
+ # Doublets
363
+ doub = stats["doublets"]
364
+ print(
365
+ f" Doublets: {doub['count']:4d} / {ssbc.n:4d} = {doub['proportion']:6.2%} "
366
+ f"95% CI: [{doub['lower']:.3f}, {doub['upper']:.3f}]"
367
+ )
368
+
369
+ print("\n ✅ RIGOROUS PAC-Controlled Operational Bounds")
370
+ print(" (LOO-CV + Clopper-Pearson for estimation uncertainty)")
371
+ pac_level_class = params[f"pac_level_{class_label}"]
372
+ print(f" PAC level: {pac_level_class:.0%} (= 1 - δ), CP level: {params['ci_level']:.0%}")
373
+ print(f" Grid points evaluated: {pac['n_grid_points']}")
374
+
375
+ s_lower, s_upper = pac["singleton_rate_bounds"]
376
+ print("\n SINGLETON:")
377
+ print(f" Bounds: [{s_lower:.3f}, {s_upper:.3f}]")
378
+ print(f" Expected: {pac['expected_singleton_rate']:.3f}")
379
+
380
+ d_lower, d_upper = pac["doublet_rate_bounds"]
381
+ print("\n DOUBLET:")
382
+ print(f" Bounds: [{d_lower:.3f}, {d_upper:.3f}]")
383
+ print(f" Expected: {pac['expected_doublet_rate']:.3f}")
384
+
385
+ a_lower, a_upper = pac["abstention_rate_bounds"]
386
+ print("\n ABSTENTION:")
387
+ print(f" Bounds: [{a_lower:.3f}, {a_upper:.3f}]")
388
+ print(f" Expected: {pac['expected_abstention_rate']:.3f}")
389
+
390
+ se_lower, se_upper = pac["singleton_error_rate_bounds"]
391
+ print("\n CONDITIONAL ERROR (P(error | singleton)):")
392
+ print(f" Bounds: [{se_lower:.3f}, {se_upper:.3f}]")
393
+ print(f" Expected: {pac['expected_singleton_error_rate']:.3f}")
394
+
395
+ # Marginal report
396
+ pac_marg = report["pac_bounds_marginal"]
397
+ marginal_stats = pred_stats["marginal"]
398
+
399
+ print("\n" + "=" * 80)
400
+ print("MARGINAL STATISTICS (Deployment View - Ignores True Labels)")
401
+ print("=" * 80)
402
+ n_total = marginal_stats["n_total"]
403
+ print(f" Total samples: n = {n_total}")
404
+
405
+ # Calibration data statistics (marginal)
406
+ print(f"\n 📊 Statistics from Calibration Data (n={n_total}):")
407
+ print(" [Basic CP CIs - evaluated on calibration data]")
408
+
409
+ # Coverage
410
+ cov = marginal_stats["coverage"]
411
+ print(
412
+ f" Coverage: {cov['count']:4d} / {n_total:4d} = {cov['rate']:6.2%} "
413
+ f"95% CI: [{cov['ci_95']['lower']:.3f}, {cov['ci_95']['upper']:.3f}]"
414
+ )
415
+
416
+ # Abstentions
417
+ abst = marginal_stats["abstentions"]
418
+ print(
419
+ f" Abstentions: {abst['count']:4d} / {n_total:4d} = {abst['proportion']:6.2%} "
420
+ f"95% CI: [{abst['lower']:.3f}, {abst['upper']:.3f}]"
421
+ )
422
+
423
+ # Singletons
424
+ sing = marginal_stats["singletons"]
425
+ print(
426
+ f" Singletons: {sing['count']:4d} / {n_total:4d} = {sing['proportion']:6.2%} "
427
+ f"95% CI: [{sing['lower']:.3f}, {sing['upper']:.3f}]"
428
+ )
429
+
430
+ # Singleton errors
431
+ if sing["count"] > 0:
432
+ from .statistics import cp_interval
433
+
434
+ error_cond_marg = cp_interval(sing["errors"], sing["count"])
435
+ err_prop = error_cond_marg["proportion"]
436
+ err_lower = error_cond_marg["lower"]
437
+ err_upper = error_cond_marg["upper"]
438
+ print(
439
+ f" Errors: {sing['errors']:4d} / {sing['count']:4d} = "
440
+ f"{err_prop:6.2%} 95% CI: [{err_lower:.3f}, {err_upper:.3f}]"
441
+ )
442
+
443
+ # Doublets
444
+ doub = marginal_stats["doublets"]
445
+ print(
446
+ f" Doublets: {doub['count']:4d} / {n_total:4d} = {doub['proportion']:6.2%} "
447
+ f"95% CI: [{doub['lower']:.3f}, {doub['upper']:.3f}]"
448
+ )
449
+
450
+ print("\n ✅ RIGOROUS PAC-Controlled Marginal Bounds")
451
+ print(" (LOO-CV + Clopper-Pearson for estimation uncertainty)")
452
+ pac_marginal = params["pac_level_marginal"]
453
+ ci_lvl = params["ci_level"]
454
+ print(f" PAC level: {pac_marginal:.0%} (= (1-δ₀)×(1-δ₁), independence), CP level: {ci_lvl:.0%}")
455
+ print(f" Grid points evaluated: {pac_marg['n_grid_points']}")
456
+
457
+ s_lower, s_upper = pac_marg["singleton_rate_bounds"]
458
+ print("\n SINGLETON:")
459
+ print(f" Bounds: [{s_lower:.3f}, {s_upper:.3f}]")
460
+ print(f" Expected: {pac_marg['expected_singleton_rate']:.3f}")
461
+
462
+ d_lower, d_upper = pac_marg["doublet_rate_bounds"]
463
+ print("\n DOUBLET:")
464
+ print(f" Bounds: [{d_lower:.3f}, {d_upper:.3f}]")
465
+ print(f" Expected: {pac_marg['expected_doublet_rate']:.3f}")
466
+
467
+ a_lower, a_upper = pac_marg["abstention_rate_bounds"]
468
+ print("\n ABSTENTION:")
469
+ print(f" Bounds: [{a_lower:.3f}, {a_upper:.3f}]")
470
+ print(f" Expected: {pac_marg['expected_abstention_rate']:.3f}")
471
+
472
+ se_lower, se_upper = pac_marg["singleton_error_rate_bounds"]
473
+ print("\n CONDITIONAL ERROR (P(error | singleton)):")
474
+ print(f" Bounds: [{se_lower:.3f}, {se_upper:.3f}]")
475
+ print(f" Expected: {pac_marg['expected_singleton_error_rate']:.3f}")
476
+
477
+ print("\n 📈 Deployment Expectations:")
478
+ print(f" Automation (singletons): {s_lower:.1%} - {s_upper:.1%}")
479
+ print(f" Escalation (doublets+abstentions): {a_lower + d_lower:.1%} - {a_upper + d_upper:.1%}")
480
+
481
+ # Bootstrap results if available
482
+ if report["bootstrap_results"] is not None:
483
+ bootstrap = report["bootstrap_results"]
484
+ print("\n" + "=" * 80)
485
+ print("BOOTSTRAP CALIBRATION UNCERTAINTY")
486
+ print(f"({bootstrap['n_bootstrap']} bootstrap samples)")
487
+ print("=" * 80)
488
+ print("\nModels: 'If I recalibrate on similar datasets, how do rates vary?'")
489
+ print("Method: Bootstrap resample → recalibrate → test on fresh data\n")
490
+
491
+ # Marginal
492
+ print("-" * 80)
493
+ print("MARGINAL")
494
+ print("-" * 80)
495
+ for metric, name in [
496
+ ("singleton", "SINGLETON"),
497
+ ("doublet", "DOUBLET"),
498
+ ("abstention", "ABSTENTION"),
499
+ ("singleton_error", "SINGLETON ERROR"),
500
+ ]:
501
+ m = bootstrap["marginal"][metric]
502
+ q = m["quantiles"]
503
+ print(f"\n{name}:")
504
+ print(f" Mean: {m['mean']:.4f} ± {m['std']:.4f}")
505
+ print(f" Median: {q['q50']:.4f}")
506
+ print(f" [5%, 95%]: [{q['q05']:.4f}, {q['q95']:.4f}]")
507
+
508
+ # Per-class
509
+ for class_label in [0, 1]:
510
+ print(f"\n{'-' * 80}")
511
+ print(f"CLASS {class_label}")
512
+ print("-" * 80)
513
+ for metric, name in [
514
+ ("singleton", "SINGLETON"),
515
+ ("doublet", "DOUBLET"),
516
+ ("abstention", "ABSTENTION"),
517
+ ("singleton_error", "SINGLETON ERROR"),
518
+ ]:
519
+ m = bootstrap[f"class_{class_label}"][metric]
520
+ q = m["quantiles"]
521
+ print(f"\n{name}:")
522
+ print(f" Mean: {m['mean']:.4f} ± {m['std']:.4f}")
523
+ print(f" Median: {q['q50']:.4f}")
524
+ print(f" [5%, 95%]: [{q['q05']:.4f}, {q['q95']:.4f}]")
525
+
526
+ # Cross-conformal results if available
527
+ if report["cross_conformal_results"] is not None:
528
+ cross_conf = report["cross_conformal_results"]
529
+ print("\n" + "=" * 80)
530
+ print("CROSS-CONFORMAL VALIDATION")
531
+ print(f"({cross_conf['n_folds']}-fold, n={cross_conf['n_samples']})")
532
+ print("=" * 80)
533
+ print("\nModels: 'How stable are rates across different calibration subsets?'")
534
+ print("Method: K-fold split → train on K-1 → test on 1 fold\n")
535
+
536
+ # Marginal
537
+ print("-" * 80)
538
+ print("MARGINAL")
539
+ print("-" * 80)
540
+ for metric, name in [
541
+ ("singleton", "SINGLETON"),
542
+ ("doublet", "DOUBLET"),
543
+ ("abstention", "ABSTENTION"),
544
+ ("singleton_error", "SINGLETON ERROR"),
545
+ ]:
546
+ m = cross_conf["marginal"][metric]
547
+ q = m["quantiles"]
548
+ print(f"\n{name}:")
549
+ print(f" Mean across folds: {m['mean']:.4f} ± {m['std']:.4f}")
550
+ print(f" Median: {q['q50']:.4f}")
551
+ print(f" [5%, 95%] range: [{q['q05']:.4f}, {q['q95']:.4f}]")
552
+
553
+ # Per-class
554
+ for class_label in [0, 1]:
555
+ print(f"\n{'-' * 80}")
556
+ print(f"CLASS {class_label}")
557
+ print("-" * 80)
558
+ for metric, name in [
559
+ ("singleton", "SINGLETON"),
560
+ ("doublet", "DOUBLET"),
561
+ ("abstention", "ABSTENTION"),
562
+ ("singleton_error", "SINGLETON ERROR"),
563
+ ]:
564
+ m = cross_conf[f"class_{class_label}"][metric]
565
+ q = m["quantiles"]
566
+ print(f"\n{name}:")
567
+ print(f" Mean across folds: {m['mean']:.4f} ± {m['std']:.4f}")
568
+ print(f" Median: {q['q50']:.4f}")
569
+ print(f" [5%, 95%] range: [{q['q05']:.4f}, {q['q95']:.4f}]")
570
+
571
+ print("\n" + "=" * 80)
572
+ print("NOTES")
573
+ print("=" * 80)
574
+ print("\n✓ PAC BOUNDS (LOO-CV + CP):")
575
+ print(" • Bound the TRUE rate for THIS fixed calibration")
576
+ print(" • Valid for any future test set size")
577
+ print(" • Models: 'Given this calibration, what rates on future test sets?'")
578
+ if report["bootstrap_results"] is not None:
579
+ print("\n✓ BOOTSTRAP INTERVALS:")
580
+ print(" • Show recalibration uncertainty (wider than PAC bounds)")
581
+ print(" • Models: 'If I recalibrate on similar data, how do rates vary?'")
582
+ print(" • Complementary to PAC bounds - different question!")
583
+ if report["cross_conformal_results"] is not None:
584
+ print("\n✓ CROSS-CONFORMAL VALIDATION:")
585
+ print(" • Shows rate stability across K-fold calibration splits")
586
+ print(" • Models: 'How stable are rates across calibration subsets?'")
587
+ print(" • Use for: Finite-sample diagnostics, sample size planning")
588
+ print(" • Large std → need more calibration data")
589
+ print("\n✓ TECHNICAL DETAILS:")
590
+ print(" • LOO-CV for unbiased rate estimates (no data leakage)")
591
+ print(" • Clopper-Pearson intervals account for estimation uncertainty")
592
+ if params["use_union_bound"]:
593
+ print(" • Union bound ensures ALL metrics hold simultaneously")
594
+ if report["bootstrap_results"] is not None or report["cross_conformal_results"] is not None:
595
+ print("\n✓ ALL METHODS ARE COMPLEMENTARY:")
596
+ print(" • Use PAC bounds for deployment (rigorous guarantees)")
597
+ if report["bootstrap_results"] is not None:
598
+ print(" • Use Bootstrap to understand recalibration impact")
599
+ if report["cross_conformal_results"] is not None:
600
+ print(" • Use Cross-Conformal to diagnose calibration quality")
601
+ print("\n" + "=" * 80)
ssbc/statistics.py CHANGED
@@ -4,6 +4,76 @@ from typing import Any
4
4
 
5
5
  import numpy as np
6
6
  from scipy import stats
7
+ from scipy.stats import beta as beta_dist
8
+
9
+
10
+ def clopper_pearson_lower(k: int, n: int, confidence: float = 0.95) -> float:
11
+ """Compute lower Clopper-Pearson (one-sided) confidence bound.
12
+
13
+ Parameters
14
+ ----------
15
+ k : int
16
+ Number of successes
17
+ n : int
18
+ Total number of trials
19
+ confidence : float, default=0.95
20
+ Confidence level (e.g., 0.95 for 95% confidence)
21
+
22
+ Returns
23
+ -------
24
+ float
25
+ Lower confidence bound for the true proportion
26
+
27
+ Examples
28
+ --------
29
+ >>> lower = clopper_pearson_lower(k=5, n=10, confidence=0.95)
30
+ >>> print(f"Lower bound: {lower:.3f}")
31
+
32
+ Notes
33
+ -----
34
+ Uses Beta distribution quantiles for exact binomial confidence bounds.
35
+ For PAC-style guarantees, you may want to use delta = 1 - confidence.
36
+ """
37
+ if k == 0:
38
+ return 0.0
39
+ # L = Beta^{-1}(1-confidence; k, n-k+1)
40
+ # Note: Using (1-confidence) as the lower tail probability
41
+ alpha = 1 - confidence
42
+ return float(beta_dist.ppf(alpha, k, n - k + 1))
43
+
44
+
45
+ def clopper_pearson_upper(k: int, n: int, confidence: float = 0.95) -> float:
46
+ """Compute upper Clopper-Pearson (one-sided) confidence bound.
47
+
48
+ Parameters
49
+ ----------
50
+ k : int
51
+ Number of successes
52
+ n : int
53
+ Total number of trials
54
+ confidence : float, default=0.95
55
+ Confidence level (e.g., 0.95 for 95% confidence)
56
+
57
+ Returns
58
+ -------
59
+ float
60
+ Upper confidence bound for the true proportion
61
+
62
+ Examples
63
+ --------
64
+ >>> upper = clopper_pearson_upper(k=5, n=10, confidence=0.95)
65
+ >>> print(f"Upper bound: {upper:.3f}")
66
+
67
+ Notes
68
+ -----
69
+ Uses Beta distribution quantiles for exact binomial confidence bounds.
70
+ For PAC-style guarantees, you may want to use delta = 1 - confidence.
71
+ """
72
+ if k == n:
73
+ return 1.0
74
+ # U = Beta^{-1}(confidence; k+1, n-k)
75
+ # Note: Using confidence directly for upper tail
76
+ return float(beta_dist.ppf(confidence, k + 1, n - k))
7
77
 
8
78
 
9
79
  def clopper_pearson_intervals(labels: np.ndarray, confidence: float = 0.95) -> dict[int, dict[str, Any]]: