ssbc 1.0.0__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ssbc/visualization.py CHANGED
@@ -5,19 +5,62 @@ from typing import Any
5
5
  from .statistics import cp_interval
6
6
 
7
7
 
8
+ def compute_conditional_rate_bounds(
9
+ numerator_fold_results: list[dict],
10
+ denominator_fold_results: list[dict],
11
+ weights: list[float],
12
+ ) -> tuple[float, float, float]:
13
+ """Compute bounds on conditional rate from fold-level counts.
14
+
15
+ For conditional rate P(A | B), computes cross-validated bounds by:
16
+ 1. Computing A_count / B_count in each fold
17
+ 2. Using Clopper-Pearson on aggregated counts
18
+
19
+ Parameters
20
+ ----------
21
+ numerator_fold_results : list[dict]
22
+ Fold results for numerator event (e.g., correct_in_singleton)
23
+ denominator_fold_results : list[dict]
24
+ Fold results for denominator event (e.g., singleton)
25
+ weights : list[float]
26
+ Fold weights
27
+
28
+ Returns
29
+ -------
30
+ rate : float
31
+ Estimated conditional rate
32
+ lower : float
33
+ Lower CI bound
34
+ upper : float
35
+ Upper CI bound
36
+ """
37
+ # Aggregate counts across folds
38
+ total_numerator = sum(fold["K_fr"] for fold in numerator_fold_results)
39
+ total_denominator = sum(fold["K_fr"] for fold in denominator_fold_results)
40
+
41
+ if total_denominator == 0:
42
+ return 0.0, 0.0, 1.0
43
+
44
+ # Compute CP interval on aggregated counts
45
+ ci = cp_interval(total_numerator, total_denominator)
46
+ return ci["proportion"], ci["lower"], ci["upper"]
47
+
48
+
8
49
  def report_prediction_stats(
9
- prediction_stats: dict[Any, Any], calibration_result: dict[Any, Any], verbose: bool = True
50
+ prediction_stats: dict[Any, Any],
51
+ calibration_result: dict[Any, Any],
52
+ operational_bounds_per_class: dict[int, Any] | None = None,
53
+ marginal_operational_bounds: Any | None = None,
54
+ verbose: bool = True,
10
55
  ) -> dict[str | int, Any]:
11
- """Pretty/robust summary for Mondrian conformal prediction stats.
56
+ """Report rigorous statistics for Mondrian conformal prediction with valid CIs.
12
57
 
13
- Tolerates multiple schema shapes:
14
- - dicts with 'rate'/'ci_95' or 'proportion'/'lower'/'upper'
15
- - raw ints for counts (e.g., marginal['singletons']['pred_0'] = 339)
16
- - per-class singleton correct/incorrect either nested under 'singletons'
17
- OR as top-level aliases 'singletons_correct' / 'singletons_incorrect'.
58
+ Only displays statistics with valid confidence intervals:
59
+ - Per-class statistics from calibration data (valid within class)
60
+ - Per-class operational bounds from cross-validation (rigorous PAC bounds)
61
+ - Marginal operational bounds from cross-validated Mondrian (rigorous PAC bounds)
18
62
 
19
- Also computes Clopper-Pearson CIs when missing, and splits marginal
20
- singleton errors by predicted class.
63
+ Does NOT display marginal statistics from calibration data (invalid CIs for Mondrian).
21
64
 
22
65
  Parameters
23
66
  ----------
