matplotly 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
matplotly/_code_gen.py ADDED
@@ -0,0 +1,1793 @@
1
+ """Generate Python code from the current figure state."""
2
+ from __future__ import annotations
3
+
4
+ from typing import Any
5
+
6
+ import numpy as np
7
+ from matplotlib.colors import to_hex
8
+ from matplotlib.figure import Figure
9
+ from matplotlib.lines import Line2D
10
+
11
+ from ._commands import CommandStack
12
+
13
+
14
+ def generate_code(fig: Figure, stack: CommandStack) -> str:
15
+ """Produce a Python code snippet reproducing the figure's current styling.
16
+
17
+ Emits style modifications that apply on top of the user's original
18
+ plotting code. For distribution plots (boxplot/violinplot), emits
19
+ ax.bxp() calls with compact statistics and styling kwargs.
20
+ """
21
+ all_axes = fig.get_axes()
22
+ # Separate main axes from marginal histogram and colorbar axes
23
+ axes_list = [a for a in all_axes
24
+ if not getattr(a, '_matplotly_marginal', False)
25
+ and not hasattr(a, '_colorbar')]
26
+ marginal_axes = [a for a in all_axes if getattr(a, '_matplotly_marginal', False)]
27
+
28
+ if not axes_list:
29
+ return "# No axes found."
30
+
31
+ multi = len(axes_list) > 1
32
+
33
+ lines: list[str] = ["# Generated by matplotly v4"]
34
+
35
+ if multi:
36
+ lines.append("axes = fig.get_axes()")
37
+
38
+ for ax_i, ax in enumerate(axes_list):
39
+ ax_var = f"axes[{ax_i}]" if multi else "ax"
40
+
41
+ # --- Identify distribution-managed artists (for skip logic) ---
42
+ dist_artist_ids: set[int] = set()
43
+ _dist_infos = getattr(ax, '_matplotly_dist_info', [])
44
+ has_dist = bool(_dist_infos)
45
+
46
+ # --- Identify heatmap-managed artists (for skip logic) ---
47
+ heatmap_artist_ids: set[int] = set()
48
+ for _hi in getattr(ax, '_matplotly_heatmap_info', []):
49
+ m = _hi.get('mappable')
50
+ if m is not None:
51
+ heatmap_artist_ids.add(id(m))
52
+
53
+ if has_dist:
54
+ for art in ax.lines:
55
+ if art.get_label().startswith("_"):
56
+ dist_artist_ids.add(id(art))
57
+ for coll in ax.collections:
58
+ if getattr(coll, '_matplotly_dist', False):
59
+ dist_artist_ids.add(id(coll))
60
+ if getattr(coll, '_matplotly_dist_jitter', False):
61
+ dist_artist_ids.add(id(coll))
62
+ for patch in ax.patches:
63
+ if getattr(patch, '_matplotly_dist', False):
64
+ dist_artist_ids.add(id(patch))
65
+
66
+ # --- Identify errorbar-managed artists (for skip logic) ---
67
+ errorbar_artist_ids: set[int] = set()
68
+ for l in ax.lines:
69
+ if getattr(l, '_matplotly_errorbar', False):
70
+ errorbar_artist_ids.add(id(l))
71
+ for coll in ax.collections:
72
+ if getattr(coll, '_matplotly_errorbar', False):
73
+ errorbar_artist_ids.add(id(coll))
74
+
75
+ # --- Lines (filtered list for stable indexing) ---
76
+ _user_lines = [l for l in ax.lines
77
+ if not l.get_label().startswith("_")
78
+ and id(l) not in errorbar_artist_ids]
79
+ if _user_lines:
80
+ lines.append(
81
+ f"\n_lines = [l for l in {ax_var}.lines "
82
+ f"if not l.get_label().startswith('_')]")
83
+ for j, line in enumerate(_user_lines):
84
+ label = line.get_label()
85
+ acc = f"_lines[{j}]"
86
+ lines.append(f"\n# {label or f'Line {j}'}")
87
+ lines.append(f"{acc}.set_color({_fmt(to_hex(line.get_color()))})")
88
+ lines.append(f"{acc}.set_linewidth({_fmt(line.get_linewidth())})")
89
+ ls = line.get_linestyle()
90
+ if ls != "-":
91
+ lines.append(f"{acc}.set_linestyle({_fmt(ls)})")
92
+ alpha = line.get_alpha()
93
+ if alpha is not None and alpha != 1.0:
94
+ lines.append(f"{acc}.set_alpha({_fmt(alpha)})")
95
+ marker = line.get_marker()
96
+ if marker and marker != "None" and marker != "none":
97
+ lines.append(f"{acc}.set_marker({_fmt(marker)})")
98
+ lines.append(f"{acc}.set_markersize({_fmt(line.get_markersize())})")
99
+
100
+ # --- Histograms ---
101
+ hist_patch_ids: set[int] = set()
102
+ hist_container_idxs: list[int] = []
103
+ from matplotlib.container import BarContainer as _BCt
104
+ from matplotlib.patches import Rectangle as _Rect
105
+
106
+ _hist_infos = getattr(ax, '_matplotly_hist_info', [])
107
+ _hist_merged = any(hi.get('merged', False) for hi in _hist_infos)
108
+ _skip_hist_style = False
109
+
110
+ if _hist_infos and _hist_merged:
111
+ # Merged: all BarContainers are histogram containers
112
+ for ci, _c in enumerate(ax.containers):
113
+ if isinstance(_c, _BCt):
114
+ hist_container_idxs.append(ci)
115
+ for _p in _c:
116
+ hist_patch_ids.add(id(_p))
117
+ for hi in _hist_infos:
118
+ for art in hi.get('artists', []):
119
+ hist_patch_ids.add(id(art))
120
+ _data_vars = getattr(ax, '_matplotly_hist_data_vars', None)
121
+ _skip_hist_style = _emit_hist_merged(
122
+ lines, ax_var, _hist_infos, data_vars=_data_vars)
123
+ else:
124
+ # Non-merged: detect histograms via geometry or attribute
125
+ for ci, _c in enumerate(ax.containers):
126
+ if isinstance(_c, _BCt):
127
+ from ._introspect import FigureIntrospector as _FI_cg
128
+ if (_FI_cg._is_histogram_container(_c)
129
+ or getattr(_c, '_matplotly_is_histogram', False)):
130
+ hist_container_idxs.append(ci)
131
+ for _p in _c:
132
+ hist_patch_ids.add(id(_p))
133
+
134
+ if not _skip_hist_style:
135
+ # Build label lookup from hist_info
136
+ _hi_label_map = {}
137
+ for _hi in _hist_infos:
138
+ _hi_arts = _hi.get('artists', [])
139
+ if _hi_arts:
140
+ _hi_label_map[id(_hi_arts[0])] = _hi.get('label', '')
141
+
142
+ # Style-only code for all detected histogram containers
143
+ for _hj, _hci in enumerate(hist_container_idxs):
144
+ _hcont = ax.containers[_hci]
145
+ _hpats = [p for p in _hcont.patches
146
+ if isinstance(p, _Rect)]
147
+ if not _hpats:
148
+ continue
149
+ _hp0 = _hpats[0]
150
+ _hlbl = _hi_label_map.get(id(_hp0), '')
151
+ if not _hlbl or _hlbl.startswith("_"):
152
+ _hlbl = _hp0.get_label()
153
+ if not _hlbl or _hlbl.startswith("_"):
154
+ _hlbl = _hcont.get_label()
155
+ if not _hlbl or _hlbl.startswith("_"):
156
+ _hlbl = f"Histogram {_hj}"
157
+ _hfc = to_hex(_hp0.get_facecolor())
158
+ _hec = to_hex(_hp0.get_edgecolor())
159
+ _halpha = _hp0.get_alpha()
160
+ _hlw = _hp0.get_linewidth()
161
+
162
+ lines.append(f"\n# Histogram: {_hlbl}")
163
+ _hacc = f"{ax_var}.containers[{_hci}]"
164
+ lines.append(f"for _p in {_hacc}:")
165
+ lines.append(f" _p.set_facecolor({_fmt(_hfc)})")
166
+ lines.append(f" _p.set_edgecolor({_fmt(_hec)})")
167
+ if _halpha is not None and round(_halpha, 2) != 1.0:
168
+ lines.append(
169
+ f" _p.set_alpha({_fmt(round(_halpha, 2))})")
170
+ if round(_hlw, 1) != 1.0:
171
+ lines.append(
172
+ f" _p.set_linewidth({_fmt(round(_hlw, 1))})")
173
+ _hhatch = _hp0.get_hatch()
174
+ if _hhatch:
175
+ lines.append(f" _p.set_hatch({_fmt(_hhatch)})")
176
+
177
+ # --- Bars (style existing containers) ---
178
+ bar_patch_ids: set[int] = set()
179
+ bar_container_idxs: list[int] = []
180
+ for ci, _c in enumerate(ax.containers):
181
+ if isinstance(_c, _BCt) and ci not in hist_container_idxs:
182
+ bar_container_idxs.append(ci)
183
+ for _p in _c:
184
+ bar_patch_ids.add(id(_p))
185
+
186
+ for _bj, _bci in enumerate(bar_container_idxs):
187
+ _bcont = ax.containers[_bci]
188
+ _bpats = [p for p in _bcont.patches if isinstance(p, _Rect)]
189
+ if not _bpats:
190
+ continue
191
+ _bp0 = _bpats[0]
192
+ _blbl = _bcont.get_label()
193
+ if not _blbl or _blbl.startswith("_"):
194
+ _blbl = _bp0.get_label()
195
+ if not _blbl or _blbl.startswith("_"):
196
+ _blbl = f"Bar group {_bj}"
197
+
198
+ _bfc = to_hex(_bp0.get_facecolor())
199
+ _bec = to_hex(_bp0.get_edgecolor())
200
+ _balpha = _bp0.get_alpha()
201
+ _blw = _bp0.get_linewidth()
202
+
203
+ lines.append(f"\n# {_blbl}")
204
+ _bacc = f"{ax_var}.containers[{_bci}]"
205
+ lines.append(f"for _p in {_bacc}:")
206
+ lines.append(f" _p.set_facecolor({_fmt(_bfc)})")
207
+ lines.append(f" _p.set_edgecolor({_fmt(_bec)})")
208
+ if _balpha is not None and round(_balpha, 2) != 1.0:
209
+ lines.append(
210
+ f" _p.set_alpha({_fmt(round(_balpha, 2))})")
211
+ if round(_blw, 1) != 1.0:
212
+ lines.append(
213
+ f" _p.set_linewidth({_fmt(round(_blw, 1))})")
214
+ _bhatch = _bp0.get_hatch()
215
+ if _bhatch:
216
+ lines.append(f" _p.set_hatch({_fmt(_bhatch)})")
217
+
218
+ # Per-patch geometry (width + position)
219
+ _bi_orient = (getattr(ax, '_matplotly_bar_info', [{}])[0]
220
+ .get('orientation', 'vertical')
221
+ if getattr(ax, '_matplotly_bar_info', None)
222
+ else 'vertical')
223
+ _b_vert = _bi_orient == 'vertical'
224
+ if _b_vert:
225
+ _bw_cur = round(float(_bp0.get_width()), 4)
226
+ _bx_list = [round(float(p.get_x()), 4) for p in _bpats]
227
+ lines.append(
228
+ f"for _p, _x in zip({_bacc}, {_bx_list!r}):")
229
+ lines.append(f" _p.set_width({_fmt(_bw_cur)})")
230
+ lines.append(f" _p.set_x(_x)")
231
+ else:
232
+ _bh_cur = round(float(_bp0.get_height()), 4)
233
+ _by_list = [round(float(p.get_y()), 4) for p in _bpats]
234
+ lines.append(
235
+ f"for _p, _y in zip({_bacc}, {_by_list!r}):")
236
+ lines.append(f" _p.set_height({_fmt(_bh_cur)})")
237
+ lines.append(f" _p.set_y(_y)")
238
+
239
+ # Bar tick labels (from bar panel adjustments)
240
+ _bar_infos = getattr(ax, '_matplotly_bar_info', [])
241
+ if _bar_infos:
242
+ _bi_ref = _bar_infos[0]
243
+ _bi_tl = _bi_ref.get('tick_labels', [])
244
+ _bi_tc = _bi_ref.get('tick_centers', [])
245
+ _bi_orient = _bi_ref.get('orientation', 'vertical')
246
+ if _bi_tl and _bi_tc:
247
+ _bi_tax = 'x' if _bi_orient == 'vertical' else 'y'
248
+ _bi_tc_str = ", ".join(
249
+ str(round(float(t), 4)) for t in _bi_tc)
250
+ lines.append(
251
+ f"{ax_var}.set_{_bi_tax}ticks([{_bi_tc_str}])")
252
+ _bi_rot = _bi_ref.get('tick_rotation', 0)
253
+ _bi_ha = _bi_ref.get('tick_ha', 'center')
254
+ _bi_tl_args = f"{_bi_tl!r}"
255
+ if _bi_rot:
256
+ _bi_tl_args += f", rotation={_fmt(_bi_rot)}"
257
+ if _bi_ha != "center":
258
+ _bi_tl_args += f", ha={_fmt(_bi_ha)}"
259
+ _bi_tl_args += ", rotation_mode='anchor'"
260
+ lines.append(
261
+ f"{ax_var}.set_{_bi_tax}ticklabels({_bi_tl_args})")
262
+
263
+ # --- Distribution plots (bxp-based recreation) ---
264
+ _dist_data_vars = getattr(ax, '_matplotly_dist_data_vars', None)
265
+
266
+ if _dist_infos:
267
+ _di_ref = _dist_infos[0]
268
+ _di_orient = _di_ref.get('orientation', 'vertical')
269
+ _di_width = _di_ref.get('width', 0.5)
270
+ _di_tl = _di_ref.get('tick_labels', [])
271
+ _di_tc = _di_ref.get('tick_centers',
272
+ _di_ref.get('positions', []))
273
+ _di_ng = len(_dist_infos)
274
+
275
+ if _dist_data_vars:
276
+ # --- Apply path: use real data variable ---
277
+ _emit_dist_with_data(
278
+ lines, ax_var, _dist_infos, _dist_data_vars,
279
+ _di_orient, _di_width, _di_ng)
280
+ else:
281
+ # --- Copy/fallback path: stats + fabricated data ---
282
+ lines.append(f"\n# Distribution plot")
283
+
284
+ # Cleanup original distribution artists so bxp() doesn't double
285
+ lines.append(
286
+ f"for _l in list({ax_var}.lines):")
287
+ lines.append(
288
+ f" if _l.get_label().startswith('_'):")
289
+ lines.append(
290
+ f" _l.remove()")
291
+ lines.append(
292
+ f"for _p in list({ax_var}.patches):")
293
+ lines.append(
294
+ f" if _p.get_label() == '' or _p.get_label().startswith('_'):")
295
+ lines.append(
296
+ f" try:")
297
+ lines.append(
298
+ f" _p.remove()")
299
+ lines.append(
300
+ f" except ValueError:")
301
+ lines.append(
302
+ f" pass")
303
+
304
+ for _dj, _dinfo in enumerate(_dist_infos):
305
+ _d_mode = _dinfo.get('display_mode', 'box')
306
+ _d_label = _dinfo.get('label', f'Group {_dj}')
307
+ _d_pos = _dinfo.get('positions', _di_tc)
308
+ _d_raw = _dinfo.get('raw_data', [])
309
+
310
+ lines.append(f"\n# {_d_label}")
311
+
312
+ if not _d_raw:
313
+ continue
314
+
315
+ # Prefer original introspected stats (avoids double
316
+ # reconstruction that degrades violin KDE shapes).
317
+ _orig_box_stats = _dinfo.get('box_stats', [])
318
+ if _orig_box_stats:
319
+ bxp_stats = _bxp_stats_from_original(
320
+ _orig_box_stats, _dinfo)
321
+ else:
322
+ bxp_stats = _compute_bxp_stats(_d_raw, _dinfo)
323
+ if not bxp_stats:
324
+ continue
325
+ svar = _emit_stats_var(lines, bxp_stats, _dj, _di_ng)
326
+
327
+ # Reconstruct data from stats for violin/jitter
328
+ _needs_data = 'violin' in _d_mode or 'jitter' in _d_mode
329
+ raw_var = None
330
+ if _needs_data:
331
+ raw_var = _emit_data_from_stats(
332
+ lines, svar, _dj, _di_ng)
333
+
334
+ # Violin (background layer)
335
+ if 'violin' in _d_mode and raw_var:
336
+ _emit_violin(lines, ax_var, _dinfo, raw_var, _d_pos,
337
+ _di_orient, _di_width, _d_mode,
338
+ _d_label)
339
+
340
+ # Box (middle layer)
341
+ if 'box' in _d_mode:
342
+ _emit_bxp_call(lines, ax_var, _dinfo, svar, _d_pos,
343
+ _d_label, _dj, _di_ng, _di_orient,
344
+ _di_width, _d_mode)
345
+
346
+ # Jitter (top layer)
347
+ if 'jitter' in _d_mode and raw_var:
348
+ _emit_jitter(lines, ax_var, _dinfo, raw_var, _d_pos,
349
+ _di_orient, _d_mode, _d_label)
350
+
351
+ # Tick labels (shared by both paths)
352
+ if _di_tl and _di_tc:
353
+ _d_tax = 'x' if _di_orient == 'vertical' else 'y'
354
+ _d_tc_str = ", ".join(
355
+ str(round(float(t), 4)) for t in _di_tc)
356
+ lines.append(
357
+ f"\n{ax_var}.set_{_d_tax}ticks([{_d_tc_str}])")
358
+ _d_ha = _di_ref.get('tick_ha', 'center')
359
+ _d_rot = _di_ref.get('tick_rotation', 0)
360
+ _d_pad = _di_ref.get('tick_pad', 4.0)
361
+ _d_tl_args = f"{_di_tl!r}"
362
+ if _d_rot:
363
+ _d_tl_args += f", rotation={_fmt(_d_rot)}"
364
+ if _d_ha != 'center':
365
+ _d_tl_args += f", ha={_fmt(_d_ha)}"
366
+ _d_tl_args += ", rotation_mode='anchor'"
367
+ lines.append(
368
+ f"{ax_var}.set_{_d_tax}ticklabels({_d_tl_args})")
369
+ if abs(_d_pad - 4.0) > 0.1:
370
+ lines.append(
371
+ f"{ax_var}.tick_params(axis={_fmt(_d_tax)},"
372
+ f" pad={_fmt(_d_pad)})")
373
+
374
+ # --- Heatmaps ---
375
+ _heatmap_infos = getattr(ax, '_matplotly_heatmap_info', [])
376
+ if _heatmap_infos:
377
+ _emit_heatmap(lines, ax_var, _heatmap_infos)
378
+
379
+ # --- Colorbar ---
380
+ _cbar_info = getattr(ax, '_matplotly_colorbar_info', None)
381
+ if _cbar_info and _cbar_info.get('show', False) and _heatmap_infos:
382
+ _emit_colorbar(lines, ax_var, _cbar_info, _heatmap_infos)
383
+
384
+ # --- Errorbars ---
385
+ _errorbar_infos = getattr(ax, '_matplotly_errorbar_info', [])
386
+ if _errorbar_infos:
387
+ _emit_errorbars(lines, ax_var, _errorbar_infos)
388
+
389
+ # --- Patches (generic — skip histogram + bar + dist patches) ---
390
+ _skip_ids = hist_patch_ids | bar_patch_ids | dist_artist_ids
391
+ for i, patch in enumerate(ax.patches):
392
+ if id(patch) in _skip_ids:
393
+ continue
394
+ if hasattr(patch, "get_facecolor"):
395
+ fc = to_hex(patch.get_facecolor())
396
+ if fc != "#1f77b4": # only if non-default
397
+ acc = f"{ax_var}.patches[{i}]"
398
+ lines.append(f"{acc}.set_facecolor({_fmt(fc)})")
399
+
400
+ # --- Collections (scatter — filtered list for stable indexing) ---
401
+ from matplotlib.collections import PathCollection
402
+ _user_colls = []
403
+ for coll in ax.collections:
404
+ if not isinstance(coll, PathCollection):
405
+ continue
406
+ if id(coll) in dist_artist_ids:
407
+ continue
408
+ if id(coll) in heatmap_artist_ids:
409
+ continue
410
+ if id(coll) in errorbar_artist_ids:
411
+ continue
412
+ _user_colls.append(coll)
413
+
414
+ if _user_colls:
415
+ lines.append(
416
+ f"\nfrom matplotlib.collections import "
417
+ f"PathCollection as _PC")
418
+ lines.append(
419
+ f"_scatter = [c for c in {ax_var}.collections "
420
+ f"if isinstance(c, _PC)]")
421
+ for j, coll in enumerate(_user_colls):
422
+ acc = f"_scatter[{j}]"
423
+ try:
424
+ fc = coll.get_facecolor()
425
+ if len(fc) > 0:
426
+ lines.append(
427
+ f"{acc}.set_facecolor({_fmt(to_hex(fc[0]))})")
428
+ ec = coll.get_edgecolor()
429
+ if len(ec) > 0:
430
+ _ec_hex = to_hex(ec[0])
431
+ if _ec_hex != "#000000":
432
+ lines.append(
433
+ f"{acc}.set_edgecolor({_fmt(_ec_hex)})")
434
+ alpha = coll.get_alpha()
435
+ if alpha is not None:
436
+ lines.append(f"{acc}.set_alpha({_fmt(alpha)})")
437
+ sizes = coll.get_sizes()
438
+ if len(sizes) > 0:
439
+ lines.append(
440
+ f"{acc}.set_sizes([{_fmt(float(sizes[0]))}])")
441
+ except Exception:
442
+ pass
443
+
444
+ # --- Text labels ---
445
+ title = ax.get_title()
446
+ xlabel = ax.get_xlabel()
447
+ ylabel = ax.get_ylabel()
448
+ if title or xlabel or ylabel:
449
+ lines.append(f"\n# Labels")
450
+ if title:
451
+ lines.append(f"{ax_var}.set_title({_fmt(title)})")
452
+ if xlabel:
453
+ lines.append(f"{ax_var}.set_xlabel({_fmt(xlabel)})")
454
+ if ylabel:
455
+ lines.append(f"{ax_var}.set_ylabel({_fmt(ylabel)})")
456
+
457
+ # --- Axis limits ---
458
+ xlim = ax.get_xlim()
459
+ ylim = ax.get_ylim()
460
+ lines.append(f"\n# Limits")
461
+ lines.append(f"{ax_var}.set_xlim({_fmt(round(float(xlim[0]), 4))}, "
462
+ f"{_fmt(round(float(xlim[1]), 4))})")
463
+ lines.append(f"{ax_var}.set_ylim({_fmt(round(float(ylim[0]), 4))}, "
464
+ f"{_fmt(round(float(ylim[1]), 4))})")
465
+
466
+ # --- Axis scale ---
467
+ xscale = ax.get_xscale()
468
+ yscale = ax.get_yscale()
469
+ if xscale != "linear" or yscale != "linear":
470
+ lines.append(f"\n# Scale")
471
+ if xscale != "linear":
472
+ lines.append(f"{ax_var}.set_xscale({_fmt(xscale)})")
473
+ if yscale != "linear":
474
+ lines.append(f"{ax_var}.set_yscale({_fmt(yscale)})")
475
+
476
+ # --- Fonts ---
477
+ title_size = ax.title.get_fontsize()
478
+ xlabel_size = ax.xaxis.label.get_fontsize()
479
+ ylabel_size = ax.yaxis.label.get_fontsize()
480
+ title_family = ax.title.get_fontfamily()
481
+ if isinstance(title_family, list):
482
+ title_family = title_family[0] if title_family else "Arial"
483
+
484
+ lines.append(f"\n# Fonts")
485
+ lines.append(f"{ax_var}.title.set_fontsize({_fmt(title_size)})")
486
+ lines.append(f"{ax_var}.title.set_fontfamily({_fmt(title_family)})")
487
+ lines.append(f"{ax_var}.xaxis.label.set_fontsize({_fmt(xlabel_size)})")
488
+ lines.append(f"{ax_var}.yaxis.label.set_fontsize({_fmt(ylabel_size)})")
489
+
490
+ # Title pad (distance from axes top, in points)
491
+ try:
492
+ _tpad_disp = ax.titleOffsetTrans.get_matrix()[1, 2]
493
+ _tpad_pts = round(_tpad_disp / fig.dpi * 72, 1)
494
+ if abs(_tpad_pts - 6.0) > 0.5:
495
+ lines.append(
496
+ f"{ax_var}.set_title({ax_var}.get_title(), "
497
+ f"pad={_fmt(_tpad_pts)})")
498
+ except Exception:
499
+ pass
500
+
501
+ # Text color / weight / style
502
+ for _tname, _tobj in [("title", ax.title),
503
+ ("xaxis.label", ax.xaxis.label),
504
+ ("yaxis.label", ax.yaxis.label)]:
505
+ try:
506
+ _tc = to_hex(_tobj.get_color())
507
+ if _tc != "#000000":
508
+ lines.append(f"{ax_var}.{_tname}.set_color({_fmt(_tc)})")
509
+ except Exception:
510
+ pass
511
+ _fw = _tobj.get_fontweight()
512
+ if _fw == 'bold' or (isinstance(_fw, (int, float)) and _fw >= 600):
513
+ lines.append(f"{ax_var}.{_tname}.set_fontweight('bold')")
514
+ if _tobj.get_fontstyle() == 'italic':
515
+ lines.append(f"{ax_var}.{_tname}.set_fontstyle('italic')")
516
+
517
+ # Tick label font size (use first tick as representative)
518
+ xticks = ax.get_xticklabels()
519
+ yticks = ax.get_yticklabels()
520
+ if xticks:
521
+ tick_sz = xticks[0].get_fontsize()
522
+ lines.append(f"{ax_var}.tick_params(labelsize={_fmt(tick_sz)})")
523
+
524
+ # Tick label colors
525
+ _xtl = ax.get_xticklabels()
526
+ if _xtl:
527
+ try:
528
+ _xtc = to_hex(_xtl[0].get_color())
529
+ if _xtc != "#000000":
530
+ lines.append(f"{ax_var}.tick_params(axis='x', "
531
+ f"labelcolor={_fmt(_xtc)})")
532
+ except Exception:
533
+ pass
534
+ _ytl = ax.get_yticklabels()
535
+ if _ytl:
536
+ try:
537
+ _ytc = to_hex(_ytl[0].get_color())
538
+ if _ytc != "#000000":
539
+ lines.append(f"{ax_var}.tick_params(axis='y', "
540
+ f"labelcolor={_fmt(_ytc)})")
541
+ except Exception:
542
+ pass
543
+
544
+ # --- Tick params ---
545
+ # Read from the actual tick objects
546
+ xtick_objs = ax.xaxis.get_major_ticks()
547
+ if xtick_objs:
548
+ tick = xtick_objs[0]
549
+ tick_line = tick.tick1line
550
+ direction = "in" if tick_line.get_marker() == 2 else "out"
551
+ # Check inout: tick2line visible means inout
552
+ if tick.tick2line.get_visible() and not ax.spines["top"].get_visible():
553
+ direction = "inout"
554
+ tick_len = tick_line.get_markersize()
555
+ tick_width = tick_line.get_markeredgewidth()
556
+ lines.append(f"{ax_var}.tick_params(direction={_fmt(direction)}, "
557
+ f"length={_fmt(tick_len)}, width={_fmt(tick_width)})")
558
+
559
+ # --- Tick spacing ---
560
+ from matplotlib.ticker import MultipleLocator as _ML
561
+ x_loc = ax.xaxis.get_major_locator()
562
+ y_loc = ax.yaxis.get_major_locator()
563
+ if isinstance(x_loc, _ML) or isinstance(y_loc, _ML):
564
+ if isinstance(x_loc, _ML):
565
+ xt = list(ax.get_xticks())
566
+ if len(xt) >= 2:
567
+ step = round(xt[1] - xt[0], 6)
568
+ if step > 0:
569
+ lines.append(
570
+ f"{ax_var}.xaxis.set_major_locator("
571
+ f"matplotlib.ticker.MultipleLocator({_fmt(step)}))")
572
+ if isinstance(y_loc, _ML):
573
+ yt = list(ax.get_yticks())
574
+ if len(yt) >= 2:
575
+ step = round(yt[1] - yt[0], 6)
576
+ if step > 0:
577
+ lines.append(
578
+ f"{ax_var}.yaxis.set_major_locator("
579
+ f"matplotlib.ticker.MultipleLocator({_fmt(step)}))")
580
+
581
+ # --- Spines ---
582
+ spine_changes = []
583
+ for name in ("top", "right", "bottom", "left"):
584
+ sp = ax.spines[name]
585
+ if not sp.get_visible():
586
+ spine_changes.append(
587
+ f"{ax_var}.spines[{_fmt(name)}].set_visible(False)")
588
+ lw = sp.get_linewidth()
589
+ if lw != 1.0:
590
+ spine_changes.append(
591
+ f"{ax_var}.spines[{_fmt(name)}].set_linewidth({_fmt(lw)})")
592
+ if spine_changes:
593
+ lines.append(f"\n# Spines")
594
+ lines.extend(spine_changes)
595
+
596
+ # --- Grid ---
597
+ # Check if grid lines are visible
598
+ has_grid = any(gl.get_visible() for gl in ax.xaxis.get_gridlines())
599
+ if has_grid:
600
+ gl = ax.xaxis.get_gridlines()[0]
601
+ lines.append(f"\n# Grid")
602
+ lines.append(
603
+ f"{ax_var}.grid(True, alpha={_fmt(gl.get_alpha() or 0.5)}, "
604
+ f"linewidth={_fmt(gl.get_linewidth())}, "
605
+ f"linestyle={_fmt(gl.get_linestyle())})")
606
+
607
+ # --- Legend ---
608
+ leg = ax.get_legend()
609
+ _leg_handles, _leg_labels = ax.get_legend_handles_labels()
610
+ if leg is None and _leg_handles:
611
+ # Legend was removed by user
612
+ lines.append(f"\n# Legend removed")
613
+ lines.append(f"if {ax_var}.get_legend() is not None:")
614
+ lines.append(f" {ax_var}.get_legend().remove()")
615
+ elif leg is not None:
616
+ if _leg_handles:
617
+ lines.append(f"\n# Legend")
618
+
619
+ # Get legend properties
620
+ leg_kwargs = []
621
+ try:
622
+ loc = leg._loc
623
+ _loc_names = {
624
+ 0: "upper right", 1: "upper right", 2: "upper left",
625
+ 3: "lower left", 4: "lower right", 5: "right",
626
+ 6: "center left", 7: "center right",
627
+ 8: "lower center", 9: "upper center", 10: "center",
628
+ }
629
+ loc_str = _loc_names.get(loc, "upper right")
630
+ leg_kwargs.append(f"loc={_fmt(loc_str)}")
631
+ except Exception:
632
+ pass
633
+
634
+ frameon = leg.get_frame().get_visible()
635
+ leg_kwargs.append(f"frameon={_fmt(frameon)}")
636
+
637
+ fontsize = leg._fontsize
638
+ if fontsize:
639
+ leg_kwargs.append(f"fontsize={_fmt(fontsize)}")
640
+
641
+ ncols = leg._ncols
642
+ if ncols != 1:
643
+ leg_kwargs.append(f"ncol={_fmt(ncols)}")
644
+
645
+ markerfirst = getattr(leg, '_markerfirst', True)
646
+ if not markerfirst:
647
+ leg_kwargs.append("markerfirst=False")
648
+
649
+ handletextpad = getattr(leg, 'handletextpad', 0.8)
650
+ if round(handletextpad, 2) != 0.8:
651
+ leg_kwargs.append(f"handletextpad={_fmt(round(handletextpad, 2))}")
652
+
653
+ handleheight = getattr(leg, 'handleheight', 0.7)
654
+ if round(handleheight, 2) != 0.7:
655
+ leg_kwargs.append(f"handleheight={_fmt(round(handleheight, 2))}")
656
+
657
+ # bbox_to_anchor
658
+ if hasattr(leg, "_bbox_to_anchor") and leg._bbox_to_anchor is not None:
659
+ try:
660
+ bbox = leg._bbox_to_anchor
661
+ if bbox.width < 1 and bbox.height < 1:
662
+ inv = ax.transAxes.inverted()
663
+ x_ax, y_ax = inv.transform((bbox.x0, bbox.y0))
664
+ leg_kwargs.append(
665
+ f"bbox_to_anchor=({_fmt(round(x_ax, 2))}, "
666
+ f"{_fmt(round(y_ax, 2))})")
667
+ except Exception:
668
+ pass
669
+
670
+ kwargs_str = ", ".join(leg_kwargs)
671
+ lines.append(f"{ax_var}.legend({kwargs_str})")
672
+
673
+ # Legend text colors
674
+ for _li, _lt in enumerate(leg.get_texts()):
675
+ try:
676
+ _ltc = to_hex(_lt.get_color())
677
+ if _ltc != "#000000":
678
+ lines.append(
679
+ f"{ax_var}.get_legend().get_texts()"
680
+ f"[{_li}].set_color({_fmt(_ltc)})")
681
+ except Exception:
682
+ pass
683
+
684
+ # --- Layout (figure-level, emitted once) ---
685
+ w, h = fig.get_size_inches()
686
+ lines.append(f"\n# Layout")
687
+ lines.append(f"fig.set_size_inches({_fmt(round(w, 2))}, {_fmt(round(h, 2))})")
688
+
689
+ fc = to_hex(fig.get_facecolor())
690
+ if fc != "#ffffff":
691
+ lines.append(f"fig.set_facecolor({_fmt(fc)})")
692
+
693
+ lines.append(f"fig.tight_layout()")
694
+
695
+ # --- Subplot spacing (multi-subplot only) ---
696
+ if multi:
697
+ try:
698
+ sp = fig.subplotpars
699
+ _h = round(sp.hspace, 2) if sp.hspace else 0.0
700
+ _w = round(sp.wspace, 2) if sp.wspace else 0.0
701
+ if _h > 0 or _w > 0:
702
+ lines.append(f"fig.subplots_adjust(hspace={_fmt(_h)}, wspace={_fmt(_w)})")
703
+ except Exception:
704
+ pass
705
+
706
+ # --- Marginal histograms ---
707
+ if marginal_axes:
708
+ lines.append("\n# Marginal histograms")
709
+
710
+ # Pre-scan: determine space needed per parent axes
711
+ _marg_infos = [] # (m_ax, info) pairs
712
+ _parent_space: dict[int, list] = {} # pidx -> [(axis, pos, height, pad)]
713
+ for m_ax in marginal_axes:
714
+ info = getattr(m_ax, '_matplotly_marginal_info', None)
715
+ if not info:
716
+ continue
717
+ _marg_infos.append((m_ax, info))
718
+ pidx = info.get('parent_ax_index', 0)
719
+ axis = info.get('axis', 'x')
720
+ position = info.get('position', 'top' if axis == 'x' else 'right')
721
+ m_height = info.get('height', 0.8)
722
+ m_pad = info.get('pad', 0.05)
723
+ _parent_space.setdefault(pidx, []).append(
724
+ (axis, position, m_height, m_pad))
725
+
726
+ # Shrink parent axes to make room for marginals
727
+ for pidx, needs in sorted(_parent_space.items()):
728
+ p_var = f"axes[{pidx}]" if multi else "ax"
729
+ lines.append(f"_pos = {p_var}.get_position()")
730
+ lines.append("_fig_w, _fig_h = fig.get_size_inches()")
731
+ lines.append(
732
+ "_x0, _y0, _w, _h = _pos.x0, _pos.y0, "
733
+ "_pos.width, _pos.height")
734
+ for axis, position, m_height, m_pad in needs:
735
+ if axis == 'x' and position == 'top':
736
+ lines.append(
737
+ f"_h -= ({_fmt(m_height)} + {_fmt(m_pad)}) / _fig_h")
738
+ elif axis == 'x' and position == 'bottom':
739
+ lines.append(
740
+ f"_bump = ({_fmt(m_height)} + {_fmt(m_pad)}) / _fig_h")
741
+ lines.append("_y0 += _bump; _h -= _bump")
742
+ elif axis == 'y' and position == 'right':
743
+ lines.append(
744
+ f"_w -= ({_fmt(m_height)} + {_fmt(m_pad)}) / _fig_w")
745
+ elif axis == 'y' and position == 'left':
746
+ lines.append(
747
+ f"_bump = ({_fmt(m_height)} + {_fmt(m_pad)}) / _fig_w")
748
+ lines.append("_x0 += _bump; _w -= _bump")
749
+ lines.append(f"{p_var}.set_position([_x0, _y0, _w, _h])")
750
+
751
+ # Now create each marginal axes in the freed space
752
+ for m_ax, info in _marg_infos:
753
+ parent_idx = info.get('parent_ax_index', 0)
754
+ axis = info.get('axis', 'x')
755
+ position = info.get('position', 'top' if axis == 'x' else 'right')
756
+ m_height = info.get('height', 0.8)
757
+ m_pad = info.get('pad', 0.05)
758
+ mode = info.get('mode', 'overlay')
759
+ n_bins = info.get('bins', 20)
760
+ alpha = info.get('alpha', 0.5)
761
+ separation = info.get('separation', 0.1)
762
+ inverted = info.get('inverted', False)
763
+ tick_side = info.get('tick_side',
764
+ 'left' if axis == 'x' else 'bottom')
765
+ tick_fs = info.get('tick_fontsize', 8)
766
+ tick_step = info.get('tick_step', 0)
767
+ range_min = info.get('range_min', 0)
768
+ range_max = info.get('range_max', 0)
769
+ label = info.get('label', '')
770
+ label_fs = info.get('label_fontsize', 8)
771
+ label_bold = info.get('label_bold', False)
772
+ label_italic = info.get('label_italic', False)
773
+ label_color = info.get('label_color', '#000000')
774
+ m_title = info.get('title', '')
775
+ title_fs = info.get('title_fontsize', 8)
776
+ title_bold = info.get('title_bold', False)
777
+ title_italic = info.get('title_italic', False)
778
+ title_color = info.get('title_color', '#000000')
779
+ colls = info.get('collections', [])
780
+ share = 'sharex' if axis == 'x' else 'sharey'
781
+ data_col = 0 if axis == 'x' else 1
782
+ p_var = f"axes[{parent_idx}]" if multi else "ax"
783
+
784
+ # Collect data + compute global bins
785
+ data_exprs = []
786
+ colors = []
787
+ for ci in colls:
788
+ idx = ci['coll_index']
789
+ data_exprs.append(
790
+ f"{p_var}.collections[{idx}].get_offsets()[:, {data_col}]")
791
+ colors.append(ci['color'])
792
+
793
+ lines.append(
794
+ f"_m_data = [{', '.join(data_exprs)}]")
795
+ lines.append(
796
+ f"_m_bins = np.histogram_bin_edges("
797
+ f"np.concatenate(_m_data), bins={_fmt(n_bins)})")
798
+
799
+ # Compute rect from adjusted main axes position
800
+ lines.append(f"_pos = {p_var}.get_position()")
801
+ lines.append("_fig_w, _fig_h = fig.get_size_inches()")
802
+ if axis == 'x':
803
+ lines.append(
804
+ f"_h_frac = {_fmt(m_height)} / _fig_h")
805
+ lines.append(
806
+ f"_pad = {_fmt(m_pad)} / _fig_h")
807
+ if position == 'top':
808
+ lines.append(
809
+ "_m_rect = [_pos.x0, _pos.y1 + _pad, "
810
+ "_pos.width, _h_frac]")
811
+ else:
812
+ lines.append(
813
+ "_m_rect = [_pos.x0, _pos.y0 - _h_frac - _pad, "
814
+ "_pos.width, _h_frac]")
815
+ else:
816
+ lines.append(
817
+ f"_w_frac = {_fmt(m_height)} / _fig_w")
818
+ lines.append(
819
+ f"_pad = {_fmt(m_pad)} / _fig_w")
820
+ if position == 'right':
821
+ lines.append(
822
+ "_m_rect = [_pos.x1 + _pad, _pos.y0, "
823
+ "_w_frac, _pos.height]")
824
+ else:
825
+ lines.append(
826
+ "_m_rect = [_pos.x0 - _w_frac - _pad, _pos.y0, "
827
+ "_w_frac, _pos.height]")
828
+
829
+ lines.append(
830
+ f"ax_m = fig.add_axes(_m_rect, {share}={p_var})")
831
+
832
+ if mode == 'overlay':
833
+ for i, color in enumerate(colors):
834
+ orient = (", orientation='horizontal'"
835
+ if axis == 'y' else "")
836
+ lines.append(
837
+ f"ax_m.hist(_m_data[{i}], bins=_m_bins, "
838
+ f"color={_fmt(color)}, alpha={_fmt(alpha)}, "
839
+ f"edgecolor='none'{orient})")
840
+ else: # dodge
841
+ n = len(colls)
842
+ lines.append(f"_n = {n}")
843
+ lines.append(f"_bw = _m_bins[1] - _m_bins[0]")
844
+ lines.append(f"_sub = _bw / _n")
845
+ lines.append(
846
+ f"_bar = _sub * {_fmt(1 - separation)}")
847
+ lines.append(
848
+ f"_ctrs = (_m_bins[:-1] + _m_bins[1:]) / 2")
849
+ if axis == 'x':
850
+ lines.append(
851
+ f"for _i, (_d, _c) in enumerate("
852
+ f"zip(_m_data, {colors!r})):")
853
+ lines.append(
854
+ f" _cnt, _ = np.histogram(_d, bins=_m_bins)")
855
+ lines.append(
856
+ f" _off = _sub * (_i - (_n - 1) / 2)")
857
+ lines.append(
858
+ f" ax_m.bar(_ctrs + _off, _cnt, width=_bar, "
859
+ f"color=_c, alpha={_fmt(alpha)}, edgecolor='none')")
860
+ else: # y
861
+ lines.append(
862
+ f"for _i, (_d, _c) in enumerate("
863
+ f"zip(_m_data, {colors!r})):")
864
+ lines.append(
865
+ f" _cnt, _ = np.histogram(_d, bins=_m_bins)")
866
+ lines.append(
867
+ f" _off = _sub * (_i - (_n - 1) / 2)")
868
+ lines.append(
869
+ f" ax_m.barh(_ctrs + _off, _cnt, height=_bar, "
870
+ f"color=_c, alpha={_fmt(alpha)}, edgecolor='none')")
871
+
872
+ # Inversion
873
+ if inverted:
874
+ if axis == 'x':
875
+ lines.append("ax_m.invert_yaxis()")
876
+ else:
877
+ lines.append("ax_m.invert_xaxis()")
878
+
879
+ # Spines
880
+ for spine_name in ('top', 'right', 'bottom', 'left'):
881
+ if not m_ax.spines[spine_name].get_visible():
882
+ lines.append(
883
+ f"ax_m.spines[{_fmt(spine_name)}].set_visible(False)")
884
+
885
+ # Tick configuration
886
+ _tfs = _fmt(tick_fs)
887
+ if axis == 'x':
888
+ lines.append(
889
+ "ax_m.tick_params(axis='x', bottom=False, top=False, "
890
+ "labelbottom=False, labeltop=False)")
891
+ if tick_side == 'left':
892
+ lines.append(
893
+ f"ax_m.tick_params(axis='y', left=True, "
894
+ f"labelleft=True, right=False, labelright=False, "
895
+ f"labelsize={_tfs})")
896
+ elif tick_side == 'right':
897
+ lines.append(
898
+ f"ax_m.tick_params(axis='y', left=False, "
899
+ f"labelleft=False, right=True, labelright=True, "
900
+ f"labelsize={_tfs})")
901
+ else:
902
+ lines.append(
903
+ "ax_m.tick_params(axis='y', left=False, "
904
+ "labelleft=False, right=False, labelright=False)")
905
+ else:
906
+ lines.append(
907
+ "ax_m.tick_params(axis='y', left=False, right=False, "
908
+ "labelleft=False, labelright=False)")
909
+ if tick_side == 'bottom':
910
+ lines.append(
911
+ f"ax_m.tick_params(axis='x', bottom=True, "
912
+ f"labelbottom=True, top=False, labeltop=False, "
913
+ f"labelsize={_tfs})")
914
+ elif tick_side == 'top':
915
+ lines.append(
916
+ f"ax_m.tick_params(axis='x', bottom=False, "
917
+ f"labelbottom=False, top=True, labeltop=True, "
918
+ f"labelsize={_tfs})")
919
+ else:
920
+ lines.append(
921
+ "ax_m.tick_params(axis='x', bottom=False, "
922
+ "labelbottom=False, top=False, labeltop=False)")
923
+
924
+ # Tick step
925
+ if tick_step > 0:
926
+ count_axis = 'yaxis' if axis == 'x' else 'xaxis'
927
+ lines.append(
928
+ f"ax_m.{count_axis}.set_major_locator("
929
+ f"matplotlib.ticker.MultipleLocator({_fmt(tick_step)}))")
930
+
931
+ # Range (count-axis limits)
932
+ if range_max > 0:
933
+ if axis == 'x':
934
+ if inverted:
935
+ lines.append(
936
+ f"ax_m.set_ylim({_fmt(range_max)}, "
937
+ f"{_fmt(range_min)})")
938
+ else:
939
+ lines.append(
940
+ f"ax_m.set_ylim({_fmt(range_min)}, "
941
+ f"{_fmt(range_max)})")
942
+ else:
943
+ if inverted:
944
+ lines.append(
945
+ f"ax_m.set_xlim({_fmt(range_max)}, "
946
+ f"{_fmt(range_min)})")
947
+ else:
948
+ lines.append(
949
+ f"ax_m.set_xlim({_fmt(range_min)}, "
950
+ f"{_fmt(range_max)})")
951
+
952
+ # Label
953
+ if label:
954
+ _lfs = _fmt(label_fs)
955
+ _lw = "'bold'" if label_bold else "'normal'"
956
+ _ls = "'italic'" if label_italic else "'normal'"
957
+ _lc_arg = ""
958
+ if label_color != '#000000':
959
+ _lc_arg = f", color={_fmt(label_color)}"
960
+ if axis == 'x':
961
+ if tick_side == 'right':
962
+ lines.append(
963
+ "ax_m.yaxis.set_label_position('right')")
964
+ lines.append(
965
+ f"ax_m.set_ylabel({_fmt(label)}, fontsize={_lfs}, "
966
+ f"fontweight={_lw}, fontstyle={_ls}{_lc_arg})")
967
+ else:
968
+ if tick_side == 'top':
969
+ lines.append(
970
+ "ax_m.xaxis.set_label_position('top')")
971
+ lines.append(
972
+ f"ax_m.set_xlabel({_fmt(label)}, fontsize={_lfs}, "
973
+ f"fontweight={_lw}, fontstyle={_ls}{_lc_arg})")
974
+
975
+ # Title
976
+ if m_title:
977
+ _tifs = _fmt(title_fs)
978
+ _tw = "'bold'" if title_bold else "'normal'"
979
+ _ts = "'italic'" if title_italic else "'normal'"
980
+ _tc_arg = ""
981
+ if title_color != '#000000':
982
+ _tc_arg = f", color={_fmt(title_color)}"
983
+ lines.append(
984
+ f"ax_m.set_title({_fmt(m_title)}, fontsize={_tifs}, "
985
+ f"fontweight={_tw}, fontstyle={_ts}{_tc_arg})")
986
+
987
+ return "\n".join(lines)
988
+
989
+
990
+ def _bxp_stats_from_original(orig_stats, dinfo):
991
+ """Convert original introspected box_stats to bxp-compatible format.
992
+
993
+ The introspector stores stats with a 'median' key; bxp() expects 'med'.
994
+ Using original stats avoids recomputing from reconstructed data, which
995
+ would degrade the violin KDE shape (double reconstruction).
996
+ """
997
+ bxp_stats = []
998
+ for s in orig_stats:
999
+ med = s.get('med') or s.get('median', 0)
1000
+ stat = {
1001
+ 'med': round(float(med), 4),
1002
+ 'q1': round(float(s['q1']), 4),
1003
+ 'q3': round(float(s['q3']), 4),
1004
+ 'whislo': round(float(s['whislo']), 4),
1005
+ 'whishi': round(float(s['whishi']), 4),
1006
+ 'fliers': [round(float(f), 4) for f in s.get('fliers', [])],
1007
+ }
1008
+ if dinfo.get('notch', False):
1009
+ iqr = stat['q3'] - stat['q1']
1010
+ # Use raw_data length if available, else 100 as proxy
1011
+ stat['cilo'] = round(stat['med'] - 1.57 * iqr / 10, 4)
1012
+ stat['cihi'] = round(stat['med'] + 1.57 * iqr / 10, 4)
1013
+ bxp_stats.append(stat)
1014
+ return bxp_stats
1015
+
1016
+
1017
+ def _compute_bxp_stats(raw_data, dinfo):
1018
+ """Compute bxp-compatible stats from raw data arrays."""
1019
+ bxp_stats = []
1020
+ for rd in raw_data:
1021
+ arr = np.asarray(rd, dtype=float)
1022
+ if len(arr) == 0:
1023
+ continue
1024
+ q1, med, q3 = np.percentile(arr, [25, 50, 75])
1025
+ iqr = q3 - q1
1026
+ wlo_d = arr[arr >= q1 - 1.5 * iqr]
1027
+ whi_d = arr[arr <= q3 + 1.5 * iqr]
1028
+ whislo = float(np.min(wlo_d)) if len(wlo_d) > 0 else float(q1)
1029
+ whishi = float(np.max(whi_d)) if len(whi_d) > 0 else float(q3)
1030
+ fliers = arr[(arr < whislo) | (arr > whishi)]
1031
+ stat = {
1032
+ 'med': round(float(med), 4),
1033
+ 'q1': round(float(q1), 4),
1034
+ 'q3': round(float(q3), 4),
1035
+ 'whislo': round(float(whislo), 4),
1036
+ 'whishi': round(float(whishi), 4),
1037
+ 'fliers': [round(float(f), 4) for f in fliers],
1038
+ }
1039
+ if dinfo.get('notch', False):
1040
+ n = len(arr)
1041
+ ci = 1.57 * iqr / np.sqrt(n) if n > 0 else 0
1042
+ stat['cilo'] = round(float(med) - ci, 4)
1043
+ stat['cihi'] = round(float(med) + ci, 4)
1044
+ bxp_stats.append(stat)
1045
+ return bxp_stats
1046
+
1047
+
1048
+ def _emit_stats_var(lines, bxp_stats, group_idx, n_groups):
1049
+ """Emit the _stats variable assignment. Returns the variable name."""
1050
+ svar = f"_stats_{group_idx}" if n_groups > 1 else "_stats"
1051
+ lines.append(f"{svar} = [")
1052
+ for s in bxp_stats:
1053
+ lines.append(f" {s!r},")
1054
+ lines.append(f"]")
1055
+ return svar
1056
+
1057
+
1058
+ def _emit_bxp_call(lines, ax_var, dinfo, svar, positions, label,
1059
+ group_idx, n_groups, orient, width, mode):
1060
+ """Emit ax.bxp() call using an already-defined stats variable."""
1061
+ box_width = width * 0.3 if 'violin' in mode else width
1062
+ is_vert = orient == "vertical"
1063
+
1064
+ bxp_args = []
1065
+ pos_str = ", ".join(str(round(float(p), 4)) for p in positions)
1066
+ bxp_args.append(f"positions=[{pos_str}]")
1067
+ bxp_args.append(f"vert={is_vert}")
1068
+ bxp_args.append("patch_artist=True")
1069
+ bxp_args.append(f"widths={_fmt(box_width)}")
1070
+ if dinfo.get('notch', False):
1071
+ bxp_args.append("notch=True")
1072
+ if dinfo.get('show_mean', False):
1073
+ bxp_args.append("showmeans=True")
1074
+ bxp_args.append("meanline=True")
1075
+ flier_marker = dinfo.get('flier_marker', 'o')
1076
+ if not flier_marker:
1077
+ bxp_args.append("showfliers=False")
1078
+
1079
+ # boxprops
1080
+ bp_kw = [
1081
+ f"facecolor={_fmt(dinfo.get('box_color', '#1f77b4'))}",
1082
+ f"edgecolor={_fmt(dinfo.get('box_edgecolor', '#000000'))}",
1083
+ ]
1084
+ blw = dinfo.get('box_lw', 1.0)
1085
+ if blw != 1.0:
1086
+ bp_kw.append(f"linewidth={_fmt(blw)}")
1087
+ ba = dinfo.get('box_alpha', 1.0)
1088
+ if ba != 1.0:
1089
+ bp_kw.append(f"alpha={_fmt(ba)}")
1090
+ if dinfo.get('box_hatch', ''):
1091
+ bp_kw.append(f"hatch={_fmt(dinfo['box_hatch'])}")
1092
+ bxp_args.append(f"boxprops=dict({', '.join(bp_kw)})")
1093
+
1094
+ # medianprops
1095
+ mp_kw = [f"color={_fmt(dinfo.get('median_color', '#ff7f0e'))}"]
1096
+ mlw = dinfo.get('median_lw', 2.0)
1097
+ if mlw != 2.0:
1098
+ mp_kw.append(f"linewidth={_fmt(mlw)}")
1099
+ bxp_args.append(f"medianprops=dict({', '.join(mp_kw)})")
1100
+
1101
+ # whiskerprops
1102
+ wp_kw = []
1103
+ ws = dinfo.get('whisker_style', '-')
1104
+ if ws != '-':
1105
+ wp_kw.append(f"linestyle={_fmt(ws)}")
1106
+ wlw = dinfo.get('whisker_lw', 1.0)
1107
+ if wlw != 1.0:
1108
+ wp_kw.append(f"linewidth={_fmt(wlw)}")
1109
+ if wp_kw:
1110
+ bxp_args.append(f"whiskerprops=dict({', '.join(wp_kw)})")
1111
+
1112
+ # capprops
1113
+ cp_kw = []
1114
+ if wlw != 1.0:
1115
+ cp_kw.append(f"linewidth={_fmt(wlw)}")
1116
+ if cp_kw:
1117
+ bxp_args.append(f"capprops=dict({', '.join(cp_kw)})")
1118
+
1119
+ # flierprops
1120
+ if flier_marker:
1121
+ fl_kw = [
1122
+ f"marker={_fmt(flier_marker)}",
1123
+ f"markersize={_fmt(dinfo.get('flier_size', 6.0))}",
1124
+ f"markerfacecolor={_fmt(dinfo.get('flier_color', '#000000'))}",
1125
+ ]
1126
+ bxp_args.append(f"flierprops=dict({', '.join(fl_kw)})")
1127
+
1128
+ # meanprops
1129
+ if dinfo.get('show_mean', False):
1130
+ ms = dinfo.get('mean_style', '--')
1131
+ mc = dinfo.get('median_color', '#ff7f0e')
1132
+ bxp_args.append(
1133
+ f"meanprops=dict(linestyle={_fmt(ms)}, color={_fmt(mc)})")
1134
+
1135
+ bvar = f"_bp_{group_idx}" if n_groups > 1 else "_bp"
1136
+ lines.append(f"{bvar} = {ax_var}.bxp({svar},")
1137
+ for k, arg in enumerate(bxp_args):
1138
+ comma = "," if k < len(bxp_args) - 1 else ")"
1139
+ lines.append(f" {arg}{comma}")
1140
+ lines.append(f"{bvar}['boxes'][0].set_label({_fmt(label)})")
1141
+
1142
+
1143
+ def _emit_data_from_stats(lines, svar, group_idx, n_groups):
1144
+ """Emit code to reconstruct approximate data from box stats.
1145
+
1146
+ Returns the variable name holding the list of data arrays.
1147
+ Mirrors ``_reconstruct_data_from_stats`` in ``_introspect.py`` exactly
1148
+ (including fliers and adjusted sample counts) so the violin KDE shapes
1149
+ match the matplotly rendering.
1150
+ """
1151
+ raw_var = f"_raw_{group_idx}" if n_groups > 1 else "_raw"
1152
+ lines.append(
1153
+ f"# Approximate data from stats "
1154
+ f"(replace {raw_var} with your original data for best fidelity)")
1155
+ lines.append(f"{raw_var} = []")
1156
+ lines.append(f"for _s in {svar}:")
1157
+ lines.append(f" _fl = _s.get('fliers', [])")
1158
+ lines.append(f" _n = max(100 - len(_fl), 20)")
1159
+ lines.append(f" _q = _n // 4")
1160
+ lines.append(f" _rem = _n - 4 * _q")
1161
+ lines.append(f" _rng = np.random.RandomState(42)")
1162
+ lines.append(f" {raw_var}.append(np.concatenate([")
1163
+ lines.append(f" _rng.uniform(_s['whislo'], _s['q1'], _q),")
1164
+ lines.append(f" _rng.uniform(_s['q1'], _s['med'], _q),")
1165
+ lines.append(f" _rng.uniform(_s['med'], _s['q3'], _q),")
1166
+ lines.append(f" _rng.uniform(_s['q3'], _s['whishi'], _q + _rem),")
1167
+ lines.append(f" np.array(_fl, dtype=float),")
1168
+ lines.append(f" ]))")
1169
+ return raw_var
1170
+
1171
+
1172
+ def _emit_violin(lines, ax_var, dinfo, raw_var, positions,
1173
+ orient, width, mode, label):
1174
+ """Emit ax.violinplot() with full styling."""
1175
+ vc = dinfo.get('violin_color', '#1f77b4')
1176
+ vec = dinfo.get('violin_edgecolor', '#000000')
1177
+ va = dinfo.get('violin_alpha', 0.3)
1178
+ vi = dinfo.get('violin_inner', 'box')
1179
+ vw = width * 1.5
1180
+ is_vert = orient == "vertical"
1181
+ show_stats = (vi != 'none')
1182
+ pos_str = ", ".join(str(round(float(p), 4)) for p in positions)
1183
+
1184
+ lines.append(f"_vp = {ax_var}.violinplot({raw_var},")
1185
+ lines.append(f" positions=[{pos_str}],")
1186
+ lines.append(f" widths={_fmt(vw)}, vert={is_vert},")
1187
+ lines.append(
1188
+ f" showmeans=False, showmedians={show_stats}, "
1189
+ f"showextrema={show_stats})")
1190
+ lines.append(f"for _body in _vp['bodies']:")
1191
+ lines.append(f" _body.set_facecolor({_fmt(vc)})")
1192
+ lines.append(f" _body.set_edgecolor({_fmt(vec)})")
1193
+ lines.append(f" _body.set_alpha({_fmt(va)})")
1194
+ if show_stats:
1195
+ lines.append(
1196
+ f"for _key in ('cmedians', 'cmins', 'cmaxes', 'cbars'):")
1197
+ lines.append(f" if _key in _vp:")
1198
+ lines.append(f" _vp[_key].set_color({_fmt(vec)})")
1199
+ lines.append(f" _vp[_key].set_alpha(0.8)")
1200
+ # Legend label if violin is the primary plot component
1201
+ if 'box' not in mode:
1202
+ lines.append(f"_vp['bodies'][0].set_label({_fmt(label)})")
1203
+
1204
+
1205
+ def _emit_jitter(lines, ax_var, dinfo, raw_var, positions,
1206
+ orient, mode, label):
1207
+ """Emit ax.scatter() calls for jitter/strip plot."""
1208
+ jc = dinfo.get('jitter_color', '#1f77b4')
1209
+ ja = dinfo.get('jitter_alpha', 0.5)
1210
+ js = dinfo.get('jitter_size', 3.0)
1211
+ jm = dinfo.get('jitter_marker', 'o')
1212
+ jsp = dinfo.get('jitter_spread', 0.2)
1213
+ pos_str = ", ".join(str(round(float(p), 4)) for p in positions)
1214
+ is_vert = orient == "vertical"
1215
+ jitter_only = 'box' not in mode and 'violin' not in mode
1216
+
1217
+ lines.append(f"_rng_j = np.random.RandomState(42)")
1218
+ lines.append(
1219
+ f"for _i, (_d, _pos) in enumerate("
1220
+ f"zip({raw_var}, [{pos_str}])):")
1221
+ lines.append(
1222
+ f" _jitter = _rng_j.uniform("
1223
+ f"{_fmt(-jsp)}, {_fmt(jsp)}, len(_d))")
1224
+ x_expr = "_pos + _jitter" if is_vert else "_d"
1225
+ y_expr = "_d" if is_vert else "_pos + _jitter"
1226
+ kw_parts = [
1227
+ f"s={_fmt(js ** 2)}", f"c={_fmt(jc)}",
1228
+ f"alpha={_fmt(ja)}", f"marker={_fmt(jm)}", "zorder=3",
1229
+ ]
1230
+ if jitter_only:
1231
+ kw_parts.append(
1232
+ f"label=({_fmt(label)} if _i == 0 else '_nolegend_')")
1233
+ lines.append(
1234
+ f" {ax_var}.scatter({x_expr}, {y_expr}, "
1235
+ f"{', '.join(kw_parts)})")
1236
+
1237
+
1238
+ def _emit_dist_with_data(lines, ax_var, dist_infos, data_vars,
1239
+ orient, width, n_groups):
1240
+ """Emit distribution code using the user's real data variable(s).
1241
+
1242
+ No artist cleanup. No fabricated data. Uses ax.boxplot() and
1243
+ ax.violinplot() directly with the user's data variable.
1244
+
1245
+ When len(data_vars) == n_groups, each group uses its own data_var
1246
+ directly (multi-call scenario). When len(data_vars) == 1, a single
1247
+ data_var is sliced by index (single-call, multi-group scenario).
1248
+ """
1249
+ lines.append(f"\n# Distribution plot")
1250
+
1251
+ multi_var = len(data_vars) == n_groups and n_groups > 1
1252
+
1253
+ if not multi_var:
1254
+ # Single data variable — compute position-to-data-index mapping
1255
+ data_var = data_vars[0]
1256
+ all_orig_pos = []
1257
+ for di in dist_infos:
1258
+ all_orig_pos.extend(
1259
+ di.get('original_positions', di['positions']))
1260
+ all_orig_pos_sorted = sorted(set(all_orig_pos))
1261
+ pos_to_idx = {p: i for i, p in enumerate(all_orig_pos_sorted)}
1262
+
1263
+ for dj, dinfo in enumerate(dist_infos):
1264
+ mode = dinfo.get('display_mode', 'box')
1265
+ label = dinfo.get('label', f'Group {dj}')
1266
+ positions = dinfo.get('positions', [])
1267
+ orig_pos = dinfo.get('original_positions', positions)
1268
+
1269
+ dvar = f"_data_{dj}" if n_groups > 1 else "_data"
1270
+ lines.append(f"\n# {label}")
1271
+
1272
+ if multi_var:
1273
+ # Each group has its own data variable — use directly
1274
+ lines.append(f"{dvar} = {data_vars[dj]}")
1275
+ else:
1276
+ # Single data variable — slice by index
1277
+ indices = [pos_to_idx[p] for p in orig_pos
1278
+ if p in pos_to_idx]
1279
+ lines.append(
1280
+ f"{dvar} = [{data_var}[i] for i in {indices!r}]")
1281
+
1282
+ # Violin (background)
1283
+ if 'violin' in mode:
1284
+ _emit_violin(lines, ax_var, dinfo, dvar, positions,
1285
+ orient, width, mode, label)
1286
+
1287
+ # Box
1288
+ if 'box' in mode:
1289
+ _emit_boxplot_call(lines, ax_var, dinfo, dvar, positions,
1290
+ label, dj, n_groups, orient, width, mode)
1291
+
1292
+ # Jitter
1293
+ if 'jitter' in mode:
1294
+ _emit_jitter(lines, ax_var, dinfo, dvar, positions,
1295
+ orient, mode, label)
1296
+
1297
+
1298
+ def _emit_boxplot_call(lines, ax_var, dinfo, data_var, positions,
1299
+ label, group_idx, n_groups, orient, width, mode):
1300
+ """Emit ax.boxplot() call using real data (Apply path)."""
1301
+ box_width = width * 0.3 if 'violin' in mode else width
1302
+ is_vert = orient == "vertical"
1303
+
1304
+ bp_args = []
1305
+ pos_str = ", ".join(str(round(float(p), 4)) for p in positions)
1306
+ bp_args.append(f"positions=[{pos_str}]")
1307
+ bp_args.append(f"vert={is_vert}")
1308
+ bp_args.append("patch_artist=True")
1309
+ bp_args.append(f"widths={_fmt(box_width)}")
1310
+ if dinfo.get('notch', False):
1311
+ bp_args.append("notch=True")
1312
+ if dinfo.get('show_mean', False):
1313
+ bp_args.append("showmeans=True")
1314
+ bp_args.append("meanline=True")
1315
+ flier_marker = dinfo.get('flier_marker', 'o')
1316
+ if not flier_marker:
1317
+ bp_args.append("showfliers=False")
1318
+
1319
+ # boxprops
1320
+ bp_kw = [
1321
+ f"facecolor={_fmt(dinfo.get('box_color', '#1f77b4'))}",
1322
+ f"edgecolor={_fmt(dinfo.get('box_edgecolor', '#000000'))}",
1323
+ ]
1324
+ blw = dinfo.get('box_lw', 1.0)
1325
+ if blw != 1.0:
1326
+ bp_kw.append(f"linewidth={_fmt(blw)}")
1327
+ ba = dinfo.get('box_alpha', 1.0)
1328
+ if ba != 1.0:
1329
+ bp_kw.append(f"alpha={_fmt(ba)}")
1330
+ if dinfo.get('box_hatch', ''):
1331
+ bp_kw.append(f"hatch={_fmt(dinfo['box_hatch'])}")
1332
+ bp_args.append(f"boxprops=dict({', '.join(bp_kw)})")
1333
+
1334
+ # medianprops
1335
+ mp_kw = [f"color={_fmt(dinfo.get('median_color', '#ff7f0e'))}"]
1336
+ mlw = dinfo.get('median_lw', 2.0)
1337
+ if mlw != 2.0:
1338
+ mp_kw.append(f"linewidth={_fmt(mlw)}")
1339
+ bp_args.append(f"medianprops=dict({', '.join(mp_kw)})")
1340
+
1341
+ # whiskerprops
1342
+ wp_kw = []
1343
+ ws = dinfo.get('whisker_style', '-')
1344
+ if ws != '-':
1345
+ wp_kw.append(f"linestyle={_fmt(ws)}")
1346
+ wlw = dinfo.get('whisker_lw', 1.0)
1347
+ if wlw != 1.0:
1348
+ wp_kw.append(f"linewidth={_fmt(wlw)}")
1349
+ if wp_kw:
1350
+ bp_args.append(f"whiskerprops=dict({', '.join(wp_kw)})")
1351
+
1352
+ # capprops
1353
+ cp_kw = []
1354
+ if wlw != 1.0:
1355
+ cp_kw.append(f"linewidth={_fmt(wlw)}")
1356
+ if cp_kw:
1357
+ bp_args.append(f"capprops=dict({', '.join(cp_kw)})")
1358
+
1359
+ # flierprops
1360
+ if flier_marker:
1361
+ fl_kw = [
1362
+ f"marker={_fmt(flier_marker)}",
1363
+ f"markersize={_fmt(dinfo.get('flier_size', 6.0))}",
1364
+ f"markerfacecolor={_fmt(dinfo.get('flier_color', '#000000'))}",
1365
+ ]
1366
+ bp_args.append(f"flierprops=dict({', '.join(fl_kw)})")
1367
+
1368
+ # meanprops
1369
+ if dinfo.get('show_mean', False):
1370
+ ms = dinfo.get('mean_style', '--')
1371
+ mc = dinfo.get('median_color', '#ff7f0e')
1372
+ bp_args.append(
1373
+ f"meanprops=dict(linestyle={_fmt(ms)}, color={_fmt(mc)})")
1374
+
1375
+ bvar = f"_bp_{group_idx}" if n_groups > 1 else "_bp"
1376
+ lines.append(f"{bvar} = {ax_var}.boxplot({data_var},")
1377
+ for k, arg in enumerate(bp_args):
1378
+ comma = "," if k < len(bp_args) - 1 else ")"
1379
+ lines.append(f" {arg}{comma}")
1380
+ lines.append(f"{bvar}['boxes'][0].set_label({_fmt(label)})")
1381
+
1382
+
1383
+ def _emit_hist_merged(lines, ax_var, hist_infos, data_vars=None):
1384
+ """Emit merged histogram code.
1385
+
1386
+ When *data_vars* are provided (Apply path — AST-extracted from the
1387
+ user's cell), emits an executable ``ax.hist()`` call that replaces
1388
+ the user's original separate calls (which get commented out by
1389
+ ``_close_apply``). Returns ``True`` so the caller can skip the
1390
+ redundant style-only loop.
1391
+
1392
+ Without *data_vars* (Copy / fallback), emits a comment template
1393
+ and returns ``False`` so the caller still emits style-only code.
1394
+ """
1395
+ ref = hist_infos[0]
1396
+ n_bins = ref.get('bins', 20)
1397
+
1398
+ colors = []
1399
+ labels = []
1400
+ for hi in hist_infos:
1401
+ colors.append(hi.get('color', '#1f77b4'))
1402
+ labels.append(hi.get('label', f'Hist {len(labels)}'))
1403
+
1404
+ histtype = ref.get('histtype', 'bar')
1405
+ orientation = ref.get('orientation', 'vertical')
1406
+ cumulative = ref.get('cumulative', False)
1407
+ mode = ref.get('mode', 'count')
1408
+ density = mode == 'density'
1409
+ rwidth = ref.get('rwidth', 0.8)
1410
+ ec = ref.get('edgecolor', '#000000')
1411
+ lw = ref.get('linewidth', 1.0)
1412
+ alpha = ref.get('alpha', 0.7)
1413
+
1414
+ kw_parts = [f"bins={_fmt(n_bins)}", f"histtype={_fmt(histtype)}"]
1415
+ if density:
1416
+ kw_parts.append("density=True")
1417
+ if cumulative:
1418
+ kw_parts.append("cumulative=True")
1419
+ if orientation != 'vertical':
1420
+ kw_parts.append(f"orientation={_fmt(orientation)}")
1421
+ kw_parts.append(f"color={colors!r}")
1422
+ kw_parts.append(f"edgecolor={_fmt(ec)}")
1423
+ kw_parts.append(f"linewidth={_fmt(lw)}")
1424
+ kw_parts.append(f"alpha={_fmt(alpha)}")
1425
+ kw_parts.append(f"label={labels!r}")
1426
+ kw_parts.append(f"rwidth={_fmt(rwidth)}")
1427
+
1428
+ have_vars = (data_vars
1429
+ and len(data_vars) == len(hist_infos)
1430
+ and all(not v.startswith('<') for v in data_vars))
1431
+
1432
+ if have_vars:
1433
+ # --- Executable merged call (Apply path) ---
1434
+ data_list = ", ".join(data_vars)
1435
+ lines.append(f"\n{ax_var}.hist([{data_list}],")
1436
+ for k, kw in enumerate(kw_parts):
1437
+ comma = "," if k < len(kw_parts) - 1 else ")"
1438
+ lines.append(f" {kw}{comma}")
1439
+
1440
+ # Per-histogram post-style (hatch, differing edgecolors, etc.)
1441
+ ecs = [hi.get('edgecolor', ec) for hi in hist_infos]
1442
+ alphas = [hi.get('alpha', alpha) for hi in hist_infos]
1443
+ for i, hi in enumerate(hist_infos):
1444
+ fixups = []
1445
+ if ecs[i] != ec:
1446
+ fixups.append(f"_p.set_edgecolor({_fmt(ecs[i])})")
1447
+ if alphas[i] != alpha:
1448
+ fixups.append(f"_p.set_alpha({_fmt(alphas[i])})")
1449
+ hatch = hi.get('hatch', '')
1450
+ if hatch:
1451
+ fixups.append(f"_p.set_hatch({_fmt(hatch)})")
1452
+ if fixups:
1453
+ lines.append(f"for _p in {ax_var}.containers[{i}]:")
1454
+ for f in fixups:
1455
+ lines.append(f" {f}")
1456
+ return True # skip style-only loop
1457
+ else:
1458
+ # --- Comment template (Copy / fallback) ---
1459
+ lines.append(
1460
+ f"\n# Merged histograms \u2014 replace your separate "
1461
+ f"ax.hist() calls with:")
1462
+ data_hint = ", ".join(f"<{lbl}>" for lbl in labels)
1463
+ lines.append(f"# {ax_var}.hist([{data_hint}],")
1464
+ for k, kw in enumerate(kw_parts):
1465
+ comma = "," if k < len(kw_parts) - 1 else ")"
1466
+ lines.append(f"# {kw}{comma}")
1467
+
1468
+ for i, hi in enumerate(hist_infos):
1469
+ hatch = hi.get('hatch', '')
1470
+ if hatch:
1471
+ lines.append(
1472
+ f"# for _p in {ax_var}.containers[{i}]:")
1473
+ lines.append(f"# _p.set_hatch({_fmt(hatch)})")
1474
+ return False # emit style-only loop
1475
+
1476
+
1477
+ def _emit_heatmap(lines, ax_var, heatmap_infos):
1478
+ """Emit code for heatmap styling (cmap, norm, clim, etc.)."""
1479
+ for idx, info in enumerate(heatmap_infos):
1480
+ htype = info.get('heatmap_type', 'imshow')
1481
+ cmap = info.get('cmap', 'viridis')
1482
+ vmin = info.get('vmin')
1483
+ vmax = info.get('vmax')
1484
+ norm_type = info.get('norm_type', 'linear')
1485
+ alpha = info.get('alpha', 1.0)
1486
+ interp = info.get('interpolation')
1487
+ aspect = info.get('aspect')
1488
+ annot_enabled = info.get('annot_enabled', False)
1489
+ annot_fmt = info.get('annot_fmt', '.2f')
1490
+ annot_fontsize = info.get('annot_fontsize', 8.0)
1491
+ annot_color = info.get('annot_color', 'auto')
1492
+ grid_enabled = info.get('grid_enabled', False)
1493
+ grid_lw = info.get('grid_lw', 1.0)
1494
+ grid_color = info.get('grid_color', '#ffffff')
1495
+ data = info.get('data')
1496
+
1497
+ lines.append(f"\n# Heatmap ({htype})")
1498
+
1499
+ if htype == 'imshow':
1500
+ var = f"_im = {ax_var}.images[{idx}]"
1501
+ acc = "_im"
1502
+ else:
1503
+ var = f"_qm = {ax_var}.collections[{idx}]"
1504
+ acc = "_qm"
1505
+ lines.append(var)
1506
+
1507
+ lines.append(f"{acc}.set_cmap({_fmt(cmap)})")
1508
+
1509
+ if norm_type == 'log':
1510
+ _vmin = max(vmin, 1e-10) if vmin is not None else 1e-10
1511
+ lines.append(
1512
+ f"from matplotlib.colors import LogNorm")
1513
+ lines.append(
1514
+ f"{acc}.set_norm(LogNorm("
1515
+ f"vmin={_fmt(_vmin)}, vmax={_fmt(vmax)}))")
1516
+ elif norm_type == 'symlog':
1517
+ lines.append(
1518
+ f"from matplotlib.colors import SymLogNorm")
1519
+ lines.append(
1520
+ f"{acc}.set_norm(SymLogNorm("
1521
+ f"linthresh=1.0, vmin={_fmt(vmin)}, vmax={_fmt(vmax)}))")
1522
+ elif norm_type == 'centered':
1523
+ lines.append(
1524
+ f"from matplotlib.colors import CenteredNorm")
1525
+ lines.append(
1526
+ f"{acc}.set_norm(CenteredNorm(vcenter=0))")
1527
+ else:
1528
+ lines.append(
1529
+ f"{acc}.set_clim({_fmt(vmin)}, {_fmt(vmax)})")
1530
+
1531
+ if htype == 'imshow' and interp:
1532
+ lines.append(f"{acc}.set_interpolation({_fmt(interp)})")
1533
+
1534
+ if htype == 'imshow' and aspect and str(aspect) != 'equal':
1535
+ lines.append(f"{ax_var}.set_aspect({_fmt(str(aspect))})")
1536
+
1537
+ if alpha is not None and round(alpha, 2) != 1.0:
1538
+ lines.append(f"{acc}.set_alpha({_fmt(alpha)})")
1539
+
1540
+ # Annotations
1541
+ if annot_enabled and data is not None:
1542
+ data_arr = np.asarray(data)
1543
+ if data_arr.ndim >= 2:
1544
+ nrows, ncols = data_arr.shape
1545
+ lines.append(f"\n# Annotations")
1546
+ lines.append(f"_data = {acc}.get_array()")
1547
+ lines.append(f"if hasattr(_data, 'reshape'):")
1548
+ lines.append(f" _data = _data.reshape({nrows}, {ncols})")
1549
+ lines.append(f"_vmin, _vmax = {acc}.get_clim()")
1550
+ lines.append(f"_vmid = (_vmin + _vmax) / 2.0")
1551
+ lines.append(f"for _i in range({nrows}):")
1552
+ lines.append(f" for _j in range({ncols}):")
1553
+ lines.append(f" _val = _data[_i, _j]")
1554
+ if annot_color == 'auto':
1555
+ lines.append(
1556
+ f" _c = 'white' if _val > _vmid else 'black'")
1557
+ else:
1558
+ lines.append(f" _c = {_fmt(annot_color)}")
1559
+ lines.append(
1560
+ f" {ax_var}.text(_j, _i, "
1561
+ f"format(_val, {_fmt(annot_fmt)}), "
1562
+ f"ha='center', va='center', "
1563
+ f"fontsize={_fmt(annot_fontsize)}, color=_c)")
1564
+
1565
+ # Grid
1566
+ if grid_enabled and data is not None:
1567
+ data_arr = np.asarray(data)
1568
+ if data_arr.ndim >= 2:
1569
+ nrows, ncols = data_arr.shape
1570
+ lines.append(f"\n# Grid lines")
1571
+ lines.append(
1572
+ f"{ax_var}.set_xticks("
1573
+ f"np.arange(-0.5, {ncols}, 1), minor=True)")
1574
+ lines.append(
1575
+ f"{ax_var}.set_yticks("
1576
+ f"np.arange(-0.5, {nrows}, 1), minor=True)")
1577
+ lines.append(
1578
+ f"{ax_var}.grid(which='minor', "
1579
+ f"color={_fmt(grid_color)}, "
1580
+ f"linewidth={_fmt(grid_lw)}, linestyle='-')")
1581
+ lines.append(
1582
+ f"{ax_var}.tick_params(which='minor', length=0)")
1583
+
1584
+ # Tick labels
1585
+ xtick_show = info.get('xtick_show', True)
1586
+ ytick_show = info.get('ytick_show', True)
1587
+ xtick_labels_str = info.get('xtick_labels', '')
1588
+ ytick_labels_str = info.get('ytick_labels', '')
1589
+
1590
+ if not xtick_show or not ytick_show or xtick_labels_str or ytick_labels_str:
1591
+ lines.append(f"\n# Tick labels")
1592
+
1593
+ if not xtick_show:
1594
+ lines.append(
1595
+ f"{ax_var}.tick_params(axis='x', "
1596
+ f"bottom=False, labelbottom=False)")
1597
+ elif xtick_labels_str.strip():
1598
+ labels = [l.strip() for l in xtick_labels_str.split(',')]
1599
+ positions = list(range(len(labels)))
1600
+ lines.append(f"{ax_var}.set_xticks({positions})")
1601
+ lines.append(f"{ax_var}.set_xticklabels({labels!r})")
1602
+
1603
+ if not ytick_show:
1604
+ lines.append(
1605
+ f"{ax_var}.tick_params(axis='y', "
1606
+ f"left=False, labelleft=False)")
1607
+ elif ytick_labels_str.strip():
1608
+ labels = [l.strip() for l in ytick_labels_str.split(',')]
1609
+ positions = list(range(len(labels)))
1610
+ lines.append(f"{ax_var}.set_yticks({positions})")
1611
+ lines.append(f"{ax_var}.set_yticklabels({labels!r})")
1612
+
1613
+
1614
+ def _emit_colorbar(lines, ax_var, cbar_info, heatmap_infos):
1615
+ """Emit colorbar creation code."""
1616
+ htype = heatmap_infos[0].get('heatmap_type', 'imshow')
1617
+ if htype == 'imshow':
1618
+ mappable_var = f"{ax_var}.images[0]"
1619
+ else:
1620
+ mappable_var = f"{ax_var}.collections[0]"
1621
+
1622
+ loc = cbar_info.get('location', 'right')
1623
+ shrink = cbar_info.get('shrink', 1.0)
1624
+ pad = cbar_info.get('pad', 0.05)
1625
+ label = cbar_info.get('label', '')
1626
+ label_fs = cbar_info.get('label_fontsize', 12.0)
1627
+ tick_fs = cbar_info.get('tick_fontsize', 10.0)
1628
+
1629
+ lines.append(f"\n# Colorbar")
1630
+
1631
+ # When location/shrink/pad differ from defaults, must remove + recreate
1632
+ needs_recreate = (loc != 'right'
1633
+ or round(shrink, 2) != 1.0
1634
+ or round(pad, 2) != 0.05)
1635
+
1636
+ if needs_recreate:
1637
+ lines.append(
1638
+ f"if getattr({mappable_var}, 'colorbar', None) is not None:")
1639
+ lines.append(f" {mappable_var}.colorbar.remove()")
1640
+ cbar_args = [f"{mappable_var}", f"ax={ax_var}"]
1641
+ if loc != 'right':
1642
+ cbar_args.append(f"location={_fmt(loc)}")
1643
+ if round(shrink, 2) != 1.0:
1644
+ cbar_args.append(f"shrink={_fmt(shrink)}")
1645
+ if round(pad, 2) != 0.05:
1646
+ cbar_args.append(f"pad={_fmt(pad)}")
1647
+ lines.append(f"_cbar = fig.colorbar({', '.join(cbar_args)})")
1648
+ else:
1649
+ # Reuse existing colorbar when possible (avoids remove/recreate)
1650
+ lines.append(
1651
+ f"_cbar = getattr({mappable_var}, 'colorbar', None)")
1652
+ lines.append(f"if _cbar is None:")
1653
+ lines.append(
1654
+ f" _cbar = fig.colorbar({mappable_var}, ax={ax_var})")
1655
+ lines.append(f"else:")
1656
+ lines.append(f" _cbar.update_normal({mappable_var})")
1657
+
1658
+ if label:
1659
+ lines.append(
1660
+ f"_cbar.set_label({_fmt(label)}, fontsize={_fmt(label_fs)})")
1661
+ if round(tick_fs, 1) != 10.0:
1662
+ lines.append(f"_cbar.ax.tick_params(labelsize={_fmt(tick_fs)})")
1663
+
1664
+
1665
+ def _emit_errorbars(lines, ax_var, errorbar_infos):
1666
+ """Emit code for errorbar styling.
1667
+
1668
+ Always uses ErrorbarContainer (ax.errorbar()). Four toggles control
1669
+ visibility: error bars, markers, connecting line, shaded region.
1670
+ Each section has its own color.
1671
+ """
1672
+ for idx, info in enumerate(errorbar_infos):
1673
+ show_bars = info.get('show_bars', True)
1674
+ show_line = info.get('show_line', True)
1675
+ show_markers = info.get('show_markers', False)
1676
+ show_shaded = info.get('show_shaded', False)
1677
+ bar_color = info.get('bar_color', info.get('color', '#1f77b4'))
1678
+ marker_color = info.get('marker_color', bar_color)
1679
+ line_color = info.get('line_color', bar_color)
1680
+ shade_color = info.get('shade_color', bar_color)
1681
+ bar_alpha = info.get('bar_alpha', 1.0)
1682
+ marker_alpha = info.get('marker_alpha', 1.0)
1683
+ line_alpha = info.get('line_alpha', 1.0)
1684
+ lw = info.get('line_width', 1.5)
1685
+ ls = info.get('line_style', '-')
1686
+ bar_lw = info.get('bar_lw', 1.5)
1687
+ cap_size = info.get('cap_size', 3.0)
1688
+ marker = info.get('marker', '')
1689
+ marker_size = info.get('marker_size', 6.0)
1690
+ label = info.get('label', f'Errorbar {idx}')
1691
+
1692
+ lines.append(f"\n# Errorbar: {label}")
1693
+
1694
+ # Find the ErrorbarContainer by index
1695
+ lines.append(
1696
+ f"from matplotlib.container import "
1697
+ f"ErrorbarContainer as _EBC")
1698
+ lines.append(
1699
+ f"_eb_containers = [c for c in {ax_var}.containers "
1700
+ f"if isinstance(c, _EBC)]")
1701
+ lines.append(f"_eb = _eb_containers[{idx}]")
1702
+
1703
+ # Data line styling
1704
+ lines.append(f"if _eb[0] is not None:")
1705
+ if show_line:
1706
+ lines.append(f" _eb[0].set_color({_fmt(line_color)})")
1707
+ lines.append(f" _eb[0].set_linewidth({_fmt(lw)})")
1708
+ lines.append(f" _eb[0].set_linestyle({_fmt(ls)})")
1709
+ if line_alpha != 1.0:
1710
+ lines.append(
1711
+ f" _eb[0].set_alpha({_fmt(line_alpha)})")
1712
+ else:
1713
+ lines.append(f" _eb[0].set_linestyle('none')")
1714
+ lines.append(f" _eb[0].set_linewidth(0)")
1715
+ if show_markers and marker and marker not in ('', 'None', 'none'):
1716
+ lines.append(f" _eb[0].set_marker({_fmt(marker)})")
1717
+ lines.append(
1718
+ f" _eb[0].set_markersize({_fmt(marker_size)})")
1719
+ if marker_alpha != 1.0:
1720
+ # Apply marker alpha via RGBA to avoid affecting line
1721
+ lines.append(
1722
+ f" import matplotlib.colors as _mc")
1723
+ lines.append(
1724
+ f" _mkr_rgba = list("
1725
+ f"_mc.to_rgba({_fmt(marker_color)}))")
1726
+ lines.append(
1727
+ f" _mkr_rgba[3] = {_fmt(marker_alpha)}")
1728
+ lines.append(
1729
+ f" _eb[0].set_markerfacecolor(_mkr_rgba)")
1730
+ lines.append(
1731
+ f" _eb[0].set_markeredgecolor(_mkr_rgba)")
1732
+ else:
1733
+ lines.append(
1734
+ f" _eb[0].set_markerfacecolor("
1735
+ f"{_fmt(marker_color)})")
1736
+ lines.append(
1737
+ f" _eb[0].set_markeredgecolor("
1738
+ f"{_fmt(marker_color)})")
1739
+
1740
+ # Cap lines
1741
+ lines.append(f"for _cap in _eb[1]:")
1742
+ lines.append(f" _cap.set_color({_fmt(bar_color)})")
1743
+ if show_bars:
1744
+ if cap_size > 0:
1745
+ lines.append(
1746
+ f" _cap.set_markersize({_fmt(cap_size)})")
1747
+ else:
1748
+ lines.append(f" _cap.set_markersize(0)")
1749
+ if bar_alpha != 1.0:
1750
+ lines.append(f" _cap.set_alpha({_fmt(bar_alpha)})")
1751
+
1752
+ # Bar line collections
1753
+ lines.append(f"for _bar in _eb[2]:")
1754
+ lines.append(f" _bar.set_color({_fmt(bar_color)})")
1755
+ if show_bars:
1756
+ lines.append(f" _bar.set_linewidth({_fmt(bar_lw)})")
1757
+ else:
1758
+ lines.append(f" _bar.set_linewidth(0)")
1759
+ if bar_alpha != 1.0:
1760
+ lines.append(f" _bar.set_alpha({_fmt(bar_alpha)})")
1761
+
1762
+ # Shaded region — create fill_between from bar segment data
1763
+ if show_shaded and info.get('has_yerr', False):
1764
+ shade_alpha = info.get('shade_alpha', 0.3)
1765
+ lines.append(f"_segs = _eb[2][0].get_segments()")
1766
+ lines.append(f"if _segs:")
1767
+ lines.append(
1768
+ f" _x_s = np.array([s[0][0] for s in _segs])")
1769
+ lines.append(
1770
+ f" _y_lo = np.array([s[0][1] for s in _segs])")
1771
+ lines.append(
1772
+ f" _y_hi = np.array([s[1][1] for s in _segs])")
1773
+ lines.append(
1774
+ f" {ax_var}.fill_between(_x_s, _y_lo, _y_hi, "
1775
+ f"color={_fmt(shade_color)}, "
1776
+ f"alpha={_fmt(shade_alpha)})")
1777
+
1778
+
1779
+ def _fmt(val: Any) -> str:
1780
+ """Format a value as a Python literal."""
1781
+ if isinstance(val, bool):
1782
+ return repr(val)
1783
+ # Preserve integers (including numpy int types)
1784
+ import numbers
1785
+ if isinstance(val, numbers.Integral):
1786
+ return repr(int(val))
1787
+ # Handle numpy floats and regular floats
1788
+ try:
1789
+ f = float(val)
1790
+ return repr(round(f, 2))
1791
+ except (TypeError, ValueError):
1792
+ pass
1793
+ return repr(val)