ssbc 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ssbc/__init__.py +59 -0
- ssbc/__main__.py +4 -0
- ssbc/cli.py +21 -0
- ssbc/conformal.py +333 -0
- ssbc/core.py +205 -0
- ssbc/hyperparameter.py +258 -0
- ssbc/simulation.py +148 -0
- ssbc/ssbc.py +1 -0
- ssbc/statistics.py +158 -0
- ssbc/utils.py +2 -0
- ssbc/visualization.py +459 -0
- ssbc-0.1.0.dist-info/METADATA +266 -0
- ssbc-0.1.0.dist-info/RECORD +17 -0
- ssbc-0.1.0.dist-info/WHEEL +5 -0
- ssbc-0.1.0.dist-info/entry_points.txt +2 -0
- ssbc-0.1.0.dist-info/licenses/LICENSE +21 -0
- ssbc-0.1.0.dist-info/top_level.txt +1 -0
ssbc/visualization.py
ADDED
@@ -0,0 +1,459 @@
|
|
1
|
+
"""Visualization and reporting utilities for conformal prediction results."""
|
2
|
+
|
3
|
+
from typing import Any
|
4
|
+
|
5
|
+
from .statistics import cp_interval
|
6
|
+
|
7
|
+
|
8
|
+
def report_prediction_stats(
|
9
|
+
prediction_stats: dict[Any, Any], calibration_result: dict[Any, Any], verbose: bool = True
|
10
|
+
) -> dict[str | int, Any]:
|
11
|
+
"""Pretty/robust summary for Mondrian conformal prediction stats.
|
12
|
+
|
13
|
+
Tolerates multiple schema shapes:
|
14
|
+
- dicts with 'rate'/'ci_95' or 'proportion'/'lower'/'upper'
|
15
|
+
- raw ints for counts (e.g., marginal['singletons']['pred_0'] = 339)
|
16
|
+
- per-class singleton correct/incorrect either nested under 'singletons'
|
17
|
+
OR as top-level aliases 'singletons_correct' / 'singletons_incorrect'.
|
18
|
+
|
19
|
+
Also computes Clopper-Pearson CIs when missing, and splits marginal
|
20
|
+
singleton errors by predicted class.
|
21
|
+
|
22
|
+
Parameters
|
23
|
+
----------
|
24
|
+
prediction_stats : dict
|
25
|
+
Output from mondrian_conformal_calibrate (second return value)
|
26
|
+
calibration_result : dict
|
27
|
+
Output from mondrian_conformal_calibrate (first return value)
|
28
|
+
verbose : bool, default=True
|
29
|
+
If True, print detailed statistics to stdout
|
30
|
+
|
31
|
+
Returns
|
32
|
+
-------
|
33
|
+
dict
|
34
|
+
Structured summary with CIs for all metrics, containing:
|
35
|
+
- Keys 0, 1 for per-class statistics
|
36
|
+
- Key 'marginal' for overall deployment statistics
|
37
|
+
|
38
|
+
Examples
|
39
|
+
--------
|
40
|
+
>>> # After calibration
|
41
|
+
>>> cal_result, pred_stats = mondrian_conformal_calibrate(...)
|
42
|
+
>>> summary = report_prediction_stats(pred_stats, cal_result, verbose=True)
|
43
|
+
>>> print(summary['marginal']['coverage']['rate'])
|
44
|
+
"""
|
45
|
+
|
46
|
+
# Helper functions
|
47
|
+
def as_dict(x: Any) -> dict[str, Any]:
|
48
|
+
"""Ensure x is a dict."""
|
49
|
+
return x if isinstance(x, dict) else {}
|
50
|
+
|
51
|
+
def get_count(x: Any, default: int = 0) -> int:
|
52
|
+
"""Extract count from dict or int."""
|
53
|
+
if isinstance(x, dict):
|
54
|
+
return int(x.get("count", default))
|
55
|
+
if isinstance(x, int):
|
56
|
+
return int(x)
|
57
|
+
return default
|
58
|
+
|
59
|
+
def get_rate(x: Any, default: float | None = 0.0) -> float | None:
|
60
|
+
"""Extract rate from dict or float."""
|
61
|
+
if isinstance(x, dict):
|
62
|
+
if "rate" in x:
|
63
|
+
return float(x["rate"])
|
64
|
+
if "proportion" in x:
|
65
|
+
return float(x["proportion"])
|
66
|
+
return default
|
67
|
+
if isinstance(x, float):
|
68
|
+
return float(x)
|
69
|
+
return default
|
70
|
+
|
71
|
+
def get_ci_tuple(x: Any) -> tuple[float, float]:
|
72
|
+
"""Extract CI bounds from dict."""
|
73
|
+
if not isinstance(x, dict):
|
74
|
+
return (0.0, 0.0)
|
75
|
+
if "ci_95" in x and isinstance(x["ci_95"], tuple | list) and len(x["ci_95"]) == 2:
|
76
|
+
return float(x["ci_95"][0]), float(x["ci_95"][1])
|
77
|
+
lo = x.get("lower", 0.0)
|
78
|
+
hi = x.get("upper", 0.0)
|
79
|
+
return float(lo), float(hi)
|
80
|
+
|
81
|
+
def ensure_ci(d: dict[str, Any], count: int, total: int) -> tuple[float, float, float]:
|
82
|
+
"""Return (rate, lo, hi). If d already has rate/CI, use them; else compute CP from count/total."""
|
83
|
+
r = get_rate(d, default=None)
|
84
|
+
lo, hi = get_ci_tuple(d)
|
85
|
+
if r is None or (lo == 0.0 and hi == 0.0 and (count > 0 or total > 0)):
|
86
|
+
ci = cp_interval(count, total)
|
87
|
+
return ci["proportion"], ci["lower"], ci["upper"]
|
88
|
+
return float(r), float(lo), float(hi)
|
89
|
+
|
90
|
+
def pct(x: float) -> str:
|
91
|
+
"""Format percentage."""
|
92
|
+
return f"{x:6.2%}"
|
93
|
+
|
94
|
+
summary: dict[str | int, Any] = {}
|
95
|
+
|
96
|
+
if verbose:
|
97
|
+
print("=" * 80)
|
98
|
+
print("PREDICTION SET STATISTICS (All rates with 95% Clopper-Pearson CIs)")
|
99
|
+
print("=" * 80)
|
100
|
+
|
101
|
+
# ----------------- per-class (conditioned on Y) -----------------
|
102
|
+
for class_label in [0, 1]:
|
103
|
+
if class_label not in prediction_stats:
|
104
|
+
continue
|
105
|
+
cls = prediction_stats[class_label]
|
106
|
+
|
107
|
+
if isinstance(cls, dict) and "error" in cls:
|
108
|
+
if verbose:
|
109
|
+
print(f"\nClass {class_label}: {cls['error']}")
|
110
|
+
summary[class_label] = {"error": cls["error"]}
|
111
|
+
continue
|
112
|
+
|
113
|
+
n = int(cls.get("n", cls.get("n_class", 0)))
|
114
|
+
alpha_target = cls.get("alpha_target", calibration_result.get(class_label, {}).get("alpha_target", None))
|
115
|
+
delta = cls.get("delta", calibration_result.get(class_label, {}).get("delta", None))
|
116
|
+
|
117
|
+
abst = as_dict(cls.get("abstentions", {}))
|
118
|
+
sing = as_dict(cls.get("singletons", {}))
|
119
|
+
# Accept both nested and flat aliases
|
120
|
+
sing_corr = as_dict(sing.get("correct", cls.get("singletons_correct", {})))
|
121
|
+
sing_inc = as_dict(sing.get("incorrect", cls.get("singletons_incorrect", {})))
|
122
|
+
doub = as_dict(cls.get("doublets", {}))
|
123
|
+
pac = as_dict(cls.get("pac_bounds", {}))
|
124
|
+
|
125
|
+
# Counts
|
126
|
+
abst_count = get_count(abst)
|
127
|
+
sing_count = get_count(sing)
|
128
|
+
sing_corr_count = get_count(sing_corr)
|
129
|
+
sing_inc_count = get_count(sing_inc)
|
130
|
+
doub_count = get_count(doub)
|
131
|
+
|
132
|
+
# Ensure rates/CIs (fallback to CP if missing)
|
133
|
+
abst_rate, abst_lo, abst_hi = ensure_ci(abst, abst_count, n)
|
134
|
+
sing_rate, sing_lo, sing_hi = ensure_ci(sing, sing_count, n)
|
135
|
+
sing_corr_rate, sing_corr_lo, sing_corr_hi = ensure_ci(sing_corr, sing_corr_count, n)
|
136
|
+
sing_inc_rate, sing_inc_lo, sing_inc_hi = ensure_ci(sing_inc, sing_inc_count, n)
|
137
|
+
doub_rate, doub_lo, doub_hi = ensure_ci(doub, doub_count, n)
|
138
|
+
|
139
|
+
# P(error | singleton, Y=class)
|
140
|
+
err_given_single_ci = cp_interval(sing_inc_count, sing_count if sing_count > 0 else 1)
|
141
|
+
|
142
|
+
class_summary = {
|
143
|
+
"n": n,
|
144
|
+
"alpha_target": alpha_target,
|
145
|
+
"delta": delta,
|
146
|
+
"abstentions": {"count": abst_count, "rate": abst_rate, "ci_95": (abst_lo, abst_hi)},
|
147
|
+
"singletons": {
|
148
|
+
"count": sing_count,
|
149
|
+
"rate": sing_rate,
|
150
|
+
"ci_95": (sing_lo, sing_hi),
|
151
|
+
"correct": {"count": sing_corr_count, "rate": sing_corr_rate, "ci_95": (sing_corr_lo, sing_corr_hi)},
|
152
|
+
"incorrect": {"count": sing_inc_count, "rate": sing_inc_rate, "ci_95": (sing_inc_lo, sing_inc_hi)},
|
153
|
+
"error_given_singleton": {
|
154
|
+
"count": sing_inc_count,
|
155
|
+
"denom": sing_count,
|
156
|
+
"rate": err_given_single_ci["proportion"],
|
157
|
+
"ci_95": (err_given_single_ci["lower"], err_given_single_ci["upper"]),
|
158
|
+
},
|
159
|
+
},
|
160
|
+
"doublets": {"count": doub_count, "rate": doub_rate, "ci_95": (doub_lo, doub_hi)},
|
161
|
+
"pac_bounds": pac,
|
162
|
+
}
|
163
|
+
summary[class_label] = class_summary
|
164
|
+
|
165
|
+
if verbose:
|
166
|
+
print(f"\n{'=' * 80}")
|
167
|
+
print(f"CLASS {class_label} (Conditioned on True Label = {class_label})")
|
168
|
+
print(f"{'=' * 80}")
|
169
|
+
alpha_str = f"{alpha_target:.3f}" if alpha_target is not None else "n/a"
|
170
|
+
delta_str = f"{delta:.3f}" if delta is not None else "n/a"
|
171
|
+
print(f" n={n}, α_target={alpha_str}, δ={delta_str}")
|
172
|
+
|
173
|
+
print("\nPrediction Set Breakdown:")
|
174
|
+
print(
|
175
|
+
f" Abstentions: {abst_count:4d} / {n:4d} = {pct(abst_rate)} 95% CI: [{abst_lo:.4f}, {abst_hi:.4f}]"
|
176
|
+
)
|
177
|
+
print(
|
178
|
+
f" Singletons: {sing_count:4d} / {n:4d} = {pct(sing_rate)} 95% CI: [{sing_lo:.4f}, {sing_hi:.4f}]"
|
179
|
+
)
|
180
|
+
print(
|
181
|
+
f" ├─ Correct: {sing_corr_count:4d} / {n:4d} = {pct(sing_corr_rate)} "
|
182
|
+
f"95% CI: [{sing_corr_lo:.4f}, {sing_corr_hi:.4f}]"
|
183
|
+
)
|
184
|
+
print(
|
185
|
+
f" └─ Incorrect: {sing_inc_count:4d} / {n:4d} = {pct(sing_inc_rate)} "
|
186
|
+
f"95% CI: [{sing_inc_lo:.4f}, {sing_inc_hi:.4f}]"
|
187
|
+
)
|
188
|
+
|
189
|
+
print(
|
190
|
+
f" Singleton error | Y={class_label}: "
|
191
|
+
f"{sing_inc_count:4d} / {sing_count:4d} = {pct(err_given_single_ci['proportion'])} "
|
192
|
+
f"95% CI: [{err_given_single_ci['lower']:.4f}, {err_given_single_ci['upper']:.4f}]"
|
193
|
+
)
|
194
|
+
|
195
|
+
print(
|
196
|
+
f"\n Doublets: {doub_count:4d} / {n:4d} = {pct(doub_rate)} 95% CI: [{doub_lo:.4f}, {doub_hi:.4f}]"
|
197
|
+
)
|
198
|
+
|
199
|
+
if pac and pac.get("rho", None) is not None:
|
200
|
+
print(f"\n PAC Singleton Error Rate (δ={delta_str}):")
|
201
|
+
print(f" ρ = {pac.get('rho', 0):.3f}, κ = {pac.get('kappa', 0):.3f}")
|
202
|
+
if "alpha_singlet_bound" in pac and "alpha_singlet_observed" in pac:
|
203
|
+
bound = float(pac["alpha_singlet_bound"])
|
204
|
+
observed = float(pac["alpha_singlet_observed"])
|
205
|
+
ok = "✓" if observed <= bound else "✗"
|
206
|
+
print(f" α'_bound: {bound:.4f}")
|
207
|
+
print(f" α'_observed: {observed:.4f} {ok}")
|
208
|
+
|
209
|
+
# ----------------- marginal / deployment view -----------------
|
210
|
+
if "marginal" in prediction_stats:
|
211
|
+
marg = prediction_stats["marginal"]
|
212
|
+
n_total = int(marg["n_total"])
|
213
|
+
|
214
|
+
cov = as_dict(marg.get("coverage", {}))
|
215
|
+
abst_m = as_dict(marg.get("abstentions", {}))
|
216
|
+
sing_m = as_dict(marg.get("singletons", {}))
|
217
|
+
doub_m = as_dict(marg.get("doublets", {}))
|
218
|
+
pac_m = as_dict(marg.get("pac_bounds", {}))
|
219
|
+
|
220
|
+
cov_count = get_count(cov)
|
221
|
+
abst_m_count = get_count(abst_m)
|
222
|
+
sing_total = get_count(sing_m)
|
223
|
+
doub_m_count = get_count(doub_m)
|
224
|
+
|
225
|
+
cov_rate, cov_lo, cov_hi = ensure_ci(cov, cov_count, n_total)
|
226
|
+
abst_m_rate, abst_m_lo, abst_m_hi = ensure_ci(abst_m, abst_m_count, n_total)
|
227
|
+
sing_m_rate, sing_m_lo, sing_m_hi = ensure_ci(sing_m, sing_total, n_total)
|
228
|
+
doub_m_rate, doub_m_lo, doub_m_hi = ensure_ci(doub_m, doub_m_count, n_total)
|
229
|
+
|
230
|
+
# pred_0 / pred_1 may be dicts or ints (counts)
|
231
|
+
raw_s0 = sing_m.get("pred_0", 0)
|
232
|
+
raw_s1 = sing_m.get("pred_1", 0)
|
233
|
+
s0_count = get_count(raw_s0)
|
234
|
+
s1_count = get_count(raw_s1)
|
235
|
+
|
236
|
+
# Prefer provided rate/CI, else compute off n_total
|
237
|
+
if isinstance(raw_s0, dict):
|
238
|
+
s0_rate, s0_lo, s0_hi = ensure_ci(raw_s0, s0_count, n_total)
|
239
|
+
else:
|
240
|
+
s0_ci = cp_interval(s0_count, n_total)
|
241
|
+
s0_rate, s0_lo, s0_hi = s0_ci["proportion"], s0_ci["lower"], s0_ci["upper"]
|
242
|
+
|
243
|
+
if isinstance(raw_s1, dict):
|
244
|
+
s1_rate, s1_lo, s1_hi = ensure_ci(raw_s1, s1_count, n_total)
|
245
|
+
else:
|
246
|
+
s1_ci = cp_interval(s1_count, n_total)
|
247
|
+
s1_rate, s1_lo, s1_hi = s1_ci["proportion"], s1_ci["lower"], s1_ci["upper"]
|
248
|
+
|
249
|
+
# Overall singleton errors (dict or int). Denominator should be sing_total.
|
250
|
+
raw_s_err = sing_m.get("errors", 0)
|
251
|
+
s_err_count = get_count(raw_s_err)
|
252
|
+
if isinstance(raw_s_err, dict):
|
253
|
+
s_err_rate, s_err_lo, s_err_hi = ensure_ci(raw_s_err, s_err_count, sing_total if sing_total > 0 else 1)
|
254
|
+
else:
|
255
|
+
se_ci = cp_interval(s_err_count, sing_total if sing_total > 0 else 1)
|
256
|
+
s_err_rate, s_err_lo, s_err_hi = se_ci["proportion"], se_ci["lower"], se_ci["upper"]
|
257
|
+
|
258
|
+
# Errors by predicted class via per-class incorrect singletons
|
259
|
+
# (pred 0 errors happen when Y=1 singleton is wrong; pred 1 errors when Y=0 singleton is wrong)
|
260
|
+
err_pred0_count = int(
|
261
|
+
prediction_stats.get(1, {})
|
262
|
+
.get("singletons", {})
|
263
|
+
.get("incorrect", prediction_stats.get(1, {}).get("singletons_incorrect", {}))
|
264
|
+
.get("count", 0)
|
265
|
+
)
|
266
|
+
err_pred1_count = int(
|
267
|
+
prediction_stats.get(0, {})
|
268
|
+
.get("singletons", {})
|
269
|
+
.get("incorrect", prediction_stats.get(0, {}).get("singletons_incorrect", {}))
|
270
|
+
.get("count", 0)
|
271
|
+
)
|
272
|
+
pred0_err_ci = cp_interval(err_pred0_count, s0_count if s0_count > 0 else 1)
|
273
|
+
pred1_err_ci = cp_interval(err_pred1_count, s1_count if s1_count > 0 else 1)
|
274
|
+
|
275
|
+
marginal_summary = {
|
276
|
+
"n_total": n_total,
|
277
|
+
"coverage": {"count": cov_count, "rate": cov_rate, "ci_95": (cov_lo, cov_hi)},
|
278
|
+
"abstentions": {"count": abst_m_count, "rate": abst_m_rate, "ci_95": (abst_m_lo, abst_m_hi)},
|
279
|
+
"singletons": {
|
280
|
+
"count": sing_total,
|
281
|
+
"rate": sing_m_rate,
|
282
|
+
"ci_95": (sing_m_lo, sing_m_hi),
|
283
|
+
"pred_0": {"count": s0_count, "rate": s0_rate, "ci_95": (s0_lo, s0_hi)},
|
284
|
+
"pred_1": {"count": s1_count, "rate": s1_rate, "ci_95": (s1_lo, s1_hi)},
|
285
|
+
"errors": {"count": s_err_count, "rate": s_err_rate, "ci_95": (s_err_lo, s_err_hi)},
|
286
|
+
"errors_by_pred": {
|
287
|
+
"pred_0": {
|
288
|
+
"count": err_pred0_count,
|
289
|
+
"denom": s0_count,
|
290
|
+
"rate": pred0_err_ci["proportion"],
|
291
|
+
"ci_95": (pred0_err_ci["lower"], pred0_err_ci["upper"]),
|
292
|
+
},
|
293
|
+
"pred_1": {
|
294
|
+
"count": err_pred1_count,
|
295
|
+
"denom": s1_count,
|
296
|
+
"rate": pred1_err_ci["proportion"],
|
297
|
+
"ci_95": (pred1_err_ci["lower"], pred1_err_ci["upper"]),
|
298
|
+
},
|
299
|
+
},
|
300
|
+
},
|
301
|
+
"doublets": {"count": doub_m_count, "rate": doub_m_rate, "ci_95": (doub_m_lo, doub_m_hi)},
|
302
|
+
"pac_bounds": pac_m,
|
303
|
+
}
|
304
|
+
summary["marginal"] = marginal_summary
|
305
|
+
|
306
|
+
if verbose:
|
307
|
+
print(f"\n{'=' * 80}")
|
308
|
+
print("MARGINAL ANALYSIS (Deployment View - Ignores True Labels)")
|
309
|
+
print(f"{'=' * 80}")
|
310
|
+
print(f" Total samples: {n_total}")
|
311
|
+
|
312
|
+
print("\nOverall Coverage:")
|
313
|
+
print(f" Covered: {cov_count:4d} / {n_total:4d} = {pct(cov_rate)} 95% CI: [{cov_lo:.4f}, {cov_hi:.4f}]")
|
314
|
+
|
315
|
+
print("\nPrediction Set Distribution:")
|
316
|
+
print(
|
317
|
+
f" Abstentions: {abst_m_count:4d} / {n_total:4d} = {pct(abst_m_rate)} "
|
318
|
+
f"95% CI: [{abst_m_lo:.4f}, {abst_m_hi:.4f}]"
|
319
|
+
)
|
320
|
+
print(
|
321
|
+
f" Singletons: {sing_total:4d} / {n_total:4d} = {pct(sing_m_rate)} "
|
322
|
+
f"95% CI: [{sing_m_lo:.4f}, {sing_m_hi:.4f}]"
|
323
|
+
)
|
324
|
+
print(f" ├─ Pred 0: {s0_count:4d} / {n_total:4d} = {pct(s0_rate)} 95% CI: [{s0_lo:.4f}, {s0_hi:.4f}]")
|
325
|
+
print(f" ├─ Pred 1: {s1_count:4d} / {n_total:4d} = {pct(s1_rate)} 95% CI: [{s1_lo:.4f}, {s1_hi:.4f}]")
|
326
|
+
print(
|
327
|
+
f" ├─ Errors (overall): {s_err_count:4d} / {sing_total:4d} = {pct(s_err_rate)} "
|
328
|
+
f"95% CI: [{s_err_lo:.4f}, {s_err_hi:.4f}]"
|
329
|
+
)
|
330
|
+
print(
|
331
|
+
f" ├─ Pred 0 errors: {err_pred0_count:4d} / {s0_count:4d} = {pct(pred0_err_ci['proportion'])} "
|
332
|
+
f"95% CI: [{pred0_err_ci['lower']:.4f}, {pred0_err_ci['upper']:.4f}]"
|
333
|
+
)
|
334
|
+
print(
|
335
|
+
f" └─ Pred 1 errors: {err_pred1_count:4d} / {s1_count:4d} = {pct(pred1_err_ci['proportion'])} "
|
336
|
+
f"95% CI: [{pred1_err_ci['lower']:.4f}, {pred1_err_ci['upper']:.4f}]"
|
337
|
+
)
|
338
|
+
|
339
|
+
print(
|
340
|
+
f" Doublets: {doub_m_count:4d} / {n_total:4d} = {pct(doub_m_rate)} "
|
341
|
+
f"95% CI: [{doub_m_lo:.4f}, {doub_m_hi:.4f}]"
|
342
|
+
)
|
343
|
+
|
344
|
+
if pac_m and pac_m.get("rho", None) is not None:
|
345
|
+
aw = pac_m.get("alpha_weighted", None)
|
346
|
+
aw_str = f"{float(aw):.3f}" if aw is not None else "n/a"
|
347
|
+
print(f"\n Overall PAC Bounds (weighted α={aw_str}):")
|
348
|
+
print(f" ρ = {pac_m.get('rho', 0):.3f}, κ = {pac_m.get('kappa', 0):.3f}")
|
349
|
+
if "alpha_singlet_bound" in pac_m and "alpha_singlet_observed" in pac_m:
|
350
|
+
bound = float(pac_m["alpha_singlet_bound"])
|
351
|
+
observed = float(pac_m["alpha_singlet_observed"])
|
352
|
+
ok = "✓" if observed <= bound else "✗"
|
353
|
+
print(f" α'_bound: {bound:.4f}")
|
354
|
+
print(f" α'_observed: {observed:.4f} {ok}")
|
355
|
+
|
356
|
+
n_escalations = int(pac_m.get("n_escalations", doub_m_count + abst_m_count))
|
357
|
+
print("\n Deployment Decision Mix:")
|
358
|
+
print(f" Automate: {sing_total} singletons ({sing_m_rate:.1%})")
|
359
|
+
print(f" Escalate: {n_escalations} doublets+abstentions ({n_escalations / n_total:.1%})")
|
360
|
+
|
361
|
+
return summary
|
362
|
+
|
363
|
+
|
364
|
+
def plot_parallel_coordinates_plotly(
|
365
|
+
df,
|
366
|
+
columns: list[str] | None = None,
|
367
|
+
color: str = "err_all",
|
368
|
+
color_continuous_scale=None,
|
369
|
+
title: str = "Mondrian sweep – interactive parallel coordinates",
|
370
|
+
height: int = 600,
|
371
|
+
base_opacity: float = 0.9,
|
372
|
+
unselected_opacity: float = 0.06,
|
373
|
+
):
|
374
|
+
"""Create interactive parallel coordinates plot for hyperparameter sweep results.
|
375
|
+
|
376
|
+
Parameters
|
377
|
+
----------
|
378
|
+
df : pd.DataFrame
|
379
|
+
DataFrame with hyperparameter sweep results
|
380
|
+
columns : list of str, optional
|
381
|
+
Columns to display in parallel coordinates
|
382
|
+
Default: ['a0','d0','a1','d1','cov','sing_rate','err_all','err_pred0','err_pred1','err_y0','err_y1','esc_rate']
|
383
|
+
color : str, default='err_all'
|
384
|
+
Column to use for coloring lines
|
385
|
+
color_continuous_scale : plotly colorscale, optional
|
386
|
+
Color scale for the lines
|
387
|
+
title : str, default="Mondrian sweep – interactive parallel coordinates"
|
388
|
+
Plot title
|
389
|
+
height : int, default=600
|
390
|
+
Plot height in pixels
|
391
|
+
base_opacity : float, default=0.9
|
392
|
+
Opacity of selected lines
|
393
|
+
unselected_opacity : float, default=0.06
|
394
|
+
Opacity of unselected lines (creates contrast)
|
395
|
+
|
396
|
+
Returns
|
397
|
+
-------
|
398
|
+
plotly.graph_objects.Figure
|
399
|
+
Interactive plotly figure
|
400
|
+
|
401
|
+
Examples
|
402
|
+
--------
|
403
|
+
>>> import pandas as pd
|
404
|
+
>>> df = sweep_hyperparams_and_collect(...)
|
405
|
+
>>> fig = plot_parallel_coordinates_plotly(df, color='err_all')
|
406
|
+
>>> fig.show() # In notebook
|
407
|
+
>>> # Or save: fig.write_html("sweep_results.html")
|
408
|
+
"""
|
409
|
+
import plotly.express as px
|
410
|
+
|
411
|
+
if columns is None:
|
412
|
+
default_cols = [
|
413
|
+
"a0",
|
414
|
+
"d0",
|
415
|
+
"a1",
|
416
|
+
"d1",
|
417
|
+
"cov",
|
418
|
+
"sing_rate",
|
419
|
+
"err_all",
|
420
|
+
"err_pred0",
|
421
|
+
"err_pred1",
|
422
|
+
"err_y0",
|
423
|
+
"err_y1",
|
424
|
+
"esc_rate",
|
425
|
+
]
|
426
|
+
columns = [c for c in default_cols if c in df.columns]
|
427
|
+
|
428
|
+
fig = px.parallel_coordinates(
|
429
|
+
df,
|
430
|
+
dimensions=columns,
|
431
|
+
color=color if color in df.columns else None,
|
432
|
+
color_continuous_scale=color_continuous_scale or px.colors.sequential.Blugrn,
|
433
|
+
labels={c: c for c in columns},
|
434
|
+
)
|
435
|
+
|
436
|
+
# Maximize contrast between selected and unselected lines
|
437
|
+
if fig.data:
|
438
|
+
# Fade unselected lines
|
439
|
+
fig.data[0].unselected.update(line=dict(color=f"rgba(1,1,1,{float(unselected_opacity)})"))
|
440
|
+
|
441
|
+
fig.update_layout(
|
442
|
+
title=title,
|
443
|
+
height=height,
|
444
|
+
margin=dict(l=40, r=40, t=60, b=40),
|
445
|
+
plot_bgcolor="white",
|
446
|
+
paper_bgcolor="white",
|
447
|
+
font=dict(size=14),
|
448
|
+
uirevision=True, # keep user brushing across updates
|
449
|
+
)
|
450
|
+
|
451
|
+
# Make axis labels and ranges more readable
|
452
|
+
fig.update_traces(labelfont=dict(size=14), rangefont=dict(size=12), tickfont=dict(size=12))
|
453
|
+
|
454
|
+
# Optional: title for colorbar if we're coloring by a column
|
455
|
+
if color in df.columns and fig.data and getattr(fig.data[0], "line", None):
|
456
|
+
if getattr(fig.data[0].line, "colorbar", None) is not None:
|
457
|
+
fig.data[0].line.colorbar.title = color
|
458
|
+
|
459
|
+
return fig
|