@@ -25,338 +68,318 @@ def report_prediction_stats(
25
68
  Output from mondrian_conformal_calibrate (second return value)
26
69
  calibration_result : dict
27
70
  Output from mondrian_conformal_calibrate (first return value)
71
+ operational_bounds_per_class : dict[int, OperationalRateBoundsResult], optional
72
+ Per-class operational bounds from compute_mondrian_operational_bounds
73
+ marginal_operational_bounds : OperationalRateBoundsResult, optional
74
+ Marginal operational bounds from compute_marginal_operational_bounds
28
75
  verbose : bool, default=True
29
76
  If True, print detailed statistics to stdout
30
77
 
31
78
  Returns
32
79
  -------
33
80
  dict
34
- Structured summary with CIs for all metrics, containing:
81
+ Structured summary with valid CIs:
35
82
  - Keys 0, 1 for per-class statistics
36
- - Key 'marginal' for overall deployment statistics
83
+ - Key 'marginal_bounds' if marginal_operational_bounds provided
37
84
 
38
85
  Examples
39
86
  --------
40
- >>> # After calibration
87
+ >>> # Basic: Only calibration data (limited info)
41
88
  >>> cal_result, pred_stats = mondrian_conformal_calibrate(...)
42
- >>> summary = report_prediction_stats(pred_stats, cal_result, verbose=True)
43
- >>> print(summary['marginal']['coverage']['rate'])
89
+ >>> summary = report_prediction_stats(pred_stats, cal_result)
90
+ >>>
91
+ >>> # With per-class operational bounds (rigorous)
92
+ >>> op_bounds = compute_mondrian_operational_bounds(cal_result, class_data, delta=0.05)
93
+ >>> summary = report_prediction_stats(pred_stats, cal_result, op_bounds)
94
+ >>>
95
+ >>> # With marginal bounds too (complete SLA)
96
+ >>> marginal = compute_marginal_operational_bounds(labels, probs, 0.1, 0.05, 0.05)
97
+ >>> summary = report_prediction_stats(pred_stats, cal_result, op_bounds, marginal)
44
98
  """
45
-
46
- # Helper functions
47
- def as_dict(x: Any) -> dict[str, Any]:
48
- """Ensure x is a dict."""
49
- return x if isinstance(x, dict) else {}
50
-
51
- def get_count(x: Any, default: int = 0) -> int:
52
- """Extract count from dict or int."""
53
- if isinstance(x, dict):
54
- return int(x.get("count", default))
55
- if isinstance(x, int):
56
- return int(x)
57
- return default
58
-
59
- def get_rate(x: Any, default: float | None = 0.0) -> float | None:
60
- """Extract rate from dict or float."""
61
- if isinstance(x, dict):
62
- if "rate" in x:
63
- return float(x["rate"])
64
- if "proportion" in x:
65
- return float(x["proportion"])
66
- return default
67
- if isinstance(x, float):
68
- return float(x)
69
- return default
70
-
71
- def get_ci_tuple(x: Any) -> tuple[float, float]:
72
- """Extract CI bounds from dict."""
73
- if not isinstance(x, dict):
74
- return (0.0, 0.0)
75
- if "ci_95" in x and isinstance(x["ci_95"], tuple | list) and len(x["ci_95"]) == 2:
76
- return float(x["ci_95"][0]), float(x["ci_95"][1])
77
- lo = x.get("lower", 0.0)
78
- hi = x.get("upper", 0.0)
79
- return float(lo), float(hi)
80
-
81
- def ensure_ci(d: dict[str, Any], count: int, total: int) -> tuple[float, float, float]:
82
- """Return (rate, lo, hi). If d already has rate/CI, use them; else compute CP from count/total."""
83
- r = get_rate(d, default=None)
84
- lo, hi = get_ci_tuple(d)
85
- if r is None or (lo == 0.0 and hi == 0.0 and (count > 0 or total > 0)):
86
- ci = cp_interval(count, total)
87
- return ci["proportion"], ci["lower"], ci["upper"]
88
- return float(r), float(lo), float(hi)
89
-
90
- def pct(x: float) -> str:
91
- """Format percentage."""
92
- return f"{x:6.2%}"
99
+ from .statistics import cp_interval
93
100
 
94
101
  summary: dict[str | int, Any] = {}
95
102
 
96
103
  if verbose:
97
104
  print("=" * 80)
98
- print("PREDICTION SET STATISTICS (All rates with 95% Clopper-Pearson CIs)")
105
+ print("MONDRIAN CONFORMAL PREDICTION REPORT")
99
106
  print("=" * 80)
100
107
 
101
- # ----------------- per-class (conditioned on Y) -----------------
102
- for class_label in [0, 1]:
103
- if class_label not in prediction_stats:
104
- continue
108
+ # ==================== PER-CLASS STATISTICS ====================
109
+ for class_label in sorted([k for k in prediction_stats.keys() if isinstance(k, int)]):
105
110
  cls = prediction_stats[class_label]
106
111
 
107
112
  if isinstance(cls, dict) and "error" in cls:
108
113
  if verbose:
109
- print(f"\nClass {class_label}: {cls['error']}")
114
+ print(f"\nCLASS {class_label}: {cls['error']}")
110
115
  summary[class_label] = {"error": cls["error"]}
111
116
  continue
112
117
 
113
118
  n = int(cls.get("n", cls.get("n_class", 0)))
114
- alpha_target = cls.get("alpha_target", calibration_result.get(class_label, {}).get("alpha_target", None))
115
- delta = cls.get("delta", calibration_result.get(class_label, {}).get("delta", None))
116
-
117
- abst = as_dict(cls.get("abstentions", {}))
118
- sing = as_dict(cls.get("singletons", {}))
119
- # Accept both nested and flat aliases
120
- sing_corr = as_dict(sing.get("correct", cls.get("singletons_correct", {})))
121
- sing_inc = as_dict(sing.get("incorrect", cls.get("singletons_incorrect", {})))
122
- doub = as_dict(cls.get("doublets", {}))
123
- pac = as_dict(cls.get("pac_bounds", {}))
124
-
125
- # Counts
126
- abst_count = get_count(abst)
127
- sing_count = get_count(sing)
128
- sing_corr_count = get_count(sing_corr)
129
- sing_inc_count = get_count(sing_inc)
130
- doub_count = get_count(doub)
131
-
132
- # Ensure rates/CIs (fallback to CP if missing)
133
- abst_rate, abst_lo, abst_hi = ensure_ci(abst, abst_count, n)
134
- sing_rate, sing_lo, sing_hi = ensure_ci(sing, sing_count, n)
135
- sing_corr_rate, sing_corr_lo, sing_corr_hi = ensure_ci(sing_corr, sing_corr_count, n)
136
- sing_inc_rate, sing_inc_lo, sing_inc_hi = ensure_ci(sing_inc, sing_inc_count, n)
137
- doub_rate, doub_lo, doub_hi = ensure_ci(doub, doub_count, n)
138
-
139
- # P(error | singleton, Y=class)
140
- err_given_single_ci = cp_interval(sing_inc_count, sing_count if sing_count > 0 else 1)
141
-
142
- class_summary = {
143
- "n": n,
144
- "alpha_target": alpha_target,
145
- "delta": delta,
146
- "abstentions": {"count": abst_count, "rate": abst_rate, "ci_95": (abst_lo, abst_hi)},
147
- "singletons": {
148
- "count": sing_count,
149
- "rate": sing_rate,
150
- "ci_95": (sing_lo, sing_hi),
151
- "correct": {"count": sing_corr_count, "rate": sing_corr_rate, "ci_95": (sing_corr_lo, sing_corr_hi)},
152
- "incorrect": {"count": sing_inc_count, "rate": sing_inc_rate, "ci_95": (sing_inc_lo, sing_inc_hi)},
153
- "error_given_singleton": {
154
- "count": sing_inc_count,
155
- "denom": sing_count,
156
- "rate": err_given_single_ci["proportion"],
157
- "ci_95": (err_given_single_ci["lower"], err_given_single_ci["upper"]),
158
- },
159
- },
160
- "doublets": {"count": doub_count, "rate": doub_rate, "ci_95": (doub_lo, doub_hi)},
161
- "pac_bounds": pac,
162
- }
163
- summary[class_label] = class_summary
119
+ if n == 0:
120
+ continue
121
+
122
+ # Get calibration info
123
+ cal = calibration_result.get(class_label, {})
124
+ alpha_target = cal.get("alpha_target")
125
+ alpha_corrected = cal.get("alpha_corrected")
126
+ delta = cal.get("delta")
127
+ threshold = cal.get("threshold")
164
128
 
165
129
  if verbose:
166
130
  print(f"\n{'=' * 80}")
167
131
  print(f"CLASS {class_label} (Conditioned on True Label = {class_label})")
168
132
  print(f"{'=' * 80}")
169
- alpha_str = f"{alpha_target:.3f}" if alpha_target is not None else "n/a"
170
- delta_str = f"{delta:.3f}" if delta is not None else "n/a"
171
- print(f" n={n}, α_target={alpha_str}, δ={delta_str}")
172
-
173
- print("\nPrediction Set Breakdown:")
174
- print(
175
- f" Abstentions: {abst_count:4d} / {n:4d} = {pct(abst_rate)} 95% CI: [{abst_lo:.4f}, {abst_hi:.4f}]"
176
- )
177
- print(
178
- f" Singletons: {sing_count:4d} / {n:4d} = {pct(sing_rate)} 95% CI: [{sing_lo:.4f}, {sing_hi:.4f}]"
179
- )
180
- print(
181
- f" ├─ Correct: {sing_corr_count:4d} / {n:4d} = {pct(sing_corr_rate)} "
182
- f"95% CI: [{sing_corr_lo:.4f}, {sing_corr_hi:.4f}]"
183
- )
184
- print(
185
- f" └─ Incorrect: {sing_inc_count:4d} / {n:4d} = {pct(sing_inc_rate)} "
186
- f"95% CI: [{sing_inc_lo:.4f}, {sing_inc_hi:.4f}]"
187
- )
188
-
189
- print(
190
- f" Singleton error | Y={class_label}: "
191
- f"{sing_inc_count:4d} / {sing_count:4d} = {pct(err_given_single_ci['proportion'])} "
192
- f"95% CI: [{err_given_single_ci['lower']:.4f}, {err_given_single_ci['upper']:.4f}]"
193
- )
194
-
195
- print(
196
- f"\n Doublets: {doub_count:4d} / {n:4d} = {pct(doub_rate)} 95% CI: [{doub_lo:.4f}, {doub_hi:.4f}]"
197
- )
198
-
199
- if pac and pac.get("rho", None) is not None:
200
- print(f"\n PAC Singleton Error Rate (δ={delta_str}):")
201
- print(f" ρ = {pac.get('rho', 0):.3f}, κ = {pac.get('kappa', 0):.3f}")
202
- if "alpha_singlet_bound" in pac and "alpha_singlet_observed" in pac:
203
- bound = float(pac["alpha_singlet_bound"])
204
- observed = float(pac["alpha_singlet_observed"])
133
+ print(f" Calibration size: n = {n}")
134
+ if alpha_target is not None:
135
+ print(f" Target miscoverage: α = {alpha_target:.3f}")
136
+ if alpha_corrected is not None:
137
+ print(f" SSBC-corrected α: α' = {alpha_corrected:.4f}")
138
+ if delta is not None:
139
+ print(f" PAC risk: δ = {delta:.3f}")
140
+ if threshold is not None:
141
+ print(f" Conformal threshold: {threshold:.4f}")
142
+
143
+ # Per-class stats from calibration data (VALID - exchangeable within class)
144
+ if verbose:
145
+ print(f"\n 📊 Statistics from Calibration Data (n={n}):")
146
+ print(" [Basic CP CIs without PAC guarantee - evaluated on calibration data]")
147
+
148
+ # Abstentions
149
+ abstentions = cls.get("abstentions", {})
150
+ if isinstance(abstentions, dict):
151
+ abst_count = abstentions.get("count", 0)
152
+ abst_ci = cp_interval(abst_count, n)
153
+ if verbose:
154
+ print(
155
+ f" Abstentions: {abst_count:4d} / {n:4d} = {abst_ci['proportion']:6.2%} "
156
+ f"95% CI: [{abst_ci['lower']:.3f}, {abst_ci['upper']:.3f}]"
157
+ )
158
+
159
+ # Singletons (note: singletons_correct/incorrect are at top level, not nested)
160
+ singletons = cls.get("singletons", {})
161
+ singletons_correct = cls.get("singletons_correct", {})
162
+ singletons_incorrect = cls.get("singletons_incorrect", {})
163
+
164
+ if isinstance(singletons, dict):
165
+ sing_count = singletons.get("count", 0)
166
+ sing_correct = singletons_correct.get("count", 0) if isinstance(singletons_correct, dict) else 0
167
+ sing_incorrect = singletons_incorrect.get("count", 0) if isinstance(singletons_incorrect, dict) else 0
168
+
169
+ # Compute valid CIs (exchangeable within class)
170
+ sing_ci = cp_interval(sing_count, n)
171
+ sing_corr_ci = cp_interval(sing_correct, n)
172
+ sing_inc_ci = cp_interval(sing_incorrect, n)
173
+
174
+ if verbose:
175
+ print(
176
+ f" Singletons: {sing_count:4d} / {n:4d} = {sing_ci['proportion']:6.2%} "
177
+ f"95% CI: [{sing_ci['lower']:.3f}, {sing_ci['upper']:.3f}]"
178
+ )
179
+ print(
180
+ f" Correct: {sing_correct:4d} / {n:4d} = {sing_corr_ci['proportion']:6.2%} "
181
+ f"95% CI: [{sing_corr_ci['lower']:.3f}, {sing_corr_ci['upper']:.3f}]"
182
+ )
183
+ print(
184
+ f" Incorrect: {sing_incorrect:4d} / {n:4d} = {sing_inc_ci['proportion']:6.2%} "
185
+ f"95% CI: [{sing_inc_ci['lower']:.3f}, {sing_inc_ci['upper']:.3f}]"
186
+ )
187
+
188
+ # Error rate given singleton
189
+ if sing_count > 0:
190
+ err_given_sing = cp_interval(sing_incorrect, sing_count)
191
+ print(
192
+ f" Error | singleton: {sing_incorrect:4d} / {sing_count:4d} = "
193
+ f"{err_given_sing['proportion']:6.2%} "
194
+ f"95% CI: [{err_given_sing['lower']:.3f}, {err_given_sing['upper']:.3f}]"
195
+ )
196
+
197
+ # Doublets
198
+ doublets = cls.get("doublets", {})
199
+ if isinstance(doublets, dict):
200
+ doub_count = doublets.get("count", 0)
201
+ doub_ci = cp_interval(doub_count, n)
202
+ if verbose:
203
+ print(
204
+ f" Doublets: {doub_count:4d} / {n:4d} = {doub_ci['proportion']:6.2%} "
205
+ f"95% CI: [{doub_ci['lower']:.3f}, {doub_ci['upper']:.3f}]"
206
+ )
207
+
208
+ # PAC bounds (ρ, κ, α'_bound) - important theoretical guarantees
209
+ pac_bounds = cls.get("pac_bounds", {})
210
+ if isinstance(pac_bounds, dict) and pac_bounds.get("rho") is not None:
211
+ if verbose:
212
+ print(f"\n 📐 PAC Singleton Error Bound (δ={delta:.3f}):")
213
+ print(f" ρ = {pac_bounds.get('rho', 0):.3f}, κ = {pac_bounds.get('kappa', 0):.3f}")
214
+ if "alpha_singlet_bound" in pac_bounds and "alpha_singlet_observed" in pac_bounds:
215
+ bound = float(pac_bounds["alpha_singlet_bound"])
216
+ observed = float(pac_bounds["alpha_singlet_observed"])
205
217
  ok = "✓" if observed <= bound else "✗"
206
- print(f" α'_bound: {bound:.4f}")
207
- print(f" α'_observed: {observed:.4f} {ok}")
208
-
209
- # ----------------- marginal / deployment view -----------------
210
- if "marginal" in prediction_stats:
211
- marg = prediction_stats["marginal"]
212
- n_total = int(marg["n_total"])
213
-
214
- cov = as_dict(marg.get("coverage", {}))
215
- abst_m = as_dict(marg.get("abstentions", {}))
216
- sing_m = as_dict(marg.get("singletons", {}))
217
- doub_m = as_dict(marg.get("doublets", {}))
218
- pac_m = as_dict(marg.get("pac_bounds", {}))
219
-
220
- cov_count = get_count(cov)
221
- abst_m_count = get_count(abst_m)
222
- sing_total = get_count(sing_m)
223
- doub_m_count = get_count(doub_m)
224
-
225
- cov_rate, cov_lo, cov_hi = ensure_ci(cov, cov_count, n_total)
226
- abst_m_rate, abst_m_lo, abst_m_hi = ensure_ci(abst_m, abst_m_count, n_total)
227
- sing_m_rate, sing_m_lo, sing_m_hi = ensure_ci(sing_m, sing_total, n_total)
228
- doub_m_rate, doub_m_lo, doub_m_hi = ensure_ci(doub_m, doub_m_count, n_total)
229
-
230
- # pred_0 / pred_1 may be dicts or ints (counts)
231
- raw_s0 = sing_m.get("pred_0", 0)
232
- raw_s1 = sing_m.get("pred_1", 0)
233
- s0_count = get_count(raw_s0)
234
- s1_count = get_count(raw_s1)
235
-
236
- # Prefer provided rate/CI, else compute off n_total
237
- if isinstance(raw_s0, dict):
238
- s0_rate, s0_lo, s0_hi = ensure_ci(raw_s0, s0_count, n_total)
239
- else:
240
- s0_ci = cp_interval(s0_count, n_total)
241
- s0_rate, s0_lo, s0_hi = s0_ci["proportion"], s0_ci["lower"], s0_ci["upper"]
218
+ print(f" α'_bound: {bound:.4f}")
219
+ print(f" α'_observed: {observed:.4f} {ok}")
242
220
 
243
- if isinstance(raw_s1, dict):
244
- s1_rate, s1_lo, s1_hi = ensure_ci(raw_s1, s1_count, n_total)
245
- else:
246
- s1_ci = cp_interval(s1_count, n_total)
247
- s1_rate, s1_lo, s1_hi = s1_ci["proportion"], s1_ci["lower"], s1_ci["upper"]
248
-
249
- # Overall singleton errors (dict or int). Denominator should be sing_total.
250
- raw_s_err = sing_m.get("errors", 0)
251
- s_err_count = get_count(raw_s_err)
252
- if isinstance(raw_s_err, dict):
253
- s_err_rate, s_err_lo, s_err_hi = ensure_ci(raw_s_err, s_err_count, sing_total if sing_total > 0 else 1)
254
- else:
255
- se_ci = cp_interval(s_err_count, sing_total if sing_total > 0 else 1)
256
- s_err_rate, s_err_lo, s_err_hi = se_ci["proportion"], se_ci["lower"], se_ci["upper"]
257
-
258
- # Errors by predicted class via per-class incorrect singletons
259
- # (pred 0 errors happen when Y=1 singleton is wrong; pred 1 errors when Y=0 singleton is wrong)
260
- err_pred0_count = int(
261
- prediction_stats.get(1, {})
262
- .get("singletons", {})
263
- .get("incorrect", prediction_stats.get(1, {}).get("singletons_incorrect", {}))
264
- .get("count", 0)
265
- )
266
- err_pred1_count = int(
267
- prediction_stats.get(0, {})
268
- .get("singletons", {})
269
- .get("incorrect", prediction_stats.get(0, {}).get("singletons_incorrect", {}))
270
- .get("count", 0)
271
- )
272
- pred0_err_ci = cp_interval(err_pred0_count, s0_count if s0_count > 0 else 1)
273
- pred1_err_ci = cp_interval(err_pred1_count, s1_count if s1_count > 0 else 1)
274
-
275
- marginal_summary = {
276
- "n_total": n_total,
277
- "coverage": {"count": cov_count, "rate": cov_rate, "ci_95": (cov_lo, cov_hi)},
278
- "abstentions": {"count": abst_m_count, "rate": abst_m_rate, "ci_95": (abst_m_lo, abst_m_hi)},
279
- "singletons": {
280
- "count": sing_total,
281
- "rate": sing_m_rate,
282
- "ci_95": (sing_m_lo, sing_m_hi),
283
- "pred_0": {"count": s0_count, "rate": s0_rate, "ci_95": (s0_lo, s0_hi)},
284
- "pred_1": {"count": s1_count, "rate": s1_rate, "ci_95": (s1_lo, s1_hi)},
285
- "errors": {"count": s_err_count, "rate": s_err_rate, "ci_95": (s_err_lo, s_err_hi)},
286
- "errors_by_pred": {
287
- "pred_0": {
288
- "count": err_pred0_count,
289
- "denom": s0_count,
290
- "rate": pred0_err_ci["proportion"],
291
- "ci_95": (pred0_err_ci["lower"], pred0_err_ci["upper"]),
292
- },
293
- "pred_1": {
294
- "count": err_pred1_count,
295
- "denom": s1_count,
296
- "rate": pred1_err_ci["proportion"],
297
- "ci_95": (pred1_err_ci["lower"], pred1_err_ci["upper"]),
298
- },
299
- },
221
+ # Operational bounds (RIGOROUS - cross-validated with PAC guarantees)
222
+ if operational_bounds_per_class and class_label in operational_bounds_per_class:
223
+ op_bounds = operational_bounds_per_class[class_label]
224
+
225
+ if verbose:
226
+ print("\n ✅ RIGOROUS Operational Bounds (LOO-CV)")
227
+ print(f" CI width: {op_bounds.ci_width:.1%}")
228
+ print(f" Calibration size: n = {op_bounds.n_calibration}")
229
+
230
+ # Show main rates (singleton, doublet, abstention)
231
+ for rate_name in ["abstention", "singleton", "doublet"]:
232
+ if rate_name in op_bounds.rate_bounds:
233
+ bounds = op_bounds.rate_bounds[rate_name]
234
+ if verbose:
235
+ print(f"\n {rate_name.upper()}:")
236
+ print(f" Bounds: [{bounds.lower_bound:.3f}, {bounds.upper_bound:.3f}]")
237
+ print(f" Count: {bounds.n_successes}/{bounds.n_evaluations}")
238
+
239
+ # Show conditional singleton rates (conditional on having a singleton)
240
+ has_correct = "correct_in_singleton" in op_bounds.rate_bounds
241
+ has_error = "error_in_singleton" in op_bounds.rate_bounds
242
+ has_singleton = "singleton" in op_bounds.rate_bounds
243
+
244
+ if verbose and (has_correct or has_error) and has_singleton:
245
+ print("\n CONDITIONAL RATES (conditioned on singleton, with CP+PAC bounds):")
246
+
247
+ singleton_bounds = op_bounds.rate_bounds["singleton"]
248
+ n_singletons = singleton_bounds.n_successes
249
+
250
+ # P(correct | singleton) with rigorous CP bounds
251
+ if has_correct and n_singletons > 0:
252
+ correct_bounds = op_bounds.rate_bounds["correct_in_singleton"]
253
+ n_correct = correct_bounds.n_successes
254
+
255
+ # Conditional rate and CP interval
256
+ rate = n_correct / n_singletons if n_singletons > 0 else 0.0
257
+ ci = cp_interval(n_correct, n_singletons)
258
+
259
+ print(f" P(correct | singleton) = {rate:.3f} 95% CI: [{ci['lower']:.3f}, {ci['upper']:.3f}]")
260
+
261
+ # P(error | singleton) with rigorous CP bounds
262
+ if has_error and n_singletons > 0:
263
+ error_bounds = op_bounds.rate_bounds["error_in_singleton"]
264
+ n_error = error_bounds.n_successes
265
+
266
+ # Conditional rate and CP interval
267
+ rate = n_error / n_singletons if n_singletons > 0 else 0.0
268
+ ci = cp_interval(n_error, n_singletons)
269
+
270
+ print(f" P(error | singleton) = {rate:.3f} 95% CI: [{ci['lower']:.3f}, {ci['upper']:.3f}]")
271
+
272
+ # Store in summary
273
+ summary[class_label] = {
274
+ "n": n,
275
+ "alpha_target": alpha_target,
276
+ "alpha_corrected": alpha_corrected,
277
+ "threshold": threshold,
278
+ "calibration_stats": {
279
+ "abstentions": abstentions,
280
+ "singletons": singletons,
281
+ "doublets": doublets,
300
282
  },
301
- "doublets": {"count": doub_m_count, "rate": doub_m_rate, "ci_95": (doub_m_lo, doub_m_hi)},
302
- "pac_bounds": pac_m,
283
+ "pac_bounds": pac_bounds,
303
284
  }
304
- summary["marginal"] = marginal_summary
285
+ if operational_bounds_per_class and class_label in operational_bounds_per_class:
286
+ summary[class_label]["operational_bounds"] = operational_bounds_per_class[class_label]
305
287
 
288
+ # ==================== MARGINAL STATISTICS ====================
289
+ if marginal_operational_bounds is not None:
306
290
  if verbose:
307
291
  print(f"\n{'=' * 80}")
308
- print("MARGINAL ANALYSIS (Deployment View - Ignores True Labels)")
292
+ print("MARGINAL STATISTICS (Deployment View - Ignores True Labels)")
309
293
  print(f"{'=' * 80}")
310
- print(f" Total samples: {n_total}")
311
-
312
- print("\nOverall Coverage:")
313
- print(f" Covered: {cov_count:4d} / {n_total:4d} = {pct(cov_rate)} 95% CI: [{cov_lo:.4f}, {cov_hi:.4f}]")
314
-
315
- print("\nPrediction Set Distribution:")
316
- print(
317
- f" Abstentions: {abst_m_count:4d} / {n_total:4d} = {pct(abst_m_rate)} "
318
- f"95% CI: [{abst_m_lo:.4f}, {abst_m_hi:.4f}]"
319
- )
320
- print(
321
- f" Singletons: {sing_total:4d} / {n_total:4d} = {pct(sing_m_rate)} "
322
- f"95% CI: [{sing_m_lo:.4f}, {sing_m_hi:.4f}]"
323
- )
324
- print(f" ├─ Pred 0: {s0_count:4d} / {n_total:4d} = {pct(s0_rate)} 95% CI: [{s0_lo:.4f}, {s0_hi:.4f}]")
325
- print(f" ├─ Pred 1: {s1_count:4d} / {n_total:4d} = {pct(s1_rate)} 95% CI: [{s1_lo:.4f}, {s1_hi:.4f}]")
326
- print(
327
- f" ├─ Errors (overall): {s_err_count:4d} / {sing_total:4d} = {pct(s_err_rate)} "
328
- f"95% CI: [{s_err_lo:.4f}, {s_err_hi:.4f}]"
329
- )
330
- print(
331
- f" ├─ Pred 0 errors: {err_pred0_count:4d} / {s0_count:4d} = {pct(pred0_err_ci['proportion'])} "
332
- f"95% CI: [{pred0_err_ci['lower']:.4f}, {pred0_err_ci['upper']:.4f}]"
333
- )
334
- print(
335
- f" └─ Pred 1 errors: {err_pred1_count:4d} / {s1_count:4d} = {pct(pred1_err_ci['proportion'])} "
336
- f"95% CI: [{pred1_err_ci['lower']:.4f}, {pred1_err_ci['upper']:.4f}]"
337
- )
338
-
339
- print(
340
- f" Doublets: {doub_m_count:4d} / {n_total:4d} = {pct(doub_m_rate)} "
341
- f"95% CI: [{doub_m_lo:.4f}, {doub_m_hi:.4f}]"
342
- )
343
-
344
- if pac_m and pac_m.get("rho", None) is not None:
345
- aw = pac_m.get("alpha_weighted", None)
346
- aw_str = f"{float(aw):.3f}" if aw is not None else "n/a"
347
- print(f"\n Overall PAC Bounds (weighted α={aw_str}):")
348
- print(f" ρ = {pac_m.get('rho', 0):.3f}, κ = {pac_m.get('kappa', 0):.3f}")
349
- if "alpha_singlet_bound" in pac_m and "alpha_singlet_observed" in pac_m:
350
- bound = float(pac_m["alpha_singlet_bound"])
351
- observed = float(pac_m["alpha_singlet_observed"])
352
- ok = "✓" if observed <= bound else "✗"
353
- print(f" α'_bound: {bound:.4f}")
354
- print(f" α'_observed: {observed:.4f} {ok}")
294
+ print(f" Total samples: n = {marginal_operational_bounds.n_calibration}")
295
+
296
+ print("\n ✅ RIGOROUS Marginal Bounds (LOO-CV)")
297
+ print(f" CI width: {marginal_operational_bounds.ci_width:.1%}")
298
+ print(f" Total evaluations: n = {marginal_operational_bounds.n_calibration}")
299
+
300
+ # Show main rates
301
+ for rate_name in ["abstention", "singleton", "doublet"]:
302
+ if rate_name in marginal_operational_bounds.rate_bounds:
303
+ bounds = marginal_operational_bounds.rate_bounds[rate_name]
304
+ if verbose:
305
+ print(f"\n {rate_name.upper()}:")
306
+ print(f" Bounds: [{bounds.lower_bound:.3f}, {bounds.upper_bound:.3f}]")
307
+ print(f" Count: {bounds.n_successes}/{bounds.n_evaluations}")
308
+
309
+ # Show conditional singleton rates (marginal)
310
+ has_correct = "correct_in_singleton" in marginal_operational_bounds.rate_bounds
311
+ has_error = "error_in_singleton" in marginal_operational_bounds.rate_bounds
312
+ has_singleton = "singleton" in marginal_operational_bounds.rate_bounds
313
+
314
+ if verbose and (has_correct or has_error) and has_singleton:
315
+ print("\n CONDITIONAL RATES (conditioned on singleton, with CP+PAC bounds):")
316
+
317
+ singleton_bounds = marginal_operational_bounds.rate_bounds["singleton"]
318
+ n_singletons = singleton_bounds.n_successes
319
+
320
+ if has_correct and n_singletons > 0:
321
+ correct_bounds = marginal_operational_bounds.rate_bounds["correct_in_singleton"]
322
+ n_correct = correct_bounds.n_successes
323
+
324
+ # Conditional rate and CP interval
325
+ rate = n_correct / n_singletons if n_singletons > 0 else 0.0
326
+ ci = cp_interval(n_correct, n_singletons)
327
+
328
+ print(f" P(correct | singleton) = {rate:.3f} 95% CI: [{ci['lower']:.3f}, {ci['upper']:.3f}]")
329
+
330
+ if has_error and n_singletons > 0:
331
+ error_bounds = marginal_operational_bounds.rate_bounds["error_in_singleton"]
332
+ n_error = error_bounds.n_successes
333
+
334
+ # Conditional rate and CP interval
335
+ rate = n_error / n_singletons if n_singletons > 0 else 0.0
336
+ ci = cp_interval(n_error, n_singletons)
337
+
338
+ print(f" P(error | singleton) = {rate:.3f} 95% CI: [{ci['lower']:.3f}, {ci['upper']:.3f}]")
339
+
340
+ summary["marginal_bounds"] = marginal_operational_bounds
341
+
342
+ if verbose:
343
+ # Deployment interpretation
344
+ sing_bounds = marginal_operational_bounds.rate_bounds.get("singleton")
345
+ doub_bounds = marginal_operational_bounds.rate_bounds.get("doublet")
346
+ abst_bounds = marginal_operational_bounds.rate_bounds.get("abstention")
347
+
348
+ if sing_bounds:
349
+ print("\n 📈 Deployment Expectations:")
350
+ print(
351
+ f" Automation (singletons): "
352
+ f"{sing_bounds.lower_bound:.1%} - {sing_bounds.upper_bound:.1%} of cases"
353
+ )
354
+
355
+ # Escalation = doublets + abstentions
356
+ if doub_bounds and abst_bounds:
357
+ esc_lower = doub_bounds.lower_bound + abst_bounds.lower_bound
358
+ esc_upper = doub_bounds.upper_bound + abst_bounds.upper_bound
359
+ print(f" Escalation (doublets+abstentions): {esc_lower:.1%} - {esc_upper:.1%} of cases")
360
+ elif doub_bounds:
361
+ print(
362
+ f" Escalation (doublets): "
363
+ f"{doub_bounds.lower_bound:.1%} - {doub_bounds.upper_bound:.1%} of cases"
364
+ )
365
+
366
+ # ==================== WARNINGS ====================
367
+ if verbose:
368
+ print(f"\n{'=' * 80}")
369
+ print("NOTES")
370
+ print(f"{'=' * 80}")
371
+ print("\n✓ Per-class CIs are valid (Clopper-Pearson, exchangeable within class)")
372
+
373
+ if operational_bounds_per_class or marginal_operational_bounds:
374
+ print("✓ Operational bounds have PAC guarantees via cross-validation")
375
+ else:
376
+ print("\n⚠️ For rigorous deployment bounds, run:")
377
+ print(" - compute_mondrian_operational_bounds() for per-class bounds")
378
+ print(" - compute_marginal_operational_bounds() for marginal bounds")
355
379
 
356
- n_escalations = int(pac_m.get("n_escalations", doub_m_count + abst_m_count))
357
- print("\n Deployment Decision Mix:")
358
- print(f" Automate: {sing_total} singletons ({sing_m_rate:.1%})")
359
- print(f" Escalate: {n_escalations} doublets+abstentions ({n_escalations / n_total:.1%})")
380
+ if prediction_stats.get("marginal") and marginal_operational_bounds is None:
381
+ print("\n⚠️ Marginal stats from calibration data NOT shown (invalid CIs for Mondrian)")
382
+ print(" Use compute_marginal_operational_bounds() for valid marginal bounds")
360
383
 
361
384
  return summary
362
385