lucid-dl 2.11.2__py3-none-any.whl → 2.11.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,818 @@
1
+ from dataclasses import dataclass
2
+ import json
3
+ from typing import Iterable, Literal
4
+
5
+ import lucid
6
+ import lucid.nn as nn
7
+
8
+ from lucid._tensor import Tensor
9
+ from lucid.types import _ShapeLike
10
+
11
+
12
+ __all__ = ["build_mermaid_chart"]
13
+
14
+
15
+ _NN_MODULES_PREFIX = "lucid.nn.modules."
16
+
17
+
18
+ # fmt: off
19
+ _BUILTIN_SUBPACKAGE_STYLE: dict[str, tuple[str, str]] = {
20
+ "conv": ("#ffe8e8", "#c53030"),
21
+ "norm": ("#e6fffa", "#2c7a7b"),
22
+ "activation": ("#faf5ff", "#6b46c1"),
23
+ "linear": ("#ebf8ff", "#2b6cb0"),
24
+ "pool": ("#fefcbf", "#b7791f"),
25
+ "drop": ("#edf2f7", "#4a5568"),
26
+ "transformer": ("#e2e8f0", "#334155"),
27
+ "attention": ("#f0fff4", "#2f855a"),
28
+ "vision": ("#fdf2f8", "#b83280"),
29
+ "rnn": ("#f0f9ff", "#0284c7"),
30
+ "sparse": ("#f1f5f9", "#475569"),
31
+ "loss": ("#fffbeb", "#d97706"),
32
+ "einops": ("#ecfccb", "#65a30d"),
33
+ }
34
+ # fmt: on
35
+
36
+
37
+ @dataclass
38
+ class _ModuleNode:
39
+ module: nn.Module
40
+ name: str
41
+ depth: int
42
+ children: list["_ModuleNode"]
43
+ group: list[nn.Module] | None = None
44
+
45
+ @property
46
+ def count(self) -> int:
47
+ return 1 if self.group is None else len(self.group)
48
+
49
+ def iter_modules(self) -> Iterable[nn.Module]:
50
+ if self.group is None:
51
+ yield self.module
52
+ return
53
+ yield from self.group
54
+
55
+
56
+ def _flatten_tensors(obj: object) -> list[Tensor]:
57
+ tensors: list[Tensor] = []
58
+
59
+ if isinstance(obj, Tensor):
60
+ tensors.append(obj)
61
+ elif isinstance(obj, (list, tuple)):
62
+ for item in obj:
63
+ tensors.extend(_flatten_tensors(item))
64
+ elif isinstance(obj, dict):
65
+ for item in obj.values():
66
+ tensors.extend(_flatten_tensors(item))
67
+
68
+ return tensors
69
+
70
+
71
+ def _build_tree(
72
+ module: nn.Module,
73
+ depth: int,
74
+ max_depth: int,
75
+ name: str = "",
76
+ *,
77
+ collapse_repeats: bool = False,
78
+ repeat_min: int = 3,
79
+ hide_subpackages: set[str] | None = None,
80
+ hide_module_names: set[str] | None = None,
81
+ ) -> _ModuleNode:
82
+ children: list[_ModuleNode] = []
83
+ if depth < max_depth:
84
+ for child_name, child in module._modules.items():
85
+ path = f"{name}.{child_name}" if name else child_name
86
+
87
+ node = _build_tree(
88
+ child,
89
+ depth + 1,
90
+ max_depth,
91
+ path,
92
+ collapse_repeats=collapse_repeats,
93
+ repeat_min=repeat_min,
94
+ hide_subpackages=hide_subpackages,
95
+ hide_module_names=hide_module_names,
96
+ )
97
+
98
+ child_cls_name = type(child).__name__
99
+ child_mod_path = type(child).__module__
100
+ child_subpkg = None
101
+
102
+ if child_mod_path.startswith(_NN_MODULES_PREFIX):
103
+ rest = child_mod_path[len(_NN_MODULES_PREFIX) :]
104
+ child_subpkg = rest.split(".", 1)[0] if rest else None
105
+
106
+ excluded = False
107
+ if hide_module_names and child_cls_name in hide_module_names:
108
+ excluded = True
109
+ if hide_subpackages and child_subpkg and child_subpkg in hide_subpackages:
110
+ excluded = True
111
+
112
+ if excluded:
113
+ children.extend(node.children)
114
+ else:
115
+ children.append(node)
116
+
117
+ if collapse_repeats and children:
118
+ children = _collapse_repeated_children(children, repeat_min=repeat_min)
119
+
120
+ return _ModuleNode(module=module, name=name, depth=depth, children=children)
121
+
122
+
123
+ def _module_label(module: nn.Module, show_params: bool) -> str:
124
+ class_name = getattr(module, "_alt_name", "") or type(module).__name__
125
+ if show_params:
126
+ return f"{class_name} ({module.parameter_size:,} params)"
127
+
128
+ return class_name
129
+
130
+
131
+ def _builtin_subpackage_key(module: nn.Module) -> str | None:
132
+ mod_path = type(module).__module__
133
+ if not mod_path.startswith(_NN_MODULES_PREFIX):
134
+ return None
135
+
136
+ rest = mod_path[len(_NN_MODULES_PREFIX) :]
137
+ return rest.split(".", 1)[0] if rest else None
138
+
139
+
140
+ def _shape_text_color(module: nn.Module) -> str | None:
141
+ subpkg = _builtin_subpackage_key(module)
142
+ if subpkg is None:
143
+ return None
144
+
145
+ style = _BUILTIN_SUBPACKAGE_STYLE.get(subpkg)
146
+ if style is None:
147
+ return None
148
+
149
+ _, stroke = style
150
+ return stroke
151
+
152
+
153
+ def _parse_rgba(value: str) -> tuple[str, float] | None:
154
+ v = value.strip().lower()
155
+ if not v.startswith("rgba(") or not v.endswith(")"):
156
+ return None
157
+
158
+ inner = v[5:-1]
159
+ parts = [p.strip() for p in inner.split(",")]
160
+ if len(parts) != 4:
161
+ return None
162
+
163
+ try:
164
+ r, g, b = (int(float(x)) for x in parts[:3])
165
+ a = float(parts[3])
166
+ except Exception:
167
+ return None
168
+
169
+ r = max(0, min(255, r))
170
+ g = max(0, min(255, g))
171
+ b = max(0, min(255, b))
172
+ a = max(0.0, min(1.0, a))
173
+
174
+ return (f"#{r:02x}{g:02x}{b:02x}", a)
175
+
176
+
177
+ def _container_attr_label(node: _ModuleNode) -> str | None:
178
+ if not isinstance(node.module, (nn.Sequential, nn.ModuleList, nn.ModuleDict)):
179
+ return None
180
+ if not node.name:
181
+ return None
182
+
183
+ leaf = node.name.rsplit(".", 1)[-1]
184
+ if leaf.isdigit():
185
+ return None
186
+
187
+ return leaf
188
+
189
+
190
+ def _escape_label(text: str) -> str:
191
+ return text.replace('"', "&quot;")
192
+
193
+
194
+ def _shape_str(shape: object) -> str:
195
+ try:
196
+ if isinstance(shape, tuple):
197
+ return "(" + ",".join(str(x) for x in shape) + ")"
198
+ except Exception:
199
+ pass
200
+ return str(shape)
201
+
202
+
203
+ def _flatten_shapes(obj: object) -> list[tuple[int, ...]]:
204
+ shapes: list[tuple[int, ...]] = []
205
+ for t in _flatten_tensors(obj):
206
+ try:
207
+ shapes.append(tuple(int(x) for x in t.shape))
208
+ except Exception:
209
+ continue
210
+ return shapes
211
+
212
+
213
+ def _shapes_brief(shapes: list[tuple[int, ...]]) -> str:
214
+ if not shapes:
215
+ return "?"
216
+ if len(shapes) == 1:
217
+ return _shape_str(shapes[0])
218
+ return f"{_shape_str(shapes[0])}x{len(shapes)}"
219
+
220
+
221
+ def _node_signature(node: _ModuleNode) -> tuple:
222
+ return (type(node.module), tuple(_node_signature(c) for c in node.children))
223
+
224
+
225
+ def _collapse_repeated_children(
226
+ children: list[_ModuleNode], *, repeat_min: int
227
+ ) -> list[_ModuleNode]:
228
+ if repeat_min <= 1:
229
+ return children
230
+
231
+ out: list[_ModuleNode] = []
232
+ i = 0
233
+ while i < len(children):
234
+ base = children[i]
235
+ base_sig = _node_signature(base)
236
+ j = i + 1
237
+ while j < len(children) and _node_signature(children[j]) == base_sig:
238
+ j += 1
239
+
240
+ run = children[i:j]
241
+ if len(run) >= repeat_min:
242
+ out.append(
243
+ _ModuleNode(
244
+ module=base.module,
245
+ name=base.name,
246
+ depth=base.depth,
247
+ children=base.children,
248
+ group=[n.module for n in run],
249
+ )
250
+ )
251
+ else:
252
+ out.extend(run)
253
+
254
+ i = j
255
+ return out
256
+
257
+
258
+ def build_mermaid_chart(
259
+ module: nn.Module,
260
+ input_shape: _ShapeLike | list[_ShapeLike] | None = None,
261
+ inputs: Iterable[Tensor] | Tensor | None = None,
262
+ depth: int = 2,
263
+ direction: str = "LR",
264
+ include_io: bool = True,
265
+ show_params: bool = False,
266
+ return_lines: bool = False,
267
+ copy_to_clipboard: bool = False,
268
+ compact: bool = False,
269
+ use_class_defs: bool = False,
270
+ end_semicolons: bool = True,
271
+ edge_mode: Literal["dataflow", "execution"] = "execution",
272
+ collapse_repeats: bool = True,
273
+ repeat_min: int = 2,
274
+ color_by_subpackage: bool = True,
275
+ container_name_from_attr: bool = True,
276
+ edge_stroke_width: float = 2.0,
277
+ emphasize_model_title: bool = True,
278
+ model_title_font_px: int = 20,
279
+ show_shapes: bool = False,
280
+ hide_subpackages: Iterable[str] = (),
281
+ hide_module_names: Iterable[str] = (),
282
+ dash_multi_input_edges: bool = True,
283
+ subgraph_fill: str = "#000000",
284
+ subgraph_fill_opacity: float = 0.05,
285
+ subgraph_stroke: str = "#000000",
286
+ subgraph_stroke_opacity: float = 0.75,
287
+ force_text_color: str | None = None,
288
+ edge_curve: str = "natural",
289
+ node_spacing: int = 50,
290
+ rank_spacing: int = 50,
291
+ **forward_kwargs,
292
+ ) -> str | list[str]:
293
+ if inputs is None and input_shape is None:
294
+ raise ValueError("Either inputs or input_shape must be provided.")
295
+ if depth < 0:
296
+ raise ValueError("depth must be >= 0")
297
+
298
+ tree = _build_tree(
299
+ module,
300
+ depth=0,
301
+ max_depth=depth,
302
+ collapse_repeats=collapse_repeats,
303
+ repeat_min=repeat_min,
304
+ hide_subpackages=set(hide_subpackages),
305
+ hide_module_names=set(hide_module_names),
306
+ )
307
+
308
+ nodes: list[_ModuleNode] = []
309
+
310
+ def _collect(n: _ModuleNode) -> None:
311
+ nodes.append(n)
312
+ for c in n.children:
313
+ _collect(c)
314
+
315
+ _collect(tree)
316
+
317
+ module_to_node: dict[nn.Module, _ModuleNode] = {n.module: n for n in nodes}
318
+
319
+ module_to_id: dict[nn.Module, str] = {}
320
+ for idx, n in enumerate(nodes):
321
+ node_id = f"m{idx}"
322
+ for mod in n.iter_modules():
323
+ module_to_id[mod] = node_id
324
+
325
+ def _build_parent_map(root: nn.Module) -> dict[nn.Module, nn.Module]:
326
+ parent: dict[nn.Module, nn.Module] = {}
327
+
328
+ def _walk(mod: nn.Module) -> None:
329
+ for child in mod._modules.values():
330
+ parent[child] = mod
331
+ _walk(child)
332
+
333
+ _walk(root)
334
+ return parent
335
+
336
+ parent_map = _build_parent_map(module)
337
+ has_non_root_included = any(mod is not module for mod in module_to_id)
338
+
339
+ def _map_to_included(mod: nn.Module) -> nn.Module | None:
340
+ cur = mod
341
+ while cur not in module_to_id and cur in parent_map:
342
+ cur = parent_map[cur]
343
+ return cur if cur in module_to_id else None
344
+
345
+ hooks = []
346
+ edges: set[tuple[str, str]] = set()
347
+ tensor_producer: dict[int, nn.Module] = {}
348
+ module_in_shapes: dict[nn.Module, list[tuple[int, ...]]] = {}
349
+ module_out_shapes: dict[nn.Module, list[tuple[int, ...]]] = {}
350
+ input_node_id = "input"
351
+ output_node_id = "output"
352
+ root_module = module
353
+ exec_order: list[nn.Module] = []
354
+
355
+ def _hook(
356
+ mod: nn.Module, input_arg: tuple, output: Tensor | tuple[Tensor, ...]
357
+ ) -> None:
358
+ mapped_mod = _map_to_included(mod)
359
+ if mapped_mod is None:
360
+ return
361
+
362
+ if show_shapes:
363
+ module_in_shapes[mapped_mod] = _flatten_shapes(input_arg)
364
+ module_out_shapes[mapped_mod] = _flatten_shapes(output)
365
+
366
+ if edge_mode == "dataflow":
367
+ input_tensors = _flatten_tensors(input_arg)
368
+ for t in input_tensors:
369
+ producer = tensor_producer.get(id(t))
370
+ if producer is None:
371
+ if include_io and (
372
+ mapped_mod is not root_module or not has_non_root_included
373
+ ):
374
+ edges.add((input_node_id, module_to_id[mapped_mod]))
375
+
376
+ else:
377
+ if producer is not mapped_mod:
378
+ edges.add((module_to_id[producer], module_to_id[mapped_mod]))
379
+
380
+ output_tensors = _flatten_tensors(output)
381
+ for t in output_tensors:
382
+ key = id(t)
383
+ if mapped_mod is root_module and key in tensor_producer:
384
+ continue
385
+ tensor_producer[key] = mapped_mod
386
+
387
+ exec_order.append(mapped_mod)
388
+
389
+ for mod in module.modules():
390
+ hooks.append(mod.register_forward_hook(_hook))
391
+
392
+ try:
393
+ if inputs is None:
394
+ if isinstance(input_shape, list):
395
+ input_tensors = [
396
+ lucid.random.randn(shape, device=module.device)
397
+ for shape in input_shape
398
+ ]
399
+ else:
400
+ input_tensors = [lucid.random.randn(input_shape, device=module.device)]
401
+ else:
402
+ if isinstance(inputs, Tensor):
403
+ input_tensors = [inputs]
404
+ else:
405
+ input_tensors = list(inputs)
406
+
407
+ outputs = module(*input_tensors, **forward_kwargs)
408
+ finally:
409
+ for remove in hooks:
410
+ remove()
411
+
412
+ model_input_shapes = _flatten_shapes(input_tensors)
413
+ model_output_shapes = _flatten_shapes(outputs)
414
+
415
+ if edge_mode == "execution":
416
+ seq = []
417
+ for m in exec_order:
418
+ if not seq or seq[-1] is not m:
419
+ seq.append(m)
420
+ seq_no_root = [m for m in seq if m is not root_module]
421
+ if not seq_no_root:
422
+ seq_no_root = [root_module]
423
+
424
+ def _is_included_container(mod: nn.Module) -> bool:
425
+ node = module_to_node.get(mod)
426
+ return bool(node and node.children)
427
+
428
+ seq_effective = [m for m in seq_no_root if not _is_included_container(m)]
429
+ if not seq_effective:
430
+ seq_effective = seq_no_root
431
+
432
+ for prev, cur in zip(seq_effective, seq_effective[1:]):
433
+ edges.add((module_to_id[prev], module_to_id[cur]))
434
+
435
+ if include_io and seq_effective:
436
+ edges.add((input_node_id, module_to_id[seq_effective[0]]))
437
+ edges.add((module_to_id[seq_effective[-1]], output_node_id))
438
+
439
+ else:
440
+ if include_io:
441
+ output_tensors = _flatten_tensors(outputs)
442
+ for t in output_tensors:
443
+ producer = tensor_producer.get(id(t))
444
+ if producer is not None:
445
+ edges.add((module_to_id[producer], output_node_id))
446
+
447
+ node_ids = set(module_to_id.values())
448
+ nodes_with_edges: set[str] = set()
449
+
450
+ for src, dst in edges:
451
+ if src in node_ids:
452
+ nodes_with_edges.add(src)
453
+ if dst in node_ids:
454
+ nodes_with_edges.add(dst)
455
+
456
+ container_node_ids = {module_to_id[n.module] for n in nodes if n.children}
457
+ container_with_edges = container_node_ids & nodes_with_edges
458
+ extra_edges: list[tuple[str, str]] = []
459
+
460
+ def _endpoint_in(node_id: str) -> str:
461
+ return f"{node_id}_in" if node_id in container_with_edges else node_id
462
+
463
+ def _endpoint_out(node_id: str) -> str:
464
+ return f"{node_id}_out" if node_id in container_with_edges else node_id
465
+
466
+ def _first_leaf_id(n: _ModuleNode) -> str:
467
+ cur = n
468
+ while cur.children:
469
+ cur = cur.children[0]
470
+ return module_to_id[cur.module]
471
+
472
+ def _last_leaf_id(n: _ModuleNode) -> str:
473
+ cur = n
474
+ while cur.children:
475
+ cur = cur.children[-1]
476
+ return module_to_id[cur.module]
477
+
478
+ lines: list[str] = []
479
+ init_cfg: dict[str, object] = {
480
+ "flowchart": {
481
+ "curve": "step" if edge_curve == "round" else edge_curve,
482
+ "nodeSpacing": node_spacing,
483
+ "rankSpacing": rank_spacing,
484
+ }
485
+ }
486
+ css_parts: list[str] = []
487
+ if edge_curve == "round":
488
+ css_parts.append(
489
+ ".edgePath path { stroke-linecap: round; stroke-linejoin: round; }"
490
+ )
491
+ if force_text_color:
492
+ css_parts.append(
493
+ f".nodeLabel, .edgeLabel, .cluster text, .node text "
494
+ f"{{ fill: {force_text_color} !important; }} "
495
+ f".node foreignObject *, .cluster foreignObject * "
496
+ f"{{ color: {force_text_color} !important; }}"
497
+ )
498
+ if css_parts:
499
+ init_cfg["themeCSS"] = " ".join(css_parts)
500
+
501
+ lines.append(f"%%{{init: {json.dumps(init_cfg, separators=(',', ':'))} }}%%")
502
+ lines.append(f"flowchart {direction}")
503
+
504
+ if edge_stroke_width and edge_stroke_width != 1.0:
505
+ lines.append(f" linkStyle default stroke-width:{edge_stroke_width}px")
506
+
507
+ if use_class_defs:
508
+ lines.append(" classDef module fill:#f9f9f9,stroke:#333,stroke-width:1px;")
509
+ lines.append(" classDef modelio fill:#fff3cd,stroke:#a67c00,stroke-width:1px;")
510
+ lines.append(
511
+ " classDef internalio fill:#e2e8f0,stroke:#64748b,stroke-width:1px;"
512
+ )
513
+ lines.append(
514
+ " classDef anchor fill:transparent,stroke:transparent,color:transparent;"
515
+ )
516
+ lines.append(" classDef repeat fill:#e8f1ff,stroke:#2b6cb0,stroke-width:1px;")
517
+
518
+ vertical_levels: set[int] = set()
519
+ if depth >= 3:
520
+ vertical_levels = {depth - 2, depth - 1}
521
+
522
+ def _render(n: _ModuleNode, indent: str = " ") -> None:
523
+ node_id = module_to_id[n.module]
524
+ base_label = _module_label(n.module, show_params)
525
+
526
+ if container_name_from_attr and n.children:
527
+ attr_label = _container_attr_label(n)
528
+ if attr_label is not None:
529
+ base_label = attr_label
530
+
531
+ base_label = _escape_label(base_label)
532
+ label_text = base_label if n.count == 1 else f"{base_label} x {n.count}"
533
+
534
+ if show_shapes and not n.children:
535
+ in_shapes = module_in_shapes.get(n.module, [])
536
+ out_shapes = module_out_shapes.get(n.module, [])
537
+
538
+ if in_shapes != out_shapes and (in_shapes or out_shapes):
539
+ ins = _shapes_brief(in_shapes)
540
+ outs = _shapes_brief(out_shapes)
541
+ color_css = ""
542
+ if not force_text_color:
543
+ color = _shape_text_color(n.module)
544
+ color_css = f"color:{color};" if color else ""
545
+ label_text = (
546
+ f"{label_text}<br/>"
547
+ f"<span style='font-size:11px;{color_css}font-weight:400'>"
548
+ f"{ins} \u2192 {outs}"
549
+ f"</span>"
550
+ )
551
+
552
+ label = label_text
553
+ if (
554
+ emphasize_model_title
555
+ and n.module is root_module
556
+ and model_title_font_px
557
+ and model_title_font_px > 0
558
+ ):
559
+ label = (
560
+ f"<span style='font-size:{model_title_font_px}px;font-weight:700'>"
561
+ f"{label}"
562
+ f"</span>"
563
+ )
564
+
565
+ if n.children:
566
+ lines.append(f'{indent}subgraph sg_{node_id}["{label}"]')
567
+ if n.depth in vertical_levels:
568
+ lines.append(f"{indent} direction TB")
569
+ if subgraph_fill or subgraph_stroke:
570
+ parts: list[str] = []
571
+
572
+ fill = subgraph_fill
573
+ fill_opacity = subgraph_fill_opacity
574
+ parsed = _parse_rgba(fill)
575
+
576
+ if parsed is not None:
577
+ fill, fill_opacity = parsed
578
+ if fill:
579
+ parts.append(f"fill:{fill}")
580
+ parts.append(f"fill-opacity:{fill_opacity}")
581
+
582
+ stroke = subgraph_stroke
583
+ stroke_opacity = subgraph_stroke_opacity
584
+ if stroke:
585
+ parsed = _parse_rgba(stroke)
586
+ if parsed is not None:
587
+ stroke, stroke_opacity = parsed
588
+
589
+ parts.append(f"stroke:{stroke}")
590
+ parts.append(f"stroke-opacity:{stroke_opacity}")
591
+ parts.append("stroke-width:1px")
592
+
593
+ lines.append(f'{indent}style sg_{node_id} {",".join(parts)}')
594
+
595
+ if node_id in nodes_with_edges:
596
+ if node_id in container_with_edges:
597
+ in_id = f"{node_id}_in"
598
+ out_id = f"{node_id}_out"
599
+
600
+ if use_class_defs:
601
+ lines.append(f'{indent} {in_id}(["Input"]):::internalio')
602
+ lines.append(f'{indent} {out_id}(["Output"]):::internalio')
603
+ else:
604
+ lines.append(f'{indent} {in_id}(["Input"])')
605
+ lines.append(f'{indent} {out_id}(["Output"])')
606
+ lines.append(
607
+ f" style {in_id} fill:#e2e8f0,stroke:#64748b,stroke-width:1px;"
608
+ )
609
+ lines.append(
610
+ f" style {out_id} fill:#e2e8f0,stroke:#64748b,stroke-width:1px;"
611
+ )
612
+
613
+ extra_edges.append((in_id, _endpoint_in(_first_leaf_id(n))))
614
+ extra_edges.append((_endpoint_out(_last_leaf_id(n)), out_id))
615
+
616
+ else:
617
+ if use_class_defs:
618
+ lines.append(f'{indent} {node_id}[""]:::anchor')
619
+ else:
620
+ lines.append(f'{indent} {node_id}["\u200b"]')
621
+ lines.append(
622
+ f" style {node_id} fill:transparent,stroke:transparent,color:transparent;"
623
+ )
624
+
625
+ for c in n.children:
626
+ _render(c, indent + " ")
627
+ lines.append(f"{indent}end")
628
+
629
+ else:
630
+ if use_class_defs:
631
+ class_name = "module" if n.count == 1 else "repeat"
632
+ lines.append(f'{indent}{node_id}["{label}"]:::{class_name}')
633
+ else:
634
+ if n.count > 1:
635
+ lines.append(f'{indent}{node_id}(["{label}"])')
636
+ else:
637
+ lines.append(f'{indent}{node_id}["{label}"]')
638
+
639
+ _render(tree)
640
+
641
+ if include_io:
642
+ input_label = "Input"
643
+ output_label = "Output"
644
+ if show_shapes:
645
+ in_s = _shapes_brief(model_input_shapes)
646
+ out_s = _shapes_brief(model_output_shapes)
647
+ io_color = force_text_color or "#a67c00"
648
+ input_label = (
649
+ f"{input_label}<br/>"
650
+ f"<span style='font-size:11px;color:{io_color};font-weight:400'>{in_s}</span>"
651
+ )
652
+ output_label = (
653
+ f"{output_label}<br/>"
654
+ f"<span style='font-size:11px;color:{io_color};font-weight:400'>{out_s}</span>"
655
+ )
656
+ if use_class_defs:
657
+ lines.append(f' {input_node_id}["{input_label}"]:::modelio')
658
+ lines.append(f' {output_node_id}["{output_label}"]:::modelio')
659
+ else:
660
+ lines.append(f' {input_node_id}["{input_label}"]')
661
+ lines.append(f' {output_node_id}["{output_label}"]')
662
+ lines.append(
663
+ f" style {input_node_id} fill:#fff3cd,stroke:#a67c00,stroke-width:1px;"
664
+ )
665
+ lines.append(
666
+ f" style {output_node_id} fill:#fff3cd,stroke:#a67c00,stroke-width:1px;"
667
+ )
668
+
669
+ if color_by_subpackage:
670
+ for n in nodes:
671
+ if n.children:
672
+ continue
673
+
674
+ subpkg = _builtin_subpackage_key(n.module)
675
+ if subpkg is None:
676
+ continue
677
+ style = _BUILTIN_SUBPACKAGE_STYLE.get(subpkg)
678
+ if style is None:
679
+ continue
680
+
681
+ fill, stroke = style
682
+ node_id = module_to_id[n.module]
683
+ if node_id in {input_node_id, output_node_id}:
684
+ continue
685
+
686
+ if node_id.endswith("_in") or node_id.endswith("_out"):
687
+ continue
688
+ lines.append(
689
+ f" style {node_id} fill:{fill},stroke:{stroke},stroke-width:1px;"
690
+ )
691
+
692
+ render_edges: set[tuple[str, str]] = set()
693
+ for src, dst in edges:
694
+ src_id = _endpoint_out(src)
695
+ dst_id = _endpoint_in(dst)
696
+ if src_id != dst_id:
697
+ render_edges.add((src_id, dst_id))
698
+
699
+ for src, dst in extra_edges:
700
+ if src != dst:
701
+ render_edges.add((src, dst))
702
+
703
+ indegree: dict[str, int] = {}
704
+ for _, dst in render_edges:
705
+ indegree[dst] = indegree.get(dst, 0) + 1
706
+
707
+ for src, dst in sorted(render_edges):
708
+ arrow = "-.->" if dash_multi_input_edges and indegree.get(dst, 0) > 1 else "-->"
709
+ lines.append(f" {src} {arrow} {dst}")
710
+
711
+ def _finalize_lines(src_lines: list[str]) -> list[str]:
712
+ if not end_semicolons:
713
+ return src_lines
714
+
715
+ out: list[str] = []
716
+ for line in src_lines:
717
+ stripped = line.rstrip()
718
+ head = stripped.lstrip()
719
+ if (
720
+ not head
721
+ or head.startswith("flowchart ")
722
+ or head.startswith("linkStyle ")
723
+ or head.startswith("subgraph ")
724
+ or head == "end"
725
+ or head.startswith("classDef ")
726
+ or head.startswith("class ")
727
+ or head.startswith("style ")
728
+ or head.startswith("%%")
729
+ ):
730
+ out.append(line)
731
+ continue
732
+
733
+ if stripped.endswith(";"):
734
+ out.append(line)
735
+ else:
736
+ out.append(f"{line};")
737
+
738
+ return out
739
+
740
+ final_lines = _finalize_lines(lines)
741
+
742
+ if compact:
743
+ text = " ".join(final_lines)
744
+ else:
745
+ text = "\n".join(final_lines)
746
+ if copy_to_clipboard:
747
+ _copy_to_clipboard(text)
748
+
749
+ if return_lines:
750
+ return final_lines
751
+ return text
752
+
753
+
754
+ def _copy_to_clipboard(text: str) -> None:
755
+ import os
756
+ import shutil
757
+ import subprocess
758
+ import sys
759
+
760
+ errors: list[str] = []
761
+
762
+ def _try(cmd: list[str]) -> bool:
763
+ try:
764
+ subprocess.run(cmd, input=text.encode("utf-8"), check=True)
765
+ return True
766
+ except Exception as e:
767
+ errors.append(f"{cmd!r}: {type(e).__name__}: {e}")
768
+ return False
769
+
770
+ if sys.platform == "darwin":
771
+ if shutil.which("pbcopy") and _try(["pbcopy"]):
772
+ return
773
+
774
+ elif sys.platform.startswith("win"):
775
+ if shutil.which("clip") and _try(["clip"]):
776
+ return
777
+ if shutil.which("powershell"):
778
+ try:
779
+ subprocess.run(
780
+ ["powershell", "-NoProfile", "-Command", "Set-Clipboard"],
781
+ input=text.encode("utf-8"),
782
+ check=True,
783
+ )
784
+ return
785
+ except Exception as e:
786
+ errors.append(f"powershell Set-Clipboard: {type(e).__name__}: {e}")
787
+
788
+ else:
789
+ if os.environ.get("WAYLAND_DISPLAY") and shutil.which("wl-copy"):
790
+ if _try(["wl-copy"]):
791
+ return
792
+ if shutil.which("xclip"):
793
+ if _try(["xclip", "-selection", "clipboard"]):
794
+ return
795
+ if shutil.which("xsel"):
796
+ if _try(["xsel", "--clipboard", "--input"]):
797
+ return
798
+
799
+ try:
800
+ import tkinter
801
+
802
+ root = tkinter.Tk()
803
+ root.withdraw()
804
+ root.clipboard_clear()
805
+ root.clipboard_append(text)
806
+ root.update()
807
+ root.destroy()
808
+ return
809
+
810
+ except Exception as e:
811
+ errors.append(f"tkinter: {type(e).__name__}: {e}")
812
+
813
+ detail = "\n".join(f"- {msg}" for msg in errors) if errors else "- (no details)"
814
+ raise RuntimeError(
815
+ "Failed to copy to clipboard. Install a clipboard utility (macOS: pbcopy; "
816
+ "Wayland: wl-copy; X11: xclip/xsel; Windows: clip/powershell) or enable tkinter.\n"
817
+ f"{detail}"
818
+ )