texer 0.5.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
texer/pgfplots.py ADDED
@@ -0,0 +1,1381 @@
1
+ """PGFPlots classes for LaTeX figure generation."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field
6
+ from typing import Any, Literal
7
+
8
+ from texer.specs import Spec, Iter, resolve_value
9
+ from texer.utils import format_options, hex_to_pgf_rgb, indent, is_hex_color
10
+
11
+ # Type aliases for common PGF options
12
+ MarkStyle = Literal[
13
+ "*", "x", "+", "-", "|", "o", "asterisk", "star",
14
+ "10-pointed star", "oplus", "oplus*", "otimes", "otimes*",
15
+ "square", "square*", "triangle", "triangle*",
16
+ "diamond", "diamond*", "pentagon", "pentagon*",
17
+ "Mercedes star", "Mercedes star flipped",
18
+ "halfcircle", "halfcircle*", "halfsquare*", "halfdiamond*"
19
+ ]
20
+
21
+ LineStyle = Literal[
22
+ "solid", "dotted", "densely dotted", "loosely dotted",
23
+ "dashed", "densely dashed", "loosely dashed",
24
+ "dashdotted", "densely dashdotted", "loosely dashdotted",
25
+ "dashdotdotted", "densely dashdotdotted", "loosely dashdotdotted"
26
+ ]
27
+
28
+ ColorName = Literal[
29
+ "red", "green", "blue", "cyan", "magenta", "yellow",
30
+ "black", "gray", "white", "darkgray", "lightgray",
31
+ "brown", "lime", "olive", "orange", "pink", "purple", "teal", "violet"
32
+ ]
33
+
34
+ LegendPos = Literal[
35
+ "north west", "north east", "south west", "south east",
36
+ "north", "south", "east", "west",
37
+ "outer north east"
38
+ ]
39
+
40
+ AxisLines = Literal["left", "center", "right", "box", "middle", "none"]
41
+
42
+ GridStyle = Literal["major", "minor", "both", "none"]
43
+
44
+
45
+ @dataclass
46
+ class Coordinates:
47
+ """Coordinates for a plot.
48
+
49
+ Examples:
50
+ # Static coordinates (list of tuples)
51
+ Coordinates([(0, 1), (1, 2), (2, 4)])
52
+
53
+ # From separate x, y arrays (numpy arrays or lists)
54
+ Coordinates(x=[0, 1, 2], y=[1, 2, 4])
55
+ Coordinates(x=np.array([0, 1, 2]), y=np.array([1, 2, 4]))
56
+
57
+ # 3D coordinates
58
+ Coordinates([(0, 0, 1), (1, 1, 2)])
59
+ Coordinates(x=[0, 1], y=[0, 1], z=[1, 2])
60
+
61
+ # Dynamic coordinates from data
62
+ Coordinates(Iter(Ref("points"), x=Ref("x"), y=Ref("y")))
63
+
64
+ # With marker size data (for scatter plots)
65
+ Coordinates(x=[0, 1, 2], y=[1, 2, 4], marker_size=[5, 10, 15])
66
+ Coordinates(Iter(Ref("points"), x=Ref("x"), y=Ref("y"), marker_size=Ref("size")))
67
+
68
+ # Control precision (default 6 significant figures)
69
+ Coordinates(x=[0.123456789], y=[0.987654321]) # Outputs (0.123457, 0.987654)
70
+ Coordinates(x=[0.123456789], y=[0.987654321], precision=3) # Outputs (0.123, 0.988)
71
+ Coordinates(x=[0.123456789], y=[0.987654321], precision=None) # No rounding
72
+ """
73
+
74
+ source: list[tuple[Any, ...]] | Iter | Spec | None = None
75
+ x: Any = None
76
+ y: Any = None
77
+ z: Any = None
78
+ marker_size: Any = None # For data-driven marker sizes
79
+ precision: int | None = 6 # Number of significant figures (None = no rounding)
80
+
81
+ def __post_init__(self) -> None:
82
+ """Validate that either source or x/y are provided."""
83
+ if self.source is None and self.x is None:
84
+ raise ValueError("Either 'source' or 'x' and 'y' must be provided")
85
+ if self.source is not None and (self.x is not None or self.y is not None):
86
+ raise ValueError("Cannot specify both 'source' and 'x'/'y' parameters")
87
+ if self.x is not None and self.y is None:
88
+ raise ValueError("If 'x' is provided, 'y' must also be provided")
89
+
90
+ def render(self, data: Any, scope: dict[str, Any] | None = None) -> str:
91
+ """Render coordinates to LaTeX."""
92
+ if scope is None:
93
+ scope = {}
94
+
95
+ # Handle x, y, z arrays
96
+ if self.x is not None:
97
+ # Resolve x, y, z if they are Specs
98
+ x_resolved = resolve_value(self.x, data, scope)
99
+ y_resolved = resolve_value(self.y, data, scope)
100
+ z_resolved = resolve_value(self.z, data, scope) if self.z is not None else None
101
+ marker_size_resolved = resolve_value(self.marker_size, data, scope) if self.marker_size is not None else None
102
+ points = self._arrays_to_points_resolved(x_resolved, y_resolved, z_resolved, marker_size_resolved)
103
+ # Resolve the source
104
+ elif isinstance(self.source, (Iter, Spec)):
105
+ points = self.source.resolve(data, scope)
106
+ else:
107
+ points = self.source # type: ignore[assignment]
108
+
109
+ # Determine if we have marker_size data
110
+ # Check both direct marker_size attribute and Iter source with marker_size
111
+ has_marker_size = (
112
+ self.marker_size is not None or
113
+ (isinstance(self.source, Iter) and self.source.marker_size is not None)
114
+ )
115
+
116
+ # If we have marker_size, use table format instead of coordinates
117
+ # because \thisrow{size} only works with table input
118
+ if has_marker_size:
119
+ return self._render_as_table(points)
120
+
121
+ # Format as coordinates (standard format without marker_size)
122
+ coord_strs = []
123
+ for point in points:
124
+ if isinstance(point, tuple):
125
+ formatted_values = [self._format_value(v) for v in point]
126
+ coord_strs.append(f"({', '.join(formatted_values)})")
127
+ else:
128
+ # Single value (rare case)
129
+ coord_strs.append(f"({self._format_value(point)})")
130
+
131
+ return "coordinates {" + " ".join(coord_strs) + "}"
132
+
133
+ def _render_as_table(self, points: list[tuple[Any, ...]]) -> str:
134
+ """Render coordinates as table format for marker_size support."""
135
+ if not points:
136
+ return "table {x y size\n}"
137
+
138
+ # Determine if 3D based on point length
139
+ first_point = points[0]
140
+ is_3d = len(first_point) == 4 # (x, y, z, size) for 3D
141
+
142
+ # Build header
143
+ if is_3d:
144
+ header = "x y z size"
145
+ else:
146
+ header = "x y size"
147
+
148
+ # Build data rows
149
+ rows = [header]
150
+ for point in points:
151
+ formatted = [self._format_value(v) for v in point]
152
+ rows.append(" ".join(formatted))
153
+
154
+ return "table {\n" + "\n".join(rows) + "\n}"
155
+
156
+ def _arrays_to_points_resolved(self, x: Any, y: Any, z: Any = None, marker_size: Any = None) -> list[tuple[Any, ...]]:
157
+ """Convert resolved x, y, z, marker_size arrays to list of tuples."""
158
+ # Convert to lists if numpy arrays
159
+ x_list = self._to_list(x)
160
+ y_list = self._to_list(y)
161
+
162
+ if z is not None:
163
+ z_list = self._to_list(z)
164
+ if marker_size is not None:
165
+ marker_size_list = self._to_list(marker_size)
166
+ if not (len(x_list) == len(y_list) == len(z_list) == len(marker_size_list)):
167
+ raise ValueError(
168
+ f"x, y, z, and marker_size must have the same length "
169
+ f"(got {len(x_list)}, {len(y_list)}, {len(z_list)}, {len(marker_size_list)})"
170
+ )
171
+ return list(zip(x_list, y_list, z_list, marker_size_list))
172
+ else:
173
+ if not (len(x_list) == len(y_list) == len(z_list)):
174
+ raise ValueError(f"x, y, and z must have the same length (got {len(x_list)}, {len(y_list)}, {len(z_list)})")
175
+ return list(zip(x_list, y_list, z_list))
176
+ else:
177
+ if marker_size is not None:
178
+ marker_size_list = self._to_list(marker_size)
179
+ if not (len(x_list) == len(y_list) == len(marker_size_list)):
180
+ raise ValueError(
181
+ f"x, y, and marker_size must have the same length "
182
+ f"(got {len(x_list)}, {len(y_list)}, {len(marker_size_list)})"
183
+ )
184
+ return list(zip(x_list, y_list, marker_size_list))
185
+ else:
186
+ if len(x_list) != len(y_list):
187
+ raise ValueError(f"x and y must have the same length (got {len(x_list)}, {len(y_list)})")
188
+ return list(zip(x_list, y_list))
189
+
190
+ @staticmethod
191
+ def _to_list(arr: Any) -> list[Any]:
192
+ """Convert array-like to list, handling numpy arrays."""
193
+ # Check if it's a numpy array
194
+ if hasattr(arr, '__array__') or hasattr(arr, 'tolist'):
195
+ result: list[Any] = arr.tolist()
196
+ return result
197
+ # Already a list or tuple
198
+ elif isinstance(arr, (list, tuple)):
199
+ return list(arr)
200
+ else:
201
+ raise TypeError(f"Expected array-like object, got {type(arr)}")
202
+
203
+ def _format_value(self, value: Any) -> str:
204
+ """Format a numeric value with specified precision."""
205
+ # If precision is None, no rounding
206
+ if self.precision is None:
207
+ return str(value)
208
+
209
+ # Try to format as float with significant figures
210
+ try:
211
+ val = float(value)
212
+ # Handle special cases
213
+ if val == 0:
214
+ return "0"
215
+
216
+ # Use the 'g' format specifier which uses significant figures
217
+ # and automatically switches between fixed and scientific notation
218
+ format_str = f"{{:.{self.precision}g}}"
219
+ return format_str.format(val)
220
+ except (ValueError, TypeError):
221
+ # Not a number, return as-is
222
+ return str(value)
223
+
224
+
225
+ @dataclass
226
+ class AddPlot:
227
+ """An \\addplot command for PGFPlots.
228
+
229
+ Examples:
230
+ AddPlot(
231
+ color="blue",
232
+ mark="*",
233
+ coords=Coordinates([(0, 1), (1, 2)])
234
+ )
235
+
236
+ AddPlot(
237
+ style="dashed",
238
+ domain="0:10",
239
+ expression="x^2"
240
+ )
241
+
242
+ # Scatter plot with data-driven marker sizes
243
+ AddPlot(
244
+ scatter=True,
245
+ only_marks=True,
246
+ coords=Coordinates(x=[0, 1, 2], y=[1, 2, 4], marker_size=[5, 10, 15])
247
+ )
248
+ """
249
+
250
+ # Coordinate-based plot
251
+ coords: Coordinates | None = None
252
+
253
+ # Expression-based plot
254
+ expression: str | Spec | None = None
255
+ domain: str | Spec | None = None
256
+ samples: int | Spec | None = None
257
+
258
+ # Style options
259
+ color: ColorName | str | Spec | None = None
260
+ mark: MarkStyle | str | Spec | None = None
261
+ mark_size: str | float | Spec | None = None # Static marker size (e.g., "3pt" or 3)
262
+ style: LineStyle | str | Spec | None = None
263
+ line_width: str | Spec | None = None
264
+ only_marks: bool | Spec = False
265
+ no_marks: bool | Spec = False
266
+ smooth: bool | Spec = False
267
+ thick: bool | Spec = False
268
+
269
+ # Scatter plot options (for data-driven marker sizes)
270
+ scatter: bool | Spec = False
271
+ scatter_src: str | Spec | None = None # Which coordinate controls marker size ("explicit" uses meta column)
272
+
273
+ # Plot name for legend
274
+ name: str | Spec | None = None
275
+
276
+ # 3D options
277
+ surf: bool = False
278
+ mesh: bool = False
279
+
280
+ # Error bars
281
+ error_bars: bool = False
282
+ error_bar_style: dict[str, Any] = field(default_factory=dict)
283
+
284
+ # Cycle list option
285
+ use_cycle_list: bool = False
286
+
287
+ # Raw options escape hatch
288
+ _raw_options: str | None = None
289
+
290
+ def render(self, data: Any, scope: dict[str, Any] | None = None) -> str:
291
+ """Render the addplot command."""
292
+ if scope is None:
293
+ scope = {}
294
+
295
+ parts = []
296
+
297
+ # Build options (resolve Specs like Ref)
298
+ options = {}
299
+ if self.color:
300
+ color_value = resolve_value(self.color, data, scope)
301
+ # Convert hex colors to PGF RGB format
302
+ if isinstance(color_value, str) and is_hex_color(color_value):
303
+ color_value = hex_to_pgf_rgb(color_value)
304
+ options["color"] = color_value
305
+ if self.mark:
306
+ options["mark"] = resolve_value(self.mark, data, scope)
307
+ if self.mark_size:
308
+ mark_size_val = resolve_value(self.mark_size, data, scope)
309
+ # If numeric, add pt unit; otherwise use as-is
310
+ if isinstance(mark_size_val, (int, float)):
311
+ options["mark size"] = f"{mark_size_val}pt"
312
+ else:
313
+ options["mark size"] = mark_size_val
314
+ if self.style:
315
+ resolved_style = resolve_value(self.style, data, scope)
316
+ options[resolved_style] = True
317
+ if self.line_width:
318
+ options["line width"] = self.line_width
319
+ if self.only_marks:
320
+ options["only marks"] = True
321
+ if self.no_marks:
322
+ options["mark"] = "none"
323
+ if self.smooth:
324
+ options["smooth"] = True
325
+ if self.thick:
326
+ options["thick"] = True
327
+ if self.domain:
328
+ options["domain"] = self.domain
329
+ if self.samples:
330
+ options["samples"] = self.samples
331
+ if self.surf:
332
+ options["surf"] = True
333
+ if self.mesh:
334
+ options["mesh"] = True
335
+
336
+ # Scatter plot options
337
+ scatter_enabled = resolve_value(self.scatter, data, scope) if isinstance(self.scatter, Spec) else self.scatter
338
+
339
+ # Check if coordinates have marker_size data
340
+ has_marker_size_data = False
341
+ if self.coords:
342
+ # Check if marker_size is directly on Coordinates object (x/y/marker_size style)
343
+ if self.coords.marker_size is not None:
344
+ has_marker_size_data = True
345
+ # Check if source is an Iter with marker_size (Iter style)
346
+ elif isinstance(self.coords.source, Iter) and self.coords.source.marker_size is not None:
347
+ has_marker_size_data = True
348
+
349
+ if scatter_enabled:
350
+ if self.scatter_src:
351
+ # User explicitly wants scatter with color mapping
352
+ options["scatter"] = True
353
+ scatter_src_val = resolve_value(self.scatter_src, data, scope)
354
+ options["scatter src"] = scatter_src_val
355
+ elif has_marker_size_data:
356
+ # User wants variable marker sizes but not scatter coloring
357
+ # We need to enable scatter mode for the marker code to work,
358
+ # but we disable scatter's color mapping to keep the user's specified color
359
+ options["scatter"] = True
360
+ # Use current color (.) instead of mapped color - prevents gradient
361
+ options["scatter/use mapped color"] = "{draw=.!0!.,fill=.!0!.}"
362
+ options["visualization depends on"] = r"{\thisrow{size} \as \perpointmarksize}"
363
+ options["scatter/@pre marker code/.append style"] = "{/tikz/mark size=\\perpointmarksize}"
364
+ else:
365
+ # scatter=True but no marker_size and no scatter_src
366
+ # Enable basic scatter mode
367
+ options["scatter"] = True
368
+
369
+ # 3D variant
370
+ base_cmd = "\\addplot3" if self.surf or self.mesh else "\\addplot"
371
+
372
+ # Check if we should use cycle list automatically
373
+ # Use + if use_cycle_list is explicitly set, OR if there are no color/mark/style options
374
+ has_style_options = bool(self.color or self.mark or self.style or self.line_width or self.mark_size)
375
+ should_use_cycle = self.use_cycle_list or not has_style_options
376
+
377
+ # Add + for cycle list usage
378
+ plot_cmd = base_cmd + "+" if should_use_cycle else base_cmd
379
+
380
+ # Format options string
381
+ opts_str = format_options(options, self._raw_options)
382
+ if opts_str:
383
+ parts.append(f"{plot_cmd}[{opts_str}]")
384
+ else:
385
+ parts.append(plot_cmd)
386
+
387
+ # Add coordinates or expression
388
+ if self.coords is not None:
389
+ parts.append(self.coords.render(data, scope))
390
+ elif self.expression is not None:
391
+ parts.append(f"{{{self.expression}}}")
392
+
393
+ return " ".join(parts) + ";"
394
+
395
+ def __repr__(self) -> str:
396
+ return f"AddPlot(color={self.color!r}, mark={self.mark!r}, ...)"
397
+
398
+
399
+ @dataclass
400
+ class Legend:
401
+ """Legend entries for a plot.
402
+
403
+ Examples:
404
+ Legend(["Series A", "Series B"])
405
+ Legend([Ref("legend_label")])
406
+ """
407
+
408
+ entries: list[Any] | Iter | Spec
409
+
410
+ def render(self, data: Any, scope: dict[str, Any] | None = None) -> str:
411
+ """Render legend command."""
412
+ if scope is None:
413
+ scope = {}
414
+
415
+ from texer.eval import _evaluate_impl
416
+
417
+ # Resolve entries if it's a Spec (like Iter)
418
+ entries = resolve_value(self.entries, data, scope)
419
+
420
+ # Validate that entries is iterable
421
+ if entries is None:
422
+ raise TypeError(
423
+ f"Legend entries resolved to None. "
424
+ f"Check that your Iter source path exists in the data. "
425
+ f"Entries spec: {self.entries!r}"
426
+ )
427
+
428
+ if isinstance(entries, Iter):
429
+ raise TypeError(
430
+ f"Legend entries is an unresolved Iter object. "
431
+ f"This usually means the Iter's source path was not found or returned None. "
432
+ f"Iter source: {entries.source!r}. "
433
+ f"Available data keys: {list(data.keys()) if isinstance(data, dict) else 'N/A'}"
434
+ )
435
+
436
+ if not hasattr(entries, "__iter__") or isinstance(entries, str):
437
+ raise TypeError(
438
+ f"Legend entries must be a list or iterable, got {type(entries).__name__}. "
439
+ f"If using an Iter, ensure the source path exists and contains a collection."
440
+ )
441
+
442
+ resolved = []
443
+ for entry in entries:
444
+ resolved.append(_evaluate_impl(entry, data, scope, escape=False))
445
+
446
+ return "\\legend{" + ", ".join(resolved) + "}"
447
+
448
+
449
+ @dataclass
450
+ class Axis:
451
+ """A PGFPlots axis environment.
452
+
453
+ Examples:
454
+ Axis(
455
+ xlabel="Time (s)",
456
+ ylabel="Temperature (K)",
457
+ plots=[AddPlot(...)],
458
+ legend=["Data"]
459
+ )
460
+ """
461
+
462
+ plots: list[AddPlot] | Iter | Spec = field(default_factory=list)
463
+
464
+ # Axis labels
465
+ xlabel: str | Spec | None = None
466
+ ylabel: str | Spec | None = None
467
+ zlabel: str | Spec | None = None
468
+ title: str | Spec | None = None
469
+ title_style: str | Spec | None = None
470
+
471
+ # Axis limits
472
+ xmin: float | Spec | None = None
473
+ xmax: float | Spec | None = None
474
+ ymin: float | Spec | None = None
475
+ ymax: float | Spec | None = None
476
+ zmin: float | Spec | None = None
477
+ zmax: float | Spec | None = None
478
+
479
+ # Legend
480
+ legend: list[Any] | Legend | Iter | Spec | None = None
481
+ legend_pos: LegendPos | str | Spec | None = None
482
+ legend_style: str | Spec | None = None
483
+ legend_cell_align: Literal["left", "center", "right"] | str | Spec | None = None
484
+ legend_columns: int | Spec | None = None
485
+ transpose_legend: bool | Spec | None = None
486
+
487
+ # Grid
488
+ grid: GridStyle | bool | Spec | None = None
489
+
490
+ # Axis type
491
+ axis_type: Literal["axis", "semilogxaxis", "semilogyaxis", "loglogaxis"] = "axis"
492
+
493
+ # Scale
494
+ width: str | Spec | None = None
495
+ height: str | Spec | None = None
496
+
497
+ # Other common options
498
+ enlargelimits: bool | float | Spec | None = None
499
+ clip: bool | Spec | None = None
500
+ axis_lines: AxisLines | str | Spec | None = None
501
+
502
+ # Cycle list options
503
+ cycle_list_name: str | Spec | None = None
504
+ cycle_list: list[dict[str, Any]] | list[str] | Spec | None = None
505
+
506
+ # Tick positions and labels
507
+ xtick: list[float | int] | str | Spec | None = None
508
+ ytick: list[float | int] | str | Spec | None = None
509
+ ztick: list[float | int] | str | Spec | None = None
510
+ xticklabels: list[str] | str | Spec | None = None
511
+ yticklabels: list[str] | str | Spec | None = None
512
+ zticklabels: list[str] | str | Spec | None = None
513
+
514
+ # Raw options escape hatch
515
+ _raw_options: str | None = None
516
+
517
+ def render(self, data: Any, scope: dict[str, Any] | None = None) -> str:
518
+ """Render the axis environment."""
519
+ if scope is None:
520
+ scope = {}
521
+
522
+ from texer.eval import _evaluate_impl
523
+
524
+ # Build options
525
+ options: dict[str, Any] = {}
526
+
527
+ # Labels (resolve if Spec)
528
+ if self.xlabel is not None:
529
+ options["xlabel"] = _evaluate_impl(self.xlabel, data, scope, escape=False)
530
+ if self.ylabel is not None:
531
+ options["ylabel"] = _evaluate_impl(self.ylabel, data, scope, escape=False)
532
+ if self.zlabel is not None:
533
+ options["zlabel"] = _evaluate_impl(self.zlabel, data, scope, escape=False)
534
+ if self.title is not None:
535
+ options["title"] = _evaluate_impl(self.title, data, scope, escape=False)
536
+
537
+ # Limits (resolve if Spec)
538
+ if self.xmin is not None:
539
+ options["xmin"] = resolve_value(self.xmin, data, scope)
540
+ if self.xmax is not None:
541
+ options["xmax"] = resolve_value(self.xmax, data, scope)
542
+ if self.ymin is not None:
543
+ options["ymin"] = resolve_value(self.ymin, data, scope)
544
+ if self.ymax is not None:
545
+ options["ymax"] = resolve_value(self.ymax, data, scope)
546
+ if self.zmin is not None:
547
+ options["zmin"] = resolve_value(self.zmin, data, scope)
548
+ if self.zmax is not None:
549
+ options["zmax"] = resolve_value(self.zmax, data, scope)
550
+
551
+ # Legend options (resolve if Spec)
552
+ if self.legend_pos is not None:
553
+ options["legend pos"] = resolve_value(self.legend_pos, data, scope)
554
+ if self.legend_style is not None:
555
+ options["legend style"] = resolve_value(self.legend_style, data, scope)
556
+ if self.title_style is not None:
557
+ options["title style"] = resolve_value(self.title_style, data, scope)
558
+ if self.legend_cell_align is not None:
559
+ options["legend cell align"] = resolve_value(self.legend_cell_align, data, scope)
560
+ if self.legend_columns is not None:
561
+ options["legend columns"] = resolve_value(self.legend_columns, data, scope)
562
+ if self.transpose_legend is not None:
563
+ transpose_value = resolve_value(self.transpose_legend, data, scope)
564
+ if transpose_value:
565
+ options["transpose legend"] = True
566
+
567
+ # Grid (resolve if Spec)
568
+ grid_value = resolve_value(self.grid, data, scope) if isinstance(self.grid, Spec) else self.grid
569
+ if grid_value is True:
570
+ options["grid"] = "major"
571
+ elif grid_value:
572
+ options["grid"] = grid_value
573
+
574
+ # Dimensions (resolve if Spec)
575
+ if self.width is not None:
576
+ options["width"] = resolve_value(self.width, data, scope)
577
+ if self.height is not None:
578
+ options["height"] = resolve_value(self.height, data, scope)
579
+
580
+ # Other options (resolve if Spec)
581
+ if self.enlargelimits is not None:
582
+ options["enlargelimits"] = resolve_value(self.enlargelimits, data, scope)
583
+ if self.clip is not None:
584
+ options["clip"] = resolve_value(self.clip, data, scope)
585
+ if self.axis_lines is not None:
586
+ options["axis lines"] = resolve_value(self.axis_lines, data, scope)
587
+
588
+ # Cycle list options (resolve if Spec)
589
+ if self.cycle_list_name is not None:
590
+ options["cycle list name"] = resolve_value(self.cycle_list_name, data, scope)
591
+ elif self.cycle_list is not None:
592
+ cycle_list_resolved = resolve_value(self.cycle_list, data, scope)
593
+ # Format cycle list
594
+ cycle_entries = []
595
+ for entry in cycle_list_resolved:
596
+ if isinstance(entry, dict):
597
+ # Format as key=value pairs wrapped in braces
598
+ entry_str = format_options(entry, None)
599
+ cycle_entries.append("{" + entry_str + "}")
600
+ else:
601
+ # Plain string entry
602
+ cycle_entries.append(str(entry))
603
+ options["cycle list"] = "{" + ",".join(cycle_entries) + "}"
604
+
605
+ # Tick positions (resolve if Spec)
606
+ if self.xtick is not None:
607
+ xtick_val = resolve_value(self.xtick, data, scope)
608
+ if isinstance(xtick_val, list):
609
+ options["xtick"] = "{" + ",".join(str(v) for v in xtick_val) + "}"
610
+ else:
611
+ options["xtick"] = xtick_val
612
+ if self.ytick is not None:
613
+ ytick_val = resolve_value(self.ytick, data, scope)
614
+ if isinstance(ytick_val, list):
615
+ options["ytick"] = "{" + ",".join(str(v) for v in ytick_val) + "}"
616
+ else:
617
+ options["ytick"] = ytick_val
618
+ if self.ztick is not None:
619
+ ztick_val = resolve_value(self.ztick, data, scope)
620
+ if isinstance(ztick_val, list):
621
+ options["ztick"] = "{" + ",".join(str(v) for v in ztick_val) + "}"
622
+ else:
623
+ options["ztick"] = ztick_val
624
+
625
+ # Tick labels (resolve if Spec)
626
+ if self.xticklabels is not None:
627
+ xticklabels_val = resolve_value(self.xticklabels, data, scope)
628
+ if isinstance(xticklabels_val, list):
629
+ options["xticklabels"] = "{" + ",".join(str(v) for v in xticklabels_val) + "}"
630
+ else:
631
+ options["xticklabels"] = xticklabels_val
632
+ if self.yticklabels is not None:
633
+ yticklabels_val = resolve_value(self.yticklabels, data, scope)
634
+ if isinstance(yticklabels_val, list):
635
+ options["yticklabels"] = "{" + ",".join(str(v) for v in yticklabels_val) + "}"
636
+ else:
637
+ options["yticklabels"] = yticklabels_val
638
+ if self.zticklabels is not None:
639
+ zticklabels_val = resolve_value(self.zticklabels, data, scope)
640
+ if isinstance(zticklabels_val, list):
641
+ options["zticklabels"] = "{" + ",".join(str(v) for v in zticklabels_val) + "}"
642
+ else:
643
+ options["zticklabels"] = zticklabels_val
644
+
645
+ # Format options
646
+ opts_str = format_options(options, self._raw_options)
647
+
648
+ lines = []
649
+
650
+ # Opening
651
+ if opts_str:
652
+ lines.append(f"\\begin{{{self.axis_type}}}[{opts_str}]")
653
+ else:
654
+ lines.append(f"\\begin{{{self.axis_type}}}")
655
+
656
+ # Plots (handle Iter specially to preserve scope)
657
+ if isinstance(self.plots, Iter):
658
+ # Resolve the Iter source to get items
659
+ if isinstance(self.plots.source, str):
660
+ import glom # type: ignore[import-untyped]
661
+ items = glom.glom(data, self.plots.source)
662
+ else:
663
+ items = self.plots.source.resolve(data, scope)
664
+
665
+ # For each item, create updated scope and render template
666
+ for item in items:
667
+ item_scope = dict(scope) if scope else {}
668
+ if isinstance(item, dict):
669
+ item_scope.update(item)
670
+ # Resolve and render the template with the item scope
671
+ plot = resolve_value(self.plots.template, item, item_scope)
672
+ lines.append(f" {plot.render(data, item_scope)}")
673
+ else:
674
+ # Regular list of plots
675
+ plots = resolve_value(self.plots, data, scope)
676
+ for plot in plots:
677
+ lines.append(f" {plot.render(data, scope)}")
678
+
679
+ # Legend
680
+ if self.legend is not None:
681
+ if isinstance(self.legend, Legend):
682
+ lines.append(f" {self.legend.render(data, scope)}")
683
+ else:
684
+ legend = Legend(self.legend)
685
+ lines.append(f" {legend.render(data, scope)}")
686
+
687
+ # Closing
688
+ lines.append(f"\\end{{{self.axis_type}}}")
689
+
690
+ return "\n".join(lines)
691
+
692
+
693
+ @dataclass
694
+ class NextGroupPlot:
695
+ """A \\nextgroupplot command within a groupplot environment.
696
+
697
+ Examples:
698
+ NextGroupPlot(
699
+ title="Plot 1",
700
+ xlabel="X",
701
+ plots=[AddPlot(...)]
702
+ )
703
+ """
704
+
705
+ plots: list[AddPlot] | Iter | Spec = field(default_factory=list)
706
+
707
+ # Axis labels
708
+ xlabel: str | Spec | None = None
709
+ ylabel: str | Spec | None = None
710
+ zlabel: str | Spec | None = None
711
+ title: str | Spec | None = None
712
+ title_style: str | Spec | None = None
713
+
714
+ # Axis limits
715
+ xmin: float | Spec | None = None
716
+ xmax: float | Spec | None = None
717
+ ymin: float | Spec | None = None
718
+ ymax: float | Spec | None = None
719
+ zmin: float | Spec | None = None
720
+ zmax: float | Spec | None = None
721
+
722
+ # Legend
723
+ legend: list[Any] | Legend | Iter | Spec | None = None
724
+ legend_pos: LegendPos | str | Spec | None = None
725
+ legend_style: str | Spec | None = None
726
+ legend_cell_align: Literal["left", "center", "right"] | str | Spec | None = None
727
+ legend_columns: int | Spec | None = None
728
+ transpose_legend: bool | Spec | None = None
729
+
730
+ # Grid
731
+ grid: GridStyle | bool | Spec | None = None
732
+
733
+ # Other options
734
+ enlargelimits: bool | float | Spec | None = None
735
+ clip: bool | Spec | None = None
736
+ axis_lines: AxisLines | str | Spec | None = None
737
+
738
+ # Cycle list options
739
+ cycle_list_name: str | Spec | None = None
740
+ cycle_list: list[dict[str, Any]] | list[str] | Spec | None = None
741
+
742
+ # Tick positions and labels
743
+ xtick: list[float | int] | str | Spec | None = None
744
+ ytick: list[float | int] | str | Spec | None = None
745
+ ztick: list[float | int] | str | Spec | None = None
746
+ xticklabels: list[str] | str | Spec | None = None
747
+ yticklabels: list[str] | str | Spec | None = None
748
+ zticklabels: list[str] | str | Spec | None = None
749
+
750
+ # Raw options escape hatch
751
+ _raw_options: str | None = None
752
+
753
+ def render(self, data: Any, scope: dict[str, Any] | None = None) -> str:
754
+ """Render the nextgroupplot command and its contents."""
755
+ if scope is None:
756
+ scope = {}
757
+
758
+ from texer.eval import _evaluate_impl
759
+
760
+ # Build options
761
+ options: dict[str, Any] = {}
762
+
763
+ # Labels (resolve if Spec)
764
+ if self.xlabel is not None:
765
+ options["xlabel"] = _evaluate_impl(self.xlabel, data, scope, escape=False)
766
+ if self.ylabel is not None:
767
+ options["ylabel"] = _evaluate_impl(self.ylabel, data, scope, escape=False)
768
+ if self.zlabel is not None:
769
+ options["zlabel"] = _evaluate_impl(self.zlabel, data, scope, escape=False)
770
+ if self.title is not None:
771
+ options["title"] = _evaluate_impl(self.title, data, scope, escape=False)
772
+
773
+ # Limits (resolve if Spec)
774
+ if self.xmin is not None:
775
+ options["xmin"] = resolve_value(self.xmin, data, scope)
776
+ if self.xmax is not None:
777
+ options["xmax"] = resolve_value(self.xmax, data, scope)
778
+ if self.ymin is not None:
779
+ options["ymin"] = resolve_value(self.ymin, data, scope)
780
+ if self.ymax is not None:
781
+ options["ymax"] = resolve_value(self.ymax, data, scope)
782
+ if self.zmin is not None:
783
+ options["zmin"] = resolve_value(self.zmin, data, scope)
784
+ if self.zmax is not None:
785
+ options["zmax"] = resolve_value(self.zmax, data, scope)
786
+
787
+ # Legend options (resolve if Spec)
788
+ if self.legend_pos is not None:
789
+ options["legend pos"] = resolve_value(self.legend_pos, data, scope)
790
+ if self.legend_style is not None:
791
+ options["legend style"] = resolve_value(self.legend_style, data, scope)
792
+ if self.title_style is not None:
793
+ options["title style"] = resolve_value(self.title_style, data, scope)
794
+ if self.legend_cell_align is not None:
795
+ options["legend cell align"] = resolve_value(self.legend_cell_align, data, scope)
796
+ if self.legend_columns is not None:
797
+ options["legend columns"] = resolve_value(self.legend_columns, data, scope)
798
+ if self.transpose_legend is not None:
799
+ transpose_value = resolve_value(self.transpose_legend, data, scope)
800
+ if transpose_value:
801
+ options["transpose legend"] = True
802
+
803
+ # Grid (resolve if Spec)
804
+ grid_value = resolve_value(self.grid, data, scope) if isinstance(self.grid, Spec) else self.grid
805
+ if grid_value is True:
806
+ options["grid"] = "major"
807
+ elif grid_value:
808
+ options["grid"] = grid_value
809
+
810
+ # Other options (resolve if Spec)
811
+ if self.enlargelimits is not None:
812
+ options["enlargelimits"] = resolve_value(self.enlargelimits, data, scope)
813
+ if self.clip is not None:
814
+ options["clip"] = resolve_value(self.clip, data, scope)
815
+ if self.axis_lines is not None:
816
+ options["axis lines"] = resolve_value(self.axis_lines, data, scope)
817
+
818
+ # Cycle list options (resolve if Spec)
819
+ if self.cycle_list_name is not None:
820
+ options["cycle list name"] = resolve_value(self.cycle_list_name, data, scope)
821
+ elif self.cycle_list is not None:
822
+ cycle_list_resolved = resolve_value(self.cycle_list, data, scope)
823
+ # Format cycle list
824
+ cycle_entries = []
825
+ for entry in cycle_list_resolved:
826
+ if isinstance(entry, dict):
827
+ # Format as key=value pairs wrapped in braces
828
+ entry_str = format_options(entry, None)
829
+ cycle_entries.append("{" + entry_str + "}")
830
+ else:
831
+ # Plain string entry
832
+ cycle_entries.append(str(entry))
833
+ options["cycle list"] = "{" + ",".join(cycle_entries) + "}"
834
+
835
+ # Tick positions (resolve if Spec)
836
+ if self.xtick is not None:
837
+ xtick_val = resolve_value(self.xtick, data, scope)
838
+ if isinstance(xtick_val, list):
839
+ options["xtick"] = "{" + ",".join(str(v) for v in xtick_val) + "}"
840
+ else:
841
+ options["xtick"] = xtick_val
842
+ if self.ytick is not None:
843
+ ytick_val = resolve_value(self.ytick, data, scope)
844
+ if isinstance(ytick_val, list):
845
+ options["ytick"] = "{" + ",".join(str(v) for v in ytick_val) + "}"
846
+ else:
847
+ options["ytick"] = ytick_val
848
+ if self.ztick is not None:
849
+ ztick_val = resolve_value(self.ztick, data, scope)
850
+ if isinstance(ztick_val, list):
851
+ options["ztick"] = "{" + ",".join(str(v) for v in ztick_val) + "}"
852
+ else:
853
+ options["ztick"] = ztick_val
854
+
855
+ # Tick labels (resolve if Spec)
856
+ if self.xticklabels is not None:
857
+ xticklabels_val = resolve_value(self.xticklabels, data, scope)
858
+ if isinstance(xticklabels_val, list):
859
+ options["xticklabels"] = "{" + ",".join(str(v) for v in xticklabels_val) + "}"
860
+ else:
861
+ options["xticklabels"] = xticklabels_val
862
+ if self.yticklabels is not None:
863
+ yticklabels_val = resolve_value(self.yticklabels, data, scope)
864
+ if isinstance(yticklabels_val, list):
865
+ options["yticklabels"] = "{" + ",".join(str(v) for v in yticklabels_val) + "}"
866
+ else:
867
+ options["yticklabels"] = yticklabels_val
868
+ if self.zticklabels is not None:
869
+ zticklabels_val = resolve_value(self.zticklabels, data, scope)
870
+ if isinstance(zticklabels_val, list):
871
+ options["zticklabels"] = "{" + ",".join(str(v) for v in zticklabels_val) + "}"
872
+ else:
873
+ options["zticklabels"] = zticklabels_val
874
+
875
+ # Format options
876
+ opts_str = format_options(options, self._raw_options)
877
+
878
+ lines = []
879
+
880
+ # Opening
881
+ if opts_str:
882
+ lines.append(f"\\nextgroupplot[{opts_str}]")
883
+ else:
884
+ lines.append("\\nextgroupplot")
885
+
886
+ # Plots (handle Iter specially to preserve scope)
887
+ if isinstance(self.plots, Iter):
888
+ # Resolve the Iter source to get items
889
+ if isinstance(self.plots.source, str):
890
+ import glom
891
+ items = glom.glom(data, self.plots.source)
892
+ else:
893
+ items = self.plots.source.resolve(data, scope)
894
+
895
+ # For each item, create updated scope and render template
896
+ for item in items:
897
+ item_scope = dict(scope) if scope else {}
898
+ if isinstance(item, dict):
899
+ item_scope.update(item)
900
+ # Resolve and render the template with the item scope
901
+ plot = resolve_value(self.plots.template, item, item_scope)
902
+ lines.append(f" {plot.render(data, item_scope)}")
903
+ else:
904
+ # Regular list of plots
905
+ plots = resolve_value(self.plots, data, scope)
906
+ for plot in plots:
907
+ lines.append(f" {plot.render(data, scope)}")
908
+
909
+ # Legend
910
+ if self.legend is not None:
911
+ if isinstance(self.legend, Legend):
912
+ lines.append(f" {self.legend.render(data, scope)}")
913
+ else:
914
+ legend = Legend(self.legend)
915
+ lines.append(f" {legend.render(data, scope)}")
916
+
917
+ return "\n".join(lines)
918
+
919
+
920
+ @dataclass
921
+ class GroupPlot:
922
+ """A groupplot environment for creating multiple plots in a grid layout.
923
+
924
+ Examples:
925
+ GroupPlot(
926
+ group_style={"group size": "2 by 2"},
927
+ plots=[
928
+ NextGroupPlot(title="Plot 1", plots=[...]),
929
+ NextGroupPlot(title="Plot 2", plots=[...]),
930
+ NextGroupPlot(title="Plot 3", plots=[...]),
931
+ NextGroupPlot(title="Plot 4", plots=[...]),
932
+ ]
933
+ )
934
+ """
935
+
936
+ plots: list[NextGroupPlot] | Iter | Spec = field(default_factory=list)
937
+
938
+ # Group style options
939
+ group_size: str | Spec | None = None # e.g., "2 by 2"
940
+ horizontal_sep: str | Spec | None = None
941
+ vertical_sep: str | Spec | None = None
942
+ xlabels_at: str | Spec | None = None # e.g., "edge bottom"
943
+ ylabels_at: str | Spec | None = None # e.g., "edge left"
944
+ xticklabels_at: str | Spec | None = None
945
+ yticklabels_at: str | Spec | None = None
946
+
947
+ # Common axis options (applied to all subplots)
948
+ width: str | Spec | None = None
949
+ height: str | Spec | None = None
950
+ xmin: float | Spec | None = None
951
+ xmax: float | Spec | None = None
952
+ ymin: float | Spec | None = None
953
+ ymax: float | Spec | None = None
954
+
955
+ # Cycle list options (applied to all subplots)
956
+ cycle_list_name: str | Spec | None = None
957
+ cycle_list: list[dict[str, Any]] | list[str] | Spec | None = None
958
+
959
+ # Raw options escape hatch
960
+ _raw_options: str | None = None
961
+ _raw_group_style: str | None = None
962
+
963
+ def render(self, data: Any, scope: dict[str, Any] | None = None) -> str:
964
+ """Render the groupplot environment."""
965
+ if scope is None:
966
+ scope = {}
967
+
968
+ # Build group style options (resolve if Spec)
969
+ group_style_opts = {}
970
+ if self.group_size is not None:
971
+ group_style_opts["group size"] = resolve_value(self.group_size, data, scope)
972
+ if self.horizontal_sep is not None:
973
+ group_style_opts["horizontal sep"] = resolve_value(self.horizontal_sep, data, scope)
974
+ if self.vertical_sep is not None:
975
+ group_style_opts["vertical sep"] = resolve_value(self.vertical_sep, data, scope)
976
+ if self.xlabels_at is not None:
977
+ group_style_opts["xlabels at"] = resolve_value(self.xlabels_at, data, scope)
978
+ if self.ylabels_at is not None:
979
+ group_style_opts["ylabels at"] = resolve_value(self.ylabels_at, data, scope)
980
+ if self.xticklabels_at is not None:
981
+ group_style_opts["xticklabels at"] = resolve_value(self.xticklabels_at, data, scope)
982
+ if self.yticklabels_at is not None:
983
+ group_style_opts["yticklabels at"] = resolve_value(self.yticklabels_at, data, scope)
984
+
985
+ # Build main options
986
+ options = {}
987
+
988
+ # Add group style if present
989
+ group_style_str = format_options(group_style_opts, self._raw_group_style)
990
+ if group_style_str:
991
+ options["group style"] = f"{{{group_style_str}}}"
992
+
993
+ # Common options (resolve if Spec)
994
+ if self.width is not None:
995
+ options["width"] = resolve_value(self.width, data, scope)
996
+ if self.height is not None:
997
+ options["height"] = resolve_value(self.height, data, scope)
998
+ if self.xmin is not None:
999
+ options["xmin"] = resolve_value(self.xmin, data, scope)
1000
+ if self.xmax is not None:
1001
+ options["xmax"] = resolve_value(self.xmax, data, scope)
1002
+ if self.ymin is not None:
1003
+ options["ymin"] = resolve_value(self.ymin, data, scope)
1004
+ if self.ymax is not None:
1005
+ options["ymax"] = resolve_value(self.ymax, data, scope)
1006
+
1007
+ # Cycle list options (resolve if Spec)
1008
+ if self.cycle_list_name is not None:
1009
+ options["cycle list name"] = resolve_value(self.cycle_list_name, data, scope)
1010
+ elif self.cycle_list is not None:
1011
+ cycle_list_resolved = resolve_value(self.cycle_list, data, scope)
1012
+ # Format cycle list
1013
+ cycle_entries = []
1014
+ for entry in cycle_list_resolved:
1015
+ if isinstance(entry, dict):
1016
+ # Format as key=value pairs wrapped in braces
1017
+ entry_str = format_options(entry, None)
1018
+ cycle_entries.append("{" + entry_str + "}")
1019
+ else:
1020
+ # Plain string entry
1021
+ cycle_entries.append(str(entry))
1022
+ options["cycle list"] = "{" + ",".join(cycle_entries) + "}"
1023
+
1024
+ # Format options
1025
+ opts_str = format_options(options, self._raw_options)
1026
+
1027
+ lines = []
1028
+
1029
+ # Opening
1030
+ if opts_str:
1031
+ lines.append(f"\\begin{{groupplot}}[{opts_str}]")
1032
+ else:
1033
+ lines.append("\\begin{groupplot}")
1034
+
1035
+ # Render each plot (handle Iter specially to preserve scope)
1036
+ if isinstance(self.plots, Iter):
1037
+ # Resolve the Iter source to get items
1038
+ if isinstance(self.plots.source, str):
1039
+ import glom
1040
+ items = glom.glom(data, self.plots.source)
1041
+ elif isinstance(self.plots.source, Spec):
1042
+ items = self.plots.source.resolve(data, scope)
1043
+ else:
1044
+ items = self.plots.source
1045
+
1046
+ # For each item, create updated scope and render template
1047
+ for item in items:
1048
+ item_scope = dict(scope) if scope else {}
1049
+ if isinstance(item, dict):
1050
+ item_scope.update(item)
1051
+ # Resolve and render the template with the item scope
1052
+ plot = resolve_value(self.plots.template, item, item_scope)
1053
+ plot_lines = plot.render(data, item_scope)
1054
+ for line in plot_lines.split("\n"):
1055
+ lines.append(f" {line}" if line else line)
1056
+ else:
1057
+ # Regular list of plots
1058
+ plots = resolve_value(self.plots, data, scope)
1059
+ for plot in plots:
1060
+ plot_lines = plot.render(data, scope)
1061
+ for line in plot_lines.split("\n"):
1062
+ lines.append(f" {line}" if line else line)
1063
+
1064
+ # Closing
1065
+ lines.append("\\end{groupplot}")
1066
+
1067
+ return "\n".join(lines)
1068
+
1069
+
1070
+ @dataclass
1071
+ class PGFPlot:
1072
+ """A complete PGFPlots tikzpicture.
1073
+
1074
+ Examples:
1075
+ # Single axis
1076
+ PGFPlot(
1077
+ Axis(
1078
+ xlabel="X",
1079
+ ylabel="Y",
1080
+ plots=[AddPlot(coords=Coordinates([...]))]
1081
+ )
1082
+ )
1083
+
1084
+ # Multiple plots in a grid with groupplot
1085
+ PGFPlot(
1086
+ GroupPlot(
1087
+ group_size="2 by 2",
1088
+ plots=[
1089
+ NextGroupPlot(...),
1090
+ NextGroupPlot(...),
1091
+ ]
1092
+ )
1093
+ )
1094
+ """
1095
+
1096
+ axis: Axis | GroupPlot
1097
+ preamble: list[str] = field(default_factory=list)
1098
+ scale: float | Spec | None = None
1099
+ _raw_options: str | None = None
1100
+
1101
+ def render(self, data: Any, scope: dict[str, Any] | None = None) -> str:
1102
+ """Render the complete tikzpicture."""
1103
+ if scope is None:
1104
+ scope = {}
1105
+
1106
+ lines = []
1107
+
1108
+ # Preamble (for standalone use)
1109
+ for line in self.preamble:
1110
+ lines.append(line)
1111
+
1112
+ # Build tikzpicture options
1113
+ options = {}
1114
+ if self.scale is not None:
1115
+ options["scale"] = self.scale
1116
+
1117
+ opts_str = format_options(options, self._raw_options)
1118
+
1119
+ # Opening
1120
+ if opts_str:
1121
+ lines.append(f"\\begin{{tikzpicture}}[{opts_str}]")
1122
+ else:
1123
+ lines.append("\\begin{tikzpicture}")
1124
+
1125
+ # Axis content
1126
+ axis_lines = self.axis.render(data, scope)
1127
+ for line in axis_lines.split("\n"):
1128
+ lines.append(f" {line}" if line else line)
1129
+
1130
+ # Closing
1131
+ lines.append("\\end{tikzpicture}")
1132
+
1133
+ return "\n".join(lines)
1134
+
1135
+ def with_preamble(self, data: Any = None) -> str:
1136
+ """Return LaTeX code including package imports for standalone use.
1137
+
1138
+ Args:
1139
+ data: Optional data dict for rendering (default: empty dict).
1140
+ """
1141
+ if data is None:
1142
+ data = {}
1143
+
1144
+ preamble = [
1145
+ "\\documentclass{standalone}",
1146
+ "\\usepackage{pgfplots}",
1147
+ "\\pgfplotsset{compat=1.18}",
1148
+ "\\usepgfplotslibrary{groupplots}",
1149
+ "",
1150
+ "\\begin{document}",
1151
+ ]
1152
+ content = self.render(data)
1153
+ closing = ["\\end{document}"]
1154
+
1155
+ return "\n".join(preamble + [content] + closing)
1156
+
1157
+ def save_to_file(
1158
+ self,
1159
+ file_path: str,
1160
+ data: Any = None,
1161
+ with_preamble: bool = True,
1162
+ ) -> None:
1163
+ """Save the LaTeX code to a file.
1164
+
1165
+ Args:
1166
+ file_path: Path to the output .tex file.
1167
+ data: Optional data dict for rendering (default: empty dict).
1168
+ with_preamble: Whether to include document preamble for standalone compilation (default: True).
1169
+
1170
+ Examples:
1171
+ # Save with preamble for standalone compilation
1172
+ plot.save_to_file("my_plot.tex")
1173
+
1174
+ # Save just the tikzpicture content
1175
+ plot.save_to_file("my_plot.tex", with_preamble=False)
1176
+
1177
+ # Save with data
1178
+ plot.save_to_file("my_plot.tex", data=my_data)
1179
+ """
1180
+ if data is None:
1181
+ data = {}
1182
+
1183
+ if with_preamble:
1184
+ latex_code = self.with_preamble(data)
1185
+ else:
1186
+ latex_code = self.render(data)
1187
+
1188
+ with open(file_path, "w", encoding="utf-8") as f:
1189
+ f.write(latex_code)
1190
+
1191
+ def compile_to_pdf(
1192
+ self,
1193
+ tex_file_path: str,
1194
+ data: Any = None,
1195
+ output_dir: str | None = None,
1196
+ ) -> str:
1197
+ """Save to .tex file and compile to PDF using pdflatex.
1198
+
1199
+ Args:
1200
+ tex_file_path: Path to save the .tex file (e.g., "my_plot.tex").
1201
+ data: Optional data dict for rendering (default: empty dict).
1202
+ output_dir: Optional output directory for compilation (default: same as .tex file).
1203
+
1204
+ Returns:
1205
+ Path to the generated PDF file.
1206
+
1207
+ Raises:
1208
+ RuntimeError: If pdflatex is not available or compilation fails.
1209
+
1210
+ Examples:
1211
+ # Simple compilation
1212
+ pdf_path = plot.compile_to_pdf("my_plot.tex")
1213
+
1214
+ # With data
1215
+ pdf_path = plot.compile_to_pdf("my_plot.tex", data=my_data)
1216
+
1217
+ # Specify output directory
1218
+ pdf_path = plot.compile_to_pdf("my_plot.tex", output_dir="/tmp")
1219
+ """
1220
+ import subprocess
1221
+ import shutil
1222
+ from pathlib import Path
1223
+
1224
+ # Check if pdflatex is available
1225
+ if shutil.which("pdflatex") is None:
1226
+ raise RuntimeError(
1227
+ "pdflatex not found. Please install a LaTeX distribution (e.g., TeX Live, MiKTeX)."
1228
+ )
1229
+
1230
+ # Save to file
1231
+ self.save_to_file(tex_file_path, data=data, with_preamble=True)
1232
+
1233
+ # Determine paths
1234
+ tex_path = Path(tex_file_path).resolve()
1235
+ output_path: Path
1236
+ if output_dir is None:
1237
+ output_path = tex_path.parent
1238
+ else:
1239
+ output_path = Path(output_dir).resolve()
1240
+
1241
+ # Run pdflatex
1242
+ try:
1243
+ result = subprocess.run(
1244
+ [
1245
+ "pdflatex",
1246
+ "-interaction=nonstopmode",
1247
+ f"-output-directory={output_path}",
1248
+ str(tex_path),
1249
+ ],
1250
+ capture_output=True,
1251
+ text=True,
1252
+ check=True,
1253
+ )
1254
+ except subprocess.CalledProcessError as e:
1255
+ raise RuntimeError(
1256
+ f"pdflatex compilation failed:\n{e.stderr}\n\nOutput:\n{e.stdout}"
1257
+ ) from e
1258
+
1259
+ # Return path to PDF
1260
+ pdf_path = output_path / tex_path.with_suffix(".pdf").name
1261
+ return str(pdf_path)
1262
+
1263
+
1264
+ # Convenience classes for specialized axis types
1265
+ @dataclass
1266
+ class SemiLogXAxis(Axis):
1267
+ """A semi-logarithmic axis (log scale on x-axis)."""
1268
+
1269
+ axis_type: Literal["axis", "semilogxaxis", "semilogyaxis", "loglogaxis"] = "semilogxaxis"
1270
+
1271
+
1272
+ @dataclass
1273
+ class SemiLogYAxis(Axis):
1274
+ """A semi-logarithmic axis (log scale on y-axis)."""
1275
+
1276
+ axis_type: Literal["axis", "semilogxaxis", "semilogyaxis", "loglogaxis"] = "semilogyaxis"
1277
+
1278
+
1279
+ @dataclass
1280
+ class LogLogAxis(Axis):
1281
+ """A log-log axis (log scale on both axes)."""
1282
+
1283
+ axis_type: Literal["axis", "semilogxaxis", "semilogyaxis", "loglogaxis"] = "loglogaxis"
1284
+
1285
+
1286
+ # Helper for creating simple line plots
1287
+ def simple_plot(
1288
+ x: list[float],
1289
+ y: list[float],
1290
+ xlabel: str = "x",
1291
+ ylabel: str = "y",
1292
+ title: str | None = None,
1293
+ color: str = "blue",
1294
+ mark: str = "*",
1295
+ precision: int | None = 6,
1296
+ ) -> PGFPlot:
1297
+ """Create a simple line plot from x and y data.
1298
+
1299
+ Args:
1300
+ x: X-axis data points.
1301
+ y: Y-axis data points.
1302
+ xlabel: Label for x-axis.
1303
+ ylabel: Label for y-axis.
1304
+ title: Optional plot title.
1305
+ color: Line/marker color.
1306
+ mark: Marker style.
1307
+ precision: Number of significant figures for coordinates (default: 6, None for no rounding).
1308
+
1309
+ Returns:
1310
+ A PGFPlot object ready for rendering.
1311
+ """
1312
+ coords = Coordinates(list(zip(x, y)), precision=precision)
1313
+
1314
+ return PGFPlot(
1315
+ Axis(
1316
+ xlabel=xlabel,
1317
+ ylabel=ylabel,
1318
+ title=title,
1319
+ plots=[AddPlot(color=color, mark=mark, coords=coords)],
1320
+ )
1321
+ )
1322
+
1323
+
1324
+ # Helper for creating scatter plots with data-driven marker sizes
1325
+ def scatter_plot(
1326
+ x: list[float],
1327
+ y: list[float],
1328
+ marker_size: list[float],
1329
+ xlabel: str = "x",
1330
+ ylabel: str = "y",
1331
+ title: str | None = None,
1332
+ color: str = "blue",
1333
+ mark: str = "*",
1334
+ precision: int | None = 6,
1335
+ ) -> PGFPlot:
1336
+ """Create a scatter plot with data-driven marker sizes (bubble chart).
1337
+
1338
+ Args:
1339
+ x: X-axis data points.
1340
+ y: Y-axis data points.
1341
+ marker_size: Marker size for each data point (in pt units).
1342
+ xlabel: Label for x-axis.
1343
+ ylabel: Label for y-axis.
1344
+ title: Optional plot title.
1345
+ color: Marker color.
1346
+ mark: Marker style.
1347
+ precision: Number of significant figures for coordinates (default: 6, None for no rounding).
1348
+
1349
+ Returns:
1350
+ A PGFPlot object ready for rendering.
1351
+
1352
+ Examples:
1353
+ # Create a bubble chart
1354
+ plot = scatter_plot(
1355
+ x=[1, 2, 3, 4, 5],
1356
+ y=[2, 4, 3, 5, 4],
1357
+ marker_size=[5, 10, 15, 20, 25],
1358
+ xlabel="X Value",
1359
+ ylabel="Y Value",
1360
+ title="Bubble Chart"
1361
+ )
1362
+ print(plot.render({}))
1363
+ """
1364
+ coords = Coordinates(x=x, y=y, marker_size=marker_size, precision=precision)
1365
+
1366
+ return PGFPlot(
1367
+ Axis(
1368
+ xlabel=xlabel,
1369
+ ylabel=ylabel,
1370
+ title=title,
1371
+ plots=[
1372
+ AddPlot(
1373
+ color=color,
1374
+ mark=mark,
1375
+ only_marks=True,
1376
+ scatter=True,
1377
+ coords=coords,
1378
+ )
1379
+ ],
1380
+ )
1381
+ )