figrecipe 0.5.0__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- figrecipe/__init__.py +361 -93
- figrecipe/_dev/__init__.py +120 -0
- figrecipe/_dev/demo_plotters/__init__.py +195 -0
- figrecipe/_dev/demo_plotters/plot_acorr.py +24 -0
- figrecipe/_dev/demo_plotters/plot_angle_spectrum.py +28 -0
- figrecipe/_dev/demo_plotters/plot_bar.py +25 -0
- figrecipe/_dev/demo_plotters/plot_barbs.py +30 -0
- figrecipe/_dev/demo_plotters/plot_barh.py +25 -0
- figrecipe/_dev/demo_plotters/plot_boxplot.py +24 -0
- figrecipe/_dev/demo_plotters/plot_cohere.py +29 -0
- figrecipe/_dev/demo_plotters/plot_contour.py +30 -0
- figrecipe/_dev/demo_plotters/plot_contourf.py +29 -0
- figrecipe/_dev/demo_plotters/plot_csd.py +29 -0
- figrecipe/_dev/demo_plotters/plot_ecdf.py +24 -0
- figrecipe/_dev/demo_plotters/plot_errorbar.py +28 -0
- figrecipe/_dev/demo_plotters/plot_eventplot.py +25 -0
- figrecipe/_dev/demo_plotters/plot_fill.py +29 -0
- figrecipe/_dev/demo_plotters/plot_fill_between.py +30 -0
- figrecipe/_dev/demo_plotters/plot_fill_betweenx.py +28 -0
- figrecipe/_dev/demo_plotters/plot_hexbin.py +25 -0
- figrecipe/_dev/demo_plotters/plot_hist.py +24 -0
- figrecipe/_dev/demo_plotters/plot_hist2d.py +25 -0
- figrecipe/_dev/demo_plotters/plot_imshow.py +23 -0
- figrecipe/_dev/demo_plotters/plot_loglog.py +27 -0
- figrecipe/_dev/demo_plotters/plot_magnitude_spectrum.py +28 -0
- figrecipe/_dev/demo_plotters/plot_matshow.py +23 -0
- figrecipe/_dev/demo_plotters/plot_pcolor.py +29 -0
- figrecipe/_dev/demo_plotters/plot_pcolormesh.py +29 -0
- figrecipe/_dev/demo_plotters/plot_phase_spectrum.py +28 -0
- figrecipe/_dev/demo_plotters/plot_pie.py +23 -0
- figrecipe/_dev/demo_plotters/plot_plot.py +27 -0
- figrecipe/_dev/demo_plotters/plot_psd.py +29 -0
- figrecipe/_dev/demo_plotters/plot_quiver.py +30 -0
- figrecipe/_dev/demo_plotters/plot_scatter.py +24 -0
- figrecipe/_dev/demo_plotters/plot_semilogx.py +27 -0
- figrecipe/_dev/demo_plotters/plot_semilogy.py +27 -0
- figrecipe/_dev/demo_plotters/plot_specgram.py +30 -0
- figrecipe/_dev/demo_plotters/plot_spy.py +29 -0
- figrecipe/_dev/demo_plotters/plot_stackplot.py +29 -0
- figrecipe/_dev/demo_plotters/plot_stairs.py +27 -0
- figrecipe/_dev/demo_plotters/plot_stem.py +27 -0
- figrecipe/_dev/demo_plotters/plot_step.py +27 -0
- figrecipe/_dev/demo_plotters/plot_streamplot.py +30 -0
- figrecipe/_dev/demo_plotters/plot_tricontour.py +28 -0
- figrecipe/_dev/demo_plotters/plot_tricontourf.py +28 -0
- figrecipe/_dev/demo_plotters/plot_tripcolor.py +29 -0
- figrecipe/_dev/demo_plotters/plot_triplot.py +25 -0
- figrecipe/_dev/demo_plotters/plot_violinplot.py +25 -0
- figrecipe/_dev/demo_plotters/plot_xcorr.py +25 -0
- figrecipe/_editor/__init__.py +230 -0
- figrecipe/_editor/_bbox.py +978 -0
- figrecipe/_editor/_flask_app.py +1229 -0
- figrecipe/_editor/_hitmap.py +937 -0
- figrecipe/_editor/_overrides.py +318 -0
- figrecipe/_editor/_renderer.py +349 -0
- figrecipe/_editor/_templates/__init__.py +75 -0
- figrecipe/_editor/_templates/_html.py +406 -0
- figrecipe/_editor/_templates/_scripts.py +2778 -0
- figrecipe/_editor/_templates/_styles.py +1326 -0
- figrecipe/_params/_DECORATION_METHODS.py +27 -0
- figrecipe/_params/_PLOTTING_METHODS.py +58 -0
- figrecipe/_params/__init__.py +9 -0
- figrecipe/_recorder.py +126 -73
- figrecipe/_reproducer.py +658 -41
- figrecipe/_seaborn.py +14 -9
- figrecipe/_serializer.py +2 -2
- figrecipe/_signatures/README.md +68 -0
- figrecipe/_signatures/__init__.py +12 -2
- figrecipe/_signatures/_loader.py +515 -56
- figrecipe/_utils/__init__.py +6 -4
- figrecipe/_utils/_crop.py +10 -4
- figrecipe/_utils/_image_diff.py +37 -33
- figrecipe/_utils/_numpy_io.py +0 -1
- figrecipe/_utils/_units.py +11 -3
- figrecipe/_validator.py +12 -3
- figrecipe/_wrappers/_axes.py +860 -46
- figrecipe/_wrappers/_figure.py +115 -18
- figrecipe/plt.py +0 -1
- figrecipe/pyplot.py +2 -1
- figrecipe/styles/__init__.py +9 -10
- figrecipe/styles/_style_applier.py +332 -28
- figrecipe/styles/_style_loader.py +172 -44
- figrecipe/styles/presets/MATPLOTLIB.yaml +94 -0
- figrecipe/styles/presets/SCITEX.yaml +176 -0
- figrecipe-0.6.0.dist-info/METADATA +394 -0
- figrecipe-0.6.0.dist-info/RECORD +90 -0
- figrecipe-0.5.0.dist-info/METADATA +0 -336
- figrecipe-0.5.0.dist-info/RECORD +0 -26
- {figrecipe-0.5.0.dist-info → figrecipe-0.6.0.dist-info}/WHEEL +0 -0
- {figrecipe-0.5.0.dist-info → figrecipe-0.6.0.dist-info}/licenses/LICENSE +0 -0
figrecipe/_wrappers/_axes.py
CHANGED
|
@@ -2,10 +2,9 @@
|
|
|
2
2
|
# -*- coding: utf-8 -*-
|
|
3
3
|
"""Wrapped Axes that records all plotting calls."""
|
|
4
4
|
|
|
5
|
-
from typing import Any, Dict,
|
|
5
|
+
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
|
|
6
6
|
|
|
7
7
|
import numpy as np
|
|
8
|
-
import matplotlib.pyplot as plt
|
|
9
8
|
from matplotlib.axes import Axes
|
|
10
9
|
|
|
11
10
|
if TYPE_CHECKING:
|
|
@@ -35,6 +34,11 @@ class RecordingAxes:
|
|
|
35
34
|
>>> # The call is recorded automatically
|
|
36
35
|
"""
|
|
37
36
|
|
|
37
|
+
# Methods whose results can be referenced by other methods (e.g., clabel needs ContourSet)
|
|
38
|
+
RESULT_REFERENCEABLE_METHODS = {"contour", "contourf"}
|
|
39
|
+
# Methods that take results from other methods as arguments
|
|
40
|
+
RESULT_REFERENCING_METHODS = {"clabel"}
|
|
41
|
+
|
|
38
42
|
def __init__(
|
|
39
43
|
self,
|
|
40
44
|
ax: Axes,
|
|
@@ -45,6 +49,8 @@ class RecordingAxes:
|
|
|
45
49
|
self._recorder = recorder
|
|
46
50
|
self._position = position
|
|
47
51
|
self._track = True
|
|
52
|
+
# Map matplotlib result objects (by id) to their source call_id
|
|
53
|
+
self._result_refs: Dict[int, str] = {}
|
|
48
54
|
|
|
49
55
|
@property
|
|
50
56
|
def ax(self) -> Axes:
|
|
@@ -87,6 +93,7 @@ class RecordingAxes:
|
|
|
87
93
|
callable
|
|
88
94
|
Wrapped method that records calls.
|
|
89
95
|
"""
|
|
96
|
+
|
|
90
97
|
def wrapper(*args, id: Optional[str] = None, track: bool = True, **kwargs):
|
|
91
98
|
# Call the original method first (without our custom kwargs)
|
|
92
99
|
result = method(*args, **kwargs)
|
|
@@ -96,24 +103,104 @@ class RecordingAxes:
|
|
|
96
103
|
# Capture actual colors from result for plotting methods
|
|
97
104
|
# that use matplotlib's color cycle
|
|
98
105
|
recorded_kwargs = kwargs.copy()
|
|
99
|
-
if method_name in (
|
|
100
|
-
|
|
101
|
-
|
|
106
|
+
if method_name in (
|
|
107
|
+
"plot",
|
|
108
|
+
"scatter",
|
|
109
|
+
"bar",
|
|
110
|
+
"barh",
|
|
111
|
+
"step",
|
|
112
|
+
"fill_between",
|
|
113
|
+
):
|
|
114
|
+
# Check if fmt string already specifies color (e.g., "b-", "r--")
|
|
115
|
+
has_fmt_color = self._args_have_fmt_color(args)
|
|
116
|
+
if (
|
|
117
|
+
"color" not in recorded_kwargs
|
|
118
|
+
and "c" not in recorded_kwargs
|
|
119
|
+
and not has_fmt_color
|
|
120
|
+
):
|
|
121
|
+
actual_color = self._extract_color_from_result(
|
|
122
|
+
method_name, result
|
|
123
|
+
)
|
|
102
124
|
if actual_color is not None:
|
|
103
|
-
recorded_kwargs[
|
|
125
|
+
recorded_kwargs["color"] = actual_color
|
|
104
126
|
|
|
105
|
-
|
|
127
|
+
# Process args to detect result references (e.g., clabel's ContourSet)
|
|
128
|
+
processed_args = self._process_result_refs_in_args(args, method_name)
|
|
129
|
+
|
|
130
|
+
call_record = self._recorder.record_call(
|
|
106
131
|
ax_position=self._position,
|
|
107
132
|
method_name=method_name,
|
|
108
|
-
args=
|
|
133
|
+
args=processed_args,
|
|
109
134
|
kwargs=recorded_kwargs,
|
|
110
135
|
call_id=id,
|
|
111
136
|
)
|
|
112
137
|
|
|
138
|
+
# Store result reference for methods whose results can be used later
|
|
139
|
+
if method_name in self.RESULT_REFERENCEABLE_METHODS:
|
|
140
|
+
import builtins
|
|
141
|
+
|
|
142
|
+
self._result_refs[builtins.id(result)] = call_record.id
|
|
143
|
+
|
|
113
144
|
return result
|
|
114
145
|
|
|
115
146
|
return wrapper
|
|
116
147
|
|
|
148
|
+
def _process_result_refs_in_args(self, args: tuple, method_name: str) -> tuple:
|
|
149
|
+
"""Process args to replace matplotlib objects with references.
|
|
150
|
+
|
|
151
|
+
For methods like clabel that take a ContourSet as argument,
|
|
152
|
+
replace the object with a reference to the original call_id.
|
|
153
|
+
|
|
154
|
+
Parameters
|
|
155
|
+
----------
|
|
156
|
+
args : tuple
|
|
157
|
+
Original arguments.
|
|
158
|
+
method_name : str
|
|
159
|
+
Name of the method.
|
|
160
|
+
|
|
161
|
+
Returns
|
|
162
|
+
-------
|
|
163
|
+
tuple
|
|
164
|
+
Processed args with references.
|
|
165
|
+
"""
|
|
166
|
+
if method_name not in self.RESULT_REFERENCING_METHODS:
|
|
167
|
+
return args
|
|
168
|
+
|
|
169
|
+
import builtins
|
|
170
|
+
|
|
171
|
+
processed = []
|
|
172
|
+
for i, arg in enumerate(args):
|
|
173
|
+
obj_id = builtins.id(arg)
|
|
174
|
+
if obj_id in self._result_refs:
|
|
175
|
+
# This arg is a reference to a previous call's result
|
|
176
|
+
processed.append({"__ref__": self._result_refs[obj_id]})
|
|
177
|
+
else:
|
|
178
|
+
processed.append(arg)
|
|
179
|
+
return tuple(processed)
|
|
180
|
+
|
|
181
|
+
def _args_have_fmt_color(self, args: tuple) -> bool:
|
|
182
|
+
"""Check if args contain a matplotlib fmt string with color specifier.
|
|
183
|
+
|
|
184
|
+
Fmt strings like "b-", "r--", "go" contain color codes (b,g,r,c,m,y,k,w).
|
|
185
|
+
|
|
186
|
+
Parameters
|
|
187
|
+
----------
|
|
188
|
+
args : tuple
|
|
189
|
+
Arguments passed to plot method.
|
|
190
|
+
|
|
191
|
+
Returns
|
|
192
|
+
-------
|
|
193
|
+
bool
|
|
194
|
+
True if a fmt string with color is found.
|
|
195
|
+
"""
|
|
196
|
+
color_codes = set("bgrcmykw")
|
|
197
|
+
for arg in args:
|
|
198
|
+
if isinstance(arg, str) and len(arg) >= 1 and len(arg) <= 4:
|
|
199
|
+
# Fmt strings are short (e.g., "b-", "r--", "go", "k:")
|
|
200
|
+
if arg[0] in color_codes:
|
|
201
|
+
return True
|
|
202
|
+
return False
|
|
203
|
+
|
|
117
204
|
def _extract_color_from_result(self, method_name: str, result) -> Optional[str]:
|
|
118
205
|
"""Extract actual color used from plot result.
|
|
119
206
|
|
|
@@ -130,34 +217,37 @@ class RecordingAxes:
|
|
|
130
217
|
The color used, or None if not extractable.
|
|
131
218
|
"""
|
|
132
219
|
try:
|
|
133
|
-
if method_name ==
|
|
220
|
+
if method_name == "plot":
|
|
134
221
|
# plot() returns list of Line2D
|
|
135
|
-
if result and hasattr(result[0],
|
|
222
|
+
if result and hasattr(result[0], "get_color"):
|
|
136
223
|
return result[0].get_color()
|
|
137
|
-
elif method_name ==
|
|
224
|
+
elif method_name == "scatter":
|
|
138
225
|
# scatter() returns PathCollection
|
|
139
|
-
if hasattr(result,
|
|
226
|
+
if hasattr(result, "get_facecolor"):
|
|
140
227
|
fc = result.get_facecolor()
|
|
141
228
|
if len(fc) > 0:
|
|
142
229
|
# Convert RGBA to hex
|
|
143
230
|
import matplotlib.colors as mcolors
|
|
231
|
+
|
|
144
232
|
return mcolors.to_hex(fc[0])
|
|
145
|
-
elif method_name in (
|
|
233
|
+
elif method_name in ("bar", "barh"):
|
|
146
234
|
# bar() returns BarContainer
|
|
147
|
-
if hasattr(result,
|
|
235
|
+
if hasattr(result, "patches") and result.patches:
|
|
148
236
|
fc = result.patches[0].get_facecolor()
|
|
149
237
|
import matplotlib.colors as mcolors
|
|
238
|
+
|
|
150
239
|
return mcolors.to_hex(fc)
|
|
151
|
-
elif method_name ==
|
|
240
|
+
elif method_name == "step":
|
|
152
241
|
# step() returns list of Line2D
|
|
153
|
-
if result and hasattr(result[0],
|
|
242
|
+
if result and hasattr(result[0], "get_color"):
|
|
154
243
|
return result[0].get_color()
|
|
155
|
-
elif method_name ==
|
|
244
|
+
elif method_name == "fill_between":
|
|
156
245
|
# fill_between() returns PolyCollection
|
|
157
|
-
if hasattr(result,
|
|
246
|
+
if hasattr(result, "get_facecolor"):
|
|
158
247
|
fc = result.get_facecolor()
|
|
159
248
|
if len(fc) > 0:
|
|
160
249
|
import matplotlib.colors as mcolors
|
|
250
|
+
|
|
161
251
|
return mcolors.to_hex(fc[0])
|
|
162
252
|
except Exception:
|
|
163
253
|
pass
|
|
@@ -213,23 +303,29 @@ class RecordingAxes:
|
|
|
213
303
|
if key in data_arrays:
|
|
214
304
|
arr = data_arrays[key]
|
|
215
305
|
if should_store_inline(arr):
|
|
216
|
-
processed_args.append(
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
306
|
+
processed_args.append(
|
|
307
|
+
{
|
|
308
|
+
"name": f"arg{i}",
|
|
309
|
+
"data": to_serializable(arr),
|
|
310
|
+
"dtype": str(arr.dtype),
|
|
311
|
+
}
|
|
312
|
+
)
|
|
221
313
|
else:
|
|
222
|
-
processed_args.append(
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
314
|
+
processed_args.append(
|
|
315
|
+
{
|
|
316
|
+
"name": f"arg{i}",
|
|
317
|
+
"data": "__FILE__",
|
|
318
|
+
"dtype": str(arr.dtype),
|
|
319
|
+
"_array": arr,
|
|
320
|
+
}
|
|
321
|
+
)
|
|
228
322
|
else:
|
|
229
|
-
processed_args.append(
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
323
|
+
processed_args.append(
|
|
324
|
+
{
|
|
325
|
+
"name": f"arg{i}",
|
|
326
|
+
"data": arg,
|
|
327
|
+
}
|
|
328
|
+
)
|
|
233
329
|
|
|
234
330
|
# Process DataFrame column data
|
|
235
331
|
for key, arr in data_arrays.items():
|
|
@@ -237,20 +333,24 @@ class RecordingAxes:
|
|
|
237
333
|
param_name = key[5:] # Remove "_col_" prefix
|
|
238
334
|
col_name = data_arrays.get(f"_colname_{param_name}", param_name)
|
|
239
335
|
if should_store_inline(arr):
|
|
240
|
-
processed_args.append(
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
336
|
+
processed_args.append(
|
|
337
|
+
{
|
|
338
|
+
"name": col_name,
|
|
339
|
+
"param": param_name,
|
|
340
|
+
"data": to_serializable(arr),
|
|
341
|
+
"dtype": str(arr.dtype),
|
|
342
|
+
}
|
|
343
|
+
)
|
|
246
344
|
else:
|
|
247
|
-
processed_args.append(
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
345
|
+
processed_args.append(
|
|
346
|
+
{
|
|
347
|
+
"name": col_name,
|
|
348
|
+
"param": param_name,
|
|
349
|
+
"data": "__FILE__",
|
|
350
|
+
"dtype": str(arr.dtype),
|
|
351
|
+
"_array": arr,
|
|
352
|
+
}
|
|
353
|
+
)
|
|
254
354
|
|
|
255
355
|
# Process kwarg arrays
|
|
256
356
|
processed_kwargs = dict(kwargs)
|
|
@@ -294,6 +394,379 @@ class RecordingAxes:
|
|
|
294
394
|
def yaxis(self):
|
|
295
395
|
return self._ax.yaxis
|
|
296
396
|
|
|
397
|
+
def pie(
|
|
398
|
+
self,
|
|
399
|
+
x,
|
|
400
|
+
*,
|
|
401
|
+
id: Optional[str] = None,
|
|
402
|
+
track: bool = True,
|
|
403
|
+
**kwargs,
|
|
404
|
+
):
|
|
405
|
+
"""Pie chart with automatic SCITEX styling.
|
|
406
|
+
|
|
407
|
+
Parameters
|
|
408
|
+
----------
|
|
409
|
+
x : array-like
|
|
410
|
+
Wedge sizes.
|
|
411
|
+
id : str, optional
|
|
412
|
+
Custom ID for this call.
|
|
413
|
+
track : bool, optional
|
|
414
|
+
Whether to record this call (default: True).
|
|
415
|
+
**kwargs
|
|
416
|
+
Additional arguments passed to matplotlib's pie.
|
|
417
|
+
|
|
418
|
+
Returns
|
|
419
|
+
-------
|
|
420
|
+
tuple
|
|
421
|
+
(patches, texts) or (patches, texts, autotexts) if autopct is set.
|
|
422
|
+
"""
|
|
423
|
+
from ..styles import get_style
|
|
424
|
+
from ..styles._style_applier import check_font
|
|
425
|
+
|
|
426
|
+
# Call matplotlib's pie
|
|
427
|
+
result = self._ax.pie(x, **kwargs)
|
|
428
|
+
|
|
429
|
+
# Get style settings
|
|
430
|
+
style = get_style()
|
|
431
|
+
if style:
|
|
432
|
+
pie_style = style.get("pie", {})
|
|
433
|
+
text_pt = pie_style.get("text_pt", 6)
|
|
434
|
+
show_axes = pie_style.get("show_axes", False)
|
|
435
|
+
font_family = check_font(style.get("fonts", {}).get("family", "Arial"))
|
|
436
|
+
|
|
437
|
+
# Apply text size to all pie text elements (labels and percentages)
|
|
438
|
+
for text in self._ax.texts:
|
|
439
|
+
text.set_fontsize(text_pt)
|
|
440
|
+
text.set_fontfamily(font_family)
|
|
441
|
+
|
|
442
|
+
# Hide axes if configured (default: hide for pie charts)
|
|
443
|
+
if not show_axes:
|
|
444
|
+
self._ax.set_xticks([])
|
|
445
|
+
self._ax.set_yticks([])
|
|
446
|
+
self._ax.set_xticklabels([])
|
|
447
|
+
self._ax.set_yticklabels([])
|
|
448
|
+
# Hide spines
|
|
449
|
+
for spine in self._ax.spines.values():
|
|
450
|
+
spine.set_visible(False)
|
|
451
|
+
|
|
452
|
+
# Record the call if tracking is enabled
|
|
453
|
+
if self._track and track:
|
|
454
|
+
self._recorder.record_call(
|
|
455
|
+
ax_position=self._position,
|
|
456
|
+
method_name="pie",
|
|
457
|
+
args=(x,),
|
|
458
|
+
kwargs=kwargs,
|
|
459
|
+
call_id=id,
|
|
460
|
+
)
|
|
461
|
+
|
|
462
|
+
return result
|
|
463
|
+
|
|
464
|
+
def imshow(
|
|
465
|
+
self,
|
|
466
|
+
X,
|
|
467
|
+
*,
|
|
468
|
+
id: Optional[str] = None,
|
|
469
|
+
track: bool = True,
|
|
470
|
+
**kwargs,
|
|
471
|
+
):
|
|
472
|
+
"""Display image with automatic SCITEX styling.
|
|
473
|
+
|
|
474
|
+
Parameters
|
|
475
|
+
----------
|
|
476
|
+
X : array-like
|
|
477
|
+
Image data.
|
|
478
|
+
id : str, optional
|
|
479
|
+
Custom ID for this call.
|
|
480
|
+
track : bool, optional
|
|
481
|
+
Whether to record this call (default: True).
|
|
482
|
+
**kwargs
|
|
483
|
+
Additional arguments passed to matplotlib's imshow.
|
|
484
|
+
|
|
485
|
+
Returns
|
|
486
|
+
-------
|
|
487
|
+
AxesImage
|
|
488
|
+
The created image.
|
|
489
|
+
"""
|
|
490
|
+
from ..styles import get_style
|
|
491
|
+
|
|
492
|
+
# Call matplotlib's imshow
|
|
493
|
+
result = self._ax.imshow(X, **kwargs)
|
|
494
|
+
|
|
495
|
+
# Get style settings
|
|
496
|
+
style = get_style()
|
|
497
|
+
if style:
|
|
498
|
+
imshow_style = style.get("imshow", {})
|
|
499
|
+
show_axes = imshow_style.get("show_axes", True)
|
|
500
|
+
show_labels = imshow_style.get("show_labels", True)
|
|
501
|
+
|
|
502
|
+
# Hide axes if configured
|
|
503
|
+
if not show_axes:
|
|
504
|
+
self._ax.set_xticks([])
|
|
505
|
+
self._ax.set_yticks([])
|
|
506
|
+
self._ax.set_xticklabels([])
|
|
507
|
+
self._ax.set_yticklabels([])
|
|
508
|
+
# Hide spines
|
|
509
|
+
for spine in self._ax.spines.values():
|
|
510
|
+
spine.set_visible(False)
|
|
511
|
+
|
|
512
|
+
if not show_labels:
|
|
513
|
+
self._ax.set_xlabel("")
|
|
514
|
+
self._ax.set_ylabel("")
|
|
515
|
+
|
|
516
|
+
# Record the call if tracking is enabled
|
|
517
|
+
if self._track and track:
|
|
518
|
+
self._recorder.record_call(
|
|
519
|
+
ax_position=self._position,
|
|
520
|
+
method_name="imshow",
|
|
521
|
+
args=(X,),
|
|
522
|
+
kwargs=kwargs,
|
|
523
|
+
call_id=id,
|
|
524
|
+
)
|
|
525
|
+
|
|
526
|
+
return result
|
|
527
|
+
|
|
528
|
+
def violinplot(
|
|
529
|
+
self,
|
|
530
|
+
dataset,
|
|
531
|
+
positions=None,
|
|
532
|
+
*,
|
|
533
|
+
id: Optional[str] = None,
|
|
534
|
+
track: bool = True,
|
|
535
|
+
inner: Optional[str] = None,
|
|
536
|
+
**kwargs,
|
|
537
|
+
):
|
|
538
|
+
"""Violin plot with support for inner display options.
|
|
539
|
+
|
|
540
|
+
Parameters
|
|
541
|
+
----------
|
|
542
|
+
dataset : array-like
|
|
543
|
+
Data to plot.
|
|
544
|
+
positions : array-like, optional
|
|
545
|
+
Position of each violin on x-axis.
|
|
546
|
+
id : str, optional
|
|
547
|
+
Custom ID for this call.
|
|
548
|
+
track : bool, optional
|
|
549
|
+
Whether to record this call (default: True).
|
|
550
|
+
inner : str, optional
|
|
551
|
+
Inner display type: "box", "quartile", "stick", "point", "swarm", or None.
|
|
552
|
+
Default is from style config (SCITEX default: "box").
|
|
553
|
+
**kwargs
|
|
554
|
+
Additional arguments passed to matplotlib's violinplot.
|
|
555
|
+
|
|
556
|
+
Returns
|
|
557
|
+
-------
|
|
558
|
+
dict
|
|
559
|
+
Dictionary with violin parts (bodies, cbars, cmins, cmaxes, cmeans, cmedians).
|
|
560
|
+
"""
|
|
561
|
+
from ..styles import get_style
|
|
562
|
+
|
|
563
|
+
# Get style settings
|
|
564
|
+
style = get_style()
|
|
565
|
+
violin_style = style.get("violinplot", {}) if style else {}
|
|
566
|
+
|
|
567
|
+
# Determine inner type (user kwarg > style config > default)
|
|
568
|
+
if inner is None:
|
|
569
|
+
inner = violin_style.get("inner", "box")
|
|
570
|
+
|
|
571
|
+
# Get violin display options from style
|
|
572
|
+
showmeans = kwargs.pop("showmeans", violin_style.get("showmeans", False))
|
|
573
|
+
showmedians = kwargs.pop("showmedians", violin_style.get("showmedians", True))
|
|
574
|
+
showextrema = kwargs.pop("showextrema", violin_style.get("showextrema", False))
|
|
575
|
+
|
|
576
|
+
# Call matplotlib's violinplot
|
|
577
|
+
result = self._ax.violinplot(
|
|
578
|
+
dataset,
|
|
579
|
+
positions=positions,
|
|
580
|
+
showmeans=showmeans,
|
|
581
|
+
showmedians=showmedians if inner not in ("box", "swarm") else False,
|
|
582
|
+
showextrema=showextrema if inner not in ("box", "swarm") else False,
|
|
583
|
+
**kwargs,
|
|
584
|
+
)
|
|
585
|
+
|
|
586
|
+
# Apply alpha from style to violin bodies
|
|
587
|
+
alpha = violin_style.get("alpha", 0.7)
|
|
588
|
+
if "bodies" in result:
|
|
589
|
+
for body in result["bodies"]:
|
|
590
|
+
body.set_alpha(alpha)
|
|
591
|
+
|
|
592
|
+
# Overlay inner elements based on inner type
|
|
593
|
+
if positions is None:
|
|
594
|
+
positions = list(range(1, len(dataset) + 1))
|
|
595
|
+
|
|
596
|
+
if inner == "box":
|
|
597
|
+
self._add_violin_inner_box(dataset, positions, violin_style)
|
|
598
|
+
elif inner == "swarm":
|
|
599
|
+
self._add_violin_inner_swarm(dataset, positions, violin_style)
|
|
600
|
+
elif inner == "quartile":
|
|
601
|
+
# quartile lines are handled by showmedians + showextrema
|
|
602
|
+
pass
|
|
603
|
+
elif inner == "stick":
|
|
604
|
+
self._add_violin_inner_stick(dataset, positions, violin_style)
|
|
605
|
+
elif inner == "point":
|
|
606
|
+
self._add_violin_inner_point(dataset, positions, violin_style)
|
|
607
|
+
|
|
608
|
+
# Record the call if tracking is enabled
|
|
609
|
+
if self._track and track:
|
|
610
|
+
recorded_kwargs = kwargs.copy()
|
|
611
|
+
recorded_kwargs["inner"] = inner
|
|
612
|
+
recorded_kwargs["showmeans"] = showmeans
|
|
613
|
+
recorded_kwargs["showmedians"] = showmedians
|
|
614
|
+
recorded_kwargs["showextrema"] = showextrema
|
|
615
|
+
|
|
616
|
+
self._recorder.record_call(
|
|
617
|
+
ax_position=self._position,
|
|
618
|
+
method_name="violinplot",
|
|
619
|
+
args=(dataset,),
|
|
620
|
+
kwargs=recorded_kwargs,
|
|
621
|
+
call_id=id,
|
|
622
|
+
)
|
|
623
|
+
|
|
624
|
+
return result
|
|
625
|
+
|
|
626
|
+
def _add_violin_inner_box(self, dataset, positions, style: Dict[str, Any]) -> None:
|
|
627
|
+
"""Add box plot inside violin.
|
|
628
|
+
|
|
629
|
+
Parameters
|
|
630
|
+
----------
|
|
631
|
+
dataset : array-like
|
|
632
|
+
Data arrays for each violin.
|
|
633
|
+
positions : array-like
|
|
634
|
+
X positions of violins.
|
|
635
|
+
style : dict
|
|
636
|
+
Violin style configuration.
|
|
637
|
+
"""
|
|
638
|
+
from ..styles._style_applier import mm_to_pt
|
|
639
|
+
|
|
640
|
+
whisker_lw = mm_to_pt(style.get("whisker_mm", 0.2))
|
|
641
|
+
median_size = mm_to_pt(style.get("median_mm", 0.8))
|
|
642
|
+
|
|
643
|
+
for i, (data, pos) in enumerate(zip(dataset, positions)):
|
|
644
|
+
data = np.asarray(data)
|
|
645
|
+
q1, median, q3 = np.percentile(data, [25, 50, 75])
|
|
646
|
+
iqr = q3 - q1
|
|
647
|
+
whisker_low = max(data.min(), q1 - 1.5 * iqr)
|
|
648
|
+
whisker_high = min(data.max(), q3 + 1.5 * iqr)
|
|
649
|
+
|
|
650
|
+
# Draw box (Q1 to Q3)
|
|
651
|
+
self._ax.vlines(
|
|
652
|
+
pos, q1, q3, colors="black", linewidths=whisker_lw, zorder=3
|
|
653
|
+
)
|
|
654
|
+
# Draw whiskers
|
|
655
|
+
self._ax.vlines(
|
|
656
|
+
pos,
|
|
657
|
+
whisker_low,
|
|
658
|
+
q1,
|
|
659
|
+
colors="black",
|
|
660
|
+
linewidths=whisker_lw * 0.5,
|
|
661
|
+
zorder=3,
|
|
662
|
+
)
|
|
663
|
+
self._ax.vlines(
|
|
664
|
+
pos,
|
|
665
|
+
q3,
|
|
666
|
+
whisker_high,
|
|
667
|
+
colors="black",
|
|
668
|
+
linewidths=whisker_lw * 0.5,
|
|
669
|
+
zorder=3,
|
|
670
|
+
)
|
|
671
|
+
# Draw median as a white dot with black edge
|
|
672
|
+
self._ax.scatter(
|
|
673
|
+
[pos],
|
|
674
|
+
[median],
|
|
675
|
+
s=median_size**2,
|
|
676
|
+
c="white",
|
|
677
|
+
edgecolors="black",
|
|
678
|
+
linewidths=whisker_lw,
|
|
679
|
+
zorder=4,
|
|
680
|
+
)
|
|
681
|
+
|
|
682
|
+
def _add_violin_inner_swarm(
|
|
683
|
+
self, dataset, positions, style: Dict[str, Any]
|
|
684
|
+
) -> None:
|
|
685
|
+
"""Add swarm points inside violin.
|
|
686
|
+
|
|
687
|
+
Parameters
|
|
688
|
+
----------
|
|
689
|
+
dataset : array-like
|
|
690
|
+
Data arrays for each violin.
|
|
691
|
+
positions : array-like
|
|
692
|
+
X positions of violins.
|
|
693
|
+
style : dict
|
|
694
|
+
Violin style configuration.
|
|
695
|
+
"""
|
|
696
|
+
from ..styles._style_applier import mm_to_pt
|
|
697
|
+
|
|
698
|
+
point_size = mm_to_pt(style.get("median_mm", 0.8))
|
|
699
|
+
|
|
700
|
+
for data, pos in zip(dataset, positions):
|
|
701
|
+
data = np.asarray(data)
|
|
702
|
+
n = len(data)
|
|
703
|
+
|
|
704
|
+
# Simple swarm: jitter x positions
|
|
705
|
+
# More sophisticated swarm would avoid overlaps
|
|
706
|
+
jitter = np.random.default_rng(42).uniform(-0.15, 0.15, n)
|
|
707
|
+
x_positions = pos + jitter
|
|
708
|
+
|
|
709
|
+
self._ax.scatter(
|
|
710
|
+
x_positions, data, s=point_size**2, c="black", alpha=0.5, zorder=3
|
|
711
|
+
)
|
|
712
|
+
|
|
713
|
+
def _add_violin_inner_stick(
|
|
714
|
+
self, dataset, positions, style: Dict[str, Any]
|
|
715
|
+
) -> None:
|
|
716
|
+
"""Add stick (line) markers inside violin for each data point.
|
|
717
|
+
|
|
718
|
+
Parameters
|
|
719
|
+
----------
|
|
720
|
+
dataset : array-like
|
|
721
|
+
Data arrays for each violin.
|
|
722
|
+
positions : array-like
|
|
723
|
+
X positions of violins.
|
|
724
|
+
style : dict
|
|
725
|
+
Violin style configuration.
|
|
726
|
+
"""
|
|
727
|
+
from ..styles._style_applier import mm_to_pt
|
|
728
|
+
|
|
729
|
+
lw = mm_to_pt(style.get("whisker_mm", 0.2))
|
|
730
|
+
|
|
731
|
+
for data, pos in zip(dataset, positions):
|
|
732
|
+
data = np.asarray(data)
|
|
733
|
+
# Draw short horizontal lines at each data point
|
|
734
|
+
for val in data:
|
|
735
|
+
self._ax.hlines(
|
|
736
|
+
val,
|
|
737
|
+
pos - 0.05,
|
|
738
|
+
pos + 0.05,
|
|
739
|
+
colors="black",
|
|
740
|
+
linewidths=lw * 0.5,
|
|
741
|
+
alpha=0.3,
|
|
742
|
+
zorder=3,
|
|
743
|
+
)
|
|
744
|
+
|
|
745
|
+
def _add_violin_inner_point(
|
|
746
|
+
self, dataset, positions, style: Dict[str, Any]
|
|
747
|
+
) -> None:
|
|
748
|
+
"""Add point markers inside violin for each data point.
|
|
749
|
+
|
|
750
|
+
Parameters
|
|
751
|
+
----------
|
|
752
|
+
dataset : array-like
|
|
753
|
+
Data arrays for each violin.
|
|
754
|
+
positions : array-like
|
|
755
|
+
X positions of violins.
|
|
756
|
+
style : dict
|
|
757
|
+
Violin style configuration.
|
|
758
|
+
"""
|
|
759
|
+
from ..styles._style_applier import mm_to_pt
|
|
760
|
+
|
|
761
|
+
point_size = mm_to_pt(style.get("median_mm", 0.8)) * 0.5
|
|
762
|
+
|
|
763
|
+
for data, pos in zip(dataset, positions):
|
|
764
|
+
data = np.asarray(data)
|
|
765
|
+
x_positions = np.full_like(data, pos)
|
|
766
|
+
self._ax.scatter(
|
|
767
|
+
x_positions, data, s=point_size**2, c="black", alpha=0.3, zorder=3
|
|
768
|
+
)
|
|
769
|
+
|
|
297
770
|
# Methods that should not be recorded
|
|
298
771
|
def get_xlim(self):
|
|
299
772
|
return self._ax.get_xlim()
|
|
@@ -310,6 +783,347 @@ class RecordingAxes:
|
|
|
310
783
|
def get_title(self):
|
|
311
784
|
return self._ax.get_title()
|
|
312
785
|
|
|
786
|
+
def joyplot(
|
|
787
|
+
self,
|
|
788
|
+
arrays,
|
|
789
|
+
*,
|
|
790
|
+
overlap: float = 0.5,
|
|
791
|
+
fill_alpha: float = 0.7,
|
|
792
|
+
line_alpha: float = 1.0,
|
|
793
|
+
colors=None,
|
|
794
|
+
labels=None,
|
|
795
|
+
id: Optional[str] = None,
|
|
796
|
+
track: bool = True,
|
|
797
|
+
**kwargs,
|
|
798
|
+
):
|
|
799
|
+
"""Create a joyplot (ridgeline plot) for distribution comparison.
|
|
800
|
+
|
|
801
|
+
Parameters
|
|
802
|
+
----------
|
|
803
|
+
arrays : list of array-like or dict
|
|
804
|
+
List of 1D arrays for each ridge. If dict, uses values.
|
|
805
|
+
overlap : float, default 0.5
|
|
806
|
+
Amount of overlap between ridges (0 = no overlap, 1 = full overlap).
|
|
807
|
+
fill_alpha : float, default 0.7
|
|
808
|
+
Alpha for the filled KDE area.
|
|
809
|
+
line_alpha : float, default 1.0
|
|
810
|
+
Alpha for the KDE line.
|
|
811
|
+
colors : list, optional
|
|
812
|
+
Colors for each ridge. If None, uses color cycle.
|
|
813
|
+
labels : list of str, optional
|
|
814
|
+
Labels for each ridge (for y-axis).
|
|
815
|
+
id : str, optional
|
|
816
|
+
Custom ID for this call.
|
|
817
|
+
track : bool, optional
|
|
818
|
+
Whether to record this call (default: True).
|
|
819
|
+
**kwargs
|
|
820
|
+
Additional arguments.
|
|
821
|
+
|
|
822
|
+
Returns
|
|
823
|
+
-------
|
|
824
|
+
RecordingAxes
|
|
825
|
+
Self for method chaining.
|
|
826
|
+
|
|
827
|
+
Examples
|
|
828
|
+
--------
|
|
829
|
+
>>> ax.joyplot([data1, data2, data3], overlap=0.5)
|
|
830
|
+
>>> ax.joyplot({"A": arr_a, "B": arr_b}, labels=["A", "B"])
|
|
831
|
+
"""
|
|
832
|
+
from scipy import stats
|
|
833
|
+
|
|
834
|
+
from .._utils._units import mm_to_pt
|
|
835
|
+
from ..styles import get_style
|
|
836
|
+
|
|
837
|
+
# Convert dict to list of arrays
|
|
838
|
+
if isinstance(arrays, dict):
|
|
839
|
+
if labels is None:
|
|
840
|
+
labels = list(arrays.keys())
|
|
841
|
+
arrays = list(arrays.values())
|
|
842
|
+
|
|
843
|
+
n_ridges = len(arrays)
|
|
844
|
+
|
|
845
|
+
# Get colors from style or use default cycle
|
|
846
|
+
if colors is None:
|
|
847
|
+
style = get_style()
|
|
848
|
+
if style and "colors" in style and "palette" in style.colors:
|
|
849
|
+
palette = list(style.colors.palette)
|
|
850
|
+
# Normalize RGB 0-255 to 0-1
|
|
851
|
+
colors = []
|
|
852
|
+
for c in palette:
|
|
853
|
+
if isinstance(c, (list, tuple)) and len(c) >= 3:
|
|
854
|
+
if all(v <= 1.0 for v in c):
|
|
855
|
+
colors.append(tuple(c))
|
|
856
|
+
else:
|
|
857
|
+
colors.append(tuple(v / 255.0 for v in c))
|
|
858
|
+
else:
|
|
859
|
+
colors.append(c)
|
|
860
|
+
else:
|
|
861
|
+
# Matplotlib default color cycle
|
|
862
|
+
import matplotlib.pyplot as plt
|
|
863
|
+
|
|
864
|
+
colors = [c["color"] for c in plt.rcParams["axes.prop_cycle"]]
|
|
865
|
+
|
|
866
|
+
# Calculate global x range
|
|
867
|
+
all_data = np.concatenate([np.asarray(arr) for arr in arrays])
|
|
868
|
+
x_min, x_max = np.min(all_data), np.max(all_data)
|
|
869
|
+
x_range = x_max - x_min
|
|
870
|
+
x_padding = x_range * 0.1
|
|
871
|
+
x = np.linspace(x_min - x_padding, x_max + x_padding, 200)
|
|
872
|
+
|
|
873
|
+
# Calculate KDEs and find max density for scaling
|
|
874
|
+
kdes = []
|
|
875
|
+
max_density = 0
|
|
876
|
+
for arr in arrays:
|
|
877
|
+
arr = np.asarray(arr)
|
|
878
|
+
if len(arr) > 1:
|
|
879
|
+
kde = stats.gaussian_kde(arr)
|
|
880
|
+
density = kde(x)
|
|
881
|
+
kdes.append(density)
|
|
882
|
+
max_density = max(max_density, np.max(density))
|
|
883
|
+
else:
|
|
884
|
+
kdes.append(np.zeros_like(x))
|
|
885
|
+
|
|
886
|
+
# Scale factor for ridge height
|
|
887
|
+
ridge_height = 1.0 / (1.0 - overlap * 0.5) if overlap < 1 else 2.0
|
|
888
|
+
|
|
889
|
+
# Get line width from style
|
|
890
|
+
style = get_style()
|
|
891
|
+
lw = mm_to_pt(0.2) # Default
|
|
892
|
+
if style and "lines" in style:
|
|
893
|
+
lw = mm_to_pt(style.lines.get("trace_mm", 0.2))
|
|
894
|
+
|
|
895
|
+
# Plot each ridge from back to front
|
|
896
|
+
for i in range(n_ridges - 1, -1, -1):
|
|
897
|
+
color = colors[i % len(colors)]
|
|
898
|
+
baseline = i * (1.0 - overlap)
|
|
899
|
+
|
|
900
|
+
# Scale density to fit nicely
|
|
901
|
+
scaled_density = (
|
|
902
|
+
kdes[i] / max_density * ridge_height if max_density > 0 else kdes[i]
|
|
903
|
+
)
|
|
904
|
+
|
|
905
|
+
# Fill
|
|
906
|
+
self._ax.fill_between(
|
|
907
|
+
x,
|
|
908
|
+
baseline,
|
|
909
|
+
baseline + scaled_density,
|
|
910
|
+
facecolor=color,
|
|
911
|
+
edgecolor="none",
|
|
912
|
+
alpha=fill_alpha,
|
|
913
|
+
)
|
|
914
|
+
# Line on top
|
|
915
|
+
self._ax.plot(
|
|
916
|
+
x,
|
|
917
|
+
baseline + scaled_density,
|
|
918
|
+
color=color,
|
|
919
|
+
alpha=line_alpha,
|
|
920
|
+
linewidth=lw,
|
|
921
|
+
)
|
|
922
|
+
|
|
923
|
+
# Set y limits
|
|
924
|
+
self._ax.set_ylim(-0.1, n_ridges * (1.0 - overlap) + ridge_height)
|
|
925
|
+
|
|
926
|
+
# Set y-axis labels if provided
|
|
927
|
+
if labels:
|
|
928
|
+
y_positions = [(i * (1.0 - overlap)) + 0.3 for i in range(n_ridges)]
|
|
929
|
+
self._ax.set_yticks(y_positions)
|
|
930
|
+
self._ax.set_yticklabels(labels)
|
|
931
|
+
else:
|
|
932
|
+
# Hide y-axis ticks for cleaner look
|
|
933
|
+
self._ax.set_yticks([])
|
|
934
|
+
|
|
935
|
+
# Record the call if tracking is enabled
|
|
936
|
+
if self._track and track:
|
|
937
|
+
self._recorder.record_call(
|
|
938
|
+
ax_position=self._position,
|
|
939
|
+
method_name="joyplot",
|
|
940
|
+
args=(arrays,),
|
|
941
|
+
kwargs={
|
|
942
|
+
"overlap": overlap,
|
|
943
|
+
"fill_alpha": fill_alpha,
|
|
944
|
+
"line_alpha": line_alpha,
|
|
945
|
+
"labels": labels,
|
|
946
|
+
},
|
|
947
|
+
call_id=id,
|
|
948
|
+
)
|
|
949
|
+
|
|
950
|
+
return self
|
|
951
|
+
|
|
952
|
+
def swarmplot(
|
|
953
|
+
self,
|
|
954
|
+
data,
|
|
955
|
+
positions=None,
|
|
956
|
+
*,
|
|
957
|
+
size: float = None,
|
|
958
|
+
color=None,
|
|
959
|
+
alpha: float = 0.7,
|
|
960
|
+
jitter: float = 0.3,
|
|
961
|
+
id: Optional[str] = None,
|
|
962
|
+
track: bool = True,
|
|
963
|
+
**kwargs,
|
|
964
|
+
):
|
|
965
|
+
"""Create a swarm plot (beeswarm plot) showing individual data points.
|
|
966
|
+
|
|
967
|
+
Parameters
|
|
968
|
+
----------
|
|
969
|
+
data : list of array-like
|
|
970
|
+
List of 1D arrays to plot.
|
|
971
|
+
positions : array-like, optional
|
|
972
|
+
X positions for each swarm. Default is 1, 2, 3, ...
|
|
973
|
+
size : float, optional
|
|
974
|
+
Marker size in mm. Default from style config.
|
|
975
|
+
color : color or list of colors, optional
|
|
976
|
+
Colors for each swarm.
|
|
977
|
+
alpha : float, default 0.7
|
|
978
|
+
Transparency of markers.
|
|
979
|
+
jitter : float, default 0.3
|
|
980
|
+
Width of jitter spread (in data units).
|
|
981
|
+
id : str, optional
|
|
982
|
+
Custom ID for this call.
|
|
983
|
+
track : bool, optional
|
|
984
|
+
Whether to record this call (default: True).
|
|
985
|
+
**kwargs
|
|
986
|
+
Additional arguments passed to scatter.
|
|
987
|
+
|
|
988
|
+
Returns
|
|
989
|
+
-------
|
|
990
|
+
list
|
|
991
|
+
List of PathCollection objects.
|
|
992
|
+
|
|
993
|
+
Examples
|
|
994
|
+
--------
|
|
995
|
+
>>> ax.swarmplot([data1, data2, data3])
|
|
996
|
+
>>> ax.swarmplot([arr1, arr2], positions=[0, 1], color=['red', 'blue'])
|
|
997
|
+
"""
|
|
998
|
+
from .._utils._units import mm_to_pt
|
|
999
|
+
from ..styles import get_style
|
|
1000
|
+
|
|
1001
|
+
# Get style
|
|
1002
|
+
style = get_style()
|
|
1003
|
+
|
|
1004
|
+
# Default marker size from style
|
|
1005
|
+
if size is None:
|
|
1006
|
+
if style and "markers" in style:
|
|
1007
|
+
size = style.markers.get("scatter_mm", 0.8)
|
|
1008
|
+
else:
|
|
1009
|
+
size = 0.8
|
|
1010
|
+
size_pt = mm_to_pt(size) ** 2 # matplotlib uses area
|
|
1011
|
+
|
|
1012
|
+
# Get colors
|
|
1013
|
+
if color is None:
|
|
1014
|
+
if style and "colors" in style and "palette" in style.colors:
|
|
1015
|
+
palette = list(style.colors.palette)
|
|
1016
|
+
colors = []
|
|
1017
|
+
for c in palette:
|
|
1018
|
+
if isinstance(c, (list, tuple)) and len(c) >= 3:
|
|
1019
|
+
if all(v <= 1.0 for v in c):
|
|
1020
|
+
colors.append(tuple(c))
|
|
1021
|
+
else:
|
|
1022
|
+
colors.append(tuple(v / 255.0 for v in c))
|
|
1023
|
+
else:
|
|
1024
|
+
colors.append(c)
|
|
1025
|
+
else:
|
|
1026
|
+
import matplotlib.pyplot as plt
|
|
1027
|
+
|
|
1028
|
+
colors = [c["color"] for c in plt.rcParams["axes.prop_cycle"]]
|
|
1029
|
+
elif isinstance(color, list):
|
|
1030
|
+
colors = color
|
|
1031
|
+
else:
|
|
1032
|
+
colors = [color] * len(data)
|
|
1033
|
+
|
|
1034
|
+
# Default positions
|
|
1035
|
+
if positions is None:
|
|
1036
|
+
positions = list(range(1, len(data) + 1))
|
|
1037
|
+
|
|
1038
|
+
# Random generator for reproducible jitter
|
|
1039
|
+
rng = np.random.default_rng(42)
|
|
1040
|
+
|
|
1041
|
+
results = []
|
|
1042
|
+
for i, (arr, pos) in enumerate(zip(data, positions)):
|
|
1043
|
+
arr = np.asarray(arr)
|
|
1044
|
+
|
|
1045
|
+
# Create jittered x positions using beeswarm algorithm (simplified)
|
|
1046
|
+
x_jitter = self._beeswarm_positions(arr, jitter, rng)
|
|
1047
|
+
x_positions = pos + x_jitter
|
|
1048
|
+
|
|
1049
|
+
c = colors[i % len(colors)]
|
|
1050
|
+
result = self._ax.scatter(
|
|
1051
|
+
x_positions, arr, s=size_pt, c=[c], alpha=alpha, **kwargs
|
|
1052
|
+
)
|
|
1053
|
+
results.append(result)
|
|
1054
|
+
|
|
1055
|
+
# Record the call if tracking is enabled
|
|
1056
|
+
if self._track and track:
|
|
1057
|
+
self._recorder.record_call(
|
|
1058
|
+
ax_position=self._position,
|
|
1059
|
+
method_name="swarmplot",
|
|
1060
|
+
args=(data,),
|
|
1061
|
+
kwargs={
|
|
1062
|
+
"positions": positions,
|
|
1063
|
+
"size": size,
|
|
1064
|
+
"alpha": alpha,
|
|
1065
|
+
"jitter": jitter,
|
|
1066
|
+
},
|
|
1067
|
+
call_id=id,
|
|
1068
|
+
)
|
|
1069
|
+
|
|
1070
|
+
return results
|
|
1071
|
+
|
|
1072
|
+
def _beeswarm_positions(
|
|
1073
|
+
self,
|
|
1074
|
+
data: np.ndarray,
|
|
1075
|
+
width: float,
|
|
1076
|
+
rng: np.random.Generator,
|
|
1077
|
+
) -> np.ndarray:
|
|
1078
|
+
"""Calculate beeswarm-style x positions to minimize overlap.
|
|
1079
|
+
|
|
1080
|
+
This is a simplified beeswarm that uses binning and jittering.
|
|
1081
|
+
For a true beeswarm, we'd need to iteratively place points.
|
|
1082
|
+
|
|
1083
|
+
Parameters
|
|
1084
|
+
----------
|
|
1085
|
+
data : array
|
|
1086
|
+
Y values of points.
|
|
1087
|
+
width : float
|
|
1088
|
+
Maximum jitter width.
|
|
1089
|
+
rng : Generator
|
|
1090
|
+
Random number generator.
|
|
1091
|
+
|
|
1092
|
+
Returns
|
|
1093
|
+
-------
|
|
1094
|
+
array
|
|
1095
|
+
X offsets for each point.
|
|
1096
|
+
"""
|
|
1097
|
+
n = len(data)
|
|
1098
|
+
if n == 0:
|
|
1099
|
+
return np.array([])
|
|
1100
|
+
|
|
1101
|
+
# Sort data and get order
|
|
1102
|
+
order = np.argsort(data)
|
|
1103
|
+
sorted_data = data[order]
|
|
1104
|
+
|
|
1105
|
+
# Group nearby points and offset them
|
|
1106
|
+
x_offsets = np.zeros(n)
|
|
1107
|
+
|
|
1108
|
+
# Simple approach: bin by quantiles and spread within each bin
|
|
1109
|
+
n_bins = max(1, int(np.sqrt(n)))
|
|
1110
|
+
bin_edges = np.percentile(sorted_data, np.linspace(0, 100, n_bins + 1))
|
|
1111
|
+
|
|
1112
|
+
for i in range(n_bins):
|
|
1113
|
+
mask = (sorted_data >= bin_edges[i]) & (sorted_data <= bin_edges[i + 1])
|
|
1114
|
+
n_in_bin = mask.sum()
|
|
1115
|
+
if n_in_bin > 0:
|
|
1116
|
+
# Spread points evenly within bin width
|
|
1117
|
+
offsets = np.linspace(-width / 2, width / 2, n_in_bin)
|
|
1118
|
+
# Add small random noise
|
|
1119
|
+
offsets += rng.uniform(-width * 0.1, width * 0.1, n_in_bin)
|
|
1120
|
+
x_offsets[mask] = offsets
|
|
1121
|
+
|
|
1122
|
+
# Restore original order
|
|
1123
|
+
result = np.zeros(n)
|
|
1124
|
+
result[order] = x_offsets
|
|
1125
|
+
return result
|
|
1126
|
+
|
|
313
1127
|
|
|
314
1128
|
class _NoRecordContext:
|
|
315
1129
|
"""Context manager to temporarily disable recording."""
|