ssbc 1.0.0__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ssbc/__init__.py +47 -1
- ssbc/bootstrap.py +411 -0
- ssbc/cli.py +0 -3
- ssbc/conformal.py +700 -1
- ssbc/cross_conformal.py +425 -0
- ssbc/mcp_server.py +93 -0
- ssbc/operational_bounds_simple.py +367 -0
- ssbc/rigorous_report.py +601 -0
- ssbc/statistics.py +70 -0
- ssbc/utils.py +72 -2
- ssbc/validation.py +409 -0
- ssbc/visualization.py +323 -300
- ssbc-1.1.0.dist-info/METADATA +337 -0
- ssbc-1.1.0.dist-info/RECORD +22 -0
- ssbc-1.1.0.dist-info/licenses/LICENSE +29 -0
- ssbc/ssbc.py +0 -1
- ssbc-1.0.0.dist-info/METADATA +0 -266
- ssbc-1.0.0.dist-info/RECORD +0 -17
- ssbc-1.0.0.dist-info/licenses/LICENSE +0 -21
- {ssbc-1.0.0.dist-info → ssbc-1.1.0.dist-info}/WHEEL +0 -0
- {ssbc-1.0.0.dist-info → ssbc-1.1.0.dist-info}/entry_points.txt +0 -0
- {ssbc-1.0.0.dist-info → ssbc-1.1.0.dist-info}/top_level.txt +0 -0
ssbc/utils.py
CHANGED
@@ -1,2 +1,72 @@
|
|
1
|
-
|
2
|
-
|
1
|
+
"""Utility functions for conformal prediction."""
|
2
|
+
|
3
|
+
from typing import Literal
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
|
7
|
+
|
8
|
+
def compute_operational_rate(
|
9
|
+
prediction_sets: list[set | list],
|
10
|
+
true_labels: np.ndarray,
|
11
|
+
rate_type: Literal["singleton", "doublet", "abstention", "error_in_singleton", "correct_in_singleton"],
|
12
|
+
) -> np.ndarray:
|
13
|
+
"""Compute operational rate indicators for prediction sets.
|
14
|
+
|
15
|
+
For each prediction set, compute a binary indicator showing whether
|
16
|
+
a specific operational event occurred (singleton, doublet, abstention,
|
17
|
+
error in singleton, or correct in singleton).
|
18
|
+
|
19
|
+
Parameters
|
20
|
+
----------
|
21
|
+
prediction_sets : list[set | list]
|
22
|
+
Prediction sets for each sample. Each set contains predicted labels.
|
23
|
+
true_labels : np.ndarray
|
24
|
+
True labels for each sample
|
25
|
+
rate_type : {"singleton", "doublet", "abstention", "error_in_singleton", "correct_in_singleton"}
|
26
|
+
Type of operational rate to compute:
|
27
|
+
- "singleton": prediction set contains exactly one label
|
28
|
+
- "doublet": prediction set contains exactly two labels
|
29
|
+
- "abstention": prediction set is empty
|
30
|
+
- "error_in_singleton": singleton prediction that doesn't contain true label
|
31
|
+
- "correct_in_singleton": singleton prediction that contains true label
|
32
|
+
|
33
|
+
Returns
|
34
|
+
-------
|
35
|
+
np.ndarray
|
36
|
+
Binary indicators (0 or 1) for whether the event holds for each sample
|
37
|
+
|
38
|
+
Examples
|
39
|
+
--------
|
40
|
+
>>> pred_sets = [{0}, {0, 1}, set(), {1}]
|
41
|
+
>>> true_labels = np.array([0, 0, 1, 0])
|
42
|
+
>>> indicators = compute_operational_rate(pred_sets, true_labels, "singleton")
|
43
|
+
>>> print(indicators) # [1, 0, 0, 1]
|
44
|
+
>>> indicators = compute_operational_rate(pred_sets, true_labels, "correct_in_singleton")
|
45
|
+
>>> print(indicators) # [1, 0, 0, 0] - first and last are singletons, first is correct
|
46
|
+
|
47
|
+
Notes
|
48
|
+
-----
|
49
|
+
This function is useful for computing operational statistics on conformal
|
50
|
+
prediction sets, such as singleton rates, escalation rates, and error rates.
|
51
|
+
"""
|
52
|
+
n = len(prediction_sets)
|
53
|
+
indicators = np.zeros(n, dtype=int)
|
54
|
+
|
55
|
+
for i in range(n):
|
56
|
+
pred_set = prediction_sets[i]
|
57
|
+
y_true = true_labels[i]
|
58
|
+
|
59
|
+
if rate_type == "singleton":
|
60
|
+
indicators[i] = int(len(pred_set) == 1)
|
61
|
+
elif rate_type == "doublet":
|
62
|
+
indicators[i] = int(len(pred_set) == 2)
|
63
|
+
elif rate_type == "abstention":
|
64
|
+
indicators[i] = int(len(pred_set) == 0)
|
65
|
+
elif rate_type == "error_in_singleton":
|
66
|
+
indicators[i] = int(len(pred_set) == 1 and y_true not in pred_set)
|
67
|
+
elif rate_type == "correct_in_singleton":
|
68
|
+
indicators[i] = int(len(pred_set) == 1 and y_true in pred_set)
|
69
|
+
else:
|
70
|
+
raise ValueError(f"Unknown rate_type: {rate_type}")
|
71
|
+
|
72
|
+
return indicators
|
ssbc/validation.py
ADDED
@@ -0,0 +1,409 @@
|
|
1
|
+
"""Validation utilities for PAC-controlled operational bounds.
|
2
|
+
|
3
|
+
This module provides tools to empirically validate the theoretical PAC guarantees
|
4
|
+
by running simulations with fixed calibration thresholds on independent test sets.
|
5
|
+
"""
|
6
|
+
|
7
|
+
from typing import Any
|
8
|
+
|
9
|
+
import numpy as np
|
10
|
+
|
11
|
+
|
12
|
+
def validate_pac_bounds(
|
13
|
+
report: dict[str, Any],
|
14
|
+
simulator: Any,
|
15
|
+
test_size: int,
|
16
|
+
n_trials: int = 1000,
|
17
|
+
seed: int | None = None,
|
18
|
+
verbose: bool = True,
|
19
|
+
) -> dict[str, Any]:
|
20
|
+
"""Empirically validate PAC operational bounds.
|
21
|
+
|
22
|
+
Takes a PAC report from generate_rigorous_pac_report() and validates that
|
23
|
+
the theoretical bounds actually hold in practice by:
|
24
|
+
1. Extracting the FIXED thresholds from calibration
|
25
|
+
2. Running n_trials simulations with fresh test sets
|
26
|
+
3. Measuring empirical coverage of the PAC bounds
|
27
|
+
|
28
|
+
Parameters
|
29
|
+
----------
|
30
|
+
report : dict
|
31
|
+
Output from generate_rigorous_pac_report()
|
32
|
+
simulator : DataGenerator
|
33
|
+
Simulator to generate independent test data (e.g., BinaryClassifierSimulator)
|
34
|
+
test_size : int
|
35
|
+
Size of each test set
|
36
|
+
n_trials : int, default=1000
|
37
|
+
Number of independent trials
|
38
|
+
seed : int, optional
|
39
|
+
Random seed for reproducibility
|
40
|
+
verbose : bool, default=True
|
41
|
+
Print validation progress
|
42
|
+
|
43
|
+
Returns
|
44
|
+
-------
|
45
|
+
dict
|
46
|
+
Validation results with:
|
47
|
+
- 'marginal': Marginal operational rates and coverage
|
48
|
+
- 'class_0': Class 0 operational rates and coverage
|
49
|
+
- 'class_1': Class 1 operational rates and coverage
|
50
|
+
Each containing:
|
51
|
+
- 'singleton', 'doublet', 'abstention', 'singleton_error' dicts with:
|
52
|
+
- 'rates': Array of rates across trials
|
53
|
+
- 'mean': Mean rate
|
54
|
+
- 'quantiles': Quantiles (5%, 25%, 50%, 75%, 95%)
|
55
|
+
- 'bounds': PAC bounds from report
|
56
|
+
- 'expected': Expected rate from report
|
57
|
+
- 'empirical_coverage': Fraction of trials within bounds
|
58
|
+
|
59
|
+
Examples
|
60
|
+
--------
|
61
|
+
>>> from ssbc import BinaryClassifierSimulator, generate_rigorous_pac_report, validate_pac_bounds
|
62
|
+
>>> sim = BinaryClassifierSimulator(p_class1=0.2, seed=42)
|
63
|
+
>>> labels, probs = sim.generate(100)
|
64
|
+
>>> report = generate_rigorous_pac_report(labels, probs, delta=0.10)
|
65
|
+
>>> validation = validate_pac_bounds(report, sim, test_size=1000, n_trials=1000)
|
66
|
+
>>> print(f"Singleton coverage: {validation['marginal']['singleton']['empirical_coverage']:.1%}")
|
67
|
+
|
68
|
+
Notes
|
69
|
+
-----
|
70
|
+
This function is useful for:
|
71
|
+
- Verifying theoretical PAC guarantees empirically
|
72
|
+
- Understanding the tightness of bounds
|
73
|
+
- Debugging issues with bounds calculation
|
74
|
+
- Generating validation plots for papers/reports
|
75
|
+
|
76
|
+
The empirical coverage should be ≥ PAC level (1 - δ) for rigorous bounds.
|
77
|
+
"""
|
78
|
+
if seed is not None:
|
79
|
+
np.random.seed(seed)
|
80
|
+
|
81
|
+
# Extract FIXED thresholds from calibration
|
82
|
+
threshold_0 = report["calibration_result"][0]["threshold"]
|
83
|
+
threshold_1 = report["calibration_result"][1]["threshold"]
|
84
|
+
|
85
|
+
if verbose:
|
86
|
+
print(f"Using fixed thresholds: q̂₀={threshold_0:.4f}, q̂₁={threshold_1:.4f}")
|
87
|
+
print(f"Running {n_trials} trials with test_size={test_size}...")
|
88
|
+
|
89
|
+
# Storage for realized rates
|
90
|
+
marginal_singleton_rates = []
|
91
|
+
marginal_doublet_rates = []
|
92
|
+
marginal_abstention_rates = []
|
93
|
+
marginal_singleton_error_rates = []
|
94
|
+
|
95
|
+
class_0_singleton_rates = []
|
96
|
+
class_0_doublet_rates = []
|
97
|
+
class_0_abstention_rates = []
|
98
|
+
class_0_singleton_error_rates = []
|
99
|
+
|
100
|
+
class_1_singleton_rates = []
|
101
|
+
class_1_doublet_rates = []
|
102
|
+
class_1_abstention_rates = []
|
103
|
+
class_1_singleton_error_rates = []
|
104
|
+
|
105
|
+
# Run trials
|
106
|
+
for _ in range(n_trials):
|
107
|
+
# Generate independent test set
|
108
|
+
labels_test, probs_test = simulator.generate(test_size)
|
109
|
+
|
110
|
+
# Apply FIXED Mondrian thresholds and evaluate
|
111
|
+
n_total = len(labels_test)
|
112
|
+
n_singletons = 0
|
113
|
+
n_doublets = 0
|
114
|
+
n_abstentions = 0
|
115
|
+
n_singletons_correct = 0
|
116
|
+
|
117
|
+
# Per-class counters
|
118
|
+
n_0 = np.sum(labels_test == 0)
|
119
|
+
n_1 = np.sum(labels_test == 1)
|
120
|
+
|
121
|
+
n_singletons_0 = 0
|
122
|
+
n_doublets_0 = 0
|
123
|
+
n_abstentions_0 = 0
|
124
|
+
n_singletons_correct_0 = 0
|
125
|
+
|
126
|
+
n_singletons_1 = 0
|
127
|
+
n_doublets_1 = 0
|
128
|
+
n_abstentions_1 = 0
|
129
|
+
n_singletons_correct_1 = 0
|
130
|
+
|
131
|
+
for i in range(n_total):
|
132
|
+
true_label = labels_test[i]
|
133
|
+
score_0 = 1.0 - probs_test[i, 0]
|
134
|
+
score_1 = 1.0 - probs_test[i, 1]
|
135
|
+
|
136
|
+
# Build prediction set using FIXED thresholds
|
137
|
+
in_0 = score_0 <= threshold_0
|
138
|
+
in_1 = score_1 <= threshold_1
|
139
|
+
|
140
|
+
# Marginal counts
|
141
|
+
if in_0 and in_1:
|
142
|
+
n_doublets += 1
|
143
|
+
elif in_0 or in_1:
|
144
|
+
n_singletons += 1
|
145
|
+
if (in_0 and true_label == 0) or (in_1 and true_label == 1):
|
146
|
+
n_singletons_correct += 1
|
147
|
+
else:
|
148
|
+
n_abstentions += 1
|
149
|
+
|
150
|
+
# Per-class counts
|
151
|
+
if true_label == 0:
|
152
|
+
if in_0 and in_1:
|
153
|
+
n_doublets_0 += 1
|
154
|
+
elif in_0 or in_1:
|
155
|
+
n_singletons_0 += 1
|
156
|
+
if in_0:
|
157
|
+
n_singletons_correct_0 += 1
|
158
|
+
else:
|
159
|
+
n_abstentions_0 += 1
|
160
|
+
else: # true_label == 1
|
161
|
+
if in_0 and in_1:
|
162
|
+
n_doublets_1 += 1
|
163
|
+
elif in_0 or in_1:
|
164
|
+
n_singletons_1 += 1
|
165
|
+
if in_1:
|
166
|
+
n_singletons_correct_1 += 1
|
167
|
+
else:
|
168
|
+
n_abstentions_1 += 1
|
169
|
+
|
170
|
+
# Compute marginal rates
|
171
|
+
marginal_singleton_rates.append(n_singletons / n_total)
|
172
|
+
marginal_doublet_rates.append(n_doublets / n_total)
|
173
|
+
marginal_abstention_rates.append(n_abstentions / n_total)
|
174
|
+
|
175
|
+
singleton_error_rate = (n_singletons - n_singletons_correct) / n_singletons if n_singletons > 0 else np.nan
|
176
|
+
marginal_singleton_error_rates.append(singleton_error_rate)
|
177
|
+
|
178
|
+
# Compute per-class rates
|
179
|
+
if n_0 > 0:
|
180
|
+
class_0_singleton_rates.append(n_singletons_0 / n_0)
|
181
|
+
class_0_doublet_rates.append(n_doublets_0 / n_0)
|
182
|
+
class_0_abstention_rates.append(n_abstentions_0 / n_0)
|
183
|
+
singleton_error_0 = (
|
184
|
+
(n_singletons_0 - n_singletons_correct_0) / n_singletons_0 if n_singletons_0 > 0 else np.nan
|
185
|
+
)
|
186
|
+
class_0_singleton_error_rates.append(singleton_error_0)
|
187
|
+
|
188
|
+
if n_1 > 0:
|
189
|
+
class_1_singleton_rates.append(n_singletons_1 / n_1)
|
190
|
+
class_1_doublet_rates.append(n_doublets_1 / n_1)
|
191
|
+
class_1_abstention_rates.append(n_abstentions_1 / n_1)
|
192
|
+
singleton_error_1 = (
|
193
|
+
(n_singletons_1 - n_singletons_correct_1) / n_singletons_1 if n_singletons_1 > 0 else np.nan
|
194
|
+
)
|
195
|
+
class_1_singleton_error_rates.append(singleton_error_1)
|
196
|
+
|
197
|
+
# Convert to arrays
|
198
|
+
marginal_singleton_rates = np.array(marginal_singleton_rates)
|
199
|
+
marginal_doublet_rates = np.array(marginal_doublet_rates)
|
200
|
+
marginal_abstention_rates = np.array(marginal_abstention_rates)
|
201
|
+
marginal_singleton_error_rates = np.array(marginal_singleton_error_rates)
|
202
|
+
|
203
|
+
class_0_singleton_rates = np.array(class_0_singleton_rates)
|
204
|
+
class_0_doublet_rates = np.array(class_0_doublet_rates)
|
205
|
+
class_0_abstention_rates = np.array(class_0_abstention_rates)
|
206
|
+
class_0_singleton_error_rates = np.array(class_0_singleton_error_rates)
|
207
|
+
|
208
|
+
class_1_singleton_rates = np.array(class_1_singleton_rates)
|
209
|
+
class_1_doublet_rates = np.array(class_1_doublet_rates)
|
210
|
+
class_1_abstention_rates = np.array(class_1_abstention_rates)
|
211
|
+
class_1_singleton_error_rates = np.array(class_1_singleton_error_rates)
|
212
|
+
|
213
|
+
# Helper functions
|
214
|
+
def check_coverage(rates: np.ndarray, bounds: tuple[float, float]) -> float:
|
215
|
+
"""Check what fraction of rates fall within bounds."""
|
216
|
+
lower, upper = bounds
|
217
|
+
within = np.sum((rates >= lower) & (rates <= upper))
|
218
|
+
return within / len(rates)
|
219
|
+
|
220
|
+
def check_coverage_with_nan(rates: np.ndarray, bounds: tuple[float, float]) -> float:
|
221
|
+
"""Check coverage, ignoring NaN values."""
|
222
|
+
lower, upper = bounds
|
223
|
+
valid = ~np.isnan(rates)
|
224
|
+
if np.sum(valid) == 0:
|
225
|
+
return np.nan
|
226
|
+
rates_valid = rates[valid]
|
227
|
+
within = np.sum((rates_valid >= lower) & (rates_valid <= upper))
|
228
|
+
return within / len(rates_valid)
|
229
|
+
|
230
|
+
def compute_quantiles(rates: np.ndarray) -> dict[str, float]:
|
231
|
+
"""Compute quantiles, handling NaN."""
|
232
|
+
valid = rates[~np.isnan(rates)] if np.any(np.isnan(rates)) else rates
|
233
|
+
if len(valid) == 0:
|
234
|
+
return {"q05": np.nan, "q25": np.nan, "q50": np.nan, "q75": np.nan, "q95": np.nan}
|
235
|
+
return {
|
236
|
+
"q05": float(np.percentile(valid, 5)),
|
237
|
+
"q25": float(np.percentile(valid, 25)),
|
238
|
+
"q50": float(np.percentile(valid, 50)),
|
239
|
+
"q75": float(np.percentile(valid, 75)),
|
240
|
+
"q95": float(np.percentile(valid, 95)),
|
241
|
+
}
|
242
|
+
|
243
|
+
# Get bounds from report
|
244
|
+
pac_marg = report["pac_bounds_marginal"]
|
245
|
+
pac_0 = report["pac_bounds_class_0"]
|
246
|
+
pac_1 = report["pac_bounds_class_1"]
|
247
|
+
|
248
|
+
return {
|
249
|
+
"n_trials": n_trials,
|
250
|
+
"test_size": test_size,
|
251
|
+
"threshold_0": threshold_0,
|
252
|
+
"threshold_1": threshold_1,
|
253
|
+
"marginal": {
|
254
|
+
"singleton": {
|
255
|
+
"rates": marginal_singleton_rates,
|
256
|
+
"mean": np.mean(marginal_singleton_rates),
|
257
|
+
"quantiles": compute_quantiles(marginal_singleton_rates),
|
258
|
+
"bounds": pac_marg["singleton_rate_bounds"],
|
259
|
+
"expected": pac_marg["expected_singleton_rate"],
|
260
|
+
"empirical_coverage": check_coverage(marginal_singleton_rates, pac_marg["singleton_rate_bounds"]),
|
261
|
+
},
|
262
|
+
"doublet": {
|
263
|
+
"rates": marginal_doublet_rates,
|
264
|
+
"mean": np.mean(marginal_doublet_rates),
|
265
|
+
"quantiles": compute_quantiles(marginal_doublet_rates),
|
266
|
+
"bounds": pac_marg["doublet_rate_bounds"],
|
267
|
+
"expected": pac_marg["expected_doublet_rate"],
|
268
|
+
"empirical_coverage": check_coverage(marginal_doublet_rates, pac_marg["doublet_rate_bounds"]),
|
269
|
+
},
|
270
|
+
"abstention": {
|
271
|
+
"rates": marginal_abstention_rates,
|
272
|
+
"mean": np.mean(marginal_abstention_rates),
|
273
|
+
"quantiles": compute_quantiles(marginal_abstention_rates),
|
274
|
+
"bounds": pac_marg["abstention_rate_bounds"],
|
275
|
+
"expected": pac_marg["expected_abstention_rate"],
|
276
|
+
"empirical_coverage": check_coverage(marginal_abstention_rates, pac_marg["abstention_rate_bounds"]),
|
277
|
+
},
|
278
|
+
"singleton_error": {
|
279
|
+
"rates": marginal_singleton_error_rates,
|
280
|
+
"mean": np.nanmean(marginal_singleton_error_rates),
|
281
|
+
"quantiles": compute_quantiles(marginal_singleton_error_rates),
|
282
|
+
"bounds": pac_marg["singleton_error_rate_bounds"],
|
283
|
+
"expected": pac_marg["expected_singleton_error_rate"],
|
284
|
+
"empirical_coverage": check_coverage_with_nan(
|
285
|
+
marginal_singleton_error_rates, pac_marg["singleton_error_rate_bounds"]
|
286
|
+
),
|
287
|
+
},
|
288
|
+
},
|
289
|
+
"class_0": {
|
290
|
+
"singleton": {
|
291
|
+
"rates": class_0_singleton_rates,
|
292
|
+
"mean": np.mean(class_0_singleton_rates),
|
293
|
+
"quantiles": compute_quantiles(class_0_singleton_rates),
|
294
|
+
"bounds": pac_0["singleton_rate_bounds"],
|
295
|
+
"expected": pac_0["expected_singleton_rate"],
|
296
|
+
"empirical_coverage": check_coverage(class_0_singleton_rates, pac_0["singleton_rate_bounds"]),
|
297
|
+
},
|
298
|
+
"doublet": {
|
299
|
+
"rates": class_0_doublet_rates,
|
300
|
+
"mean": np.mean(class_0_doublet_rates),
|
301
|
+
"quantiles": compute_quantiles(class_0_doublet_rates),
|
302
|
+
"bounds": pac_0["doublet_rate_bounds"],
|
303
|
+
"expected": pac_0["expected_doublet_rate"],
|
304
|
+
"empirical_coverage": check_coverage(class_0_doublet_rates, pac_0["doublet_rate_bounds"]),
|
305
|
+
},
|
306
|
+
"abstention": {
|
307
|
+
"rates": class_0_abstention_rates,
|
308
|
+
"mean": np.mean(class_0_abstention_rates),
|
309
|
+
"quantiles": compute_quantiles(class_0_abstention_rates),
|
310
|
+
"bounds": pac_0["abstention_rate_bounds"],
|
311
|
+
"expected": pac_0["expected_abstention_rate"],
|
312
|
+
"empirical_coverage": check_coverage(class_0_abstention_rates, pac_0["abstention_rate_bounds"]),
|
313
|
+
},
|
314
|
+
"singleton_error": {
|
315
|
+
"rates": class_0_singleton_error_rates,
|
316
|
+
"mean": np.nanmean(class_0_singleton_error_rates),
|
317
|
+
"quantiles": compute_quantiles(class_0_singleton_error_rates),
|
318
|
+
"bounds": pac_0["singleton_error_rate_bounds"],
|
319
|
+
"expected": pac_0["expected_singleton_error_rate"],
|
320
|
+
"empirical_coverage": check_coverage_with_nan(
|
321
|
+
class_0_singleton_error_rates, pac_0["singleton_error_rate_bounds"]
|
322
|
+
),
|
323
|
+
},
|
324
|
+
},
|
325
|
+
"class_1": {
|
326
|
+
"singleton": {
|
327
|
+
"rates": class_1_singleton_rates,
|
328
|
+
"mean": np.mean(class_1_singleton_rates),
|
329
|
+
"quantiles": compute_quantiles(class_1_singleton_rates),
|
330
|
+
"bounds": pac_1["singleton_rate_bounds"],
|
331
|
+
"expected": pac_1["expected_singleton_rate"],
|
332
|
+
"empirical_coverage": check_coverage(class_1_singleton_rates, pac_1["singleton_rate_bounds"]),
|
333
|
+
},
|
334
|
+
"doublet": {
|
335
|
+
"rates": class_1_doublet_rates,
|
336
|
+
"mean": np.mean(class_1_doublet_rates),
|
337
|
+
"quantiles": compute_quantiles(class_1_doublet_rates),
|
338
|
+
"bounds": pac_1["doublet_rate_bounds"],
|
339
|
+
"expected": pac_1["expected_doublet_rate"],
|
340
|
+
"empirical_coverage": check_coverage(class_1_doublet_rates, pac_1["doublet_rate_bounds"]),
|
341
|
+
},
|
342
|
+
"abstention": {
|
343
|
+
"rates": class_1_abstention_rates,
|
344
|
+
"mean": np.mean(class_1_abstention_rates),
|
345
|
+
"quantiles": compute_quantiles(class_1_abstention_rates),
|
346
|
+
"bounds": pac_1["abstention_rate_bounds"],
|
347
|
+
"expected": pac_1["expected_abstention_rate"],
|
348
|
+
"empirical_coverage": check_coverage(class_1_abstention_rates, pac_1["abstention_rate_bounds"]),
|
349
|
+
},
|
350
|
+
"singleton_error": {
|
351
|
+
"rates": class_1_singleton_error_rates,
|
352
|
+
"mean": np.nanmean(class_1_singleton_error_rates),
|
353
|
+
"quantiles": compute_quantiles(class_1_singleton_error_rates),
|
354
|
+
"bounds": pac_1["singleton_error_rate_bounds"],
|
355
|
+
"expected": pac_1["expected_singleton_error_rate"],
|
356
|
+
"empirical_coverage": check_coverage_with_nan(
|
357
|
+
class_1_singleton_error_rates, pac_1["singleton_error_rate_bounds"]
|
358
|
+
),
|
359
|
+
},
|
360
|
+
},
|
361
|
+
}
|
362
|
+
|
363
|
+
|
364
|
+
def print_validation_results(validation: dict[str, Any]) -> None:
|
365
|
+
"""Pretty print validation results.
|
366
|
+
|
367
|
+
Parameters
|
368
|
+
----------
|
369
|
+
validation : dict
|
370
|
+
Output from validate_pac_bounds()
|
371
|
+
|
372
|
+
Examples
|
373
|
+
--------
|
374
|
+
>>> validation = validate_pac_bounds(report, sim, test_size=1000, n_trials=1000)
|
375
|
+
>>> print_validation_results(validation)
|
376
|
+
"""
|
377
|
+
print("=" * 80)
|
378
|
+
print("PAC BOUNDS VALIDATION RESULTS")
|
379
|
+
print("=" * 80)
|
380
|
+
print(f"\nTrials: {validation['n_trials']}")
|
381
|
+
print(f"Test size: {validation['test_size']}")
|
382
|
+
print(f"Thresholds: q̂₀={validation['threshold_0']:.4f}, q̂₁={validation['threshold_1']:.4f}")
|
383
|
+
|
384
|
+
for scope in ["marginal", "class_0", "class_1"]:
|
385
|
+
scope_name = scope.upper() if scope == "marginal" else f"CLASS {scope[-1]}"
|
386
|
+
print(f"\n{'=' * 80}")
|
387
|
+
print(f"{scope_name}")
|
388
|
+
print("=" * 80)
|
389
|
+
|
390
|
+
for metric in ["singleton", "doublet", "abstention", "singleton_error"]:
|
391
|
+
m = validation[scope][metric]
|
392
|
+
q = m["quantiles"]
|
393
|
+
coverage = m["empirical_coverage"]
|
394
|
+
|
395
|
+
coverage_check = "✅" if coverage >= 0.90 else "❌" # Assuming 90% PAC level
|
396
|
+
|
397
|
+
print(f"\n{metric.upper().replace('_', ' ')}:")
|
398
|
+
print(f" Empirical mean: {m['mean']:.4f}")
|
399
|
+
print(f" Expected (LOO): {m['expected']:.4f}")
|
400
|
+
q_str = f"[5%: {q['q05']:.3f}, 25%: {q['q25']:.3f}, 50%: {q['q50']:.3f}, "
|
401
|
+
q_str += f"75%: {q['q75']:.3f}, 95%: {q['q95']:.3f}]"
|
402
|
+
print(f" Quantiles: {q_str}")
|
403
|
+
print(f" PAC bounds: [{m['bounds'][0]:.4f}, {m['bounds'][1]:.4f}]")
|
404
|
+
if not np.isnan(coverage):
|
405
|
+
print(f" Coverage: {coverage:.1%} {coverage_check}")
|
406
|
+
else:
|
407
|
+
print(" Coverage: N/A (no valid samples)")
|
408
|
+
|
409
|
+
print("\n" + "=" * 80)
|