ssbc 0.1.0__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ssbc/__init__.py +50 -2
- ssbc/bootstrap.py +411 -0
- ssbc/cli.py +0 -3
- ssbc/conformal.py +700 -1
- ssbc/cross_conformal.py +425 -0
- ssbc/mcp_server.py +93 -0
- ssbc/operational_bounds_simple.py +367 -0
- ssbc/rigorous_report.py +601 -0
- ssbc/statistics.py +70 -0
- ssbc/utils.py +72 -2
- ssbc/validation.py +409 -0
- ssbc/visualization.py +323 -300
- ssbc-1.1.0.dist-info/METADATA +337 -0
- ssbc-1.1.0.dist-info/RECORD +22 -0
- ssbc-1.1.0.dist-info/licenses/LICENSE +29 -0
- ssbc/ssbc.py +0 -1
- ssbc-0.1.0.dist-info/METADATA +0 -266
- ssbc-0.1.0.dist-info/RECORD +0 -17
- ssbc-0.1.0.dist-info/licenses/LICENSE +0 -21
- {ssbc-0.1.0.dist-info → ssbc-1.1.0.dist-info}/WHEEL +0 -0
- {ssbc-0.1.0.dist-info → ssbc-1.1.0.dist-info}/entry_points.txt +0 -0
- {ssbc-0.1.0.dist-info → ssbc-1.1.0.dist-info}/top_level.txt +0 -0
ssbc/visualization.py
CHANGED
@@ -5,19 +5,62 @@ from typing import Any
|
|
5
5
|
from .statistics import cp_interval
|
6
6
|
|
7
7
|
|
8
|
+
def compute_conditional_rate_bounds(
|
9
|
+
numerator_fold_results: list[dict],
|
10
|
+
denominator_fold_results: list[dict],
|
11
|
+
weights: list[float],
|
12
|
+
) -> tuple[float, float, float]:
|
13
|
+
"""Compute bounds on conditional rate from fold-level counts.
|
14
|
+
|
15
|
+
For conditional rate P(A | B), computes cross-validated bounds by:
|
16
|
+
1. Computing A_count / B_count in each fold
|
17
|
+
2. Using Clopper-Pearson on aggregated counts
|
18
|
+
|
19
|
+
Parameters
|
20
|
+
----------
|
21
|
+
numerator_fold_results : list[dict]
|
22
|
+
Fold results for numerator event (e.g., correct_in_singleton)
|
23
|
+
denominator_fold_results : list[dict]
|
24
|
+
Fold results for denominator event (e.g., singleton)
|
25
|
+
weights : list[float]
|
26
|
+
Fold weights
|
27
|
+
|
28
|
+
Returns
|
29
|
+
-------
|
30
|
+
rate : float
|
31
|
+
Estimated conditional rate
|
32
|
+
lower : float
|
33
|
+
Lower CI bound
|
34
|
+
upper : float
|
35
|
+
Upper CI bound
|
36
|
+
"""
|
37
|
+
# Aggregate counts across folds
|
38
|
+
total_numerator = sum(fold["K_fr"] for fold in numerator_fold_results)
|
39
|
+
total_denominator = sum(fold["K_fr"] for fold in denominator_fold_results)
|
40
|
+
|
41
|
+
if total_denominator == 0:
|
42
|
+
return 0.0, 0.0, 1.0
|
43
|
+
|
44
|
+
# Compute CP interval on aggregated counts
|
45
|
+
ci = cp_interval(total_numerator, total_denominator)
|
46
|
+
return ci["proportion"], ci["lower"], ci["upper"]
|
47
|
+
|
48
|
+
|
8
49
|
def report_prediction_stats(
|
9
|
-
prediction_stats: dict[Any, Any],
|
50
|
+
prediction_stats: dict[Any, Any],
|
51
|
+
calibration_result: dict[Any, Any],
|
52
|
+
operational_bounds_per_class: dict[int, Any] | None = None,
|
53
|
+
marginal_operational_bounds: Any | None = None,
|
54
|
+
verbose: bool = True,
|
10
55
|
) -> dict[str | int, Any]:
|
11
|
-
"""
|
56
|
+
"""Report rigorous statistics for Mondrian conformal prediction with valid CIs.
|
12
57
|
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
OR as top-level aliases 'singletons_correct' / 'singletons_incorrect'.
|
58
|
+
Only displays statistics with valid confidence intervals:
|
59
|
+
- Per-class statistics from calibration data (valid within class)
|
60
|
+
- Per-class operational bounds from cross-validation (rigorous PAC bounds)
|
61
|
+
- Marginal operational bounds from cross-validated Mondrian (rigorous PAC bounds)
|
18
62
|
|
19
|
-
|
20
|
-
singleton errors by predicted class.
|
63
|
+
Does NOT display marginal statistics from calibration data (invalid CIs for Mondrian).
|
21
64
|
|
22
65
|
Parameters
|
23
66
|
----------
|
@@ -25,338 +68,318 @@ def report_prediction_stats(
|
|
25
68
|
Output from mondrian_conformal_calibrate (second return value)
|
26
69
|
calibration_result : dict
|
27
70
|
Output from mondrian_conformal_calibrate (first return value)
|
71
|
+
operational_bounds_per_class : dict[int, OperationalRateBoundsResult], optional
|
72
|
+
Per-class operational bounds from compute_mondrian_operational_bounds
|
73
|
+
marginal_operational_bounds : OperationalRateBoundsResult, optional
|
74
|
+
Marginal operational bounds from compute_marginal_operational_bounds
|
28
75
|
verbose : bool, default=True
|
29
76
|
If True, print detailed statistics to stdout
|
30
77
|
|
31
78
|
Returns
|
32
79
|
-------
|
33
80
|
dict
|
34
|
-
Structured summary with CIs
|
81
|
+
Structured summary with valid CIs:
|
35
82
|
- Keys 0, 1 for per-class statistics
|
36
|
-
- Key '
|
83
|
+
- Key 'marginal_bounds' if marginal_operational_bounds provided
|
37
84
|
|
38
85
|
Examples
|
39
86
|
--------
|
40
|
-
>>> #
|
87
|
+
>>> # Basic: Only calibration data (limited info)
|
41
88
|
>>> cal_result, pred_stats = mondrian_conformal_calibrate(...)
|
42
|
-
>>> summary = report_prediction_stats(pred_stats, cal_result
|
43
|
-
>>>
|
89
|
+
>>> summary = report_prediction_stats(pred_stats, cal_result)
|
90
|
+
>>>
|
91
|
+
>>> # With per-class operational bounds (rigorous)
|
92
|
+
>>> op_bounds = compute_mondrian_operational_bounds(cal_result, class_data, delta=0.05)
|
93
|
+
>>> summary = report_prediction_stats(pred_stats, cal_result, op_bounds)
|
94
|
+
>>>
|
95
|
+
>>> # With marginal bounds too (complete SLA)
|
96
|
+
>>> marginal = compute_marginal_operational_bounds(labels, probs, 0.1, 0.05, 0.05)
|
97
|
+
>>> summary = report_prediction_stats(pred_stats, cal_result, op_bounds, marginal)
|
44
98
|
"""
|
45
|
-
|
46
|
-
# Helper functions
|
47
|
-
def as_dict(x: Any) -> dict[str, Any]:
|
48
|
-
"""Ensure x is a dict."""
|
49
|
-
return x if isinstance(x, dict) else {}
|
50
|
-
|
51
|
-
def get_count(x: Any, default: int = 0) -> int:
|
52
|
-
"""Extract count from dict or int."""
|
53
|
-
if isinstance(x, dict):
|
54
|
-
return int(x.get("count", default))
|
55
|
-
if isinstance(x, int):
|
56
|
-
return int(x)
|
57
|
-
return default
|
58
|
-
|
59
|
-
def get_rate(x: Any, default: float | None = 0.0) -> float | None:
|
60
|
-
"""Extract rate from dict or float."""
|
61
|
-
if isinstance(x, dict):
|
62
|
-
if "rate" in x:
|
63
|
-
return float(x["rate"])
|
64
|
-
if "proportion" in x:
|
65
|
-
return float(x["proportion"])
|
66
|
-
return default
|
67
|
-
if isinstance(x, float):
|
68
|
-
return float(x)
|
69
|
-
return default
|
70
|
-
|
71
|
-
def get_ci_tuple(x: Any) -> tuple[float, float]:
|
72
|
-
"""Extract CI bounds from dict."""
|
73
|
-
if not isinstance(x, dict):
|
74
|
-
return (0.0, 0.0)
|
75
|
-
if "ci_95" in x and isinstance(x["ci_95"], tuple | list) and len(x["ci_95"]) == 2:
|
76
|
-
return float(x["ci_95"][0]), float(x["ci_95"][1])
|
77
|
-
lo = x.get("lower", 0.0)
|
78
|
-
hi = x.get("upper", 0.0)
|
79
|
-
return float(lo), float(hi)
|
80
|
-
|
81
|
-
def ensure_ci(d: dict[str, Any], count: int, total: int) -> tuple[float, float, float]:
|
82
|
-
"""Return (rate, lo, hi). If d already has rate/CI, use them; else compute CP from count/total."""
|
83
|
-
r = get_rate(d, default=None)
|
84
|
-
lo, hi = get_ci_tuple(d)
|
85
|
-
if r is None or (lo == 0.0 and hi == 0.0 and (count > 0 or total > 0)):
|
86
|
-
ci = cp_interval(count, total)
|
87
|
-
return ci["proportion"], ci["lower"], ci["upper"]
|
88
|
-
return float(r), float(lo), float(hi)
|
89
|
-
|
90
|
-
def pct(x: float) -> str:
|
91
|
-
"""Format percentage."""
|
92
|
-
return f"{x:6.2%}"
|
99
|
+
from .statistics import cp_interval
|
93
100
|
|
94
101
|
summary: dict[str | int, Any] = {}
|
95
102
|
|
96
103
|
if verbose:
|
97
104
|
print("=" * 80)
|
98
|
-
print("
|
105
|
+
print("MONDRIAN CONFORMAL PREDICTION REPORT")
|
99
106
|
print("=" * 80)
|
100
107
|
|
101
|
-
#
|
102
|
-
for class_label in [
|
103
|
-
if class_label not in prediction_stats:
|
104
|
-
continue
|
108
|
+
# ==================== PER-CLASS STATISTICS ====================
|
109
|
+
for class_label in sorted([k for k in prediction_stats.keys() if isinstance(k, int)]):
|
105
110
|
cls = prediction_stats[class_label]
|
106
111
|
|
107
112
|
if isinstance(cls, dict) and "error" in cls:
|
108
113
|
if verbose:
|
109
|
-
print(f"\
|
114
|
+
print(f"\nCLASS {class_label}: {cls['error']}")
|
110
115
|
summary[class_label] = {"error": cls["error"]}
|
111
116
|
continue
|
112
117
|
|
113
118
|
n = int(cls.get("n", cls.get("n_class", 0)))
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
pac = as_dict(cls.get("pac_bounds", {}))
|
124
|
-
|
125
|
-
# Counts
|
126
|
-
abst_count = get_count(abst)
|
127
|
-
sing_count = get_count(sing)
|
128
|
-
sing_corr_count = get_count(sing_corr)
|
129
|
-
sing_inc_count = get_count(sing_inc)
|
130
|
-
doub_count = get_count(doub)
|
131
|
-
|
132
|
-
# Ensure rates/CIs (fallback to CP if missing)
|
133
|
-
abst_rate, abst_lo, abst_hi = ensure_ci(abst, abst_count, n)
|
134
|
-
sing_rate, sing_lo, sing_hi = ensure_ci(sing, sing_count, n)
|
135
|
-
sing_corr_rate, sing_corr_lo, sing_corr_hi = ensure_ci(sing_corr, sing_corr_count, n)
|
136
|
-
sing_inc_rate, sing_inc_lo, sing_inc_hi = ensure_ci(sing_inc, sing_inc_count, n)
|
137
|
-
doub_rate, doub_lo, doub_hi = ensure_ci(doub, doub_count, n)
|
138
|
-
|
139
|
-
# P(error | singleton, Y=class)
|
140
|
-
err_given_single_ci = cp_interval(sing_inc_count, sing_count if sing_count > 0 else 1)
|
141
|
-
|
142
|
-
class_summary = {
|
143
|
-
"n": n,
|
144
|
-
"alpha_target": alpha_target,
|
145
|
-
"delta": delta,
|
146
|
-
"abstentions": {"count": abst_count, "rate": abst_rate, "ci_95": (abst_lo, abst_hi)},
|
147
|
-
"singletons": {
|
148
|
-
"count": sing_count,
|
149
|
-
"rate": sing_rate,
|
150
|
-
"ci_95": (sing_lo, sing_hi),
|
151
|
-
"correct": {"count": sing_corr_count, "rate": sing_corr_rate, "ci_95": (sing_corr_lo, sing_corr_hi)},
|
152
|
-
"incorrect": {"count": sing_inc_count, "rate": sing_inc_rate, "ci_95": (sing_inc_lo, sing_inc_hi)},
|
153
|
-
"error_given_singleton": {
|
154
|
-
"count": sing_inc_count,
|
155
|
-
"denom": sing_count,
|
156
|
-
"rate": err_given_single_ci["proportion"],
|
157
|
-
"ci_95": (err_given_single_ci["lower"], err_given_single_ci["upper"]),
|
158
|
-
},
|
159
|
-
},
|
160
|
-
"doublets": {"count": doub_count, "rate": doub_rate, "ci_95": (doub_lo, doub_hi)},
|
161
|
-
"pac_bounds": pac,
|
162
|
-
}
|
163
|
-
summary[class_label] = class_summary
|
119
|
+
if n == 0:
|
120
|
+
continue
|
121
|
+
|
122
|
+
# Get calibration info
|
123
|
+
cal = calibration_result.get(class_label, {})
|
124
|
+
alpha_target = cal.get("alpha_target")
|
125
|
+
alpha_corrected = cal.get("alpha_corrected")
|
126
|
+
delta = cal.get("delta")
|
127
|
+
threshold = cal.get("threshold")
|
164
128
|
|
165
129
|
if verbose:
|
166
130
|
print(f"\n{'=' * 80}")
|
167
131
|
print(f"CLASS {class_label} (Conditioned on True Label = {class_label})")
|
168
132
|
print(f"{'=' * 80}")
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
f"
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
)
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
133
|
+
print(f" Calibration size: n = {n}")
|
134
|
+
if alpha_target is not None:
|
135
|
+
print(f" Target miscoverage: α = {alpha_target:.3f}")
|
136
|
+
if alpha_corrected is not None:
|
137
|
+
print(f" SSBC-corrected α: α' = {alpha_corrected:.4f}")
|
138
|
+
if delta is not None:
|
139
|
+
print(f" PAC risk: δ = {delta:.3f}")
|
140
|
+
if threshold is not None:
|
141
|
+
print(f" Conformal threshold: {threshold:.4f}")
|
142
|
+
|
143
|
+
# Per-class stats from calibration data (VALID - exchangeable within class)
|
144
|
+
if verbose:
|
145
|
+
print(f"\n 📊 Statistics from Calibration Data (n={n}):")
|
146
|
+
print(" [Basic CP CIs without PAC guarantee - evaluated on calibration data]")
|
147
|
+
|
148
|
+
# Abstentions
|
149
|
+
abstentions = cls.get("abstentions", {})
|
150
|
+
if isinstance(abstentions, dict):
|
151
|
+
abst_count = abstentions.get("count", 0)
|
152
|
+
abst_ci = cp_interval(abst_count, n)
|
153
|
+
if verbose:
|
154
|
+
print(
|
155
|
+
f" Abstentions: {abst_count:4d} / {n:4d} = {abst_ci['proportion']:6.2%} "
|
156
|
+
f"95% CI: [{abst_ci['lower']:.3f}, {abst_ci['upper']:.3f}]"
|
157
|
+
)
|
158
|
+
|
159
|
+
# Singletons (note: singletons_correct/incorrect are at top level, not nested)
|
160
|
+
singletons = cls.get("singletons", {})
|
161
|
+
singletons_correct = cls.get("singletons_correct", {})
|
162
|
+
singletons_incorrect = cls.get("singletons_incorrect", {})
|
163
|
+
|
164
|
+
if isinstance(singletons, dict):
|
165
|
+
sing_count = singletons.get("count", 0)
|
166
|
+
sing_correct = singletons_correct.get("count", 0) if isinstance(singletons_correct, dict) else 0
|
167
|
+
sing_incorrect = singletons_incorrect.get("count", 0) if isinstance(singletons_incorrect, dict) else 0
|
168
|
+
|
169
|
+
# Compute valid CIs (exchangeable within class)
|
170
|
+
sing_ci = cp_interval(sing_count, n)
|
171
|
+
sing_corr_ci = cp_interval(sing_correct, n)
|
172
|
+
sing_inc_ci = cp_interval(sing_incorrect, n)
|
173
|
+
|
174
|
+
if verbose:
|
175
|
+
print(
|
176
|
+
f" Singletons: {sing_count:4d} / {n:4d} = {sing_ci['proportion']:6.2%} "
|
177
|
+
f"95% CI: [{sing_ci['lower']:.3f}, {sing_ci['upper']:.3f}]"
|
178
|
+
)
|
179
|
+
print(
|
180
|
+
f" Correct: {sing_correct:4d} / {n:4d} = {sing_corr_ci['proportion']:6.2%} "
|
181
|
+
f"95% CI: [{sing_corr_ci['lower']:.3f}, {sing_corr_ci['upper']:.3f}]"
|
182
|
+
)
|
183
|
+
print(
|
184
|
+
f" Incorrect: {sing_incorrect:4d} / {n:4d} = {sing_inc_ci['proportion']:6.2%} "
|
185
|
+
f"95% CI: [{sing_inc_ci['lower']:.3f}, {sing_inc_ci['upper']:.3f}]"
|
186
|
+
)
|
187
|
+
|
188
|
+
# Error rate given singleton
|
189
|
+
if sing_count > 0:
|
190
|
+
err_given_sing = cp_interval(sing_incorrect, sing_count)
|
191
|
+
print(
|
192
|
+
f" Error | singleton: {sing_incorrect:4d} / {sing_count:4d} = "
|
193
|
+
f"{err_given_sing['proportion']:6.2%} "
|
194
|
+
f"95% CI: [{err_given_sing['lower']:.3f}, {err_given_sing['upper']:.3f}]"
|
195
|
+
)
|
196
|
+
|
197
|
+
# Doublets
|
198
|
+
doublets = cls.get("doublets", {})
|
199
|
+
if isinstance(doublets, dict):
|
200
|
+
doub_count = doublets.get("count", 0)
|
201
|
+
doub_ci = cp_interval(doub_count, n)
|
202
|
+
if verbose:
|
203
|
+
print(
|
204
|
+
f" Doublets: {doub_count:4d} / {n:4d} = {doub_ci['proportion']:6.2%} "
|
205
|
+
f"95% CI: [{doub_ci['lower']:.3f}, {doub_ci['upper']:.3f}]"
|
206
|
+
)
|
207
|
+
|
208
|
+
# PAC bounds (ρ, κ, α'_bound) - important theoretical guarantees
|
209
|
+
pac_bounds = cls.get("pac_bounds", {})
|
210
|
+
if isinstance(pac_bounds, dict) and pac_bounds.get("rho") is not None:
|
211
|
+
if verbose:
|
212
|
+
print(f"\n 📐 PAC Singleton Error Bound (δ={delta:.3f}):")
|
213
|
+
print(f" ρ = {pac_bounds.get('rho', 0):.3f}, κ = {pac_bounds.get('kappa', 0):.3f}")
|
214
|
+
if "alpha_singlet_bound" in pac_bounds and "alpha_singlet_observed" in pac_bounds:
|
215
|
+
bound = float(pac_bounds["alpha_singlet_bound"])
|
216
|
+
observed = float(pac_bounds["alpha_singlet_observed"])
|
205
217
|
ok = "✓" if observed <= bound else "✗"
|
206
|
-
print(f"
|
207
|
-
print(f"
|
208
|
-
|
209
|
-
# ----------------- marginal / deployment view -----------------
|
210
|
-
if "marginal" in prediction_stats:
|
211
|
-
marg = prediction_stats["marginal"]
|
212
|
-
n_total = int(marg["n_total"])
|
213
|
-
|
214
|
-
cov = as_dict(marg.get("coverage", {}))
|
215
|
-
abst_m = as_dict(marg.get("abstentions", {}))
|
216
|
-
sing_m = as_dict(marg.get("singletons", {}))
|
217
|
-
doub_m = as_dict(marg.get("doublets", {}))
|
218
|
-
pac_m = as_dict(marg.get("pac_bounds", {}))
|
219
|
-
|
220
|
-
cov_count = get_count(cov)
|
221
|
-
abst_m_count = get_count(abst_m)
|
222
|
-
sing_total = get_count(sing_m)
|
223
|
-
doub_m_count = get_count(doub_m)
|
224
|
-
|
225
|
-
cov_rate, cov_lo, cov_hi = ensure_ci(cov, cov_count, n_total)
|
226
|
-
abst_m_rate, abst_m_lo, abst_m_hi = ensure_ci(abst_m, abst_m_count, n_total)
|
227
|
-
sing_m_rate, sing_m_lo, sing_m_hi = ensure_ci(sing_m, sing_total, n_total)
|
228
|
-
doub_m_rate, doub_m_lo, doub_m_hi = ensure_ci(doub_m, doub_m_count, n_total)
|
229
|
-
|
230
|
-
# pred_0 / pred_1 may be dicts or ints (counts)
|
231
|
-
raw_s0 = sing_m.get("pred_0", 0)
|
232
|
-
raw_s1 = sing_m.get("pred_1", 0)
|
233
|
-
s0_count = get_count(raw_s0)
|
234
|
-
s1_count = get_count(raw_s1)
|
235
|
-
|
236
|
-
# Prefer provided rate/CI, else compute off n_total
|
237
|
-
if isinstance(raw_s0, dict):
|
238
|
-
s0_rate, s0_lo, s0_hi = ensure_ci(raw_s0, s0_count, n_total)
|
239
|
-
else:
|
240
|
-
s0_ci = cp_interval(s0_count, n_total)
|
241
|
-
s0_rate, s0_lo, s0_hi = s0_ci["proportion"], s0_ci["lower"], s0_ci["upper"]
|
218
|
+
print(f" α'_bound: {bound:.4f}")
|
219
|
+
print(f" α'_observed: {observed:.4f} {ok}")
|
242
220
|
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
},
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
221
|
+
# Operational bounds (RIGOROUS - cross-validated with PAC guarantees)
|
222
|
+
if operational_bounds_per_class and class_label in operational_bounds_per_class:
|
223
|
+
op_bounds = operational_bounds_per_class[class_label]
|
224
|
+
|
225
|
+
if verbose:
|
226
|
+
print("\n ✅ RIGOROUS Operational Bounds (LOO-CV)")
|
227
|
+
print(f" CI width: {op_bounds.ci_width:.1%}")
|
228
|
+
print(f" Calibration size: n = {op_bounds.n_calibration}")
|
229
|
+
|
230
|
+
# Show main rates (singleton, doublet, abstention)
|
231
|
+
for rate_name in ["abstention", "singleton", "doublet"]:
|
232
|
+
if rate_name in op_bounds.rate_bounds:
|
233
|
+
bounds = op_bounds.rate_bounds[rate_name]
|
234
|
+
if verbose:
|
235
|
+
print(f"\n {rate_name.upper()}:")
|
236
|
+
print(f" Bounds: [{bounds.lower_bound:.3f}, {bounds.upper_bound:.3f}]")
|
237
|
+
print(f" Count: {bounds.n_successes}/{bounds.n_evaluations}")
|
238
|
+
|
239
|
+
# Show conditional singleton rates (conditional on having a singleton)
|
240
|
+
has_correct = "correct_in_singleton" in op_bounds.rate_bounds
|
241
|
+
has_error = "error_in_singleton" in op_bounds.rate_bounds
|
242
|
+
has_singleton = "singleton" in op_bounds.rate_bounds
|
243
|
+
|
244
|
+
if verbose and (has_correct or has_error) and has_singleton:
|
245
|
+
print("\n CONDITIONAL RATES (conditioned on singleton, with CP+PAC bounds):")
|
246
|
+
|
247
|
+
singleton_bounds = op_bounds.rate_bounds["singleton"]
|
248
|
+
n_singletons = singleton_bounds.n_successes
|
249
|
+
|
250
|
+
# P(correct | singleton) with rigorous CP bounds
|
251
|
+
if has_correct and n_singletons > 0:
|
252
|
+
correct_bounds = op_bounds.rate_bounds["correct_in_singleton"]
|
253
|
+
n_correct = correct_bounds.n_successes
|
254
|
+
|
255
|
+
# Conditional rate and CP interval
|
256
|
+
rate = n_correct / n_singletons if n_singletons > 0 else 0.0
|
257
|
+
ci = cp_interval(n_correct, n_singletons)
|
258
|
+
|
259
|
+
print(f" P(correct | singleton) = {rate:.3f} 95% CI: [{ci['lower']:.3f}, {ci['upper']:.3f}]")
|
260
|
+
|
261
|
+
# P(error | singleton) with rigorous CP bounds
|
262
|
+
if has_error and n_singletons > 0:
|
263
|
+
error_bounds = op_bounds.rate_bounds["error_in_singleton"]
|
264
|
+
n_error = error_bounds.n_successes
|
265
|
+
|
266
|
+
# Conditional rate and CP interval
|
267
|
+
rate = n_error / n_singletons if n_singletons > 0 else 0.0
|
268
|
+
ci = cp_interval(n_error, n_singletons)
|
269
|
+
|
270
|
+
print(f" P(error | singleton) = {rate:.3f} 95% CI: [{ci['lower']:.3f}, {ci['upper']:.3f}]")
|
271
|
+
|
272
|
+
# Store in summary
|
273
|
+
summary[class_label] = {
|
274
|
+
"n": n,
|
275
|
+
"alpha_target": alpha_target,
|
276
|
+
"alpha_corrected": alpha_corrected,
|
277
|
+
"threshold": threshold,
|
278
|
+
"calibration_stats": {
|
279
|
+
"abstentions": abstentions,
|
280
|
+
"singletons": singletons,
|
281
|
+
"doublets": doublets,
|
300
282
|
},
|
301
|
-
"
|
302
|
-
"pac_bounds": pac_m,
|
283
|
+
"pac_bounds": pac_bounds,
|
303
284
|
}
|
304
|
-
|
285
|
+
if operational_bounds_per_class and class_label in operational_bounds_per_class:
|
286
|
+
summary[class_label]["operational_bounds"] = operational_bounds_per_class[class_label]
|
305
287
|
|
288
|
+
# ==================== MARGINAL STATISTICS ====================
|
289
|
+
if marginal_operational_bounds is not None:
|
306
290
|
if verbose:
|
307
291
|
print(f"\n{'=' * 80}")
|
308
|
-
print("MARGINAL
|
292
|
+
print("MARGINAL STATISTICS (Deployment View - Ignores True Labels)")
|
309
293
|
print(f"{'=' * 80}")
|
310
|
-
print(f" Total samples: {
|
311
|
-
|
312
|
-
print("\
|
313
|
-
print(f"
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
294
|
+
print(f" Total samples: n = {marginal_operational_bounds.n_calibration}")
|
295
|
+
|
296
|
+
print("\n ✅ RIGOROUS Marginal Bounds (LOO-CV)")
|
297
|
+
print(f" CI width: {marginal_operational_bounds.ci_width:.1%}")
|
298
|
+
print(f" Total evaluations: n = {marginal_operational_bounds.n_calibration}")
|
299
|
+
|
300
|
+
# Show main rates
|
301
|
+
for rate_name in ["abstention", "singleton", "doublet"]:
|
302
|
+
if rate_name in marginal_operational_bounds.rate_bounds:
|
303
|
+
bounds = marginal_operational_bounds.rate_bounds[rate_name]
|
304
|
+
if verbose:
|
305
|
+
print(f"\n {rate_name.upper()}:")
|
306
|
+
print(f" Bounds: [{bounds.lower_bound:.3f}, {bounds.upper_bound:.3f}]")
|
307
|
+
print(f" Count: {bounds.n_successes}/{bounds.n_evaluations}")
|
308
|
+
|
309
|
+
# Show conditional singleton rates (marginal)
|
310
|
+
has_correct = "correct_in_singleton" in marginal_operational_bounds.rate_bounds
|
311
|
+
has_error = "error_in_singleton" in marginal_operational_bounds.rate_bounds
|
312
|
+
has_singleton = "singleton" in marginal_operational_bounds.rate_bounds
|
313
|
+
|
314
|
+
if verbose and (has_correct or has_error) and has_singleton:
|
315
|
+
print("\n CONDITIONAL RATES (conditioned on singleton, with CP+PAC bounds):")
|
316
|
+
|
317
|
+
singleton_bounds = marginal_operational_bounds.rate_bounds["singleton"]
|
318
|
+
n_singletons = singleton_bounds.n_successes
|
319
|
+
|
320
|
+
if has_correct and n_singletons > 0:
|
321
|
+
correct_bounds = marginal_operational_bounds.rate_bounds["correct_in_singleton"]
|
322
|
+
n_correct = correct_bounds.n_successes
|
323
|
+
|
324
|
+
# Conditional rate and CP interval
|
325
|
+
rate = n_correct / n_singletons if n_singletons > 0 else 0.0
|
326
|
+
ci = cp_interval(n_correct, n_singletons)
|
327
|
+
|
328
|
+
print(f" P(correct | singleton) = {rate:.3f} 95% CI: [{ci['lower']:.3f}, {ci['upper']:.3f}]")
|
329
|
+
|
330
|
+
if has_error and n_singletons > 0:
|
331
|
+
error_bounds = marginal_operational_bounds.rate_bounds["error_in_singleton"]
|
332
|
+
n_error = error_bounds.n_successes
|
333
|
+
|
334
|
+
# Conditional rate and CP interval
|
335
|
+
rate = n_error / n_singletons if n_singletons > 0 else 0.0
|
336
|
+
ci = cp_interval(n_error, n_singletons)
|
337
|
+
|
338
|
+
print(f" P(error | singleton) = {rate:.3f} 95% CI: [{ci['lower']:.3f}, {ci['upper']:.3f}]")
|
339
|
+
|
340
|
+
summary["marginal_bounds"] = marginal_operational_bounds
|
341
|
+
|
342
|
+
if verbose:
|
343
|
+
# Deployment interpretation
|
344
|
+
sing_bounds = marginal_operational_bounds.rate_bounds.get("singleton")
|
345
|
+
doub_bounds = marginal_operational_bounds.rate_bounds.get("doublet")
|
346
|
+
abst_bounds = marginal_operational_bounds.rate_bounds.get("abstention")
|
347
|
+
|
348
|
+
if sing_bounds:
|
349
|
+
print("\n 📈 Deployment Expectations:")
|
350
|
+
print(
|
351
|
+
f" Automation (singletons): "
|
352
|
+
f"{sing_bounds.lower_bound:.1%} - {sing_bounds.upper_bound:.1%} of cases"
|
353
|
+
)
|
354
|
+
|
355
|
+
# Escalation = doublets + abstentions
|
356
|
+
if doub_bounds and abst_bounds:
|
357
|
+
esc_lower = doub_bounds.lower_bound + abst_bounds.lower_bound
|
358
|
+
esc_upper = doub_bounds.upper_bound + abst_bounds.upper_bound
|
359
|
+
print(f" Escalation (doublets+abstentions): {esc_lower:.1%} - {esc_upper:.1%} of cases")
|
360
|
+
elif doub_bounds:
|
361
|
+
print(
|
362
|
+
f" Escalation (doublets): "
|
363
|
+
f"{doub_bounds.lower_bound:.1%} - {doub_bounds.upper_bound:.1%} of cases"
|
364
|
+
)
|
365
|
+
|
366
|
+
# ==================== WARNINGS ====================
|
367
|
+
if verbose:
|
368
|
+
print(f"\n{'=' * 80}")
|
369
|
+
print("NOTES")
|
370
|
+
print(f"{'=' * 80}")
|
371
|
+
print("\n✓ Per-class CIs are valid (Clopper-Pearson, exchangeable within class)")
|
372
|
+
|
373
|
+
if operational_bounds_per_class or marginal_operational_bounds:
|
374
|
+
print("✓ Operational bounds have PAC guarantees via cross-validation")
|
375
|
+
else:
|
376
|
+
print("\n⚠️ For rigorous deployment bounds, run:")
|
377
|
+
print(" - compute_mondrian_operational_bounds() for per-class bounds")
|
378
|
+
print(" - compute_marginal_operational_bounds() for marginal bounds")
|
355
379
|
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
print(f" Escalate: {n_escalations} doublets+abstentions ({n_escalations / n_total:.1%})")
|
380
|
+
if prediction_stats.get("marginal") and marginal_operational_bounds is None:
|
381
|
+
print("\n⚠️ Marginal stats from calibration data NOT shown (invalid CIs for Mondrian)")
|
382
|
+
print(" Use compute_marginal_operational_bounds() for valid marginal bounds")
|
360
383
|
|
361
384
|
return summary
|
362
385
|
|