model-unfolder 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. model_unfolder/__init__.py +58 -0
  2. model_unfolder/adapters/__init__.py +15 -0
  3. model_unfolder/adapters/custom/__init__.py +8 -0
  4. model_unfolder/adapters/diffusor/__init__.py +8 -0
  5. model_unfolder/adapters/transformer/__init__.py +5 -0
  6. model_unfolder/adapters/transformer/assembly.py +57 -0
  7. model_unfolder/adapters/transformer/blocks.py +238 -0
  8. model_unfolder/adapters/transformer/common.py +35 -0
  9. model_unfolder/adapters/transformer/families/__init__.py +12 -0
  10. model_unfolder/adapters/transformer/families/deepseek.py +107 -0
  11. model_unfolder/adapters/transformer/families/gemma4.py +202 -0
  12. model_unfolder/adapters/transformer/families/llama.py +91 -0
  13. model_unfolder/adapters/transformer/special_parts/__init__.py +2 -0
  14. model_unfolder/adapters/transformer/special_parts/per_layer_embedding.py +220 -0
  15. model_unfolder/diagram.py +95 -0
  16. model_unfolder/html_renderer.py +5 -0
  17. model_unfolder/ir.py +163 -0
  18. model_unfolder/labels.py +166 -0
  19. model_unfolder/params.py +119 -0
  20. model_unfolder/parser.py +137 -0
  21. model_unfolder/renderers/__init__.py +1 -0
  22. model_unfolder/renderers/html/__init__.py +5 -0
  23. model_unfolder/renderers/html/block_views/__init__.py +20 -0
  24. model_unfolder/renderers/html/block_views/attention.py +91 -0
  25. model_unfolder/renderers/html/block_views/feed_forward.py +213 -0
  26. model_unfolder/renderers/html/block_views/per_layer_embedding.py +199 -0
  27. model_unfolder/renderers/html/cards.py +130 -0
  28. model_unfolder/renderers/html/document.py +157 -0
  29. model_unfolder/renderers/html/interactions.py +64 -0
  30. model_unfolder/renderers/html/metadata.py +265 -0
  31. model_unfolder/renderers/html/sections.py +60 -0
  32. model_unfolder/renderers/html/styles.py +283 -0
  33. model_unfolder/renderers/html/svg.py +349 -0
  34. model_unfolder/renderers/html/theme.py +24 -0
  35. model_unfolder/renderers/html/utils.py +28 -0
  36. model_unfolder/renderers/html/views.py +461 -0
  37. model_unfolder-0.2.0.dist-info/METADATA +122 -0
  38. model_unfolder-0.2.0.dist-info/RECORD +41 -0
  39. model_unfolder-0.2.0.dist-info/WHEEL +5 -0
  40. model_unfolder-0.2.0.dist-info/licenses/LICENSE +201 -0
  41. model_unfolder-0.2.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,202 @@
