ssbc 0.1.0__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ssbc/utils.py CHANGED
@@ -1,2 +1,72 @@
1
- def do_something_useful():
2
- print("Replace this with a utility function")
1
+ """Utility functions for conformal prediction."""
2
+
3
+ from typing import Literal
4
+
5
+ import numpy as np
6
+
7
+
8
+ def compute_operational_rate(
9
+ prediction_sets: list[set | list],
10
+ true_labels: np.ndarray,
11
+ rate_type: Literal["singleton", "doublet", "abstention", "error_in_singleton", "correct_in_singleton"],
12
+ ) -> np.ndarray:
13
+ """Compute operational rate indicators for prediction sets.
14
+
15
+ For each prediction set, compute a binary indicator showing whether
16
+ a specific operational event occurred (singleton, doublet, abstention,
17
+ error in singleton, or correct in singleton).
18
+
19
+ Parameters
20
+ ----------
21
+ prediction_sets : list[set | list]
22
+ Prediction sets for each sample. Each set contains predicted labels.
23
+ true_labels : np.ndarray
24
+ True labels for each sample
25
+ rate_type : {"singleton", "doublet", "abstention", "error_in_singleton", "correct_in_singleton"}
26
+ Type of operational rate to compute:
27
+ - "singleton": prediction set contains exactly one label
28
+ - "doublet": prediction set contains exactly two labels
29
+ - "abstention": prediction set is empty
30
+ - "error_in_singleton": singleton prediction that doesn't contain true label
31
+ - "correct_in_singleton": singleton prediction that contains true label
32
+
33
+ Returns
34
+ -------
35
+ np.ndarray
36
+ Binary indicators (0 or 1) for whether the event holds for each sample
37
+
38
+ Examples
39
+ --------
40
+ >>> pred_sets = [{0}, {0, 1}, set(), {1}]
41
+ >>> true_labels = np.array([0, 0, 1, 0])
42
+ >>> indicators = compute_operational_rate(pred_sets, true_labels, "singleton")
43
+ >>> print(indicators) # [1, 0, 0, 1]
44
+ >>> indicators = compute_operational_rate(pred_sets, true_labels, "correct_in_singleton")
45
+ >>> print(indicators) # [1, 0, 0, 0] - first and last are singletons, first is correct
46
+
47
+ Notes
48
+ -----
49
+ This function is useful for computing operational statistics on conformal
50
+ prediction sets, such as singleton rates, escalation rates, and error rates.
51
+ """
52
+ n = len(prediction_sets)
53
+ indicators = np.zeros(n, dtype=int)
54
+
55
+ for i in range(n):
56
+ pred_set = prediction_sets[i]
57
+ y_true = true_labels[i]
58
+
59
+ if rate_type == "singleton":
60
+ indicators[i] = int(len(pred_set) == 1)
61
+ elif rate_type == "doublet":
62
+ indicators[i] = int(len(pred_set) == 2)
63
+ elif rate_type == "abstention":
64
+ indicators[i] = int(len(pred_set) == 0)
65
+ elif rate_type == "error_in_singleton":
66
+ indicators[i] = int(len(pred_set) == 1 and y_true not in pred_set)
67
+ elif rate_type == "correct_in_singleton":
68
+ indicators[i] = int(len(pred_set) == 1 and y_true in pred_set)
69
+ else:
70
+ raise ValueError(f"Unknown rate_type: {rate_type}")
71
+
72
+ return indicators
ssbc/validation.py ADDED
@@ -0,0 +1,409 @@
1
+ """Validation utilities for PAC-controlled operational bounds.
2
+
3
+ This module provides tools to empirically validate the theoretical PAC guarantees
4
+ by running simulations with fixed calibration thresholds on independent test sets.
5
+ """
6
+
7
+ from typing import Any
8
+
9
+ import numpy as np
10
+
11
+
12
+ def validate_pac_bounds(
13
+ report: dict[str, Any],
14
+ simulator: Any,
15
+ test_size: int,
16
+ n_trials: int = 1000,
17
+ seed: int | None = None,
18
+ verbose: bool = True,
19
+ ) -> dict[str, Any]:
20
+ """Empirically validate PAC operational bounds.
21
+
22
+ Takes a PAC report from generate_rigorous_pac_report() and validates that
23
+ the theoretical bounds actually hold in practice by:
24
+ 1. Extracting the FIXED thresholds from calibration
25
+ 2. Running n_trials simulations with fresh test sets
26
+ 3. Measuring empirical coverage of the PAC bounds
27
+
28
+ Parameters
29
+ ----------
30
+ report : dict
31
+ Output from generate_rigorous_pac_report()
32
+ simulator : DataGenerator
33
+ Simulator to generate independent test data (e.g., BinaryClassifierSimulator)
34
+ test_size : int
35
+ Size of each test set
36
+ n_trials : int, default=1000
37
+ Number of independent trials
38
+ seed : int, optional
39
+ Random seed for reproducibility
40
+ verbose : bool, default=True
41
+ Print validation progress
42
+
43
+ Returns
44
+ -------
45
+ dict
46
+ Validation results with:
47
+ - 'marginal': Marginal operational rates and coverage
48
+ - 'class_0': Class 0 operational rates and coverage
49
+ - 'class_1': Class 1 operational rates and coverage
50
+ Each containing:
51
+ - 'singleton', 'doublet', 'abstention', 'singleton_error' dicts with:
52
+ - 'rates': Array of rates across trials
53
+ - 'mean': Mean rate
54
+ - 'quantiles': Quantiles (5%, 25%, 50%, 75%, 95%)
55
+ - 'bounds': PAC bounds from report
56
+ - 'expected': Expected rate from report
57
+ - 'empirical_coverage': Fraction of trials within bounds
58
+
59
+ Examples
60
+ --------
61
+ >>> from ssbc import BinaryClassifierSimulator, generate_rigorous_pac_report, validate_pac_bounds
62
+ >>> sim = BinaryClassifierSimulator(p_class1=0.2, seed=42)
63
+ >>> labels, probs = sim.generate(100)
64
+ >>> report = generate_rigorous_pac_report(labels, probs, delta=0.10)
65
+ >>> validation = validate_pac_bounds(report, sim, test_size=1000, n_trials=1000)
66
+ >>> print(f"Singleton coverage: {validation['marginal']['singleton']['empirical_coverage']:.1%}")
67
+
68
+ Notes
69
+ -----
70
+ This function is useful for:
71
+ - Verifying theoretical PAC guarantees empirically
72
+ - Understanding the tightness of bounds
73
+ - Debugging issues with bounds calculation
74
+ - Generating validation plots for papers/reports
75
+
76
+ The empirical coverage should be ≥ PAC level (1 - δ) for rigorous bounds.
77
+ """
78
+ if seed is not None:
79
+ np.random.seed(seed)
80
+
81
+ # Extract FIXED thresholds from calibration
82
+ threshold_0 = report["calibration_result"][0]["threshold"]
83
+ threshold_1 = report["calibration_result"][1]["threshold"]
84
+
85
+ if verbose:
86
+ print(f"Using fixed thresholds: q̂₀={threshold_0:.4f}, q̂₁={threshold_1:.4f}")
87
+ print(f"Running {n_trials} trials with test_size={test_size}...")
88
+
89
+ # Storage for realized rates
90
+ marginal_singleton_rates = []
91
+ marginal_doublet_rates = []
92
+ marginal_abstention_rates = []
93
+ marginal_singleton_error_rates = []
94
+
95
+ class_0_singleton_rates = []
96
+ class_0_doublet_rates = []
97
+ class_0_abstention_rates = []
98
+ class_0_singleton_error_rates = []
99
+
100
+ class_1_singleton_rates = []
101
+ class_1_doublet_rates = []
102
+ class_1_abstention_rates = []
103
+ class_1_singleton_error_rates = []
104
+
105
+ # Run trials
106
+ for _ in range(n_trials):
107
+ # Generate independent test set
108
+ labels_test, probs_test = simulator.generate(test_size)
109
+
110
+ # Apply FIXED Mondrian thresholds and evaluate
111
+ n_total = len(labels_test)
112
+ n_singletons = 0
113
+ n_doublets = 0
114
+ n_abstentions = 0
115
+ n_singletons_correct = 0
116
+
117
+ # Per-class counters
118
+ n_0 = np.sum(labels_test == 0)
119
+ n_1 = np.sum(labels_test == 1)
120
+
121
+ n_singletons_0 = 0
122
+ n_doublets_0 = 0
123
+ n_abstentions_0 = 0
124
+ n_singletons_correct_0 = 0
125
+
126
+ n_singletons_1 = 0
127
+ n_doublets_1 = 0
128
+ n_abstentions_1 = 0
129
+ n_singletons_correct_1 = 0
130
+
131
+ for i in range(n_total):
132
+ true_label = labels_test[i]
133
+ score_0 = 1.0 - probs_test[i, 0]
134
+ score_1 = 1.0 - probs_test[i, 1]
135
+
136
+ # Build prediction set using FIXED thresholds
137
+ in_0 = score_0 <= threshold_0
138
+ in_1 = score_1 <= threshold_1
139
+
140
+ # Marginal counts
141
+ if in_0 and in_1:
142
+ n_doublets += 1
143
+ elif in_0 or in_1:
144
+ n_singletons += 1
145
+ if (in_0 and true_label == 0) or (in_1 and true_label == 1):
146
+ n_singletons_correct += 1
147
+ else:
148
+ n_abstentions += 1
149
+
150
+ # Per-class counts
151
+ if true_label == 0:
152
+ if in_0 and in_1:
153
+ n_doublets_0 += 1
154
+ elif in_0 or in_1:
155
+ n_singletons_0 += 1
156
+ if in_0:
157
+ n_singletons_correct_0 += 1
158
+ else:
159
+ n_abstentions_0 += 1
160
+ else: # true_label == 1
161
+ if in_0 and in_1:
162
+ n_doublets_1 += 1
163
+ elif in_0 or in_1:
164
+ n_singletons_1 += 1
165
+ if in_1:
166
+ n_singletons_correct_1 += 1
167
+ else:
168
+ n_abstentions_1 += 1
169
+
170
+ # Compute marginal rates
171
+ marginal_singleton_rates.append(n_singletons / n_total)
172
+ marginal_doublet_rates.append(n_doublets / n_total)
173
+ marginal_abstention_rates.append(n_abstentions / n_total)
174
+
175
+ singleton_error_rate = (n_singletons - n_singletons_correct) / n_singletons if n_singletons > 0 else np.nan
176
+ marginal_singleton_error_rates.append(singleton_error_rate)
177
+
178
+ # Compute per-class rates
179
+ if n_0 > 0:
180
+ class_0_singleton_rates.append(n_singletons_0 / n_0)
181
+ class_0_doublet_rates.append(n_doublets_0 / n_0)
182
+ class_0_abstention_rates.append(n_abstentions_0 / n_0)
183
+ singleton_error_0 = (
184
+ (n_singletons_0 - n_singletons_correct_0) / n_singletons_0 if n_singletons_0 > 0 else np.nan
185
+ )
186
+ class_0_singleton_error_rates.append(singleton_error_0)
187
+
188
+ if n_1 > 0:
189
+ class_1_singleton_rates.append(n_singletons_1 / n_1)
190
+ class_1_doublet_rates.append(n_doublets_1 / n_1)
191
+ class_1_abstention_rates.append(n_abstentions_1 / n_1)
192
+ singleton_error_1 = (
193
+ (n_singletons_1 - n_singletons_correct_1) / n_singletons_1 if n_singletons_1 > 0 else np.nan
194
+ )
195
+ class_1_singleton_error_rates.append(singleton_error_1)
196
+
197
+ # Convert to arrays
198
+ marginal_singleton_rates = np.array(marginal_singleton_rates)
199
+ marginal_doublet_rates = np.array(marginal_doublet_rates)
200
+ marginal_abstention_rates = np.array(marginal_abstention_rates)
201
+ marginal_singleton_error_rates = np.array(marginal_singleton_error_rates)
202
+
203
+ class_0_singleton_rates = np.array(class_0_singleton_rates)
204
+ class_0_doublet_rates = np.array(class_0_doublet_rates)
205
+ class_0_abstention_rates = np.array(class_0_abstention_rates)
206
+ class_0_singleton_error_rates = np.array(class_0_singleton_error_rates)
207
+
208
+ class_1_singleton_rates = np.array(class_1_singleton_rates)
209
+ class_1_doublet_rates = np.array(class_1_doublet_rates)
210
+ class_1_abstention_rates = np.array(class_1_abstention_rates)
211
+ class_1_singleton_error_rates = np.array(class_1_singleton_error_rates)
212
+
213
+ # Helper functions
214
+ def check_coverage(rates: np.ndarray, bounds: tuple[float, float]) -> float:
215
+ """Check what fraction of rates fall within bounds."""
216
+ lower, upper = bounds
217
+ within = np.sum((rates >= lower) & (rates <= upper))
218
+ return within / len(rates)
219
+
220
+ def check_coverage_with_nan(rates: np.ndarray, bounds: tuple[float, float]) -> float:
221
+ """Check coverage, ignoring NaN values."""
222
+ lower, upper = bounds
223
+ valid = ~np.isnan(rates)
224
+ if np.sum(valid) == 0:
225
+ return np.nan
226
+ rates_valid = rates[valid]
227
+ within = np.sum((rates_valid >= lower) & (rates_valid <= upper))
228
+ return within / len(rates_valid)
229
+
230
+ def compute_quantiles(rates: np.ndarray) -> dict[str, float]:
231
+ """Compute quantiles, handling NaN."""
232
+ valid = rates[~np.isnan(rates)] if np.any(np.isnan(rates)) else rates
233
+ if len(valid) == 0:
234
+ return {"q05": np.nan, "q25": np.nan, "q50": np.nan, "q75": np.nan, "q95": np.nan}
235
+ return {
236
+ "q05": float(np.percentile(valid, 5)),
237
+ "q25": float(np.percentile(valid, 25)),
238
+ "q50": float(np.percentile(valid, 50)),
239
+ "q75": float(np.percentile(valid, 75)),
240
+ "q95": float(np.percentile(valid, 95)),
241
+ }
242
+
243
+ # Get bounds from report
244
+ pac_marg = report["pac_bounds_marginal"]
245
+ pac_0 = report["pac_bounds_class_0"]
246
+ pac_1 = report["pac_bounds_class_1"]
247
+
248
+ return {
249
+ "n_trials": n_trials,
250
+ "test_size": test_size,
251
+ "threshold_0": threshold_0,
252
+ "threshold_1": threshold_1,
253
+ "marginal": {
254
+ "singleton": {
255
+ "rates": marginal_singleton_rates,
256
+ "mean": np.mean(marginal_singleton_rates),
257
+ "quantiles": compute_quantiles(marginal_singleton_rates),
258
+ "bounds": pac_marg["singleton_rate_bounds"],
259
+ "expected": pac_marg["expected_singleton_rate"],
260
+ "empirical_coverage": check_coverage(marginal_singleton_rates, pac_marg["singleton_rate_bounds"]),
261
+ },
262
+ "doublet": {
263
+ "rates": marginal_doublet_rates,
264
+ "mean": np.mean(marginal_doublet_rates),
265
+ "quantiles": compute_quantiles(marginal_doublet_rates),
266
+ "bounds": pac_marg["doublet_rate_bounds"],
267
+ "expected": pac_marg["expected_doublet_rate"],
268
+ "empirical_coverage": check_coverage(marginal_doublet_rates, pac_marg["doublet_rate_bounds"]),
269
+ },
270
+ "abstention": {
271
+ "rates": marginal_abstention_rates,
272
+ "mean": np.mean(marginal_abstention_rates),
273
+ "quantiles": compute_quantiles(marginal_abstention_rates),
274
+ "bounds": pac_marg["abstention_rate_bounds"],
275
+ "expected": pac_marg["expected_abstention_rate"],
276
+ "empirical_coverage": check_coverage(marginal_abstention_rates, pac_marg["abstention_rate_bounds"]),
277
+ },
278
+ "singleton_error": {
279
+ "rates": marginal_singleton_error_rates,
280
+ "mean": np.nanmean(marginal_singleton_error_rates),
281
+ "quantiles": compute_quantiles(marginal_singleton_error_rates),
282
+ "bounds": pac_marg["singleton_error_rate_bounds"],
283
+ "expected": pac_marg["expected_singleton_error_rate"],
284
+ "empirical_coverage": check_coverage_with_nan(
285
+ marginal_singleton_error_rates, pac_marg["singleton_error_rate_bounds"]
286
+ ),
287
+ },
288
+ },
289
+ "class_0": {
290
+ "singleton": {
291
+ "rates": class_0_singleton_rates,
292
+ "mean": np.mean(class_0_singleton_rates),
293
+ "quantiles": compute_quantiles(class_0_singleton_rates),
294
+ "bounds": pac_0["singleton_rate_bounds"],
295
+ "expected": pac_0["expected_singleton_rate"],
296
+ "empirical_coverage": check_coverage(class_0_singleton_rates, pac_0["singleton_rate_bounds"]),
297
+ },
298
+ "doublet": {
299
+ "rates": class_0_doublet_rates,
300
+ "mean": np.mean(class_0_doublet_rates),
301
+ "quantiles": compute_quantiles(class_0_doublet_rates),
302
+ "bounds": pac_0["doublet_rate_bounds"],
303
+ "expected": pac_0["expected_doublet_rate"],
304
+ "empirical_coverage": check_coverage(class_0_doublet_rates, pac_0["doublet_rate_bounds"]),
305
+ },
306
+ "abstention": {
307
+ "rates": class_0_abstention_rates,
308
+ "mean": np.mean(class_0_abstention_rates),
309
+ "quantiles": compute_quantiles(class_0_abstention_rates),
310
+ "bounds": pac_0["abstention_rate_bounds"],
311
+ "expected": pac_0["expected_abstention_rate"],
312
+ "empirical_coverage": check_coverage(class_0_abstention_rates, pac_0["abstention_rate_bounds"]),
313
+ },
314
+ "singleton_error": {
315
+ "rates": class_0_singleton_error_rates,
316
+ "mean": np.nanmean(class_0_singleton_error_rates),
317
+ "quantiles": compute_quantiles(class_0_singleton_error_rates),
318
+ "bounds": pac_0["singleton_error_rate_bounds"],
319
+ "expected": pac_0["expected_singleton_error_rate"],
320
+ "empirical_coverage": check_coverage_with_nan(
321
+ class_0_singleton_error_rates, pac_0["singleton_error_rate_bounds"]
322
+ ),
323
+ },
324
+ },
325
+ "class_1": {
326
+ "singleton": {
327
+ "rates": class_1_singleton_rates,
328
+ "mean": np.mean(class_1_singleton_rates),
329
+ "quantiles": compute_quantiles(class_1_singleton_rates),
330
+ "bounds": pac_1["singleton_rate_bounds"],
331
+ "expected": pac_1["expected_singleton_rate"],
332
+ "empirical_coverage": check_coverage(class_1_singleton_rates, pac_1["singleton_rate_bounds"]),
333
+ },
334
+ "doublet": {
335
+ "rates": class_1_doublet_rates,
336
+ "mean": np.mean(class_1_doublet_rates),
337
+ "quantiles": compute_quantiles(class_1_doublet_rates),
338
+ "bounds": pac_1["doublet_rate_bounds"],
339
+ "expected": pac_1["expected_doublet_rate"],
340
+ "empirical_coverage": check_coverage(class_1_doublet_rates, pac_1["doublet_rate_bounds"]),
341
+ },
342
+ "abstention": {
343
+ "rates": class_1_abstention_rates,
344
+ "mean": np.mean(class_1_abstention_rates),
345
+ "quantiles": compute_quantiles(class_1_abstention_rates),
346
+ "bounds": pac_1["abstention_rate_bounds"],
347
+ "expected": pac_1["expected_abstention_rate"],
348
+ "empirical_coverage": check_coverage(class_1_abstention_rates, pac_1["abstention_rate_bounds"]),
349
+ },
350
+ "singleton_error": {
351
+ "rates": class_1_singleton_error_rates,
352
+ "mean": np.nanmean(class_1_singleton_error_rates),
353
+ "quantiles": compute_quantiles(class_1_singleton_error_rates),
354
+ "bounds": pac_1["singleton_error_rate_bounds"],
355
+ "expected": pac_1["expected_singleton_error_rate"],
356
+ "empirical_coverage": check_coverage_with_nan(
357
+ class_1_singleton_error_rates, pac_1["singleton_error_rate_bounds"]
358
+ ),
359
+ },
360
+ },
361
+ }
362
+
363
+
364
+ def print_validation_results(validation: dict[str, Any]) -> None:
365
+ """Pretty print validation results.
366
+
367
+ Parameters
368
+ ----------
369
+ validation : dict
370
+ Output from validate_pac_bounds()
371
+
372
+ Examples
373
+ --------
374
+ >>> validation = validate_pac_bounds(report, sim, test_size=1000, n_trials=1000)
375
+ >>> print_validation_results(validation)
376
+ """
377
+ print("=" * 80)
378
+ print("PAC BOUNDS VALIDATION RESULTS")
379
+ print("=" * 80)
380
+ print(f"\nTrials: {validation['n_trials']}")
381
+ print(f"Test size: {validation['test_size']}")
382
+ print(f"Thresholds: q̂₀={validation['threshold_0']:.4f}, q̂₁={validation['threshold_1']:.4f}")
383
+
384
+ for scope in ["marginal", "class_0", "class_1"]:
385
+ scope_name = scope.upper() if scope == "marginal" else f"CLASS {scope[-1]}"
386
+ print(f"\n{'=' * 80}")
387
+ print(f"{scope_name}")
388
+ print("=" * 80)
389
+
390
+ for metric in ["singleton", "doublet", "abstention", "singleton_error"]:
391
+ m = validation[scope][metric]
392
+ q = m["quantiles"]
393
+ coverage = m["empirical_coverage"]
394
+
395
+ coverage_check = "✅" if coverage >= 0.90 else "❌" # Assuming 90% PAC level
396
+
397
+ print(f"\n{metric.upper().replace('_', ' ')}:")
398
+ print(f" Empirical mean: {m['mean']:.4f}")
399
+ print(f" Expected (LOO): {m['expected']:.4f}")
400
+ q_str = f"[5%: {q['q05']:.3f}, 25%: {q['q25']:.3f}, 50%: {q['q50']:.3f}, "
401
+ q_str += f"75%: {q['q75']:.3f}, 95%: {q['q95']:.3f}]"
402
+ print(f" Quantiles: {q_str}")
403
+ print(f" PAC bounds: [{m['bounds'][0]:.4f}, {m['bounds'][1]:.4f}]")
404
+ if not np.isnan(coverage):
405
+ print(f" Coverage: {coverage:.1%} {coverage_check}")
406
+ else:
407
+ print(" Coverage: N/A (no valid samples)")
408
+
409
+ print("\n" + "=" * 80)