ssbc 0.1.0__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,425 @@
1
+ """Cross-conformal validation for estimating rate variability.
2
+
3
+ This implements K-fold cross-validation specifically for conformal prediction:
4
+ - Split calibration data into K folds
5
+ - For each fold: train thresholds on K-1 folds, evaluate rates on held-out fold
6
+ - Aggregate rates across folds to quantify finite-sample variability
7
+
8
+ Different from:
9
+ - LOO-CV: Leave-one-out, aggregates counts (not rates per fold)
10
+ - Bootstrap: Resamples with replacement, tests on fresh data
11
+ - Cross-conformal: K-fold split, estimates rate distribution from finite calibration
12
+ """
13
+
14
+ from typing import Any
15
+
16
+ import numpy as np
17
+
18
+ from ssbc.conformal import split_by_class
19
+ from ssbc.core import ssbc_correct
20
+
21
+
22
+ def _compute_fold_rates_mondrian(
23
+ train_labels: np.ndarray,
24
+ train_probs: np.ndarray,
25
+ test_labels: np.ndarray,
26
+ test_probs: np.ndarray,
27
+ alpha_target: float,
28
+ delta: float,
29
+ ) -> dict[str, dict[str, float]]:
30
+ """Compute operational rates for one fold in Mondrian conformal.
31
+
32
+ Parameters
33
+ ----------
34
+ train_labels : np.ndarray
35
+ Training fold labels
36
+ train_probs : np.ndarray
37
+ Training fold probabilities
38
+ test_labels : np.ndarray
39
+ Test fold labels
40
+ test_probs : np.ndarray
41
+ Test fold probabilities
42
+ alpha_target : float
43
+ Target miscoverage
44
+ delta : float
45
+ PAC risk (for SSBC correction)
46
+
47
+ Returns
48
+ -------
49
+ dict
50
+ Rates for this fold: marginal and per-class
51
+ """
52
+ # Split training data by class
53
+ train_class_data = split_by_class(train_labels, train_probs)
54
+
55
+ # SSBC correction and threshold computation
56
+ thresholds = {}
57
+ for class_label in [0, 1]:
58
+ class_data = train_class_data[class_label]
59
+ if class_data["n"] == 0:
60
+ thresholds[class_label] = 0.0
61
+ continue
62
+
63
+ # SSBC correction
64
+ ssbc_result = ssbc_correct(alpha_target=alpha_target, n=class_data["n"], delta=delta)
65
+
66
+ # Compute threshold
67
+ n_class = class_data["n"]
68
+ k = int(np.ceil((n_class + 1) * (1 - ssbc_result.alpha_corrected)))
69
+
70
+ mask = train_labels == class_label
71
+ scores = 1.0 - train_probs[mask, class_label]
72
+ sorted_scores = np.sort(scores)
73
+
74
+ thresholds[class_label] = sorted_scores[min(k - 1, len(sorted_scores) - 1)]
75
+
76
+ # Evaluate on test fold
77
+ n_test = len(test_labels)
78
+
79
+ # Marginal counters
80
+ n_abstentions = 0
81
+ n_singletons = 0
82
+ n_doublets = 0
83
+ n_singletons_correct = 0
84
+
85
+ # Per-class counters
86
+ counts_0 = {"abstentions": 0, "singletons": 0, "doublets": 0, "singletons_correct": 0, "n": 0}
87
+ counts_1 = {"abstentions": 0, "singletons": 0, "doublets": 0, "singletons_correct": 0, "n": 0}
88
+
89
+ for i in range(n_test):
90
+ true_label = test_labels[i]
91
+ score_0 = 1.0 - test_probs[i, 0]
92
+ score_1 = 1.0 - test_probs[i, 1]
93
+
94
+ in_0 = score_0 <= thresholds[0]
95
+ in_1 = score_1 <= thresholds[1]
96
+
97
+ # Marginal
98
+ if in_0 and in_1:
99
+ n_doublets += 1
100
+ elif in_0 or in_1:
101
+ n_singletons += 1
102
+ if (in_0 and true_label == 0) or (in_1 and true_label == 1):
103
+ n_singletons_correct += 1
104
+ else:
105
+ n_abstentions += 1
106
+
107
+ # Per-class
108
+ if true_label == 0:
109
+ counts_0["n"] += 1
110
+ if in_0 and in_1:
111
+ counts_0["doublets"] += 1
112
+ elif in_0 or in_1:
113
+ counts_0["singletons"] += 1
114
+ if in_0:
115
+ counts_0["singletons_correct"] += 1
116
+ else:
117
+ counts_0["abstentions"] += 1
118
+ else:
119
+ counts_1["n"] += 1
120
+ if in_0 and in_1:
121
+ counts_1["doublets"] += 1
122
+ elif in_0 or in_1:
123
+ counts_1["singletons"] += 1
124
+ if in_1:
125
+ counts_1["singletons_correct"] += 1
126
+ else:
127
+ counts_1["abstentions"] += 1
128
+
129
+ # Compute rates
130
+ marginal_rates = {
131
+ "abstention": n_abstentions / n_test,
132
+ "singleton": n_singletons / n_test,
133
+ "doublet": n_doublets / n_test,
134
+ "singleton_error": (n_singletons - n_singletons_correct) / n_singletons if n_singletons > 0 else np.nan,
135
+ }
136
+
137
+ class_0_rates = {
138
+ "abstention": counts_0["abstentions"] / counts_0["n"] if counts_0["n"] > 0 else np.nan,
139
+ "singleton": counts_0["singletons"] / counts_0["n"] if counts_0["n"] > 0 else np.nan,
140
+ "doublet": counts_0["doublets"] / counts_0["n"] if counts_0["n"] > 0 else np.nan,
141
+ "singleton_error": (
142
+ (counts_0["singletons"] - counts_0["singletons_correct"]) / counts_0["singletons"]
143
+ if counts_0["singletons"] > 0
144
+ else np.nan
145
+ ),
146
+ }
147
+
148
+ class_1_rates = {
149
+ "abstention": counts_1["abstentions"] / counts_1["n"] if counts_1["n"] > 0 else np.nan,
150
+ "singleton": counts_1["singletons"] / counts_1["n"] if counts_1["n"] > 0 else np.nan,
151
+ "doublet": counts_1["doublets"] / counts_1["n"] if counts_1["n"] > 0 else np.nan,
152
+ "singleton_error": (
153
+ (counts_1["singletons"] - counts_1["singletons_correct"]) / counts_1["singletons"]
154
+ if counts_1["singletons"] > 0
155
+ else np.nan
156
+ ),
157
+ }
158
+
159
+ return {
160
+ "marginal": marginal_rates,
161
+ "class_0": class_0_rates,
162
+ "class_1": class_1_rates,
163
+ }
164
+
165
+
166
+ def cross_conformal_validation(
167
+ labels: np.ndarray,
168
+ probs: np.ndarray,
169
+ alpha_target: float = 0.10,
170
+ delta: float = 0.10,
171
+ n_folds: int = 5,
172
+ stratify: bool = True,
173
+ seed: int | None = None,
174
+ ) -> dict[str, Any]:
175
+ """K-fold cross-conformal validation for Mondrian conformal prediction.
176
+
177
+ Estimates the variability of operational rates (abstentions, singletons, doublets)
178
+ due to finite calibration sample effects by splitting data into K folds.
179
+
180
+ For each fold:
181
+ 1. Train: Compute SSBC-corrected thresholds on K-1 folds
182
+ 2. Test: Evaluate operational rates on held-out fold
183
+ 3. Record: Store rates for this fold
184
+
185
+ Aggregate rates across folds to quantify finite-sample variability.
186
+
187
+ Parameters
188
+ ----------
189
+ labels : np.ndarray, shape (n,)
190
+ Calibration labels (0 or 1)
191
+ probs : np.ndarray, shape (n, 2)
192
+ Calibration probabilities [P(class=0), P(class=1)]
193
+ alpha_target : float, default=0.10
194
+ Target miscoverage rate
195
+ delta : float, default=0.10
196
+ PAC risk for SSBC correction
197
+ n_folds : int, default=5
198
+ Number of folds (K)
199
+ stratify : bool, default=True
200
+ Stratify folds by class labels
201
+ seed : int, optional
202
+ Random seed for reproducibility
203
+
204
+ Returns
205
+ -------
206
+ dict
207
+ Cross-conformal results with keys:
208
+ - 'fold_rates': List of rate dicts for each fold
209
+ - 'marginal': Statistics for marginal rates
210
+ - 'class_0': Statistics for class 0 rates
211
+ - 'class_1': Statistics for class 1 rates
212
+ Each statistics dict contains:
213
+ - 'samples': Array of rates across folds
214
+ - 'mean': Mean rate
215
+ - 'std': Standard deviation
216
+ - 'quantiles': Dict with q05, q25, q50, q75, q95
217
+ - 'ci_95': 95% Clopper-Pearson CI (if applicable)
218
+
219
+ Examples
220
+ --------
221
+ >>> from ssbc import cross_conformal_validation
222
+ >>> results = cross_conformal_validation(labels, probs, n_folds=10)
223
+ >>> m = results['marginal']['singleton']
224
+ >>> print(f"Singleton rate: {m['mean']:.3f} ± {m['std']:.3f}")
225
+ >>> print(f"95% range: [{m['quantiles']['q05']:.3f}, {m['quantiles']['q95']:.3f}]")
226
+
227
+ Notes
228
+ -----
229
+ Different from other methods:
230
+ - **LOO-CV**: Leave-one-out, aggregates counts (not fold-level rates)
231
+ - **Bootstrap**: Resamples with replacement, tests on fresh data
232
+ - **Cross-conformal**: K-fold split, estimates rate distribution from calibration
233
+
234
+ This method directly estimates the variability of rates due to finite calibration samples,
235
+ without requiring a data simulator.
236
+ """
237
+ if seed is not None:
238
+ np.random.seed(seed)
239
+
240
+ n = len(labels)
241
+
242
+ # Create fold indices
243
+ indices = np.arange(n)
244
+
245
+ if stratify:
246
+ # Stratified K-fold: maintain class proportions in each fold
247
+ class_0_idx = indices[labels == 0]
248
+ class_1_idx = indices[labels == 1]
249
+
250
+ np.random.shuffle(class_0_idx)
251
+ np.random.shuffle(class_1_idx)
252
+
253
+ class_0_folds = np.array_split(class_0_idx, n_folds)
254
+ class_1_folds = np.array_split(class_1_idx, n_folds)
255
+
256
+ folds = [np.concatenate([class_0_folds[i], class_1_folds[i]]) for i in range(n_folds)]
257
+ else:
258
+ # Standard K-fold
259
+ np.random.shuffle(indices)
260
+ folds = np.array_split(indices, n_folds)
261
+
262
+ # Compute rates for each fold
263
+ fold_rates = []
264
+
265
+ for fold_idx in range(n_folds):
266
+ # Test fold
267
+ test_idx = folds[fold_idx]
268
+
269
+ # Train folds (all except test)
270
+ train_idx = np.concatenate([folds[i] for i in range(n_folds) if i != fold_idx])
271
+
272
+ # Compute fold rates
273
+ rates = _compute_fold_rates_mondrian(
274
+ train_labels=labels[train_idx],
275
+ train_probs=probs[train_idx],
276
+ test_labels=labels[test_idx],
277
+ test_probs=probs[test_idx],
278
+ alpha_target=alpha_target,
279
+ delta=delta,
280
+ )
281
+
282
+ fold_rates.append(rates)
283
+
284
+ # Aggregate statistics
285
+ metrics = ["abstention", "singleton", "doublet", "singleton_error"]
286
+
287
+ def compute_stats(values: list[float], metric_name: str) -> dict[str, Any]:
288
+ """Compute statistics for a metric across folds."""
289
+ arr = np.array(values)
290
+ valid = arr[~np.isnan(arr)]
291
+
292
+ if len(valid) == 0:
293
+ return {
294
+ "samples": arr,
295
+ "mean": np.nan,
296
+ "std": np.nan,
297
+ "quantiles": {"q05": np.nan, "q25": np.nan, "q50": np.nan, "q75": np.nan, "q95": np.nan},
298
+ "ci_95": {"lower": np.nan, "upper": np.nan},
299
+ }
300
+
301
+ quantiles = {
302
+ "q05": float(np.percentile(valid, 5)),
303
+ "q25": float(np.percentile(valid, 25)),
304
+ "q50": float(np.percentile(valid, 50)),
305
+ "q75": float(np.percentile(valid, 75)),
306
+ "q95": float(np.percentile(valid, 95)),
307
+ }
308
+
309
+ stats = {
310
+ "samples": arr,
311
+ "mean": float(np.mean(valid)),
312
+ "std": float(np.std(valid, ddof=1)) if len(valid) > 1 else 0.0,
313
+ "quantiles": quantiles,
314
+ }
315
+
316
+ # Add empirical CI based on fold distribution (binomial-like but for fold means)
317
+ # This is approximate - treats fold means as if they were Bernoulli trials
318
+ # Better: just use quantiles, but keeping for compatibility
319
+ stats["ci_95"] = {
320
+ "lower": quantiles["q05"],
321
+ "upper": quantiles["q95"],
322
+ }
323
+
324
+ return stats
325
+
326
+ # Aggregate marginal statistics
327
+ marginal_stats = {
328
+ metric: compute_stats([fold["marginal"][metric] for fold in fold_rates], metric) for metric in metrics
329
+ }
330
+
331
+ # Aggregate class-specific statistics
332
+ class_0_stats = {
333
+ metric: compute_stats([fold["class_0"][metric] for fold in fold_rates], metric) for metric in metrics
334
+ }
335
+
336
+ class_1_stats = {
337
+ metric: compute_stats([fold["class_1"][metric] for fold in fold_rates], metric) for metric in metrics
338
+ }
339
+
340
+ return {
341
+ "n_folds": n_folds,
342
+ "n_samples": n,
343
+ "stratified": stratify,
344
+ "fold_rates": fold_rates,
345
+ "marginal": marginal_stats,
346
+ "class_0": class_0_stats,
347
+ "class_1": class_1_stats,
348
+ "parameters": {
349
+ "alpha_target": alpha_target,
350
+ "delta": delta,
351
+ "n_folds": n_folds,
352
+ "stratify": stratify,
353
+ },
354
+ }
355
+
356
+
357
+ def print_cross_conformal_results(results: dict) -> None:
358
+ """Pretty print cross-conformal validation results.
359
+
360
+ Parameters
361
+ ----------
362
+ results : dict
363
+ Results from cross_conformal_validation()
364
+ """
365
+ print("=" * 80)
366
+ print("CROSS-CONFORMAL VALIDATION RESULTS")
367
+ print("=" * 80)
368
+ print("\nParameters:")
369
+ print(f" K-folds: {results['n_folds']}")
370
+ print(f" Samples: {results['n_samples']}")
371
+ print(f" Stratified: {results['stratified']}")
372
+ print(f" Alpha target: {results['parameters']['alpha_target']:.3f}")
373
+ print(f" Delta (PAC): {results['parameters']['delta']:.3f}")
374
+
375
+ # Marginal
376
+ print("\n" + "-" * 80)
377
+ print("MARGINAL RATES (Across All Samples)")
378
+ print("-" * 80)
379
+
380
+ for metric, name in [
381
+ ("singleton", "SINGLETON"),
382
+ ("doublet", "DOUBLET"),
383
+ ("abstention", "ABSTENTION"),
384
+ ("singleton_error", "SINGLETON ERROR"),
385
+ ]:
386
+ m = results["marginal"][metric]
387
+ q = m["quantiles"]
388
+
389
+ print(f"\n{name}:")
390
+ print(f" Mean across folds: {m['mean']:.4f} ± {m['std']:.4f}")
391
+ print(f" Median: {q['q50']:.4f}")
392
+ print(f" [5%, 95%] range: [{q['q05']:.4f}, {q['q95']:.4f}]")
393
+ print(f" [25%, 75%] IQR: [{q['q25']:.4f}, {q['q75']:.4f}]")
394
+
395
+ # Per-class
396
+ for class_label in [0, 1]:
397
+ print(f"\n{'-' * 80}")
398
+ print(f"CLASS {class_label} RATES")
399
+ print("-" * 80)
400
+
401
+ for metric, name in [
402
+ ("singleton", "SINGLETON"),
403
+ ("doublet", "DOUBLET"),
404
+ ("abstention", "ABSTENTION"),
405
+ ("singleton_error", "SINGLETON ERROR"),
406
+ ]:
407
+ m = results[f"class_{class_label}"][metric]
408
+ q = m["quantiles"]
409
+
410
+ print(f"\n{name}:")
411
+ print(f" Mean across folds: {m['mean']:.4f} ± {m['std']:.4f}")
412
+ print(f" Median: {q['q50']:.4f}")
413
+ print(f" [5%, 95%] range: [{q['q05']:.4f}, {q['q95']:.4f}]")
414
+
415
+ print("\n" + "=" * 80)
416
+ print("INTERPRETATION")
417
+ print("=" * 80)
418
+ print("\n✓ Shows finite-sample variability from K-fold splits of calibration data")
419
+ print("✓ [5%, 95%] range indicates expected rate fluctuations")
420
+ print("✓ Smaller std → more stable rates across different calibration subsets")
421
+ print("✓ Complementary to:")
422
+ print(" • LOO-CV bounds: Uncertainty for fixed full calibration")
423
+ print(" • Bootstrap: Recalibration uncertainty with fresh test data")
424
+ print(" • Cross-conformal: Rate variability from finite calibration splits")
425
+ print("\n" + "=" * 80)
ssbc/mcp_server.py ADDED
@@ -0,0 +1,93 @@
1
+ """MCP Server for SSBC - Small-Sample Beta Correction.
2
+
3
+ Exposes SSBC functionality as MCP tools for AI assistants.
4
+ """
5
+
6
+ from typing import Literal
7
+
8
+ from mcp.server.fastmcp import FastMCP # type: ignore[import-untyped]
9
+
10
+ from .core import ssbc_correct
11
+
12
+ # Initialize FastMCP server
13
+ mcp = FastMCP("SSBC Server")
14
+
15
+
16
+ @mcp.tool()
17
+ def compute_ssbc_correction(
18
+ alpha_target: float,
19
+ n: int,
20
+ delta: float,
21
+ mode: Literal["beta", "beta-binomial"] = "beta",
22
+ ) -> dict[str, float | int | str]:
23
+ """Compute Small-Sample Beta Correction for conformal prediction.
24
+
25
+ Corrects the miscoverage rate α to provide finite-sample PAC guarantees.
26
+ Unlike asymptotic methods or concentration inequalities (Hoeffding, DKWM),
27
+ SSBC uses the exact induced beta distribution for tighter bounds.
28
+
29
+ Parameters
30
+ ----------
31
+ alpha_target : float
32
+ Target miscoverage rate (e.g., 0.10 for 90% coverage target)
33
+ n : int
34
+ Calibration set size (number of calibration points)
35
+ delta : float
36
+ PAC risk parameter (e.g., 0.05 for 95% probability guarantee)
37
+ mode : str, default="beta"
38
+ "beta" for infinite test window, "beta-binomial" for finite test window
39
+
40
+ Returns
41
+ -------
42
+ result : dict
43
+ Dictionary containing:
44
+ - alpha_corrected: Corrected miscoverage rate (α')
45
+ - u_star: Optimal threshold index (1-based)
46
+ - pac_mass: Beta distribution mass satisfying guarantee
47
+ - guarantee: Human-readable guarantee statement
48
+
49
+ Examples
50
+ --------
51
+ For a calibration set of 100 points targeting 90% coverage with 95% confidence:
52
+
53
+ >>> compute_ssbc_correction(alpha_target=0.10, n=100, delta=0.05)
54
+ {
55
+ "alpha_corrected": 0.0571,
56
+ "u_star": 95,
57
+ "pac_mass": 0.9549,
58
+ "guarantee": "With 95.0% probability, coverage ≥ 90.0%"
59
+ }
60
+
61
+ Notes
62
+ -----
63
+ Statistical Properties:
64
+ - Distribution-free: No assumptions about P(X,Y)
65
+ - Frequentist: Valid frequentist guarantee (no priors)
66
+ - Finite-sample: Exact for ANY n (not asymptotic)
67
+ - Model-agnostic: Works with any probabilistic classifier
68
+
69
+ The corrected α' < α_target provides more conservative thresholds,
70
+ leading to larger prediction sets and higher coverage guarantees.
71
+ """
72
+ # Call the core SSBC function
73
+ result = ssbc_correct(alpha_target=alpha_target, n=n, delta=delta, mode=mode)
74
+
75
+ # Format response
76
+ return {
77
+ "alpha_corrected": float(result.alpha_corrected),
78
+ "u_star": int(result.u_star),
79
+ "guarantee": f"With {100 * (1 - delta):.1f}% probability, coverage ≥ {100 * (1 - alpha_target):.1f}%",
80
+ "explanation": (
81
+ f"Use α'={result.alpha_corrected:.4f} instead of α={alpha_target:.4f}. "
82
+ f"This ensures coverage ≥ {100 * (1 - alpha_target):.1f}% with "
83
+ f"{100 * (1 - delta):.1f}% probability over calibration sets of size {n}."
84
+ ),
85
+ "calibration_size": n,
86
+ "target_coverage": f"{100 * (1 - alpha_target):.1f}%",
87
+ "pac_confidence": f"{100 * (1 - delta):.1f}%",
88
+ }
89
+
90
+
91
+ if __name__ == "__main__":
92
+ # Run the MCP server
93
+ mcp.run()