1
+ """Adapter for Google's Gemma 4 family (E2B / E4B / 31B / 26B-A4B).
2
+
3
+ Gemma 4 nests its language-model fields under ``text_config`` since the
4
+ top-level config also covers the vision and audio encoders. The text stack
5
+ itself has a few features the older Llama/Gemma adapters don't model:
6
+
7
+ * Per-layer ``layer_types`` array tagging each block as ``sliding_attention``
8
+ or ``full_attention`` (instead of Gemma 3's "every Nth layer" pattern).
9
+ * Dual attention shape: sliding layers use ``head_dim`` + ``num_key_value_heads``;
10
+ global layers use ``global_head_dim`` + ``num_global_key_value_heads``.
11
+ * Optional MoE FFN (26B-A4B) controlled by ``enable_moe_block``.
12
+ * Per-Layer Embeddings (PLE) on small models — ``hidden_size_per_layer_input``
13
+ is the parallel conditioning dim, ``vocab_size_per_layer_input`` its vocab.
14
+ * Shared-KV layers — the last ``num_kv_shared_layers`` layers reuse K/V from
15
+ the last earlier layer of the same attention type.
16
+ """
17
+ from __future__ import annotations
18
+
19
+ from typing import Any
20
+
21
+ from ....ir import AttentionSpec, CrossLayerEdge, FFNSpec, ModelIR
22
+ from ..assembly import decoder_extras, decoder_layer
23
+ from ..common import architecture_name, get_config_value as _g, model_name
24
+ from ..special_parts.per_layer_embedding import (
25
+ per_layer_embedding_blocks,
26
+ per_layer_embedding_extras,
27
+ )
28
+
29
+
30
+ _TOP_TYPES = {"gemma4"}
31
+ _TEXT_TYPES = {"gemma4_text"}
32
+ _ARCH_HINTS = ("gemma4",)
33
+
34
+
35
+ def matches(cfg: Any) -> bool:
36
+ arches = _g(cfg, "architectures") or []
37
+ if any(any(hint in arch.lower() for hint in _ARCH_HINTS) for arch in arches):
38
+ return True
39
+ model_type = (_g(cfg, "model_type", "") or "").lower()
40
+ if model_type in _TOP_TYPES or model_type in _TEXT_TYPES:
41
+ return True
42
+ return False
43
+
44
+
45
+ def parse(cfg: Any) -> ModelIR:
46
+ arch_name = architecture_name(cfg, "gemma4")
47
+ text_cfg = _text_config(cfg)
48
+
49
+ num_layers = _g(text_cfg, "num_hidden_layers", 0)
50
+ hidden_size = _g(text_cfg, "hidden_size", 0)
51
+ intermediate_size = _g(text_cfg, "intermediate_size", 0)
52
+ activation = (_g(text_cfg, "hidden_activation") or _g(text_cfg, "hidden_act") or "gelu_pytorch_tanh").lower()
53
+
54
+ num_q = _g(text_cfg, "num_attention_heads", 0)
55
+ num_kv = _g(text_cfg, "num_key_value_heads", num_q)
56
+ num_kv_global = _g(text_cfg, "num_global_key_value_heads") or num_kv
57
+ head_dim = _g(text_cfg, "head_dim") or (hidden_size // num_q if num_q else None)
58
+ head_dim_global = _g(text_cfg, "global_head_dim") or head_dim
59
+ sliding_window = _g(text_cfg, "sliding_window")
60
+
61
+ layer_types = _g(text_cfg, "layer_types") or []
62
+
63
+ moe_enabled = bool(_g(text_cfg, "enable_moe_block"))
64
+ num_experts = _g(text_cfg, "num_experts") or 0
65
+ top_k = _g(text_cfg, "top_k_experts") or 0
66
+ moe_intermediate_size = _g(text_cfg, "moe_intermediate_size") or 0
67
+
68
+ num_kv_shared = _g(text_cfg, "num_kv_shared_layers") or 0
69
+ first_shared = num_layers - num_kv_shared if num_kv_shared else num_layers
70
+
71
+ ple_dim = _g(text_cfg, "hidden_size_per_layer_input") or 0
72
+ ple_vocab = _g(text_cfg, "vocab_size_per_layer_input") or _g(text_cfg, "vocab_size", 0)
73
+
74
+ layers = []
75
+ cross_edges: list[CrossLayerEdge] = []
76
+ for i in range(num_layers):
77
+ layer_type = layer_types[i] if i < len(layer_types) else "full_attention"
78
+ is_sliding = "sliding" in layer_type
79
+
80
+ if is_sliding:
81
+ mask = "sliding"
82
+ window = sliding_window
83
+ kv_heads = num_kv
84
+ this_head_dim = head_dim
85
+ else:
86
+ mask = "global"
87
+ window = None
88
+ kv_heads = num_kv_global
89
+ this_head_dim = head_dim_global
90
+
91
+ kv_source: int | None = None
92
+ if i >= first_shared:
93
+ kv_source = _last_matching_layer(layer_types, i, first_shared)
94
+ if kv_source is not None:
95
+ cross_edges.append(
96
+ CrossLayerEdge(
97
+ kind="kv_share",
98
+ from_layer=kv_source,
99
+ to_layer=i,
100
+ shared=["K", "V"],
101
+ )
102
+ )
103
+
104
+ attn_kind = _attention_kind(num_q, kv_heads)
105
+ attn = AttentionSpec(
106
+ kind=attn_kind,
107
+ num_heads=num_q,
108
+ num_kv_heads=kv_heads,
109
+ head_dim=this_head_dim,
110
+ mask=mask,
111
+ window_size=window,
112
+ kv_source_layer=kv_source,
113
+ )
114
+
115
+ if moe_enabled and num_experts:
116
+ ffn = FFNSpec(
117
+ kind="moe",
118
+ activation=activation,
119
+ intermediate_size=intermediate_size,
120
+ gated=True,
121
+ num_experts=num_experts,
122
+ num_experts_per_tok=top_k,
123
+ expert_intermediate_size=moe_intermediate_size or intermediate_size,
124
+ )
125
+ else:
126
+ ffn = FFNSpec(
127
+ kind="dense",
128
+ activation=activation,
129
+ intermediate_size=intermediate_size,
130
+ gated=True,
131
+ )
132
+
133
+ extra_blocks = []
134
+ if ple_dim:
135
+ extra_blocks.extend(
136
+ per_layer_embedding_blocks(hidden_size, ple_dim, activation="gelu")
137
+ )
138
+ layers.append(
139
+ decoder_layer(i, attn, ffn, hidden_size, extra_blocks=extra_blocks)
140
+ )
141
+
142
+ vocab_size = _g(text_cfg, "vocab_size", 0)
143
+ tie_word_embeddings = bool(_g(text_cfg, "tie_word_embeddings", _g(cfg, "tie_word_embeddings", False)))
144
+
145
+ extras = decoder_extras(
146
+ vocab_size,
147
+ hidden_size,
148
+ tie_word_embeddings,
149
+ per_layer_embedding_extras(hidden_size, ple_dim, ple_vocab, num_layers)
150
+ if ple_dim else None,
151
+ )
152
+ if num_kv_shared:
153
+ extras["num_kv_shared_layers"] = num_kv_shared
154
+ if _g(text_cfg, "attention_k_eq_v"):
155
+ extras["attention_k_eq_v"] = True
156
+ if _g(text_cfg, "use_double_wide_mlp"):
157
+ extras["use_double_wide_mlp"] = True
158
+
159
+ return ModelIR(
160
+ name=model_name(cfg, arch_name),
161
+ architecture=arch_name,
162
+ vocab_size=vocab_size,
163
+ hidden_size=hidden_size,
164
+ max_position_embeddings=_g(text_cfg, "max_position_embeddings"),
165
+ tie_word_embeddings=tie_word_embeddings,
166
+ layers=layers,
167
+ cross_layer_edges=cross_edges,
168
+ extras=extras,
169
+ )
170
+
171
+
172
+ def _text_config(cfg: Any) -> Any:
173
+ """Reach the language-model sub-config when ``cfg`` is the multimodal wrapper."""
174
+ if (_g(cfg, "model_type", "") or "").lower() in _TEXT_TYPES:
175
+ return cfg
176
+ sub = _g(cfg, "text_config")
177
+ return sub if sub is not None else cfg
178
+
179
+
180
+ def _attention_kind(num_q: int, num_kv: int) -> str:
181
+ if not num_q:
182
+ return "mha"
183
+ if num_kv == num_q:
184
+ return "mha"
185
+ if num_kv == 1:
186
+ return "mqa"
187
+ return "gqa"
188
+
189
+
190
+ def _last_matching_layer(layer_types: list, i: int, first_shared: int) -> int | None:
191
+ """Source layer for a KV-shared layer.
192
+
193
+ Per the Gemma 4 release notes: shared layers reuse K/V from the most
194
+ recent non-shared layer of the *same* attention type (sliding or full).
195
+ """
196
+ if not layer_types or i >= len(layer_types):
197
+ return None
198
+ target_type = layer_types[i]
199
+ for j in range(min(first_shared, len(layer_types)) - 1, -1, -1):
200
+ if layer_types[j] == target_type:
201
+ return j
202
+ return None
@@ -0,0 +1,91 @@
1
+ """Adapter for Llama, Mistral, Qwen, and similar GQA/MHA dense models."""
2
+ from __future__ import annotations
3
+
4
+ from typing import Any
5
+
6
+ from ....ir import AttentionSpec, FFNSpec, ModelIR
7
+ from ..assembly import decoder_extras, decoder_layer
8
+ from ..common import architecture_name, get_config_value as _g, model_name
9
+
10
+
11
+ _FAMILIES = {"llama", "mistral", "qwen2", "qwen3", "phi3", "gemma"}
12
+
13
+
14
+ def matches(cfg: Any) -> bool:
15
+ arches = _g(cfg, "architectures") or []
16
+ model_type = _g(cfg, "model_type", "")
17
+ for arch in arches:
18
+ if any(fam in arch.lower() for fam in ("llama", "mistral", "qwen", "phi3")):
19
+ return True
20
+ if model_type in _FAMILIES:
21
+ return True
22
+ return False
23
+
24
+
25
+ def parse(cfg: Any) -> ModelIR:
26
+ num_layers = _g(cfg, "num_hidden_layers", 0)
27
+ num_heads = _g(cfg, "num_attention_heads", 0)
28
+ num_kv_heads = _g(cfg, "num_key_value_heads", num_heads)
29
+ hidden_size = _g(cfg, "hidden_size", 0)
30
+ head_dim = _g(cfg, "head_dim") or (hidden_size // num_heads if num_heads else None)
31
+
32
+ if num_kv_heads == num_heads:
33
+ attn_kind = "mha"
34
+ elif num_kv_heads == 1:
35
+ attn_kind = "mqa"
36
+ else:
37
+ attn_kind = "gqa"
38
+
39
+ sliding_window = _g(cfg, "sliding_window")
40
+ sliding_pattern = _g(cfg, "sliding_window_pattern")
41
+ layer_types = _g(cfg, "layer_types")
42
+
43
+ intermediate_size = _g(cfg, "intermediate_size", 0)
44
+ activation = (_g(cfg, "hidden_act", "silu") or "silu").lower()
45
+
46
+ arch_name = architecture_name(cfg, "llama")
47
+
48
+ layers = []
49
+ for i in range(num_layers):
50
+ if layer_types and i < len(layer_types):
51
+ layer_type = layer_types[i]
52
+ if "sliding" in layer_type:
53
+ mask, win = "sliding", sliding_window
54
+ else:
55
+ mask, win = "causal", None
56
+ elif sliding_pattern and sliding_window:
57
+ mask = "sliding" if (i % sliding_pattern) != (sliding_pattern - 1) else "causal"
58
+ win = sliding_window if mask == "sliding" else None
59
+ elif sliding_window:
60
+ mask, win = "sliding", sliding_window
61
+ else:
62
+ mask, win = "causal", None
63
+
64
+ attn = AttentionSpec(
65
+ kind=attn_kind,
66
+ num_heads=num_heads,
67
+ num_kv_heads=num_kv_heads,
68
+ head_dim=head_dim,
69
+ mask=mask,
70
+ window_size=win,
71
+ )
72
+ ffn = FFNSpec(
73
+ kind="dense",
74
+ activation=activation,
75
+ intermediate_size=intermediate_size,
76
+ gated=True,
77
+ )
78
+ layers.append(decoder_layer(i, attn, ffn, hidden_size))
79
+
80
+ vocab_size = _g(cfg, "vocab_size", 0)
81
+ tie_word_embeddings = bool(_g(cfg, "tie_word_embeddings", False))
82
+ return ModelIR(
83
+ name=model_name(cfg, arch_name),
84
+ architecture=arch_name,
85
+ vocab_size=vocab_size,
86
+ hidden_size=hidden_size,
87
+ max_position_embeddings=_g(cfg, "max_position_embeddings"),
88
+ tie_word_embeddings=tie_word_embeddings,
89
+ layers=layers,
90
+ extras=decoder_extras(vocab_size, hidden_size, tie_word_embeddings),
91
+ )
@@ -0,0 +1,2 @@
1
+ """Reusable transformer layer parts."""
2
+
@@ -0,0 +1,220 @@
1
+ """Reusable Per-Layer Embedding (PLE) transformer part.
2
+
3
+ This module owns the canonical IR shape for PLE-style conditioning. A model
4
+ family adapter should only detect that the config has such a pathway, then
5
+ attach these blocks and extras. The renderer consumes the declared block
6
+ metadata, so the feature is not tied to any one model family.
7
+ """
8
+ from __future__ import annotations
9
+
10
+ from ....labels import activation_label
11
+ from ..common import format_dim as _fmt
12
+
13
+
14
+ DEFAULT_BLOCK_ID = "ple"
15
+ DEFAULT_ADD_ID = "add3"
16
+ DEFAULT_PATHWAY_ID = "per_layer_input"
17
+
18
+
19
+ def per_layer_embedding_blocks(
20
+ hidden_size: int,
21
+ embedding_dim: int,
22
+ *,
23
+ activation: str = "gelu",
24
+ block_id: str = DEFAULT_BLOCK_ID,
25
+ add_id: str = DEFAULT_ADD_ID,
26
+ pathway_id: str = DEFAULT_PATHWAY_ID,
27
+ lane: str = "left",
28
+ tap_from: str = "rms1",
29
+ feeds: str | None = None,
30
+ residual_from: str = "add2",
31
+ ) -> list[dict]:
32
+ """Return the layer blocks for a reusable PLE side pathway.
33
+
34
+ The canonical shape is:
35
+
36
+ hidden -> gate -> activation -> multiply(per-layer vector) -> projection
37
+ -> norm -> residual add.
38
+
39
+ The side block is intentionally rendered off the central chain via
40
+ ``lane``/``tap_from``/``feeds``. Future adapters can reuse this exact
41
+ shape without adding renderer branches for their model family.
42
+ """
43
+ feeds = feeds or add_id
44
+ hidden = _fmt(hidden_size)
45
+ emb = _fmt(embedding_dim)
46
+ act_name = _activation_label(activation)
47
+ ids = _child_ids(block_id)
48
+
49
+ children = [
50
+ {
51
+ "id": ids["gate"],
52
+ "label": "Linear (gate)",
53
+ "title": "Per-layer input gate",
54
+ "description": f"Linear; {hidden} -> {emb}",
55
+ },
56
+ {
57
+ "id": ids["activation"],
58
+ "label": act_name,
59
+ "title": "PLE activation",
60
+ "description": f"Element-wise {act_name}",
61
+ },
62
+ {
63
+ "id": ids["multiply"],
64
+ "label": "x",
65
+ "title": "Per-layer gate (x)",
66
+ "description": (
67
+ f"Element-wise multiply by {pathway_id}[L] "
68
+ f"({emb}-d vector sourced from the parallel pathway)"
69
+ ),
70
+ },
71
+ {
72
+ "id": pathway_id,
73
+ "label": f"{pathway_id}[L]",
74
+ "title": "Per-layer input vector",
75
+ "description": f"{emb}-d vector produced outside the layer stack for layer L.",
76
+ },
77
+ {
78
+ "id": ids["projection"],
79
+ "label": "Linear (up)",
80
+ "title": "Per-layer projection",
81
+ "description": f"Linear; {emb} -> {hidden}",
82
+ },
83
+ {
84
+ "id": ids["norm"],
85
+ "label": "RMSNorm",
86
+ "title": "Post-PLE norm",
87
+ "description": f"RMSNorm; dim {hidden}",
88
+ },
89
+ ]
90
+
91
+ return [
92
+ {
93
+ "id": block_id,
94
+ "role": "ple",
95
+ "kind": "ple",
96
+ "label": "PLE",
97
+ "title": "Per-Layer Embeddings",
98
+ "description": (
99
+ f"Per-layer gate-and-project; {hidden} -> {emb} -> {hidden}. "
100
+ "Multiplied by a per-layer vector built outside the stack."
101
+ ),
102
+ "detail_view": "per_layer_embedding",
103
+ "detail": {
104
+ "view": "per_layer_embedding",
105
+ "view_id": block_id,
106
+ "pathway_id": pathway_id,
107
+ "nodes": ids,
108
+ "input_label": "in (hidden)",
109
+ "output_label": "out -> add (residual)",
110
+ "external_label": f"{pathway_id}[L]",
111
+ "external_description": f"({emb}-d, built outside layers)",
112
+ "hidden_size": hidden_size,
113
+ "embedding_dim": embedding_dim,
114
+ },
115
+ "lane": lane,
116
+ "tap_from": tap_from,
117
+ "feeds": feeds,
118
+ "children": children,
119
+ },
120
+ {
121
+ "id": add_id,
122
+ "role": "residual",
123
+ "kind": "residual_add",
124
+ "residual_from": residual_from,
125
+ "label": "+",
126
+ "title": "Residual add (PLE)",
127
+ "description": "post-FFN + PLE output",
128
+ },
129
+ ]
130
+
131
+
132
+ def per_layer_embedding_pathway(
133
+ hidden_size: int,
134
+ embedding_dim: int,
135
+ vocab_size: int,
136
+ num_layers: int,
137
+ *,
138
+ pathway_id: str = DEFAULT_PATHWAY_ID,
139
+ block_id: str = DEFAULT_BLOCK_ID,
140
+ ) -> dict:
141
+ """Return the external pathway descriptor consumed by the PLE block."""
142
+ hidden = _fmt(hidden_size)
143
+ emb = _fmt(embedding_dim)
144
+ vocab = _fmt(vocab_size)
145
+ layers = _fmt(num_layers)
146
+ ids = _child_ids(block_id)
147
+ return {
148
+ "id": pathway_id,
149
+ "label": "Per-Layer Embeddings",
150
+ "short_label": "PLE",
151
+ "description": (
152
+ f"Parallel pathway producing one {emb}-d vector per layer per token; "
153
+ "feeds every layer's PLE gate."
154
+ ),
155
+ "feeds": "every_layer",
156
+ "tap_block": ids["multiply"],
157
+ "construction": [
158
+ {
159
+ "id": f"{block_id}_lookup",
160
+ "label": "embed_tokens_per_layer",
161
+ "kind": "embedding",
162
+ "description": f"Lookup; {vocab} -> {layers} x {emb}",
163
+ },
164
+ {
165
+ "id": f"{block_id}_proj_in",
166
+ "label": "per_layer_model_projection",
167
+ "kind": "linear",
168
+ "description": f"Linear; {hidden} -> {layers} x {emb}",
169
+ },
170
+ {
171
+ "id": f"{block_id}_combine",
172
+ "label": "(token + context) / sqrt(2)",
173
+ "kind": "scale_add",
174
+ "description": "Sum the two pathways and rescale.",
175
+ },
176
+ ],
177
+ }
178
+
179
+
180
+ def per_layer_embedding_extras(
181
+ hidden_size: int,
182
+ embedding_dim: int,
183
+ vocab_size: int,
184
+ num_layers: int,
185
+ *,
186
+ pathway_id: str = DEFAULT_PATHWAY_ID,
187
+ block_id: str = DEFAULT_BLOCK_ID,
188
+ ) -> dict:
189
+ """Return top-level IR extras for a reusable PLE pathway."""
190
+ return {
191
+ "per_layer_embeddings": {
192
+ "hidden": embedding_dim,
193
+ "vocab": vocab_size,
194
+ "pathway_id": pathway_id,
195
+ },
196
+ "external_pathways": [
197
+ per_layer_embedding_pathway(
198
+ hidden_size,
199
+ embedding_dim,
200
+ vocab_size,
201
+ num_layers,
202
+ pathway_id=pathway_id,
203
+ block_id=block_id,
204
+ )
205
+ ],
206
+ }
207
+
208
+
209
+ def _child_ids(block_id: str) -> dict[str, str]:
210
+ return {
211
+ "gate": f"{block_id}_gate",
212
+ "activation": f"{block_id}_act",
213
+ "multiply": f"{block_id}_mul",
214
+ "projection": f"{block_id}_proj",
215
+ "norm": f"{block_id}_norm",
216
+ }
217
+
218
+
219
+ def _activation_label(activation: str) -> str:
220
+ return activation_label(activation or "gelu")
@@ -0,0 +1,95 @@
1
+ """Diagram — the renderable object.
2
+
3
+ Implements ``_repr_html_`` so it auto-renders inline in Jupyter (like
4
+ ``matplotlib`` or a ``pandas`` DataFrame). Outside notebooks, call
5
+ ``.save(path)`` to write a portable HTML file.
6
+ """
7
+ from __future__ import annotations
8
+ import json
9
+ import os
10
+ import uuid
11
+ from .ir import ModelIR
12
+ from .html_renderer import render_document, render_fragment
13
+ from .params import estimate_params, humanize
14
+
15
+
16
+ class Diagram:
17
+ """A renderable diagram of a transformer architecture."""
18
+
19
+ def __init__(self, ir: ModelIR):
20
+ self.ir = ir
21
+ self._mount_id = f"uf-{uuid.uuid4().hex[:10]}"
22
+ self._params = estimate_params(ir)
23
+ self._ir_cache: dict | None = None
24
+ self._html_cache: dict[bool, str] = {}
25
+
26
+ def to_ir(self) -> dict:
27
+ """Return the underlying IR (plus param estimates) as a plain dict."""
28
+ if self._ir_cache is not None:
29
+ return self._ir_cache
30
+
31
+ d = self.ir.to_dict()
32
+ p = self._params
33
+ d["params"] = {
34
+ "total": p["total"],
35
+ "active": p["active"],
36
+ "total_h": humanize(p["total"]),
37
+ "active_h": humanize(p["active"]),
38
+ "is_sparse": p["is_sparse"],
39
+ }
40
+ self._ir_cache = d
41
+ return d
42
+
43
+ def param_count(self) -> dict:
44
+ """Return parameter-count estimates: total / active / per-layer breakdown."""
45
+ return self._params
46
+
47
+ def _repr_html_(self) -> str:
48
+ """Jupyter calls this; returned HTML string is rendered inline."""
49
+ return self._html(standalone=False)
50
+
51
+ def to_html(self, standalone: bool = True) -> str:
52
+ """Return the diagram as an HTML string.
53
+
54
+ Parameters
55
+ ----------
56
+ standalone : bool
57
+ If True (default), wraps the diagram in a full HTML document.
58
+ If False, returns a fragment usable for embedding (Jupyter mode).
59
+ """
60
+ return self._html(standalone=standalone)
61
+
62
+ def save(self, path: str) -> str:
63
+ """Save the diagram to disk.
64
+
65
+ - ``.html`` — interactive standalone document
66
+ - ``.json`` — the underlying IR (no rendering)
67
+ """
68
+ ext = os.path.splitext(path)[1].lower()
69
+ if ext == ".html":
70
+ with open(path, "w", encoding="utf-8") as f:
71
+ f.write(self.to_html(standalone=True))
72
+ elif ext == ".json":
73
+ with open(path, "w", encoding="utf-8") as f:
74
+ json.dump(self.to_ir(), f, indent=2)
75
+ else:
76
+ raise ValueError(
77
+ f"Unsupported extension {ext!r}. Use .html or .json."
78
+ )
79
+ return path
80
+
81
+ def _html(self, standalone: bool) -> str:
82
+ if standalone not in self._html_cache:
83
+ if standalone:
84
+ self._html_cache[standalone] = render_document(self.to_ir(), self._mount_id)
85
+ else:
86
+ self._html_cache[standalone] = render_fragment(self.to_ir(), self._mount_id)
87
+ return self._html_cache[standalone]
88
+
89
+ def __repr__(self) -> str:
90
+ return (
91
+ f"<Diagram {self.ir.name!r} · {self.ir.num_layers} layers · "
92
+ f"~{humanize(self._params['total'])} params"
93
+ + (f" ({humanize(self._params['active'])} active)" if self._params['is_sparse'] else "")
94
+ + ">"
95
+ )
@@ -0,0 +1,5 @@
1
+ """Compatibility wrapper for the HTML/SVG renderer backend."""
2
+
3
+ from .renderers.html import render_document, render_fragment
4
+
5
+ __all__ = ["render_document", "render_fragment"]