model-unfolder 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- model_unfolder/__init__.py +58 -0
- model_unfolder/adapters/__init__.py +15 -0
- model_unfolder/adapters/custom/__init__.py +8 -0
- model_unfolder/adapters/diffusor/__init__.py +8 -0
- model_unfolder/adapters/transformer/__init__.py +5 -0
- model_unfolder/adapters/transformer/assembly.py +57 -0
- model_unfolder/adapters/transformer/blocks.py +238 -0
- model_unfolder/adapters/transformer/common.py +35 -0
- model_unfolder/adapters/transformer/families/__init__.py +12 -0
- model_unfolder/adapters/transformer/families/deepseek.py +107 -0
- model_unfolder/adapters/transformer/families/gemma4.py +202 -0
- model_unfolder/adapters/transformer/families/llama.py +91 -0
- model_unfolder/adapters/transformer/special_parts/__init__.py +2 -0
- model_unfolder/adapters/transformer/special_parts/per_layer_embedding.py +220 -0
- model_unfolder/diagram.py +95 -0
- model_unfolder/html_renderer.py +5 -0
- model_unfolder/ir.py +163 -0
- model_unfolder/labels.py +166 -0
- model_unfolder/params.py +119 -0
- model_unfolder/parser.py +137 -0
- model_unfolder/renderers/__init__.py +1 -0
- model_unfolder/renderers/html/__init__.py +5 -0
- model_unfolder/renderers/html/block_views/__init__.py +20 -0
- model_unfolder/renderers/html/block_views/attention.py +91 -0
- model_unfolder/renderers/html/block_views/feed_forward.py +213 -0
- model_unfolder/renderers/html/block_views/per_layer_embedding.py +199 -0
- model_unfolder/renderers/html/cards.py +130 -0
- model_unfolder/renderers/html/document.py +157 -0
- model_unfolder/renderers/html/interactions.py +64 -0
- model_unfolder/renderers/html/metadata.py +265 -0
- model_unfolder/renderers/html/sections.py +60 -0
- model_unfolder/renderers/html/styles.py +283 -0
- model_unfolder/renderers/html/svg.py +349 -0
- model_unfolder/renderers/html/theme.py +24 -0
- model_unfolder/renderers/html/utils.py +28 -0
- model_unfolder/renderers/html/views.py +461 -0
- model_unfolder-0.2.0.dist-info/METADATA +122 -0
- model_unfolder-0.2.0.dist-info/RECORD +41 -0
- model_unfolder-0.2.0.dist-info/WHEEL +5 -0
- model_unfolder-0.2.0.dist-info/licenses/LICENSE +201 -0
- model_unfolder-0.2.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,202 @@
|
|
|
1
|
+
"""Adapter for Google's Gemma 4 family (E2B / E4B / 31B / 26B-A4B).
|
|
2
|
+
|
|
3
|
+
Gemma 4 nests its language-model fields under ``text_config`` since the
|
|
4
|
+
top-level config also covers the vision and audio encoders. The text stack
|
|
5
|
+
itself has a few features the older Llama/Gemma adapters don't model:
|
|
6
|
+
|
|
7
|
+
* Per-layer ``layer_types`` array tagging each block as ``sliding_attention``
|
|
8
|
+
or ``full_attention`` (instead of Gemma 3's "every Nth layer" pattern).
|
|
9
|
+
* Dual attention shape: sliding layers use ``head_dim`` + ``num_key_value_heads``;
|
|
10
|
+
global layers use ``global_head_dim`` + ``num_global_key_value_heads``.
|
|
11
|
+
* Optional MoE FFN (26B-A4B) controlled by ``enable_moe_block``.
|
|
12
|
+
* Per-Layer Embeddings (PLE) on small models — ``hidden_size_per_layer_input``
|
|
13
|
+
is the parallel conditioning dim, ``vocab_size_per_layer_input`` its vocab.
|
|
14
|
+
* Shared-KV layers — the last ``num_kv_shared_layers`` layers reuse K/V from
|
|
15
|
+
the last earlier layer of the same attention type.
|
|
16
|
+
"""
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
from typing import Any
|
|
20
|
+
|
|
21
|
+
from ....ir import AttentionSpec, CrossLayerEdge, FFNSpec, ModelIR
|
|
22
|
+
from ..assembly import decoder_extras, decoder_layer
|
|
23
|
+
from ..common import architecture_name, get_config_value as _g, model_name
|
|
24
|
+
from ..special_parts.per_layer_embedding import (
|
|
25
|
+
per_layer_embedding_blocks,
|
|
26
|
+
per_layer_embedding_extras,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
_TOP_TYPES = {"gemma4"}
|
|
31
|
+
_TEXT_TYPES = {"gemma4_text"}
|
|
32
|
+
_ARCH_HINTS = ("gemma4",)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def matches(cfg: Any) -> bool:
|
|
36
|
+
arches = _g(cfg, "architectures") or []
|
|
37
|
+
if any(any(hint in arch.lower() for hint in _ARCH_HINTS) for arch in arches):
|
|
38
|
+
return True
|
|
39
|
+
model_type = (_g(cfg, "model_type", "") or "").lower()
|
|
40
|
+
if model_type in _TOP_TYPES or model_type in _TEXT_TYPES:
|
|
41
|
+
return True
|
|
42
|
+
return False
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def parse(cfg: Any) -> ModelIR:
|
|
46
|
+
arch_name = architecture_name(cfg, "gemma4")
|
|
47
|
+
text_cfg = _text_config(cfg)
|
|
48
|
+
|
|
49
|
+
num_layers = _g(text_cfg, "num_hidden_layers", 0)
|
|
50
|
+
hidden_size = _g(text_cfg, "hidden_size", 0)
|
|
51
|
+
intermediate_size = _g(text_cfg, "intermediate_size", 0)
|
|
52
|
+
activation = (_g(text_cfg, "hidden_activation") or _g(text_cfg, "hidden_act") or "gelu_pytorch_tanh").lower()
|
|
53
|
+
|
|
54
|
+
num_q = _g(text_cfg, "num_attention_heads", 0)
|
|
55
|
+
num_kv = _g(text_cfg, "num_key_value_heads", num_q)
|
|
56
|
+
num_kv_global = _g(text_cfg, "num_global_key_value_heads") or num_kv
|
|
57
|
+
head_dim = _g(text_cfg, "head_dim") or (hidden_size // num_q if num_q else None)
|
|
58
|
+
head_dim_global = _g(text_cfg, "global_head_dim") or head_dim
|
|
59
|
+
sliding_window = _g(text_cfg, "sliding_window")
|
|
60
|
+
|
|
61
|
+
layer_types = _g(text_cfg, "layer_types") or []
|
|
62
|
+
|
|
63
|
+
moe_enabled = bool(_g(text_cfg, "enable_moe_block"))
|
|
64
|
+
num_experts = _g(text_cfg, "num_experts") or 0
|
|
65
|
+
top_k = _g(text_cfg, "top_k_experts") or 0
|
|
66
|
+
moe_intermediate_size = _g(text_cfg, "moe_intermediate_size") or 0
|
|
67
|
+
|
|
68
|
+
num_kv_shared = _g(text_cfg, "num_kv_shared_layers") or 0
|
|
69
|
+
first_shared = num_layers - num_kv_shared if num_kv_shared else num_layers
|
|
70
|
+
|
|
71
|
+
ple_dim = _g(text_cfg, "hidden_size_per_layer_input") or 0
|
|
72
|
+
ple_vocab = _g(text_cfg, "vocab_size_per_layer_input") or _g(text_cfg, "vocab_size", 0)
|
|
73
|
+
|
|
74
|
+
layers = []
|
|
75
|
+
cross_edges: list[CrossLayerEdge] = []
|
|
76
|
+
for i in range(num_layers):
|
|
77
|
+
layer_type = layer_types[i] if i < len(layer_types) else "full_attention"
|
|
78
|
+
is_sliding = "sliding" in layer_type
|
|
79
|
+
|
|
80
|
+
if is_sliding:
|
|
81
|
+
mask = "sliding"
|
|
82
|
+
window = sliding_window
|
|
83
|
+
kv_heads = num_kv
|
|
84
|
+
this_head_dim = head_dim
|
|
85
|
+
else:
|
|
86
|
+
mask = "global"
|
|
87
|
+
window = None
|
|
88
|
+
kv_heads = num_kv_global
|
|
89
|
+
this_head_dim = head_dim_global
|
|
90
|
+
|
|
91
|
+
kv_source: int | None = None
|
|
92
|
+
if i >= first_shared:
|
|
93
|
+
kv_source = _last_matching_layer(layer_types, i, first_shared)
|
|
94
|
+
if kv_source is not None:
|
|
95
|
+
cross_edges.append(
|
|
96
|
+
CrossLayerEdge(
|
|
97
|
+
kind="kv_share",
|
|
98
|
+
from_layer=kv_source,
|
|
99
|
+
to_layer=i,
|
|
100
|
+
shared=["K", "V"],
|
|
101
|
+
)
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
attn_kind = _attention_kind(num_q, kv_heads)
|
|
105
|
+
attn = AttentionSpec(
|
|
106
|
+
kind=attn_kind,
|
|
107
|
+
num_heads=num_q,
|
|
108
|
+
num_kv_heads=kv_heads,
|
|
109
|
+
head_dim=this_head_dim,
|
|
110
|
+
mask=mask,
|
|
111
|
+
window_size=window,
|
|
112
|
+
kv_source_layer=kv_source,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
if moe_enabled and num_experts:
|
|
116
|
+
ffn = FFNSpec(
|
|
117
|
+
kind="moe",
|
|
118
|
+
activation=activation,
|
|
119
|
+
intermediate_size=intermediate_size,
|
|
120
|
+
gated=True,
|
|
121
|
+
num_experts=num_experts,
|
|
122
|
+
num_experts_per_tok=top_k,
|
|
123
|
+
expert_intermediate_size=moe_intermediate_size or intermediate_size,
|
|
124
|
+
)
|
|
125
|
+
else:
|
|
126
|
+
ffn = FFNSpec(
|
|
127
|
+
kind="dense",
|
|
128
|
+
activation=activation,
|
|
129
|
+
intermediate_size=intermediate_size,
|
|
130
|
+
gated=True,
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
extra_blocks = []
|
|
134
|
+
if ple_dim:
|
|
135
|
+
extra_blocks.extend(
|
|
136
|
+
per_layer_embedding_blocks(hidden_size, ple_dim, activation="gelu")
|
|
137
|
+
)
|
|
138
|
+
layers.append(
|
|
139
|
+
decoder_layer(i, attn, ffn, hidden_size, extra_blocks=extra_blocks)
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
vocab_size = _g(text_cfg, "vocab_size", 0)
|
|
143
|
+
tie_word_embeddings = bool(_g(text_cfg, "tie_word_embeddings", _g(cfg, "tie_word_embeddings", False)))
|
|
144
|
+
|
|
145
|
+
extras = decoder_extras(
|
|
146
|
+
vocab_size,
|
|
147
|
+
hidden_size,
|
|
148
|
+
tie_word_embeddings,
|
|
149
|
+
per_layer_embedding_extras(hidden_size, ple_dim, ple_vocab, num_layers)
|
|
150
|
+
if ple_dim else None,
|
|
151
|
+
)
|
|
152
|
+
if num_kv_shared:
|
|
153
|
+
extras["num_kv_shared_layers"] = num_kv_shared
|
|
154
|
+
if _g(text_cfg, "attention_k_eq_v"):
|
|
155
|
+
extras["attention_k_eq_v"] = True
|
|
156
|
+
if _g(text_cfg, "use_double_wide_mlp"):
|
|
157
|
+
extras["use_double_wide_mlp"] = True
|
|
158
|
+
|
|
159
|
+
return ModelIR(
|
|
160
|
+
name=model_name(cfg, arch_name),
|
|
161
|
+
architecture=arch_name,
|
|
162
|
+
vocab_size=vocab_size,
|
|
163
|
+
hidden_size=hidden_size,
|
|
164
|
+
max_position_embeddings=_g(text_cfg, "max_position_embeddings"),
|
|
165
|
+
tie_word_embeddings=tie_word_embeddings,
|
|
166
|
+
layers=layers,
|
|
167
|
+
cross_layer_edges=cross_edges,
|
|
168
|
+
extras=extras,
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def _text_config(cfg: Any) -> Any:
|
|
173
|
+
"""Reach the language-model sub-config when ``cfg`` is the multimodal wrapper."""
|
|
174
|
+
if (_g(cfg, "model_type", "") or "").lower() in _TEXT_TYPES:
|
|
175
|
+
return cfg
|
|
176
|
+
sub = _g(cfg, "text_config")
|
|
177
|
+
return sub if sub is not None else cfg
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def _attention_kind(num_q: int, num_kv: int) -> str:
|
|
181
|
+
if not num_q:
|
|
182
|
+
return "mha"
|
|
183
|
+
if num_kv == num_q:
|
|
184
|
+
return "mha"
|
|
185
|
+
if num_kv == 1:
|
|
186
|
+
return "mqa"
|
|
187
|
+
return "gqa"
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def _last_matching_layer(layer_types: list, i: int, first_shared: int) -> int | None:
|
|
191
|
+
"""Source layer for a KV-shared layer.
|
|
192
|
+
|
|
193
|
+
Per the Gemma 4 release notes: shared layers reuse K/V from the most
|
|
194
|
+
recent non-shared layer of the *same* attention type (sliding or full).
|
|
195
|
+
"""
|
|
196
|
+
if not layer_types or i >= len(layer_types):
|
|
197
|
+
return None
|
|
198
|
+
target_type = layer_types[i]
|
|
199
|
+
for j in range(min(first_shared, len(layer_types)) - 1, -1, -1):
|
|
200
|
+
if layer_types[j] == target_type:
|
|
201
|
+
return j
|
|
202
|
+
return None
|
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
"""Adapter for Llama, Mistral, Qwen, and similar GQA/MHA dense models."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from ....ir import AttentionSpec, FFNSpec, ModelIR
|
|
7
|
+
from ..assembly import decoder_extras, decoder_layer
|
|
8
|
+
from ..common import architecture_name, get_config_value as _g, model_name
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
_FAMILIES = {"llama", "mistral", "qwen2", "qwen3", "phi3", "gemma"}
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def matches(cfg: Any) -> bool:
|
|
15
|
+
arches = _g(cfg, "architectures") or []
|
|
16
|
+
model_type = _g(cfg, "model_type", "")
|
|
17
|
+
for arch in arches:
|
|
18
|
+
if any(fam in arch.lower() for fam in ("llama", "mistral", "qwen", "phi3")):
|
|
19
|
+
return True
|
|
20
|
+
if model_type in _FAMILIES:
|
|
21
|
+
return True
|
|
22
|
+
return False
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def parse(cfg: Any) -> ModelIR:
|
|
26
|
+
num_layers = _g(cfg, "num_hidden_layers", 0)
|
|
27
|
+
num_heads = _g(cfg, "num_attention_heads", 0)
|
|
28
|
+
num_kv_heads = _g(cfg, "num_key_value_heads", num_heads)
|
|
29
|
+
hidden_size = _g(cfg, "hidden_size", 0)
|
|
30
|
+
head_dim = _g(cfg, "head_dim") or (hidden_size // num_heads if num_heads else None)
|
|
31
|
+
|
|
32
|
+
if num_kv_heads == num_heads:
|
|
33
|
+
attn_kind = "mha"
|
|
34
|
+
elif num_kv_heads == 1:
|
|
35
|
+
attn_kind = "mqa"
|
|
36
|
+
else:
|
|
37
|
+
attn_kind = "gqa"
|
|
38
|
+
|
|
39
|
+
sliding_window = _g(cfg, "sliding_window")
|
|
40
|
+
sliding_pattern = _g(cfg, "sliding_window_pattern")
|
|
41
|
+
layer_types = _g(cfg, "layer_types")
|
|
42
|
+
|
|
43
|
+
intermediate_size = _g(cfg, "intermediate_size", 0)
|
|
44
|
+
activation = (_g(cfg, "hidden_act", "silu") or "silu").lower()
|
|
45
|
+
|
|
46
|
+
arch_name = architecture_name(cfg, "llama")
|
|
47
|
+
|
|
48
|
+
layers = []
|
|
49
|
+
for i in range(num_layers):
|
|
50
|
+
if layer_types and i < len(layer_types):
|
|
51
|
+
layer_type = layer_types[i]
|
|
52
|
+
if "sliding" in layer_type:
|
|
53
|
+
mask, win = "sliding", sliding_window
|
|
54
|
+
else:
|
|
55
|
+
mask, win = "causal", None
|
|
56
|
+
elif sliding_pattern and sliding_window:
|
|
57
|
+
mask = "sliding" if (i % sliding_pattern) != (sliding_pattern - 1) else "causal"
|
|
58
|
+
win = sliding_window if mask == "sliding" else None
|
|
59
|
+
elif sliding_window:
|
|
60
|
+
mask, win = "sliding", sliding_window
|
|
61
|
+
else:
|
|
62
|
+
mask, win = "causal", None
|
|
63
|
+
|
|
64
|
+
attn = AttentionSpec(
|
|
65
|
+
kind=attn_kind,
|
|
66
|
+
num_heads=num_heads,
|
|
67
|
+
num_kv_heads=num_kv_heads,
|
|
68
|
+
head_dim=head_dim,
|
|
69
|
+
mask=mask,
|
|
70
|
+
window_size=win,
|
|
71
|
+
)
|
|
72
|
+
ffn = FFNSpec(
|
|
73
|
+
kind="dense",
|
|
74
|
+
activation=activation,
|
|
75
|
+
intermediate_size=intermediate_size,
|
|
76
|
+
gated=True,
|
|
77
|
+
)
|
|
78
|
+
layers.append(decoder_layer(i, attn, ffn, hidden_size))
|
|
79
|
+
|
|
80
|
+
vocab_size = _g(cfg, "vocab_size", 0)
|
|
81
|
+
tie_word_embeddings = bool(_g(cfg, "tie_word_embeddings", False))
|
|
82
|
+
return ModelIR(
|
|
83
|
+
name=model_name(cfg, arch_name),
|
|
84
|
+
architecture=arch_name,
|
|
85
|
+
vocab_size=vocab_size,
|
|
86
|
+
hidden_size=hidden_size,
|
|
87
|
+
max_position_embeddings=_g(cfg, "max_position_embeddings"),
|
|
88
|
+
tie_word_embeddings=tie_word_embeddings,
|
|
89
|
+
layers=layers,
|
|
90
|
+
extras=decoder_extras(vocab_size, hidden_size, tie_word_embeddings),
|
|
91
|
+
)
|
|
@@ -0,0 +1,220 @@
|
|
|
1
|
+
"""Reusable Per-Layer Embedding (PLE) transformer part.
|
|
2
|
+
|
|
3
|
+
This module owns the canonical IR shape for PLE-style conditioning. A model
|
|
4
|
+
family adapter should only detect that the config has such a pathway, then
|
|
5
|
+
attach these blocks and extras. The renderer consumes the declared block
|
|
6
|
+
metadata, so the feature is not tied to any one model family.
|
|
7
|
+
"""
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from ....labels import activation_label
|
|
11
|
+
from ..common import format_dim as _fmt
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
DEFAULT_BLOCK_ID = "ple"
|
|
15
|
+
DEFAULT_ADD_ID = "add3"
|
|
16
|
+
DEFAULT_PATHWAY_ID = "per_layer_input"
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def per_layer_embedding_blocks(
|
|
20
|
+
hidden_size: int,
|
|
21
|
+
embedding_dim: int,
|
|
22
|
+
*,
|
|
23
|
+
activation: str = "gelu",
|
|
24
|
+
block_id: str = DEFAULT_BLOCK_ID,
|
|
25
|
+
add_id: str = DEFAULT_ADD_ID,
|
|
26
|
+
pathway_id: str = DEFAULT_PATHWAY_ID,
|
|
27
|
+
lane: str = "left",
|
|
28
|
+
tap_from: str = "rms1",
|
|
29
|
+
feeds: str | None = None,
|
|
30
|
+
residual_from: str = "add2",
|
|
31
|
+
) -> list[dict]:
|
|
32
|
+
"""Return the layer blocks for a reusable PLE side pathway.
|
|
33
|
+
|
|
34
|
+
The canonical shape is:
|
|
35
|
+
|
|
36
|
+
hidden -> gate -> activation -> multiply(per-layer vector) -> projection
|
|
37
|
+
-> norm -> residual add.
|
|
38
|
+
|
|
39
|
+
The side block is intentionally rendered off the central chain via
|
|
40
|
+
``lane``/``tap_from``/``feeds``. Future adapters can reuse this exact
|
|
41
|
+
shape without adding renderer branches for their model family.
|
|
42
|
+
"""
|
|
43
|
+
feeds = feeds or add_id
|
|
44
|
+
hidden = _fmt(hidden_size)
|
|
45
|
+
emb = _fmt(embedding_dim)
|
|
46
|
+
act_name = _activation_label(activation)
|
|
47
|
+
ids = _child_ids(block_id)
|
|
48
|
+
|
|
49
|
+
children = [
|
|
50
|
+
{
|
|
51
|
+
"id": ids["gate"],
|
|
52
|
+
"label": "Linear (gate)",
|
|
53
|
+
"title": "Per-layer input gate",
|
|
54
|
+
"description": f"Linear; {hidden} -> {emb}",
|
|
55
|
+
},
|
|
56
|
+
{
|
|
57
|
+
"id": ids["activation"],
|
|
58
|
+
"label": act_name,
|
|
59
|
+
"title": "PLE activation",
|
|
60
|
+
"description": f"Element-wise {act_name}",
|
|
61
|
+
},
|
|
62
|
+
{
|
|
63
|
+
"id": ids["multiply"],
|
|
64
|
+
"label": "x",
|
|
65
|
+
"title": "Per-layer gate (x)",
|
|
66
|
+
"description": (
|
|
67
|
+
f"Element-wise multiply by {pathway_id}[L] "
|
|
68
|
+
f"({emb}-d vector sourced from the parallel pathway)"
|
|
69
|
+
),
|
|
70
|
+
},
|
|
71
|
+
{
|
|
72
|
+
"id": pathway_id,
|
|
73
|
+
"label": f"{pathway_id}[L]",
|
|
74
|
+
"title": "Per-layer input vector",
|
|
75
|
+
"description": f"{emb}-d vector produced outside the layer stack for layer L.",
|
|
76
|
+
},
|
|
77
|
+
{
|
|
78
|
+
"id": ids["projection"],
|
|
79
|
+
"label": "Linear (up)",
|
|
80
|
+
"title": "Per-layer projection",
|
|
81
|
+
"description": f"Linear; {emb} -> {hidden}",
|
|
82
|
+
},
|
|
83
|
+
{
|
|
84
|
+
"id": ids["norm"],
|
|
85
|
+
"label": "RMSNorm",
|
|
86
|
+
"title": "Post-PLE norm",
|
|
87
|
+
"description": f"RMSNorm; dim {hidden}",
|
|
88
|
+
},
|
|
89
|
+
]
|
|
90
|
+
|
|
91
|
+
return [
|
|
92
|
+
{
|
|
93
|
+
"id": block_id,
|
|
94
|
+
"role": "ple",
|
|
95
|
+
"kind": "ple",
|
|
96
|
+
"label": "PLE",
|
|
97
|
+
"title": "Per-Layer Embeddings",
|
|
98
|
+
"description": (
|
|
99
|
+
f"Per-layer gate-and-project; {hidden} -> {emb} -> {hidden}. "
|
|
100
|
+
"Multiplied by a per-layer vector built outside the stack."
|
|
101
|
+
),
|
|
102
|
+
"detail_view": "per_layer_embedding",
|
|
103
|
+
"detail": {
|
|
104
|
+
"view": "per_layer_embedding",
|
|
105
|
+
"view_id": block_id,
|
|
106
|
+
"pathway_id": pathway_id,
|
|
107
|
+
"nodes": ids,
|
|
108
|
+
"input_label": "in (hidden)",
|
|
109
|
+
"output_label": "out -> add (residual)",
|
|
110
|
+
"external_label": f"{pathway_id}[L]",
|
|
111
|
+
"external_description": f"({emb}-d, built outside layers)",
|
|
112
|
+
"hidden_size": hidden_size,
|
|
113
|
+
"embedding_dim": embedding_dim,
|
|
114
|
+
},
|
|
115
|
+
"lane": lane,
|
|
116
|
+
"tap_from": tap_from,
|
|
117
|
+
"feeds": feeds,
|
|
118
|
+
"children": children,
|
|
119
|
+
},
|
|
120
|
+
{
|
|
121
|
+
"id": add_id,
|
|
122
|
+
"role": "residual",
|
|
123
|
+
"kind": "residual_add",
|
|
124
|
+
"residual_from": residual_from,
|
|
125
|
+
"label": "+",
|
|
126
|
+
"title": "Residual add (PLE)",
|
|
127
|
+
"description": "post-FFN + PLE output",
|
|
128
|
+
},
|
|
129
|
+
]
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def per_layer_embedding_pathway(
|
|
133
|
+
hidden_size: int,
|
|
134
|
+
embedding_dim: int,
|
|
135
|
+
vocab_size: int,
|
|
136
|
+
num_layers: int,
|
|
137
|
+
*,
|
|
138
|
+
pathway_id: str = DEFAULT_PATHWAY_ID,
|
|
139
|
+
block_id: str = DEFAULT_BLOCK_ID,
|
|
140
|
+
) -> dict:
|
|
141
|
+
"""Return the external pathway descriptor consumed by the PLE block."""
|
|
142
|
+
hidden = _fmt(hidden_size)
|
|
143
|
+
emb = _fmt(embedding_dim)
|
|
144
|
+
vocab = _fmt(vocab_size)
|
|
145
|
+
layers = _fmt(num_layers)
|
|
146
|
+
ids = _child_ids(block_id)
|
|
147
|
+
return {
|
|
148
|
+
"id": pathway_id,
|
|
149
|
+
"label": "Per-Layer Embeddings",
|
|
150
|
+
"short_label": "PLE",
|
|
151
|
+
"description": (
|
|
152
|
+
f"Parallel pathway producing one {emb}-d vector per layer per token; "
|
|
153
|
+
"feeds every layer's PLE gate."
|
|
154
|
+
),
|
|
155
|
+
"feeds": "every_layer",
|
|
156
|
+
"tap_block": ids["multiply"],
|
|
157
|
+
"construction": [
|
|
158
|
+
{
|
|
159
|
+
"id": f"{block_id}_lookup",
|
|
160
|
+
"label": "embed_tokens_per_layer",
|
|
161
|
+
"kind": "embedding",
|
|
162
|
+
"description": f"Lookup; {vocab} -> {layers} x {emb}",
|
|
163
|
+
},
|
|
164
|
+
{
|
|
165
|
+
"id": f"{block_id}_proj_in",
|
|
166
|
+
"label": "per_layer_model_projection",
|
|
167
|
+
"kind": "linear",
|
|
168
|
+
"description": f"Linear; {hidden} -> {layers} x {emb}",
|
|
169
|
+
},
|
|
170
|
+
{
|
|
171
|
+
"id": f"{block_id}_combine",
|
|
172
|
+
"label": "(token + context) / sqrt(2)",
|
|
173
|
+
"kind": "scale_add",
|
|
174
|
+
"description": "Sum the two pathways and rescale.",
|
|
175
|
+
},
|
|
176
|
+
],
|
|
177
|
+
}
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def per_layer_embedding_extras(
|
|
181
|
+
hidden_size: int,
|
|
182
|
+
embedding_dim: int,
|
|
183
|
+
vocab_size: int,
|
|
184
|
+
num_layers: int,
|
|
185
|
+
*,
|
|
186
|
+
pathway_id: str = DEFAULT_PATHWAY_ID,
|
|
187
|
+
block_id: str = DEFAULT_BLOCK_ID,
|
|
188
|
+
) -> dict:
|
|
189
|
+
"""Return top-level IR extras for a reusable PLE pathway."""
|
|
190
|
+
return {
|
|
191
|
+
"per_layer_embeddings": {
|
|
192
|
+
"hidden": embedding_dim,
|
|
193
|
+
"vocab": vocab_size,
|
|
194
|
+
"pathway_id": pathway_id,
|
|
195
|
+
},
|
|
196
|
+
"external_pathways": [
|
|
197
|
+
per_layer_embedding_pathway(
|
|
198
|
+
hidden_size,
|
|
199
|
+
embedding_dim,
|
|
200
|
+
vocab_size,
|
|
201
|
+
num_layers,
|
|
202
|
+
pathway_id=pathway_id,
|
|
203
|
+
block_id=block_id,
|
|
204
|
+
)
|
|
205
|
+
],
|
|
206
|
+
}
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def _child_ids(block_id: str) -> dict[str, str]:
|
|
210
|
+
return {
|
|
211
|
+
"gate": f"{block_id}_gate",
|
|
212
|
+
"activation": f"{block_id}_act",
|
|
213
|
+
"multiply": f"{block_id}_mul",
|
|
214
|
+
"projection": f"{block_id}_proj",
|
|
215
|
+
"norm": f"{block_id}_norm",
|
|
216
|
+
}
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def _activation_label(activation: str) -> str:
|
|
220
|
+
return activation_label(activation or "gelu")
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
"""Diagram — the renderable object.
|
|
2
|
+
|
|
3
|
+
Implements ``_repr_html_`` so it auto-renders inline in Jupyter (like
|
|
4
|
+
``matplotlib`` or a ``pandas`` DataFrame). Outside notebooks, call
|
|
5
|
+
``.save(path)`` to write a portable HTML file.
|
|
6
|
+
"""
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
import json
|
|
9
|
+
import os
|
|
10
|
+
import uuid
|
|
11
|
+
from .ir import ModelIR
|
|
12
|
+
from .html_renderer import render_document, render_fragment
|
|
13
|
+
from .params import estimate_params, humanize
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class Diagram:
|
|
17
|
+
"""A renderable diagram of a transformer architecture."""
|
|
18
|
+
|
|
19
|
+
def __init__(self, ir: ModelIR):
|
|
20
|
+
self.ir = ir
|
|
21
|
+
self._mount_id = f"uf-{uuid.uuid4().hex[:10]}"
|
|
22
|
+
self._params = estimate_params(ir)
|
|
23
|
+
self._ir_cache: dict | None = None
|
|
24
|
+
self._html_cache: dict[bool, str] = {}
|
|
25
|
+
|
|
26
|
+
def to_ir(self) -> dict:
|
|
27
|
+
"""Return the underlying IR (plus param estimates) as a plain dict."""
|
|
28
|
+
if self._ir_cache is not None:
|
|
29
|
+
return self._ir_cache
|
|
30
|
+
|
|
31
|
+
d = self.ir.to_dict()
|
|
32
|
+
p = self._params
|
|
33
|
+
d["params"] = {
|
|
34
|
+
"total": p["total"],
|
|
35
|
+
"active": p["active"],
|
|
36
|
+
"total_h": humanize(p["total"]),
|
|
37
|
+
"active_h": humanize(p["active"]),
|
|
38
|
+
"is_sparse": p["is_sparse"],
|
|
39
|
+
}
|
|
40
|
+
self._ir_cache = d
|
|
41
|
+
return d
|
|
42
|
+
|
|
43
|
+
def param_count(self) -> dict:
|
|
44
|
+
"""Return parameter-count estimates: total / active / per-layer breakdown."""
|
|
45
|
+
return self._params
|
|
46
|
+
|
|
47
|
+
def _repr_html_(self) -> str:
|
|
48
|
+
"""Jupyter calls this; returned HTML string is rendered inline."""
|
|
49
|
+
return self._html(standalone=False)
|
|
50
|
+
|
|
51
|
+
def to_html(self, standalone: bool = True) -> str:
|
|
52
|
+
"""Return the diagram as an HTML string.
|
|
53
|
+
|
|
54
|
+
Parameters
|
|
55
|
+
----------
|
|
56
|
+
standalone : bool
|
|
57
|
+
If True (default), wraps the diagram in a full HTML document.
|
|
58
|
+
If False, returns a fragment usable for embedding (Jupyter mode).
|
|
59
|
+
"""
|
|
60
|
+
return self._html(standalone=standalone)
|
|
61
|
+
|
|
62
|
+
def save(self, path: str) -> str:
|
|
63
|
+
"""Save the diagram to disk.
|
|
64
|
+
|
|
65
|
+
- ``.html`` — interactive standalone document
|
|
66
|
+
- ``.json`` — the underlying IR (no rendering)
|
|
67
|
+
"""
|
|
68
|
+
ext = os.path.splitext(path)[1].lower()
|
|
69
|
+
if ext == ".html":
|
|
70
|
+
with open(path, "w", encoding="utf-8") as f:
|
|
71
|
+
f.write(self.to_html(standalone=True))
|
|
72
|
+
elif ext == ".json":
|
|
73
|
+
with open(path, "w", encoding="utf-8") as f:
|
|
74
|
+
json.dump(self.to_ir(), f, indent=2)
|
|
75
|
+
else:
|
|
76
|
+
raise ValueError(
|
|
77
|
+
f"Unsupported extension {ext!r}. Use .html or .json."
|
|
78
|
+
)
|
|
79
|
+
return path
|
|
80
|
+
|
|
81
|
+
def _html(self, standalone: bool) -> str:
|
|
82
|
+
if standalone not in self._html_cache:
|
|
83
|
+
if standalone:
|
|
84
|
+
self._html_cache[standalone] = render_document(self.to_ir(), self._mount_id)
|
|
85
|
+
else:
|
|
86
|
+
self._html_cache[standalone] = render_fragment(self.to_ir(), self._mount_id)
|
|
87
|
+
return self._html_cache[standalone]
|
|
88
|
+
|
|
89
|
+
def __repr__(self) -> str:
|
|
90
|
+
return (
|
|
91
|
+
f"<Diagram {self.ir.name!r} · {self.ir.num_layers} layers · "
|
|
92
|
+
f"~{humanize(self._params['total'])} params"
|
|
93
|
+
+ (f" ({humanize(self._params['active'])} active)" if self._params['is_sparse'] else "")
|
|
94
|
+
+ ">"
|
|
95
|
+
)
|