figrecipe 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- figrecipe/__init__.py +1090 -0
- figrecipe/_recorder.py +435 -0
- figrecipe/_reproducer.py +358 -0
- figrecipe/_seaborn.py +305 -0
- figrecipe/_serializer.py +227 -0
- figrecipe/_signatures/__init__.py +7 -0
- figrecipe/_signatures/_loader.py +186 -0
- figrecipe/_utils/__init__.py +32 -0
- figrecipe/_utils/_crop.py +261 -0
- figrecipe/_utils/_diff.py +98 -0
- figrecipe/_utils/_image_diff.py +204 -0
- figrecipe/_utils/_numpy_io.py +204 -0
- figrecipe/_utils/_units.py +200 -0
- figrecipe/_validator.py +186 -0
- figrecipe/_wrappers/__init__.py +8 -0
- figrecipe/_wrappers/_axes.py +327 -0
- figrecipe/_wrappers/_figure.py +227 -0
- figrecipe/plt.py +12 -0
- figrecipe/pyplot.py +264 -0
- figrecipe/styles/__init__.py +50 -0
- figrecipe/styles/_style_applier.py +412 -0
- figrecipe/styles/_style_loader.py +450 -0
- figrecipe-0.5.0.dist-info/METADATA +336 -0
- figrecipe-0.5.0.dist-info/RECORD +26 -0
- figrecipe-0.5.0.dist-info/WHEEL +4 -0
- figrecipe-0.5.0.dist-info/licenses/LICENSE +661 -0
figrecipe/_recorder.py
ADDED
|
@@ -0,0 +1,435 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
"""Core recording functionality for figrecipe."""
|
|
4
|
+
|
|
5
|
+
from collections import OrderedDict
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
|
+
from datetime import datetime
|
|
8
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
9
|
+
import uuid
|
|
10
|
+
|
|
11
|
+
import matplotlib
|
|
12
|
+
import numpy as np
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class CallRecord:
|
|
17
|
+
"""Record of a single plotting call."""
|
|
18
|
+
|
|
19
|
+
id: str
|
|
20
|
+
function: str
|
|
21
|
+
args: List[Dict[str, Any]]
|
|
22
|
+
kwargs: Dict[str, Any]
|
|
23
|
+
timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
|
|
24
|
+
ax_position: Tuple[int, int] = (0, 0)
|
|
25
|
+
|
|
26
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
27
|
+
"""Convert to dictionary for serialization."""
|
|
28
|
+
return {
|
|
29
|
+
"id": self.id,
|
|
30
|
+
"function": self.function,
|
|
31
|
+
"args": self.args,
|
|
32
|
+
"kwargs": self.kwargs,
|
|
33
|
+
"timestamp": self.timestamp,
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
@classmethod
|
|
37
|
+
def from_dict(cls, data: Dict[str, Any], ax_position: Tuple[int, int] = (0, 0)) -> "CallRecord":
|
|
38
|
+
"""Create from dictionary."""
|
|
39
|
+
return cls(
|
|
40
|
+
id=data["id"],
|
|
41
|
+
function=data["function"],
|
|
42
|
+
args=data["args"],
|
|
43
|
+
kwargs=data["kwargs"],
|
|
44
|
+
timestamp=data.get("timestamp", ""),
|
|
45
|
+
ax_position=ax_position,
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@dataclass
|
|
50
|
+
class AxesRecord:
|
|
51
|
+
"""Record of all calls on a single axes."""
|
|
52
|
+
|
|
53
|
+
position: Tuple[int, int]
|
|
54
|
+
calls: List[CallRecord] = field(default_factory=list)
|
|
55
|
+
decorations: List[CallRecord] = field(default_factory=list)
|
|
56
|
+
|
|
57
|
+
def add_call(self, record: CallRecord) -> None:
|
|
58
|
+
"""Add a plotting call record."""
|
|
59
|
+
self.calls.append(record)
|
|
60
|
+
|
|
61
|
+
def add_decoration(self, record: CallRecord) -> None:
|
|
62
|
+
"""Add a decoration call (set_xlabel, etc.)."""
|
|
63
|
+
self.decorations.append(record)
|
|
64
|
+
|
|
65
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
66
|
+
"""Convert to dictionary for serialization."""
|
|
67
|
+
return {
|
|
68
|
+
"calls": [c.to_dict() for c in self.calls],
|
|
69
|
+
"decorations": [d.to_dict() for d in self.decorations],
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
@dataclass
|
|
74
|
+
class FigureRecord:
|
|
75
|
+
"""Record of an entire figure."""
|
|
76
|
+
|
|
77
|
+
id: str = field(default_factory=lambda: f"fig_{uuid.uuid4().hex[:8]}")
|
|
78
|
+
created: str = field(default_factory=lambda: datetime.now().isoformat())
|
|
79
|
+
matplotlib_version: str = field(default_factory=lambda: matplotlib.__version__)
|
|
80
|
+
figsize: Tuple[float, float] = (6.4, 4.8)
|
|
81
|
+
dpi: int = 300
|
|
82
|
+
axes: Dict[str, AxesRecord] = field(default_factory=dict)
|
|
83
|
+
# Layout parameters (subplots_adjust)
|
|
84
|
+
layout: Optional[Dict[str, float]] = None
|
|
85
|
+
# Style parameters
|
|
86
|
+
style: Optional[Dict[str, Any]] = None
|
|
87
|
+
# Constrained layout flag
|
|
88
|
+
constrained_layout: bool = False
|
|
89
|
+
|
|
90
|
+
def get_axes_key(self, row: int, col: int) -> str:
|
|
91
|
+
"""Get dictionary key for axes at position."""
|
|
92
|
+
return f"ax_{row}_{col}"
|
|
93
|
+
|
|
94
|
+
def get_or_create_axes(self, row: int, col: int) -> AxesRecord:
|
|
95
|
+
"""Get or create axes record at position."""
|
|
96
|
+
key = self.get_axes_key(row, col)
|
|
97
|
+
if key not in self.axes:
|
|
98
|
+
self.axes[key] = AxesRecord(position=(row, col))
|
|
99
|
+
return self.axes[key]
|
|
100
|
+
|
|
101
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
102
|
+
"""Convert to dictionary for serialization."""
|
|
103
|
+
result = {
|
|
104
|
+
"figrecipe": "1.0",
|
|
105
|
+
"id": self.id,
|
|
106
|
+
"created": self.created,
|
|
107
|
+
"matplotlib_version": self.matplotlib_version,
|
|
108
|
+
"figure": {
|
|
109
|
+
"figsize": list(self.figsize),
|
|
110
|
+
"dpi": self.dpi,
|
|
111
|
+
},
|
|
112
|
+
"axes": {k: v.to_dict() for k, v in self.axes.items()},
|
|
113
|
+
}
|
|
114
|
+
# Add layout if set
|
|
115
|
+
if self.layout is not None:
|
|
116
|
+
result["figure"]["layout"] = self.layout
|
|
117
|
+
# Add style if set
|
|
118
|
+
if self.style is not None:
|
|
119
|
+
result["figure"]["style"] = self.style
|
|
120
|
+
# Add constrained_layout if True
|
|
121
|
+
if self.constrained_layout:
|
|
122
|
+
result["figure"]["constrained_layout"] = True
|
|
123
|
+
return result
|
|
124
|
+
|
|
125
|
+
@classmethod
|
|
126
|
+
def from_dict(cls, data: Dict[str, Any]) -> "FigureRecord":
|
|
127
|
+
"""Create from dictionary."""
|
|
128
|
+
fig_data = data.get("figure", {})
|
|
129
|
+
record = cls(
|
|
130
|
+
id=data.get("id", f"fig_{uuid.uuid4().hex[:8]}"),
|
|
131
|
+
created=data.get("created", ""),
|
|
132
|
+
matplotlib_version=data.get("matplotlib_version", ""),
|
|
133
|
+
figsize=tuple(fig_data.get("figsize", [6.4, 4.8])),
|
|
134
|
+
dpi=fig_data.get("dpi", 300),
|
|
135
|
+
layout=fig_data.get("layout"),
|
|
136
|
+
style=fig_data.get("style"),
|
|
137
|
+
constrained_layout=fig_data.get("constrained_layout", False),
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
# Reconstruct axes
|
|
141
|
+
for ax_key, ax_data in data.get("axes", {}).items():
|
|
142
|
+
# Parse position from key like "ax_0_1"
|
|
143
|
+
parts = ax_key.split("_")
|
|
144
|
+
if len(parts) >= 3:
|
|
145
|
+
row, col = int(parts[1]), int(parts[2])
|
|
146
|
+
else:
|
|
147
|
+
row, col = 0, 0
|
|
148
|
+
|
|
149
|
+
ax_record = AxesRecord(position=(row, col))
|
|
150
|
+
for call_data in ax_data.get("calls", []):
|
|
151
|
+
ax_record.calls.append(CallRecord.from_dict(call_data, (row, col)))
|
|
152
|
+
for dec_data in ax_data.get("decorations", []):
|
|
153
|
+
ax_record.decorations.append(CallRecord.from_dict(dec_data, (row, col)))
|
|
154
|
+
|
|
155
|
+
record.axes[ax_key] = ax_record
|
|
156
|
+
|
|
157
|
+
return record
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
class Recorder:
|
|
161
|
+
"""Central recorder for tracking matplotlib calls."""
|
|
162
|
+
|
|
163
|
+
# Plotting methods that create artists
|
|
164
|
+
PLOTTING_METHODS = {
|
|
165
|
+
"plot", "scatter", "bar", "barh", "hist", "hist2d",
|
|
166
|
+
"boxplot", "violinplot", "pie", "errorbar", "fill",
|
|
167
|
+
"fill_between", "fill_betweenx", "stackplot", "stem",
|
|
168
|
+
"step", "imshow", "pcolor", "pcolormesh", "contour",
|
|
169
|
+
"contourf", "quiver", "barbs", "streamplot", "hexbin",
|
|
170
|
+
"tripcolor", "triplot", "tricontour", "tricontourf",
|
|
171
|
+
"eventplot", "stairs", "ecdf", "matshow", "spy",
|
|
172
|
+
"loglog", "semilogx", "semilogy", "acorr", "xcorr",
|
|
173
|
+
"specgram", "psd", "csd", "cohere", "angle_spectrum",
|
|
174
|
+
"magnitude_spectrum", "phase_spectrum",
|
|
175
|
+
}
|
|
176
|
+
|
|
177
|
+
# Decoration methods
|
|
178
|
+
DECORATION_METHODS = {
|
|
179
|
+
"set_xlabel", "set_ylabel", "set_title", "set_xlim",
|
|
180
|
+
"set_ylim", "legend", "grid", "axhline", "axvline",
|
|
181
|
+
"axhspan", "axvspan", "text", "annotate",
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
def __init__(self):
|
|
185
|
+
self._figure_record: Optional[FigureRecord] = None
|
|
186
|
+
self._method_counters: Dict[str, int] = {}
|
|
187
|
+
|
|
188
|
+
def start_figure(
|
|
189
|
+
self,
|
|
190
|
+
figsize: Tuple[float, float] = (6.4, 4.8),
|
|
191
|
+
dpi: int = 300,
|
|
192
|
+
) -> FigureRecord:
|
|
193
|
+
"""Start recording a new figure."""
|
|
194
|
+
self._figure_record = FigureRecord(figsize=figsize, dpi=dpi)
|
|
195
|
+
self._method_counters = {}
|
|
196
|
+
return self._figure_record
|
|
197
|
+
|
|
198
|
+
@property
|
|
199
|
+
def figure_record(self) -> Optional[FigureRecord]:
|
|
200
|
+
"""Get current figure record."""
|
|
201
|
+
return self._figure_record
|
|
202
|
+
|
|
203
|
+
def _generate_call_id(self, method_name: str) -> str:
|
|
204
|
+
"""Generate unique call ID."""
|
|
205
|
+
counter = self._method_counters.get(method_name, 0)
|
|
206
|
+
self._method_counters[method_name] = counter + 1
|
|
207
|
+
return f"{method_name}_{counter:03d}"
|
|
208
|
+
|
|
209
|
+
def record_call(
|
|
210
|
+
self,
|
|
211
|
+
ax_position: Tuple[int, int],
|
|
212
|
+
method_name: str,
|
|
213
|
+
args: tuple,
|
|
214
|
+
kwargs: Dict[str, Any],
|
|
215
|
+
call_id: Optional[str] = None,
|
|
216
|
+
) -> CallRecord:
|
|
217
|
+
"""Record a plotting call.
|
|
218
|
+
|
|
219
|
+
Parameters
|
|
220
|
+
----------
|
|
221
|
+
ax_position : tuple
|
|
222
|
+
(row, col) position of axes.
|
|
223
|
+
method_name : str
|
|
224
|
+
Name of the method called.
|
|
225
|
+
args : tuple
|
|
226
|
+
Positional arguments.
|
|
227
|
+
kwargs : dict
|
|
228
|
+
Keyword arguments.
|
|
229
|
+
call_id : str, optional
|
|
230
|
+
Custom ID for this call.
|
|
231
|
+
|
|
232
|
+
Returns
|
|
233
|
+
-------
|
|
234
|
+
CallRecord
|
|
235
|
+
The recorded call.
|
|
236
|
+
"""
|
|
237
|
+
if self._figure_record is None:
|
|
238
|
+
self.start_figure()
|
|
239
|
+
|
|
240
|
+
# Generate ID if not provided
|
|
241
|
+
if call_id is None:
|
|
242
|
+
call_id = self._generate_call_id(method_name)
|
|
243
|
+
|
|
244
|
+
# Process args into serializable format
|
|
245
|
+
processed_args = self._process_args(args, method_name)
|
|
246
|
+
|
|
247
|
+
# Filter kwargs to non-default only (if signature available)
|
|
248
|
+
processed_kwargs = self._process_kwargs(kwargs, method_name)
|
|
249
|
+
|
|
250
|
+
record = CallRecord(
|
|
251
|
+
id=call_id,
|
|
252
|
+
function=method_name,
|
|
253
|
+
args=processed_args,
|
|
254
|
+
kwargs=processed_kwargs,
|
|
255
|
+
ax_position=ax_position,
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
# Add to appropriate axes
|
|
259
|
+
ax_record = self._figure_record.get_or_create_axes(*ax_position)
|
|
260
|
+
|
|
261
|
+
if method_name in self.DECORATION_METHODS:
|
|
262
|
+
ax_record.add_decoration(record)
|
|
263
|
+
else:
|
|
264
|
+
ax_record.add_call(record)
|
|
265
|
+
|
|
266
|
+
return record
|
|
267
|
+
|
|
268
|
+
def _process_args(
|
|
269
|
+
self,
|
|
270
|
+
args: tuple,
|
|
271
|
+
method_name: str,
|
|
272
|
+
) -> List[Dict[str, Any]]:
|
|
273
|
+
"""Process positional arguments for storage.
|
|
274
|
+
|
|
275
|
+
Parameters
|
|
276
|
+
----------
|
|
277
|
+
args : tuple
|
|
278
|
+
Raw positional arguments.
|
|
279
|
+
method_name : str
|
|
280
|
+
Name of the method.
|
|
281
|
+
|
|
282
|
+
Returns
|
|
283
|
+
-------
|
|
284
|
+
list
|
|
285
|
+
Processed args with name and data.
|
|
286
|
+
"""
|
|
287
|
+
from ._utils._numpy_io import should_store_inline, to_serializable
|
|
288
|
+
|
|
289
|
+
processed = []
|
|
290
|
+
# Simple arg names based on common patterns
|
|
291
|
+
arg_names = self._get_arg_names(method_name, len(args))
|
|
292
|
+
|
|
293
|
+
for i, (name, value) in enumerate(zip(arg_names, args)):
|
|
294
|
+
if isinstance(value, np.ndarray):
|
|
295
|
+
if should_store_inline(value):
|
|
296
|
+
processed.append({
|
|
297
|
+
"name": name,
|
|
298
|
+
"data": to_serializable(value),
|
|
299
|
+
"dtype": str(value.dtype),
|
|
300
|
+
})
|
|
301
|
+
else:
|
|
302
|
+
# Mark for file storage (will be handled by serializer)
|
|
303
|
+
processed.append({
|
|
304
|
+
"name": name,
|
|
305
|
+
"data": "__FILE__",
|
|
306
|
+
"dtype": str(value.dtype),
|
|
307
|
+
"_array": value, # Temporary, removed during serialization
|
|
308
|
+
})
|
|
309
|
+
elif hasattr(value, "values"): # pandas
|
|
310
|
+
arr = np.asarray(value)
|
|
311
|
+
if should_store_inline(arr):
|
|
312
|
+
processed.append({
|
|
313
|
+
"name": name,
|
|
314
|
+
"data": to_serializable(arr),
|
|
315
|
+
"dtype": str(arr.dtype),
|
|
316
|
+
})
|
|
317
|
+
else:
|
|
318
|
+
processed.append({
|
|
319
|
+
"name": name,
|
|
320
|
+
"data": "__FILE__",
|
|
321
|
+
"dtype": str(arr.dtype),
|
|
322
|
+
"_array": arr,
|
|
323
|
+
})
|
|
324
|
+
else:
|
|
325
|
+
# Scalar or other serializable value
|
|
326
|
+
try:
|
|
327
|
+
processed.append({
|
|
328
|
+
"name": name,
|
|
329
|
+
"data": value if self._is_serializable(value) else str(value),
|
|
330
|
+
})
|
|
331
|
+
except (TypeError, ValueError):
|
|
332
|
+
processed.append({
|
|
333
|
+
"name": name,
|
|
334
|
+
"data": str(value),
|
|
335
|
+
})
|
|
336
|
+
|
|
337
|
+
return processed
|
|
338
|
+
|
|
339
|
+
def _get_arg_names(self, method_name: str, n_args: int) -> List[str]:
|
|
340
|
+
"""Get argument names for a method.
|
|
341
|
+
|
|
342
|
+
Parameters
|
|
343
|
+
----------
|
|
344
|
+
method_name : str
|
|
345
|
+
Name of the method.
|
|
346
|
+
n_args : int
|
|
347
|
+
Number of arguments.
|
|
348
|
+
|
|
349
|
+
Returns
|
|
350
|
+
-------
|
|
351
|
+
list
|
|
352
|
+
List of argument names.
|
|
353
|
+
"""
|
|
354
|
+
# Common patterns
|
|
355
|
+
patterns = {
|
|
356
|
+
"plot": ["x", "y", "fmt"],
|
|
357
|
+
"scatter": ["x", "y", "s", "c"],
|
|
358
|
+
"bar": ["x", "height", "width", "bottom"],
|
|
359
|
+
"barh": ["y", "width", "height", "left"],
|
|
360
|
+
"hist": ["x", "bins"],
|
|
361
|
+
"imshow": ["X"],
|
|
362
|
+
"contour": ["X", "Y", "Z", "levels"],
|
|
363
|
+
"contourf": ["X", "Y", "Z", "levels"],
|
|
364
|
+
"fill_between": ["x", "y1", "y2"],
|
|
365
|
+
"errorbar": ["x", "y", "yerr", "xerr"],
|
|
366
|
+
"text": ["x", "y", "s"],
|
|
367
|
+
"annotate": ["text", "xy", "xytext"],
|
|
368
|
+
}
|
|
369
|
+
|
|
370
|
+
if method_name in patterns:
|
|
371
|
+
names = patterns[method_name][:n_args]
|
|
372
|
+
# Pad with generic names if needed
|
|
373
|
+
while len(names) < n_args:
|
|
374
|
+
names.append(f"arg{len(names)}")
|
|
375
|
+
return names
|
|
376
|
+
|
|
377
|
+
# Default generic names
|
|
378
|
+
return [f"arg{i}" for i in range(n_args)]
|
|
379
|
+
|
|
380
|
+
def _process_kwargs(
|
|
381
|
+
self,
|
|
382
|
+
kwargs: Dict[str, Any],
|
|
383
|
+
method_name: str,
|
|
384
|
+
) -> Dict[str, Any]:
|
|
385
|
+
"""Process keyword arguments for storage.
|
|
386
|
+
|
|
387
|
+
Parameters
|
|
388
|
+
----------
|
|
389
|
+
kwargs : dict
|
|
390
|
+
Raw keyword arguments.
|
|
391
|
+
method_name : str
|
|
392
|
+
Name of the method.
|
|
393
|
+
|
|
394
|
+
Returns
|
|
395
|
+
-------
|
|
396
|
+
dict
|
|
397
|
+
Processed kwargs (non-default only).
|
|
398
|
+
"""
|
|
399
|
+
# Remove internal keys
|
|
400
|
+
skip_keys = {"id", "track", "_array"}
|
|
401
|
+
processed = {}
|
|
402
|
+
|
|
403
|
+
for key, value in kwargs.items():
|
|
404
|
+
if key in skip_keys:
|
|
405
|
+
continue
|
|
406
|
+
|
|
407
|
+
if self._is_serializable(value):
|
|
408
|
+
processed[key] = value
|
|
409
|
+
elif isinstance(value, np.ndarray):
|
|
410
|
+
processed[key] = value.tolist()
|
|
411
|
+
elif hasattr(value, "values"):
|
|
412
|
+
processed[key] = np.asarray(value).tolist()
|
|
413
|
+
else:
|
|
414
|
+
# Try to convert to string
|
|
415
|
+
try:
|
|
416
|
+
processed[key] = str(value)
|
|
417
|
+
except Exception:
|
|
418
|
+
pass
|
|
419
|
+
|
|
420
|
+
return processed
|
|
421
|
+
|
|
422
|
+
def _is_serializable(self, value: Any) -> bool:
|
|
423
|
+
"""Check if value is directly serializable to YAML."""
|
|
424
|
+
if value is None:
|
|
425
|
+
return True
|
|
426
|
+
if isinstance(value, (bool, int, float, str)):
|
|
427
|
+
return True
|
|
428
|
+
if isinstance(value, (list, tuple)):
|
|
429
|
+
return all(self._is_serializable(v) for v in value)
|
|
430
|
+
if isinstance(value, dict):
|
|
431
|
+
return all(
|
|
432
|
+
isinstance(k, str) and self._is_serializable(v)
|
|
433
|
+
for k, v in value.items()
|
|
434
|
+
)
|
|
435
|
+
return False
|