figrecipe 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,186 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """Reproducibility validation for figrecipe recipes."""
4
+
5
+ import tempfile
6
+ from dataclasses import dataclass
7
+ from pathlib import Path
8
+ from typing import Optional, Union
9
+
10
+ import numpy as np
11
+
12
+
13
+ @dataclass
14
+ class ValidationResult:
15
+ """Result of reproducibility validation.
16
+
17
+ Attributes
18
+ ----------
19
+ valid : bool
20
+ True if reproduction is considered valid (MSE below threshold).
21
+ mse : float
22
+ Mean squared error between original and reproduced images.
23
+ psnr : float
24
+ Peak signal-to-noise ratio (higher is better, inf if identical).
25
+ max_diff : float
26
+ Maximum pixel difference (0-255 scale).
27
+ size_original : tuple
28
+ (height, width) of original image.
29
+ size_reproduced : tuple
30
+ (height, width) of reproduced image.
31
+ same_size : bool
32
+ True if dimensions match exactly.
33
+ file_size_diff : int
34
+ Difference in file sizes (bytes).
35
+ message : str
36
+ Human-readable summary.
37
+ """
38
+
39
+ valid: bool
40
+ mse: float
41
+ psnr: float
42
+ max_diff: float
43
+ size_original: tuple
44
+ size_reproduced: tuple
45
+ same_size: bool
46
+ file_size_diff: int
47
+ message: str
48
+
49
+ def __repr__(self) -> str:
50
+ status = "VALID" if self.valid else "INVALID"
51
+ return (
52
+ f"ValidationResult({status}, mse={self.mse:.2f}, "
53
+ f"size={'match' if self.same_size else 'differ'})"
54
+ )
55
+
56
+ def summary(self) -> str:
57
+ """Return detailed summary string."""
58
+ lines = [
59
+ f"Reproducibility Validation: {'PASSED' if self.valid else 'FAILED'}",
60
+ f" Dimensions: {self.size_original} vs {self.size_reproduced} "
61
+ f"({'match' if self.same_size else 'DIFFER'})",
62
+ f" Pixel MSE: {self.mse:.2f}",
63
+ f" Max pixel diff: {self.max_diff:.1f}",
64
+ f" PSNR: {self.psnr:.1f} dB" if not np.isinf(self.psnr) else " PSNR: inf (identical)",
65
+ f" File size diff: {self.file_size_diff:+d} bytes",
66
+ ]
67
+ if not self.valid:
68
+ lines.append(f" Note: {self.message}")
69
+ return "\n".join(lines)
70
+
71
+
72
+ def validate_recipe(
73
+ fig,
74
+ recipe_path: Union[str, Path],
75
+ mse_threshold: float = 100.0,
76
+ dpi: int = 150,
77
+ ) -> ValidationResult:
78
+ """Validate that a recipe can faithfully reproduce the original figure.
79
+
80
+ Parameters
81
+ ----------
82
+ fig : RecordingFigure
83
+ The original figure (with matplotlib figure accessible via fig.fig).
84
+ recipe_path : str or Path
85
+ Path to the saved recipe file.
86
+ mse_threshold : float
87
+ Maximum acceptable MSE for validation to pass (default: 100).
88
+ Lower values require closer matches.
89
+ dpi : int
90
+ DPI for comparison images (default: 150).
91
+
92
+ Returns
93
+ -------
94
+ ValidationResult
95
+ Detailed comparison results.
96
+ """
97
+ import matplotlib.pyplot as plt
98
+ from ._reproducer import reproduce
99
+ from ._utils._image_diff import compare_images
100
+
101
+ recipe_path = Path(recipe_path)
102
+
103
+ with tempfile.TemporaryDirectory() as tmpdir:
104
+ tmpdir = Path(tmpdir)
105
+
106
+ # Save original figure to temp image
107
+ original_path = tmpdir / "original.png"
108
+ fig.fig.savefig(original_path, dpi=dpi)
109
+
110
+ # Reproduce from recipe
111
+ reproduced_fig, _ = reproduce(recipe_path)
112
+
113
+ # Save reproduced figure
114
+ reproduced_path = tmpdir / "reproduced.png"
115
+ reproduced_fig.savefig(reproduced_path, dpi=dpi)
116
+
117
+ # Close reproduced figure to prevent double display in notebooks
118
+ plt.close(reproduced_fig)
119
+
120
+ # Compare images
121
+ diff = compare_images(original_path, reproduced_path)
122
+
123
+ # Determine validity
124
+ mse = diff["mse"]
125
+ if np.isnan(mse):
126
+ # Different sizes - invalid
127
+ valid = False
128
+ message = f"Image dimensions differ: {diff['size1']} vs {diff['size2']}"
129
+ elif mse > mse_threshold:
130
+ valid = False
131
+ message = f"MSE ({mse:.2f}) exceeds threshold ({mse_threshold})"
132
+ else:
133
+ valid = True
134
+ message = "Reproduction matches original within threshold"
135
+
136
+ return ValidationResult(
137
+ valid=valid,
138
+ mse=mse if not np.isnan(mse) else float("inf"),
139
+ psnr=diff["psnr"],
140
+ max_diff=diff["max_diff"] if not np.isnan(diff["max_diff"]) else float("inf"),
141
+ size_original=diff["size1"],
142
+ size_reproduced=diff["size2"],
143
+ same_size=diff["same_size"],
144
+ file_size_diff=diff["file_size2"] - diff["file_size1"],
145
+ message=message,
146
+ )
147
+
148
+
149
+ def validate_on_save(
150
+ fig,
151
+ recipe_path: Union[str, Path],
152
+ mse_threshold: float = 100.0,
153
+ dpi: int = 150,
154
+ raise_on_failure: bool = False,
155
+ ) -> Optional[ValidationResult]:
156
+ """Validate recipe immediately after saving.
157
+
158
+ Parameters
159
+ ----------
160
+ fig : RecordingFigure
161
+ The original figure.
162
+ recipe_path : str or Path
163
+ Path where recipe was saved.
164
+ mse_threshold : float
165
+ Maximum acceptable MSE.
166
+ dpi : int
167
+ DPI for comparison.
168
+ raise_on_failure : bool
169
+ If True, raise ValueError when validation fails.
170
+
171
+ Returns
172
+ -------
173
+ ValidationResult
174
+ Validation results.
175
+
176
+ Raises
177
+ ------
178
+ ValueError
179
+ If raise_on_failure=True and validation fails.
180
+ """
181
+ result = validate_recipe(fig, recipe_path, mse_threshold, dpi)
182
+
183
+ if raise_on_failure and not result.valid:
184
+ raise ValueError(f"Recipe validation failed: {result.message}")
185
+
186
+ return result
@@ -0,0 +1,8 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """Matplotlib object wrappers for recording."""
4
+
5
+ from ._axes import RecordingAxes
6
+ from ._figure import RecordingFigure
7
+
8
+ __all__ = ["RecordingAxes", "RecordingFigure"]
@@ -0,0 +1,327 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """Wrapped Axes that records all plotting calls."""
4
+
5
+ from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
6
+
7
+ import numpy as np
8
+ import matplotlib.pyplot as plt
9
+ from matplotlib.axes import Axes
10
+
11
+ if TYPE_CHECKING:
12
+ from .._recorder import Recorder
13
+
14
+
15
+ class RecordingAxes:
16
+ """Wrapper around matplotlib Axes that records all calls.
17
+
18
+ This wrapper intercepts calls to plotting methods and records them
19
+ for later reproduction.
20
+
21
+ Parameters
22
+ ----------
23
+ ax : matplotlib.axes.Axes
24
+ The underlying matplotlib axes.
25
+ recorder : Recorder
26
+ The recorder instance to log calls to.
27
+ position : tuple
28
+ (row, col) position in the figure grid.
29
+
30
+ Examples
31
+ --------
32
+ >>> import figrecipe as ps
33
+ >>> fig, ax = ps.subplots()
34
+ >>> ax.plot([1, 2, 3], [4, 5, 6], color='red', id='my_line')
35
+ >>> # The call is recorded automatically
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ ax: Axes,
41
+ recorder: "Recorder",
42
+ position: Tuple[int, int] = (0, 0),
43
+ ):
44
+ self._ax = ax
45
+ self._recorder = recorder
46
+ self._position = position
47
+ self._track = True
48
+
49
+ @property
50
+ def ax(self) -> Axes:
51
+ """Get the underlying matplotlib axes."""
52
+ return self._ax
53
+
54
+ @property
55
+ def position(self) -> Tuple[int, int]:
56
+ """Get axes position in grid."""
57
+ return self._position
58
+
59
+ def __getattr__(self, name: str) -> Any:
60
+ """Intercept attribute access to wrap methods.
61
+
62
+ This is the core mechanism for recording calls.
63
+ """
64
+ attr = getattr(self._ax, name)
65
+
66
+ # If it's a plotting or decoration method, wrap it
67
+ if callable(attr) and name in (
68
+ self._recorder.PLOTTING_METHODS | self._recorder.DECORATION_METHODS
69
+ ):
70
+ return self._create_recording_wrapper(name, attr)
71
+
72
+ # For other methods/attributes, return as-is
73
+ return attr
74
+
75
+ def _create_recording_wrapper(self, method_name: str, method: callable):
76
+ """Create a wrapper function that records the call.
77
+
78
+ Parameters
79
+ ----------
80
+ method_name : str
81
+ Name of the method.
82
+ method : callable
83
+ The original method.
84
+
85
+ Returns
86
+ -------
87
+ callable
88
+ Wrapped method that records calls.
89
+ """
90
+ def wrapper(*args, id: Optional[str] = None, track: bool = True, **kwargs):
91
+ # Call the original method first (without our custom kwargs)
92
+ result = method(*args, **kwargs)
93
+
94
+ # Record the call if tracking is enabled
95
+ if self._track and track:
96
+ # Capture actual colors from result for plotting methods
97
+ # that use matplotlib's color cycle
98
+ recorded_kwargs = kwargs.copy()
99
+ if method_name in ('plot', 'scatter', 'bar', 'barh', 'step', 'fill_between'):
100
+ if 'color' not in recorded_kwargs and 'c' not in recorded_kwargs:
101
+ actual_color = self._extract_color_from_result(method_name, result)
102
+ if actual_color is not None:
103
+ recorded_kwargs['color'] = actual_color
104
+
105
+ self._recorder.record_call(
106
+ ax_position=self._position,
107
+ method_name=method_name,
108
+ args=args,
109
+ kwargs=recorded_kwargs,
110
+ call_id=id,
111
+ )
112
+
113
+ return result
114
+
115
+ return wrapper
116
+
117
+ def _extract_color_from_result(self, method_name: str, result) -> Optional[str]:
118
+ """Extract actual color used from plot result.
119
+
120
+ Parameters
121
+ ----------
122
+ method_name : str
123
+ Name of the plotting method.
124
+ result : Any
125
+ Return value from the plotting method.
126
+
127
+ Returns
128
+ -------
129
+ str or None
130
+ The color used, or None if not extractable.
131
+ """
132
+ try:
133
+ if method_name == 'plot':
134
+ # plot() returns list of Line2D
135
+ if result and hasattr(result[0], 'get_color'):
136
+ return result[0].get_color()
137
+ elif method_name == 'scatter':
138
+ # scatter() returns PathCollection
139
+ if hasattr(result, 'get_facecolor'):
140
+ fc = result.get_facecolor()
141
+ if len(fc) > 0:
142
+ # Convert RGBA to hex
143
+ import matplotlib.colors as mcolors
144
+ return mcolors.to_hex(fc[0])
145
+ elif method_name in ('bar', 'barh'):
146
+ # bar() returns BarContainer
147
+ if hasattr(result, 'patches') and result.patches:
148
+ fc = result.patches[0].get_facecolor()
149
+ import matplotlib.colors as mcolors
150
+ return mcolors.to_hex(fc)
151
+ elif method_name == 'step':
152
+ # step() returns list of Line2D
153
+ if result and hasattr(result[0], 'get_color'):
154
+ return result[0].get_color()
155
+ elif method_name == 'fill_between':
156
+ # fill_between() returns PolyCollection
157
+ if hasattr(result, 'get_facecolor'):
158
+ fc = result.get_facecolor()
159
+ if len(fc) > 0:
160
+ import matplotlib.colors as mcolors
161
+ return mcolors.to_hex(fc[0])
162
+ except Exception:
163
+ pass
164
+ return None
165
+
166
+ def no_record(self):
167
+ """Context manager to temporarily disable recording.
168
+
169
+ Examples
170
+ --------
171
+ >>> with ax.no_record():
172
+ ... ax.plot([1, 2, 3], [4, 5, 6]) # Not recorded
173
+ """
174
+ return _NoRecordContext(self)
175
+
176
+ def record_seaborn_call(
177
+ self,
178
+ func_name: str,
179
+ args: tuple,
180
+ kwargs: Dict[str, Any],
181
+ data_arrays: Dict[str, np.ndarray],
182
+ call_id: Optional[str] = None,
183
+ ) -> None:
184
+ """Record a seaborn plotting call.
185
+
186
+ Parameters
187
+ ----------
188
+ func_name : str
189
+ Name of the seaborn function (e.g., 'scatterplot').
190
+ args : tuple
191
+ Processed positional arguments.
192
+ kwargs : dict
193
+ Processed keyword arguments.
194
+ data_arrays : dict
195
+ Dictionary of array data extracted from DataFrame/arrays.
196
+ call_id : str, optional
197
+ Custom ID for this call.
198
+ """
199
+ if not self._track:
200
+ return
201
+
202
+ from .._utils._numpy_io import should_store_inline, to_serializable
203
+
204
+ # Generate call ID if not provided
205
+ if call_id is None:
206
+ call_id = self._recorder._generate_call_id(f"sns_{func_name}")
207
+
208
+ # Process data arrays into args format
209
+ processed_args = []
210
+ for i, arg in enumerate(args):
211
+ if arg == "__ARRAY__":
212
+ key = f"_arg_{i}"
213
+ if key in data_arrays:
214
+ arr = data_arrays[key]
215
+ if should_store_inline(arr):
216
+ processed_args.append({
217
+ "name": f"arg{i}",
218
+ "data": to_serializable(arr),
219
+ "dtype": str(arr.dtype),
220
+ })
221
+ else:
222
+ processed_args.append({
223
+ "name": f"arg{i}",
224
+ "data": "__FILE__",
225
+ "dtype": str(arr.dtype),
226
+ "_array": arr,
227
+ })
228
+ else:
229
+ processed_args.append({
230
+ "name": f"arg{i}",
231
+ "data": arg,
232
+ })
233
+
234
+ # Process DataFrame column data
235
+ for key, arr in data_arrays.items():
236
+ if key.startswith("_col_"):
237
+ param_name = key[5:] # Remove "_col_" prefix
238
+ col_name = data_arrays.get(f"_colname_{param_name}", param_name)
239
+ if should_store_inline(arr):
240
+ processed_args.append({
241
+ "name": col_name,
242
+ "param": param_name,
243
+ "data": to_serializable(arr),
244
+ "dtype": str(arr.dtype),
245
+ })
246
+ else:
247
+ processed_args.append({
248
+ "name": col_name,
249
+ "param": param_name,
250
+ "data": "__FILE__",
251
+ "dtype": str(arr.dtype),
252
+ "_array": arr,
253
+ })
254
+
255
+ # Process kwarg arrays
256
+ processed_kwargs = dict(kwargs)
257
+ for key, value in kwargs.items():
258
+ if value == "__ARRAY__":
259
+ arr_key = f"_kwarg_{key}"
260
+ if arr_key in data_arrays:
261
+ arr = data_arrays[arr_key]
262
+ if should_store_inline(arr):
263
+ processed_kwargs[key] = to_serializable(arr)
264
+ else:
265
+ # Mark for file storage
266
+ processed_kwargs[key] = "__FILE__"
267
+ processed_kwargs[f"_array_{key}"] = arr
268
+
269
+ # Create call record
270
+ from .._recorder import CallRecord
271
+
272
+ record = CallRecord(
273
+ id=call_id,
274
+ function=f"sns.{func_name}",
275
+ args=processed_args,
276
+ kwargs=processed_kwargs,
277
+ ax_position=self._position,
278
+ )
279
+
280
+ # Add to axes record
281
+ ax_record = self._recorder.figure_record.get_or_create_axes(*self._position)
282
+ ax_record.add_call(record)
283
+
284
+ # Expose common properties directly
285
+ @property
286
+ def figure(self):
287
+ return self._ax.figure
288
+
289
+ @property
290
+ def xaxis(self):
291
+ return self._ax.xaxis
292
+
293
+ @property
294
+ def yaxis(self):
295
+ return self._ax.yaxis
296
+
297
+ # Methods that should not be recorded
298
+ def get_xlim(self):
299
+ return self._ax.get_xlim()
300
+
301
+ def get_ylim(self):
302
+ return self._ax.get_ylim()
303
+
304
+ def get_xlabel(self):
305
+ return self._ax.get_xlabel()
306
+
307
+ def get_ylabel(self):
308
+ return self._ax.get_ylabel()
309
+
310
+ def get_title(self):
311
+ return self._ax.get_title()
312
+
313
+
314
+ class _NoRecordContext:
315
+ """Context manager to temporarily disable recording."""
316
+
317
+ def __init__(self, axes: RecordingAxes):
318
+ self._axes = axes
319
+ self._original_track = axes._track
320
+
321
+ def __enter__(self):
322
+ self._axes._track = False
323
+ return self
324
+
325
+ def __exit__(self, exc_type, exc_val, exc_tb):
326
+ self._axes._track = self._original_track
327
+ return False