model-unfolder 0.2.6__tar.gz → 0.2.7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/PKG-INFO +1 -1
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/model_unfolder/adapters/transformer/blocks/attention.py +194 -13
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/model_unfolder/renderers/html/block_views/__init__.py +5 -0
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/model_unfolder/renderers/html/block_views/attention_types/__init__.py +4 -0
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/model_unfolder/renderers/html/block_views/attention_types/common.py +209 -1
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/model_unfolder/renderers/html/block_views/attention_types/grouped_query.py +31 -23
- model_unfolder-0.2.7/model_unfolder/renderers/html/block_views/attention_types/latent.py +262 -0
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/model_unfolder/renderers/html/block_views/attention_types/linear.py +3 -2
- model_unfolder-0.2.7/model_unfolder/renderers/html/block_views/attention_types/multi_head.py +90 -0
- model_unfolder-0.2.7/model_unfolder/renderers/html/block_views/attention_types/multi_query.py +89 -0
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/model_unfolder/renderers/html/block_views/attention_types/rwkv.py +3 -2
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/model_unfolder/renderers/html/svg.py +19 -2
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/model_unfolder.egg-info/PKG-INFO +1 -1
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/pyproject.toml +1 -1
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/tests/test_smoke.py +3 -1
- model_unfolder-0.2.6/model_unfolder/renderers/html/block_views/attention_types/latent.py +0 -61
- model_unfolder-0.2.6/model_unfolder/renderers/html/block_views/attention_types/multi_head.py +0 -74
- model_unfolder-0.2.6/model_unfolder/renderers/html/block_views/attention_types/multi_query.py +0 -118
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/LICENSE +0 -0
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/README.md +0 -0
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/model_unfolder/__init__.py +0 -0
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/model_unfolder/adapters/__init__.py +0 -0
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/model_unfolder/adapters/custom/__init__.py +0 -0
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/model_unfolder/adapters/diffusor/__init__.py +0 -0
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/model_unfolder/adapters/transformer/__init__.py +0 -0
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/model_unfolder/adapters/transformer/assembly.py +0 -0
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/model_unfolder/adapters/transformer/blocks/__init__.py +0 -0
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/model_unfolder/adapters/transformer/blocks/descriptions.py +0 -0
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/model_unfolder/adapters/transformer/blocks/feed_forward.py +0 -0
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/model_unfolder/adapters/transformer/blocks/layers.py +0 -0
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/model_unfolder/adapters/transformer/blocks/model.py +0 -0
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/model_unfolder/adapters/transformer/common.py +0 -0
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/model_unfolder/adapters/transformer/families/__init__.py +0 -0
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/model_unfolder/adapters/transformer/families/cohere.py +0 -0
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/model_unfolder/adapters/transformer/families/deepseek.py +0 -0
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/model_unfolder/adapters/transformer/families/falcon.py +0 -0
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/model_unfolder/adapters/transformer/families/fallback.py +0 -0
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/model_unfolder/adapters/transformer/families/gemma/__init__.py +0 -0
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/model_unfolder/adapters/transformer/families/gemma/gemma2.py +0 -0
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/model_unfolder/adapters/transformer/families/gemma/gemma3.py +0 -0
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/model_unfolder/adapters/transformer/families/gemma/gemma4.py +0 -0
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/model_unfolder/adapters/transformer/families/gpt_neox.py +0 -0
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/model_unfolder/adapters/transformer/families/jamba.py +0 -0
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/model_unfolder/adapters/transformer/families/llama.py +0 -0
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/model_unfolder/adapters/transformer/families/minimax.py +0 -0
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/model_unfolder/adapters/transformer/families/mistral.py +0 -0
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/model_unfolder/adapters/transformer/families/qwen.py +0 -0
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/model_unfolder/adapters/transformer/families/recurrent_gemma.py +0 -0
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/model_unfolder/adapters/transformer/families/rwkv.py +0 -0
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/model_unfolder/adapters/transformer/families/zamba.py +0 -0
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/model_unfolder/adapters/transformer/special_parts/__init__.py +0 -0
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/model_unfolder/adapters/transformer/special_parts/per_layer_embedding.py +0 -0
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/model_unfolder/diagram.py +0 -0
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/model_unfolder/html_renderer.py +0 -0
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/model_unfolder/ir.py +0 -0
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/model_unfolder/labels.py +0 -0
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/model_unfolder/params.py +0 -0
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/model_unfolder/parser.py +0 -0
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/model_unfolder/renderers/__init__.py +0 -0
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/model_unfolder/renderers/html/__init__.py +0 -0
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/model_unfolder/renderers/html/block_views/attention.py +0 -0
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/model_unfolder/renderers/html/block_views/attention_types/state_space.py +0 -0
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/model_unfolder/renderers/html/block_views/feed_forward.py +0 -0
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/model_unfolder/renderers/html/block_views/mixture_of_experts.py +0 -0
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/model_unfolder/renderers/html/block_views/per_layer_embedding.py +0 -0
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/model_unfolder/renderers/html/cards.py +0 -0
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/model_unfolder/renderers/html/document.py +0 -0
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/model_unfolder/renderers/html/interactions.py +0 -0
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/model_unfolder/renderers/html/metadata.py +0 -0
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/model_unfolder/renderers/html/sections.py +0 -0
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/model_unfolder/renderers/html/styles.py +0 -0
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/model_unfolder/renderers/html/theme.py +0 -0
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/model_unfolder/renderers/html/utils.py +0 -0
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/model_unfolder/renderers/html/views.py +0 -0
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/model_unfolder.egg-info/SOURCES.txt +0 -0
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/model_unfolder.egg-info/dependency_links.txt +0 -0
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/model_unfolder.egg-info/requires.txt +0 -0
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/model_unfolder.egg-info/top_level.txt +0 -0
- {model_unfolder-0.2.6 → model_unfolder-0.2.7}/setup.cfg +0 -0
|
@@ -26,6 +26,18 @@ def _sdpa_child_blocks(attention: AttentionSpec, hidden_size: int) -> list[dict]
|
|
|
26
26
|
q_out = _fmt(num_heads * head_dim) if (num_heads and head_dim) else hidden
|
|
27
27
|
kv_out = _fmt(num_kv_heads * head_dim) if (num_kv_heads and head_dim) else hidden
|
|
28
28
|
d_k = _fmt(head_dim) if head_dim else "d_k"
|
|
29
|
+
if attention.kind in {"mha", "gqa", "mqa"}:
|
|
30
|
+
return _sdpa_detailed_child_blocks(
|
|
31
|
+
attention.kind,
|
|
32
|
+
hidden,
|
|
33
|
+
q_out,
|
|
34
|
+
kv_out,
|
|
35
|
+
num_heads,
|
|
36
|
+
num_kv_heads,
|
|
37
|
+
d_k,
|
|
38
|
+
q_per_group,
|
|
39
|
+
)
|
|
40
|
+
|
|
29
41
|
attention_title, attention_desc = _sdpa_operation_meta(attention, num_heads, num_kv_heads, d_k, q_per_group)
|
|
30
42
|
return [
|
|
31
43
|
{
|
|
@@ -56,6 +68,77 @@ def _sdpa_child_blocks(attention: AttentionSpec, hidden_size: int) -> list[dict]
|
|
|
56
68
|
]
|
|
57
69
|
|
|
58
70
|
|
|
71
|
+
def _sdpa_detailed_child_blocks(
|
|
72
|
+
kind: str,
|
|
73
|
+
hidden: str,
|
|
74
|
+
q_out: str,
|
|
75
|
+
kv_out: str,
|
|
76
|
+
num_heads: int,
|
|
77
|
+
num_kv_heads: int,
|
|
78
|
+
d_k: str,
|
|
79
|
+
q_per_group: int | None,
|
|
80
|
+
) -> list[dict]:
|
|
81
|
+
kv_label = "1 shared K/V head" if kind == "mqa" else f"{num_kv_heads} KV-heads"
|
|
82
|
+
scaled_title = "Scaled attention scores"
|
|
83
|
+
scaled_desc = "Per head: QK^T / sqrt(dim); dot-product scores scaled for numerical stability"
|
|
84
|
+
if kind == "gqa":
|
|
85
|
+
scaled_title = "Grouped scaled dot-product attention"
|
|
86
|
+
group = f"; each KV head serves {q_per_group} query heads" if q_per_group else ""
|
|
87
|
+
scaled_desc = (
|
|
88
|
+
f"Grouped SDPA scores: {num_heads} query heads attend through "
|
|
89
|
+
f"{num_kv_heads} shared K/V heads{group}; scores use QK^T / sqrt(dim)"
|
|
90
|
+
)
|
|
91
|
+
elif kind == "mqa":
|
|
92
|
+
scaled_title = "Multi-query scaled dot-product attention"
|
|
93
|
+
scaled_desc = (
|
|
94
|
+
f"Multi-Query SDPA scores: {num_heads} query heads share one K/V stream; "
|
|
95
|
+
"scores use QK^T / sqrt(dim)"
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
return [
|
|
99
|
+
{
|
|
100
|
+
"id": "q_proj",
|
|
101
|
+
"title": "Query projection",
|
|
102
|
+
"description": f"Linear; {hidden} -> {q_out} ({num_heads} heads x {d_k} dims)",
|
|
103
|
+
},
|
|
104
|
+
{
|
|
105
|
+
"id": "k_proj",
|
|
106
|
+
"title": "Key projection",
|
|
107
|
+
"description": f"Linear; {hidden} -> {kv_out} ({kv_label} x {d_k} dims)",
|
|
108
|
+
},
|
|
109
|
+
{
|
|
110
|
+
"id": "v_proj",
|
|
111
|
+
"title": "Value projection",
|
|
112
|
+
"description": f"Linear; {hidden} -> {kv_out} ({kv_label} x {d_k} dims)",
|
|
113
|
+
},
|
|
114
|
+
{
|
|
115
|
+
"id": "scaled_scores",
|
|
116
|
+
"title": scaled_title,
|
|
117
|
+
"description": scaled_desc,
|
|
118
|
+
},
|
|
119
|
+
{
|
|
120
|
+
"id": "attn_softmax",
|
|
121
|
+
"title": "Softmax weights",
|
|
122
|
+
"description": "Normalize each query row into attention weights over source tokens",
|
|
123
|
+
},
|
|
124
|
+
{
|
|
125
|
+
"id": "attn_apply_v",
|
|
126
|
+
"title": "Apply values",
|
|
127
|
+
"description": "Multiply attention weights by V to produce one context vector per head",
|
|
128
|
+
},
|
|
129
|
+
{
|
|
130
|
+
"id": "concat_heads",
|
|
131
|
+
"title": "Concatenate heads",
|
|
132
|
+
"description": f"Stack all {num_heads} per-head context vectors back into width {q_out}",
|
|
133
|
+
},
|
|
134
|
+
{
|
|
135
|
+
"id": "o_proj",
|
|
136
|
+
"title": "Output projection",
|
|
137
|
+
"description": f"Linear; {q_out} -> {hidden} (mixes information across heads)",
|
|
138
|
+
},
|
|
139
|
+
]
|
|
140
|
+
|
|
141
|
+
|
|
59
142
|
def _sdpa_operation_meta(
|
|
60
143
|
attention: AttentionSpec,
|
|
61
144
|
num_heads: int,
|
|
@@ -102,40 +185,138 @@ def _mla_child_blocks(attention: AttentionSpec, hidden_size: int) -> list[dict]:
|
|
|
102
185
|
num_heads = attention.num_heads or 0
|
|
103
186
|
head_dim = attention.head_dim or 0
|
|
104
187
|
q_out = _fmt(num_heads * head_dim) if (num_heads and head_dim) else hidden
|
|
105
|
-
|
|
188
|
+
query_children = [
|
|
106
189
|
{
|
|
107
190
|
"id": "mla_q",
|
|
108
191
|
"label": "Q projection",
|
|
109
192
|
"title": "Query projection",
|
|
110
193
|
"description": (
|
|
111
|
-
f"Projects hidden states into query
|
|
194
|
+
f"Projects hidden states into query latent space through LoRA rank {q_rank}"
|
|
112
195
|
if attention.q_lora_rank
|
|
113
|
-
else f"
|
|
196
|
+
else f"Projects hidden states directly into query heads; {hidden} -> {q_out}"
|
|
114
197
|
),
|
|
115
198
|
},
|
|
199
|
+
{
|
|
200
|
+
"id": "mla_q_nope",
|
|
201
|
+
"label": "Q noPE",
|
|
202
|
+
"title": "Query content slice",
|
|
203
|
+
"description": "Query content component that does not receive rotary position encoding",
|
|
204
|
+
},
|
|
205
|
+
{
|
|
206
|
+
"id": "mla_q_rope",
|
|
207
|
+
"label": "Q RoPE",
|
|
208
|
+
"title": "Query positional slice",
|
|
209
|
+
"description": f"Query positional component prepared for rotary position encoding; dim {rope}",
|
|
210
|
+
},
|
|
211
|
+
{
|
|
212
|
+
"id": "mla_q_rope_apply",
|
|
213
|
+
"label": "Apply RoPE",
|
|
214
|
+
"title": "Apply RoPE to query",
|
|
215
|
+
"description": "Applies rotary position encoding to the query positional slice",
|
|
216
|
+
},
|
|
217
|
+
{
|
|
218
|
+
"id": "mla_q_concat",
|
|
219
|
+
"label": "Q concat",
|
|
220
|
+
"title": "Final MLA query",
|
|
221
|
+
"description": "Concatenates Q noPE with RoPE-encoded Q RoPE before score computation",
|
|
222
|
+
},
|
|
223
|
+
]
|
|
224
|
+
kv_children = [
|
|
116
225
|
{
|
|
117
226
|
"id": "mla_kv_down",
|
|
118
227
|
"label": "KV compress",
|
|
119
228
|
"title": "K/V latent compression",
|
|
120
|
-
"description": f"Compresses the token state into
|
|
229
|
+
"description": f"Compresses the token state into the shared latent K/V cache; {hidden} -> rank {kv_rank}",
|
|
230
|
+
},
|
|
231
|
+
{
|
|
232
|
+
"id": "mla_cache",
|
|
233
|
+
"label": "latent cache c_t",
|
|
234
|
+
"title": "Stored latent cache",
|
|
235
|
+
"description": f"Compressed K/V latent stored in the cache instead of full K and V heads; rank {kv_rank}",
|
|
121
236
|
},
|
|
122
237
|
{
|
|
123
238
|
"id": "mla_kv_up",
|
|
124
239
|
"label": "KV expand",
|
|
125
240
|
"title": "K/V head expansion",
|
|
126
|
-
"description": f"Expands
|
|
241
|
+
"description": f"Expands cached latent c_t into K noPE content and V values for {num_heads} query heads",
|
|
242
|
+
},
|
|
243
|
+
{
|
|
244
|
+
"id": "mla_k_nope",
|
|
245
|
+
"label": "K noPE",
|
|
246
|
+
"title": "Latent key content",
|
|
247
|
+
"description": "Key content expanded from the compressed K/V latent; concatenated with the RoPE key before scoring",
|
|
248
|
+
},
|
|
249
|
+
{
|
|
250
|
+
"id": "mla_k_rope",
|
|
251
|
+
"label": "K RoPE",
|
|
252
|
+
"title": "Key positional slice",
|
|
253
|
+
"description": f"Key positional component produced alongside the latent cache; dim {rope}",
|
|
254
|
+
},
|
|
255
|
+
{
|
|
256
|
+
"id": "mla_k_rope_apply",
|
|
257
|
+
"label": "Apply RoPE",
|
|
258
|
+
"title": "Apply RoPE to key",
|
|
259
|
+
"description": "Applies rotary position encoding to the key positional slice",
|
|
260
|
+
},
|
|
261
|
+
{
|
|
262
|
+
"id": "mla_k_merge",
|
|
263
|
+
"label": "K concat",
|
|
264
|
+
"title": "Composed MLA key",
|
|
265
|
+
"description": "Concatenates K noPE with the RoPE key side-channel before QK^T score computation",
|
|
266
|
+
},
|
|
267
|
+
{
|
|
268
|
+
"id": "mla_v",
|
|
269
|
+
"label": "V values",
|
|
270
|
+
"title": "Latent value heads",
|
|
271
|
+
"description": "Value heads expanded from the compressed K/V latent; consumed after softmax",
|
|
272
|
+
},
|
|
273
|
+
]
|
|
274
|
+
return [
|
|
275
|
+
{
|
|
276
|
+
"id": "mla_query_path",
|
|
277
|
+
"label": "Query path",
|
|
278
|
+
"title": "MLA query path",
|
|
279
|
+
"description": (
|
|
280
|
+
"Builds Q by projecting the hidden state, splitting content and positional slices, "
|
|
281
|
+
"applying RoPE to the positional slice, then concatenating them"
|
|
282
|
+
),
|
|
283
|
+
"detail_view": "mla_query_path",
|
|
284
|
+
"children": query_children,
|
|
285
|
+
},
|
|
286
|
+
{
|
|
287
|
+
"id": "mla_kv_path",
|
|
288
|
+
"label": "KV cache path",
|
|
289
|
+
"title": "MLA K/V cache path",
|
|
290
|
+
"description": (
|
|
291
|
+
f"Compresses hidden state into rank {kv_rank} latent cache, expands K/V content, "
|
|
292
|
+
"and combines K noPE with a RoPE key side-channel"
|
|
293
|
+
),
|
|
294
|
+
"detail_view": "mla_kv_cache_path",
|
|
295
|
+
"children": kv_children,
|
|
296
|
+
},
|
|
297
|
+
{
|
|
298
|
+
"id": "scaled_scores",
|
|
299
|
+
"label": "Latent scores",
|
|
300
|
+
"title": "Multi-Head Latent scores",
|
|
301
|
+
"description": "Q attends to expanded latent K plus the RoPE key side-channel; scores use QK^T / sqrt(dim)",
|
|
302
|
+
},
|
|
303
|
+
{
|
|
304
|
+
"id": "attn_softmax",
|
|
305
|
+
"label": "Softmax",
|
|
306
|
+
"title": "Softmax weights",
|
|
307
|
+
"description": "Normalize latent attention scores over source positions",
|
|
127
308
|
},
|
|
128
309
|
{
|
|
129
|
-
"id": "
|
|
130
|
-
"label": "
|
|
131
|
-
"title": "
|
|
132
|
-
"description":
|
|
310
|
+
"id": "attn_apply_v",
|
|
311
|
+
"label": "Apply V",
|
|
312
|
+
"title": "Apply latent values",
|
|
313
|
+
"description": "Multiply softmax weights by V expanded from the compressed K/V latent",
|
|
133
314
|
},
|
|
134
315
|
{
|
|
135
|
-
"id": "
|
|
136
|
-
"label": "
|
|
137
|
-
"title": "
|
|
138
|
-
"description": "
|
|
316
|
+
"id": "concat_heads",
|
|
317
|
+
"label": "Concat heads",
|
|
318
|
+
"title": "Concatenate latent heads",
|
|
319
|
+
"description": f"Stack all {num_heads} context heads back into width {q_out}",
|
|
139
320
|
},
|
|
140
321
|
{
|
|
141
322
|
"id": "o_proj",
|
{model_unfolder-0.2.6 → model_unfolder-0.2.7}/model_unfolder/renderers/html/block_views/__init__.py
RENAMED
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
from __future__ import annotations
|
|
3
3
|
|
|
4
4
|
from .attention import attention_card, attention_card_css, build_attention_view
|
|
5
|
+
from .attention_types import build_mla_kv_cache_view, build_mla_query_path_view
|
|
5
6
|
from .feed_forward import build_dense_ffn_view, build_ffn_view
|
|
6
7
|
from .mixture_of_experts import build_moe_expert_view, build_moe_view
|
|
7
8
|
from .per_layer_embedding import build_per_layer_embedding_view
|
|
@@ -28,6 +29,10 @@ def block_detail_svg(ir: dict, info: dict, mount_id: str, block: dict) -> str |
|
|
|
28
29
|
|
|
29
30
|
def sub_block_detail_svg(ir: dict, info: dict, mount_id: str, child: dict) -> str | None:
|
|
30
31
|
"""Return a rich SVG for a clicked node inside a detail view."""
|
|
32
|
+
if child.get("detail_view") == "mla_query_path":
|
|
33
|
+
return build_mla_query_path_view(ir, info, mount_id, child)
|
|
34
|
+
if child.get("detail_view") == "mla_kv_cache_path":
|
|
35
|
+
return build_mla_kv_cache_view(ir, info, mount_id, child)
|
|
31
36
|
if child.get("detail_view") == "moe_expert":
|
|
32
37
|
return build_moe_expert_view(ir, info, mount_id, child)
|
|
33
38
|
return None
|
|
@@ -3,6 +3,8 @@ from __future__ import annotations
|
|
|
3
3
|
|
|
4
4
|
from .grouped_query import build as build_gqa_attention_view
|
|
5
5
|
from .latent import build as build_mla_attention_view
|
|
6
|
+
from .latent import build_kv_cache_view as build_mla_kv_cache_view
|
|
7
|
+
from .latent import build_query_path_view as build_mla_query_path_view
|
|
6
8
|
from .linear import build as build_linear_attention_view
|
|
7
9
|
from .multi_head import build as build_sdpa_attention_view
|
|
8
10
|
from .multi_query import build as build_mqa_attention_view
|
|
@@ -14,6 +16,8 @@ __all__ = [
|
|
|
14
16
|
"build_gqa_attention_view",
|
|
15
17
|
"build_linear_attention_view",
|
|
16
18
|
"build_mla_attention_view",
|
|
19
|
+
"build_mla_kv_cache_view",
|
|
20
|
+
"build_mla_query_path_view",
|
|
17
21
|
"build_mqa_attention_view",
|
|
18
22
|
"build_recurrent_view",
|
|
19
23
|
"build_rwkv_view",
|
|
@@ -11,7 +11,7 @@ from ...svg import (
|
|
|
11
11
|
_svg_text,
|
|
12
12
|
_v_line,
|
|
13
13
|
)
|
|
14
|
-
from ...theme import C, FONT_MONO, GAP
|
|
14
|
+
from ...theme import C, FONT_HEAD, FONT_MONO, GAP
|
|
15
15
|
from ...utils import _fmt_int
|
|
16
16
|
|
|
17
17
|
|
|
@@ -38,6 +38,9 @@ def gqa_grouping_panel(
|
|
|
38
38
|
num_kv_heads: int,
|
|
39
39
|
q_per_group: int | None,
|
|
40
40
|
) -> dict:
|
|
41
|
+
if w <= 300:
|
|
42
|
+
return _gqa_compact_legend(parts, x, y, w, h, num_heads, num_kv_heads, q_per_group)
|
|
43
|
+
|
|
41
44
|
parts.append(_svg_tag("rect", {
|
|
42
45
|
"x": x,
|
|
43
46
|
"y": y,
|
|
@@ -81,6 +84,93 @@ def gqa_grouping_panel(
|
|
|
81
84
|
}
|
|
82
85
|
|
|
83
86
|
|
|
87
|
+
def _gqa_compact_legend(
|
|
88
|
+
parts: list[str],
|
|
89
|
+
x: float,
|
|
90
|
+
y: float,
|
|
91
|
+
w: float,
|
|
92
|
+
h: float,
|
|
93
|
+
num_heads: int,
|
|
94
|
+
num_kv_heads: int,
|
|
95
|
+
q_per_group: int | None,
|
|
96
|
+
) -> dict:
|
|
97
|
+
parts.append(_svg_tag("rect", {
|
|
98
|
+
"x": x,
|
|
99
|
+
"y": y,
|
|
100
|
+
"width": w,
|
|
101
|
+
"height": h,
|
|
102
|
+
"rx": 14,
|
|
103
|
+
"ry": 14,
|
|
104
|
+
"fill": C["bg_card"],
|
|
105
|
+
"stroke": C["border"],
|
|
106
|
+
"stroke-width": 0.7,
|
|
107
|
+
}))
|
|
108
|
+
parts.append(_svg_text(
|
|
109
|
+
x + 18,
|
|
110
|
+
y + 24,
|
|
111
|
+
"KV sharing pattern",
|
|
112
|
+
{
|
|
113
|
+
"fill": C["text"],
|
|
114
|
+
"font-family": FONT_MONO,
|
|
115
|
+
"font-size": 10,
|
|
116
|
+
"font-weight": 700,
|
|
117
|
+
"letter-spacing": "0.08em",
|
|
118
|
+
},
|
|
119
|
+
))
|
|
120
|
+
subtitle = (
|
|
121
|
+
f"{q_per_group} Q heads per KV"
|
|
122
|
+
if q_per_group
|
|
123
|
+
else f"{num_heads} Q / {num_kv_heads} KV"
|
|
124
|
+
)
|
|
125
|
+
parts.append(_svg_text(
|
|
126
|
+
x + 18,
|
|
127
|
+
y + 42,
|
|
128
|
+
subtitle,
|
|
129
|
+
{"fill": C["muted"], "font-family": FONT_MONO, "font-size": 9},
|
|
130
|
+
))
|
|
131
|
+
|
|
132
|
+
card_y = y + 58
|
|
133
|
+
card_h = 34
|
|
134
|
+
gap = 7
|
|
135
|
+
card_w = w - 36
|
|
136
|
+
for i, (top, bottom) in enumerate(_gqa_card_specs(num_heads, num_kv_heads, q_per_group)):
|
|
137
|
+
cy = card_y + i * (card_h + gap)
|
|
138
|
+
parts.append(_svg_tag("rect", {
|
|
139
|
+
"x": x + 18,
|
|
140
|
+
"y": cy,
|
|
141
|
+
"width": card_w,
|
|
142
|
+
"height": card_h,
|
|
143
|
+
"rx": 9,
|
|
144
|
+
"ry": 9,
|
|
145
|
+
"fill": C["badge_bg"],
|
|
146
|
+
"stroke": C["border"],
|
|
147
|
+
"stroke-width": 0.7,
|
|
148
|
+
}))
|
|
149
|
+
parts.append(_svg_text(
|
|
150
|
+
x + 36,
|
|
151
|
+
cy + 14,
|
|
152
|
+
top,
|
|
153
|
+
{"fill": C["text"], "font-family": FONT_MONO, "font-size": 10, "font-weight": 700},
|
|
154
|
+
))
|
|
155
|
+
parts.append(_svg_text(
|
|
156
|
+
x + 36,
|
|
157
|
+
cy + 28,
|
|
158
|
+
bottom,
|
|
159
|
+
{"fill": C["muted"], "font-family": FONT_MONO, "font-size": 8.5},
|
|
160
|
+
))
|
|
161
|
+
|
|
162
|
+
return {
|
|
163
|
+
"left": x,
|
|
164
|
+
"right": x + w,
|
|
165
|
+
"top": y,
|
|
166
|
+
"bottom": y + h,
|
|
167
|
+
"cx": x + w / 2,
|
|
168
|
+
"cy": y + h / 2,
|
|
169
|
+
"w": w,
|
|
170
|
+
"h": h,
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
|
|
84
174
|
def mqa_shared_kv_node(parts: list[str], x: float, y: float, w: float, h: float, num_heads: int) -> dict:
|
|
85
175
|
parts.append(_svg_tag("rect", {
|
|
86
176
|
"x": x,
|
|
@@ -308,3 +398,121 @@ def attn_dim_label(parts: list[str], x: float, y: float, text: str, *, anchor: s
|
|
|
308
398
|
"font-size": 10,
|
|
309
399
|
},
|
|
310
400
|
))
|
|
401
|
+
|
|
402
|
+
|
|
403
|
+
def sdpa_fraction_block(
|
|
404
|
+
parts: list[str],
|
|
405
|
+
info: dict,
|
|
406
|
+
shadow_id: str,
|
|
407
|
+
node_id: str,
|
|
408
|
+
x: float,
|
|
409
|
+
y: float,
|
|
410
|
+
w: float,
|
|
411
|
+
h: float,
|
|
412
|
+
*,
|
|
413
|
+
numerator: str = "Q K^T",
|
|
414
|
+
denominator: str = "sqrt(dim)",
|
|
415
|
+
) -> dict:
|
|
416
|
+
"""Green formula block for the scaled score step in SDPA-style attention."""
|
|
417
|
+
children = [
|
|
418
|
+
_svg_tag("rect", {
|
|
419
|
+
"x": x,
|
|
420
|
+
"y": y,
|
|
421
|
+
"width": w,
|
|
422
|
+
"height": h,
|
|
423
|
+
"rx": 11,
|
|
424
|
+
"ry": 11,
|
|
425
|
+
"fill": C["block"],
|
|
426
|
+
"stroke": C["block_alt"],
|
|
427
|
+
"stroke-width": 0.6,
|
|
428
|
+
"filter": f"url(#{shadow_id})",
|
|
429
|
+
}),
|
|
430
|
+
_svg_text(
|
|
431
|
+
x + w / 2,
|
|
432
|
+
y + h * 0.32,
|
|
433
|
+
numerator,
|
|
434
|
+
{
|
|
435
|
+
"text-anchor": "middle",
|
|
436
|
+
"dominant-baseline": "central",
|
|
437
|
+
"fill": C["text_block"],
|
|
438
|
+
"font-family": FONT_HEAD,
|
|
439
|
+
"font-size": 22,
|
|
440
|
+
"pointer-events": "none",
|
|
441
|
+
},
|
|
442
|
+
),
|
|
443
|
+
_svg_tag("line", {
|
|
444
|
+
"x1": x + 72,
|
|
445
|
+
"y1": y + h * 0.52,
|
|
446
|
+
"x2": x + w - 72,
|
|
447
|
+
"y2": y + h * 0.52,
|
|
448
|
+
"stroke": C["text_block"],
|
|
449
|
+
"stroke-width": 1.7,
|
|
450
|
+
"stroke-linecap": "round",
|
|
451
|
+
"pointer-events": "none",
|
|
452
|
+
}),
|
|
453
|
+
_svg_text(
|
|
454
|
+
x + w / 2,
|
|
455
|
+
y + h * 0.73,
|
|
456
|
+
denominator,
|
|
457
|
+
{
|
|
458
|
+
"text-anchor": "middle",
|
|
459
|
+
"dominant-baseline": "central",
|
|
460
|
+
"fill": C["text_block"],
|
|
461
|
+
"font-family": FONT_HEAD,
|
|
462
|
+
"font-size": 19,
|
|
463
|
+
"pointer-events": "none",
|
|
464
|
+
},
|
|
465
|
+
),
|
|
466
|
+
]
|
|
467
|
+
parts.append(_svg_tag("g", {"class": "uf-node", "data-id": node_id}, "".join(children)))
|
|
468
|
+
return {"left": x, "right": x + w, "top": y, "bottom": y + h, "cx": x + w / 2, "cy": y + h / 2, "w": w, "h": h}
|
|
469
|
+
|
|
470
|
+
|
|
471
|
+
def sdpa_dot_operator(parts: list[str], info: dict, shadow_id: str, node_id: str, cx: float, cy: float) -> dict:
|
|
472
|
+
"""Small green dot-product operator used for attention weights × V."""
|
|
473
|
+
r = 16
|
|
474
|
+
children = [
|
|
475
|
+
_svg_tag("circle", {
|
|
476
|
+
"cx": cx,
|
|
477
|
+
"cy": cy,
|
|
478
|
+
"r": r,
|
|
479
|
+
"fill": C["block"],
|
|
480
|
+
"stroke": C["block_alt"],
|
|
481
|
+
"stroke-width": 0.6,
|
|
482
|
+
"filter": f"url(#{shadow_id})",
|
|
483
|
+
}),
|
|
484
|
+
_svg_tag("circle", {
|
|
485
|
+
"cx": cx,
|
|
486
|
+
"cy": cy,
|
|
487
|
+
"r": 5,
|
|
488
|
+
"fill": "none",
|
|
489
|
+
"stroke": C["text_block"],
|
|
490
|
+
"stroke-width": 2,
|
|
491
|
+
"pointer-events": "none",
|
|
492
|
+
}),
|
|
493
|
+
]
|
|
494
|
+
parts.append(_svg_tag("g", {"class": "uf-node", "data-id": node_id}, "".join(children)))
|
|
495
|
+
return {"left": cx - r, "right": cx + r, "top": cy - r, "bottom": cy + r, "cx": cx, "cy": cy, "r": r}
|
|
496
|
+
|
|
497
|
+
|
|
498
|
+
def input_to_block(x1: float, y1: float, x2: float, y2: float, arrow_id: str, *, lane_offset: float = 26) -> str:
|
|
499
|
+
"""Route an upward input into the lower edge of a block without grazing it."""
|
|
500
|
+
lane_y = y2 + lane_offset
|
|
501
|
+
r = 10
|
|
502
|
+
d = (
|
|
503
|
+
f"M {x1:g} {y1:g} "
|
|
504
|
+
f"L {x1:g} {lane_y + r:g} "
|
|
505
|
+
f"Q {x1:g} {lane_y:g} {x1 + r:g} {lane_y:g} "
|
|
506
|
+
f"L {x2 - r:g} {lane_y:g} "
|
|
507
|
+
f"Q {x2:g} {lane_y:g} {x2:g} {lane_y - r:g} "
|
|
508
|
+
f"L {x2:g} {y2:g}"
|
|
509
|
+
)
|
|
510
|
+
return _svg_tag("path", {
|
|
511
|
+
"d": d,
|
|
512
|
+
"fill": "none",
|
|
513
|
+
"stroke": C["arrow"],
|
|
514
|
+
"stroke-width": 1.6,
|
|
515
|
+
"stroke-linecap": "round",
|
|
516
|
+
"stroke-linejoin": "round",
|
|
517
|
+
"marker-end": f"url(#{arrow_id})",
|
|
518
|
+
})
|
|
@@ -5,6 +5,7 @@ from ...svg import (
|
|
|
5
5
|
_branch_dot,
|
|
6
6
|
_defs,
|
|
7
7
|
_elbow_hv,
|
|
8
|
+
_elbow_vh,
|
|
8
9
|
_ids,
|
|
9
10
|
_rect_block,
|
|
10
11
|
_region_rect,
|
|
@@ -17,16 +18,19 @@ from ...theme import C, GAP
|
|
|
17
18
|
from ...utils import _fmt_int
|
|
18
19
|
from .common import (
|
|
19
20
|
gqa_grouping_panel,
|
|
21
|
+
input_to_block,
|
|
20
22
|
kv_cache_badge,
|
|
21
23
|
output_stem,
|
|
22
24
|
placed_figure,
|
|
23
25
|
queries_per_kv_group,
|
|
26
|
+
sdpa_dot_operator,
|
|
27
|
+
sdpa_fraction_block,
|
|
24
28
|
)
|
|
25
29
|
|
|
26
30
|
|
|
27
31
|
def build(ir: dict, info: dict, mount_id: str) -> str:
|
|
28
32
|
"""Detail view for grouped-query attention."""
|
|
29
|
-
w, h =
|
|
33
|
+
w, h = 820, 880
|
|
30
34
|
arrow_id, shadow_id = _ids(mount_id, "gqa-attn")
|
|
31
35
|
parts = [_defs(arrow_id, shadow_id)]
|
|
32
36
|
parts.append(_region_rect(40, 30, w - 80, h - 60, C["bg_outer"]))
|
|
@@ -42,31 +46,38 @@ def build(ir: dict, info: dict, mount_id: str) -> str:
|
|
|
42
46
|
q_per_group = queries_per_kv_group(num_heads, num_kv_heads)
|
|
43
47
|
|
|
44
48
|
cx = w / 2
|
|
45
|
-
o_proj = _rect_block(body, info, shadow_id, "o_proj", cx -
|
|
46
|
-
|
|
49
|
+
o_proj = _rect_block(body, info, shadow_id, "o_proj", cx - 100, 72, 200, 52, "Linear (out)")
|
|
50
|
+
concat = _rect_block(
|
|
47
51
|
body,
|
|
48
52
|
info,
|
|
49
53
|
shadow_id,
|
|
50
|
-
"
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
["
|
|
56
|
-
font_size=
|
|
54
|
+
"concat_heads",
|
|
55
|
+
cx - 112,
|
|
56
|
+
164,
|
|
57
|
+
224,
|
|
58
|
+
54,
|
|
59
|
+
["Concat heads", f"{num_heads} x {d_k}" if num_heads else "per head"],
|
|
60
|
+
font_size=16,
|
|
57
61
|
)
|
|
62
|
+
value_dot = sdpa_dot_operator(body, info, shadow_id, "attn_apply_v", cx, 276)
|
|
63
|
+
softmax = _rect_block(body, info, shadow_id, "attn_softmax", cx - 96, 344, 192, 52, "Softmax")
|
|
64
|
+
scaled_scores = sdpa_fraction_block(body, info, shadow_id, "scaled_scores", cx - 140, 452, 280, 82)
|
|
65
|
+
gqa_grouping_panel(body, 64, 92, 220, 218, num_heads, num_kv_heads, q_per_group)
|
|
58
66
|
|
|
59
|
-
|
|
67
|
+
body.append(_v_line(scaled_scores, softmax, arrow_id))
|
|
68
|
+
body.append(_v_line(softmax, value_dot, arrow_id))
|
|
69
|
+
body.append(_v_line(value_dot, concat, arrow_id))
|
|
70
|
+
body.append(_v_line(concat, o_proj, arrow_id))
|
|
60
71
|
|
|
61
|
-
proj_w, proj_h, proj_y =
|
|
62
|
-
q_proj = _rect_block(body, info, shadow_id, "q_proj",
|
|
72
|
+
proj_w, proj_h, proj_y = 185, 52, 704
|
|
73
|
+
q_proj = _rect_block(body, info, shadow_id, "q_proj", 78, proj_y, proj_w, proj_h, ["Linear (Q)", f"{num_heads} heads"], font_size=15)
|
|
63
74
|
kv_head_label = f"{num_kv_heads} head" if num_kv_heads == 1 else f"{num_kv_heads} heads"
|
|
64
|
-
k_proj = _rect_block(body, info, shadow_id, "k_proj",
|
|
65
|
-
v_proj = _rect_block(body, info, shadow_id, "v_proj",
|
|
75
|
+
k_proj = _rect_block(body, info, shadow_id, "k_proj", cx - proj_w / 2, proj_y, proj_w, proj_h, ["Linear (K)", kv_head_label], font_size=15)
|
|
76
|
+
v_proj = _rect_block(body, info, shadow_id, "v_proj", w - 78 - proj_w, proj_y, proj_w, proj_h, ["Linear (V)", kv_head_label], font_size=15)
|
|
66
77
|
|
|
67
|
-
branch_x, branch_y = cx,
|
|
78
|
+
branch_x, branch_y = cx, 792
|
|
68
79
|
body.append(_svg_tag("line", {
|
|
69
|
-
"x1": branch_x, "y1": branch_y +
|
|
80
|
+
"x1": branch_x, "y1": branch_y + 34, "x2": branch_x, "y2": branch_y,
|
|
70
81
|
"stroke": C["arrow"], "stroke-width": 1.6, "stroke-linecap": "round",
|
|
71
82
|
"fill": "none",
|
|
72
83
|
}))
|
|
@@ -75,12 +86,9 @@ def build(ir: dict, info: dict, mount_id: str) -> str:
|
|
|
75
86
|
body.append(_elbow_hv(branch_x, branch_y, v_proj["cx"], v_proj["bottom"] + GAP, arrow_id))
|
|
76
87
|
body.append(_branch_dot(branch_x, branch_y))
|
|
77
88
|
|
|
78
|
-
|
|
79
|
-
body.append(_v_seg(
|
|
80
|
-
body.append(
|
|
81
|
-
body.append(_v_seg(v_proj["cx"], v_proj["top"], panel_entry_y, arrow_id))
|
|
82
|
-
body.append(_v_line(panel, sdpa, arrow_id))
|
|
83
|
-
body.append(_v_line(sdpa, o_proj, arrow_id))
|
|
89
|
+
body.append(input_to_block(q_proj["cx"], q_proj["top"], scaled_scores["left"] + 92, scaled_scores["bottom"], arrow_id))
|
|
90
|
+
body.append(_v_seg(k_proj["cx"], k_proj["top"], scaled_scores["bottom"], arrow_id))
|
|
91
|
+
body.append(_elbow_vh(v_proj["cx"], v_proj["top"], value_dot["right"] + GAP, value_dot["cy"], arrow_id))
|
|
84
92
|
output_stem(body, cx, o_proj, arrow_id, hidden, show_label=False)
|
|
85
93
|
|
|
86
94
|
if q_per_group and q_per_group > 1:
|