ssbc 0.1.0__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ssbc/__init__.py +50 -2
- ssbc/bootstrap.py +411 -0
- ssbc/cli.py +0 -3
- ssbc/conformal.py +700 -1
- ssbc/cross_conformal.py +425 -0
- ssbc/mcp_server.py +93 -0
- ssbc/operational_bounds_simple.py +367 -0
- ssbc/rigorous_report.py +601 -0
- ssbc/statistics.py +70 -0
- ssbc/utils.py +72 -2
- ssbc/validation.py +409 -0
- ssbc/visualization.py +323 -300
- ssbc-1.1.0.dist-info/METADATA +337 -0
- ssbc-1.1.0.dist-info/RECORD +22 -0
- ssbc-1.1.0.dist-info/licenses/LICENSE +29 -0
- ssbc/ssbc.py +0 -1
- ssbc-0.1.0.dist-info/METADATA +0 -266
- ssbc-0.1.0.dist-info/RECORD +0 -17
- ssbc-0.1.0.dist-info/licenses/LICENSE +0 -21
- {ssbc-0.1.0.dist-info → ssbc-1.1.0.dist-info}/WHEEL +0 -0
- {ssbc-0.1.0.dist-info → ssbc-1.1.0.dist-info}/entry_points.txt +0 -0
- {ssbc-0.1.0.dist-info → ssbc-1.1.0.dist-info}/top_level.txt +0 -0
ssbc/__init__.py
CHANGED
@@ -1,12 +1,21 @@
|
|
1
1
|
"""Top-level package for SSBC (Small-Sample Beta Correction)."""
|
2
2
|
|
3
|
+
from importlib.metadata import version
|
4
|
+
|
3
5
|
__author__ = """Petrus H Zwart"""
|
4
6
|
__email__ = "phzwart@lbl.gov"
|
5
|
-
__version__ = "
|
7
|
+
__version__ = version("ssbc") # Read from package metadata (pyproject.toml)
|
6
8
|
|
7
9
|
# Core SSBC algorithm
|
8
10
|
# Conformal prediction
|
11
|
+
# Bootstrap uncertainty analysis
|
12
|
+
from .bootstrap import (
|
13
|
+
bootstrap_calibration_uncertainty,
|
14
|
+
plot_bootstrap_distributions,
|
15
|
+
)
|
9
16
|
from .conformal import (
|
17
|
+
alpha_scan,
|
18
|
+
compute_pac_operational_metrics,
|
10
19
|
mondrian_conformal_calibrate,
|
11
20
|
split_by_class,
|
12
21
|
)
|
@@ -15,12 +24,23 @@ from .core import (
|
|
15
24
|
ssbc_correct,
|
16
25
|
)
|
17
26
|
|
27
|
+
# Cross-conformal validation
|
28
|
+
from .cross_conformal import (
|
29
|
+
cross_conformal_validation,
|
30
|
+
print_cross_conformal_results,
|
31
|
+
)
|
32
|
+
|
18
33
|
# Hyperparameter tuning
|
19
34
|
from .hyperparameter import (
|
20
35
|
sweep_and_plot_parallel_plotly,
|
21
36
|
sweep_hyperparams_and_collect,
|
22
37
|
)
|
23
38
|
|
39
|
+
# Visualization and reporting
|
40
|
+
from .rigorous_report import (
|
41
|
+
generate_rigorous_pac_report,
|
42
|
+
)
|
43
|
+
|
24
44
|
# Simulation (for testing and examples)
|
25
45
|
from .simulation import (
|
26
46
|
BinaryClassifierSimulator,
|
@@ -29,10 +49,21 @@ from .simulation import (
|
|
29
49
|
# Statistics utilities
|
30
50
|
from .statistics import (
|
31
51
|
clopper_pearson_intervals,
|
52
|
+
clopper_pearson_lower,
|
53
|
+
clopper_pearson_upper,
|
32
54
|
cp_interval,
|
33
55
|
)
|
34
56
|
|
35
|
-
#
|
57
|
+
# Utility functions
|
58
|
+
from .utils import (
|
59
|
+
compute_operational_rate,
|
60
|
+
)
|
61
|
+
|
62
|
+
# Validation utilities
|
63
|
+
from .validation import (
|
64
|
+
print_validation_results,
|
65
|
+
validate_pac_bounds,
|
66
|
+
)
|
36
67
|
from .visualization import (
|
37
68
|
plot_parallel_coordinates_plotly,
|
38
69
|
report_prediction_stats,
|
@@ -43,16 +74,33 @@ __all__ = [
|
|
43
74
|
"SSBCResult",
|
44
75
|
"ssbc_correct",
|
45
76
|
# Conformal
|
77
|
+
"alpha_scan",
|
78
|
+
"compute_pac_operational_metrics",
|
46
79
|
"mondrian_conformal_calibrate",
|
47
80
|
"split_by_class",
|
48
81
|
# Statistics
|
49
82
|
"clopper_pearson_intervals",
|
83
|
+
"clopper_pearson_lower",
|
84
|
+
"clopper_pearson_upper",
|
50
85
|
"cp_interval",
|
86
|
+
# Utilities
|
87
|
+
"compute_operational_rate",
|
51
88
|
# Simulation
|
52
89
|
"BinaryClassifierSimulator",
|
53
90
|
# Visualization
|
54
91
|
"report_prediction_stats",
|
55
92
|
"plot_parallel_coordinates_plotly",
|
93
|
+
# Bootstrap uncertainty
|
94
|
+
"bootstrap_calibration_uncertainty",
|
95
|
+
"plot_bootstrap_distributions",
|
96
|
+
# Cross-conformal validation
|
97
|
+
"cross_conformal_validation",
|
98
|
+
"print_cross_conformal_results",
|
99
|
+
# Validation utilities
|
100
|
+
"validate_pac_bounds",
|
101
|
+
"print_validation_results",
|
102
|
+
# Rigorous reporting
|
103
|
+
"generate_rigorous_pac_report",
|
56
104
|
# Hyperparameter
|
57
105
|
"sweep_hyperparams_and_collect",
|
58
106
|
"sweep_and_plot_parallel_plotly",
|
ssbc/bootstrap.py
ADDED
@@ -0,0 +1,411 @@
|
|
1
|
+
"""Bootstrap analysis of calibration uncertainty for operational rates.
|
2
|
+
|
3
|
+
This models: "If I recalibrate many times on similar datasets, how do rates vary?"
|
4
|
+
Different from LOO-CV which models: "Given ONE fixed calibration, how do test sets vary?"
|
5
|
+
"""
|
6
|
+
|
7
|
+
from typing import Protocol
|
8
|
+
|
9
|
+
import numpy as np
|
10
|
+
from joblib import Parallel, delayed
|
11
|
+
|
12
|
+
from ssbc.conformal import split_by_class
|
13
|
+
from ssbc.core import ssbc_correct
|
14
|
+
|
15
|
+
# Optional plotting dependencies
|
16
|
+
try:
|
17
|
+
import matplotlib.pyplot as plt
|
18
|
+
|
19
|
+
HAS_MATPLOTLIB = True
|
20
|
+
except ImportError:
|
21
|
+
HAS_MATPLOTLIB = False
|
22
|
+
|
23
|
+
|
24
|
+
class DataGenerator(Protocol):
|
25
|
+
"""Protocol for data generators (e.g., BinaryClassifierSimulator)."""
|
26
|
+
|
27
|
+
def generate(self, n_samples: int) -> tuple[np.ndarray, np.ndarray]:
|
28
|
+
"""Generate samples.
|
29
|
+
|
30
|
+
Returns
|
31
|
+
-------
|
32
|
+
tuple
|
33
|
+
(labels, probabilities)
|
34
|
+
"""
|
35
|
+
...
|
36
|
+
|
37
|
+
|
38
|
+
def _bootstrap_single_trial(
|
39
|
+
labels: np.ndarray,
|
40
|
+
probs: np.ndarray,
|
41
|
+
alpha_target: float,
|
42
|
+
delta: float,
|
43
|
+
test_size: int,
|
44
|
+
bootstrap_seed: int,
|
45
|
+
simulator: DataGenerator,
|
46
|
+
) -> dict[str, float]:
|
47
|
+
"""Single bootstrap trial: resample calibration → calibrate → evaluate on fresh test set.
|
48
|
+
|
49
|
+
Parameters
|
50
|
+
----------
|
51
|
+
labels : np.ndarray
|
52
|
+
Calibration labels
|
53
|
+
probs : np.ndarray
|
54
|
+
Calibration probabilities
|
55
|
+
alpha_target : float
|
56
|
+
Target miscoverage
|
57
|
+
delta : float
|
58
|
+
PAC risk
|
59
|
+
test_size : int
|
60
|
+
Test set size
|
61
|
+
bootstrap_seed : int
|
62
|
+
Random seed for this trial
|
63
|
+
simulator : DataGenerator
|
64
|
+
Simulator to generate fresh test sets
|
65
|
+
|
66
|
+
Returns
|
67
|
+
-------
|
68
|
+
dict
|
69
|
+
Operational rates for this bootstrap sample
|
70
|
+
"""
|
71
|
+
np.random.seed(bootstrap_seed)
|
72
|
+
|
73
|
+
n = len(labels)
|
74
|
+
|
75
|
+
# Bootstrap resample calibration data (with replacement)
|
76
|
+
bootstrap_idx = np.random.choice(n, size=n, replace=True)
|
77
|
+
labels_boot = labels[bootstrap_idx]
|
78
|
+
probs_boot = probs[bootstrap_idx]
|
79
|
+
|
80
|
+
# Split by class
|
81
|
+
class_data_boot = split_by_class(labels_boot, probs_boot)
|
82
|
+
|
83
|
+
# Calibrate on bootstrap sample
|
84
|
+
try:
|
85
|
+
ssbc_0 = ssbc_correct(alpha_target=alpha_target, n=class_data_boot[0]["n"], delta=delta)
|
86
|
+
ssbc_1 = ssbc_correct(alpha_target=alpha_target, n=class_data_boot[1]["n"], delta=delta)
|
87
|
+
except Exception:
|
88
|
+
# Handle edge cases (e.g., all samples from one class)
|
89
|
+
return {
|
90
|
+
"singleton": np.nan,
|
91
|
+
"doublet": np.nan,
|
92
|
+
"abstention": np.nan,
|
93
|
+
"singleton_error": np.nan,
|
94
|
+
"singleton_0": np.nan,
|
95
|
+
"doublet_0": np.nan,
|
96
|
+
"abstention_0": np.nan,
|
97
|
+
"singleton_error_0": np.nan,
|
98
|
+
"singleton_1": np.nan,
|
99
|
+
"doublet_1": np.nan,
|
100
|
+
"abstention_1": np.nan,
|
101
|
+
"singleton_error_1": np.nan,
|
102
|
+
}
|
103
|
+
|
104
|
+
# Compute thresholds
|
105
|
+
n_0 = class_data_boot[0]["n"]
|
106
|
+
n_1 = class_data_boot[1]["n"]
|
107
|
+
|
108
|
+
k_0 = int(np.ceil((n_0 + 1) * (1 - ssbc_0.alpha_corrected)))
|
109
|
+
k_1 = int(np.ceil((n_1 + 1) * (1 - ssbc_1.alpha_corrected)))
|
110
|
+
|
111
|
+
mask_0 = labels_boot == 0
|
112
|
+
mask_1 = labels_boot == 1
|
113
|
+
|
114
|
+
scores_0 = 1.0 - probs_boot[mask_0, 0]
|
115
|
+
scores_1 = 1.0 - probs_boot[mask_1, 1]
|
116
|
+
|
117
|
+
sorted_0 = np.sort(scores_0)
|
118
|
+
sorted_1 = np.sort(scores_1)
|
119
|
+
|
120
|
+
threshold_0 = sorted_0[min(k_0 - 1, len(sorted_0) - 1)]
|
121
|
+
threshold_1 = sorted_1[min(k_1 - 1, len(sorted_1) - 1)]
|
122
|
+
|
123
|
+
# Generate FRESH test set
|
124
|
+
labels_test, probs_test = simulator.generate(test_size)
|
125
|
+
|
126
|
+
# Evaluate on test set
|
127
|
+
n_test = len(labels_test)
|
128
|
+
n_singletons = 0
|
129
|
+
n_doublets = 0
|
130
|
+
n_abstentions = 0
|
131
|
+
n_singletons_correct = 0
|
132
|
+
|
133
|
+
# Per-class counters
|
134
|
+
n_singletons_0 = 0
|
135
|
+
n_doublets_0 = 0
|
136
|
+
n_abstentions_0 = 0
|
137
|
+
n_singletons_correct_0 = 0
|
138
|
+
n_class_0 = 0
|
139
|
+
|
140
|
+
n_singletons_1 = 0
|
141
|
+
n_doublets_1 = 0
|
142
|
+
n_abstentions_1 = 0
|
143
|
+
n_singletons_correct_1 = 0
|
144
|
+
n_class_1 = 0
|
145
|
+
|
146
|
+
for i in range(n_test):
|
147
|
+
true_label = labels_test[i]
|
148
|
+
score_0 = 1.0 - probs_test[i, 0]
|
149
|
+
score_1 = 1.0 - probs_test[i, 1]
|
150
|
+
|
151
|
+
in_0 = score_0 <= threshold_0
|
152
|
+
in_1 = score_1 <= threshold_1
|
153
|
+
|
154
|
+
# Marginal
|
155
|
+
if in_0 and in_1:
|
156
|
+
n_doublets += 1
|
157
|
+
elif in_0 or in_1:
|
158
|
+
n_singletons += 1
|
159
|
+
if (in_0 and true_label == 0) or (in_1 and true_label == 1):
|
160
|
+
n_singletons_correct += 1
|
161
|
+
else:
|
162
|
+
n_abstentions += 1
|
163
|
+
|
164
|
+
# Per-class
|
165
|
+
if true_label == 0:
|
166
|
+
n_class_0 += 1
|
167
|
+
if in_0 and in_1:
|
168
|
+
n_doublets_0 += 1
|
169
|
+
elif in_0 or in_1:
|
170
|
+
n_singletons_0 += 1
|
171
|
+
if in_0:
|
172
|
+
n_singletons_correct_0 += 1
|
173
|
+
else:
|
174
|
+
n_abstentions_0 += 1
|
175
|
+
else:
|
176
|
+
n_class_1 += 1
|
177
|
+
if in_0 and in_1:
|
178
|
+
n_doublets_1 += 1
|
179
|
+
elif in_0 or in_1:
|
180
|
+
n_singletons_1 += 1
|
181
|
+
if in_1:
|
182
|
+
n_singletons_correct_1 += 1
|
183
|
+
else:
|
184
|
+
n_abstentions_1 += 1
|
185
|
+
|
186
|
+
# Compute rates
|
187
|
+
singleton_rate = n_singletons / n_test
|
188
|
+
doublet_rate = n_doublets / n_test
|
189
|
+
abstention_rate = n_abstentions / n_test
|
190
|
+
singleton_error_rate = (n_singletons - n_singletons_correct) / n_singletons if n_singletons > 0 else np.nan
|
191
|
+
|
192
|
+
# Per-class rates
|
193
|
+
singleton_0 = n_singletons_0 / n_class_0 if n_class_0 > 0 else np.nan
|
194
|
+
doublet_0 = n_doublets_0 / n_class_0 if n_class_0 > 0 else np.nan
|
195
|
+
abstention_0 = n_abstentions_0 / n_class_0 if n_class_0 > 0 else np.nan
|
196
|
+
singleton_error_0 = (n_singletons_0 - n_singletons_correct_0) / n_singletons_0 if n_singletons_0 > 0 else np.nan
|
197
|
+
|
198
|
+
singleton_1 = n_singletons_1 / n_class_1 if n_class_1 > 0 else np.nan
|
199
|
+
doublet_1 = n_doublets_1 / n_class_1 if n_class_1 > 0 else np.nan
|
200
|
+
abstention_1 = n_abstentions_1 / n_class_1 if n_class_1 > 0 else np.nan
|
201
|
+
singleton_error_1 = (n_singletons_1 - n_singletons_correct_1) / n_singletons_1 if n_singletons_1 > 0 else np.nan
|
202
|
+
|
203
|
+
return {
|
204
|
+
"singleton": singleton_rate,
|
205
|
+
"doublet": doublet_rate,
|
206
|
+
"abstention": abstention_rate,
|
207
|
+
"singleton_error": singleton_error_rate,
|
208
|
+
"singleton_0": singleton_0,
|
209
|
+
"doublet_0": doublet_0,
|
210
|
+
"abstention_0": abstention_0,
|
211
|
+
"singleton_error_0": singleton_error_0,
|
212
|
+
"singleton_1": singleton_1,
|
213
|
+
"doublet_1": doublet_1,
|
214
|
+
"abstention_1": abstention_1,
|
215
|
+
"singleton_error_1": singleton_error_1,
|
216
|
+
}
|
217
|
+
|
218
|
+
|
219
|
+
def bootstrap_calibration_uncertainty(
|
220
|
+
labels: np.ndarray,
|
221
|
+
probs: np.ndarray,
|
222
|
+
simulator: DataGenerator,
|
223
|
+
alpha_target: float = 0.10,
|
224
|
+
delta: float = 0.10,
|
225
|
+
test_size: int = 1000,
|
226
|
+
n_bootstrap: int = 1000,
|
227
|
+
n_jobs: int = -1,
|
228
|
+
seed: int | None = None,
|
229
|
+
) -> dict:
|
230
|
+
"""Bootstrap analysis of calibration uncertainty.
|
231
|
+
|
232
|
+
For each bootstrap iteration:
|
233
|
+
1. Resample calibration data with replacement
|
234
|
+
2. Calibrate (compute SSBC thresholds)
|
235
|
+
3. Evaluate on fresh independent test set
|
236
|
+
4. Record operational rates
|
237
|
+
|
238
|
+
This models: "If I recalibrate on similar datasets, how do rates vary?"
|
239
|
+
|
240
|
+
Parameters
|
241
|
+
----------
|
242
|
+
labels : np.ndarray
|
243
|
+
Calibration labels
|
244
|
+
probs : np.ndarray
|
245
|
+
Calibration probabilities
|
246
|
+
simulator : DataGenerator
|
247
|
+
Simulator to generate independent test sets
|
248
|
+
alpha_target : float, default=0.10
|
249
|
+
Target miscoverage
|
250
|
+
delta : float, default=0.10
|
251
|
+
PAC risk
|
252
|
+
test_size : int, default=1000
|
253
|
+
Size of test sets for evaluation
|
254
|
+
n_bootstrap : int, default=1000
|
255
|
+
Number of bootstrap iterations
|
256
|
+
n_jobs : int, default=-1
|
257
|
+
Parallel jobs (-1 for all cores)
|
258
|
+
seed : int, optional
|
259
|
+
Random seed
|
260
|
+
|
261
|
+
Returns
|
262
|
+
-------
|
263
|
+
dict
|
264
|
+
Bootstrap distributions with keys:
|
265
|
+
- 'marginal': dict with 'singleton', 'doublet', 'abstention', 'singleton_error'
|
266
|
+
- 'class_0': dict with same metrics
|
267
|
+
- 'class_1': dict with same metrics
|
268
|
+
Each metric contains:
|
269
|
+
- 'samples': array of rates across bootstrap trials
|
270
|
+
- 'mean': mean rate
|
271
|
+
- 'std': standard deviation
|
272
|
+
- 'quantiles': dict with q05, q25, q50, q75, q95
|
273
|
+
|
274
|
+
Examples
|
275
|
+
--------
|
276
|
+
>>> from ssbc import BinaryClassifierSimulator, bootstrap_calibration_uncertainty
|
277
|
+
>>> sim = BinaryClassifierSimulator(p_class1=0.2, beta_params_class0=(1,7), beta_params_class1=(5,2))
|
278
|
+
>>> labels, probs = sim.generate(100)
|
279
|
+
>>> results = bootstrap_calibration_uncertainty(labels, probs, sim, n_bootstrap=100)
|
280
|
+
>>> print(results['marginal']['singleton']['mean'])
|
281
|
+
"""
|
282
|
+
if seed is not None:
|
283
|
+
np.random.seed(seed)
|
284
|
+
|
285
|
+
# Generate bootstrap seeds
|
286
|
+
bootstrap_seeds = np.random.randint(0, 2**31, size=n_bootstrap)
|
287
|
+
|
288
|
+
# Parallel bootstrap
|
289
|
+
results = Parallel(n_jobs=n_jobs)(
|
290
|
+
delayed(_bootstrap_single_trial)(labels, probs, alpha_target, delta, test_size, bs_seed, simulator)
|
291
|
+
for bs_seed in bootstrap_seeds
|
292
|
+
)
|
293
|
+
|
294
|
+
# Extract metrics
|
295
|
+
metrics = ["singleton", "doublet", "abstention", "singleton_error"]
|
296
|
+
|
297
|
+
def compute_stats(values):
|
298
|
+
"""Compute statistics for a metric."""
|
299
|
+
arr = np.array(values)
|
300
|
+
valid = arr[~np.isnan(arr)]
|
301
|
+
if len(valid) == 0:
|
302
|
+
return {
|
303
|
+
"samples": arr,
|
304
|
+
"mean": np.nan,
|
305
|
+
"std": np.nan,
|
306
|
+
"quantiles": {"q05": np.nan, "q25": np.nan, "q50": np.nan, "q75": np.nan, "q95": np.nan},
|
307
|
+
}
|
308
|
+
return {
|
309
|
+
"samples": arr,
|
310
|
+
"mean": np.mean(valid),
|
311
|
+
"std": np.std(valid),
|
312
|
+
"quantiles": {
|
313
|
+
"q05": np.percentile(valid, 5),
|
314
|
+
"q25": np.percentile(valid, 25),
|
315
|
+
"q50": np.percentile(valid, 50),
|
316
|
+
"q75": np.percentile(valid, 75),
|
317
|
+
"q95": np.percentile(valid, 95),
|
318
|
+
},
|
319
|
+
}
|
320
|
+
|
321
|
+
# Organize results
|
322
|
+
return {
|
323
|
+
"n_bootstrap": n_bootstrap,
|
324
|
+
"n_calibration": len(labels),
|
325
|
+
"test_size": test_size,
|
326
|
+
"marginal": {metric: compute_stats([r[metric] for r in results]) for metric in metrics},
|
327
|
+
"class_0": {metric: compute_stats([r[f"{metric}_0"] for r in results]) for metric in metrics},
|
328
|
+
"class_1": {metric: compute_stats([r[f"{metric}_1"] for r in results]) for metric in metrics},
|
329
|
+
}
|
330
|
+
|
331
|
+
|
332
|
+
def plot_bootstrap_distributions(
|
333
|
+
bootstrap_results: dict,
|
334
|
+
figsize: tuple[int, int] = (16, 12),
|
335
|
+
save_path: str | None = None,
|
336
|
+
) -> None:
|
337
|
+
"""Plot bootstrap distributions.
|
338
|
+
|
339
|
+
Parameters
|
340
|
+
----------
|
341
|
+
bootstrap_results : dict
|
342
|
+
Results from bootstrap_calibration_uncertainty()
|
343
|
+
figsize : tuple, default=(16, 12)
|
344
|
+
Figure size
|
345
|
+
save_path : str, optional
|
346
|
+
Path to save figure. If None, displays interactively.
|
347
|
+
|
348
|
+
Raises
|
349
|
+
------
|
350
|
+
ImportError
|
351
|
+
If matplotlib is not installed
|
352
|
+
|
353
|
+
Examples
|
354
|
+
--------
|
355
|
+
>>> from ssbc import bootstrap_calibration_uncertainty, plot_bootstrap_distributions
|
356
|
+
>>> results = bootstrap_calibration_uncertainty(...)
|
357
|
+
>>> plot_bootstrap_distributions(results, save_path='bootstrap_results.png')
|
358
|
+
"""
|
359
|
+
if not HAS_MATPLOTLIB:
|
360
|
+
raise ImportError("matplotlib is required for plotting. Install with: pip install matplotlib")
|
361
|
+
|
362
|
+
fig, axes = plt.subplots(3, 4, figsize=figsize)
|
363
|
+
fig.suptitle(
|
364
|
+
f"Bootstrap Calibration Uncertainty ({bootstrap_results['n_bootstrap']} trials)\n"
|
365
|
+
f"Calibration n={bootstrap_results['n_calibration']}, Test size={bootstrap_results['test_size']}",
|
366
|
+
fontsize=14,
|
367
|
+
fontweight="bold",
|
368
|
+
)
|
369
|
+
|
370
|
+
metrics = ["singleton", "doublet", "abstention", "singleton_error"]
|
371
|
+
metric_names = ["Singleton Rate", "Doublet Rate", "Abstention Rate", "Singleton Error Rate"]
|
372
|
+
colors = ["steelblue", "coral", "mediumpurple"]
|
373
|
+
row_names = ["MARGINAL", "CLASS 0", "CLASS 1"]
|
374
|
+
data_keys = ["marginal", "class_0", "class_1"]
|
375
|
+
|
376
|
+
for row, (row_name, data_key, color) in enumerate(zip(row_names, data_keys, colors, strict=False)):
|
377
|
+
for col, (metric, name) in enumerate(zip(metrics, metric_names, strict=False)):
|
378
|
+
ax = axes[row, col]
|
379
|
+
m = bootstrap_results[data_key][metric]
|
380
|
+
|
381
|
+
# Filter NaNs
|
382
|
+
samples = m["samples"]
|
383
|
+
samples = samples[~np.isnan(samples)]
|
384
|
+
|
385
|
+
if len(samples) == 0:
|
386
|
+
ax.text(0.5, 0.5, "No data", ha="center", va="center")
|
387
|
+
continue
|
388
|
+
|
389
|
+
# Histogram
|
390
|
+
ax.hist(samples, bins=50, alpha=0.7, color=color, edgecolor="black")
|
391
|
+
|
392
|
+
# Quantiles
|
393
|
+
q = m["quantiles"]
|
394
|
+
ax.axvline(q["q50"], color="green", linestyle="-", linewidth=2, label=f"Median: {q['q50']:.3f}")
|
395
|
+
ax.axvline(q["q05"], color="red", linestyle="--", linewidth=2, label=f"5%: {q['q05']:.3f}")
|
396
|
+
ax.axvline(q["q95"], color="red", linestyle="--", linewidth=2, label=f"95%: {q['q95']:.3f}")
|
397
|
+
ax.axvline(m["mean"], color="orange", linestyle=":", linewidth=2, label=f"Mean: {m['mean']:.3f}")
|
398
|
+
|
399
|
+
ax.set_title(f"{row_name}: {name}", fontweight="bold")
|
400
|
+
ax.set_xlabel("Rate")
|
401
|
+
ax.set_ylabel("Count")
|
402
|
+
ax.legend(loc="best", fontsize=8)
|
403
|
+
ax.grid(True, alpha=0.3)
|
404
|
+
|
405
|
+
plt.tight_layout()
|
406
|
+
|
407
|
+
if save_path:
|
408
|
+
plt.savefig(save_path, dpi=300, bbox_inches="tight")
|
409
|
+
print(f"✅ Saved bootstrap visualization to: {save_path}")
|
410
|
+
else:
|
411
|
+
plt.show()
|
ssbc/cli.py
CHANGED
@@ -3,8 +3,6 @@
|
|
3
3
|
import typer
|
4
4
|
from rich.console import Console
|
5
5
|
|
6
|
-
from ssbc import utils
|
7
|
-
|
8
6
|
app = typer.Typer()
|
9
7
|
console = Console()
|
10
8
|
|
@@ -14,7 +12,6 @@ def main():
|
|
14
12
|
"""Console script for ssbc."""
|
15
13
|
console.print("Replace this message by putting your code into ssbc.cli.main")
|
16
14
|
console.print("See Typer documentation at https://typer.tiangolo.com/")
|
17
|
-
utils.do_something_useful()
|
18
15
|
|
19
16
|
|
20
17
|
if __name__ == "__main__":
|