matplotly 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- matplotly/__init__.py +124 -0
- matplotly/_api.py +984 -0
- matplotly/_code_gen.py +1793 -0
- matplotly/_commands.py +109 -0
- matplotly/_introspect.py +1197 -0
- matplotly/_profiles.py +241 -0
- matplotly/_renderer.py +79 -0
- matplotly/_style_import.py +155 -0
- matplotly/_types.py +31 -0
- matplotly/panels/__init__.py +37 -0
- matplotly/panels/_bar.py +788 -0
- matplotly/panels/_base.py +38 -0
- matplotly/panels/_color_utils.py +221 -0
- matplotly/panels/_distribution.py +1605 -0
- matplotly/panels/_errorbar.py +652 -0
- matplotly/panels/_fill.py +90 -0
- matplotly/panels/_global.py +1507 -0
- matplotly/panels/_heatmap.py +898 -0
- matplotly/panels/_histogram.py +938 -0
- matplotly/panels/_line.py +709 -0
- matplotly/panels/_marginal.py +944 -0
- matplotly/panels/_scatter.py +428 -0
- matplotly/panels/_subplot.py +846 -0
- matplotly-0.1.0.dist-info/METADATA +120 -0
- matplotly-0.1.0.dist-info/RECORD +27 -0
- matplotly-0.1.0.dist-info/WHEEL +4 -0
- matplotly-0.1.0.dist-info/licenses/LICENSE +21 -0
matplotly/_code_gen.py
ADDED
|
@@ -0,0 +1,1793 @@
|
|
|
1
|
+
"""Generate Python code from the current figure state."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from matplotlib.colors import to_hex
|
|
8
|
+
from matplotlib.figure import Figure
|
|
9
|
+
from matplotlib.lines import Line2D
|
|
10
|
+
|
|
11
|
+
from ._commands import CommandStack
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def generate_code(fig: Figure, stack: CommandStack) -> str:
|
|
15
|
+
"""Produce a Python code snippet reproducing the figure's current styling.
|
|
16
|
+
|
|
17
|
+
Emits style modifications that apply on top of the user's original
|
|
18
|
+
plotting code. For distribution plots (boxplot/violinplot), emits
|
|
19
|
+
ax.bxp() calls with compact statistics and styling kwargs.
|
|
20
|
+
"""
|
|
21
|
+
all_axes = fig.get_axes()
|
|
22
|
+
# Separate main axes from marginal histogram and colorbar axes
|
|
23
|
+
axes_list = [a for a in all_axes
|
|
24
|
+
if not getattr(a, '_matplotly_marginal', False)
|
|
25
|
+
and not hasattr(a, '_colorbar')]
|
|
26
|
+
marginal_axes = [a for a in all_axes if getattr(a, '_matplotly_marginal', False)]
|
|
27
|
+
|
|
28
|
+
if not axes_list:
|
|
29
|
+
return "# No axes found."
|
|
30
|
+
|
|
31
|
+
multi = len(axes_list) > 1
|
|
32
|
+
|
|
33
|
+
lines: list[str] = ["# Generated by matplotly v4"]
|
|
34
|
+
|
|
35
|
+
if multi:
|
|
36
|
+
lines.append("axes = fig.get_axes()")
|
|
37
|
+
|
|
38
|
+
for ax_i, ax in enumerate(axes_list):
|
|
39
|
+
ax_var = f"axes[{ax_i}]" if multi else "ax"
|
|
40
|
+
|
|
41
|
+
# --- Identify distribution-managed artists (for skip logic) ---
|
|
42
|
+
dist_artist_ids: set[int] = set()
|
|
43
|
+
_dist_infos = getattr(ax, '_matplotly_dist_info', [])
|
|
44
|
+
has_dist = bool(_dist_infos)
|
|
45
|
+
|
|
46
|
+
# --- Identify heatmap-managed artists (for skip logic) ---
|
|
47
|
+
heatmap_artist_ids: set[int] = set()
|
|
48
|
+
for _hi in getattr(ax, '_matplotly_heatmap_info', []):
|
|
49
|
+
m = _hi.get('mappable')
|
|
50
|
+
if m is not None:
|
|
51
|
+
heatmap_artist_ids.add(id(m))
|
|
52
|
+
|
|
53
|
+
if has_dist:
|
|
54
|
+
for art in ax.lines:
|
|
55
|
+
if art.get_label().startswith("_"):
|
|
56
|
+
dist_artist_ids.add(id(art))
|
|
57
|
+
for coll in ax.collections:
|
|
58
|
+
if getattr(coll, '_matplotly_dist', False):
|
|
59
|
+
dist_artist_ids.add(id(coll))
|
|
60
|
+
if getattr(coll, '_matplotly_dist_jitter', False):
|
|
61
|
+
dist_artist_ids.add(id(coll))
|
|
62
|
+
for patch in ax.patches:
|
|
63
|
+
if getattr(patch, '_matplotly_dist', False):
|
|
64
|
+
dist_artist_ids.add(id(patch))
|
|
65
|
+
|
|
66
|
+
# --- Identify errorbar-managed artists (for skip logic) ---
|
|
67
|
+
errorbar_artist_ids: set[int] = set()
|
|
68
|
+
for l in ax.lines:
|
|
69
|
+
if getattr(l, '_matplotly_errorbar', False):
|
|
70
|
+
errorbar_artist_ids.add(id(l))
|
|
71
|
+
for coll in ax.collections:
|
|
72
|
+
if getattr(coll, '_matplotly_errorbar', False):
|
|
73
|
+
errorbar_artist_ids.add(id(coll))
|
|
74
|
+
|
|
75
|
+
# --- Lines (filtered list for stable indexing) ---
|
|
76
|
+
_user_lines = [l for l in ax.lines
|
|
77
|
+
if not l.get_label().startswith("_")
|
|
78
|
+
and id(l) not in errorbar_artist_ids]
|
|
79
|
+
if _user_lines:
|
|
80
|
+
lines.append(
|
|
81
|
+
f"\n_lines = [l for l in {ax_var}.lines "
|
|
82
|
+
f"if not l.get_label().startswith('_')]")
|
|
83
|
+
for j, line in enumerate(_user_lines):
|
|
84
|
+
label = line.get_label()
|
|
85
|
+
acc = f"_lines[{j}]"
|
|
86
|
+
lines.append(f"\n# {label or f'Line {j}'}")
|
|
87
|
+
lines.append(f"{acc}.set_color({_fmt(to_hex(line.get_color()))})")
|
|
88
|
+
lines.append(f"{acc}.set_linewidth({_fmt(line.get_linewidth())})")
|
|
89
|
+
ls = line.get_linestyle()
|
|
90
|
+
if ls != "-":
|
|
91
|
+
lines.append(f"{acc}.set_linestyle({_fmt(ls)})")
|
|
92
|
+
alpha = line.get_alpha()
|
|
93
|
+
if alpha is not None and alpha != 1.0:
|
|
94
|
+
lines.append(f"{acc}.set_alpha({_fmt(alpha)})")
|
|
95
|
+
marker = line.get_marker()
|
|
96
|
+
if marker and marker != "None" and marker != "none":
|
|
97
|
+
lines.append(f"{acc}.set_marker({_fmt(marker)})")
|
|
98
|
+
lines.append(f"{acc}.set_markersize({_fmt(line.get_markersize())})")
|
|
99
|
+
|
|
100
|
+
# --- Histograms ---
|
|
101
|
+
hist_patch_ids: set[int] = set()
|
|
102
|
+
hist_container_idxs: list[int] = []
|
|
103
|
+
from matplotlib.container import BarContainer as _BCt
|
|
104
|
+
from matplotlib.patches import Rectangle as _Rect
|
|
105
|
+
|
|
106
|
+
_hist_infos = getattr(ax, '_matplotly_hist_info', [])
|
|
107
|
+
_hist_merged = any(hi.get('merged', False) for hi in _hist_infos)
|
|
108
|
+
_skip_hist_style = False
|
|
109
|
+
|
|
110
|
+
if _hist_infos and _hist_merged:
|
|
111
|
+
# Merged: all BarContainers are histogram containers
|
|
112
|
+
for ci, _c in enumerate(ax.containers):
|
|
113
|
+
if isinstance(_c, _BCt):
|
|
114
|
+
hist_container_idxs.append(ci)
|
|
115
|
+
for _p in _c:
|
|
116
|
+
hist_patch_ids.add(id(_p))
|
|
117
|
+
for hi in _hist_infos:
|
|
118
|
+
for art in hi.get('artists', []):
|
|
119
|
+
hist_patch_ids.add(id(art))
|
|
120
|
+
_data_vars = getattr(ax, '_matplotly_hist_data_vars', None)
|
|
121
|
+
_skip_hist_style = _emit_hist_merged(
|
|
122
|
+
lines, ax_var, _hist_infos, data_vars=_data_vars)
|
|
123
|
+
else:
|
|
124
|
+
# Non-merged: detect histograms via geometry or attribute
|
|
125
|
+
for ci, _c in enumerate(ax.containers):
|
|
126
|
+
if isinstance(_c, _BCt):
|
|
127
|
+
from ._introspect import FigureIntrospector as _FI_cg
|
|
128
|
+
if (_FI_cg._is_histogram_container(_c)
|
|
129
|
+
or getattr(_c, '_matplotly_is_histogram', False)):
|
|
130
|
+
hist_container_idxs.append(ci)
|
|
131
|
+
for _p in _c:
|
|
132
|
+
hist_patch_ids.add(id(_p))
|
|
133
|
+
|
|
134
|
+
if not _skip_hist_style:
|
|
135
|
+
# Build label lookup from hist_info
|
|
136
|
+
_hi_label_map = {}
|
|
137
|
+
for _hi in _hist_infos:
|
|
138
|
+
_hi_arts = _hi.get('artists', [])
|
|
139
|
+
if _hi_arts:
|
|
140
|
+
_hi_label_map[id(_hi_arts[0])] = _hi.get('label', '')
|
|
141
|
+
|
|
142
|
+
# Style-only code for all detected histogram containers
|
|
143
|
+
for _hj, _hci in enumerate(hist_container_idxs):
|
|
144
|
+
_hcont = ax.containers[_hci]
|
|
145
|
+
_hpats = [p for p in _hcont.patches
|
|
146
|
+
if isinstance(p, _Rect)]
|
|
147
|
+
if not _hpats:
|
|
148
|
+
continue
|
|
149
|
+
_hp0 = _hpats[0]
|
|
150
|
+
_hlbl = _hi_label_map.get(id(_hp0), '')
|
|
151
|
+
if not _hlbl or _hlbl.startswith("_"):
|
|
152
|
+
_hlbl = _hp0.get_label()
|
|
153
|
+
if not _hlbl or _hlbl.startswith("_"):
|
|
154
|
+
_hlbl = _hcont.get_label()
|
|
155
|
+
if not _hlbl or _hlbl.startswith("_"):
|
|
156
|
+
_hlbl = f"Histogram {_hj}"
|
|
157
|
+
_hfc = to_hex(_hp0.get_facecolor())
|
|
158
|
+
_hec = to_hex(_hp0.get_edgecolor())
|
|
159
|
+
_halpha = _hp0.get_alpha()
|
|
160
|
+
_hlw = _hp0.get_linewidth()
|
|
161
|
+
|
|
162
|
+
lines.append(f"\n# Histogram: {_hlbl}")
|
|
163
|
+
_hacc = f"{ax_var}.containers[{_hci}]"
|
|
164
|
+
lines.append(f"for _p in {_hacc}:")
|
|
165
|
+
lines.append(f" _p.set_facecolor({_fmt(_hfc)})")
|
|
166
|
+
lines.append(f" _p.set_edgecolor({_fmt(_hec)})")
|
|
167
|
+
if _halpha is not None and round(_halpha, 2) != 1.0:
|
|
168
|
+
lines.append(
|
|
169
|
+
f" _p.set_alpha({_fmt(round(_halpha, 2))})")
|
|
170
|
+
if round(_hlw, 1) != 1.0:
|
|
171
|
+
lines.append(
|
|
172
|
+
f" _p.set_linewidth({_fmt(round(_hlw, 1))})")
|
|
173
|
+
_hhatch = _hp0.get_hatch()
|
|
174
|
+
if _hhatch:
|
|
175
|
+
lines.append(f" _p.set_hatch({_fmt(_hhatch)})")
|
|
176
|
+
|
|
177
|
+
# --- Bars (style existing containers) ---
|
|
178
|
+
bar_patch_ids: set[int] = set()
|
|
179
|
+
bar_container_idxs: list[int] = []
|
|
180
|
+
for ci, _c in enumerate(ax.containers):
|
|
181
|
+
if isinstance(_c, _BCt) and ci not in hist_container_idxs:
|
|
182
|
+
bar_container_idxs.append(ci)
|
|
183
|
+
for _p in _c:
|
|
184
|
+
bar_patch_ids.add(id(_p))
|
|
185
|
+
|
|
186
|
+
for _bj, _bci in enumerate(bar_container_idxs):
|
|
187
|
+
_bcont = ax.containers[_bci]
|
|
188
|
+
_bpats = [p for p in _bcont.patches if isinstance(p, _Rect)]
|
|
189
|
+
if not _bpats:
|
|
190
|
+
continue
|
|
191
|
+
_bp0 = _bpats[0]
|
|
192
|
+
_blbl = _bcont.get_label()
|
|
193
|
+
if not _blbl or _blbl.startswith("_"):
|
|
194
|
+
_blbl = _bp0.get_label()
|
|
195
|
+
if not _blbl or _blbl.startswith("_"):
|
|
196
|
+
_blbl = f"Bar group {_bj}"
|
|
197
|
+
|
|
198
|
+
_bfc = to_hex(_bp0.get_facecolor())
|
|
199
|
+
_bec = to_hex(_bp0.get_edgecolor())
|
|
200
|
+
_balpha = _bp0.get_alpha()
|
|
201
|
+
_blw = _bp0.get_linewidth()
|
|
202
|
+
|
|
203
|
+
lines.append(f"\n# {_blbl}")
|
|
204
|
+
_bacc = f"{ax_var}.containers[{_bci}]"
|
|
205
|
+
lines.append(f"for _p in {_bacc}:")
|
|
206
|
+
lines.append(f" _p.set_facecolor({_fmt(_bfc)})")
|
|
207
|
+
lines.append(f" _p.set_edgecolor({_fmt(_bec)})")
|
|
208
|
+
if _balpha is not None and round(_balpha, 2) != 1.0:
|
|
209
|
+
lines.append(
|
|
210
|
+
f" _p.set_alpha({_fmt(round(_balpha, 2))})")
|
|
211
|
+
if round(_blw, 1) != 1.0:
|
|
212
|
+
lines.append(
|
|
213
|
+
f" _p.set_linewidth({_fmt(round(_blw, 1))})")
|
|
214
|
+
_bhatch = _bp0.get_hatch()
|
|
215
|
+
if _bhatch:
|
|
216
|
+
lines.append(f" _p.set_hatch({_fmt(_bhatch)})")
|
|
217
|
+
|
|
218
|
+
# Per-patch geometry (width + position)
|
|
219
|
+
_bi_orient = (getattr(ax, '_matplotly_bar_info', [{}])[0]
|
|
220
|
+
.get('orientation', 'vertical')
|
|
221
|
+
if getattr(ax, '_matplotly_bar_info', None)
|
|
222
|
+
else 'vertical')
|
|
223
|
+
_b_vert = _bi_orient == 'vertical'
|
|
224
|
+
if _b_vert:
|
|
225
|
+
_bw_cur = round(float(_bp0.get_width()), 4)
|
|
226
|
+
_bx_list = [round(float(p.get_x()), 4) for p in _bpats]
|
|
227
|
+
lines.append(
|
|
228
|
+
f"for _p, _x in zip({_bacc}, {_bx_list!r}):")
|
|
229
|
+
lines.append(f" _p.set_width({_fmt(_bw_cur)})")
|
|
230
|
+
lines.append(f" _p.set_x(_x)")
|
|
231
|
+
else:
|
|
232
|
+
_bh_cur = round(float(_bp0.get_height()), 4)
|
|
233
|
+
_by_list = [round(float(p.get_y()), 4) for p in _bpats]
|
|
234
|
+
lines.append(
|
|
235
|
+
f"for _p, _y in zip({_bacc}, {_by_list!r}):")
|
|
236
|
+
lines.append(f" _p.set_height({_fmt(_bh_cur)})")
|
|
237
|
+
lines.append(f" _p.set_y(_y)")
|
|
238
|
+
|
|
239
|
+
# Bar tick labels (from bar panel adjustments)
|
|
240
|
+
_bar_infos = getattr(ax, '_matplotly_bar_info', [])
|
|
241
|
+
if _bar_infos:
|
|
242
|
+
_bi_ref = _bar_infos[0]
|
|
243
|
+
_bi_tl = _bi_ref.get('tick_labels', [])
|
|
244
|
+
_bi_tc = _bi_ref.get('tick_centers', [])
|
|
245
|
+
_bi_orient = _bi_ref.get('orientation', 'vertical')
|
|
246
|
+
if _bi_tl and _bi_tc:
|
|
247
|
+
_bi_tax = 'x' if _bi_orient == 'vertical' else 'y'
|
|
248
|
+
_bi_tc_str = ", ".join(
|
|
249
|
+
str(round(float(t), 4)) for t in _bi_tc)
|
|
250
|
+
lines.append(
|
|
251
|
+
f"{ax_var}.set_{_bi_tax}ticks([{_bi_tc_str}])")
|
|
252
|
+
_bi_rot = _bi_ref.get('tick_rotation', 0)
|
|
253
|
+
_bi_ha = _bi_ref.get('tick_ha', 'center')
|
|
254
|
+
_bi_tl_args = f"{_bi_tl!r}"
|
|
255
|
+
if _bi_rot:
|
|
256
|
+
_bi_tl_args += f", rotation={_fmt(_bi_rot)}"
|
|
257
|
+
if _bi_ha != "center":
|
|
258
|
+
_bi_tl_args += f", ha={_fmt(_bi_ha)}"
|
|
259
|
+
_bi_tl_args += ", rotation_mode='anchor'"
|
|
260
|
+
lines.append(
|
|
261
|
+
f"{ax_var}.set_{_bi_tax}ticklabels({_bi_tl_args})")
|
|
262
|
+
|
|
263
|
+
# --- Distribution plots (bxp-based recreation) ---
|
|
264
|
+
_dist_data_vars = getattr(ax, '_matplotly_dist_data_vars', None)
|
|
265
|
+
|
|
266
|
+
if _dist_infos:
|
|
267
|
+
_di_ref = _dist_infos[0]
|
|
268
|
+
_di_orient = _di_ref.get('orientation', 'vertical')
|
|
269
|
+
_di_width = _di_ref.get('width', 0.5)
|
|
270
|
+
_di_tl = _di_ref.get('tick_labels', [])
|
|
271
|
+
_di_tc = _di_ref.get('tick_centers',
|
|
272
|
+
_di_ref.get('positions', []))
|
|
273
|
+
_di_ng = len(_dist_infos)
|
|
274
|
+
|
|
275
|
+
if _dist_data_vars:
|
|
276
|
+
# --- Apply path: use real data variable ---
|
|
277
|
+
_emit_dist_with_data(
|
|
278
|
+
lines, ax_var, _dist_infos, _dist_data_vars,
|
|
279
|
+
_di_orient, _di_width, _di_ng)
|
|
280
|
+
else:
|
|
281
|
+
# --- Copy/fallback path: stats + fabricated data ---
|
|
282
|
+
lines.append(f"\n# Distribution plot")
|
|
283
|
+
|
|
284
|
+
# Cleanup original distribution artists so bxp() doesn't double
|
|
285
|
+
lines.append(
|
|
286
|
+
f"for _l in list({ax_var}.lines):")
|
|
287
|
+
lines.append(
|
|
288
|
+
f" if _l.get_label().startswith('_'):")
|
|
289
|
+
lines.append(
|
|
290
|
+
f" _l.remove()")
|
|
291
|
+
lines.append(
|
|
292
|
+
f"for _p in list({ax_var}.patches):")
|
|
293
|
+
lines.append(
|
|
294
|
+
f" if _p.get_label() == '' or _p.get_label().startswith('_'):")
|
|
295
|
+
lines.append(
|
|
296
|
+
f" try:")
|
|
297
|
+
lines.append(
|
|
298
|
+
f" _p.remove()")
|
|
299
|
+
lines.append(
|
|
300
|
+
f" except ValueError:")
|
|
301
|
+
lines.append(
|
|
302
|
+
f" pass")
|
|
303
|
+
|
|
304
|
+
for _dj, _dinfo in enumerate(_dist_infos):
|
|
305
|
+
_d_mode = _dinfo.get('display_mode', 'box')
|
|
306
|
+
_d_label = _dinfo.get('label', f'Group {_dj}')
|
|
307
|
+
_d_pos = _dinfo.get('positions', _di_tc)
|
|
308
|
+
_d_raw = _dinfo.get('raw_data', [])
|
|
309
|
+
|
|
310
|
+
lines.append(f"\n# {_d_label}")
|
|
311
|
+
|
|
312
|
+
if not _d_raw:
|
|
313
|
+
continue
|
|
314
|
+
|
|
315
|
+
# Prefer original introspected stats (avoids double
|
|
316
|
+
# reconstruction that degrades violin KDE shapes).
|
|
317
|
+
_orig_box_stats = _dinfo.get('box_stats', [])
|
|
318
|
+
if _orig_box_stats:
|
|
319
|
+
bxp_stats = _bxp_stats_from_original(
|
|
320
|
+
_orig_box_stats, _dinfo)
|
|
321
|
+
else:
|
|
322
|
+
bxp_stats = _compute_bxp_stats(_d_raw, _dinfo)
|
|
323
|
+
if not bxp_stats:
|
|
324
|
+
continue
|
|
325
|
+
svar = _emit_stats_var(lines, bxp_stats, _dj, _di_ng)
|
|
326
|
+
|
|
327
|
+
# Reconstruct data from stats for violin/jitter
|
|
328
|
+
_needs_data = 'violin' in _d_mode or 'jitter' in _d_mode
|
|
329
|
+
raw_var = None
|
|
330
|
+
if _needs_data:
|
|
331
|
+
raw_var = _emit_data_from_stats(
|
|
332
|
+
lines, svar, _dj, _di_ng)
|
|
333
|
+
|
|
334
|
+
# Violin (background layer)
|
|
335
|
+
if 'violin' in _d_mode and raw_var:
|
|
336
|
+
_emit_violin(lines, ax_var, _dinfo, raw_var, _d_pos,
|
|
337
|
+
_di_orient, _di_width, _d_mode,
|
|
338
|
+
_d_label)
|
|
339
|
+
|
|
340
|
+
# Box (middle layer)
|
|
341
|
+
if 'box' in _d_mode:
|
|
342
|
+
_emit_bxp_call(lines, ax_var, _dinfo, svar, _d_pos,
|
|
343
|
+
_d_label, _dj, _di_ng, _di_orient,
|
|
344
|
+
_di_width, _d_mode)
|
|
345
|
+
|
|
346
|
+
# Jitter (top layer)
|
|
347
|
+
if 'jitter' in _d_mode and raw_var:
|
|
348
|
+
_emit_jitter(lines, ax_var, _dinfo, raw_var, _d_pos,
|
|
349
|
+
_di_orient, _d_mode, _d_label)
|
|
350
|
+
|
|
351
|
+
# Tick labels (shared by both paths)
|
|
352
|
+
if _di_tl and _di_tc:
|
|
353
|
+
_d_tax = 'x' if _di_orient == 'vertical' else 'y'
|
|
354
|
+
_d_tc_str = ", ".join(
|
|
355
|
+
str(round(float(t), 4)) for t in _di_tc)
|
|
356
|
+
lines.append(
|
|
357
|
+
f"\n{ax_var}.set_{_d_tax}ticks([{_d_tc_str}])")
|
|
358
|
+
_d_ha = _di_ref.get('tick_ha', 'center')
|
|
359
|
+
_d_rot = _di_ref.get('tick_rotation', 0)
|
|
360
|
+
_d_pad = _di_ref.get('tick_pad', 4.0)
|
|
361
|
+
_d_tl_args = f"{_di_tl!r}"
|
|
362
|
+
if _d_rot:
|
|
363
|
+
_d_tl_args += f", rotation={_fmt(_d_rot)}"
|
|
364
|
+
if _d_ha != 'center':
|
|
365
|
+
_d_tl_args += f", ha={_fmt(_d_ha)}"
|
|
366
|
+
_d_tl_args += ", rotation_mode='anchor'"
|
|
367
|
+
lines.append(
|
|
368
|
+
f"{ax_var}.set_{_d_tax}ticklabels({_d_tl_args})")
|
|
369
|
+
if abs(_d_pad - 4.0) > 0.1:
|
|
370
|
+
lines.append(
|
|
371
|
+
f"{ax_var}.tick_params(axis={_fmt(_d_tax)},"
|
|
372
|
+
f" pad={_fmt(_d_pad)})")
|
|
373
|
+
|
|
374
|
+
# --- Heatmaps ---
|
|
375
|
+
_heatmap_infos = getattr(ax, '_matplotly_heatmap_info', [])
|
|
376
|
+
if _heatmap_infos:
|
|
377
|
+
_emit_heatmap(lines, ax_var, _heatmap_infos)
|
|
378
|
+
|
|
379
|
+
# --- Colorbar ---
|
|
380
|
+
_cbar_info = getattr(ax, '_matplotly_colorbar_info', None)
|
|
381
|
+
if _cbar_info and _cbar_info.get('show', False) and _heatmap_infos:
|
|
382
|
+
_emit_colorbar(lines, ax_var, _cbar_info, _heatmap_infos)
|
|
383
|
+
|
|
384
|
+
# --- Errorbars ---
|
|
385
|
+
_errorbar_infos = getattr(ax, '_matplotly_errorbar_info', [])
|
|
386
|
+
if _errorbar_infos:
|
|
387
|
+
_emit_errorbars(lines, ax_var, _errorbar_infos)
|
|
388
|
+
|
|
389
|
+
# --- Patches (generic — skip histogram + bar + dist patches) ---
|
|
390
|
+
_skip_ids = hist_patch_ids | bar_patch_ids | dist_artist_ids
|
|
391
|
+
for i, patch in enumerate(ax.patches):
|
|
392
|
+
if id(patch) in _skip_ids:
|
|
393
|
+
continue
|
|
394
|
+
if hasattr(patch, "get_facecolor"):
|
|
395
|
+
fc = to_hex(patch.get_facecolor())
|
|
396
|
+
if fc != "#1f77b4": # only if non-default
|
|
397
|
+
acc = f"{ax_var}.patches[{i}]"
|
|
398
|
+
lines.append(f"{acc}.set_facecolor({_fmt(fc)})")
|
|
399
|
+
|
|
400
|
+
# --- Collections (scatter — filtered list for stable indexing) ---
|
|
401
|
+
from matplotlib.collections import PathCollection
|
|
402
|
+
_user_colls = []
|
|
403
|
+
for coll in ax.collections:
|
|
404
|
+
if not isinstance(coll, PathCollection):
|
|
405
|
+
continue
|
|
406
|
+
if id(coll) in dist_artist_ids:
|
|
407
|
+
continue
|
|
408
|
+
if id(coll) in heatmap_artist_ids:
|
|
409
|
+
continue
|
|
410
|
+
if id(coll) in errorbar_artist_ids:
|
|
411
|
+
continue
|
|
412
|
+
_user_colls.append(coll)
|
|
413
|
+
|
|
414
|
+
if _user_colls:
|
|
415
|
+
lines.append(
|
|
416
|
+
f"\nfrom matplotlib.collections import "
|
|
417
|
+
f"PathCollection as _PC")
|
|
418
|
+
lines.append(
|
|
419
|
+
f"_scatter = [c for c in {ax_var}.collections "
|
|
420
|
+
f"if isinstance(c, _PC)]")
|
|
421
|
+
for j, coll in enumerate(_user_colls):
|
|
422
|
+
acc = f"_scatter[{j}]"
|
|
423
|
+
try:
|
|
424
|
+
fc = coll.get_facecolor()
|
|
425
|
+
if len(fc) > 0:
|
|
426
|
+
lines.append(
|
|
427
|
+
f"{acc}.set_facecolor({_fmt(to_hex(fc[0]))})")
|
|
428
|
+
ec = coll.get_edgecolor()
|
|
429
|
+
if len(ec) > 0:
|
|
430
|
+
_ec_hex = to_hex(ec[0])
|
|
431
|
+
if _ec_hex != "#000000":
|
|
432
|
+
lines.append(
|
|
433
|
+
f"{acc}.set_edgecolor({_fmt(_ec_hex)})")
|
|
434
|
+
alpha = coll.get_alpha()
|
|
435
|
+
if alpha is not None:
|
|
436
|
+
lines.append(f"{acc}.set_alpha({_fmt(alpha)})")
|
|
437
|
+
sizes = coll.get_sizes()
|
|
438
|
+
if len(sizes) > 0:
|
|
439
|
+
lines.append(
|
|
440
|
+
f"{acc}.set_sizes([{_fmt(float(sizes[0]))}])")
|
|
441
|
+
except Exception:
|
|
442
|
+
pass
|
|
443
|
+
|
|
444
|
+
# --- Text labels ---
|
|
445
|
+
title = ax.get_title()
|
|
446
|
+
xlabel = ax.get_xlabel()
|
|
447
|
+
ylabel = ax.get_ylabel()
|
|
448
|
+
if title or xlabel or ylabel:
|
|
449
|
+
lines.append(f"\n# Labels")
|
|
450
|
+
if title:
|
|
451
|
+
lines.append(f"{ax_var}.set_title({_fmt(title)})")
|
|
452
|
+
if xlabel:
|
|
453
|
+
lines.append(f"{ax_var}.set_xlabel({_fmt(xlabel)})")
|
|
454
|
+
if ylabel:
|
|
455
|
+
lines.append(f"{ax_var}.set_ylabel({_fmt(ylabel)})")
|
|
456
|
+
|
|
457
|
+
# --- Axis limits ---
|
|
458
|
+
xlim = ax.get_xlim()
|
|
459
|
+
ylim = ax.get_ylim()
|
|
460
|
+
lines.append(f"\n# Limits")
|
|
461
|
+
lines.append(f"{ax_var}.set_xlim({_fmt(round(float(xlim[0]), 4))}, "
|
|
462
|
+
f"{_fmt(round(float(xlim[1]), 4))})")
|
|
463
|
+
lines.append(f"{ax_var}.set_ylim({_fmt(round(float(ylim[0]), 4))}, "
|
|
464
|
+
f"{_fmt(round(float(ylim[1]), 4))})")
|
|
465
|
+
|
|
466
|
+
# --- Axis scale ---
|
|
467
|
+
xscale = ax.get_xscale()
|
|
468
|
+
yscale = ax.get_yscale()
|
|
469
|
+
if xscale != "linear" or yscale != "linear":
|
|
470
|
+
lines.append(f"\n# Scale")
|
|
471
|
+
if xscale != "linear":
|
|
472
|
+
lines.append(f"{ax_var}.set_xscale({_fmt(xscale)})")
|
|
473
|
+
if yscale != "linear":
|
|
474
|
+
lines.append(f"{ax_var}.set_yscale({_fmt(yscale)})")
|
|
475
|
+
|
|
476
|
+
# --- Fonts ---
|
|
477
|
+
title_size = ax.title.get_fontsize()
|
|
478
|
+
xlabel_size = ax.xaxis.label.get_fontsize()
|
|
479
|
+
ylabel_size = ax.yaxis.label.get_fontsize()
|
|
480
|
+
title_family = ax.title.get_fontfamily()
|
|
481
|
+
if isinstance(title_family, list):
|
|
482
|
+
title_family = title_family[0] if title_family else "Arial"
|
|
483
|
+
|
|
484
|
+
lines.append(f"\n# Fonts")
|
|
485
|
+
lines.append(f"{ax_var}.title.set_fontsize({_fmt(title_size)})")
|
|
486
|
+
lines.append(f"{ax_var}.title.set_fontfamily({_fmt(title_family)})")
|
|
487
|
+
lines.append(f"{ax_var}.xaxis.label.set_fontsize({_fmt(xlabel_size)})")
|
|
488
|
+
lines.append(f"{ax_var}.yaxis.label.set_fontsize({_fmt(ylabel_size)})")
|
|
489
|
+
|
|
490
|
+
# Title pad (distance from axes top, in points)
|
|
491
|
+
try:
|
|
492
|
+
_tpad_disp = ax.titleOffsetTrans.get_matrix()[1, 2]
|
|
493
|
+
_tpad_pts = round(_tpad_disp / fig.dpi * 72, 1)
|
|
494
|
+
if abs(_tpad_pts - 6.0) > 0.5:
|
|
495
|
+
lines.append(
|
|
496
|
+
f"{ax_var}.set_title({ax_var}.get_title(), "
|
|
497
|
+
f"pad={_fmt(_tpad_pts)})")
|
|
498
|
+
except Exception:
|
|
499
|
+
pass
|
|
500
|
+
|
|
501
|
+
# Text color / weight / style
|
|
502
|
+
for _tname, _tobj in [("title", ax.title),
|
|
503
|
+
("xaxis.label", ax.xaxis.label),
|
|
504
|
+
("yaxis.label", ax.yaxis.label)]:
|
|
505
|
+
try:
|
|
506
|
+
_tc = to_hex(_tobj.get_color())
|
|
507
|
+
if _tc != "#000000":
|
|
508
|
+
lines.append(f"{ax_var}.{_tname}.set_color({_fmt(_tc)})")
|
|
509
|
+
except Exception:
|
|
510
|
+
pass
|
|
511
|
+
_fw = _tobj.get_fontweight()
|
|
512
|
+
if _fw == 'bold' or (isinstance(_fw, (int, float)) and _fw >= 600):
|
|
513
|
+
lines.append(f"{ax_var}.{_tname}.set_fontweight('bold')")
|
|
514
|
+
if _tobj.get_fontstyle() == 'italic':
|
|
515
|
+
lines.append(f"{ax_var}.{_tname}.set_fontstyle('italic')")
|
|
516
|
+
|
|
517
|
+
# Tick label font size (use first tick as representative)
|
|
518
|
+
xticks = ax.get_xticklabels()
|
|
519
|
+
yticks = ax.get_yticklabels()
|
|
520
|
+
if xticks:
|
|
521
|
+
tick_sz = xticks[0].get_fontsize()
|
|
522
|
+
lines.append(f"{ax_var}.tick_params(labelsize={_fmt(tick_sz)})")
|
|
523
|
+
|
|
524
|
+
# Tick label colors
|
|
525
|
+
_xtl = ax.get_xticklabels()
|
|
526
|
+
if _xtl:
|
|
527
|
+
try:
|
|
528
|
+
_xtc = to_hex(_xtl[0].get_color())
|
|
529
|
+
if _xtc != "#000000":
|
|
530
|
+
lines.append(f"{ax_var}.tick_params(axis='x', "
|
|
531
|
+
f"labelcolor={_fmt(_xtc)})")
|
|
532
|
+
except Exception:
|
|
533
|
+
pass
|
|
534
|
+
_ytl = ax.get_yticklabels()
|
|
535
|
+
if _ytl:
|
|
536
|
+
try:
|
|
537
|
+
_ytc = to_hex(_ytl[0].get_color())
|
|
538
|
+
if _ytc != "#000000":
|
|
539
|
+
lines.append(f"{ax_var}.tick_params(axis='y', "
|
|
540
|
+
f"labelcolor={_fmt(_ytc)})")
|
|
541
|
+
except Exception:
|
|
542
|
+
pass
|
|
543
|
+
|
|
544
|
+
# --- Tick params ---
|
|
545
|
+
# Read from the actual tick objects
|
|
546
|
+
xtick_objs = ax.xaxis.get_major_ticks()
|
|
547
|
+
if xtick_objs:
|
|
548
|
+
tick = xtick_objs[0]
|
|
549
|
+
tick_line = tick.tick1line
|
|
550
|
+
direction = "in" if tick_line.get_marker() == 2 else "out"
|
|
551
|
+
# Check inout: tick2line visible means inout
|
|
552
|
+
if tick.tick2line.get_visible() and not ax.spines["top"].get_visible():
|
|
553
|
+
direction = "inout"
|
|
554
|
+
tick_len = tick_line.get_markersize()
|
|
555
|
+
tick_width = tick_line.get_markeredgewidth()
|
|
556
|
+
lines.append(f"{ax_var}.tick_params(direction={_fmt(direction)}, "
|
|
557
|
+
f"length={_fmt(tick_len)}, width={_fmt(tick_width)})")
|
|
558
|
+
|
|
559
|
+
# --- Tick spacing ---
|
|
560
|
+
from matplotlib.ticker import MultipleLocator as _ML
|
|
561
|
+
x_loc = ax.xaxis.get_major_locator()
|
|
562
|
+
y_loc = ax.yaxis.get_major_locator()
|
|
563
|
+
if isinstance(x_loc, _ML) or isinstance(y_loc, _ML):
|
|
564
|
+
if isinstance(x_loc, _ML):
|
|
565
|
+
xt = list(ax.get_xticks())
|
|
566
|
+
if len(xt) >= 2:
|
|
567
|
+
step = round(xt[1] - xt[0], 6)
|
|
568
|
+
if step > 0:
|
|
569
|
+
lines.append(
|
|
570
|
+
f"{ax_var}.xaxis.set_major_locator("
|
|
571
|
+
f"matplotlib.ticker.MultipleLocator({_fmt(step)}))")
|
|
572
|
+
if isinstance(y_loc, _ML):
|
|
573
|
+
yt = list(ax.get_yticks())
|
|
574
|
+
if len(yt) >= 2:
|
|
575
|
+
step = round(yt[1] - yt[0], 6)
|
|
576
|
+
if step > 0:
|
|
577
|
+
lines.append(
|
|
578
|
+
f"{ax_var}.yaxis.set_major_locator("
|
|
579
|
+
f"matplotlib.ticker.MultipleLocator({_fmt(step)}))")
|
|
580
|
+
|
|
581
|
+
# --- Spines ---
|
|
582
|
+
spine_changes = []
|
|
583
|
+
for name in ("top", "right", "bottom", "left"):
|
|
584
|
+
sp = ax.spines[name]
|
|
585
|
+
if not sp.get_visible():
|
|
586
|
+
spine_changes.append(
|
|
587
|
+
f"{ax_var}.spines[{_fmt(name)}].set_visible(False)")
|
|
588
|
+
lw = sp.get_linewidth()
|
|
589
|
+
if lw != 1.0:
|
|
590
|
+
spine_changes.append(
|
|
591
|
+
f"{ax_var}.spines[{_fmt(name)}].set_linewidth({_fmt(lw)})")
|
|
592
|
+
if spine_changes:
|
|
593
|
+
lines.append(f"\n# Spines")
|
|
594
|
+
lines.extend(spine_changes)
|
|
595
|
+
|
|
596
|
+
# --- Grid ---
|
|
597
|
+
# Check if grid lines are visible
|
|
598
|
+
has_grid = any(gl.get_visible() for gl in ax.xaxis.get_gridlines())
|
|
599
|
+
if has_grid:
|
|
600
|
+
gl = ax.xaxis.get_gridlines()[0]
|
|
601
|
+
lines.append(f"\n# Grid")
|
|
602
|
+
lines.append(
|
|
603
|
+
f"{ax_var}.grid(True, alpha={_fmt(gl.get_alpha() or 0.5)}, "
|
|
604
|
+
f"linewidth={_fmt(gl.get_linewidth())}, "
|
|
605
|
+
f"linestyle={_fmt(gl.get_linestyle())})")
|
|
606
|
+
|
|
607
|
+
# --- Legend ---
|
|
608
|
+
leg = ax.get_legend()
|
|
609
|
+
_leg_handles, _leg_labels = ax.get_legend_handles_labels()
|
|
610
|
+
if leg is None and _leg_handles:
|
|
611
|
+
# Legend was removed by user
|
|
612
|
+
lines.append(f"\n# Legend removed")
|
|
613
|
+
lines.append(f"if {ax_var}.get_legend() is not None:")
|
|
614
|
+
lines.append(f" {ax_var}.get_legend().remove()")
|
|
615
|
+
elif leg is not None:
|
|
616
|
+
if _leg_handles:
|
|
617
|
+
lines.append(f"\n# Legend")
|
|
618
|
+
|
|
619
|
+
# Get legend properties
|
|
620
|
+
leg_kwargs = []
|
|
621
|
+
try:
|
|
622
|
+
loc = leg._loc
|
|
623
|
+
_loc_names = {
|
|
624
|
+
0: "upper right", 1: "upper right", 2: "upper left",
|
|
625
|
+
3: "lower left", 4: "lower right", 5: "right",
|
|
626
|
+
6: "center left", 7: "center right",
|
|
627
|
+
8: "lower center", 9: "upper center", 10: "center",
|
|
628
|
+
}
|
|
629
|
+
loc_str = _loc_names.get(loc, "upper right")
|
|
630
|
+
leg_kwargs.append(f"loc={_fmt(loc_str)}")
|
|
631
|
+
except Exception:
|
|
632
|
+
pass
|
|
633
|
+
|
|
634
|
+
frameon = leg.get_frame().get_visible()
|
|
635
|
+
leg_kwargs.append(f"frameon={_fmt(frameon)}")
|
|
636
|
+
|
|
637
|
+
fontsize = leg._fontsize
|
|
638
|
+
if fontsize:
|
|
639
|
+
leg_kwargs.append(f"fontsize={_fmt(fontsize)}")
|
|
640
|
+
|
|
641
|
+
ncols = leg._ncols
|
|
642
|
+
if ncols != 1:
|
|
643
|
+
leg_kwargs.append(f"ncol={_fmt(ncols)}")
|
|
644
|
+
|
|
645
|
+
markerfirst = getattr(leg, '_markerfirst', True)
|
|
646
|
+
if not markerfirst:
|
|
647
|
+
leg_kwargs.append("markerfirst=False")
|
|
648
|
+
|
|
649
|
+
handletextpad = getattr(leg, 'handletextpad', 0.8)
|
|
650
|
+
if round(handletextpad, 2) != 0.8:
|
|
651
|
+
leg_kwargs.append(f"handletextpad={_fmt(round(handletextpad, 2))}")
|
|
652
|
+
|
|
653
|
+
handleheight = getattr(leg, 'handleheight', 0.7)
|
|
654
|
+
if round(handleheight, 2) != 0.7:
|
|
655
|
+
leg_kwargs.append(f"handleheight={_fmt(round(handleheight, 2))}")
|
|
656
|
+
|
|
657
|
+
# bbox_to_anchor
|
|
658
|
+
if hasattr(leg, "_bbox_to_anchor") and leg._bbox_to_anchor is not None:
|
|
659
|
+
try:
|
|
660
|
+
bbox = leg._bbox_to_anchor
|
|
661
|
+
if bbox.width < 1 and bbox.height < 1:
|
|
662
|
+
inv = ax.transAxes.inverted()
|
|
663
|
+
x_ax, y_ax = inv.transform((bbox.x0, bbox.y0))
|
|
664
|
+
leg_kwargs.append(
|
|
665
|
+
f"bbox_to_anchor=({_fmt(round(x_ax, 2))}, "
|
|
666
|
+
f"{_fmt(round(y_ax, 2))})")
|
|
667
|
+
except Exception:
|
|
668
|
+
pass
|
|
669
|
+
|
|
670
|
+
kwargs_str = ", ".join(leg_kwargs)
|
|
671
|
+
lines.append(f"{ax_var}.legend({kwargs_str})")
|
|
672
|
+
|
|
673
|
+
# Legend text colors
|
|
674
|
+
for _li, _lt in enumerate(leg.get_texts()):
|
|
675
|
+
try:
|
|
676
|
+
_ltc = to_hex(_lt.get_color())
|
|
677
|
+
if _ltc != "#000000":
|
|
678
|
+
lines.append(
|
|
679
|
+
f"{ax_var}.get_legend().get_texts()"
|
|
680
|
+
f"[{_li}].set_color({_fmt(_ltc)})")
|
|
681
|
+
except Exception:
|
|
682
|
+
pass
|
|
683
|
+
|
|
684
|
+
# --- Layout (figure-level, emitted once) ---
|
|
685
|
+
w, h = fig.get_size_inches()
|
|
686
|
+
lines.append(f"\n# Layout")
|
|
687
|
+
lines.append(f"fig.set_size_inches({_fmt(round(w, 2))}, {_fmt(round(h, 2))})")
|
|
688
|
+
|
|
689
|
+
fc = to_hex(fig.get_facecolor())
|
|
690
|
+
if fc != "#ffffff":
|
|
691
|
+
lines.append(f"fig.set_facecolor({_fmt(fc)})")
|
|
692
|
+
|
|
693
|
+
lines.append(f"fig.tight_layout()")
|
|
694
|
+
|
|
695
|
+
# --- Subplot spacing (multi-subplot only) ---
|
|
696
|
+
if multi:
|
|
697
|
+
try:
|
|
698
|
+
sp = fig.subplotpars
|
|
699
|
+
_h = round(sp.hspace, 2) if sp.hspace else 0.0
|
|
700
|
+
_w = round(sp.wspace, 2) if sp.wspace else 0.0
|
|
701
|
+
if _h > 0 or _w > 0:
|
|
702
|
+
lines.append(f"fig.subplots_adjust(hspace={_fmt(_h)}, wspace={_fmt(_w)})")
|
|
703
|
+
except Exception:
|
|
704
|
+
pass
|
|
705
|
+
|
|
706
|
+
# --- Marginal histograms ---
|
|
707
|
+
if marginal_axes:
|
|
708
|
+
lines.append("\n# Marginal histograms")
|
|
709
|
+
|
|
710
|
+
# Pre-scan: determine space needed per parent axes
|
|
711
|
+
_marg_infos = [] # (m_ax, info) pairs
|
|
712
|
+
_parent_space: dict[int, list] = {} # pidx -> [(axis, pos, height, pad)]
|
|
713
|
+
for m_ax in marginal_axes:
|
|
714
|
+
info = getattr(m_ax, '_matplotly_marginal_info', None)
|
|
715
|
+
if not info:
|
|
716
|
+
continue
|
|
717
|
+
_marg_infos.append((m_ax, info))
|
|
718
|
+
pidx = info.get('parent_ax_index', 0)
|
|
719
|
+
axis = info.get('axis', 'x')
|
|
720
|
+
position = info.get('position', 'top' if axis == 'x' else 'right')
|
|
721
|
+
m_height = info.get('height', 0.8)
|
|
722
|
+
m_pad = info.get('pad', 0.05)
|
|
723
|
+
_parent_space.setdefault(pidx, []).append(
|
|
724
|
+
(axis, position, m_height, m_pad))
|
|
725
|
+
|
|
726
|
+
# Shrink parent axes to make room for marginals
|
|
727
|
+
for pidx, needs in sorted(_parent_space.items()):
|
|
728
|
+
p_var = f"axes[{pidx}]" if multi else "ax"
|
|
729
|
+
lines.append(f"_pos = {p_var}.get_position()")
|
|
730
|
+
lines.append("_fig_w, _fig_h = fig.get_size_inches()")
|
|
731
|
+
lines.append(
|
|
732
|
+
"_x0, _y0, _w, _h = _pos.x0, _pos.y0, "
|
|
733
|
+
"_pos.width, _pos.height")
|
|
734
|
+
for axis, position, m_height, m_pad in needs:
|
|
735
|
+
if axis == 'x' and position == 'top':
|
|
736
|
+
lines.append(
|
|
737
|
+
f"_h -= ({_fmt(m_height)} + {_fmt(m_pad)}) / _fig_h")
|
|
738
|
+
elif axis == 'x' and position == 'bottom':
|
|
739
|
+
lines.append(
|
|
740
|
+
f"_bump = ({_fmt(m_height)} + {_fmt(m_pad)}) / _fig_h")
|
|
741
|
+
lines.append("_y0 += _bump; _h -= _bump")
|
|
742
|
+
elif axis == 'y' and position == 'right':
|
|
743
|
+
lines.append(
|
|
744
|
+
f"_w -= ({_fmt(m_height)} + {_fmt(m_pad)}) / _fig_w")
|
|
745
|
+
elif axis == 'y' and position == 'left':
|
|
746
|
+
lines.append(
|
|
747
|
+
f"_bump = ({_fmt(m_height)} + {_fmt(m_pad)}) / _fig_w")
|
|
748
|
+
lines.append("_x0 += _bump; _w -= _bump")
|
|
749
|
+
lines.append(f"{p_var}.set_position([_x0, _y0, _w, _h])")
|
|
750
|
+
|
|
751
|
+
# Now create each marginal axes in the freed space
|
|
752
|
+
for m_ax, info in _marg_infos:
|
|
753
|
+
parent_idx = info.get('parent_ax_index', 0)
|
|
754
|
+
axis = info.get('axis', 'x')
|
|
755
|
+
position = info.get('position', 'top' if axis == 'x' else 'right')
|
|
756
|
+
m_height = info.get('height', 0.8)
|
|
757
|
+
m_pad = info.get('pad', 0.05)
|
|
758
|
+
mode = info.get('mode', 'overlay')
|
|
759
|
+
n_bins = info.get('bins', 20)
|
|
760
|
+
alpha = info.get('alpha', 0.5)
|
|
761
|
+
separation = info.get('separation', 0.1)
|
|
762
|
+
inverted = info.get('inverted', False)
|
|
763
|
+
tick_side = info.get('tick_side',
|
|
764
|
+
'left' if axis == 'x' else 'bottom')
|
|
765
|
+
tick_fs = info.get('tick_fontsize', 8)
|
|
766
|
+
tick_step = info.get('tick_step', 0)
|
|
767
|
+
range_min = info.get('range_min', 0)
|
|
768
|
+
range_max = info.get('range_max', 0)
|
|
769
|
+
label = info.get('label', '')
|
|
770
|
+
label_fs = info.get('label_fontsize', 8)
|
|
771
|
+
label_bold = info.get('label_bold', False)
|
|
772
|
+
label_italic = info.get('label_italic', False)
|
|
773
|
+
label_color = info.get('label_color', '#000000')
|
|
774
|
+
m_title = info.get('title', '')
|
|
775
|
+
title_fs = info.get('title_fontsize', 8)
|
|
776
|
+
title_bold = info.get('title_bold', False)
|
|
777
|
+
title_italic = info.get('title_italic', False)
|
|
778
|
+
title_color = info.get('title_color', '#000000')
|
|
779
|
+
colls = info.get('collections', [])
|
|
780
|
+
share = 'sharex' if axis == 'x' else 'sharey'
|
|
781
|
+
data_col = 0 if axis == 'x' else 1
|
|
782
|
+
p_var = f"axes[{parent_idx}]" if multi else "ax"
|
|
783
|
+
|
|
784
|
+
# Collect data + compute global bins
|
|
785
|
+
data_exprs = []
|
|
786
|
+
colors = []
|
|
787
|
+
for ci in colls:
|
|
788
|
+
idx = ci['coll_index']
|
|
789
|
+
data_exprs.append(
|
|
790
|
+
f"{p_var}.collections[{idx}].get_offsets()[:, {data_col}]")
|
|
791
|
+
colors.append(ci['color'])
|
|
792
|
+
|
|
793
|
+
lines.append(
|
|
794
|
+
f"_m_data = [{', '.join(data_exprs)}]")
|
|
795
|
+
lines.append(
|
|
796
|
+
f"_m_bins = np.histogram_bin_edges("
|
|
797
|
+
f"np.concatenate(_m_data), bins={_fmt(n_bins)})")
|
|
798
|
+
|
|
799
|
+
# Compute rect from adjusted main axes position
|
|
800
|
+
lines.append(f"_pos = {p_var}.get_position()")
|
|
801
|
+
lines.append("_fig_w, _fig_h = fig.get_size_inches()")
|
|
802
|
+
if axis == 'x':
|
|
803
|
+
lines.append(
|
|
804
|
+
f"_h_frac = {_fmt(m_height)} / _fig_h")
|
|
805
|
+
lines.append(
|
|
806
|
+
f"_pad = {_fmt(m_pad)} / _fig_h")
|
|
807
|
+
if position == 'top':
|
|
808
|
+
lines.append(
|
|
809
|
+
"_m_rect = [_pos.x0, _pos.y1 + _pad, "
|
|
810
|
+
"_pos.width, _h_frac]")
|
|
811
|
+
else:
|
|
812
|
+
lines.append(
|
|
813
|
+
"_m_rect = [_pos.x0, _pos.y0 - _h_frac - _pad, "
|
|
814
|
+
"_pos.width, _h_frac]")
|
|
815
|
+
else:
|
|
816
|
+
lines.append(
|
|
817
|
+
f"_w_frac = {_fmt(m_height)} / _fig_w")
|
|
818
|
+
lines.append(
|
|
819
|
+
f"_pad = {_fmt(m_pad)} / _fig_w")
|
|
820
|
+
if position == 'right':
|
|
821
|
+
lines.append(
|
|
822
|
+
"_m_rect = [_pos.x1 + _pad, _pos.y0, "
|
|
823
|
+
"_w_frac, _pos.height]")
|
|
824
|
+
else:
|
|
825
|
+
lines.append(
|
|
826
|
+
"_m_rect = [_pos.x0 - _w_frac - _pad, _pos.y0, "
|
|
827
|
+
"_w_frac, _pos.height]")
|
|
828
|
+
|
|
829
|
+
lines.append(
|
|
830
|
+
f"ax_m = fig.add_axes(_m_rect, {share}={p_var})")
|
|
831
|
+
|
|
832
|
+
if mode == 'overlay':
|
|
833
|
+
for i, color in enumerate(colors):
|
|
834
|
+
orient = (", orientation='horizontal'"
|
|
835
|
+
if axis == 'y' else "")
|
|
836
|
+
lines.append(
|
|
837
|
+
f"ax_m.hist(_m_data[{i}], bins=_m_bins, "
|
|
838
|
+
f"color={_fmt(color)}, alpha={_fmt(alpha)}, "
|
|
839
|
+
f"edgecolor='none'{orient})")
|
|
840
|
+
else: # dodge
|
|
841
|
+
n = len(colls)
|
|
842
|
+
lines.append(f"_n = {n}")
|
|
843
|
+
lines.append(f"_bw = _m_bins[1] - _m_bins[0]")
|
|
844
|
+
lines.append(f"_sub = _bw / _n")
|
|
845
|
+
lines.append(
|
|
846
|
+
f"_bar = _sub * {_fmt(1 - separation)}")
|
|
847
|
+
lines.append(
|
|
848
|
+
f"_ctrs = (_m_bins[:-1] + _m_bins[1:]) / 2")
|
|
849
|
+
if axis == 'x':
|
|
850
|
+
lines.append(
|
|
851
|
+
f"for _i, (_d, _c) in enumerate("
|
|
852
|
+
f"zip(_m_data, {colors!r})):")
|
|
853
|
+
lines.append(
|
|
854
|
+
f" _cnt, _ = np.histogram(_d, bins=_m_bins)")
|
|
855
|
+
lines.append(
|
|
856
|
+
f" _off = _sub * (_i - (_n - 1) / 2)")
|
|
857
|
+
lines.append(
|
|
858
|
+
f" ax_m.bar(_ctrs + _off, _cnt, width=_bar, "
|
|
859
|
+
f"color=_c, alpha={_fmt(alpha)}, edgecolor='none')")
|
|
860
|
+
else: # y
|
|
861
|
+
lines.append(
|
|
862
|
+
f"for _i, (_d, _c) in enumerate("
|
|
863
|
+
f"zip(_m_data, {colors!r})):")
|
|
864
|
+
lines.append(
|
|
865
|
+
f" _cnt, _ = np.histogram(_d, bins=_m_bins)")
|
|
866
|
+
lines.append(
|
|
867
|
+
f" _off = _sub * (_i - (_n - 1) / 2)")
|
|
868
|
+
lines.append(
|
|
869
|
+
f" ax_m.barh(_ctrs + _off, _cnt, height=_bar, "
|
|
870
|
+
f"color=_c, alpha={_fmt(alpha)}, edgecolor='none')")
|
|
871
|
+
|
|
872
|
+
# Inversion
|
|
873
|
+
if inverted:
|
|
874
|
+
if axis == 'x':
|
|
875
|
+
lines.append("ax_m.invert_yaxis()")
|
|
876
|
+
else:
|
|
877
|
+
lines.append("ax_m.invert_xaxis()")
|
|
878
|
+
|
|
879
|
+
# Spines
|
|
880
|
+
for spine_name in ('top', 'right', 'bottom', 'left'):
|
|
881
|
+
if not m_ax.spines[spine_name].get_visible():
|
|
882
|
+
lines.append(
|
|
883
|
+
f"ax_m.spines[{_fmt(spine_name)}].set_visible(False)")
|
|
884
|
+
|
|
885
|
+
# Tick configuration
|
|
886
|
+
_tfs = _fmt(tick_fs)
|
|
887
|
+
if axis == 'x':
|
|
888
|
+
lines.append(
|
|
889
|
+
"ax_m.tick_params(axis='x', bottom=False, top=False, "
|
|
890
|
+
"labelbottom=False, labeltop=False)")
|
|
891
|
+
if tick_side == 'left':
|
|
892
|
+
lines.append(
|
|
893
|
+
f"ax_m.tick_params(axis='y', left=True, "
|
|
894
|
+
f"labelleft=True, right=False, labelright=False, "
|
|
895
|
+
f"labelsize={_tfs})")
|
|
896
|
+
elif tick_side == 'right':
|
|
897
|
+
lines.append(
|
|
898
|
+
f"ax_m.tick_params(axis='y', left=False, "
|
|
899
|
+
f"labelleft=False, right=True, labelright=True, "
|
|
900
|
+
f"labelsize={_tfs})")
|
|
901
|
+
else:
|
|
902
|
+
lines.append(
|
|
903
|
+
"ax_m.tick_params(axis='y', left=False, "
|
|
904
|
+
"labelleft=False, right=False, labelright=False)")
|
|
905
|
+
else:
|
|
906
|
+
lines.append(
|
|
907
|
+
"ax_m.tick_params(axis='y', left=False, right=False, "
|
|
908
|
+
"labelleft=False, labelright=False)")
|
|
909
|
+
if tick_side == 'bottom':
|
|
910
|
+
lines.append(
|
|
911
|
+
f"ax_m.tick_params(axis='x', bottom=True, "
|
|
912
|
+
f"labelbottom=True, top=False, labeltop=False, "
|
|
913
|
+
f"labelsize={_tfs})")
|
|
914
|
+
elif tick_side == 'top':
|
|
915
|
+
lines.append(
|
|
916
|
+
f"ax_m.tick_params(axis='x', bottom=False, "
|
|
917
|
+
f"labelbottom=False, top=True, labeltop=True, "
|
|
918
|
+
f"labelsize={_tfs})")
|
|
919
|
+
else:
|
|
920
|
+
lines.append(
|
|
921
|
+
"ax_m.tick_params(axis='x', bottom=False, "
|
|
922
|
+
"labelbottom=False, top=False, labeltop=False)")
|
|
923
|
+
|
|
924
|
+
# Tick step
|
|
925
|
+
if tick_step > 0:
|
|
926
|
+
count_axis = 'yaxis' if axis == 'x' else 'xaxis'
|
|
927
|
+
lines.append(
|
|
928
|
+
f"ax_m.{count_axis}.set_major_locator("
|
|
929
|
+
f"matplotlib.ticker.MultipleLocator({_fmt(tick_step)}))")
|
|
930
|
+
|
|
931
|
+
# Range (count-axis limits)
|
|
932
|
+
if range_max > 0:
|
|
933
|
+
if axis == 'x':
|
|
934
|
+
if inverted:
|
|
935
|
+
lines.append(
|
|
936
|
+
f"ax_m.set_ylim({_fmt(range_max)}, "
|
|
937
|
+
f"{_fmt(range_min)})")
|
|
938
|
+
else:
|
|
939
|
+
lines.append(
|
|
940
|
+
f"ax_m.set_ylim({_fmt(range_min)}, "
|
|
941
|
+
f"{_fmt(range_max)})")
|
|
942
|
+
else:
|
|
943
|
+
if inverted:
|
|
944
|
+
lines.append(
|
|
945
|
+
f"ax_m.set_xlim({_fmt(range_max)}, "
|
|
946
|
+
f"{_fmt(range_min)})")
|
|
947
|
+
else:
|
|
948
|
+
lines.append(
|
|
949
|
+
f"ax_m.set_xlim({_fmt(range_min)}, "
|
|
950
|
+
f"{_fmt(range_max)})")
|
|
951
|
+
|
|
952
|
+
# Label
|
|
953
|
+
if label:
|
|
954
|
+
_lfs = _fmt(label_fs)
|
|
955
|
+
_lw = "'bold'" if label_bold else "'normal'"
|
|
956
|
+
_ls = "'italic'" if label_italic else "'normal'"
|
|
957
|
+
_lc_arg = ""
|
|
958
|
+
if label_color != '#000000':
|
|
959
|
+
_lc_arg = f", color={_fmt(label_color)}"
|
|
960
|
+
if axis == 'x':
|
|
961
|
+
if tick_side == 'right':
|
|
962
|
+
lines.append(
|
|
963
|
+
"ax_m.yaxis.set_label_position('right')")
|
|
964
|
+
lines.append(
|
|
965
|
+
f"ax_m.set_ylabel({_fmt(label)}, fontsize={_lfs}, "
|
|
966
|
+
f"fontweight={_lw}, fontstyle={_ls}{_lc_arg})")
|
|
967
|
+
else:
|
|
968
|
+
if tick_side == 'top':
|
|
969
|
+
lines.append(
|
|
970
|
+
"ax_m.xaxis.set_label_position('top')")
|
|
971
|
+
lines.append(
|
|
972
|
+
f"ax_m.set_xlabel({_fmt(label)}, fontsize={_lfs}, "
|
|
973
|
+
f"fontweight={_lw}, fontstyle={_ls}{_lc_arg})")
|
|
974
|
+
|
|
975
|
+
# Title
|
|
976
|
+
if m_title:
|
|
977
|
+
_tifs = _fmt(title_fs)
|
|
978
|
+
_tw = "'bold'" if title_bold else "'normal'"
|
|
979
|
+
_ts = "'italic'" if title_italic else "'normal'"
|
|
980
|
+
_tc_arg = ""
|
|
981
|
+
if title_color != '#000000':
|
|
982
|
+
_tc_arg = f", color={_fmt(title_color)}"
|
|
983
|
+
lines.append(
|
|
984
|
+
f"ax_m.set_title({_fmt(m_title)}, fontsize={_tifs}, "
|
|
985
|
+
f"fontweight={_tw}, fontstyle={_ts}{_tc_arg})")
|
|
986
|
+
|
|
987
|
+
return "\n".join(lines)
|
|
988
|
+
|
|
989
|
+
|
|
990
|
+
def _bxp_stats_from_original(orig_stats, dinfo):
|
|
991
|
+
"""Convert original introspected box_stats to bxp-compatible format.
|
|
992
|
+
|
|
993
|
+
The introspector stores stats with a 'median' key; bxp() expects 'med'.
|
|
994
|
+
Using original stats avoids recomputing from reconstructed data, which
|
|
995
|
+
would degrade the violin KDE shape (double reconstruction).
|
|
996
|
+
"""
|
|
997
|
+
bxp_stats = []
|
|
998
|
+
for s in orig_stats:
|
|
999
|
+
med = s.get('med') or s.get('median', 0)
|
|
1000
|
+
stat = {
|
|
1001
|
+
'med': round(float(med), 4),
|
|
1002
|
+
'q1': round(float(s['q1']), 4),
|
|
1003
|
+
'q3': round(float(s['q3']), 4),
|
|
1004
|
+
'whislo': round(float(s['whislo']), 4),
|
|
1005
|
+
'whishi': round(float(s['whishi']), 4),
|
|
1006
|
+
'fliers': [round(float(f), 4) for f in s.get('fliers', [])],
|
|
1007
|
+
}
|
|
1008
|
+
if dinfo.get('notch', False):
|
|
1009
|
+
iqr = stat['q3'] - stat['q1']
|
|
1010
|
+
# Use raw_data length if available, else 100 as proxy
|
|
1011
|
+
stat['cilo'] = round(stat['med'] - 1.57 * iqr / 10, 4)
|
|
1012
|
+
stat['cihi'] = round(stat['med'] + 1.57 * iqr / 10, 4)
|
|
1013
|
+
bxp_stats.append(stat)
|
|
1014
|
+
return bxp_stats
|
|
1015
|
+
|
|
1016
|
+
|
|
1017
|
+
def _compute_bxp_stats(raw_data, dinfo):
|
|
1018
|
+
"""Compute bxp-compatible stats from raw data arrays."""
|
|
1019
|
+
bxp_stats = []
|
|
1020
|
+
for rd in raw_data:
|
|
1021
|
+
arr = np.asarray(rd, dtype=float)
|
|
1022
|
+
if len(arr) == 0:
|
|
1023
|
+
continue
|
|
1024
|
+
q1, med, q3 = np.percentile(arr, [25, 50, 75])
|
|
1025
|
+
iqr = q3 - q1
|
|
1026
|
+
wlo_d = arr[arr >= q1 - 1.5 * iqr]
|
|
1027
|
+
whi_d = arr[arr <= q3 + 1.5 * iqr]
|
|
1028
|
+
whislo = float(np.min(wlo_d)) if len(wlo_d) > 0 else float(q1)
|
|
1029
|
+
whishi = float(np.max(whi_d)) if len(whi_d) > 0 else float(q3)
|
|
1030
|
+
fliers = arr[(arr < whislo) | (arr > whishi)]
|
|
1031
|
+
stat = {
|
|
1032
|
+
'med': round(float(med), 4),
|
|
1033
|
+
'q1': round(float(q1), 4),
|
|
1034
|
+
'q3': round(float(q3), 4),
|
|
1035
|
+
'whislo': round(float(whislo), 4),
|
|
1036
|
+
'whishi': round(float(whishi), 4),
|
|
1037
|
+
'fliers': [round(float(f), 4) for f in fliers],
|
|
1038
|
+
}
|
|
1039
|
+
if dinfo.get('notch', False):
|
|
1040
|
+
n = len(arr)
|
|
1041
|
+
ci = 1.57 * iqr / np.sqrt(n) if n > 0 else 0
|
|
1042
|
+
stat['cilo'] = round(float(med) - ci, 4)
|
|
1043
|
+
stat['cihi'] = round(float(med) + ci, 4)
|
|
1044
|
+
bxp_stats.append(stat)
|
|
1045
|
+
return bxp_stats
|
|
1046
|
+
|
|
1047
|
+
|
|
1048
|
+
def _emit_stats_var(lines, bxp_stats, group_idx, n_groups):
|
|
1049
|
+
"""Emit the _stats variable assignment. Returns the variable name."""
|
|
1050
|
+
svar = f"_stats_{group_idx}" if n_groups > 1 else "_stats"
|
|
1051
|
+
lines.append(f"{svar} = [")
|
|
1052
|
+
for s in bxp_stats:
|
|
1053
|
+
lines.append(f" {s!r},")
|
|
1054
|
+
lines.append(f"]")
|
|
1055
|
+
return svar
|
|
1056
|
+
|
|
1057
|
+
|
|
1058
|
+
def _emit_bxp_call(lines, ax_var, dinfo, svar, positions, label,
|
|
1059
|
+
group_idx, n_groups, orient, width, mode):
|
|
1060
|
+
"""Emit ax.bxp() call using an already-defined stats variable."""
|
|
1061
|
+
box_width = width * 0.3 if 'violin' in mode else width
|
|
1062
|
+
is_vert = orient == "vertical"
|
|
1063
|
+
|
|
1064
|
+
bxp_args = []
|
|
1065
|
+
pos_str = ", ".join(str(round(float(p), 4)) for p in positions)
|
|
1066
|
+
bxp_args.append(f"positions=[{pos_str}]")
|
|
1067
|
+
bxp_args.append(f"vert={is_vert}")
|
|
1068
|
+
bxp_args.append("patch_artist=True")
|
|
1069
|
+
bxp_args.append(f"widths={_fmt(box_width)}")
|
|
1070
|
+
if dinfo.get('notch', False):
|
|
1071
|
+
bxp_args.append("notch=True")
|
|
1072
|
+
if dinfo.get('show_mean', False):
|
|
1073
|
+
bxp_args.append("showmeans=True")
|
|
1074
|
+
bxp_args.append("meanline=True")
|
|
1075
|
+
flier_marker = dinfo.get('flier_marker', 'o')
|
|
1076
|
+
if not flier_marker:
|
|
1077
|
+
bxp_args.append("showfliers=False")
|
|
1078
|
+
|
|
1079
|
+
# boxprops
|
|
1080
|
+
bp_kw = [
|
|
1081
|
+
f"facecolor={_fmt(dinfo.get('box_color', '#1f77b4'))}",
|
|
1082
|
+
f"edgecolor={_fmt(dinfo.get('box_edgecolor', '#000000'))}",
|
|
1083
|
+
]
|
|
1084
|
+
blw = dinfo.get('box_lw', 1.0)
|
|
1085
|
+
if blw != 1.0:
|
|
1086
|
+
bp_kw.append(f"linewidth={_fmt(blw)}")
|
|
1087
|
+
ba = dinfo.get('box_alpha', 1.0)
|
|
1088
|
+
if ba != 1.0:
|
|
1089
|
+
bp_kw.append(f"alpha={_fmt(ba)}")
|
|
1090
|
+
if dinfo.get('box_hatch', ''):
|
|
1091
|
+
bp_kw.append(f"hatch={_fmt(dinfo['box_hatch'])}")
|
|
1092
|
+
bxp_args.append(f"boxprops=dict({', '.join(bp_kw)})")
|
|
1093
|
+
|
|
1094
|
+
# medianprops
|
|
1095
|
+
mp_kw = [f"color={_fmt(dinfo.get('median_color', '#ff7f0e'))}"]
|
|
1096
|
+
mlw = dinfo.get('median_lw', 2.0)
|
|
1097
|
+
if mlw != 2.0:
|
|
1098
|
+
mp_kw.append(f"linewidth={_fmt(mlw)}")
|
|
1099
|
+
bxp_args.append(f"medianprops=dict({', '.join(mp_kw)})")
|
|
1100
|
+
|
|
1101
|
+
# whiskerprops
|
|
1102
|
+
wp_kw = []
|
|
1103
|
+
ws = dinfo.get('whisker_style', '-')
|
|
1104
|
+
if ws != '-':
|
|
1105
|
+
wp_kw.append(f"linestyle={_fmt(ws)}")
|
|
1106
|
+
wlw = dinfo.get('whisker_lw', 1.0)
|
|
1107
|
+
if wlw != 1.0:
|
|
1108
|
+
wp_kw.append(f"linewidth={_fmt(wlw)}")
|
|
1109
|
+
if wp_kw:
|
|
1110
|
+
bxp_args.append(f"whiskerprops=dict({', '.join(wp_kw)})")
|
|
1111
|
+
|
|
1112
|
+
# capprops
|
|
1113
|
+
cp_kw = []
|
|
1114
|
+
if wlw != 1.0:
|
|
1115
|
+
cp_kw.append(f"linewidth={_fmt(wlw)}")
|
|
1116
|
+
if cp_kw:
|
|
1117
|
+
bxp_args.append(f"capprops=dict({', '.join(cp_kw)})")
|
|
1118
|
+
|
|
1119
|
+
# flierprops
|
|
1120
|
+
if flier_marker:
|
|
1121
|
+
fl_kw = [
|
|
1122
|
+
f"marker={_fmt(flier_marker)}",
|
|
1123
|
+
f"markersize={_fmt(dinfo.get('flier_size', 6.0))}",
|
|
1124
|
+
f"markerfacecolor={_fmt(dinfo.get('flier_color', '#000000'))}",
|
|
1125
|
+
]
|
|
1126
|
+
bxp_args.append(f"flierprops=dict({', '.join(fl_kw)})")
|
|
1127
|
+
|
|
1128
|
+
# meanprops
|
|
1129
|
+
if dinfo.get('show_mean', False):
|
|
1130
|
+
ms = dinfo.get('mean_style', '--')
|
|
1131
|
+
mc = dinfo.get('median_color', '#ff7f0e')
|
|
1132
|
+
bxp_args.append(
|
|
1133
|
+
f"meanprops=dict(linestyle={_fmt(ms)}, color={_fmt(mc)})")
|
|
1134
|
+
|
|
1135
|
+
bvar = f"_bp_{group_idx}" if n_groups > 1 else "_bp"
|
|
1136
|
+
lines.append(f"{bvar} = {ax_var}.bxp({svar},")
|
|
1137
|
+
for k, arg in enumerate(bxp_args):
|
|
1138
|
+
comma = "," if k < len(bxp_args) - 1 else ")"
|
|
1139
|
+
lines.append(f" {arg}{comma}")
|
|
1140
|
+
lines.append(f"{bvar}['boxes'][0].set_label({_fmt(label)})")
|
|
1141
|
+
|
|
1142
|
+
|
|
1143
|
+
def _emit_data_from_stats(lines, svar, group_idx, n_groups):
|
|
1144
|
+
"""Emit code to reconstruct approximate data from box stats.
|
|
1145
|
+
|
|
1146
|
+
Returns the variable name holding the list of data arrays.
|
|
1147
|
+
Mirrors ``_reconstruct_data_from_stats`` in ``_introspect.py`` exactly
|
|
1148
|
+
(including fliers and adjusted sample counts) so the violin KDE shapes
|
|
1149
|
+
match the matplotly rendering.
|
|
1150
|
+
"""
|
|
1151
|
+
raw_var = f"_raw_{group_idx}" if n_groups > 1 else "_raw"
|
|
1152
|
+
lines.append(
|
|
1153
|
+
f"# Approximate data from stats "
|
|
1154
|
+
f"(replace {raw_var} with your original data for best fidelity)")
|
|
1155
|
+
lines.append(f"{raw_var} = []")
|
|
1156
|
+
lines.append(f"for _s in {svar}:")
|
|
1157
|
+
lines.append(f" _fl = _s.get('fliers', [])")
|
|
1158
|
+
lines.append(f" _n = max(100 - len(_fl), 20)")
|
|
1159
|
+
lines.append(f" _q = _n // 4")
|
|
1160
|
+
lines.append(f" _rem = _n - 4 * _q")
|
|
1161
|
+
lines.append(f" _rng = np.random.RandomState(42)")
|
|
1162
|
+
lines.append(f" {raw_var}.append(np.concatenate([")
|
|
1163
|
+
lines.append(f" _rng.uniform(_s['whislo'], _s['q1'], _q),")
|
|
1164
|
+
lines.append(f" _rng.uniform(_s['q1'], _s['med'], _q),")
|
|
1165
|
+
lines.append(f" _rng.uniform(_s['med'], _s['q3'], _q),")
|
|
1166
|
+
lines.append(f" _rng.uniform(_s['q3'], _s['whishi'], _q + _rem),")
|
|
1167
|
+
lines.append(f" np.array(_fl, dtype=float),")
|
|
1168
|
+
lines.append(f" ]))")
|
|
1169
|
+
return raw_var
|
|
1170
|
+
|
|
1171
|
+
|
|
1172
|
+
def _emit_violin(lines, ax_var, dinfo, raw_var, positions,
|
|
1173
|
+
orient, width, mode, label):
|
|
1174
|
+
"""Emit ax.violinplot() with full styling."""
|
|
1175
|
+
vc = dinfo.get('violin_color', '#1f77b4')
|
|
1176
|
+
vec = dinfo.get('violin_edgecolor', '#000000')
|
|
1177
|
+
va = dinfo.get('violin_alpha', 0.3)
|
|
1178
|
+
vi = dinfo.get('violin_inner', 'box')
|
|
1179
|
+
vw = width * 1.5
|
|
1180
|
+
is_vert = orient == "vertical"
|
|
1181
|
+
show_stats = (vi != 'none')
|
|
1182
|
+
pos_str = ", ".join(str(round(float(p), 4)) for p in positions)
|
|
1183
|
+
|
|
1184
|
+
lines.append(f"_vp = {ax_var}.violinplot({raw_var},")
|
|
1185
|
+
lines.append(f" positions=[{pos_str}],")
|
|
1186
|
+
lines.append(f" widths={_fmt(vw)}, vert={is_vert},")
|
|
1187
|
+
lines.append(
|
|
1188
|
+
f" showmeans=False, showmedians={show_stats}, "
|
|
1189
|
+
f"showextrema={show_stats})")
|
|
1190
|
+
lines.append(f"for _body in _vp['bodies']:")
|
|
1191
|
+
lines.append(f" _body.set_facecolor({_fmt(vc)})")
|
|
1192
|
+
lines.append(f" _body.set_edgecolor({_fmt(vec)})")
|
|
1193
|
+
lines.append(f" _body.set_alpha({_fmt(va)})")
|
|
1194
|
+
if show_stats:
|
|
1195
|
+
lines.append(
|
|
1196
|
+
f"for _key in ('cmedians', 'cmins', 'cmaxes', 'cbars'):")
|
|
1197
|
+
lines.append(f" if _key in _vp:")
|
|
1198
|
+
lines.append(f" _vp[_key].set_color({_fmt(vec)})")
|
|
1199
|
+
lines.append(f" _vp[_key].set_alpha(0.8)")
|
|
1200
|
+
# Legend label if violin is the primary plot component
|
|
1201
|
+
if 'box' not in mode:
|
|
1202
|
+
lines.append(f"_vp['bodies'][0].set_label({_fmt(label)})")
|
|
1203
|
+
|
|
1204
|
+
|
|
1205
|
+
def _emit_jitter(lines, ax_var, dinfo, raw_var, positions,
|
|
1206
|
+
orient, mode, label):
|
|
1207
|
+
"""Emit ax.scatter() calls for jitter/strip plot."""
|
|
1208
|
+
jc = dinfo.get('jitter_color', '#1f77b4')
|
|
1209
|
+
ja = dinfo.get('jitter_alpha', 0.5)
|
|
1210
|
+
js = dinfo.get('jitter_size', 3.0)
|
|
1211
|
+
jm = dinfo.get('jitter_marker', 'o')
|
|
1212
|
+
jsp = dinfo.get('jitter_spread', 0.2)
|
|
1213
|
+
pos_str = ", ".join(str(round(float(p), 4)) for p in positions)
|
|
1214
|
+
is_vert = orient == "vertical"
|
|
1215
|
+
jitter_only = 'box' not in mode and 'violin' not in mode
|
|
1216
|
+
|
|
1217
|
+
lines.append(f"_rng_j = np.random.RandomState(42)")
|
|
1218
|
+
lines.append(
|
|
1219
|
+
f"for _i, (_d, _pos) in enumerate("
|
|
1220
|
+
f"zip({raw_var}, [{pos_str}])):")
|
|
1221
|
+
lines.append(
|
|
1222
|
+
f" _jitter = _rng_j.uniform("
|
|
1223
|
+
f"{_fmt(-jsp)}, {_fmt(jsp)}, len(_d))")
|
|
1224
|
+
x_expr = "_pos + _jitter" if is_vert else "_d"
|
|
1225
|
+
y_expr = "_d" if is_vert else "_pos + _jitter"
|
|
1226
|
+
kw_parts = [
|
|
1227
|
+
f"s={_fmt(js ** 2)}", f"c={_fmt(jc)}",
|
|
1228
|
+
f"alpha={_fmt(ja)}", f"marker={_fmt(jm)}", "zorder=3",
|
|
1229
|
+
]
|
|
1230
|
+
if jitter_only:
|
|
1231
|
+
kw_parts.append(
|
|
1232
|
+
f"label=({_fmt(label)} if _i == 0 else '_nolegend_')")
|
|
1233
|
+
lines.append(
|
|
1234
|
+
f" {ax_var}.scatter({x_expr}, {y_expr}, "
|
|
1235
|
+
f"{', '.join(kw_parts)})")
|
|
1236
|
+
|
|
1237
|
+
|
|
1238
|
+
def _emit_dist_with_data(lines, ax_var, dist_infos, data_vars,
|
|
1239
|
+
orient, width, n_groups):
|
|
1240
|
+
"""Emit distribution code using the user's real data variable(s).
|
|
1241
|
+
|
|
1242
|
+
No artist cleanup. No fabricated data. Uses ax.boxplot() and
|
|
1243
|
+
ax.violinplot() directly with the user's data variable.
|
|
1244
|
+
|
|
1245
|
+
When len(data_vars) == n_groups, each group uses its own data_var
|
|
1246
|
+
directly (multi-call scenario). When len(data_vars) == 1, a single
|
|
1247
|
+
data_var is sliced by index (single-call, multi-group scenario).
|
|
1248
|
+
"""
|
|
1249
|
+
lines.append(f"\n# Distribution plot")
|
|
1250
|
+
|
|
1251
|
+
multi_var = len(data_vars) == n_groups and n_groups > 1
|
|
1252
|
+
|
|
1253
|
+
if not multi_var:
|
|
1254
|
+
# Single data variable — compute position-to-data-index mapping
|
|
1255
|
+
data_var = data_vars[0]
|
|
1256
|
+
all_orig_pos = []
|
|
1257
|
+
for di in dist_infos:
|
|
1258
|
+
all_orig_pos.extend(
|
|
1259
|
+
di.get('original_positions', di['positions']))
|
|
1260
|
+
all_orig_pos_sorted = sorted(set(all_orig_pos))
|
|
1261
|
+
pos_to_idx = {p: i for i, p in enumerate(all_orig_pos_sorted)}
|
|
1262
|
+
|
|
1263
|
+
for dj, dinfo in enumerate(dist_infos):
|
|
1264
|
+
mode = dinfo.get('display_mode', 'box')
|
|
1265
|
+
label = dinfo.get('label', f'Group {dj}')
|
|
1266
|
+
positions = dinfo.get('positions', [])
|
|
1267
|
+
orig_pos = dinfo.get('original_positions', positions)
|
|
1268
|
+
|
|
1269
|
+
dvar = f"_data_{dj}" if n_groups > 1 else "_data"
|
|
1270
|
+
lines.append(f"\n# {label}")
|
|
1271
|
+
|
|
1272
|
+
if multi_var:
|
|
1273
|
+
# Each group has its own data variable — use directly
|
|
1274
|
+
lines.append(f"{dvar} = {data_vars[dj]}")
|
|
1275
|
+
else:
|
|
1276
|
+
# Single data variable — slice by index
|
|
1277
|
+
indices = [pos_to_idx[p] for p in orig_pos
|
|
1278
|
+
if p in pos_to_idx]
|
|
1279
|
+
lines.append(
|
|
1280
|
+
f"{dvar} = [{data_var}[i] for i in {indices!r}]")
|
|
1281
|
+
|
|
1282
|
+
# Violin (background)
|
|
1283
|
+
if 'violin' in mode:
|
|
1284
|
+
_emit_violin(lines, ax_var, dinfo, dvar, positions,
|
|
1285
|
+
orient, width, mode, label)
|
|
1286
|
+
|
|
1287
|
+
# Box
|
|
1288
|
+
if 'box' in mode:
|
|
1289
|
+
_emit_boxplot_call(lines, ax_var, dinfo, dvar, positions,
|
|
1290
|
+
label, dj, n_groups, orient, width, mode)
|
|
1291
|
+
|
|
1292
|
+
# Jitter
|
|
1293
|
+
if 'jitter' in mode:
|
|
1294
|
+
_emit_jitter(lines, ax_var, dinfo, dvar, positions,
|
|
1295
|
+
orient, mode, label)
|
|
1296
|
+
|
|
1297
|
+
|
|
1298
|
+
def _emit_boxplot_call(lines, ax_var, dinfo, data_var, positions,
|
|
1299
|
+
label, group_idx, n_groups, orient, width, mode):
|
|
1300
|
+
"""Emit ax.boxplot() call using real data (Apply path)."""
|
|
1301
|
+
box_width = width * 0.3 if 'violin' in mode else width
|
|
1302
|
+
is_vert = orient == "vertical"
|
|
1303
|
+
|
|
1304
|
+
bp_args = []
|
|
1305
|
+
pos_str = ", ".join(str(round(float(p), 4)) for p in positions)
|
|
1306
|
+
bp_args.append(f"positions=[{pos_str}]")
|
|
1307
|
+
bp_args.append(f"vert={is_vert}")
|
|
1308
|
+
bp_args.append("patch_artist=True")
|
|
1309
|
+
bp_args.append(f"widths={_fmt(box_width)}")
|
|
1310
|
+
if dinfo.get('notch', False):
|
|
1311
|
+
bp_args.append("notch=True")
|
|
1312
|
+
if dinfo.get('show_mean', False):
|
|
1313
|
+
bp_args.append("showmeans=True")
|
|
1314
|
+
bp_args.append("meanline=True")
|
|
1315
|
+
flier_marker = dinfo.get('flier_marker', 'o')
|
|
1316
|
+
if not flier_marker:
|
|
1317
|
+
bp_args.append("showfliers=False")
|
|
1318
|
+
|
|
1319
|
+
# boxprops
|
|
1320
|
+
bp_kw = [
|
|
1321
|
+
f"facecolor={_fmt(dinfo.get('box_color', '#1f77b4'))}",
|
|
1322
|
+
f"edgecolor={_fmt(dinfo.get('box_edgecolor', '#000000'))}",
|
|
1323
|
+
]
|
|
1324
|
+
blw = dinfo.get('box_lw', 1.0)
|
|
1325
|
+
if blw != 1.0:
|
|
1326
|
+
bp_kw.append(f"linewidth={_fmt(blw)}")
|
|
1327
|
+
ba = dinfo.get('box_alpha', 1.0)
|
|
1328
|
+
if ba != 1.0:
|
|
1329
|
+
bp_kw.append(f"alpha={_fmt(ba)}")
|
|
1330
|
+
if dinfo.get('box_hatch', ''):
|
|
1331
|
+
bp_kw.append(f"hatch={_fmt(dinfo['box_hatch'])}")
|
|
1332
|
+
bp_args.append(f"boxprops=dict({', '.join(bp_kw)})")
|
|
1333
|
+
|
|
1334
|
+
# medianprops
|
|
1335
|
+
mp_kw = [f"color={_fmt(dinfo.get('median_color', '#ff7f0e'))}"]
|
|
1336
|
+
mlw = dinfo.get('median_lw', 2.0)
|
|
1337
|
+
if mlw != 2.0:
|
|
1338
|
+
mp_kw.append(f"linewidth={_fmt(mlw)}")
|
|
1339
|
+
bp_args.append(f"medianprops=dict({', '.join(mp_kw)})")
|
|
1340
|
+
|
|
1341
|
+
# whiskerprops
|
|
1342
|
+
wp_kw = []
|
|
1343
|
+
ws = dinfo.get('whisker_style', '-')
|
|
1344
|
+
if ws != '-':
|
|
1345
|
+
wp_kw.append(f"linestyle={_fmt(ws)}")
|
|
1346
|
+
wlw = dinfo.get('whisker_lw', 1.0)
|
|
1347
|
+
if wlw != 1.0:
|
|
1348
|
+
wp_kw.append(f"linewidth={_fmt(wlw)}")
|
|
1349
|
+
if wp_kw:
|
|
1350
|
+
bp_args.append(f"whiskerprops=dict({', '.join(wp_kw)})")
|
|
1351
|
+
|
|
1352
|
+
# capprops
|
|
1353
|
+
cp_kw = []
|
|
1354
|
+
if wlw != 1.0:
|
|
1355
|
+
cp_kw.append(f"linewidth={_fmt(wlw)}")
|
|
1356
|
+
if cp_kw:
|
|
1357
|
+
bp_args.append(f"capprops=dict({', '.join(cp_kw)})")
|
|
1358
|
+
|
|
1359
|
+
# flierprops
|
|
1360
|
+
if flier_marker:
|
|
1361
|
+
fl_kw = [
|
|
1362
|
+
f"marker={_fmt(flier_marker)}",
|
|
1363
|
+
f"markersize={_fmt(dinfo.get('flier_size', 6.0))}",
|
|
1364
|
+
f"markerfacecolor={_fmt(dinfo.get('flier_color', '#000000'))}",
|
|
1365
|
+
]
|
|
1366
|
+
bp_args.append(f"flierprops=dict({', '.join(fl_kw)})")
|
|
1367
|
+
|
|
1368
|
+
# meanprops
|
|
1369
|
+
if dinfo.get('show_mean', False):
|
|
1370
|
+
ms = dinfo.get('mean_style', '--')
|
|
1371
|
+
mc = dinfo.get('median_color', '#ff7f0e')
|
|
1372
|
+
bp_args.append(
|
|
1373
|
+
f"meanprops=dict(linestyle={_fmt(ms)}, color={_fmt(mc)})")
|
|
1374
|
+
|
|
1375
|
+
bvar = f"_bp_{group_idx}" if n_groups > 1 else "_bp"
|
|
1376
|
+
lines.append(f"{bvar} = {ax_var}.boxplot({data_var},")
|
|
1377
|
+
for k, arg in enumerate(bp_args):
|
|
1378
|
+
comma = "," if k < len(bp_args) - 1 else ")"
|
|
1379
|
+
lines.append(f" {arg}{comma}")
|
|
1380
|
+
lines.append(f"{bvar}['boxes'][0].set_label({_fmt(label)})")
|
|
1381
|
+
|
|
1382
|
+
|
|
1383
|
+
def _emit_hist_merged(lines, ax_var, hist_infos, data_vars=None):
|
|
1384
|
+
"""Emit merged histogram code.
|
|
1385
|
+
|
|
1386
|
+
When *data_vars* are provided (Apply path — AST-extracted from the
|
|
1387
|
+
user's cell), emits an executable ``ax.hist()`` call that replaces
|
|
1388
|
+
the user's original separate calls (which get commented out by
|
|
1389
|
+
``_close_apply``). Returns ``True`` so the caller can skip the
|
|
1390
|
+
redundant style-only loop.
|
|
1391
|
+
|
|
1392
|
+
Without *data_vars* (Copy / fallback), emits a comment template
|
|
1393
|
+
and returns ``False`` so the caller still emits style-only code.
|
|
1394
|
+
"""
|
|
1395
|
+
ref = hist_infos[0]
|
|
1396
|
+
n_bins = ref.get('bins', 20)
|
|
1397
|
+
|
|
1398
|
+
colors = []
|
|
1399
|
+
labels = []
|
|
1400
|
+
for hi in hist_infos:
|
|
1401
|
+
colors.append(hi.get('color', '#1f77b4'))
|
|
1402
|
+
labels.append(hi.get('label', f'Hist {len(labels)}'))
|
|
1403
|
+
|
|
1404
|
+
histtype = ref.get('histtype', 'bar')
|
|
1405
|
+
orientation = ref.get('orientation', 'vertical')
|
|
1406
|
+
cumulative = ref.get('cumulative', False)
|
|
1407
|
+
mode = ref.get('mode', 'count')
|
|
1408
|
+
density = mode == 'density'
|
|
1409
|
+
rwidth = ref.get('rwidth', 0.8)
|
|
1410
|
+
ec = ref.get('edgecolor', '#000000')
|
|
1411
|
+
lw = ref.get('linewidth', 1.0)
|
|
1412
|
+
alpha = ref.get('alpha', 0.7)
|
|
1413
|
+
|
|
1414
|
+
kw_parts = [f"bins={_fmt(n_bins)}", f"histtype={_fmt(histtype)}"]
|
|
1415
|
+
if density:
|
|
1416
|
+
kw_parts.append("density=True")
|
|
1417
|
+
if cumulative:
|
|
1418
|
+
kw_parts.append("cumulative=True")
|
|
1419
|
+
if orientation != 'vertical':
|
|
1420
|
+
kw_parts.append(f"orientation={_fmt(orientation)}")
|
|
1421
|
+
kw_parts.append(f"color={colors!r}")
|
|
1422
|
+
kw_parts.append(f"edgecolor={_fmt(ec)}")
|
|
1423
|
+
kw_parts.append(f"linewidth={_fmt(lw)}")
|
|
1424
|
+
kw_parts.append(f"alpha={_fmt(alpha)}")
|
|
1425
|
+
kw_parts.append(f"label={labels!r}")
|
|
1426
|
+
kw_parts.append(f"rwidth={_fmt(rwidth)}")
|
|
1427
|
+
|
|
1428
|
+
have_vars = (data_vars
|
|
1429
|
+
and len(data_vars) == len(hist_infos)
|
|
1430
|
+
and all(not v.startswith('<') for v in data_vars))
|
|
1431
|
+
|
|
1432
|
+
if have_vars:
|
|
1433
|
+
# --- Executable merged call (Apply path) ---
|
|
1434
|
+
data_list = ", ".join(data_vars)
|
|
1435
|
+
lines.append(f"\n{ax_var}.hist([{data_list}],")
|
|
1436
|
+
for k, kw in enumerate(kw_parts):
|
|
1437
|
+
comma = "," if k < len(kw_parts) - 1 else ")"
|
|
1438
|
+
lines.append(f" {kw}{comma}")
|
|
1439
|
+
|
|
1440
|
+
# Per-histogram post-style (hatch, differing edgecolors, etc.)
|
|
1441
|
+
ecs = [hi.get('edgecolor', ec) for hi in hist_infos]
|
|
1442
|
+
alphas = [hi.get('alpha', alpha) for hi in hist_infos]
|
|
1443
|
+
for i, hi in enumerate(hist_infos):
|
|
1444
|
+
fixups = []
|
|
1445
|
+
if ecs[i] != ec:
|
|
1446
|
+
fixups.append(f"_p.set_edgecolor({_fmt(ecs[i])})")
|
|
1447
|
+
if alphas[i] != alpha:
|
|
1448
|
+
fixups.append(f"_p.set_alpha({_fmt(alphas[i])})")
|
|
1449
|
+
hatch = hi.get('hatch', '')
|
|
1450
|
+
if hatch:
|
|
1451
|
+
fixups.append(f"_p.set_hatch({_fmt(hatch)})")
|
|
1452
|
+
if fixups:
|
|
1453
|
+
lines.append(f"for _p in {ax_var}.containers[{i}]:")
|
|
1454
|
+
for f in fixups:
|
|
1455
|
+
lines.append(f" {f}")
|
|
1456
|
+
return True # skip style-only loop
|
|
1457
|
+
else:
|
|
1458
|
+
# --- Comment template (Copy / fallback) ---
|
|
1459
|
+
lines.append(
|
|
1460
|
+
f"\n# Merged histograms \u2014 replace your separate "
|
|
1461
|
+
f"ax.hist() calls with:")
|
|
1462
|
+
data_hint = ", ".join(f"<{lbl}>" for lbl in labels)
|
|
1463
|
+
lines.append(f"# {ax_var}.hist([{data_hint}],")
|
|
1464
|
+
for k, kw in enumerate(kw_parts):
|
|
1465
|
+
comma = "," if k < len(kw_parts) - 1 else ")"
|
|
1466
|
+
lines.append(f"# {kw}{comma}")
|
|
1467
|
+
|
|
1468
|
+
for i, hi in enumerate(hist_infos):
|
|
1469
|
+
hatch = hi.get('hatch', '')
|
|
1470
|
+
if hatch:
|
|
1471
|
+
lines.append(
|
|
1472
|
+
f"# for _p in {ax_var}.containers[{i}]:")
|
|
1473
|
+
lines.append(f"# _p.set_hatch({_fmt(hatch)})")
|
|
1474
|
+
return False # emit style-only loop
|
|
1475
|
+
|
|
1476
|
+
|
|
1477
|
+
def _emit_heatmap(lines, ax_var, heatmap_infos):
|
|
1478
|
+
"""Emit code for heatmap styling (cmap, norm, clim, etc.)."""
|
|
1479
|
+
for idx, info in enumerate(heatmap_infos):
|
|
1480
|
+
htype = info.get('heatmap_type', 'imshow')
|
|
1481
|
+
cmap = info.get('cmap', 'viridis')
|
|
1482
|
+
vmin = info.get('vmin')
|
|
1483
|
+
vmax = info.get('vmax')
|
|
1484
|
+
norm_type = info.get('norm_type', 'linear')
|
|
1485
|
+
alpha = info.get('alpha', 1.0)
|
|
1486
|
+
interp = info.get('interpolation')
|
|
1487
|
+
aspect = info.get('aspect')
|
|
1488
|
+
annot_enabled = info.get('annot_enabled', False)
|
|
1489
|
+
annot_fmt = info.get('annot_fmt', '.2f')
|
|
1490
|
+
annot_fontsize = info.get('annot_fontsize', 8.0)
|
|
1491
|
+
annot_color = info.get('annot_color', 'auto')
|
|
1492
|
+
grid_enabled = info.get('grid_enabled', False)
|
|
1493
|
+
grid_lw = info.get('grid_lw', 1.0)
|
|
1494
|
+
grid_color = info.get('grid_color', '#ffffff')
|
|
1495
|
+
data = info.get('data')
|
|
1496
|
+
|
|
1497
|
+
lines.append(f"\n# Heatmap ({htype})")
|
|
1498
|
+
|
|
1499
|
+
if htype == 'imshow':
|
|
1500
|
+
var = f"_im = {ax_var}.images[{idx}]"
|
|
1501
|
+
acc = "_im"
|
|
1502
|
+
else:
|
|
1503
|
+
var = f"_qm = {ax_var}.collections[{idx}]"
|
|
1504
|
+
acc = "_qm"
|
|
1505
|
+
lines.append(var)
|
|
1506
|
+
|
|
1507
|
+
lines.append(f"{acc}.set_cmap({_fmt(cmap)})")
|
|
1508
|
+
|
|
1509
|
+
if norm_type == 'log':
|
|
1510
|
+
_vmin = max(vmin, 1e-10) if vmin is not None else 1e-10
|
|
1511
|
+
lines.append(
|
|
1512
|
+
f"from matplotlib.colors import LogNorm")
|
|
1513
|
+
lines.append(
|
|
1514
|
+
f"{acc}.set_norm(LogNorm("
|
|
1515
|
+
f"vmin={_fmt(_vmin)}, vmax={_fmt(vmax)}))")
|
|
1516
|
+
elif norm_type == 'symlog':
|
|
1517
|
+
lines.append(
|
|
1518
|
+
f"from matplotlib.colors import SymLogNorm")
|
|
1519
|
+
lines.append(
|
|
1520
|
+
f"{acc}.set_norm(SymLogNorm("
|
|
1521
|
+
f"linthresh=1.0, vmin={_fmt(vmin)}, vmax={_fmt(vmax)}))")
|
|
1522
|
+
elif norm_type == 'centered':
|
|
1523
|
+
lines.append(
|
|
1524
|
+
f"from matplotlib.colors import CenteredNorm")
|
|
1525
|
+
lines.append(
|
|
1526
|
+
f"{acc}.set_norm(CenteredNorm(vcenter=0))")
|
|
1527
|
+
else:
|
|
1528
|
+
lines.append(
|
|
1529
|
+
f"{acc}.set_clim({_fmt(vmin)}, {_fmt(vmax)})")
|
|
1530
|
+
|
|
1531
|
+
if htype == 'imshow' and interp:
|
|
1532
|
+
lines.append(f"{acc}.set_interpolation({_fmt(interp)})")
|
|
1533
|
+
|
|
1534
|
+
if htype == 'imshow' and aspect and str(aspect) != 'equal':
|
|
1535
|
+
lines.append(f"{ax_var}.set_aspect({_fmt(str(aspect))})")
|
|
1536
|
+
|
|
1537
|
+
if alpha is not None and round(alpha, 2) != 1.0:
|
|
1538
|
+
lines.append(f"{acc}.set_alpha({_fmt(alpha)})")
|
|
1539
|
+
|
|
1540
|
+
# Annotations
|
|
1541
|
+
if annot_enabled and data is not None:
|
|
1542
|
+
data_arr = np.asarray(data)
|
|
1543
|
+
if data_arr.ndim >= 2:
|
|
1544
|
+
nrows, ncols = data_arr.shape
|
|
1545
|
+
lines.append(f"\n# Annotations")
|
|
1546
|
+
lines.append(f"_data = {acc}.get_array()")
|
|
1547
|
+
lines.append(f"if hasattr(_data, 'reshape'):")
|
|
1548
|
+
lines.append(f" _data = _data.reshape({nrows}, {ncols})")
|
|
1549
|
+
lines.append(f"_vmin, _vmax = {acc}.get_clim()")
|
|
1550
|
+
lines.append(f"_vmid = (_vmin + _vmax) / 2.0")
|
|
1551
|
+
lines.append(f"for _i in range({nrows}):")
|
|
1552
|
+
lines.append(f" for _j in range({ncols}):")
|
|
1553
|
+
lines.append(f" _val = _data[_i, _j]")
|
|
1554
|
+
if annot_color == 'auto':
|
|
1555
|
+
lines.append(
|
|
1556
|
+
f" _c = 'white' if _val > _vmid else 'black'")
|
|
1557
|
+
else:
|
|
1558
|
+
lines.append(f" _c = {_fmt(annot_color)}")
|
|
1559
|
+
lines.append(
|
|
1560
|
+
f" {ax_var}.text(_j, _i, "
|
|
1561
|
+
f"format(_val, {_fmt(annot_fmt)}), "
|
|
1562
|
+
f"ha='center', va='center', "
|
|
1563
|
+
f"fontsize={_fmt(annot_fontsize)}, color=_c)")
|
|
1564
|
+
|
|
1565
|
+
# Grid
|
|
1566
|
+
if grid_enabled and data is not None:
|
|
1567
|
+
data_arr = np.asarray(data)
|
|
1568
|
+
if data_arr.ndim >= 2:
|
|
1569
|
+
nrows, ncols = data_arr.shape
|
|
1570
|
+
lines.append(f"\n# Grid lines")
|
|
1571
|
+
lines.append(
|
|
1572
|
+
f"{ax_var}.set_xticks("
|
|
1573
|
+
f"np.arange(-0.5, {ncols}, 1), minor=True)")
|
|
1574
|
+
lines.append(
|
|
1575
|
+
f"{ax_var}.set_yticks("
|
|
1576
|
+
f"np.arange(-0.5, {nrows}, 1), minor=True)")
|
|
1577
|
+
lines.append(
|
|
1578
|
+
f"{ax_var}.grid(which='minor', "
|
|
1579
|
+
f"color={_fmt(grid_color)}, "
|
|
1580
|
+
f"linewidth={_fmt(grid_lw)}, linestyle='-')")
|
|
1581
|
+
lines.append(
|
|
1582
|
+
f"{ax_var}.tick_params(which='minor', length=0)")
|
|
1583
|
+
|
|
1584
|
+
# Tick labels
|
|
1585
|
+
xtick_show = info.get('xtick_show', True)
|
|
1586
|
+
ytick_show = info.get('ytick_show', True)
|
|
1587
|
+
xtick_labels_str = info.get('xtick_labels', '')
|
|
1588
|
+
ytick_labels_str = info.get('ytick_labels', '')
|
|
1589
|
+
|
|
1590
|
+
if not xtick_show or not ytick_show or xtick_labels_str or ytick_labels_str:
|
|
1591
|
+
lines.append(f"\n# Tick labels")
|
|
1592
|
+
|
|
1593
|
+
if not xtick_show:
|
|
1594
|
+
lines.append(
|
|
1595
|
+
f"{ax_var}.tick_params(axis='x', "
|
|
1596
|
+
f"bottom=False, labelbottom=False)")
|
|
1597
|
+
elif xtick_labels_str.strip():
|
|
1598
|
+
labels = [l.strip() for l in xtick_labels_str.split(',')]
|
|
1599
|
+
positions = list(range(len(labels)))
|
|
1600
|
+
lines.append(f"{ax_var}.set_xticks({positions})")
|
|
1601
|
+
lines.append(f"{ax_var}.set_xticklabels({labels!r})")
|
|
1602
|
+
|
|
1603
|
+
if not ytick_show:
|
|
1604
|
+
lines.append(
|
|
1605
|
+
f"{ax_var}.tick_params(axis='y', "
|
|
1606
|
+
f"left=False, labelleft=False)")
|
|
1607
|
+
elif ytick_labels_str.strip():
|
|
1608
|
+
labels = [l.strip() for l in ytick_labels_str.split(',')]
|
|
1609
|
+
positions = list(range(len(labels)))
|
|
1610
|
+
lines.append(f"{ax_var}.set_yticks({positions})")
|
|
1611
|
+
lines.append(f"{ax_var}.set_yticklabels({labels!r})")
|
|
1612
|
+
|
|
1613
|
+
|
|
1614
|
+
def _emit_colorbar(lines, ax_var, cbar_info, heatmap_infos):
|
|
1615
|
+
"""Emit colorbar creation code."""
|
|
1616
|
+
htype = heatmap_infos[0].get('heatmap_type', 'imshow')
|
|
1617
|
+
if htype == 'imshow':
|
|
1618
|
+
mappable_var = f"{ax_var}.images[0]"
|
|
1619
|
+
else:
|
|
1620
|
+
mappable_var = f"{ax_var}.collections[0]"
|
|
1621
|
+
|
|
1622
|
+
loc = cbar_info.get('location', 'right')
|
|
1623
|
+
shrink = cbar_info.get('shrink', 1.0)
|
|
1624
|
+
pad = cbar_info.get('pad', 0.05)
|
|
1625
|
+
label = cbar_info.get('label', '')
|
|
1626
|
+
label_fs = cbar_info.get('label_fontsize', 12.0)
|
|
1627
|
+
tick_fs = cbar_info.get('tick_fontsize', 10.0)
|
|
1628
|
+
|
|
1629
|
+
lines.append(f"\n# Colorbar")
|
|
1630
|
+
|
|
1631
|
+
# When location/shrink/pad differ from defaults, must remove + recreate
|
|
1632
|
+
needs_recreate = (loc != 'right'
|
|
1633
|
+
or round(shrink, 2) != 1.0
|
|
1634
|
+
or round(pad, 2) != 0.05)
|
|
1635
|
+
|
|
1636
|
+
if needs_recreate:
|
|
1637
|
+
lines.append(
|
|
1638
|
+
f"if getattr({mappable_var}, 'colorbar', None) is not None:")
|
|
1639
|
+
lines.append(f" {mappable_var}.colorbar.remove()")
|
|
1640
|
+
cbar_args = [f"{mappable_var}", f"ax={ax_var}"]
|
|
1641
|
+
if loc != 'right':
|
|
1642
|
+
cbar_args.append(f"location={_fmt(loc)}")
|
|
1643
|
+
if round(shrink, 2) != 1.0:
|
|
1644
|
+
cbar_args.append(f"shrink={_fmt(shrink)}")
|
|
1645
|
+
if round(pad, 2) != 0.05:
|
|
1646
|
+
cbar_args.append(f"pad={_fmt(pad)}")
|
|
1647
|
+
lines.append(f"_cbar = fig.colorbar({', '.join(cbar_args)})")
|
|
1648
|
+
else:
|
|
1649
|
+
# Reuse existing colorbar when possible (avoids remove/recreate)
|
|
1650
|
+
lines.append(
|
|
1651
|
+
f"_cbar = getattr({mappable_var}, 'colorbar', None)")
|
|
1652
|
+
lines.append(f"if _cbar is None:")
|
|
1653
|
+
lines.append(
|
|
1654
|
+
f" _cbar = fig.colorbar({mappable_var}, ax={ax_var})")
|
|
1655
|
+
lines.append(f"else:")
|
|
1656
|
+
lines.append(f" _cbar.update_normal({mappable_var})")
|
|
1657
|
+
|
|
1658
|
+
if label:
|
|
1659
|
+
lines.append(
|
|
1660
|
+
f"_cbar.set_label({_fmt(label)}, fontsize={_fmt(label_fs)})")
|
|
1661
|
+
if round(tick_fs, 1) != 10.0:
|
|
1662
|
+
lines.append(f"_cbar.ax.tick_params(labelsize={_fmt(tick_fs)})")
|
|
1663
|
+
|
|
1664
|
+
|
|
1665
|
+
def _emit_errorbars(lines, ax_var, errorbar_infos):
|
|
1666
|
+
"""Emit code for errorbar styling.
|
|
1667
|
+
|
|
1668
|
+
Always uses ErrorbarContainer (ax.errorbar()). Four toggles control
|
|
1669
|
+
visibility: error bars, markers, connecting line, shaded region.
|
|
1670
|
+
Each section has its own color.
|
|
1671
|
+
"""
|
|
1672
|
+
for idx, info in enumerate(errorbar_infos):
|
|
1673
|
+
show_bars = info.get('show_bars', True)
|
|
1674
|
+
show_line = info.get('show_line', True)
|
|
1675
|
+
show_markers = info.get('show_markers', False)
|
|
1676
|
+
show_shaded = info.get('show_shaded', False)
|
|
1677
|
+
bar_color = info.get('bar_color', info.get('color', '#1f77b4'))
|
|
1678
|
+
marker_color = info.get('marker_color', bar_color)
|
|
1679
|
+
line_color = info.get('line_color', bar_color)
|
|
1680
|
+
shade_color = info.get('shade_color', bar_color)
|
|
1681
|
+
bar_alpha = info.get('bar_alpha', 1.0)
|
|
1682
|
+
marker_alpha = info.get('marker_alpha', 1.0)
|
|
1683
|
+
line_alpha = info.get('line_alpha', 1.0)
|
|
1684
|
+
lw = info.get('line_width', 1.5)
|
|
1685
|
+
ls = info.get('line_style', '-')
|
|
1686
|
+
bar_lw = info.get('bar_lw', 1.5)
|
|
1687
|
+
cap_size = info.get('cap_size', 3.0)
|
|
1688
|
+
marker = info.get('marker', '')
|
|
1689
|
+
marker_size = info.get('marker_size', 6.0)
|
|
1690
|
+
label = info.get('label', f'Errorbar {idx}')
|
|
1691
|
+
|
|
1692
|
+
lines.append(f"\n# Errorbar: {label}")
|
|
1693
|
+
|
|
1694
|
+
# Find the ErrorbarContainer by index
|
|
1695
|
+
lines.append(
|
|
1696
|
+
f"from matplotlib.container import "
|
|
1697
|
+
f"ErrorbarContainer as _EBC")
|
|
1698
|
+
lines.append(
|
|
1699
|
+
f"_eb_containers = [c for c in {ax_var}.containers "
|
|
1700
|
+
f"if isinstance(c, _EBC)]")
|
|
1701
|
+
lines.append(f"_eb = _eb_containers[{idx}]")
|
|
1702
|
+
|
|
1703
|
+
# Data line styling
|
|
1704
|
+
lines.append(f"if _eb[0] is not None:")
|
|
1705
|
+
if show_line:
|
|
1706
|
+
lines.append(f" _eb[0].set_color({_fmt(line_color)})")
|
|
1707
|
+
lines.append(f" _eb[0].set_linewidth({_fmt(lw)})")
|
|
1708
|
+
lines.append(f" _eb[0].set_linestyle({_fmt(ls)})")
|
|
1709
|
+
if line_alpha != 1.0:
|
|
1710
|
+
lines.append(
|
|
1711
|
+
f" _eb[0].set_alpha({_fmt(line_alpha)})")
|
|
1712
|
+
else:
|
|
1713
|
+
lines.append(f" _eb[0].set_linestyle('none')")
|
|
1714
|
+
lines.append(f" _eb[0].set_linewidth(0)")
|
|
1715
|
+
if show_markers and marker and marker not in ('', 'None', 'none'):
|
|
1716
|
+
lines.append(f" _eb[0].set_marker({_fmt(marker)})")
|
|
1717
|
+
lines.append(
|
|
1718
|
+
f" _eb[0].set_markersize({_fmt(marker_size)})")
|
|
1719
|
+
if marker_alpha != 1.0:
|
|
1720
|
+
# Apply marker alpha via RGBA to avoid affecting line
|
|
1721
|
+
lines.append(
|
|
1722
|
+
f" import matplotlib.colors as _mc")
|
|
1723
|
+
lines.append(
|
|
1724
|
+
f" _mkr_rgba = list("
|
|
1725
|
+
f"_mc.to_rgba({_fmt(marker_color)}))")
|
|
1726
|
+
lines.append(
|
|
1727
|
+
f" _mkr_rgba[3] = {_fmt(marker_alpha)}")
|
|
1728
|
+
lines.append(
|
|
1729
|
+
f" _eb[0].set_markerfacecolor(_mkr_rgba)")
|
|
1730
|
+
lines.append(
|
|
1731
|
+
f" _eb[0].set_markeredgecolor(_mkr_rgba)")
|
|
1732
|
+
else:
|
|
1733
|
+
lines.append(
|
|
1734
|
+
f" _eb[0].set_markerfacecolor("
|
|
1735
|
+
f"{_fmt(marker_color)})")
|
|
1736
|
+
lines.append(
|
|
1737
|
+
f" _eb[0].set_markeredgecolor("
|
|
1738
|
+
f"{_fmt(marker_color)})")
|
|
1739
|
+
|
|
1740
|
+
# Cap lines
|
|
1741
|
+
lines.append(f"for _cap in _eb[1]:")
|
|
1742
|
+
lines.append(f" _cap.set_color({_fmt(bar_color)})")
|
|
1743
|
+
if show_bars:
|
|
1744
|
+
if cap_size > 0:
|
|
1745
|
+
lines.append(
|
|
1746
|
+
f" _cap.set_markersize({_fmt(cap_size)})")
|
|
1747
|
+
else:
|
|
1748
|
+
lines.append(f" _cap.set_markersize(0)")
|
|
1749
|
+
if bar_alpha != 1.0:
|
|
1750
|
+
lines.append(f" _cap.set_alpha({_fmt(bar_alpha)})")
|
|
1751
|
+
|
|
1752
|
+
# Bar line collections
|
|
1753
|
+
lines.append(f"for _bar in _eb[2]:")
|
|
1754
|
+
lines.append(f" _bar.set_color({_fmt(bar_color)})")
|
|
1755
|
+
if show_bars:
|
|
1756
|
+
lines.append(f" _bar.set_linewidth({_fmt(bar_lw)})")
|
|
1757
|
+
else:
|
|
1758
|
+
lines.append(f" _bar.set_linewidth(0)")
|
|
1759
|
+
if bar_alpha != 1.0:
|
|
1760
|
+
lines.append(f" _bar.set_alpha({_fmt(bar_alpha)})")
|
|
1761
|
+
|
|
1762
|
+
# Shaded region — create fill_between from bar segment data
|
|
1763
|
+
if show_shaded and info.get('has_yerr', False):
|
|
1764
|
+
shade_alpha = info.get('shade_alpha', 0.3)
|
|
1765
|
+
lines.append(f"_segs = _eb[2][0].get_segments()")
|
|
1766
|
+
lines.append(f"if _segs:")
|
|
1767
|
+
lines.append(
|
|
1768
|
+
f" _x_s = np.array([s[0][0] for s in _segs])")
|
|
1769
|
+
lines.append(
|
|
1770
|
+
f" _y_lo = np.array([s[0][1] for s in _segs])")
|
|
1771
|
+
lines.append(
|
|
1772
|
+
f" _y_hi = np.array([s[1][1] for s in _segs])")
|
|
1773
|
+
lines.append(
|
|
1774
|
+
f" {ax_var}.fill_between(_x_s, _y_lo, _y_hi, "
|
|
1775
|
+
f"color={_fmt(shade_color)}, "
|
|
1776
|
+
f"alpha={_fmt(shade_alpha)})")
|
|
1777
|
+
|
|
1778
|
+
|
|
1779
|
+
def _fmt(val: Any) -> str:
|
|
1780
|
+
"""Format a value as a Python literal."""
|
|
1781
|
+
if isinstance(val, bool):
|
|
1782
|
+
return repr(val)
|
|
1783
|
+
# Preserve integers (including numpy int types)
|
|
1784
|
+
import numbers
|
|
1785
|
+
if isinstance(val, numbers.Integral):
|
|
1786
|
+
return repr(int(val))
|
|
1787
|
+
# Handle numpy floats and regular floats
|
|
1788
|
+
try:
|
|
1789
|
+
f = float(val)
|
|
1790
|
+
return repr(round(f, 2))
|
|
1791
|
+
except (TypeError, ValueError):
|
|
1792
|
+
pass
|
|
1793
|
+
return repr(val)
|