figrecipe 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- figrecipe/__init__.py +1090 -0
- figrecipe/_recorder.py +435 -0
- figrecipe/_reproducer.py +358 -0
- figrecipe/_seaborn.py +305 -0
- figrecipe/_serializer.py +227 -0
- figrecipe/_signatures/__init__.py +7 -0
- figrecipe/_signatures/_loader.py +186 -0
- figrecipe/_utils/__init__.py +32 -0
- figrecipe/_utils/_crop.py +261 -0
- figrecipe/_utils/_diff.py +98 -0
- figrecipe/_utils/_image_diff.py +204 -0
- figrecipe/_utils/_numpy_io.py +204 -0
- figrecipe/_utils/_units.py +200 -0
- figrecipe/_validator.py +186 -0
- figrecipe/_wrappers/__init__.py +8 -0
- figrecipe/_wrappers/_axes.py +327 -0
- figrecipe/_wrappers/_figure.py +227 -0
- figrecipe/plt.py +12 -0
- figrecipe/pyplot.py +264 -0
- figrecipe/styles/__init__.py +50 -0
- figrecipe/styles/_style_applier.py +412 -0
- figrecipe/styles/_style_loader.py +450 -0
- figrecipe-0.5.0.dist-info/METADATA +336 -0
- figrecipe-0.5.0.dist-info/RECORD +26 -0
- figrecipe-0.5.0.dist-info/WHEEL +4 -0
- figrecipe-0.5.0.dist-info/licenses/LICENSE +661 -0
figrecipe/_reproducer.py
ADDED
|
@@ -0,0 +1,358 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
"""Reproduce figures from recipe files."""
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
7
|
+
|
|
8
|
+
import matplotlib.pyplot as plt
|
|
9
|
+
import numpy as np
|
|
10
|
+
from matplotlib.axes import Axes
|
|
11
|
+
from matplotlib.figure import Figure
|
|
12
|
+
|
|
13
|
+
from ._recorder import FigureRecord, CallRecord
|
|
14
|
+
from ._serializer import load_recipe
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def reproduce(
|
|
18
|
+
path: Union[str, Path],
|
|
19
|
+
calls: Optional[List[str]] = None,
|
|
20
|
+
skip_decorations: bool = False,
|
|
21
|
+
) -> Tuple[Figure, Union[Axes, List[Axes]]]:
|
|
22
|
+
"""Reproduce a figure from a recipe file.
|
|
23
|
+
|
|
24
|
+
Parameters
|
|
25
|
+
----------
|
|
26
|
+
path : str or Path
|
|
27
|
+
Path to .yaml recipe file.
|
|
28
|
+
calls : list of str, optional
|
|
29
|
+
If provided, only reproduce these specific call IDs.
|
|
30
|
+
skip_decorations : bool
|
|
31
|
+
If True, skip decoration calls (labels, legends, etc.).
|
|
32
|
+
|
|
33
|
+
Returns
|
|
34
|
+
-------
|
|
35
|
+
fig : matplotlib.figure.Figure
|
|
36
|
+
Reproduced figure.
|
|
37
|
+
axes : Axes or list of Axes
|
|
38
|
+
Reproduced axes (single if 1x1, otherwise list).
|
|
39
|
+
|
|
40
|
+
Examples
|
|
41
|
+
--------
|
|
42
|
+
>>> import figrecipe as ps
|
|
43
|
+
>>> fig, ax = ps.reproduce("experiment_001.yaml")
|
|
44
|
+
>>> plt.show()
|
|
45
|
+
"""
|
|
46
|
+
record = load_recipe(path)
|
|
47
|
+
return reproduce_from_record(
|
|
48
|
+
record,
|
|
49
|
+
calls=calls,
|
|
50
|
+
skip_decorations=skip_decorations,
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def reproduce_from_record(
|
|
55
|
+
record: FigureRecord,
|
|
56
|
+
calls: Optional[List[str]] = None,
|
|
57
|
+
skip_decorations: bool = False,
|
|
58
|
+
) -> Tuple[Figure, Union[Axes, List[Axes]]]:
|
|
59
|
+
"""Reproduce a figure from a FigureRecord.
|
|
60
|
+
|
|
61
|
+
Parameters
|
|
62
|
+
----------
|
|
63
|
+
record : FigureRecord
|
|
64
|
+
The figure record to reproduce.
|
|
65
|
+
calls : list of str, optional
|
|
66
|
+
If provided, only reproduce these specific call IDs.
|
|
67
|
+
skip_decorations : bool
|
|
68
|
+
If True, skip decoration calls.
|
|
69
|
+
|
|
70
|
+
Returns
|
|
71
|
+
-------
|
|
72
|
+
fig : matplotlib.figure.Figure
|
|
73
|
+
Reproduced figure.
|
|
74
|
+
axes : Axes or list of Axes
|
|
75
|
+
Reproduced axes.
|
|
76
|
+
"""
|
|
77
|
+
# Determine grid size from axes positions
|
|
78
|
+
max_row = 0
|
|
79
|
+
max_col = 0
|
|
80
|
+
for ax_key in record.axes.keys():
|
|
81
|
+
parts = ax_key.split("_")
|
|
82
|
+
if len(parts) >= 3:
|
|
83
|
+
max_row = max(max_row, int(parts[1]))
|
|
84
|
+
max_col = max(max_col, int(parts[2]))
|
|
85
|
+
|
|
86
|
+
nrows = max_row + 1
|
|
87
|
+
ncols = max_col + 1
|
|
88
|
+
|
|
89
|
+
# Create figure
|
|
90
|
+
fig, mpl_axes = plt.subplots(
|
|
91
|
+
nrows,
|
|
92
|
+
ncols,
|
|
93
|
+
figsize=record.figsize,
|
|
94
|
+
dpi=record.dpi,
|
|
95
|
+
constrained_layout=record.constrained_layout,
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
# Apply layout if recorded
|
|
99
|
+
if record.layout is not None:
|
|
100
|
+
fig.subplots_adjust(**record.layout)
|
|
101
|
+
|
|
102
|
+
# Ensure axes is 2D array
|
|
103
|
+
if nrows == 1 and ncols == 1:
|
|
104
|
+
axes_2d = np.array([[mpl_axes]])
|
|
105
|
+
else:
|
|
106
|
+
axes_2d = np.atleast_2d(mpl_axes)
|
|
107
|
+
if nrows == 1:
|
|
108
|
+
axes_2d = axes_2d.reshape(1, -1)
|
|
109
|
+
elif ncols == 1:
|
|
110
|
+
axes_2d = axes_2d.reshape(-1, 1)
|
|
111
|
+
|
|
112
|
+
# Apply style BEFORE replaying calls (to match original order:
|
|
113
|
+
# style is applied during subplots(), then user creates plots/decorations)
|
|
114
|
+
if record.style is not None:
|
|
115
|
+
from .styles import apply_style_mm
|
|
116
|
+
for row in range(nrows):
|
|
117
|
+
for col in range(ncols):
|
|
118
|
+
apply_style_mm(axes_2d[row, col], record.style)
|
|
119
|
+
|
|
120
|
+
# Replay calls on each axes
|
|
121
|
+
for ax_key, ax_record in record.axes.items():
|
|
122
|
+
parts = ax_key.split("_")
|
|
123
|
+
if len(parts) >= 3:
|
|
124
|
+
row, col = int(parts[1]), int(parts[2])
|
|
125
|
+
else:
|
|
126
|
+
row, col = 0, 0
|
|
127
|
+
|
|
128
|
+
ax = axes_2d[row, col]
|
|
129
|
+
|
|
130
|
+
# Replay plotting calls
|
|
131
|
+
for call in ax_record.calls:
|
|
132
|
+
if calls is not None and call.id not in calls:
|
|
133
|
+
continue
|
|
134
|
+
_replay_call(ax, call)
|
|
135
|
+
|
|
136
|
+
# Replay decorations
|
|
137
|
+
if not skip_decorations:
|
|
138
|
+
for call in ax_record.decorations:
|
|
139
|
+
if calls is not None and call.id not in calls:
|
|
140
|
+
continue
|
|
141
|
+
_replay_call(ax, call)
|
|
142
|
+
|
|
143
|
+
# Return in appropriate format
|
|
144
|
+
if nrows == 1 and ncols == 1:
|
|
145
|
+
return fig, axes_2d[0, 0]
|
|
146
|
+
elif nrows == 1:
|
|
147
|
+
return fig, list(axes_2d[0])
|
|
148
|
+
elif ncols == 1:
|
|
149
|
+
return fig, list(axes_2d[:, 0])
|
|
150
|
+
else:
|
|
151
|
+
return fig, axes_2d.tolist()
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def _replay_call(ax: Axes, call: CallRecord) -> Any:
|
|
155
|
+
"""Replay a single call on an axes.
|
|
156
|
+
|
|
157
|
+
Parameters
|
|
158
|
+
----------
|
|
159
|
+
ax : Axes
|
|
160
|
+
The matplotlib axes.
|
|
161
|
+
call : CallRecord
|
|
162
|
+
The call to replay.
|
|
163
|
+
|
|
164
|
+
Returns
|
|
165
|
+
-------
|
|
166
|
+
Any
|
|
167
|
+
Result of the matplotlib call.
|
|
168
|
+
"""
|
|
169
|
+
method_name = call.function
|
|
170
|
+
|
|
171
|
+
# Check if it's a seaborn call
|
|
172
|
+
if method_name.startswith("sns."):
|
|
173
|
+
return _replay_seaborn_call(ax, call)
|
|
174
|
+
|
|
175
|
+
method = getattr(ax, method_name, None)
|
|
176
|
+
|
|
177
|
+
if method is None:
|
|
178
|
+
# Method not found, skip
|
|
179
|
+
return None
|
|
180
|
+
|
|
181
|
+
# Reconstruct args
|
|
182
|
+
args = []
|
|
183
|
+
for arg_data in call.args:
|
|
184
|
+
value = _reconstruct_value(arg_data)
|
|
185
|
+
args.append(value)
|
|
186
|
+
|
|
187
|
+
# Get kwargs
|
|
188
|
+
kwargs = call.kwargs.copy()
|
|
189
|
+
|
|
190
|
+
# Call the method
|
|
191
|
+
try:
|
|
192
|
+
return method(*args, **kwargs)
|
|
193
|
+
except Exception as e:
|
|
194
|
+
# Log warning but continue
|
|
195
|
+
import warnings
|
|
196
|
+
warnings.warn(f"Failed to replay {method_name}: {e}")
|
|
197
|
+
return None
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def _replay_seaborn_call(ax: Axes, call: CallRecord) -> Any:
|
|
201
|
+
"""Replay a seaborn call on an axes.
|
|
202
|
+
|
|
203
|
+
Parameters
|
|
204
|
+
----------
|
|
205
|
+
ax : Axes
|
|
206
|
+
The matplotlib axes.
|
|
207
|
+
call : CallRecord
|
|
208
|
+
The seaborn call to replay.
|
|
209
|
+
|
|
210
|
+
Returns
|
|
211
|
+
-------
|
|
212
|
+
Any
|
|
213
|
+
Result of the seaborn call.
|
|
214
|
+
"""
|
|
215
|
+
try:
|
|
216
|
+
import seaborn as sns
|
|
217
|
+
import pandas as pd
|
|
218
|
+
except ImportError:
|
|
219
|
+
import warnings
|
|
220
|
+
warnings.warn("seaborn/pandas required to replay seaborn calls")
|
|
221
|
+
return None
|
|
222
|
+
|
|
223
|
+
# Get the seaborn function name (remove "sns." prefix)
|
|
224
|
+
func_name = call.function[4:] # Remove "sns."
|
|
225
|
+
func = getattr(sns, func_name, None)
|
|
226
|
+
|
|
227
|
+
if func is None:
|
|
228
|
+
import warnings
|
|
229
|
+
warnings.warn(f"Seaborn function {func_name} not found")
|
|
230
|
+
return None
|
|
231
|
+
|
|
232
|
+
# Reconstruct data from args
|
|
233
|
+
# Args contain column data with "param" field indicating the parameter name
|
|
234
|
+
data_dict = {}
|
|
235
|
+
param_mapping = {} # Maps param name to column name
|
|
236
|
+
|
|
237
|
+
for arg_data in call.args:
|
|
238
|
+
param = arg_data.get("param")
|
|
239
|
+
name = arg_data.get("name")
|
|
240
|
+
value = _reconstruct_value(arg_data)
|
|
241
|
+
|
|
242
|
+
if param is not None:
|
|
243
|
+
# This is a DataFrame column
|
|
244
|
+
col_name = name if name else param
|
|
245
|
+
data_dict[col_name] = value
|
|
246
|
+
param_mapping[param] = col_name
|
|
247
|
+
|
|
248
|
+
# Build kwargs
|
|
249
|
+
kwargs = call.kwargs.copy()
|
|
250
|
+
|
|
251
|
+
# Remove internal keys
|
|
252
|
+
internal_keys = [k for k in kwargs.keys() if k.startswith("_")]
|
|
253
|
+
for key in internal_keys:
|
|
254
|
+
kwargs.pop(key, None)
|
|
255
|
+
|
|
256
|
+
# If we have data columns, create a DataFrame
|
|
257
|
+
if data_dict:
|
|
258
|
+
df = pd.DataFrame(data_dict)
|
|
259
|
+
kwargs["data"] = df
|
|
260
|
+
|
|
261
|
+
# Update column name references in kwargs
|
|
262
|
+
for param, col_name in param_mapping.items():
|
|
263
|
+
if param in ["x", "y", "hue", "size", "style", "row", "col"]:
|
|
264
|
+
kwargs[param] = col_name
|
|
265
|
+
|
|
266
|
+
# Add the axes
|
|
267
|
+
kwargs["ax"] = ax
|
|
268
|
+
|
|
269
|
+
# Convert certain list parameters back to tuples (YAML serializes tuples as lists)
|
|
270
|
+
# 'sizes' in seaborn expects a tuple (min, max) for range, not a list
|
|
271
|
+
if "sizes" in kwargs and isinstance(kwargs["sizes"], list):
|
|
272
|
+
kwargs["sizes"] = tuple(kwargs["sizes"])
|
|
273
|
+
|
|
274
|
+
# Call the seaborn function
|
|
275
|
+
try:
|
|
276
|
+
return func(**kwargs)
|
|
277
|
+
except Exception as e:
|
|
278
|
+
import warnings
|
|
279
|
+
warnings.warn(f"Failed to replay sns.{func_name}: {e}")
|
|
280
|
+
return None
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
def _reconstruct_value(arg_data: Dict[str, Any]) -> Any:
|
|
284
|
+
"""Reconstruct a value from serialized arg data.
|
|
285
|
+
|
|
286
|
+
Parameters
|
|
287
|
+
----------
|
|
288
|
+
arg_data : dict
|
|
289
|
+
Serialized argument data.
|
|
290
|
+
|
|
291
|
+
Returns
|
|
292
|
+
-------
|
|
293
|
+
Any
|
|
294
|
+
Reconstructed value.
|
|
295
|
+
"""
|
|
296
|
+
# Check if we have a pre-loaded array
|
|
297
|
+
if "_loaded_array" in arg_data:
|
|
298
|
+
return arg_data["_loaded_array"]
|
|
299
|
+
|
|
300
|
+
data = arg_data.get("data")
|
|
301
|
+
|
|
302
|
+
# If data is a list, convert to numpy array
|
|
303
|
+
if isinstance(data, list):
|
|
304
|
+
dtype = arg_data.get("dtype")
|
|
305
|
+
try:
|
|
306
|
+
return np.array(data, dtype=dtype if dtype else None)
|
|
307
|
+
except (TypeError, ValueError):
|
|
308
|
+
return np.array(data)
|
|
309
|
+
|
|
310
|
+
return data
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
def get_recipe_info(path: Union[str, Path]) -> Dict[str, Any]:
|
|
314
|
+
"""Get information about a recipe without reproducing.
|
|
315
|
+
|
|
316
|
+
Parameters
|
|
317
|
+
----------
|
|
318
|
+
path : str or Path
|
|
319
|
+
Path to .yaml recipe file.
|
|
320
|
+
|
|
321
|
+
Returns
|
|
322
|
+
-------
|
|
323
|
+
dict
|
|
324
|
+
Recipe information including:
|
|
325
|
+
- id: Figure ID
|
|
326
|
+
- created: Creation timestamp
|
|
327
|
+
- matplotlib_version: Version used
|
|
328
|
+
- figsize: Figure size
|
|
329
|
+
- n_axes: Number of axes
|
|
330
|
+
- calls: List of call IDs
|
|
331
|
+
"""
|
|
332
|
+
record = load_recipe(path)
|
|
333
|
+
|
|
334
|
+
all_calls = []
|
|
335
|
+
for ax_record in record.axes.values():
|
|
336
|
+
for call in ax_record.calls:
|
|
337
|
+
all_calls.append({
|
|
338
|
+
"id": call.id,
|
|
339
|
+
"function": call.function,
|
|
340
|
+
"n_args": len(call.args),
|
|
341
|
+
"kwargs": list(call.kwargs.keys()),
|
|
342
|
+
})
|
|
343
|
+
for call in ax_record.decorations:
|
|
344
|
+
all_calls.append({
|
|
345
|
+
"id": call.id,
|
|
346
|
+
"function": call.function,
|
|
347
|
+
"type": "decoration",
|
|
348
|
+
})
|
|
349
|
+
|
|
350
|
+
return {
|
|
351
|
+
"id": record.id,
|
|
352
|
+
"created": record.created,
|
|
353
|
+
"matplotlib_version": record.matplotlib_version,
|
|
354
|
+
"figsize": record.figsize,
|
|
355
|
+
"dpi": record.dpi,
|
|
356
|
+
"n_axes": len(record.axes),
|
|
357
|
+
"calls": all_calls,
|
|
358
|
+
}
|
figrecipe/_seaborn.py
ADDED
|
@@ -0,0 +1,305 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
"""Seaborn wrapper for figrecipe recording."""
|
|
4
|
+
|
|
5
|
+
from functools import wraps
|
|
6
|
+
from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
|
|
10
|
+
try:
|
|
11
|
+
import seaborn as sns
|
|
12
|
+
import pandas as pd
|
|
13
|
+
HAS_SEABORN = True
|
|
14
|
+
except ImportError:
|
|
15
|
+
HAS_SEABORN = False
|
|
16
|
+
sns = None
|
|
17
|
+
pd = None
|
|
18
|
+
|
|
19
|
+
if TYPE_CHECKING:
|
|
20
|
+
from ._wrappers._axes import RecordingAxes
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
# Seaborn axes-level plotting functions to wrap
|
|
24
|
+
SEABORN_PLOT_FUNCTIONS = {
|
|
25
|
+
# Relational
|
|
26
|
+
"scatterplot",
|
|
27
|
+
"lineplot",
|
|
28
|
+
# Distribution
|
|
29
|
+
"histplot",
|
|
30
|
+
"kdeplot",
|
|
31
|
+
"ecdfplot",
|
|
32
|
+
"rugplot",
|
|
33
|
+
# Categorical
|
|
34
|
+
"stripplot",
|
|
35
|
+
"swarmplot",
|
|
36
|
+
"boxplot",
|
|
37
|
+
"violinplot",
|
|
38
|
+
"boxenplot",
|
|
39
|
+
"pointplot",
|
|
40
|
+
"barplot",
|
|
41
|
+
"countplot",
|
|
42
|
+
# Regression
|
|
43
|
+
"regplot",
|
|
44
|
+
"residplot",
|
|
45
|
+
# Matrix
|
|
46
|
+
"heatmap",
|
|
47
|
+
"clustermap",
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def _check_seaborn():
|
|
52
|
+
"""Check if seaborn is available."""
|
|
53
|
+
if not HAS_SEABORN:
|
|
54
|
+
raise ImportError(
|
|
55
|
+
"seaborn is required for this functionality. "
|
|
56
|
+
"Install it with: pip install seaborn"
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def _extract_data_from_dataframe(
|
|
61
|
+
data: Optional["pd.DataFrame"],
|
|
62
|
+
x: Optional[str] = None,
|
|
63
|
+
y: Optional[str] = None,
|
|
64
|
+
hue: Optional[str] = None,
|
|
65
|
+
size: Optional[str] = None,
|
|
66
|
+
style: Optional[str] = None,
|
|
67
|
+
row: Optional[str] = None,
|
|
68
|
+
col: Optional[str] = None,
|
|
69
|
+
weight: Optional[str] = None,
|
|
70
|
+
weights: Optional[str] = None,
|
|
71
|
+
) -> Dict[str, Any]:
|
|
72
|
+
"""Extract relevant columns from DataFrame for serialization.
|
|
73
|
+
|
|
74
|
+
Parameters
|
|
75
|
+
----------
|
|
76
|
+
data : DataFrame or None
|
|
77
|
+
The data source.
|
|
78
|
+
x, y, hue, size, style, row, col, weight, weights : str or None
|
|
79
|
+
Column names to extract.
|
|
80
|
+
|
|
81
|
+
Returns
|
|
82
|
+
-------
|
|
83
|
+
dict
|
|
84
|
+
Extracted data with column arrays.
|
|
85
|
+
"""
|
|
86
|
+
if data is None:
|
|
87
|
+
return {}
|
|
88
|
+
|
|
89
|
+
extracted = {}
|
|
90
|
+
columns_to_extract = []
|
|
91
|
+
|
|
92
|
+
# All column parameters
|
|
93
|
+
param_values = [
|
|
94
|
+
("x", x), ("y", y), ("hue", hue), ("size", size), ("style", style),
|
|
95
|
+
("row", row), ("col", col), ("weight", weight), ("weights", weights),
|
|
96
|
+
]
|
|
97
|
+
|
|
98
|
+
for param_name, col_name in param_values:
|
|
99
|
+
if col_name is not None and isinstance(col_name, str):
|
|
100
|
+
if col_name in data.columns:
|
|
101
|
+
columns_to_extract.append((param_name, col_name))
|
|
102
|
+
|
|
103
|
+
# Extract columns
|
|
104
|
+
for param_name, col_name in columns_to_extract:
|
|
105
|
+
arr = data[col_name].values
|
|
106
|
+
extracted[f"_col_{param_name}"] = arr
|
|
107
|
+
extracted[f"_colname_{param_name}"] = col_name
|
|
108
|
+
|
|
109
|
+
return extracted
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def _serialize_seaborn_args(
|
|
113
|
+
func_name: str,
|
|
114
|
+
args: tuple,
|
|
115
|
+
kwargs: Dict[str, Any],
|
|
116
|
+
) -> tuple:
|
|
117
|
+
"""Serialize seaborn function arguments.
|
|
118
|
+
|
|
119
|
+
Parameters
|
|
120
|
+
----------
|
|
121
|
+
func_name : str
|
|
122
|
+
Name of seaborn function.
|
|
123
|
+
args : tuple
|
|
124
|
+
Positional arguments.
|
|
125
|
+
kwargs : dict
|
|
126
|
+
Keyword arguments.
|
|
127
|
+
|
|
128
|
+
Returns
|
|
129
|
+
-------
|
|
130
|
+
tuple
|
|
131
|
+
(processed_args, processed_kwargs, data_arrays)
|
|
132
|
+
"""
|
|
133
|
+
processed_kwargs = {}
|
|
134
|
+
data_arrays = {}
|
|
135
|
+
|
|
136
|
+
# Handle 'data' parameter (DataFrame)
|
|
137
|
+
data = kwargs.get("data")
|
|
138
|
+
if data is not None and hasattr(data, "columns"):
|
|
139
|
+
# Extract column data
|
|
140
|
+
extracted = _extract_data_from_dataframe(
|
|
141
|
+
data,
|
|
142
|
+
x=kwargs.get("x"),
|
|
143
|
+
y=kwargs.get("y"),
|
|
144
|
+
hue=kwargs.get("hue"),
|
|
145
|
+
size=kwargs.get("size"),
|
|
146
|
+
style=kwargs.get("style"),
|
|
147
|
+
row=kwargs.get("row"),
|
|
148
|
+
col=kwargs.get("col"),
|
|
149
|
+
weight=kwargs.get("weight"),
|
|
150
|
+
weights=kwargs.get("weights"),
|
|
151
|
+
)
|
|
152
|
+
data_arrays.update(extracted)
|
|
153
|
+
|
|
154
|
+
# Store column names (not the DataFrame itself)
|
|
155
|
+
processed_kwargs["_has_dataframe"] = True
|
|
156
|
+
|
|
157
|
+
# Process other kwargs
|
|
158
|
+
for key, value in kwargs.items():
|
|
159
|
+
if key == "data":
|
|
160
|
+
continue # Handled above
|
|
161
|
+
elif key == "ax":
|
|
162
|
+
continue # Will be handled separately
|
|
163
|
+
elif isinstance(value, np.ndarray):
|
|
164
|
+
data_arrays[f"_kwarg_{key}"] = value
|
|
165
|
+
processed_kwargs[key] = "__ARRAY__"
|
|
166
|
+
elif hasattr(value, "values"): # pandas Series
|
|
167
|
+
data_arrays[f"_kwarg_{key}"] = np.asarray(value)
|
|
168
|
+
processed_kwargs[key] = "__ARRAY__"
|
|
169
|
+
elif _is_serializable(value):
|
|
170
|
+
processed_kwargs[key] = value
|
|
171
|
+
else:
|
|
172
|
+
try:
|
|
173
|
+
processed_kwargs[key] = str(value)
|
|
174
|
+
except Exception:
|
|
175
|
+
pass
|
|
176
|
+
|
|
177
|
+
# Process positional args (less common for seaborn)
|
|
178
|
+
processed_args = []
|
|
179
|
+
for i, arg in enumerate(args):
|
|
180
|
+
if isinstance(arg, np.ndarray):
|
|
181
|
+
data_arrays[f"_arg_{i}"] = arg
|
|
182
|
+
processed_args.append("__ARRAY__")
|
|
183
|
+
elif hasattr(arg, "values"):
|
|
184
|
+
data_arrays[f"_arg_{i}"] = np.asarray(arg)
|
|
185
|
+
processed_args.append("__ARRAY__")
|
|
186
|
+
elif _is_serializable(arg):
|
|
187
|
+
processed_args.append(arg)
|
|
188
|
+
else:
|
|
189
|
+
processed_args.append(str(arg))
|
|
190
|
+
|
|
191
|
+
return tuple(processed_args), processed_kwargs, data_arrays
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def _is_serializable(value: Any) -> bool:
|
|
195
|
+
"""Check if value is directly serializable."""
|
|
196
|
+
if value is None:
|
|
197
|
+
return True
|
|
198
|
+
if isinstance(value, (bool, int, float, str)):
|
|
199
|
+
return True
|
|
200
|
+
if isinstance(value, (list, tuple)):
|
|
201
|
+
return all(_is_serializable(v) for v in value)
|
|
202
|
+
if isinstance(value, dict):
|
|
203
|
+
return all(
|
|
204
|
+
isinstance(k, str) and _is_serializable(v)
|
|
205
|
+
for k, v in value.items()
|
|
206
|
+
)
|
|
207
|
+
return False
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
class SeabornRecorder:
|
|
211
|
+
"""Wrapper that records seaborn plotting calls."""
|
|
212
|
+
|
|
213
|
+
def __init__(self):
|
|
214
|
+
_check_seaborn()
|
|
215
|
+
|
|
216
|
+
def __getattr__(self, name: str) -> Callable:
|
|
217
|
+
"""Get a wrapped seaborn function."""
|
|
218
|
+
if name.startswith("_"):
|
|
219
|
+
raise AttributeError(name)
|
|
220
|
+
|
|
221
|
+
if not hasattr(sns, name):
|
|
222
|
+
raise AttributeError(f"seaborn has no attribute '{name}'")
|
|
223
|
+
|
|
224
|
+
original_func = getattr(sns, name)
|
|
225
|
+
|
|
226
|
+
if name not in SEABORN_PLOT_FUNCTIONS:
|
|
227
|
+
# Return unwrapped for non-plotting functions
|
|
228
|
+
return original_func
|
|
229
|
+
|
|
230
|
+
@wraps(original_func)
|
|
231
|
+
def wrapped(*args, **kwargs):
|
|
232
|
+
return self._record_and_call(name, original_func, args, kwargs)
|
|
233
|
+
|
|
234
|
+
return wrapped
|
|
235
|
+
|
|
236
|
+
def _record_and_call(
|
|
237
|
+
self,
|
|
238
|
+
func_name: str,
|
|
239
|
+
func: Callable,
|
|
240
|
+
args: tuple,
|
|
241
|
+
kwargs: Dict[str, Any],
|
|
242
|
+
) -> Any:
|
|
243
|
+
"""Record the seaborn call and execute it.
|
|
244
|
+
|
|
245
|
+
Parameters
|
|
246
|
+
----------
|
|
247
|
+
func_name : str
|
|
248
|
+
Name of the seaborn function.
|
|
249
|
+
func : callable
|
|
250
|
+
The actual seaborn function.
|
|
251
|
+
args : tuple
|
|
252
|
+
Positional arguments.
|
|
253
|
+
kwargs : dict
|
|
254
|
+
Keyword arguments.
|
|
255
|
+
|
|
256
|
+
Returns
|
|
257
|
+
-------
|
|
258
|
+
Any
|
|
259
|
+
Result from the seaborn function.
|
|
260
|
+
"""
|
|
261
|
+
from ._wrappers._axes import RecordingAxes
|
|
262
|
+
|
|
263
|
+
# Extract custom ID if provided
|
|
264
|
+
call_id = kwargs.pop("id", None)
|
|
265
|
+
|
|
266
|
+
# Get the axes
|
|
267
|
+
ax = kwargs.get("ax")
|
|
268
|
+
|
|
269
|
+
# If we have a RecordingAxes, disable recording during seaborn call
|
|
270
|
+
# to prevent recording the underlying matplotlib calls (e.g., scatter)
|
|
271
|
+
# that seaborn makes internally. We only want to record the seaborn call.
|
|
272
|
+
if isinstance(ax, RecordingAxes):
|
|
273
|
+
with ax.no_record():
|
|
274
|
+
result = func(*args, **kwargs)
|
|
275
|
+
|
|
276
|
+
# Serialize arguments
|
|
277
|
+
proc_args, proc_kwargs, data_arrays = _serialize_seaborn_args(
|
|
278
|
+
func_name, args, kwargs
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
# Record as a seaborn call (outside no_record context)
|
|
282
|
+
ax.record_seaborn_call(
|
|
283
|
+
func_name=func_name,
|
|
284
|
+
args=proc_args,
|
|
285
|
+
kwargs=proc_kwargs,
|
|
286
|
+
data_arrays=data_arrays,
|
|
287
|
+
call_id=call_id,
|
|
288
|
+
)
|
|
289
|
+
else:
|
|
290
|
+
# No recording axes, just call the function
|
|
291
|
+
result = func(*args, **kwargs)
|
|
292
|
+
|
|
293
|
+
return result
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
# Module-level instance for convenient access
|
|
297
|
+
_recorder: Optional[SeabornRecorder] = None
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
def get_seaborn_recorder() -> SeabornRecorder:
|
|
301
|
+
"""Get or create the seaborn recorder instance."""
|
|
302
|
+
global _recorder
|
|
303
|
+
if _recorder is None:
|
|
304
|
+
_recorder = SeabornRecorder()
|
|
305
|
+
return _recorder
|