gsim 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,630 @@
1
+ """Stack visualization utility for PDK layer stacks.
2
+
3
+ Prints ASCII diagrams showing the layer stack structure with z-positions,
4
+ thicknesses, and layer numbers.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from dataclasses import dataclass
10
+ from typing import Any
11
+
12
+ import plotly.graph_objects as go
13
+ from gdsfactory.technology import LayerLevel
14
+ from gdsfactory.technology import LayerStack as GfLayerStack
15
+
16
+
17
+ @dataclass
18
+ class StackLayer:
19
+ """Parsed layer info for visualization."""
20
+
21
+ name: str
22
+ zmin: float
23
+ zmax: float
24
+ thickness: float
25
+ material: str | None = None
26
+ gds_layer: int | None = None
27
+ layer_type: str = "conductor" # conductor, via, dielectric, substrate
28
+
29
+
30
+ def _get_gds_layer_number(layer_level: LayerLevel) -> int | None:
31
+ """Extract GDS layer number from LayerLevel."""
32
+ layer: Any = layer_level.layer
33
+
34
+ # Handle tuple
35
+ if isinstance(layer, tuple):
36
+ return int(layer[0])
37
+
38
+ # Handle int
39
+ if isinstance(layer, int):
40
+ return int(layer)
41
+
42
+ # Handle LogicalLayer or enum with nested layer
43
+ if hasattr(layer, "layer"):
44
+ inner = layer.layer
45
+ if hasattr(inner, "layer"):
46
+ return int(inner.layer)
47
+ if isinstance(inner, int):
48
+ return int(inner)
49
+
50
+ # Handle enum with value
51
+ if hasattr(layer, "value"):
52
+ if isinstance(layer.value, tuple):
53
+ return int(layer.value[0])
54
+ return int(layer.value)
55
+
56
+ return None
57
+
58
+
59
+ def _classify_layer(name: str) -> str:
60
+ """Classify layer type based on name."""
61
+ name_lower = name.lower()
62
+
63
+ if "via" in name_lower or "cont" in name_lower:
64
+ return "via"
65
+ if "substrate" in name_lower or name_lower == "sub":
66
+ return "substrate"
67
+ if any(
68
+ m in name_lower
69
+ for m in ["metal", "topmetal", "m1", "m2", "m3", "m4", "m5", "poly", "active"]
70
+ ):
71
+ return "conductor"
72
+
73
+ return "dielectric"
74
+
75
+
76
+ def parse_layer_stack(layer_stack: GfLayerStack) -> list[StackLayer]:
77
+ """Parse a gdsfactory LayerStack into a list of StackLayer objects.
78
+
79
+ Args:
80
+ layer_stack: gdsfactory LayerStack object
81
+
82
+ Returns:
83
+ List of StackLayer objects sorted by zmin (ascending)
84
+ """
85
+ layers = []
86
+
87
+ for name, level in layer_stack.layers.items():
88
+ zmin = level.zmin if level.zmin is not None else 0.0
89
+ thickness = level.thickness if level.thickness is not None else 0.0
90
+ zmax = zmin + thickness
91
+ material = level.material if level.material else None
92
+ gds_layer = _get_gds_layer_number(level)
93
+ layer_type = _classify_layer(name)
94
+
95
+ layers.append(
96
+ StackLayer(
97
+ name=name,
98
+ zmin=zmin,
99
+ zmax=zmax,
100
+ thickness=thickness,
101
+ material=material,
102
+ gds_layer=gds_layer,
103
+ layer_type=layer_type,
104
+ )
105
+ )
106
+
107
+ # Sort by zmin ascending
108
+ layers.sort(key=lambda layer: layer.zmin)
109
+ return layers
110
+
111
+
112
+ def _format_layer_name(name: str, _max_len: int = 20) -> str:
113
+ """Format layer name with abbreviation in parentheses."""
114
+ # Common abbreviations
115
+ abbrevs = {
116
+ "topmetal2": "TM2",
117
+ "topmetal1": "TM1",
118
+ "topvia2": "TV2",
119
+ "topvia1": "TV1",
120
+ "metal5": "M5",
121
+ "metal4": "M4",
122
+ "metal3": "M3",
123
+ "metal2": "M2",
124
+ "metal1": "M1",
125
+ "via4": "V4",
126
+ "via3": "V3",
127
+ "via2": "V2",
128
+ "via1": "V1",
129
+ "poly": "Poly",
130
+ "active": "Act",
131
+ "substrate": "Sub",
132
+ }
133
+
134
+ name_lower = name.lower()
135
+ if name_lower in abbrevs:
136
+ abbrev = abbrevs[name_lower]
137
+ display = name.capitalize() if name[0].islower() else name
138
+ return f"{display} ({abbrev})"
139
+
140
+ return name
141
+
142
+
143
+ def print_stack(pdk) -> str:
144
+ """Print an ASCII diagram of the layer stack.
145
+
146
+ Args:
147
+ pdk: A PDK module with LAYER_STACK, or a LayerStack directly
148
+
149
+ Returns:
150
+ The formatted string (also prints to stdout)
151
+
152
+ Examples:
153
+ ```python
154
+ import ihp
155
+
156
+ print_stack(ihp)
157
+ ```
158
+ """
159
+ # Extract LayerStack from PDK module if needed
160
+ layer_stack = pdk.LAYER_STACK if hasattr(pdk, "LAYER_STACK") else pdk
161
+
162
+ layers = parse_layer_stack(layer_stack)
163
+
164
+ if not layers:
165
+ return "No layers found in stack"
166
+
167
+ # Separate layers by type
168
+ substrate_layer = None
169
+ active_layers = [] # active, poly, etc.
170
+ metal_layers = []
171
+
172
+ for layer in layers:
173
+ if layer.layer_type == "substrate":
174
+ substrate_layer = layer
175
+ elif layer.name.lower() in ("active", "poly", "gatpoly"):
176
+ active_layers.append(layer)
177
+ else:
178
+ metal_layers.append(layer)
179
+
180
+ # Build the diagram
181
+ lines = []
182
+ width = 50
183
+ box_width = width - 4
184
+
185
+ # Title
186
+ title = "Layer Stack"
187
+ lines.append(f" Z (um){title:^{width + 10}}")
188
+ lines.append(" " + "─" * (width + 18))
189
+
190
+ # Sort metal layers by zmax descending for top-down drawing
191
+ metal_layers_sorted = sorted(
192
+ metal_layers, key=lambda layer: layer.zmax, reverse=True
193
+ )
194
+
195
+ # Draw top border
196
+ if metal_layers_sorted:
197
+ first_layer = metal_layers_sorted[0]
198
+ lines.append(f"{first_layer.zmax:7.2f} ┌{'─' * box_width}┐")
199
+
200
+ # Draw each metal layer from top to bottom
201
+ for _i, layer in enumerate(metal_layers_sorted):
202
+ display_name = _format_layer_name(layer.name)
203
+ thickness_str = (
204
+ f"{layer.thickness:.2f} um"
205
+ if layer.thickness >= 0.01
206
+ else f"{layer.thickness * 1000:.0f} nm"
207
+ )
208
+ layer_str = f"Layer {layer.gds_layer}" if layer.gds_layer else ""
209
+
210
+ name_part = f"{display_name:^{box_width - 24}}"
211
+ info_part = f"{thickness_str:>10} {layer_str:<10}"
212
+ content = f"{name_part}{info_part}"
213
+
214
+ lines.append(f"{'':>8}│{content:^{box_width}}│")
215
+ lines.append(f"{layer.zmin:7.2f} ├{'─' * box_width}┤")
216
+
217
+ # Dielectric/oxide region
218
+ lines.append(f"{'':>8}│{'(dielectric / oxide)':^{box_width}}│")
219
+
220
+ # Active layers (active, poly)
221
+ if active_layers:
222
+ active_sorted = sorted(
223
+ active_layers, key=lambda layer: layer.zmax, reverse=True
224
+ )
225
+ z_top = max(layer.zmax for layer in active_layers)
226
+ third = box_width // 3
227
+ tail = "─" * (box_width - 2 * third - 2)
228
+ lines.append(f"{z_top:7.2f} ├{'─' * third}┬{'─' * third}┬{tail}┤")
229
+
230
+ names = " ".join(layer.name.capitalize() for layer in active_sorted[:2])
231
+ gds_layers = ", ".join(
232
+ str(layer.gds_layer) for layer in active_sorted[:2] if layer.gds_layer
233
+ )
234
+ content = f"{names} ~{active_sorted[0].thickness:.1f} um Layer {gds_layers}"
235
+ lines.append(f"{'':>8}│{content:^{box_width}}│")
236
+ lines.append(f"{0.00:7.2f} ├{'─' * third}┴{'─' * third}┴{tail}┤")
237
+ else:
238
+ lines.append(f"{0.00:7.2f} ├{'─' * box_width}┤")
239
+
240
+ # Substrate
241
+ if substrate_layer:
242
+ lines.append(f"{'':>8}│{'':^{box_width}}│")
243
+ lines.append(f"{'':>8}│{'Substrate (Si)':^{box_width}}│")
244
+ sub_thickness = f"{abs(substrate_layer.thickness):.0f} um"
245
+ lines.append(f"{'':>8}│{sub_thickness:^{box_width}}│")
246
+ lines.append(f"{'':>8}│{'':^{box_width}}│")
247
+ lines.append(f"{substrate_layer.zmin:7.0f} └{'─' * box_width}┘")
248
+ else:
249
+ lines.append(f"{'':>8}└{'─' * box_width}┘")
250
+
251
+ lines.append(" " + "─" * (width + 18))
252
+
253
+ result = "\n".join(lines)
254
+ return result
255
+
256
+
257
+ def _find_overlap_groups(layers: list[StackLayer]) -> list[list[StackLayer]]:
258
+ """Group layers that overlap in z-range.
259
+
260
+ Returns list of groups, where each group contains layers that overlap.
261
+ Non-overlapping layers are in their own single-element groups.
262
+ """
263
+ if not layers:
264
+ return []
265
+
266
+ # Sort by zmin
267
+ sorted_layers = sorted(layers, key=lambda layer: layer.zmin)
268
+
269
+ groups = []
270
+ current_group = [sorted_layers[0]]
271
+ group_zmax = sorted_layers[0].zmax
272
+
273
+ for layer in sorted_layers[1:]:
274
+ # Check if this layer overlaps with current group
275
+ if layer.zmin < group_zmax:
276
+ # Overlaps - add to current group
277
+ current_group.append(layer)
278
+ group_zmax = max(group_zmax, layer.zmax)
279
+ else:
280
+ # No overlap - start new group
281
+ groups.append(current_group)
282
+ current_group = [layer]
283
+ group_zmax = layer.zmax
284
+
285
+ groups.append(current_group)
286
+ return groups
287
+
288
+
289
+ def plot_stack(pdk, width: float = 600, height: float = 800, to_scale: bool = False):
290
+ """Create an interactive plotly visualization of the layer stack.
291
+
292
+ Args:
293
+ pdk: A PDK module with LAYER_STACK, or a LayerStack directly
294
+ width: Figure width in pixels
295
+ height: Figure height in pixels
296
+ to_scale: If True, show actual z dimensions. If False (default),
297
+ use fixed height for all layers for better visibility.
298
+
299
+ Returns:
300
+ plotly Figure object (displays automatically in notebooks)
301
+
302
+ Examples:
303
+ ```python
304
+ import ihp
305
+
306
+ plot_stack(ihp)
307
+ ```
308
+ """
309
+ # Extract LayerStack from PDK module if needed
310
+ layer_stack = pdk.LAYER_STACK if hasattr(pdk, "LAYER_STACK") else pdk
311
+
312
+ layers = parse_layer_stack(layer_stack)
313
+
314
+ if not layers:
315
+ fig = go.Figure()
316
+ fig.add_annotation(text="No layers found", x=0.5, y=0.5, showarrow=False)
317
+ return fig
318
+
319
+ # Color scheme by layer type
320
+ colors = {
321
+ "conductor": "#4CAF50", # Green for metals/actives
322
+ "via": "#87CEEB", # Sky blue for vias
323
+ "substrate": "#D0D0D0", # Light gray for substrate
324
+ "dielectric": "#D0D0D0", # Light gray for dielectrics
325
+ }
326
+
327
+ # Find overlapping groups
328
+ overlap_groups = _find_overlap_groups(layers)
329
+
330
+ # Calculate column assignments and positions for each layer
331
+ layer_columns = {} # layer.name -> (column_index, total_columns_in_group)
332
+ for group in overlap_groups:
333
+ for i, layer in enumerate(group):
334
+ layer_columns[layer.name] = (i, len(group))
335
+
336
+ # Sort layers by zmin for consistent ordering
337
+ sorted_layers = sorted(layers, key=lambda layer: layer.zmin)
338
+
339
+ # Calculate uniform positions (not to scale)
340
+ # In uniform mode, simply stack all layers vertically
341
+ uniform_height = 1.0
342
+ uniform_positions = {} # layer.name -> (y0, y1)
343
+ current_y = 0
344
+ for layer in sorted_layers:
345
+ uniform_positions[layer.name] = (current_y, current_y + uniform_height)
346
+ current_y += uniform_height
347
+
348
+ fig = go.Figure()
349
+
350
+ # Base box width and x-coordinates
351
+ total_width = 4
352
+ base_x0 = 0
353
+
354
+ # Helper to calculate x-coordinates for a layer based on its column
355
+ # (for to-scale view).
356
+ def get_x_coords_scaled(layer_name):
357
+ col_idx, num_cols = layer_columns[layer_name]
358
+ col_width = total_width / num_cols
359
+ x0 = base_x0 + col_idx * col_width
360
+ x1 = x0 + col_width
361
+ return x0, x1
362
+
363
+ # Add shapes and traces for both views
364
+ # We'll use visibility toggling with buttons
365
+ for _i, layer in enumerate(sorted_layers):
366
+ color = colors.get(layer.layer_type, "#CCCCCC")
367
+
368
+ # In uniform view, use full width; in to-scale view, use columns
369
+ if to_scale:
370
+ x0, x1 = get_x_coords_scaled(layer.name)
371
+ else:
372
+ x0, x1 = base_x0, base_x0 + total_width
373
+
374
+ # Uniform (not to scale) positions
375
+ u_y0, u_y1 = uniform_positions[layer.name]
376
+
377
+ # To-scale positions
378
+ s_y0, s_y1 = layer.zmin, layer.zmax
379
+
380
+ # Use initial positions based on to_scale parameter
381
+ y0 = s_y0 if to_scale else u_y0
382
+ y1 = s_y1 if to_scale else u_y1
383
+
384
+ # Add rectangle shape
385
+ fig.add_shape(
386
+ type="rect",
387
+ x0=x0,
388
+ x1=x1,
389
+ y0=y0,
390
+ y1=y1,
391
+ fillcolor=color,
392
+ line=dict(color="black", width=1),
393
+ layer="below",
394
+ name=f"shape_{layer.name}",
395
+ )
396
+
397
+ # Add invisible scatter for hover info
398
+ fig.add_trace(
399
+ go.Scatter(
400
+ x=[(x0 + x1) / 2],
401
+ y=[(y0 + y1) / 2],
402
+ mode="markers",
403
+ marker=dict(size=20, opacity=0),
404
+ hoverinfo="text",
405
+ hovertext=(
406
+ f"<b>{layer.name}</b><br>"
407
+ f"GDS Layer: {layer.gds_layer or 'N/A'}<br>"
408
+ f"Type: {layer.layer_type}<br>"
409
+ f"Z: {layer.zmin:.2f} - {layer.zmax:.2f} µm<br>"
410
+ f"Thickness: {layer.thickness:.3f} µm<br>"
411
+ f"Material: {layer.material or 'N/A'}"
412
+ ),
413
+ showlegend=False,
414
+ name=f"hover_{layer.name}",
415
+ )
416
+ )
417
+
418
+ # Add layer label
419
+ fig.add_annotation(
420
+ x=(x0 + x1) / 2,
421
+ y=(y0 + y1) / 2,
422
+ text=f"<b>{layer.name}</b>",
423
+ showarrow=False,
424
+ font=dict(size=10),
425
+ name=f"label_{layer.name}",
426
+ )
427
+
428
+ # Add GDS layer number on the right side of the box
429
+ if layer.gds_layer is not None:
430
+ fig.add_annotation(
431
+ x=x1 - 0.1,
432
+ y=(y0 + y1) / 2,
433
+ text=f"{layer.gds_layer}",
434
+ showarrow=False,
435
+ font=dict(size=9, color="gray"),
436
+ xanchor="right",
437
+ )
438
+
439
+ # Build button data for toggling between views
440
+ def build_layout_update(use_scale):
441
+ shapes = []
442
+ annotations = []
443
+
444
+ for layer in sorted_layers:
445
+ color = colors.get(layer.layer_type, "#CCCCCC")
446
+
447
+ # Uniform view: full width; To-scale view: columns for overlaps
448
+ if use_scale:
449
+ lx0, lx1 = get_x_coords_scaled(layer.name)
450
+ y0, y1 = layer.zmin, layer.zmax
451
+ else:
452
+ lx0, lx1 = base_x0, base_x0 + total_width
453
+ y0, y1 = uniform_positions[layer.name]
454
+
455
+ shapes.append(
456
+ dict(
457
+ type="rect",
458
+ x0=lx0,
459
+ x1=lx1,
460
+ y0=y0,
461
+ y1=y1,
462
+ fillcolor=color,
463
+ line=dict(color="black", width=1),
464
+ layer="below",
465
+ )
466
+ )
467
+
468
+ annotations.append(
469
+ dict(
470
+ x=(lx0 + lx1) / 2,
471
+ y=(y0 + y1) / 2,
472
+ text=f"<b>{layer.name}</b>",
473
+ showarrow=False,
474
+ font=dict(size=10),
475
+ )
476
+ )
477
+
478
+ # Add GDS layer number on the right side
479
+ if layer.gds_layer is not None:
480
+ annotations.append(
481
+ dict(
482
+ x=lx1 - 0.1,
483
+ y=(y0 + y1) / 2,
484
+ text=f"{layer.gds_layer}",
485
+ showarrow=False,
486
+ font=dict(size=9, color="gray"),
487
+ xanchor="right",
488
+ )
489
+ )
490
+
491
+ y_title = "Z (µm)" if use_scale else "Layer (not to scale)"
492
+ if use_scale:
493
+ y_range = [
494
+ min(layer.zmin for layer in sorted_layers) - 1,
495
+ max(layer.zmax for layer in sorted_layers) + 1,
496
+ ]
497
+ else:
498
+ y_range = [-0.5, len(sorted_layers) + 0.5]
499
+
500
+ return dict(
501
+ shapes=shapes,
502
+ annotations=annotations,
503
+ yaxis=dict(
504
+ title=y_title,
505
+ range=y_range,
506
+ showgrid=False,
507
+ zeroline=False,
508
+ showticklabels=use_scale,
509
+ ),
510
+ )
511
+
512
+ # Build scatter y-positions for each view
513
+ def build_scatter_update(use_scale):
514
+ updates = []
515
+ for layer in sorted_layers:
516
+ if use_scale:
517
+ y0, y1 = layer.zmin, layer.zmax
518
+ else:
519
+ y0, y1 = uniform_positions[layer.name]
520
+ updates.append([(y0 + y1) / 2])
521
+ return updates
522
+
523
+ uniform_scatter_y = build_scatter_update(False)
524
+ scale_scatter_y = build_scatter_update(True)
525
+
526
+ # Initial y-range
527
+ if to_scale:
528
+ y_range = [
529
+ min(layer.zmin for layer in sorted_layers) - 1,
530
+ max(layer.zmax for layer in sorted_layers) + 1,
531
+ ]
532
+ else:
533
+ y_range = [-0.5, len(sorted_layers) + 0.5]
534
+
535
+ # Layout with buttons
536
+ fig.update_layout(
537
+ title="Layer Stack",
538
+ width=width,
539
+ height=height,
540
+ xaxis=dict(
541
+ showgrid=False,
542
+ zeroline=False,
543
+ showticklabels=False,
544
+ range=[-0.5, 4.5],
545
+ ),
546
+ yaxis=dict(
547
+ title="Layer (not to scale)" if not to_scale else "Z (µm)",
548
+ showgrid=False,
549
+ zeroline=False,
550
+ showticklabels=to_scale,
551
+ range=y_range,
552
+ ),
553
+ plot_bgcolor="white",
554
+ hoverlabel=dict(bgcolor="white"),
555
+ updatemenus=[
556
+ dict(
557
+ type="buttons",
558
+ direction="left",
559
+ x=0.0,
560
+ y=1.15,
561
+ xanchor="left",
562
+ yanchor="top",
563
+ buttons=[
564
+ dict(
565
+ label="Uniform",
566
+ method="update",
567
+ args=[
568
+ {"y": uniform_scatter_y},
569
+ build_layout_update(False),
570
+ ],
571
+ ),
572
+ dict(
573
+ label="To Scale",
574
+ method="update",
575
+ args=[
576
+ {"y": scale_scatter_y},
577
+ build_layout_update(True),
578
+ ],
579
+ ),
580
+ ],
581
+ ),
582
+ ],
583
+ )
584
+
585
+ return fig
586
+
587
+
588
+ def print_stack_table(pdk) -> str:
589
+ """Print a table of layer information.
590
+
591
+ Args:
592
+ pdk: A PDK module with LAYER_STACK, or a LayerStack directly
593
+
594
+ Returns:
595
+ The formatted string (also prints to stdout)
596
+
597
+ Examples:
598
+ ```python
599
+ import ihp
600
+
601
+ print_stack_table(ihp)
602
+ ```
603
+ """
604
+ layer_stack = pdk.LAYER_STACK if hasattr(pdk, "LAYER_STACK") else pdk
605
+
606
+ layers = parse_layer_stack(layer_stack)
607
+
608
+ lines = []
609
+ lines.append("\nLayer Stack Table")
610
+ lines.append("=" * 80)
611
+ lines.append(
612
+ f"{'Layer':<15} {'GDS':<8} {'Type':<12} "
613
+ f"{'Z-min':>10} {'Z-max':>10} {'Thick':>10} {'Material':<12}"
614
+ )
615
+ lines.append("-" * 80)
616
+
617
+ # Sort by zmin descending (top to bottom)
618
+ for layer in sorted(layers, key=lambda layer: layer.zmin, reverse=True):
619
+ gds = str(layer.gds_layer) if layer.gds_layer else "-"
620
+ material = layer.material or "-"
621
+ lines.append(
622
+ f"{layer.name:<15} {gds:<8} {layer.layer_type:<12} "
623
+ f"{layer.zmin:>10.2f} {layer.zmax:>10.2f} "
624
+ f"{layer.thickness:>10.2f} {material:<12}"
625
+ )
626
+
627
+ lines.append("=" * 80)
628
+
629
+ result = "\n".join(lines)
630
+ return result
gsim/viz.py ADDED
@@ -0,0 +1,86 @@
1
+ """Visualization utilities for gsim.
2
+
3
+ This module provides visualization tools for meshes and simulation results.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ from pathlib import Path
9
+
10
+ import meshio # type: ignore[import-untyped]
11
+ import pyvista as pv # type: ignore[import-untyped]
12
+
13
+
14
+ def plot_mesh(
15
+ msh_path: str | Path,
16
+ output: str | Path | None = None,
17
+ show_groups: list[str] | None = None,
18
+ interactive: bool = True,
19
+ ) -> None:
20
+ """Plot mesh wireframe using PyVista.
21
+
22
+ Args:
23
+ msh_path: Path to .msh file
24
+ output: Output PNG path (only used if interactive=False)
25
+ show_groups: List of group name patterns to show (None = all).
26
+ Example: ["metal", "P"] to show metal layers and ports.
27
+ interactive: If True, open interactive 3D viewer.
28
+ If False, save static PNG to output path.
29
+
30
+ Example:
31
+ >>> pa.plot_mesh("./sim/palace.msh", show_groups=["metal", "P"])
32
+ """
33
+ msh_path = Path(msh_path)
34
+
35
+ # Get group info from meshio
36
+ mio = meshio.read(msh_path)
37
+ group_map = {tag: name for name, (tag, _) in mio.field_data.items()}
38
+
39
+ # Load mesh with pyvista
40
+ mesh = pv.read(msh_path)
41
+
42
+ if interactive:
43
+ plotter = pv.Plotter(window_size=[1200, 900])
44
+ else:
45
+ plotter = pv.Plotter(off_screen=True, window_size=[1200, 900])
46
+
47
+ plotter.set_background("white")
48
+
49
+ if show_groups:
50
+ # Filter to matching groups
51
+ ids = [
52
+ tag
53
+ for tag, name in group_map.items()
54
+ if any(p in name for p in show_groups)
55
+ ]
56
+ colors = ["red", "blue", "green", "orange", "purple", "cyan"]
57
+ for i, gid in enumerate(ids):
58
+ subset = mesh.extract_cells(mesh.cell_data["gmsh:physical"] == gid)
59
+ if subset.n_cells > 0:
60
+ plotter.add_mesh(
61
+ subset,
62
+ style="wireframe",
63
+ color=colors[i % len(colors)],
64
+ line_width=1,
65
+ label=group_map.get(gid, str(gid)),
66
+ )
67
+ plotter.add_legend()
68
+ else:
69
+ plotter.add_mesh(mesh, style="wireframe", color="black", line_width=1)
70
+
71
+ plotter.camera_position = "iso"
72
+
73
+ if interactive:
74
+ plotter.show()
75
+ else:
76
+ if output is None:
77
+ output = msh_path.with_suffix(".png")
78
+ plotter.screenshot(str(output))
79
+ plotter.close()
80
+ # Display in notebook if available
81
+ try:
82
+ from IPython.display import Image, display # type: ignore[import-untyped]
83
+
84
+ display(Image(str(output)))
85
+ except ImportError:
86
+ print(f"Saved: {output}")