ssbc 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ssbc/visualization.py ADDED
@@ -0,0 +1,459 @@
1
+ """Visualization and reporting utilities for conformal prediction results."""
2
+
3
+ from typing import Any
4
+
5
+ from .statistics import cp_interval
6
+
7
+
8
+ def report_prediction_stats(
9
+ prediction_stats: dict[Any, Any], calibration_result: dict[Any, Any], verbose: bool = True
10
+ ) -> dict[str | int, Any]:
11
+ """Pretty/robust summary for Mondrian conformal prediction stats.
12
+
13
+ Tolerates multiple schema shapes:
14
+ - dicts with 'rate'/'ci_95' or 'proportion'/'lower'/'upper'
15
+ - raw ints for counts (e.g., marginal['singletons']['pred_0'] = 339)
16
+ - per-class singleton correct/incorrect either nested under 'singletons'
17
+ OR as top-level aliases 'singletons_correct' / 'singletons_incorrect'.
18
+
19
+ Also computes Clopper-Pearson CIs when missing, and splits marginal
20
+ singleton errors by predicted class.
21
+
22
+ Parameters
23
+ ----------
24
+ prediction_stats : dict
25
+ Output from mondrian_conformal_calibrate (second return value)
26
+ calibration_result : dict
27
+ Output from mondrian_conformal_calibrate (first return value)
28
+ verbose : bool, default=True
29
+ If True, print detailed statistics to stdout
30
+
31
+ Returns
32
+ -------
33
+ dict
34
+ Structured summary with CIs for all metrics, containing:
35
+ - Keys 0, 1 for per-class statistics
36
+ - Key 'marginal' for overall deployment statistics
37
+
38
+ Examples
39
+ --------
40
+ >>> # After calibration
41
+ >>> cal_result, pred_stats = mondrian_conformal_calibrate(...)
42
+ >>> summary = report_prediction_stats(pred_stats, cal_result, verbose=True)
43
+ >>> print(summary['marginal']['coverage']['rate'])
44
+ """
45
+
46
+ # Helper functions
47
+ def as_dict(x: Any) -> dict[str, Any]:
48
+ """Ensure x is a dict."""
49
+ return x if isinstance(x, dict) else {}
50
+
51
+ def get_count(x: Any, default: int = 0) -> int:
52
+ """Extract count from dict or int."""
53
+ if isinstance(x, dict):
54
+ return int(x.get("count", default))
55
+ if isinstance(x, int):
56
+ return int(x)
57
+ return default
58
+
59
+ def get_rate(x: Any, default: float | None = 0.0) -> float | None:
60
+ """Extract rate from dict or float."""
61
+ if isinstance(x, dict):
62
+ if "rate" in x:
63
+ return float(x["rate"])
64
+ if "proportion" in x:
65
+ return float(x["proportion"])
66
+ return default
67
+ if isinstance(x, float):
68
+ return float(x)
69
+ return default
70
+
71
+ def get_ci_tuple(x: Any) -> tuple[float, float]:
72
+ """Extract CI bounds from dict."""
73
+ if not isinstance(x, dict):
74
+ return (0.0, 0.0)
75
+ if "ci_95" in x and isinstance(x["ci_95"], tuple | list) and len(x["ci_95"]) == 2:
76
+ return float(x["ci_95"][0]), float(x["ci_95"][1])
77
+ lo = x.get("lower", 0.0)
78
+ hi = x.get("upper", 0.0)
79
+ return float(lo), float(hi)
80
+
81
+ def ensure_ci(d: dict[str, Any], count: int, total: int) -> tuple[float, float, float]:
82
+ """Return (rate, lo, hi). If d already has rate/CI, use them; else compute CP from count/total."""
83
+ r = get_rate(d, default=None)
84
+ lo, hi = get_ci_tuple(d)
85
+ if r is None or (lo == 0.0 and hi == 0.0 and (count > 0 or total > 0)):
86
+ ci = cp_interval(count, total)
87
+ return ci["proportion"], ci["lower"], ci["upper"]
88
+ return float(r), float(lo), float(hi)
89
+
90
+ def pct(x: float) -> str:
91
+ """Format percentage."""
92
+ return f"{x:6.2%}"
93
+
94
+ summary: dict[str | int, Any] = {}
95
+
96
+ if verbose:
97
+ print("=" * 80)
98
+ print("PREDICTION SET STATISTICS (All rates with 95% Clopper-Pearson CIs)")
99
+ print("=" * 80)
100
+
101
+ # ----------------- per-class (conditioned on Y) -----------------
102
+ for class_label in [0, 1]:
103
+ if class_label not in prediction_stats:
104
+ continue
105
+ cls = prediction_stats[class_label]
106
+
107
+ if isinstance(cls, dict) and "error" in cls:
108
+ if verbose:
109
+ print(f"\nClass {class_label}: {cls['error']}")
110
+ summary[class_label] = {"error": cls["error"]}
111
+ continue
112
+
113
+ n = int(cls.get("n", cls.get("n_class", 0)))
114
+ alpha_target = cls.get("alpha_target", calibration_result.get(class_label, {}).get("alpha_target", None))
115
+ delta = cls.get("delta", calibration_result.get(class_label, {}).get("delta", None))
116
+
117
+ abst = as_dict(cls.get("abstentions", {}))
118
+ sing = as_dict(cls.get("singletons", {}))
119
+ # Accept both nested and flat aliases
120
+ sing_corr = as_dict(sing.get("correct", cls.get("singletons_correct", {})))
121
+ sing_inc = as_dict(sing.get("incorrect", cls.get("singletons_incorrect", {})))
122
+ doub = as_dict(cls.get("doublets", {}))
123
+ pac = as_dict(cls.get("pac_bounds", {}))
124
+
125
+ # Counts
126
+ abst_count = get_count(abst)
127
+ sing_count = get_count(sing)
128
+ sing_corr_count = get_count(sing_corr)
129
+ sing_inc_count = get_count(sing_inc)
130
+ doub_count = get_count(doub)
131
+
132
+ # Ensure rates/CIs (fallback to CP if missing)
133
+ abst_rate, abst_lo, abst_hi = ensure_ci(abst, abst_count, n)
134
+ sing_rate, sing_lo, sing_hi = ensure_ci(sing, sing_count, n)
135
+ sing_corr_rate, sing_corr_lo, sing_corr_hi = ensure_ci(sing_corr, sing_corr_count, n)
136
+ sing_inc_rate, sing_inc_lo, sing_inc_hi = ensure_ci(sing_inc, sing_inc_count, n)
137
+ doub_rate, doub_lo, doub_hi = ensure_ci(doub, doub_count, n)
138
+
139
+ # P(error | singleton, Y=class)
140
+ err_given_single_ci = cp_interval(sing_inc_count, sing_count if sing_count > 0 else 1)
141
+
142
+ class_summary = {
143
+ "n": n,
144
+ "alpha_target": alpha_target,
145
+ "delta": delta,
146
+ "abstentions": {"count": abst_count, "rate": abst_rate, "ci_95": (abst_lo, abst_hi)},
147
+ "singletons": {
148
+ "count": sing_count,
149
+ "rate": sing_rate,
150
+ "ci_95": (sing_lo, sing_hi),
151
+ "correct": {"count": sing_corr_count, "rate": sing_corr_rate, "ci_95": (sing_corr_lo, sing_corr_hi)},
152
+ "incorrect": {"count": sing_inc_count, "rate": sing_inc_rate, "ci_95": (sing_inc_lo, sing_inc_hi)},
153
+ "error_given_singleton": {
154
+ "count": sing_inc_count,
155
+ "denom": sing_count,
156
+ "rate": err_given_single_ci["proportion"],
157
+ "ci_95": (err_given_single_ci["lower"], err_given_single_ci["upper"]),
158
+ },
159
+ },
160
+ "doublets": {"count": doub_count, "rate": doub_rate, "ci_95": (doub_lo, doub_hi)},
161
+ "pac_bounds": pac,
162
+ }
163
+ summary[class_label] = class_summary
164
+
165
+ if verbose:
166
+ print(f"\n{'=' * 80}")
167
+ print(f"CLASS {class_label} (Conditioned on True Label = {class_label})")
168
+ print(f"{'=' * 80}")
169
+ alpha_str = f"{alpha_target:.3f}" if alpha_target is not None else "n/a"
170
+ delta_str = f"{delta:.3f}" if delta is not None else "n/a"
171
+ print(f" n={n}, α_target={alpha_str}, δ={delta_str}")
172
+
173
+ print("\nPrediction Set Breakdown:")
174
+ print(
175
+ f" Abstentions: {abst_count:4d} / {n:4d} = {pct(abst_rate)} 95% CI: [{abst_lo:.4f}, {abst_hi:.4f}]"
176
+ )
177
+ print(
178
+ f" Singletons: {sing_count:4d} / {n:4d} = {pct(sing_rate)} 95% CI: [{sing_lo:.4f}, {sing_hi:.4f}]"
179
+ )
180
+ print(
181
+ f" ├─ Correct: {sing_corr_count:4d} / {n:4d} = {pct(sing_corr_rate)} "
182
+ f"95% CI: [{sing_corr_lo:.4f}, {sing_corr_hi:.4f}]"
183
+ )
184
+ print(
185
+ f" └─ Incorrect: {sing_inc_count:4d} / {n:4d} = {pct(sing_inc_rate)} "
186
+ f"95% CI: [{sing_inc_lo:.4f}, {sing_inc_hi:.4f}]"
187
+ )
188
+
189
+ print(
190
+ f" Singleton error | Y={class_label}: "
191
+ f"{sing_inc_count:4d} / {sing_count:4d} = {pct(err_given_single_ci['proportion'])} "
192
+ f"95% CI: [{err_given_single_ci['lower']:.4f}, {err_given_single_ci['upper']:.4f}]"
193
+ )
194
+
195
+ print(
196
+ f"\n Doublets: {doub_count:4d} / {n:4d} = {pct(doub_rate)} 95% CI: [{doub_lo:.4f}, {doub_hi:.4f}]"
197
+ )
198
+
199
+ if pac and pac.get("rho", None) is not None:
200
+ print(f"\n PAC Singleton Error Rate (δ={delta_str}):")
201
+ print(f" ρ = {pac.get('rho', 0):.3f}, κ = {pac.get('kappa', 0):.3f}")
202
+ if "alpha_singlet_bound" in pac and "alpha_singlet_observed" in pac:
203
+ bound = float(pac["alpha_singlet_bound"])
204
+ observed = float(pac["alpha_singlet_observed"])
205
+ ok = "✓" if observed <= bound else "✗"
206
+ print(f" α'_bound: {bound:.4f}")
207
+ print(f" α'_observed: {observed:.4f} {ok}")
208
+
209
+ # ----------------- marginal / deployment view -----------------
210
+ if "marginal" in prediction_stats:
211
+ marg = prediction_stats["marginal"]
212
+ n_total = int(marg["n_total"])
213
+
214
+ cov = as_dict(marg.get("coverage", {}))
215
+ abst_m = as_dict(marg.get("abstentions", {}))
216
+ sing_m = as_dict(marg.get("singletons", {}))
217
+ doub_m = as_dict(marg.get("doublets", {}))
218
+ pac_m = as_dict(marg.get("pac_bounds", {}))
219
+
220
+ cov_count = get_count(cov)
221
+ abst_m_count = get_count(abst_m)
222
+ sing_total = get_count(sing_m)
223
+ doub_m_count = get_count(doub_m)
224
+
225
+ cov_rate, cov_lo, cov_hi = ensure_ci(cov, cov_count, n_total)
226
+ abst_m_rate, abst_m_lo, abst_m_hi = ensure_ci(abst_m, abst_m_count, n_total)
227
+ sing_m_rate, sing_m_lo, sing_m_hi = ensure_ci(sing_m, sing_total, n_total)
228
+ doub_m_rate, doub_m_lo, doub_m_hi = ensure_ci(doub_m, doub_m_count, n_total)
229
+
230
+ # pred_0 / pred_1 may be dicts or ints (counts)
231
+ raw_s0 = sing_m.get("pred_0", 0)
232
+ raw_s1 = sing_m.get("pred_1", 0)
233
+ s0_count = get_count(raw_s0)
234
+ s1_count = get_count(raw_s1)
235
+
236
+ # Prefer provided rate/CI, else compute off n_total
237
+ if isinstance(raw_s0, dict):
238
+ s0_rate, s0_lo, s0_hi = ensure_ci(raw_s0, s0_count, n_total)
239
+ else:
240
+ s0_ci = cp_interval(s0_count, n_total)
241
+ s0_rate, s0_lo, s0_hi = s0_ci["proportion"], s0_ci["lower"], s0_ci["upper"]
242
+
243
+ if isinstance(raw_s1, dict):
244
+ s1_rate, s1_lo, s1_hi = ensure_ci(raw_s1, s1_count, n_total)
245
+ else:
246
+ s1_ci = cp_interval(s1_count, n_total)
247
+ s1_rate, s1_lo, s1_hi = s1_ci["proportion"], s1_ci["lower"], s1_ci["upper"]
248
+
249
+ # Overall singleton errors (dict or int). Denominator should be sing_total.
250
+ raw_s_err = sing_m.get("errors", 0)
251
+ s_err_count = get_count(raw_s_err)
252
+ if isinstance(raw_s_err, dict):
253
+ s_err_rate, s_err_lo, s_err_hi = ensure_ci(raw_s_err, s_err_count, sing_total if sing_total > 0 else 1)
254
+ else:
255
+ se_ci = cp_interval(s_err_count, sing_total if sing_total > 0 else 1)
256
+ s_err_rate, s_err_lo, s_err_hi = se_ci["proportion"], se_ci["lower"], se_ci["upper"]
257
+
258
+ # Errors by predicted class via per-class incorrect singletons
259
+ # (pred 0 errors happen when Y=1 singleton is wrong; pred 1 errors when Y=0 singleton is wrong)
260
+ err_pred0_count = int(
261
+ prediction_stats.get(1, {})
262
+ .get("singletons", {})
263
+ .get("incorrect", prediction_stats.get(1, {}).get("singletons_incorrect", {}))
264
+ .get("count", 0)
265
+ )
266
+ err_pred1_count = int(
267
+ prediction_stats.get(0, {})
268
+ .get("singletons", {})
269
+ .get("incorrect", prediction_stats.get(0, {}).get("singletons_incorrect", {}))
270
+ .get("count", 0)
271
+ )
272
+ pred0_err_ci = cp_interval(err_pred0_count, s0_count if s0_count > 0 else 1)
273
+ pred1_err_ci = cp_interval(err_pred1_count, s1_count if s1_count > 0 else 1)
274
+
275
+ marginal_summary = {
276
+ "n_total": n_total,
277
+ "coverage": {"count": cov_count, "rate": cov_rate, "ci_95": (cov_lo, cov_hi)},
278
+ "abstentions": {"count": abst_m_count, "rate": abst_m_rate, "ci_95": (abst_m_lo, abst_m_hi)},
279
+ "singletons": {
280
+ "count": sing_total,
281
+ "rate": sing_m_rate,
282
+ "ci_95": (sing_m_lo, sing_m_hi),
283
+ "pred_0": {"count": s0_count, "rate": s0_rate, "ci_95": (s0_lo, s0_hi)},
284
+ "pred_1": {"count": s1_count, "rate": s1_rate, "ci_95": (s1_lo, s1_hi)},
285
+ "errors": {"count": s_err_count, "rate": s_err_rate, "ci_95": (s_err_lo, s_err_hi)},
286
+ "errors_by_pred": {
287
+ "pred_0": {
288
+ "count": err_pred0_count,
289
+ "denom": s0_count,
290
+ "rate": pred0_err_ci["proportion"],
291
+ "ci_95": (pred0_err_ci["lower"], pred0_err_ci["upper"]),
292
+ },
293
+ "pred_1": {
294
+ "count": err_pred1_count,
295
+ "denom": s1_count,
296
+ "rate": pred1_err_ci["proportion"],
297
+ "ci_95": (pred1_err_ci["lower"], pred1_err_ci["upper"]),
298
+ },
299
+ },
300
+ },
301
+ "doublets": {"count": doub_m_count, "rate": doub_m_rate, "ci_95": (doub_m_lo, doub_m_hi)},
302
+ "pac_bounds": pac_m,
303
+ }
304
+ summary["marginal"] = marginal_summary
305
+
306
+ if verbose:
307
+ print(f"\n{'=' * 80}")
308
+ print("MARGINAL ANALYSIS (Deployment View - Ignores True Labels)")
309
+ print(f"{'=' * 80}")
310
+ print(f" Total samples: {n_total}")
311
+
312
+ print("\nOverall Coverage:")
313
+ print(f" Covered: {cov_count:4d} / {n_total:4d} = {pct(cov_rate)} 95% CI: [{cov_lo:.4f}, {cov_hi:.4f}]")
314
+
315
+ print("\nPrediction Set Distribution:")
316
+ print(
317
+ f" Abstentions: {abst_m_count:4d} / {n_total:4d} = {pct(abst_m_rate)} "
318
+ f"95% CI: [{abst_m_lo:.4f}, {abst_m_hi:.4f}]"
319
+ )
320
+ print(
321
+ f" Singletons: {sing_total:4d} / {n_total:4d} = {pct(sing_m_rate)} "
322
+ f"95% CI: [{sing_m_lo:.4f}, {sing_m_hi:.4f}]"
323
+ )
324
+ print(f" ├─ Pred 0: {s0_count:4d} / {n_total:4d} = {pct(s0_rate)} 95% CI: [{s0_lo:.4f}, {s0_hi:.4f}]")
325
+ print(f" ├─ Pred 1: {s1_count:4d} / {n_total:4d} = {pct(s1_rate)} 95% CI: [{s1_lo:.4f}, {s1_hi:.4f}]")
326
+ print(
327
+ f" ├─ Errors (overall): {s_err_count:4d} / {sing_total:4d} = {pct(s_err_rate)} "
328
+ f"95% CI: [{s_err_lo:.4f}, {s_err_hi:.4f}]"
329
+ )
330
+ print(
331
+ f" ├─ Pred 0 errors: {err_pred0_count:4d} / {s0_count:4d} = {pct(pred0_err_ci['proportion'])} "
332
+ f"95% CI: [{pred0_err_ci['lower']:.4f}, {pred0_err_ci['upper']:.4f}]"
333
+ )
334
+ print(
335
+ f" └─ Pred 1 errors: {err_pred1_count:4d} / {s1_count:4d} = {pct(pred1_err_ci['proportion'])} "
336
+ f"95% CI: [{pred1_err_ci['lower']:.4f}, {pred1_err_ci['upper']:.4f}]"
337
+ )
338
+
339
+ print(
340
+ f" Doublets: {doub_m_count:4d} / {n_total:4d} = {pct(doub_m_rate)} "
341
+ f"95% CI: [{doub_m_lo:.4f}, {doub_m_hi:.4f}]"
342
+ )
343
+
344
+ if pac_m and pac_m.get("rho", None) is not None:
345
+ aw = pac_m.get("alpha_weighted", None)
346
+ aw_str = f"{float(aw):.3f}" if aw is not None else "n/a"
347
+ print(f"\n Overall PAC Bounds (weighted α={aw_str}):")
348
+ print(f" ρ = {pac_m.get('rho', 0):.3f}, κ = {pac_m.get('kappa', 0):.3f}")
349
+ if "alpha_singlet_bound" in pac_m and "alpha_singlet_observed" in pac_m:
350
+ bound = float(pac_m["alpha_singlet_bound"])
351
+ observed = float(pac_m["alpha_singlet_observed"])
352
+ ok = "✓" if observed <= bound else "✗"
353
+ print(f" α'_bound: {bound:.4f}")
354
+ print(f" α'_observed: {observed:.4f} {ok}")
355
+
356
+ n_escalations = int(pac_m.get("n_escalations", doub_m_count + abst_m_count))
357
+ print("\n Deployment Decision Mix:")
358
+ print(f" Automate: {sing_total} singletons ({sing_m_rate:.1%})")
359
+ print(f" Escalate: {n_escalations} doublets+abstentions ({n_escalations / n_total:.1%})")
360
+
361
+ return summary
362
+
363
+
364
+ def plot_parallel_coordinates_plotly(
365
+ df,
366
+ columns: list[str] | None = None,
367
+ color: str = "err_all",
368
+ color_continuous_scale=None,
369
+ title: str = "Mondrian sweep – interactive parallel coordinates",
370
+ height: int = 600,
371
+ base_opacity: float = 0.9,
372
+ unselected_opacity: float = 0.06,
373
+ ):
374
+ """Create interactive parallel coordinates plot for hyperparameter sweep results.
375
+
376
+ Parameters
377
+ ----------
378
+ df : pd.DataFrame
379
+ DataFrame with hyperparameter sweep results
380
+ columns : list of str, optional
381
+ Columns to display in parallel coordinates
382
+ Default: ['a0','d0','a1','d1','cov','sing_rate','err_all','err_pred0','err_pred1','err_y0','err_y1','esc_rate']
383
+ color : str, default='err_all'
384
+ Column to use for coloring lines
385
+ color_continuous_scale : plotly colorscale, optional
386
+ Color scale for the lines
387
+ title : str, default="Mondrian sweep – interactive parallel coordinates"
388
+ Plot title
389
+ height : int, default=600
390
+ Plot height in pixels
391
+ base_opacity : float, default=0.9
392
+ Opacity of selected lines
393
+ unselected_opacity : float, default=0.06
394
+ Opacity of unselected lines (creates contrast)
395
+
396
+ Returns
397
+ -------
398
+ plotly.graph_objects.Figure
399
+ Interactive plotly figure
400
+
401
+ Examples
402
+ --------
403
+ >>> import pandas as pd
404
+ >>> df = sweep_hyperparams_and_collect(...)
405
+ >>> fig = plot_parallel_coordinates_plotly(df, color='err_all')
406
+ >>> fig.show() # In notebook
407
+ >>> # Or save: fig.write_html("sweep_results.html")
408
+ """
409
+ import plotly.express as px
410
+
411
+ if columns is None:
412
+ default_cols = [
413
+ "a0",
414
+ "d0",
415
+ "a1",
416
+ "d1",
417
+ "cov",
418
+ "sing_rate",
419
+ "err_all",
420
+ "err_pred0",
421
+ "err_pred1",
422
+ "err_y0",
423
+ "err_y1",
424
+ "esc_rate",
425
+ ]
426
+ columns = [c for c in default_cols if c in df.columns]
427
+
428
+ fig = px.parallel_coordinates(
429
+ df,
430
+ dimensions=columns,
431
+ color=color if color in df.columns else None,
432
+ color_continuous_scale=color_continuous_scale or px.colors.sequential.Blugrn,
433
+ labels={c: c for c in columns},
434
+ )
435
+
436
+ # Maximize contrast between selected and unselected lines
437
+ if fig.data:
438
+ # Fade unselected lines
439
+ fig.data[0].unselected.update(line=dict(color=f"rgba(1,1,1,{float(unselected_opacity)})"))
440
+
441
+ fig.update_layout(
442
+ title=title,
443
+ height=height,
444
+ margin=dict(l=40, r=40, t=60, b=40),
445
+ plot_bgcolor="white",
446
+ paper_bgcolor="white",
447
+ font=dict(size=14),
448
+ uirevision=True, # keep user brushing across updates
449
+ )
450
+
451
+ # Make axis labels and ranges more readable
452
+ fig.update_traces(labelfont=dict(size=14), rangefont=dict(size=12), tickfont=dict(size=12))
453
+
454
+ # Optional: title for colorbar if we're coloring by a column
455
+ if color in df.columns and fig.data and getattr(fig.data[0], "line", None):
456
+ if getattr(fig.data[0].line, "colorbar", None) is not None:
457
+ fig.data[0].line.colorbar.title = color
458
+
459
+ return fig