ssbc 1.0.0__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,367 @@
1
+ """Simplified operational bounds for fixed calibration (LOO-CV + CP)."""
2
+
3
+ import numpy as np
4
+ from joblib import Parallel, delayed
5
+
6
+ from .core import SSBCResult
7
+ from .statistics import clopper_pearson_lower, clopper_pearson_upper
8
+
9
+
10
+ def _evaluate_loo_single_sample_marginal(
11
+ idx: int,
12
+ labels: np.ndarray,
13
+ probs: np.ndarray,
14
+ k_0: int,
15
+ k_1: int,
16
+ ) -> tuple[int, int, int, int]:
17
+ """Evaluate single LOO fold for marginal operational rates.
18
+
19
+ Parameters
20
+ ----------
21
+ k_0, k_1 : int
22
+ Quantile positions (1-indexed) from SSBC calibration
23
+
24
+ Returns
25
+ -------
26
+ tuple[int, int, int, int]
27
+ (is_singleton, is_doublet, is_abstention, is_singleton_correct)
28
+ """
29
+ mask_0 = labels == 0
30
+ mask_1 = labels == 1
31
+
32
+ # Compute LOO thresholds (using FIXED k positions)
33
+ # Class 0
34
+ if mask_0[idx]:
35
+ scores_0_loo = 1.0 - probs[mask_0, 0]
36
+ mask_0_idx = np.where(mask_0)[0]
37
+ loo_position = np.where(mask_0_idx == idx)[0][0]
38
+ scores_0_loo = np.delete(scores_0_loo, loo_position)
39
+ else:
40
+ scores_0_loo = 1.0 - probs[mask_0, 0]
41
+
42
+ sorted_0_loo = np.sort(scores_0_loo)
43
+ threshold_0_loo = sorted_0_loo[min(k_0 - 1, len(sorted_0_loo) - 1)]
44
+
45
+ # Class 1
46
+ if mask_1[idx]:
47
+ scores_1_loo = 1.0 - probs[mask_1, 1]
48
+ mask_1_idx = np.where(mask_1)[0]
49
+ loo_position = np.where(mask_1_idx == idx)[0][0]
50
+ scores_1_loo = np.delete(scores_1_loo, loo_position)
51
+ else:
52
+ scores_1_loo = 1.0 - probs[mask_1, 1]
53
+
54
+ sorted_1_loo = np.sort(scores_1_loo)
55
+ threshold_1_loo = sorted_1_loo[min(k_1 - 1, len(sorted_1_loo) - 1)]
56
+
57
+ # Evaluate on held-out sample
58
+ score_0 = 1.0 - probs[idx, 0]
59
+ score_1 = 1.0 - probs[idx, 1]
60
+ true_label = labels[idx]
61
+
62
+ in_0 = score_0 <= threshold_0_loo
63
+ in_1 = score_1 <= threshold_1_loo
64
+
65
+ # Determine prediction set type
66
+ if in_0 and in_1:
67
+ is_singleton, is_doublet, is_abstention = 0, 1, 0
68
+ is_singleton_correct = 0
69
+ elif in_0 or in_1:
70
+ is_singleton, is_doublet, is_abstention = 1, 0, 0
71
+ is_singleton_correct = 1 if (in_0 and true_label == 0) or (in_1 and true_label == 1) else 0
72
+ else:
73
+ is_singleton, is_doublet, is_abstention = 0, 0, 1
74
+ is_singleton_correct = 0
75
+
76
+ return is_singleton, is_doublet, is_abstention, is_singleton_correct
77
+
78
+
79
+ def compute_pac_operational_bounds_marginal(
80
+ ssbc_result_0: SSBCResult,
81
+ ssbc_result_1: SSBCResult,
82
+ labels: np.ndarray,
83
+ probs: np.ndarray,
84
+ test_size: int, # Kept for API compatibility (not used)
85
+ ci_level: float = 0.95,
86
+ pac_level: float = 0.95, # Kept for API compatibility (not used)
87
+ use_union_bound: bool = True,
88
+ n_jobs: int = -1,
89
+ ) -> dict:
90
+ """Compute marginal operational bounds for FIXED calibration via LOO-CV.
91
+
92
+ Simplified approach:
93
+ 1. Use FIXED u_star positions from SSBC calibration
94
+ 2. Run LOO-CV to get unbiased rate estimates
95
+ 3. Apply Clopper-Pearson for binomial sampling uncertainty
96
+ 4. Optional union bound for simultaneous guarantees
97
+
98
+ This models: "Given fixed calibration, what are rate distributions on future test sets?"
99
+
100
+ Parameters
101
+ ----------
102
+ ssbc_result_0 : SSBCResult
103
+ SSBC result for class 0
104
+ ssbc_result_1 : SSBCResult
105
+ SSBC result for class 1
106
+ labels : np.ndarray
107
+ True labels
108
+ probs : np.ndarray
109
+ Predicted probabilities
110
+ ci_level : float, default=0.95
111
+ Confidence level for Clopper-Pearson intervals
112
+ use_union_bound : bool, default=True
113
+ Apply Bonferroni for simultaneous guarantees
114
+ n_jobs : int, default=-1
115
+ Number of parallel jobs (-1 = all cores)
116
+
117
+ Returns
118
+ -------
119
+ dict
120
+ Operational bounds with keys:
121
+ - 'singleton_rate_bounds': [L, U]
122
+ - 'doublet_rate_bounds': [L, U]
123
+ - 'abstention_rate_bounds': [L, U]
124
+ - 'singleton_error_rate_bounds': [L, U]
125
+ - 'expected_*_rate': point estimates
126
+ """
127
+ n = len(labels)
128
+
129
+ # Compute k (quantile position) from SSBC-corrected alpha
130
+ # k = ceil((n_class + 1) * (1 - alpha_corrected))
131
+ n_0 = ssbc_result_0.n
132
+ n_1 = ssbc_result_1.n
133
+ k_0 = int(np.ceil((n_0 + 1) * (1 - ssbc_result_0.alpha_corrected)))
134
+ k_1 = int(np.ceil((n_1 + 1) * (1 - ssbc_result_1.alpha_corrected)))
135
+
136
+ # Parallel LOO-CV: evaluate each sample
137
+ results = Parallel(n_jobs=n_jobs)(
138
+ delayed(_evaluate_loo_single_sample_marginal)(idx, labels, probs, k_0, k_1) for idx in range(n)
139
+ )
140
+
141
+ # Aggregate results
142
+ results_array = np.array(results)
143
+ n_singletons = int(np.sum(results_array[:, 0]))
144
+ n_doublets = int(np.sum(results_array[:, 1]))
145
+ n_abstentions = int(np.sum(results_array[:, 2]))
146
+ n_singletons_correct = int(np.sum(results_array[:, 3]))
147
+
148
+ # Point estimates
149
+ singleton_rate = n_singletons / n
150
+ doublet_rate = n_doublets / n
151
+ abstention_rate = n_abstentions / n
152
+ n_errors = n_singletons - n_singletons_correct
153
+ singleton_error_rate = n_errors / n_singletons if n_singletons > 0 else 0.0
154
+
155
+ # Apply CP bounds from calibration counts
156
+ # These bound the TRUE rate p (valid for any future test set size)
157
+ # No need to scale - CP interval already accounts for estimation uncertainty
158
+
159
+ n_metrics = 4
160
+ if use_union_bound:
161
+ adjusted_ci_level = 1 - (1 - ci_level) / n_metrics
162
+ else:
163
+ adjusted_ci_level = ci_level
164
+
165
+ # Use calibration counts directly
166
+ singleton_lower = clopper_pearson_lower(n_singletons, n, adjusted_ci_level)
167
+ singleton_upper = clopper_pearson_upper(n_singletons, n, adjusted_ci_level)
168
+
169
+ doublet_lower = clopper_pearson_lower(n_doublets, n, adjusted_ci_level)
170
+ doublet_upper = clopper_pearson_upper(n_doublets, n, adjusted_ci_level)
171
+
172
+ abstention_lower = clopper_pearson_lower(n_abstentions, n, adjusted_ci_level)
173
+ abstention_upper = clopper_pearson_upper(n_abstentions, n, adjusted_ci_level)
174
+
175
+ # Singleton error (conditioned on singletons)
176
+ if n_singletons > 0:
177
+ error_lower = clopper_pearson_lower(n_errors, n_singletons, adjusted_ci_level)
178
+ error_upper = clopper_pearson_upper(n_errors, n_singletons, adjusted_ci_level)
179
+ else:
180
+ error_lower = 0.0
181
+ error_upper = 1.0
182
+
183
+ return {
184
+ "singleton_rate_bounds": [singleton_lower, singleton_upper],
185
+ "doublet_rate_bounds": [doublet_lower, doublet_upper],
186
+ "abstention_rate_bounds": [abstention_lower, abstention_upper],
187
+ "singleton_error_rate_bounds": [error_lower, error_upper],
188
+ "expected_singleton_rate": singleton_rate,
189
+ "expected_doublet_rate": doublet_rate,
190
+ "expected_abstention_rate": abstention_rate,
191
+ "expected_singleton_error_rate": singleton_error_rate,
192
+ "n_grid_points": 1, # Single scenario (fixed thresholds)
193
+ "pac_level": adjusted_ci_level,
194
+ "ci_level": ci_level,
195
+ "test_size": n,
196
+ "use_union_bound": use_union_bound,
197
+ "n_metrics": n_metrics if use_union_bound else None,
198
+ }
199
+
200
+
201
+ def _evaluate_loo_single_sample_perclass(
202
+ idx: int,
203
+ labels: np.ndarray,
204
+ probs: np.ndarray,
205
+ k_0: int,
206
+ k_1: int,
207
+ class_label: int,
208
+ ) -> tuple[int, int, int, int]:
209
+ """Evaluate single LOO fold for per-class operational rates.
210
+
211
+ Returns
212
+ -------
213
+ tuple[int, int, int, int]
214
+ (is_singleton, is_doublet, is_abstention, is_singleton_correct)
215
+ """
216
+ # Only evaluate if sample is from class_label
217
+ if labels[idx] != class_label:
218
+ return 0, 0, 0, 0
219
+
220
+ mask_0 = labels == 0
221
+ mask_1 = labels == 1
222
+
223
+ # Compute LOO thresholds
224
+ # Class 0
225
+ if mask_0[idx]:
226
+ scores_0_loo = 1.0 - probs[mask_0, 0]
227
+ mask_0_idx = np.where(mask_0)[0]
228
+ loo_position = np.where(mask_0_idx == idx)[0][0]
229
+ scores_0_loo = np.delete(scores_0_loo, loo_position)
230
+ else:
231
+ scores_0_loo = 1.0 - probs[mask_0, 0]
232
+
233
+ sorted_0_loo = np.sort(scores_0_loo)
234
+ threshold_0_loo = sorted_0_loo[min(k_0 - 1, len(sorted_0_loo) - 1)]
235
+
236
+ # Class 1
237
+ if mask_1[idx]:
238
+ scores_1_loo = 1.0 - probs[mask_1, 1]
239
+ mask_1_idx = np.where(mask_1)[0]
240
+ loo_position = np.where(mask_1_idx == idx)[0][0]
241
+ scores_1_loo = np.delete(scores_1_loo, loo_position)
242
+ else:
243
+ scores_1_loo = 1.0 - probs[mask_1, 1]
244
+
245
+ sorted_1_loo = np.sort(scores_1_loo)
246
+ threshold_1_loo = sorted_1_loo[min(k_1 - 1, len(sorted_1_loo) - 1)]
247
+
248
+ # Evaluate on held-out sample
249
+ score_0 = 1.0 - probs[idx, 0]
250
+ score_1 = 1.0 - probs[idx, 1]
251
+ true_label = labels[idx]
252
+
253
+ in_0 = score_0 <= threshold_0_loo
254
+ in_1 = score_1 <= threshold_1_loo
255
+
256
+ # Determine prediction set type
257
+ if in_0 and in_1:
258
+ is_singleton, is_doublet, is_abstention = 0, 1, 0
259
+ is_singleton_correct = 0
260
+ elif in_0 or in_1:
261
+ is_singleton, is_doublet, is_abstention = 1, 0, 0
262
+ is_singleton_correct = 1 if (in_0 and true_label == 0) or (in_1 and true_label == 1) else 0
263
+ else:
264
+ is_singleton, is_doublet, is_abstention = 0, 0, 1
265
+ is_singleton_correct = 0
266
+
267
+ return is_singleton, is_doublet, is_abstention, is_singleton_correct
268
+
269
+
270
+ def compute_pac_operational_bounds_perclass(
271
+ ssbc_result_0: SSBCResult,
272
+ ssbc_result_1: SSBCResult,
273
+ labels: np.ndarray,
274
+ probs: np.ndarray,
275
+ class_label: int,
276
+ test_size: int, # Kept for API compatibility (not used)
277
+ ci_level: float = 0.95,
278
+ pac_level: float = 0.95, # Kept for API compatibility (not used)
279
+ use_union_bound: bool = True,
280
+ n_jobs: int = -1,
281
+ ) -> dict:
282
+ """Compute per-class operational bounds for FIXED calibration via LOO-CV.
283
+
284
+ Parameters
285
+ ----------
286
+ class_label : int
287
+ Which class to analyze (0 or 1)
288
+
289
+ Other parameters same as marginal version.
290
+
291
+ Returns
292
+ -------
293
+ dict
294
+ Per-class operational bounds
295
+ """
296
+ # Compute k from alpha_corrected
297
+ n_0 = ssbc_result_0.n
298
+ n_1 = ssbc_result_1.n
299
+ k_0 = int(np.ceil((n_0 + 1) * (1 - ssbc_result_0.alpha_corrected)))
300
+ k_1 = int(np.ceil((n_1 + 1) * (1 - ssbc_result_1.alpha_corrected)))
301
+
302
+ # Parallel LOO-CV: evaluate each sample
303
+ n = len(labels)
304
+ results = Parallel(n_jobs=n_jobs)(
305
+ delayed(_evaluate_loo_single_sample_perclass)(idx, labels, probs, k_0, k_1, class_label) for idx in range(n)
306
+ )
307
+
308
+ # Aggregate results (only from class_label samples)
309
+ results_array = np.array(results)
310
+ n_singletons = int(np.sum(results_array[:, 0]))
311
+ n_doublets = int(np.sum(results_array[:, 1]))
312
+ n_abstentions = int(np.sum(results_array[:, 2]))
313
+ n_singletons_correct = int(np.sum(results_array[:, 3]))
314
+
315
+ # Number of class_label samples in calibration
316
+ n_class_cal = np.sum(labels == class_label)
317
+
318
+ # Point estimates
319
+ singleton_rate = n_singletons / n_class_cal
320
+ doublet_rate = n_doublets / n_class_cal
321
+ abstention_rate = n_abstentions / n_class_cal
322
+ n_errors = n_singletons - n_singletons_correct
323
+ singleton_error_rate = n_errors / n_singletons if n_singletons > 0 else 0.0
324
+
325
+ # Apply CP bounds from calibration counts
326
+ # These bound the TRUE rate p (valid for any future test set size)
327
+
328
+ n_metrics = 4
329
+ if use_union_bound:
330
+ adjusted_ci_level = 1 - (1 - ci_level) / n_metrics
331
+ else:
332
+ adjusted_ci_level = ci_level
333
+
334
+ # Use calibration counts directly
335
+ singleton_lower = clopper_pearson_lower(n_singletons, n_class_cal, adjusted_ci_level)
336
+ singleton_upper = clopper_pearson_upper(n_singletons, n_class_cal, adjusted_ci_level)
337
+
338
+ doublet_lower = clopper_pearson_lower(n_doublets, n_class_cal, adjusted_ci_level)
339
+ doublet_upper = clopper_pearson_upper(n_doublets, n_class_cal, adjusted_ci_level)
340
+
341
+ abstention_lower = clopper_pearson_lower(n_abstentions, n_class_cal, adjusted_ci_level)
342
+ abstention_upper = clopper_pearson_upper(n_abstentions, n_class_cal, adjusted_ci_level)
343
+
344
+ # Singleton error (conditioned on singletons)
345
+ if n_singletons > 0:
346
+ error_lower = clopper_pearson_lower(n_errors, n_singletons, adjusted_ci_level)
347
+ error_upper = clopper_pearson_upper(n_errors, n_singletons, adjusted_ci_level)
348
+ else:
349
+ error_lower = 0.0
350
+ error_upper = 1.0
351
+
352
+ return {
353
+ "singleton_rate_bounds": [singleton_lower, singleton_upper],
354
+ "doublet_rate_bounds": [doublet_lower, doublet_upper],
355
+ "abstention_rate_bounds": [abstention_lower, abstention_upper],
356
+ "singleton_error_rate_bounds": [error_lower, error_upper],
357
+ "expected_singleton_rate": singleton_rate,
358
+ "expected_doublet_rate": doublet_rate,
359
+ "expected_abstention_rate": abstention_rate,
360
+ "expected_singleton_error_rate": singleton_error_rate,
361
+ "n_grid_points": 1,
362
+ "pac_level": adjusted_ci_level,
363
+ "ci_level": ci_level,
364
+ "test_size": n_class_cal, # Use calibration class size
365
+ "use_union_bound": use_union_bound,
366
+ "n_metrics": n_metrics if use_union_bound else None,
367
+ }