figrecipe 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
figrecipe/_recorder.py ADDED
@@ -0,0 +1,435 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """Core recording functionality for figrecipe."""
4
+
5
+ from collections import OrderedDict
6
+ from dataclasses import dataclass, field
7
+ from datetime import datetime
8
+ from typing import Any, Dict, List, Optional, Tuple
9
+ import uuid
10
+
11
+ import matplotlib
12
+ import numpy as np
13
+
14
+
15
+ @dataclass
16
+ class CallRecord:
17
+ """Record of a single plotting call."""
18
+
19
+ id: str
20
+ function: str
21
+ args: List[Dict[str, Any]]
22
+ kwargs: Dict[str, Any]
23
+ timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
24
+ ax_position: Tuple[int, int] = (0, 0)
25
+
26
+ def to_dict(self) -> Dict[str, Any]:
27
+ """Convert to dictionary for serialization."""
28
+ return {
29
+ "id": self.id,
30
+ "function": self.function,
31
+ "args": self.args,
32
+ "kwargs": self.kwargs,
33
+ "timestamp": self.timestamp,
34
+ }
35
+
36
+ @classmethod
37
+ def from_dict(cls, data: Dict[str, Any], ax_position: Tuple[int, int] = (0, 0)) -> "CallRecord":
38
+ """Create from dictionary."""
39
+ return cls(
40
+ id=data["id"],
41
+ function=data["function"],
42
+ args=data["args"],
43
+ kwargs=data["kwargs"],
44
+ timestamp=data.get("timestamp", ""),
45
+ ax_position=ax_position,
46
+ )
47
+
48
+
49
+ @dataclass
50
+ class AxesRecord:
51
+ """Record of all calls on a single axes."""
52
+
53
+ position: Tuple[int, int]
54
+ calls: List[CallRecord] = field(default_factory=list)
55
+ decorations: List[CallRecord] = field(default_factory=list)
56
+
57
+ def add_call(self, record: CallRecord) -> None:
58
+ """Add a plotting call record."""
59
+ self.calls.append(record)
60
+
61
+ def add_decoration(self, record: CallRecord) -> None:
62
+ """Add a decoration call (set_xlabel, etc.)."""
63
+ self.decorations.append(record)
64
+
65
+ def to_dict(self) -> Dict[str, Any]:
66
+ """Convert to dictionary for serialization."""
67
+ return {
68
+ "calls": [c.to_dict() for c in self.calls],
69
+ "decorations": [d.to_dict() for d in self.decorations],
70
+ }
71
+
72
+
73
+ @dataclass
74
+ class FigureRecord:
75
+ """Record of an entire figure."""
76
+
77
+ id: str = field(default_factory=lambda: f"fig_{uuid.uuid4().hex[:8]}")
78
+ created: str = field(default_factory=lambda: datetime.now().isoformat())
79
+ matplotlib_version: str = field(default_factory=lambda: matplotlib.__version__)
80
+ figsize: Tuple[float, float] = (6.4, 4.8)
81
+ dpi: int = 300
82
+ axes: Dict[str, AxesRecord] = field(default_factory=dict)
83
+ # Layout parameters (subplots_adjust)
84
+ layout: Optional[Dict[str, float]] = None
85
+ # Style parameters
86
+ style: Optional[Dict[str, Any]] = None
87
+ # Constrained layout flag
88
+ constrained_layout: bool = False
89
+
90
+ def get_axes_key(self, row: int, col: int) -> str:
91
+ """Get dictionary key for axes at position."""
92
+ return f"ax_{row}_{col}"
93
+
94
+ def get_or_create_axes(self, row: int, col: int) -> AxesRecord:
95
+ """Get or create axes record at position."""
96
+ key = self.get_axes_key(row, col)
97
+ if key not in self.axes:
98
+ self.axes[key] = AxesRecord(position=(row, col))
99
+ return self.axes[key]
100
+
101
+ def to_dict(self) -> Dict[str, Any]:
102
+ """Convert to dictionary for serialization."""
103
+ result = {
104
+ "figrecipe": "1.0",
105
+ "id": self.id,
106
+ "created": self.created,
107
+ "matplotlib_version": self.matplotlib_version,
108
+ "figure": {
109
+ "figsize": list(self.figsize),
110
+ "dpi": self.dpi,
111
+ },
112
+ "axes": {k: v.to_dict() for k, v in self.axes.items()},
113
+ }
114
+ # Add layout if set
115
+ if self.layout is not None:
116
+ result["figure"]["layout"] = self.layout
117
+ # Add style if set
118
+ if self.style is not None:
119
+ result["figure"]["style"] = self.style
120
+ # Add constrained_layout if True
121
+ if self.constrained_layout:
122
+ result["figure"]["constrained_layout"] = True
123
+ return result
124
+
125
+ @classmethod
126
+ def from_dict(cls, data: Dict[str, Any]) -> "FigureRecord":
127
+ """Create from dictionary."""
128
+ fig_data = data.get("figure", {})
129
+ record = cls(
130
+ id=data.get("id", f"fig_{uuid.uuid4().hex[:8]}"),
131
+ created=data.get("created", ""),
132
+ matplotlib_version=data.get("matplotlib_version", ""),
133
+ figsize=tuple(fig_data.get("figsize", [6.4, 4.8])),
134
+ dpi=fig_data.get("dpi", 300),
135
+ layout=fig_data.get("layout"),
136
+ style=fig_data.get("style"),
137
+ constrained_layout=fig_data.get("constrained_layout", False),
138
+ )
139
+
140
+ # Reconstruct axes
141
+ for ax_key, ax_data in data.get("axes", {}).items():
142
+ # Parse position from key like "ax_0_1"
143
+ parts = ax_key.split("_")
144
+ if len(parts) >= 3:
145
+ row, col = int(parts[1]), int(parts[2])
146
+ else:
147
+ row, col = 0, 0
148
+
149
+ ax_record = AxesRecord(position=(row, col))
150
+ for call_data in ax_data.get("calls", []):
151
+ ax_record.calls.append(CallRecord.from_dict(call_data, (row, col)))
152
+ for dec_data in ax_data.get("decorations", []):
153
+ ax_record.decorations.append(CallRecord.from_dict(dec_data, (row, col)))
154
+
155
+ record.axes[ax_key] = ax_record
156
+
157
+ return record
158
+
159
+
160
+ class Recorder:
161
+ """Central recorder for tracking matplotlib calls."""
162
+
163
+ # Plotting methods that create artists
164
+ PLOTTING_METHODS = {
165
+ "plot", "scatter", "bar", "barh", "hist", "hist2d",
166
+ "boxplot", "violinplot", "pie", "errorbar", "fill",
167
+ "fill_between", "fill_betweenx", "stackplot", "stem",
168
+ "step", "imshow", "pcolor", "pcolormesh", "contour",
169
+ "contourf", "quiver", "barbs", "streamplot", "hexbin",
170
+ "tripcolor", "triplot", "tricontour", "tricontourf",
171
+ "eventplot", "stairs", "ecdf", "matshow", "spy",
172
+ "loglog", "semilogx", "semilogy", "acorr", "xcorr",
173
+ "specgram", "psd", "csd", "cohere", "angle_spectrum",
174
+ "magnitude_spectrum", "phase_spectrum",
175
+ }
176
+
177
+ # Decoration methods
178
+ DECORATION_METHODS = {
179
+ "set_xlabel", "set_ylabel", "set_title", "set_xlim",
180
+ "set_ylim", "legend", "grid", "axhline", "axvline",
181
+ "axhspan", "axvspan", "text", "annotate",
182
+ }
183
+
184
+ def __init__(self):
185
+ self._figure_record: Optional[FigureRecord] = None
186
+ self._method_counters: Dict[str, int] = {}
187
+
188
+ def start_figure(
189
+ self,
190
+ figsize: Tuple[float, float] = (6.4, 4.8),
191
+ dpi: int = 300,
192
+ ) -> FigureRecord:
193
+ """Start recording a new figure."""
194
+ self._figure_record = FigureRecord(figsize=figsize, dpi=dpi)
195
+ self._method_counters = {}
196
+ return self._figure_record
197
+
198
+ @property
199
+ def figure_record(self) -> Optional[FigureRecord]:
200
+ """Get current figure record."""
201
+ return self._figure_record
202
+
203
+ def _generate_call_id(self, method_name: str) -> str:
204
+ """Generate unique call ID."""
205
+ counter = self._method_counters.get(method_name, 0)
206
+ self._method_counters[method_name] = counter + 1
207
+ return f"{method_name}_{counter:03d}"
208
+
209
+ def record_call(
210
+ self,
211
+ ax_position: Tuple[int, int],
212
+ method_name: str,
213
+ args: tuple,
214
+ kwargs: Dict[str, Any],
215
+ call_id: Optional[str] = None,
216
+ ) -> CallRecord:
217
+ """Record a plotting call.
218
+
219
+ Parameters
220
+ ----------
221
+ ax_position : tuple
222
+ (row, col) position of axes.
223
+ method_name : str
224
+ Name of the method called.
225
+ args : tuple
226
+ Positional arguments.
227
+ kwargs : dict
228
+ Keyword arguments.
229
+ call_id : str, optional
230
+ Custom ID for this call.
231
+
232
+ Returns
233
+ -------
234
+ CallRecord
235
+ The recorded call.
236
+ """
237
+ if self._figure_record is None:
238
+ self.start_figure()
239
+
240
+ # Generate ID if not provided
241
+ if call_id is None:
242
+ call_id = self._generate_call_id(method_name)
243
+
244
+ # Process args into serializable format
245
+ processed_args = self._process_args(args, method_name)
246
+
247
+ # Filter kwargs to non-default only (if signature available)
248
+ processed_kwargs = self._process_kwargs(kwargs, method_name)
249
+
250
+ record = CallRecord(
251
+ id=call_id,
252
+ function=method_name,
253
+ args=processed_args,
254
+ kwargs=processed_kwargs,
255
+ ax_position=ax_position,
256
+ )
257
+
258
+ # Add to appropriate axes
259
+ ax_record = self._figure_record.get_or_create_axes(*ax_position)
260
+
261
+ if method_name in self.DECORATION_METHODS:
262
+ ax_record.add_decoration(record)
263
+ else:
264
+ ax_record.add_call(record)
265
+
266
+ return record
267
+
268
+ def _process_args(
269
+ self,
270
+ args: tuple,
271
+ method_name: str,
272
+ ) -> List[Dict[str, Any]]:
273
+ """Process positional arguments for storage.
274
+
275
+ Parameters
276
+ ----------
277
+ args : tuple
278
+ Raw positional arguments.
279
+ method_name : str
280
+ Name of the method.
281
+
282
+ Returns
283
+ -------
284
+ list
285
+ Processed args with name and data.
286
+ """
287
+ from ._utils._numpy_io import should_store_inline, to_serializable
288
+
289
+ processed = []
290
+ # Simple arg names based on common patterns
291
+ arg_names = self._get_arg_names(method_name, len(args))
292
+
293
+ for i, (name, value) in enumerate(zip(arg_names, args)):
294
+ if isinstance(value, np.ndarray):
295
+ if should_store_inline(value):
296
+ processed.append({
297
+ "name": name,
298
+ "data": to_serializable(value),
299
+ "dtype": str(value.dtype),
300
+ })
301
+ else:
302
+ # Mark for file storage (will be handled by serializer)
303
+ processed.append({
304
+ "name": name,
305
+ "data": "__FILE__",
306
+ "dtype": str(value.dtype),
307
+ "_array": value, # Temporary, removed during serialization
308
+ })
309
+ elif hasattr(value, "values"): # pandas
310
+ arr = np.asarray(value)
311
+ if should_store_inline(arr):
312
+ processed.append({
313
+ "name": name,
314
+ "data": to_serializable(arr),
315
+ "dtype": str(arr.dtype),
316
+ })
317
+ else:
318
+ processed.append({
319
+ "name": name,
320
+ "data": "__FILE__",
321
+ "dtype": str(arr.dtype),
322
+ "_array": arr,
323
+ })
324
+ else:
325
+ # Scalar or other serializable value
326
+ try:
327
+ processed.append({
328
+ "name": name,
329
+ "data": value if self._is_serializable(value) else str(value),
330
+ })
331
+ except (TypeError, ValueError):
332
+ processed.append({
333
+ "name": name,
334
+ "data": str(value),
335
+ })
336
+
337
+ return processed
338
+
339
+ def _get_arg_names(self, method_name: str, n_args: int) -> List[str]:
340
+ """Get argument names for a method.
341
+
342
+ Parameters
343
+ ----------
344
+ method_name : str
345
+ Name of the method.
346
+ n_args : int
347
+ Number of arguments.
348
+
349
+ Returns
350
+ -------
351
+ list
352
+ List of argument names.
353
+ """
354
+ # Common patterns
355
+ patterns = {
356
+ "plot": ["x", "y", "fmt"],
357
+ "scatter": ["x", "y", "s", "c"],
358
+ "bar": ["x", "height", "width", "bottom"],
359
+ "barh": ["y", "width", "height", "left"],
360
+ "hist": ["x", "bins"],
361
+ "imshow": ["X"],
362
+ "contour": ["X", "Y", "Z", "levels"],
363
+ "contourf": ["X", "Y", "Z", "levels"],
364
+ "fill_between": ["x", "y1", "y2"],
365
+ "errorbar": ["x", "y", "yerr", "xerr"],
366
+ "text": ["x", "y", "s"],
367
+ "annotate": ["text", "xy", "xytext"],
368
+ }
369
+
370
+ if method_name in patterns:
371
+ names = patterns[method_name][:n_args]
372
+ # Pad with generic names if needed
373
+ while len(names) < n_args:
374
+ names.append(f"arg{len(names)}")
375
+ return names
376
+
377
+ # Default generic names
378
+ return [f"arg{i}" for i in range(n_args)]
379
+
380
+ def _process_kwargs(
381
+ self,
382
+ kwargs: Dict[str, Any],
383
+ method_name: str,
384
+ ) -> Dict[str, Any]:
385
+ """Process keyword arguments for storage.
386
+
387
+ Parameters
388
+ ----------
389
+ kwargs : dict
390
+ Raw keyword arguments.
391
+ method_name : str
392
+ Name of the method.
393
+
394
+ Returns
395
+ -------
396
+ dict
397
+ Processed kwargs (non-default only).
398
+ """
399
+ # Remove internal keys
400
+ skip_keys = {"id", "track", "_array"}
401
+ processed = {}
402
+
403
+ for key, value in kwargs.items():
404
+ if key in skip_keys:
405
+ continue
406
+
407
+ if self._is_serializable(value):
408
+ processed[key] = value
409
+ elif isinstance(value, np.ndarray):
410
+ processed[key] = value.tolist()
411
+ elif hasattr(value, "values"):
412
+ processed[key] = np.asarray(value).tolist()
413
+ else:
414
+ # Try to convert to string
415
+ try:
416
+ processed[key] = str(value)
417
+ except Exception:
418
+ pass
419
+
420
+ return processed
421
+
422
+ def _is_serializable(self, value: Any) -> bool:
423
+ """Check if value is directly serializable to YAML."""
424
+ if value is None:
425
+ return True
426
+ if isinstance(value, (bool, int, float, str)):
427
+ return True
428
+ if isinstance(value, (list, tuple)):
429
+ return all(self._is_serializable(v) for v in value)
430
+ if isinstance(value, dict):
431
+ return all(
432
+ isinstance(k, str) and self._is_serializable(v)
433
+ for k, v in value.items()
434
+ )
435
+ return False