lucid-dl 2.11.2__py3-none-any.whl → 2.11.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lucid/models/imgclf/crossvit.py +1 -1
- lucid/models/imgclf/efficientformer.py +2 -2
- lucid/models/imgclf/maxvit.py +1 -1
- lucid/models/imgclf/pvt.py +2 -2
- lucid/models/imggen/vae.py +1 -1
- lucid/models/objdet/efficientdet.py +24 -8
- lucid/models/objdet/rcnn.py +1 -1
- lucid/models/objdet/util.py +5 -0
- lucid/nn/module.py +142 -13
- lucid/types.py +58 -0
- lucid/visual/__init__.py +1 -0
- lucid/visual/graph.py +3 -0
- lucid/visual/mermaid.py +818 -0
- {lucid_dl-2.11.2.dist-info → lucid_dl-2.11.4.dist-info}/METADATA +30 -21
- {lucid_dl-2.11.2.dist-info → lucid_dl-2.11.4.dist-info}/RECORD +18 -17
- {lucid_dl-2.11.2.dist-info → lucid_dl-2.11.4.dist-info}/WHEEL +1 -1
- {lucid_dl-2.11.2.dist-info → lucid_dl-2.11.4.dist-info}/licenses/LICENSE +0 -0
- {lucid_dl-2.11.2.dist-info → lucid_dl-2.11.4.dist-info}/top_level.txt +0 -0
lucid/visual/mermaid.py
ADDED
|
@@ -0,0 +1,818 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
import json
|
|
3
|
+
from typing import Iterable, Literal
|
|
4
|
+
|
|
5
|
+
import lucid
|
|
6
|
+
import lucid.nn as nn
|
|
7
|
+
|
|
8
|
+
from lucid._tensor import Tensor
|
|
9
|
+
from lucid.types import _ShapeLike
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
__all__ = ["build_mermaid_chart"]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
_NN_MODULES_PREFIX = "lucid.nn.modules."
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
# fmt: off
|
|
19
|
+
_BUILTIN_SUBPACKAGE_STYLE: dict[str, tuple[str, str]] = {
|
|
20
|
+
"conv": ("#ffe8e8", "#c53030"),
|
|
21
|
+
"norm": ("#e6fffa", "#2c7a7b"),
|
|
22
|
+
"activation": ("#faf5ff", "#6b46c1"),
|
|
23
|
+
"linear": ("#ebf8ff", "#2b6cb0"),
|
|
24
|
+
"pool": ("#fefcbf", "#b7791f"),
|
|
25
|
+
"drop": ("#edf2f7", "#4a5568"),
|
|
26
|
+
"transformer": ("#e2e8f0", "#334155"),
|
|
27
|
+
"attention": ("#f0fff4", "#2f855a"),
|
|
28
|
+
"vision": ("#fdf2f8", "#b83280"),
|
|
29
|
+
"rnn": ("#f0f9ff", "#0284c7"),
|
|
30
|
+
"sparse": ("#f1f5f9", "#475569"),
|
|
31
|
+
"loss": ("#fffbeb", "#d97706"),
|
|
32
|
+
"einops": ("#ecfccb", "#65a30d"),
|
|
33
|
+
}
|
|
34
|
+
# fmt: on
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@dataclass
|
|
38
|
+
class _ModuleNode:
|
|
39
|
+
module: nn.Module
|
|
40
|
+
name: str
|
|
41
|
+
depth: int
|
|
42
|
+
children: list["_ModuleNode"]
|
|
43
|
+
group: list[nn.Module] | None = None
|
|
44
|
+
|
|
45
|
+
@property
|
|
46
|
+
def count(self) -> int:
|
|
47
|
+
return 1 if self.group is None else len(self.group)
|
|
48
|
+
|
|
49
|
+
def iter_modules(self) -> Iterable[nn.Module]:
|
|
50
|
+
if self.group is None:
|
|
51
|
+
yield self.module
|
|
52
|
+
return
|
|
53
|
+
yield from self.group
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def _flatten_tensors(obj: object) -> list[Tensor]:
|
|
57
|
+
tensors: list[Tensor] = []
|
|
58
|
+
|
|
59
|
+
if isinstance(obj, Tensor):
|
|
60
|
+
tensors.append(obj)
|
|
61
|
+
elif isinstance(obj, (list, tuple)):
|
|
62
|
+
for item in obj:
|
|
63
|
+
tensors.extend(_flatten_tensors(item))
|
|
64
|
+
elif isinstance(obj, dict):
|
|
65
|
+
for item in obj.values():
|
|
66
|
+
tensors.extend(_flatten_tensors(item))
|
|
67
|
+
|
|
68
|
+
return tensors
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def _build_tree(
|
|
72
|
+
module: nn.Module,
|
|
73
|
+
depth: int,
|
|
74
|
+
max_depth: int,
|
|
75
|
+
name: str = "",
|
|
76
|
+
*,
|
|
77
|
+
collapse_repeats: bool = False,
|
|
78
|
+
repeat_min: int = 3,
|
|
79
|
+
hide_subpackages: set[str] | None = None,
|
|
80
|
+
hide_module_names: set[str] | None = None,
|
|
81
|
+
) -> _ModuleNode:
|
|
82
|
+
children: list[_ModuleNode] = []
|
|
83
|
+
if depth < max_depth:
|
|
84
|
+
for child_name, child in module._modules.items():
|
|
85
|
+
path = f"{name}.{child_name}" if name else child_name
|
|
86
|
+
|
|
87
|
+
node = _build_tree(
|
|
88
|
+
child,
|
|
89
|
+
depth + 1,
|
|
90
|
+
max_depth,
|
|
91
|
+
path,
|
|
92
|
+
collapse_repeats=collapse_repeats,
|
|
93
|
+
repeat_min=repeat_min,
|
|
94
|
+
hide_subpackages=hide_subpackages,
|
|
95
|
+
hide_module_names=hide_module_names,
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
child_cls_name = type(child).__name__
|
|
99
|
+
child_mod_path = type(child).__module__
|
|
100
|
+
child_subpkg = None
|
|
101
|
+
|
|
102
|
+
if child_mod_path.startswith(_NN_MODULES_PREFIX):
|
|
103
|
+
rest = child_mod_path[len(_NN_MODULES_PREFIX) :]
|
|
104
|
+
child_subpkg = rest.split(".", 1)[0] if rest else None
|
|
105
|
+
|
|
106
|
+
excluded = False
|
|
107
|
+
if hide_module_names and child_cls_name in hide_module_names:
|
|
108
|
+
excluded = True
|
|
109
|
+
if hide_subpackages and child_subpkg and child_subpkg in hide_subpackages:
|
|
110
|
+
excluded = True
|
|
111
|
+
|
|
112
|
+
if excluded:
|
|
113
|
+
children.extend(node.children)
|
|
114
|
+
else:
|
|
115
|
+
children.append(node)
|
|
116
|
+
|
|
117
|
+
if collapse_repeats and children:
|
|
118
|
+
children = _collapse_repeated_children(children, repeat_min=repeat_min)
|
|
119
|
+
|
|
120
|
+
return _ModuleNode(module=module, name=name, depth=depth, children=children)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def _module_label(module: nn.Module, show_params: bool) -> str:
|
|
124
|
+
class_name = getattr(module, "_alt_name", "") or type(module).__name__
|
|
125
|
+
if show_params:
|
|
126
|
+
return f"{class_name} ({module.parameter_size:,} params)"
|
|
127
|
+
|
|
128
|
+
return class_name
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def _builtin_subpackage_key(module: nn.Module) -> str | None:
|
|
132
|
+
mod_path = type(module).__module__
|
|
133
|
+
if not mod_path.startswith(_NN_MODULES_PREFIX):
|
|
134
|
+
return None
|
|
135
|
+
|
|
136
|
+
rest = mod_path[len(_NN_MODULES_PREFIX) :]
|
|
137
|
+
return rest.split(".", 1)[0] if rest else None
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def _shape_text_color(module: nn.Module) -> str | None:
|
|
141
|
+
subpkg = _builtin_subpackage_key(module)
|
|
142
|
+
if subpkg is None:
|
|
143
|
+
return None
|
|
144
|
+
|
|
145
|
+
style = _BUILTIN_SUBPACKAGE_STYLE.get(subpkg)
|
|
146
|
+
if style is None:
|
|
147
|
+
return None
|
|
148
|
+
|
|
149
|
+
_, stroke = style
|
|
150
|
+
return stroke
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def _parse_rgba(value: str) -> tuple[str, float] | None:
|
|
154
|
+
v = value.strip().lower()
|
|
155
|
+
if not v.startswith("rgba(") or not v.endswith(")"):
|
|
156
|
+
return None
|
|
157
|
+
|
|
158
|
+
inner = v[5:-1]
|
|
159
|
+
parts = [p.strip() for p in inner.split(",")]
|
|
160
|
+
if len(parts) != 4:
|
|
161
|
+
return None
|
|
162
|
+
|
|
163
|
+
try:
|
|
164
|
+
r, g, b = (int(float(x)) for x in parts[:3])
|
|
165
|
+
a = float(parts[3])
|
|
166
|
+
except Exception:
|
|
167
|
+
return None
|
|
168
|
+
|
|
169
|
+
r = max(0, min(255, r))
|
|
170
|
+
g = max(0, min(255, g))
|
|
171
|
+
b = max(0, min(255, b))
|
|
172
|
+
a = max(0.0, min(1.0, a))
|
|
173
|
+
|
|
174
|
+
return (f"#{r:02x}{g:02x}{b:02x}", a)
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def _container_attr_label(node: _ModuleNode) -> str | None:
|
|
178
|
+
if not isinstance(node.module, (nn.Sequential, nn.ModuleList, nn.ModuleDict)):
|
|
179
|
+
return None
|
|
180
|
+
if not node.name:
|
|
181
|
+
return None
|
|
182
|
+
|
|
183
|
+
leaf = node.name.rsplit(".", 1)[-1]
|
|
184
|
+
if leaf.isdigit():
|
|
185
|
+
return None
|
|
186
|
+
|
|
187
|
+
return leaf
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def _escape_label(text: str) -> str:
|
|
191
|
+
return text.replace('"', """)
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def _shape_str(shape: object) -> str:
|
|
195
|
+
try:
|
|
196
|
+
if isinstance(shape, tuple):
|
|
197
|
+
return "(" + ",".join(str(x) for x in shape) + ")"
|
|
198
|
+
except Exception:
|
|
199
|
+
pass
|
|
200
|
+
return str(shape)
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def _flatten_shapes(obj: object) -> list[tuple[int, ...]]:
|
|
204
|
+
shapes: list[tuple[int, ...]] = []
|
|
205
|
+
for t in _flatten_tensors(obj):
|
|
206
|
+
try:
|
|
207
|
+
shapes.append(tuple(int(x) for x in t.shape))
|
|
208
|
+
except Exception:
|
|
209
|
+
continue
|
|
210
|
+
return shapes
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def _shapes_brief(shapes: list[tuple[int, ...]]) -> str:
|
|
214
|
+
if not shapes:
|
|
215
|
+
return "?"
|
|
216
|
+
if len(shapes) == 1:
|
|
217
|
+
return _shape_str(shapes[0])
|
|
218
|
+
return f"{_shape_str(shapes[0])}x{len(shapes)}"
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def _node_signature(node: _ModuleNode) -> tuple:
|
|
222
|
+
return (type(node.module), tuple(_node_signature(c) for c in node.children))
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def _collapse_repeated_children(
|
|
226
|
+
children: list[_ModuleNode], *, repeat_min: int
|
|
227
|
+
) -> list[_ModuleNode]:
|
|
228
|
+
if repeat_min <= 1:
|
|
229
|
+
return children
|
|
230
|
+
|
|
231
|
+
out: list[_ModuleNode] = []
|
|
232
|
+
i = 0
|
|
233
|
+
while i < len(children):
|
|
234
|
+
base = children[i]
|
|
235
|
+
base_sig = _node_signature(base)
|
|
236
|
+
j = i + 1
|
|
237
|
+
while j < len(children) and _node_signature(children[j]) == base_sig:
|
|
238
|
+
j += 1
|
|
239
|
+
|
|
240
|
+
run = children[i:j]
|
|
241
|
+
if len(run) >= repeat_min:
|
|
242
|
+
out.append(
|
|
243
|
+
_ModuleNode(
|
|
244
|
+
module=base.module,
|
|
245
|
+
name=base.name,
|
|
246
|
+
depth=base.depth,
|
|
247
|
+
children=base.children,
|
|
248
|
+
group=[n.module for n in run],
|
|
249
|
+
)
|
|
250
|
+
)
|
|
251
|
+
else:
|
|
252
|
+
out.extend(run)
|
|
253
|
+
|
|
254
|
+
i = j
|
|
255
|
+
return out
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def build_mermaid_chart(
|
|
259
|
+
module: nn.Module,
|
|
260
|
+
input_shape: _ShapeLike | list[_ShapeLike] | None = None,
|
|
261
|
+
inputs: Iterable[Tensor] | Tensor | None = None,
|
|
262
|
+
depth: int = 2,
|
|
263
|
+
direction: str = "LR",
|
|
264
|
+
include_io: bool = True,
|
|
265
|
+
show_params: bool = False,
|
|
266
|
+
return_lines: bool = False,
|
|
267
|
+
copy_to_clipboard: bool = False,
|
|
268
|
+
compact: bool = False,
|
|
269
|
+
use_class_defs: bool = False,
|
|
270
|
+
end_semicolons: bool = True,
|
|
271
|
+
edge_mode: Literal["dataflow", "execution"] = "execution",
|
|
272
|
+
collapse_repeats: bool = True,
|
|
273
|
+
repeat_min: int = 2,
|
|
274
|
+
color_by_subpackage: bool = True,
|
|
275
|
+
container_name_from_attr: bool = True,
|
|
276
|
+
edge_stroke_width: float = 2.0,
|
|
277
|
+
emphasize_model_title: bool = True,
|
|
278
|
+
model_title_font_px: int = 20,
|
|
279
|
+
show_shapes: bool = False,
|
|
280
|
+
hide_subpackages: Iterable[str] = (),
|
|
281
|
+
hide_module_names: Iterable[str] = (),
|
|
282
|
+
dash_multi_input_edges: bool = True,
|
|
283
|
+
subgraph_fill: str = "#000000",
|
|
284
|
+
subgraph_fill_opacity: float = 0.05,
|
|
285
|
+
subgraph_stroke: str = "#000000",
|
|
286
|
+
subgraph_stroke_opacity: float = 0.75,
|
|
287
|
+
force_text_color: str | None = None,
|
|
288
|
+
edge_curve: str = "natural",
|
|
289
|
+
node_spacing: int = 50,
|
|
290
|
+
rank_spacing: int = 50,
|
|
291
|
+
**forward_kwargs,
|
|
292
|
+
) -> str | list[str]:
|
|
293
|
+
if inputs is None and input_shape is None:
|
|
294
|
+
raise ValueError("Either inputs or input_shape must be provided.")
|
|
295
|
+
if depth < 0:
|
|
296
|
+
raise ValueError("depth must be >= 0")
|
|
297
|
+
|
|
298
|
+
tree = _build_tree(
|
|
299
|
+
module,
|
|
300
|
+
depth=0,
|
|
301
|
+
max_depth=depth,
|
|
302
|
+
collapse_repeats=collapse_repeats,
|
|
303
|
+
repeat_min=repeat_min,
|
|
304
|
+
hide_subpackages=set(hide_subpackages),
|
|
305
|
+
hide_module_names=set(hide_module_names),
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
nodes: list[_ModuleNode] = []
|
|
309
|
+
|
|
310
|
+
def _collect(n: _ModuleNode) -> None:
|
|
311
|
+
nodes.append(n)
|
|
312
|
+
for c in n.children:
|
|
313
|
+
_collect(c)
|
|
314
|
+
|
|
315
|
+
_collect(tree)
|
|
316
|
+
|
|
317
|
+
module_to_node: dict[nn.Module, _ModuleNode] = {n.module: n for n in nodes}
|
|
318
|
+
|
|
319
|
+
module_to_id: dict[nn.Module, str] = {}
|
|
320
|
+
for idx, n in enumerate(nodes):
|
|
321
|
+
node_id = f"m{idx}"
|
|
322
|
+
for mod in n.iter_modules():
|
|
323
|
+
module_to_id[mod] = node_id
|
|
324
|
+
|
|
325
|
+
def _build_parent_map(root: nn.Module) -> dict[nn.Module, nn.Module]:
|
|
326
|
+
parent: dict[nn.Module, nn.Module] = {}
|
|
327
|
+
|
|
328
|
+
def _walk(mod: nn.Module) -> None:
|
|
329
|
+
for child in mod._modules.values():
|
|
330
|
+
parent[child] = mod
|
|
331
|
+
_walk(child)
|
|
332
|
+
|
|
333
|
+
_walk(root)
|
|
334
|
+
return parent
|
|
335
|
+
|
|
336
|
+
parent_map = _build_parent_map(module)
|
|
337
|
+
has_non_root_included = any(mod is not module for mod in module_to_id)
|
|
338
|
+
|
|
339
|
+
def _map_to_included(mod: nn.Module) -> nn.Module | None:
|
|
340
|
+
cur = mod
|
|
341
|
+
while cur not in module_to_id and cur in parent_map:
|
|
342
|
+
cur = parent_map[cur]
|
|
343
|
+
return cur if cur in module_to_id else None
|
|
344
|
+
|
|
345
|
+
hooks = []
|
|
346
|
+
edges: set[tuple[str, str]] = set()
|
|
347
|
+
tensor_producer: dict[int, nn.Module] = {}
|
|
348
|
+
module_in_shapes: dict[nn.Module, list[tuple[int, ...]]] = {}
|
|
349
|
+
module_out_shapes: dict[nn.Module, list[tuple[int, ...]]] = {}
|
|
350
|
+
input_node_id = "input"
|
|
351
|
+
output_node_id = "output"
|
|
352
|
+
root_module = module
|
|
353
|
+
exec_order: list[nn.Module] = []
|
|
354
|
+
|
|
355
|
+
def _hook(
|
|
356
|
+
mod: nn.Module, input_arg: tuple, output: Tensor | tuple[Tensor, ...]
|
|
357
|
+
) -> None:
|
|
358
|
+
mapped_mod = _map_to_included(mod)
|
|
359
|
+
if mapped_mod is None:
|
|
360
|
+
return
|
|
361
|
+
|
|
362
|
+
if show_shapes:
|
|
363
|
+
module_in_shapes[mapped_mod] = _flatten_shapes(input_arg)
|
|
364
|
+
module_out_shapes[mapped_mod] = _flatten_shapes(output)
|
|
365
|
+
|
|
366
|
+
if edge_mode == "dataflow":
|
|
367
|
+
input_tensors = _flatten_tensors(input_arg)
|
|
368
|
+
for t in input_tensors:
|
|
369
|
+
producer = tensor_producer.get(id(t))
|
|
370
|
+
if producer is None:
|
|
371
|
+
if include_io and (
|
|
372
|
+
mapped_mod is not root_module or not has_non_root_included
|
|
373
|
+
):
|
|
374
|
+
edges.add((input_node_id, module_to_id[mapped_mod]))
|
|
375
|
+
|
|
376
|
+
else:
|
|
377
|
+
if producer is not mapped_mod:
|
|
378
|
+
edges.add((module_to_id[producer], module_to_id[mapped_mod]))
|
|
379
|
+
|
|
380
|
+
output_tensors = _flatten_tensors(output)
|
|
381
|
+
for t in output_tensors:
|
|
382
|
+
key = id(t)
|
|
383
|
+
if mapped_mod is root_module and key in tensor_producer:
|
|
384
|
+
continue
|
|
385
|
+
tensor_producer[key] = mapped_mod
|
|
386
|
+
|
|
387
|
+
exec_order.append(mapped_mod)
|
|
388
|
+
|
|
389
|
+
for mod in module.modules():
|
|
390
|
+
hooks.append(mod.register_forward_hook(_hook))
|
|
391
|
+
|
|
392
|
+
try:
|
|
393
|
+
if inputs is None:
|
|
394
|
+
if isinstance(input_shape, list):
|
|
395
|
+
input_tensors = [
|
|
396
|
+
lucid.random.randn(shape, device=module.device)
|
|
397
|
+
for shape in input_shape
|
|
398
|
+
]
|
|
399
|
+
else:
|
|
400
|
+
input_tensors = [lucid.random.randn(input_shape, device=module.device)]
|
|
401
|
+
else:
|
|
402
|
+
if isinstance(inputs, Tensor):
|
|
403
|
+
input_tensors = [inputs]
|
|
404
|
+
else:
|
|
405
|
+
input_tensors = list(inputs)
|
|
406
|
+
|
|
407
|
+
outputs = module(*input_tensors, **forward_kwargs)
|
|
408
|
+
finally:
|
|
409
|
+
for remove in hooks:
|
|
410
|
+
remove()
|
|
411
|
+
|
|
412
|
+
model_input_shapes = _flatten_shapes(input_tensors)
|
|
413
|
+
model_output_shapes = _flatten_shapes(outputs)
|
|
414
|
+
|
|
415
|
+
if edge_mode == "execution":
|
|
416
|
+
seq = []
|
|
417
|
+
for m in exec_order:
|
|
418
|
+
if not seq or seq[-1] is not m:
|
|
419
|
+
seq.append(m)
|
|
420
|
+
seq_no_root = [m for m in seq if m is not root_module]
|
|
421
|
+
if not seq_no_root:
|
|
422
|
+
seq_no_root = [root_module]
|
|
423
|
+
|
|
424
|
+
def _is_included_container(mod: nn.Module) -> bool:
|
|
425
|
+
node = module_to_node.get(mod)
|
|
426
|
+
return bool(node and node.children)
|
|
427
|
+
|
|
428
|
+
seq_effective = [m for m in seq_no_root if not _is_included_container(m)]
|
|
429
|
+
if not seq_effective:
|
|
430
|
+
seq_effective = seq_no_root
|
|
431
|
+
|
|
432
|
+
for prev, cur in zip(seq_effective, seq_effective[1:]):
|
|
433
|
+
edges.add((module_to_id[prev], module_to_id[cur]))
|
|
434
|
+
|
|
435
|
+
if include_io and seq_effective:
|
|
436
|
+
edges.add((input_node_id, module_to_id[seq_effective[0]]))
|
|
437
|
+
edges.add((module_to_id[seq_effective[-1]], output_node_id))
|
|
438
|
+
|
|
439
|
+
else:
|
|
440
|
+
if include_io:
|
|
441
|
+
output_tensors = _flatten_tensors(outputs)
|
|
442
|
+
for t in output_tensors:
|
|
443
|
+
producer = tensor_producer.get(id(t))
|
|
444
|
+
if producer is not None:
|
|
445
|
+
edges.add((module_to_id[producer], output_node_id))
|
|
446
|
+
|
|
447
|
+
node_ids = set(module_to_id.values())
|
|
448
|
+
nodes_with_edges: set[str] = set()
|
|
449
|
+
|
|
450
|
+
for src, dst in edges:
|
|
451
|
+
if src in node_ids:
|
|
452
|
+
nodes_with_edges.add(src)
|
|
453
|
+
if dst in node_ids:
|
|
454
|
+
nodes_with_edges.add(dst)
|
|
455
|
+
|
|
456
|
+
container_node_ids = {module_to_id[n.module] for n in nodes if n.children}
|
|
457
|
+
container_with_edges = container_node_ids & nodes_with_edges
|
|
458
|
+
extra_edges: list[tuple[str, str]] = []
|
|
459
|
+
|
|
460
|
+
def _endpoint_in(node_id: str) -> str:
|
|
461
|
+
return f"{node_id}_in" if node_id in container_with_edges else node_id
|
|
462
|
+
|
|
463
|
+
def _endpoint_out(node_id: str) -> str:
|
|
464
|
+
return f"{node_id}_out" if node_id in container_with_edges else node_id
|
|
465
|
+
|
|
466
|
+
def _first_leaf_id(n: _ModuleNode) -> str:
|
|
467
|
+
cur = n
|
|
468
|
+
while cur.children:
|
|
469
|
+
cur = cur.children[0]
|
|
470
|
+
return module_to_id[cur.module]
|
|
471
|
+
|
|
472
|
+
def _last_leaf_id(n: _ModuleNode) -> str:
|
|
473
|
+
cur = n
|
|
474
|
+
while cur.children:
|
|
475
|
+
cur = cur.children[-1]
|
|
476
|
+
return module_to_id[cur.module]
|
|
477
|
+
|
|
478
|
+
lines: list[str] = []
|
|
479
|
+
init_cfg: dict[str, object] = {
|
|
480
|
+
"flowchart": {
|
|
481
|
+
"curve": "step" if edge_curve == "round" else edge_curve,
|
|
482
|
+
"nodeSpacing": node_spacing,
|
|
483
|
+
"rankSpacing": rank_spacing,
|
|
484
|
+
}
|
|
485
|
+
}
|
|
486
|
+
css_parts: list[str] = []
|
|
487
|
+
if edge_curve == "round":
|
|
488
|
+
css_parts.append(
|
|
489
|
+
".edgePath path { stroke-linecap: round; stroke-linejoin: round; }"
|
|
490
|
+
)
|
|
491
|
+
if force_text_color:
|
|
492
|
+
css_parts.append(
|
|
493
|
+
f".nodeLabel, .edgeLabel, .cluster text, .node text "
|
|
494
|
+
f"{{ fill: {force_text_color} !important; }} "
|
|
495
|
+
f".node foreignObject *, .cluster foreignObject * "
|
|
496
|
+
f"{{ color: {force_text_color} !important; }}"
|
|
497
|
+
)
|
|
498
|
+
if css_parts:
|
|
499
|
+
init_cfg["themeCSS"] = " ".join(css_parts)
|
|
500
|
+
|
|
501
|
+
lines.append(f"%%{{init: {json.dumps(init_cfg, separators=(',', ':'))} }}%%")
|
|
502
|
+
lines.append(f"flowchart {direction}")
|
|
503
|
+
|
|
504
|
+
if edge_stroke_width and edge_stroke_width != 1.0:
|
|
505
|
+
lines.append(f" linkStyle default stroke-width:{edge_stroke_width}px")
|
|
506
|
+
|
|
507
|
+
if use_class_defs:
|
|
508
|
+
lines.append(" classDef module fill:#f9f9f9,stroke:#333,stroke-width:1px;")
|
|
509
|
+
lines.append(" classDef modelio fill:#fff3cd,stroke:#a67c00,stroke-width:1px;")
|
|
510
|
+
lines.append(
|
|
511
|
+
" classDef internalio fill:#e2e8f0,stroke:#64748b,stroke-width:1px;"
|
|
512
|
+
)
|
|
513
|
+
lines.append(
|
|
514
|
+
" classDef anchor fill:transparent,stroke:transparent,color:transparent;"
|
|
515
|
+
)
|
|
516
|
+
lines.append(" classDef repeat fill:#e8f1ff,stroke:#2b6cb0,stroke-width:1px;")
|
|
517
|
+
|
|
518
|
+
vertical_levels: set[int] = set()
|
|
519
|
+
if depth >= 3:
|
|
520
|
+
vertical_levels = {depth - 2, depth - 1}
|
|
521
|
+
|
|
522
|
+
def _render(n: _ModuleNode, indent: str = " ") -> None:
|
|
523
|
+
node_id = module_to_id[n.module]
|
|
524
|
+
base_label = _module_label(n.module, show_params)
|
|
525
|
+
|
|
526
|
+
if container_name_from_attr and n.children:
|
|
527
|
+
attr_label = _container_attr_label(n)
|
|
528
|
+
if attr_label is not None:
|
|
529
|
+
base_label = attr_label
|
|
530
|
+
|
|
531
|
+
base_label = _escape_label(base_label)
|
|
532
|
+
label_text = base_label if n.count == 1 else f"{base_label} x {n.count}"
|
|
533
|
+
|
|
534
|
+
if show_shapes and not n.children:
|
|
535
|
+
in_shapes = module_in_shapes.get(n.module, [])
|
|
536
|
+
out_shapes = module_out_shapes.get(n.module, [])
|
|
537
|
+
|
|
538
|
+
if in_shapes != out_shapes and (in_shapes or out_shapes):
|
|
539
|
+
ins = _shapes_brief(in_shapes)
|
|
540
|
+
outs = _shapes_brief(out_shapes)
|
|
541
|
+
color_css = ""
|
|
542
|
+
if not force_text_color:
|
|
543
|
+
color = _shape_text_color(n.module)
|
|
544
|
+
color_css = f"color:{color};" if color else ""
|
|
545
|
+
label_text = (
|
|
546
|
+
f"{label_text}<br/>"
|
|
547
|
+
f"<span style='font-size:11px;{color_css}font-weight:400'>"
|
|
548
|
+
f"{ins} \u2192 {outs}"
|
|
549
|
+
f"</span>"
|
|
550
|
+
)
|
|
551
|
+
|
|
552
|
+
label = label_text
|
|
553
|
+
if (
|
|
554
|
+
emphasize_model_title
|
|
555
|
+
and n.module is root_module
|
|
556
|
+
and model_title_font_px
|
|
557
|
+
and model_title_font_px > 0
|
|
558
|
+
):
|
|
559
|
+
label = (
|
|
560
|
+
f"<span style='font-size:{model_title_font_px}px;font-weight:700'>"
|
|
561
|
+
f"{label}"
|
|
562
|
+
f"</span>"
|
|
563
|
+
)
|
|
564
|
+
|
|
565
|
+
if n.children:
|
|
566
|
+
lines.append(f'{indent}subgraph sg_{node_id}["{label}"]')
|
|
567
|
+
if n.depth in vertical_levels:
|
|
568
|
+
lines.append(f"{indent} direction TB")
|
|
569
|
+
if subgraph_fill or subgraph_stroke:
|
|
570
|
+
parts: list[str] = []
|
|
571
|
+
|
|
572
|
+
fill = subgraph_fill
|
|
573
|
+
fill_opacity = subgraph_fill_opacity
|
|
574
|
+
parsed = _parse_rgba(fill)
|
|
575
|
+
|
|
576
|
+
if parsed is not None:
|
|
577
|
+
fill, fill_opacity = parsed
|
|
578
|
+
if fill:
|
|
579
|
+
parts.append(f"fill:{fill}")
|
|
580
|
+
parts.append(f"fill-opacity:{fill_opacity}")
|
|
581
|
+
|
|
582
|
+
stroke = subgraph_stroke
|
|
583
|
+
stroke_opacity = subgraph_stroke_opacity
|
|
584
|
+
if stroke:
|
|
585
|
+
parsed = _parse_rgba(stroke)
|
|
586
|
+
if parsed is not None:
|
|
587
|
+
stroke, stroke_opacity = parsed
|
|
588
|
+
|
|
589
|
+
parts.append(f"stroke:{stroke}")
|
|
590
|
+
parts.append(f"stroke-opacity:{stroke_opacity}")
|
|
591
|
+
parts.append("stroke-width:1px")
|
|
592
|
+
|
|
593
|
+
lines.append(f'{indent}style sg_{node_id} {",".join(parts)}')
|
|
594
|
+
|
|
595
|
+
if node_id in nodes_with_edges:
|
|
596
|
+
if node_id in container_with_edges:
|
|
597
|
+
in_id = f"{node_id}_in"
|
|
598
|
+
out_id = f"{node_id}_out"
|
|
599
|
+
|
|
600
|
+
if use_class_defs:
|
|
601
|
+
lines.append(f'{indent} {in_id}(["Input"]):::internalio')
|
|
602
|
+
lines.append(f'{indent} {out_id}(["Output"]):::internalio')
|
|
603
|
+
else:
|
|
604
|
+
lines.append(f'{indent} {in_id}(["Input"])')
|
|
605
|
+
lines.append(f'{indent} {out_id}(["Output"])')
|
|
606
|
+
lines.append(
|
|
607
|
+
f" style {in_id} fill:#e2e8f0,stroke:#64748b,stroke-width:1px;"
|
|
608
|
+
)
|
|
609
|
+
lines.append(
|
|
610
|
+
f" style {out_id} fill:#e2e8f0,stroke:#64748b,stroke-width:1px;"
|
|
611
|
+
)
|
|
612
|
+
|
|
613
|
+
extra_edges.append((in_id, _endpoint_in(_first_leaf_id(n))))
|
|
614
|
+
extra_edges.append((_endpoint_out(_last_leaf_id(n)), out_id))
|
|
615
|
+
|
|
616
|
+
else:
|
|
617
|
+
if use_class_defs:
|
|
618
|
+
lines.append(f'{indent} {node_id}[""]:::anchor')
|
|
619
|
+
else:
|
|
620
|
+
lines.append(f'{indent} {node_id}["\u200b"]')
|
|
621
|
+
lines.append(
|
|
622
|
+
f" style {node_id} fill:transparent,stroke:transparent,color:transparent;"
|
|
623
|
+
)
|
|
624
|
+
|
|
625
|
+
for c in n.children:
|
|
626
|
+
_render(c, indent + " ")
|
|
627
|
+
lines.append(f"{indent}end")
|
|
628
|
+
|
|
629
|
+
else:
|
|
630
|
+
if use_class_defs:
|
|
631
|
+
class_name = "module" if n.count == 1 else "repeat"
|
|
632
|
+
lines.append(f'{indent}{node_id}["{label}"]:::{class_name}')
|
|
633
|
+
else:
|
|
634
|
+
if n.count > 1:
|
|
635
|
+
lines.append(f'{indent}{node_id}(["{label}"])')
|
|
636
|
+
else:
|
|
637
|
+
lines.append(f'{indent}{node_id}["{label}"]')
|
|
638
|
+
|
|
639
|
+
_render(tree)
|
|
640
|
+
|
|
641
|
+
if include_io:
|
|
642
|
+
input_label = "Input"
|
|
643
|
+
output_label = "Output"
|
|
644
|
+
if show_shapes:
|
|
645
|
+
in_s = _shapes_brief(model_input_shapes)
|
|
646
|
+
out_s = _shapes_brief(model_output_shapes)
|
|
647
|
+
io_color = force_text_color or "#a67c00"
|
|
648
|
+
input_label = (
|
|
649
|
+
f"{input_label}<br/>"
|
|
650
|
+
f"<span style='font-size:11px;color:{io_color};font-weight:400'>{in_s}</span>"
|
|
651
|
+
)
|
|
652
|
+
output_label = (
|
|
653
|
+
f"{output_label}<br/>"
|
|
654
|
+
f"<span style='font-size:11px;color:{io_color};font-weight:400'>{out_s}</span>"
|
|
655
|
+
)
|
|
656
|
+
if use_class_defs:
|
|
657
|
+
lines.append(f' {input_node_id}["{input_label}"]:::modelio')
|
|
658
|
+
lines.append(f' {output_node_id}["{output_label}"]:::modelio')
|
|
659
|
+
else:
|
|
660
|
+
lines.append(f' {input_node_id}["{input_label}"]')
|
|
661
|
+
lines.append(f' {output_node_id}["{output_label}"]')
|
|
662
|
+
lines.append(
|
|
663
|
+
f" style {input_node_id} fill:#fff3cd,stroke:#a67c00,stroke-width:1px;"
|
|
664
|
+
)
|
|
665
|
+
lines.append(
|
|
666
|
+
f" style {output_node_id} fill:#fff3cd,stroke:#a67c00,stroke-width:1px;"
|
|
667
|
+
)
|
|
668
|
+
|
|
669
|
+
if color_by_subpackage:
|
|
670
|
+
for n in nodes:
|
|
671
|
+
if n.children:
|
|
672
|
+
continue
|
|
673
|
+
|
|
674
|
+
subpkg = _builtin_subpackage_key(n.module)
|
|
675
|
+
if subpkg is None:
|
|
676
|
+
continue
|
|
677
|
+
style = _BUILTIN_SUBPACKAGE_STYLE.get(subpkg)
|
|
678
|
+
if style is None:
|
|
679
|
+
continue
|
|
680
|
+
|
|
681
|
+
fill, stroke = style
|
|
682
|
+
node_id = module_to_id[n.module]
|
|
683
|
+
if node_id in {input_node_id, output_node_id}:
|
|
684
|
+
continue
|
|
685
|
+
|
|
686
|
+
if node_id.endswith("_in") or node_id.endswith("_out"):
|
|
687
|
+
continue
|
|
688
|
+
lines.append(
|
|
689
|
+
f" style {node_id} fill:{fill},stroke:{stroke},stroke-width:1px;"
|
|
690
|
+
)
|
|
691
|
+
|
|
692
|
+
render_edges: set[tuple[str, str]] = set()
|
|
693
|
+
for src, dst in edges:
|
|
694
|
+
src_id = _endpoint_out(src)
|
|
695
|
+
dst_id = _endpoint_in(dst)
|
|
696
|
+
if src_id != dst_id:
|
|
697
|
+
render_edges.add((src_id, dst_id))
|
|
698
|
+
|
|
699
|
+
for src, dst in extra_edges:
|
|
700
|
+
if src != dst:
|
|
701
|
+
render_edges.add((src, dst))
|
|
702
|
+
|
|
703
|
+
indegree: dict[str, int] = {}
|
|
704
|
+
for _, dst in render_edges:
|
|
705
|
+
indegree[dst] = indegree.get(dst, 0) + 1
|
|
706
|
+
|
|
707
|
+
for src, dst in sorted(render_edges):
|
|
708
|
+
arrow = "-.->" if dash_multi_input_edges and indegree.get(dst, 0) > 1 else "-->"
|
|
709
|
+
lines.append(f" {src} {arrow} {dst}")
|
|
710
|
+
|
|
711
|
+
def _finalize_lines(src_lines: list[str]) -> list[str]:
|
|
712
|
+
if not end_semicolons:
|
|
713
|
+
return src_lines
|
|
714
|
+
|
|
715
|
+
out: list[str] = []
|
|
716
|
+
for line in src_lines:
|
|
717
|
+
stripped = line.rstrip()
|
|
718
|
+
head = stripped.lstrip()
|
|
719
|
+
if (
|
|
720
|
+
not head
|
|
721
|
+
or head.startswith("flowchart ")
|
|
722
|
+
or head.startswith("linkStyle ")
|
|
723
|
+
or head.startswith("subgraph ")
|
|
724
|
+
or head == "end"
|
|
725
|
+
or head.startswith("classDef ")
|
|
726
|
+
or head.startswith("class ")
|
|
727
|
+
or head.startswith("style ")
|
|
728
|
+
or head.startswith("%%")
|
|
729
|
+
):
|
|
730
|
+
out.append(line)
|
|
731
|
+
continue
|
|
732
|
+
|
|
733
|
+
if stripped.endswith(";"):
|
|
734
|
+
out.append(line)
|
|
735
|
+
else:
|
|
736
|
+
out.append(f"{line};")
|
|
737
|
+
|
|
738
|
+
return out
|
|
739
|
+
|
|
740
|
+
final_lines = _finalize_lines(lines)
|
|
741
|
+
|
|
742
|
+
if compact:
|
|
743
|
+
text = " ".join(final_lines)
|
|
744
|
+
else:
|
|
745
|
+
text = "\n".join(final_lines)
|
|
746
|
+
if copy_to_clipboard:
|
|
747
|
+
_copy_to_clipboard(text)
|
|
748
|
+
|
|
749
|
+
if return_lines:
|
|
750
|
+
return final_lines
|
|
751
|
+
return text
|
|
752
|
+
|
|
753
|
+
|
|
754
|
+
def _copy_to_clipboard(text: str) -> None:
|
|
755
|
+
import os
|
|
756
|
+
import shutil
|
|
757
|
+
import subprocess
|
|
758
|
+
import sys
|
|
759
|
+
|
|
760
|
+
errors: list[str] = []
|
|
761
|
+
|
|
762
|
+
def _try(cmd: list[str]) -> bool:
|
|
763
|
+
try:
|
|
764
|
+
subprocess.run(cmd, input=text.encode("utf-8"), check=True)
|
|
765
|
+
return True
|
|
766
|
+
except Exception as e:
|
|
767
|
+
errors.append(f"{cmd!r}: {type(e).__name__}: {e}")
|
|
768
|
+
return False
|
|
769
|
+
|
|
770
|
+
if sys.platform == "darwin":
|
|
771
|
+
if shutil.which("pbcopy") and _try(["pbcopy"]):
|
|
772
|
+
return
|
|
773
|
+
|
|
774
|
+
elif sys.platform.startswith("win"):
|
|
775
|
+
if shutil.which("clip") and _try(["clip"]):
|
|
776
|
+
return
|
|
777
|
+
if shutil.which("powershell"):
|
|
778
|
+
try:
|
|
779
|
+
subprocess.run(
|
|
780
|
+
["powershell", "-NoProfile", "-Command", "Set-Clipboard"],
|
|
781
|
+
input=text.encode("utf-8"),
|
|
782
|
+
check=True,
|
|
783
|
+
)
|
|
784
|
+
return
|
|
785
|
+
except Exception as e:
|
|
786
|
+
errors.append(f"powershell Set-Clipboard: {type(e).__name__}: {e}")
|
|
787
|
+
|
|
788
|
+
else:
|
|
789
|
+
if os.environ.get("WAYLAND_DISPLAY") and shutil.which("wl-copy"):
|
|
790
|
+
if _try(["wl-copy"]):
|
|
791
|
+
return
|
|
792
|
+
if shutil.which("xclip"):
|
|
793
|
+
if _try(["xclip", "-selection", "clipboard"]):
|
|
794
|
+
return
|
|
795
|
+
if shutil.which("xsel"):
|
|
796
|
+
if _try(["xsel", "--clipboard", "--input"]):
|
|
797
|
+
return
|
|
798
|
+
|
|
799
|
+
try:
|
|
800
|
+
import tkinter
|
|
801
|
+
|
|
802
|
+
root = tkinter.Tk()
|
|
803
|
+
root.withdraw()
|
|
804
|
+
root.clipboard_clear()
|
|
805
|
+
root.clipboard_append(text)
|
|
806
|
+
root.update()
|
|
807
|
+
root.destroy()
|
|
808
|
+
return
|
|
809
|
+
|
|
810
|
+
except Exception as e:
|
|
811
|
+
errors.append(f"tkinter: {type(e).__name__}: {e}")
|
|
812
|
+
|
|
813
|
+
detail = "\n".join(f"- {msg}" for msg in errors) if errors else "- (no details)"
|
|
814
|
+
raise RuntimeError(
|
|
815
|
+
"Failed to copy to clipboard. Install a clipboard utility (macOS: pbcopy; "
|
|
816
|
+
"Wayland: wl-copy; X11: xclip/xsel; Windows: clip/powershell) or enable tkinter.\n"
|
|
817
|
+
f"{detail}"
|
|
818
|
+
)
|