figrecipe 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,358 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """Reproduce figures from recipe files."""
4
+
5
+ from pathlib import Path
6
+ from typing import Any, Dict, List, Optional, Tuple, Union
7
+
8
+ import matplotlib.pyplot as plt
9
+ import numpy as np
10
+ from matplotlib.axes import Axes
11
+ from matplotlib.figure import Figure
12
+
13
+ from ._recorder import FigureRecord, CallRecord
14
+ from ._serializer import load_recipe
15
+
16
+
17
+ def reproduce(
18
+ path: Union[str, Path],
19
+ calls: Optional[List[str]] = None,
20
+ skip_decorations: bool = False,
21
+ ) -> Tuple[Figure, Union[Axes, List[Axes]]]:
22
+ """Reproduce a figure from a recipe file.
23
+
24
+ Parameters
25
+ ----------
26
+ path : str or Path
27
+ Path to .yaml recipe file.
28
+ calls : list of str, optional
29
+ If provided, only reproduce these specific call IDs.
30
+ skip_decorations : bool
31
+ If True, skip decoration calls (labels, legends, etc.).
32
+
33
+ Returns
34
+ -------
35
+ fig : matplotlib.figure.Figure
36
+ Reproduced figure.
37
+ axes : Axes or list of Axes
38
+ Reproduced axes (single if 1x1, otherwise list).
39
+
40
+ Examples
41
+ --------
42
+ >>> import figrecipe as ps
43
+ >>> fig, ax = ps.reproduce("experiment_001.yaml")
44
+ >>> plt.show()
45
+ """
46
+ record = load_recipe(path)
47
+ return reproduce_from_record(
48
+ record,
49
+ calls=calls,
50
+ skip_decorations=skip_decorations,
51
+ )
52
+
53
+
54
+ def reproduce_from_record(
55
+ record: FigureRecord,
56
+ calls: Optional[List[str]] = None,
57
+ skip_decorations: bool = False,
58
+ ) -> Tuple[Figure, Union[Axes, List[Axes]]]:
59
+ """Reproduce a figure from a FigureRecord.
60
+
61
+ Parameters
62
+ ----------
63
+ record : FigureRecord
64
+ The figure record to reproduce.
65
+ calls : list of str, optional
66
+ If provided, only reproduce these specific call IDs.
67
+ skip_decorations : bool
68
+ If True, skip decoration calls.
69
+
70
+ Returns
71
+ -------
72
+ fig : matplotlib.figure.Figure
73
+ Reproduced figure.
74
+ axes : Axes or list of Axes
75
+ Reproduced axes.
76
+ """
77
+ # Determine grid size from axes positions
78
+ max_row = 0
79
+ max_col = 0
80
+ for ax_key in record.axes.keys():
81
+ parts = ax_key.split("_")
82
+ if len(parts) >= 3:
83
+ max_row = max(max_row, int(parts[1]))
84
+ max_col = max(max_col, int(parts[2]))
85
+
86
+ nrows = max_row + 1
87
+ ncols = max_col + 1
88
+
89
+ # Create figure
90
+ fig, mpl_axes = plt.subplots(
91
+ nrows,
92
+ ncols,
93
+ figsize=record.figsize,
94
+ dpi=record.dpi,
95
+ constrained_layout=record.constrained_layout,
96
+ )
97
+
98
+ # Apply layout if recorded
99
+ if record.layout is not None:
100
+ fig.subplots_adjust(**record.layout)
101
+
102
+ # Ensure axes is 2D array
103
+ if nrows == 1 and ncols == 1:
104
+ axes_2d = np.array([[mpl_axes]])
105
+ else:
106
+ axes_2d = np.atleast_2d(mpl_axes)
107
+ if nrows == 1:
108
+ axes_2d = axes_2d.reshape(1, -1)
109
+ elif ncols == 1:
110
+ axes_2d = axes_2d.reshape(-1, 1)
111
+
112
+ # Apply style BEFORE replaying calls (to match original order:
113
+ # style is applied during subplots(), then user creates plots/decorations)
114
+ if record.style is not None:
115
+ from .styles import apply_style_mm
116
+ for row in range(nrows):
117
+ for col in range(ncols):
118
+ apply_style_mm(axes_2d[row, col], record.style)
119
+
120
+ # Replay calls on each axes
121
+ for ax_key, ax_record in record.axes.items():
122
+ parts = ax_key.split("_")
123
+ if len(parts) >= 3:
124
+ row, col = int(parts[1]), int(parts[2])
125
+ else:
126
+ row, col = 0, 0
127
+
128
+ ax = axes_2d[row, col]
129
+
130
+ # Replay plotting calls
131
+ for call in ax_record.calls:
132
+ if calls is not None and call.id not in calls:
133
+ continue
134
+ _replay_call(ax, call)
135
+
136
+ # Replay decorations
137
+ if not skip_decorations:
138
+ for call in ax_record.decorations:
139
+ if calls is not None and call.id not in calls:
140
+ continue
141
+ _replay_call(ax, call)
142
+
143
+ # Return in appropriate format
144
+ if nrows == 1 and ncols == 1:
145
+ return fig, axes_2d[0, 0]
146
+ elif nrows == 1:
147
+ return fig, list(axes_2d[0])
148
+ elif ncols == 1:
149
+ return fig, list(axes_2d[:, 0])
150
+ else:
151
+ return fig, axes_2d.tolist()
152
+
153
+
154
+ def _replay_call(ax: Axes, call: CallRecord) -> Any:
155
+ """Replay a single call on an axes.
156
+
157
+ Parameters
158
+ ----------
159
+ ax : Axes
160
+ The matplotlib axes.
161
+ call : CallRecord
162
+ The call to replay.
163
+
164
+ Returns
165
+ -------
166
+ Any
167
+ Result of the matplotlib call.
168
+ """
169
+ method_name = call.function
170
+
171
+ # Check if it's a seaborn call
172
+ if method_name.startswith("sns."):
173
+ return _replay_seaborn_call(ax, call)
174
+
175
+ method = getattr(ax, method_name, None)
176
+
177
+ if method is None:
178
+ # Method not found, skip
179
+ return None
180
+
181
+ # Reconstruct args
182
+ args = []
183
+ for arg_data in call.args:
184
+ value = _reconstruct_value(arg_data)
185
+ args.append(value)
186
+
187
+ # Get kwargs
188
+ kwargs = call.kwargs.copy()
189
+
190
+ # Call the method
191
+ try:
192
+ return method(*args, **kwargs)
193
+ except Exception as e:
194
+ # Log warning but continue
195
+ import warnings
196
+ warnings.warn(f"Failed to replay {method_name}: {e}")
197
+ return None
198
+
199
+
200
+ def _replay_seaborn_call(ax: Axes, call: CallRecord) -> Any:
201
+ """Replay a seaborn call on an axes.
202
+
203
+ Parameters
204
+ ----------
205
+ ax : Axes
206
+ The matplotlib axes.
207
+ call : CallRecord
208
+ The seaborn call to replay.
209
+
210
+ Returns
211
+ -------
212
+ Any
213
+ Result of the seaborn call.
214
+ """
215
+ try:
216
+ import seaborn as sns
217
+ import pandas as pd
218
+ except ImportError:
219
+ import warnings
220
+ warnings.warn("seaborn/pandas required to replay seaborn calls")
221
+ return None
222
+
223
+ # Get the seaborn function name (remove "sns." prefix)
224
+ func_name = call.function[4:] # Remove "sns."
225
+ func = getattr(sns, func_name, None)
226
+
227
+ if func is None:
228
+ import warnings
229
+ warnings.warn(f"Seaborn function {func_name} not found")
230
+ return None
231
+
232
+ # Reconstruct data from args
233
+ # Args contain column data with "param" field indicating the parameter name
234
+ data_dict = {}
235
+ param_mapping = {} # Maps param name to column name
236
+
237
+ for arg_data in call.args:
238
+ param = arg_data.get("param")
239
+ name = arg_data.get("name")
240
+ value = _reconstruct_value(arg_data)
241
+
242
+ if param is not None:
243
+ # This is a DataFrame column
244
+ col_name = name if name else param
245
+ data_dict[col_name] = value
246
+ param_mapping[param] = col_name
247
+
248
+ # Build kwargs
249
+ kwargs = call.kwargs.copy()
250
+
251
+ # Remove internal keys
252
+ internal_keys = [k for k in kwargs.keys() if k.startswith("_")]
253
+ for key in internal_keys:
254
+ kwargs.pop(key, None)
255
+
256
+ # If we have data columns, create a DataFrame
257
+ if data_dict:
258
+ df = pd.DataFrame(data_dict)
259
+ kwargs["data"] = df
260
+
261
+ # Update column name references in kwargs
262
+ for param, col_name in param_mapping.items():
263
+ if param in ["x", "y", "hue", "size", "style", "row", "col"]:
264
+ kwargs[param] = col_name
265
+
266
+ # Add the axes
267
+ kwargs["ax"] = ax
268
+
269
+ # Convert certain list parameters back to tuples (YAML serializes tuples as lists)
270
+ # 'sizes' in seaborn expects a tuple (min, max) for range, not a list
271
+ if "sizes" in kwargs and isinstance(kwargs["sizes"], list):
272
+ kwargs["sizes"] = tuple(kwargs["sizes"])
273
+
274
+ # Call the seaborn function
275
+ try:
276
+ return func(**kwargs)
277
+ except Exception as e:
278
+ import warnings
279
+ warnings.warn(f"Failed to replay sns.{func_name}: {e}")
280
+ return None
281
+
282
+
283
+ def _reconstruct_value(arg_data: Dict[str, Any]) -> Any:
284
+ """Reconstruct a value from serialized arg data.
285
+
286
+ Parameters
287
+ ----------
288
+ arg_data : dict
289
+ Serialized argument data.
290
+
291
+ Returns
292
+ -------
293
+ Any
294
+ Reconstructed value.
295
+ """
296
+ # Check if we have a pre-loaded array
297
+ if "_loaded_array" in arg_data:
298
+ return arg_data["_loaded_array"]
299
+
300
+ data = arg_data.get("data")
301
+
302
+ # If data is a list, convert to numpy array
303
+ if isinstance(data, list):
304
+ dtype = arg_data.get("dtype")
305
+ try:
306
+ return np.array(data, dtype=dtype if dtype else None)
307
+ except (TypeError, ValueError):
308
+ return np.array(data)
309
+
310
+ return data
311
+
312
+
313
+ def get_recipe_info(path: Union[str, Path]) -> Dict[str, Any]:
314
+ """Get information about a recipe without reproducing.
315
+
316
+ Parameters
317
+ ----------
318
+ path : str or Path
319
+ Path to .yaml recipe file.
320
+
321
+ Returns
322
+ -------
323
+ dict
324
+ Recipe information including:
325
+ - id: Figure ID
326
+ - created: Creation timestamp
327
+ - matplotlib_version: Version used
328
+ - figsize: Figure size
329
+ - n_axes: Number of axes
330
+ - calls: List of call IDs
331
+ """
332
+ record = load_recipe(path)
333
+
334
+ all_calls = []
335
+ for ax_record in record.axes.values():
336
+ for call in ax_record.calls:
337
+ all_calls.append({
338
+ "id": call.id,
339
+ "function": call.function,
340
+ "n_args": len(call.args),
341
+ "kwargs": list(call.kwargs.keys()),
342
+ })
343
+ for call in ax_record.decorations:
344
+ all_calls.append({
345
+ "id": call.id,
346
+ "function": call.function,
347
+ "type": "decoration",
348
+ })
349
+
350
+ return {
351
+ "id": record.id,
352
+ "created": record.created,
353
+ "matplotlib_version": record.matplotlib_version,
354
+ "figsize": record.figsize,
355
+ "dpi": record.dpi,
356
+ "n_axes": len(record.axes),
357
+ "calls": all_calls,
358
+ }
figrecipe/_seaborn.py ADDED
@@ -0,0 +1,305 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """Seaborn wrapper for figrecipe recording."""
4
+
5
+ from functools import wraps
6
+ from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING
7
+
8
+ import numpy as np
9
+
10
+ try:
11
+ import seaborn as sns
12
+ import pandas as pd
13
+ HAS_SEABORN = True
14
+ except ImportError:
15
+ HAS_SEABORN = False
16
+ sns = None
17
+ pd = None
18
+
19
+ if TYPE_CHECKING:
20
+ from ._wrappers._axes import RecordingAxes
21
+
22
+
23
+ # Seaborn axes-level plotting functions to wrap
24
+ SEABORN_PLOT_FUNCTIONS = {
25
+ # Relational
26
+ "scatterplot",
27
+ "lineplot",
28
+ # Distribution
29
+ "histplot",
30
+ "kdeplot",
31
+ "ecdfplot",
32
+ "rugplot",
33
+ # Categorical
34
+ "stripplot",
35
+ "swarmplot",
36
+ "boxplot",
37
+ "violinplot",
38
+ "boxenplot",
39
+ "pointplot",
40
+ "barplot",
41
+ "countplot",
42
+ # Regression
43
+ "regplot",
44
+ "residplot",
45
+ # Matrix
46
+ "heatmap",
47
+ "clustermap",
48
+ }
49
+
50
+
51
+ def _check_seaborn():
52
+ """Check if seaborn is available."""
53
+ if not HAS_SEABORN:
54
+ raise ImportError(
55
+ "seaborn is required for this functionality. "
56
+ "Install it with: pip install seaborn"
57
+ )
58
+
59
+
60
+ def _extract_data_from_dataframe(
61
+ data: Optional["pd.DataFrame"],
62
+ x: Optional[str] = None,
63
+ y: Optional[str] = None,
64
+ hue: Optional[str] = None,
65
+ size: Optional[str] = None,
66
+ style: Optional[str] = None,
67
+ row: Optional[str] = None,
68
+ col: Optional[str] = None,
69
+ weight: Optional[str] = None,
70
+ weights: Optional[str] = None,
71
+ ) -> Dict[str, Any]:
72
+ """Extract relevant columns from DataFrame for serialization.
73
+
74
+ Parameters
75
+ ----------
76
+ data : DataFrame or None
77
+ The data source.
78
+ x, y, hue, size, style, row, col, weight, weights : str or None
79
+ Column names to extract.
80
+
81
+ Returns
82
+ -------
83
+ dict
84
+ Extracted data with column arrays.
85
+ """
86
+ if data is None:
87
+ return {}
88
+
89
+ extracted = {}
90
+ columns_to_extract = []
91
+
92
+ # All column parameters
93
+ param_values = [
94
+ ("x", x), ("y", y), ("hue", hue), ("size", size), ("style", style),
95
+ ("row", row), ("col", col), ("weight", weight), ("weights", weights),
96
+ ]
97
+
98
+ for param_name, col_name in param_values:
99
+ if col_name is not None and isinstance(col_name, str):
100
+ if col_name in data.columns:
101
+ columns_to_extract.append((param_name, col_name))
102
+
103
+ # Extract columns
104
+ for param_name, col_name in columns_to_extract:
105
+ arr = data[col_name].values
106
+ extracted[f"_col_{param_name}"] = arr
107
+ extracted[f"_colname_{param_name}"] = col_name
108
+
109
+ return extracted
110
+
111
+
112
+ def _serialize_seaborn_args(
113
+ func_name: str,
114
+ args: tuple,
115
+ kwargs: Dict[str, Any],
116
+ ) -> tuple:
117
+ """Serialize seaborn function arguments.
118
+
119
+ Parameters
120
+ ----------
121
+ func_name : str
122
+ Name of seaborn function.
123
+ args : tuple
124
+ Positional arguments.
125
+ kwargs : dict
126
+ Keyword arguments.
127
+
128
+ Returns
129
+ -------
130
+ tuple
131
+ (processed_args, processed_kwargs, data_arrays)
132
+ """
133
+ processed_kwargs = {}
134
+ data_arrays = {}
135
+
136
+ # Handle 'data' parameter (DataFrame)
137
+ data = kwargs.get("data")
138
+ if data is not None and hasattr(data, "columns"):
139
+ # Extract column data
140
+ extracted = _extract_data_from_dataframe(
141
+ data,
142
+ x=kwargs.get("x"),
143
+ y=kwargs.get("y"),
144
+ hue=kwargs.get("hue"),
145
+ size=kwargs.get("size"),
146
+ style=kwargs.get("style"),
147
+ row=kwargs.get("row"),
148
+ col=kwargs.get("col"),
149
+ weight=kwargs.get("weight"),
150
+ weights=kwargs.get("weights"),
151
+ )
152
+ data_arrays.update(extracted)
153
+
154
+ # Store column names (not the DataFrame itself)
155
+ processed_kwargs["_has_dataframe"] = True
156
+
157
+ # Process other kwargs
158
+ for key, value in kwargs.items():
159
+ if key == "data":
160
+ continue # Handled above
161
+ elif key == "ax":
162
+ continue # Will be handled separately
163
+ elif isinstance(value, np.ndarray):
164
+ data_arrays[f"_kwarg_{key}"] = value
165
+ processed_kwargs[key] = "__ARRAY__"
166
+ elif hasattr(value, "values"): # pandas Series
167
+ data_arrays[f"_kwarg_{key}"] = np.asarray(value)
168
+ processed_kwargs[key] = "__ARRAY__"
169
+ elif _is_serializable(value):
170
+ processed_kwargs[key] = value
171
+ else:
172
+ try:
173
+ processed_kwargs[key] = str(value)
174
+ except Exception:
175
+ pass
176
+
177
+ # Process positional args (less common for seaborn)
178
+ processed_args = []
179
+ for i, arg in enumerate(args):
180
+ if isinstance(arg, np.ndarray):
181
+ data_arrays[f"_arg_{i}"] = arg
182
+ processed_args.append("__ARRAY__")
183
+ elif hasattr(arg, "values"):
184
+ data_arrays[f"_arg_{i}"] = np.asarray(arg)
185
+ processed_args.append("__ARRAY__")
186
+ elif _is_serializable(arg):
187
+ processed_args.append(arg)
188
+ else:
189
+ processed_args.append(str(arg))
190
+
191
+ return tuple(processed_args), processed_kwargs, data_arrays
192
+
193
+
194
+ def _is_serializable(value: Any) -> bool:
195
+ """Check if value is directly serializable."""
196
+ if value is None:
197
+ return True
198
+ if isinstance(value, (bool, int, float, str)):
199
+ return True
200
+ if isinstance(value, (list, tuple)):
201
+ return all(_is_serializable(v) for v in value)
202
+ if isinstance(value, dict):
203
+ return all(
204
+ isinstance(k, str) and _is_serializable(v)
205
+ for k, v in value.items()
206
+ )
207
+ return False
208
+
209
+
210
+ class SeabornRecorder:
211
+ """Wrapper that records seaborn plotting calls."""
212
+
213
+ def __init__(self):
214
+ _check_seaborn()
215
+
216
+ def __getattr__(self, name: str) -> Callable:
217
+ """Get a wrapped seaborn function."""
218
+ if name.startswith("_"):
219
+ raise AttributeError(name)
220
+
221
+ if not hasattr(sns, name):
222
+ raise AttributeError(f"seaborn has no attribute '{name}'")
223
+
224
+ original_func = getattr(sns, name)
225
+
226
+ if name not in SEABORN_PLOT_FUNCTIONS:
227
+ # Return unwrapped for non-plotting functions
228
+ return original_func
229
+
230
+ @wraps(original_func)
231
+ def wrapped(*args, **kwargs):
232
+ return self._record_and_call(name, original_func, args, kwargs)
233
+
234
+ return wrapped
235
+
236
+ def _record_and_call(
237
+ self,
238
+ func_name: str,
239
+ func: Callable,
240
+ args: tuple,
241
+ kwargs: Dict[str, Any],
242
+ ) -> Any:
243
+ """Record the seaborn call and execute it.
244
+
245
+ Parameters
246
+ ----------
247
+ func_name : str
248
+ Name of the seaborn function.
249
+ func : callable
250
+ The actual seaborn function.
251
+ args : tuple
252
+ Positional arguments.
253
+ kwargs : dict
254
+ Keyword arguments.
255
+
256
+ Returns
257
+ -------
258
+ Any
259
+ Result from the seaborn function.
260
+ """
261
+ from ._wrappers._axes import RecordingAxes
262
+
263
+ # Extract custom ID if provided
264
+ call_id = kwargs.pop("id", None)
265
+
266
+ # Get the axes
267
+ ax = kwargs.get("ax")
268
+
269
+ # If we have a RecordingAxes, disable recording during seaborn call
270
+ # to prevent recording the underlying matplotlib calls (e.g., scatter)
271
+ # that seaborn makes internally. We only want to record the seaborn call.
272
+ if isinstance(ax, RecordingAxes):
273
+ with ax.no_record():
274
+ result = func(*args, **kwargs)
275
+
276
+ # Serialize arguments
277
+ proc_args, proc_kwargs, data_arrays = _serialize_seaborn_args(
278
+ func_name, args, kwargs
279
+ )
280
+
281
+ # Record as a seaborn call (outside no_record context)
282
+ ax.record_seaborn_call(
283
+ func_name=func_name,
284
+ args=proc_args,
285
+ kwargs=proc_kwargs,
286
+ data_arrays=data_arrays,
287
+ call_id=call_id,
288
+ )
289
+ else:
290
+ # No recording axes, just call the function
291
+ result = func(*args, **kwargs)
292
+
293
+ return result
294
+
295
+
296
+ # Module-level instance for convenient access
297
+ _recorder: Optional[SeabornRecorder] = None
298
+
299
+
300
+ def get_seaborn_recorder() -> SeabornRecorder:
301
+ """Get or create the seaborn recorder instance."""
302
+ global _recorder
303
+ if _recorder is None:
304
+ _recorder = SeabornRecorder()
305
+ return _recorder