model-unfolder 0.2.5__tar.gz → 0.2.7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {model_unfolder-0.2.5 → model_unfolder-0.2.7}/PKG-INFO +12 -6
- {model_unfolder-0.2.5 → model_unfolder-0.2.7}/README.md +11 -5
- {model_unfolder-0.2.5 → model_unfolder-0.2.7}/model_unfolder/adapters/transformer/assembly.py +30 -2
- model_unfolder-0.2.7/model_unfolder/adapters/transformer/blocks/__init__.py +56 -0
- model_unfolder-0.2.7/model_unfolder/adapters/transformer/blocks/attention.py +479 -0
- model_unfolder-0.2.7/model_unfolder/adapters/transformer/blocks/descriptions.py +132 -0
- model_unfolder-0.2.7/model_unfolder/adapters/transformer/blocks/feed_forward.py +166 -0
- model_unfolder-0.2.7/model_unfolder/adapters/transformer/blocks/layers.py +123 -0
- model_unfolder-0.2.7/model_unfolder/adapters/transformer/blocks/model.py +52 -0
- model_unfolder-0.2.7/model_unfolder/adapters/transformer/families/__init__.py +26 -0
- model_unfolder-0.2.7/model_unfolder/adapters/transformer/families/cohere.py +106 -0
- {model_unfolder-0.2.5 → model_unfolder-0.2.7}/model_unfolder/adapters/transformer/families/deepseek.py +9 -10
- model_unfolder-0.2.7/model_unfolder/adapters/transformer/families/falcon.py +183 -0
- {model_unfolder-0.2.5 → model_unfolder-0.2.7}/model_unfolder/adapters/transformer/families/gemma/__init__.py +3 -2
- model_unfolder-0.2.7/model_unfolder/adapters/transformer/families/gemma/gemma2.py +101 -0
- {model_unfolder-0.2.5 → model_unfolder-0.2.7}/model_unfolder/adapters/transformer/families/gemma/gemma4.py +1 -3
- model_unfolder-0.2.7/model_unfolder/adapters/transformer/families/gpt_neox.py +114 -0
- model_unfolder-0.2.7/model_unfolder/adapters/transformer/families/jamba.py +126 -0
- {model_unfolder-0.2.5 → model_unfolder-0.2.7}/model_unfolder/adapters/transformer/families/llama.py +57 -9
- {model_unfolder-0.2.5 → model_unfolder-0.2.7}/model_unfolder/adapters/transformer/families/qwen.py +5 -1
- model_unfolder-0.2.7/model_unfolder/adapters/transformer/families/recurrent_gemma.py +103 -0
- model_unfolder-0.2.7/model_unfolder/adapters/transformer/families/rwkv.py +74 -0
- model_unfolder-0.2.7/model_unfolder/adapters/transformer/families/zamba.py +127 -0
- {model_unfolder-0.2.5 → model_unfolder-0.2.7}/model_unfolder/ir.py +8 -1
- {model_unfolder-0.2.5 → model_unfolder-0.2.7}/model_unfolder/labels.py +57 -10
- {model_unfolder-0.2.5 → model_unfolder-0.2.7}/model_unfolder/parser.py +36 -3
- model_unfolder-0.2.7/model_unfolder/renderers/html/block_views/__init__.py +38 -0
- {model_unfolder-0.2.5 → model_unfolder-0.2.7}/model_unfolder/renderers/html/block_views/attention.py +34 -3
- model_unfolder-0.2.7/model_unfolder/renderers/html/block_views/attention_types/__init__.py +26 -0
- model_unfolder-0.2.7/model_unfolder/renderers/html/block_views/attention_types/common.py +518 -0
- model_unfolder-0.2.7/model_unfolder/renderers/html/block_views/attention_types/grouped_query.py +98 -0
- model_unfolder-0.2.7/model_unfolder/renderers/html/block_views/attention_types/latent.py +262 -0
- model_unfolder-0.2.7/model_unfolder/renderers/html/block_views/attention_types/linear.py +59 -0
- model_unfolder-0.2.7/model_unfolder/renderers/html/block_views/attention_types/multi_head.py +90 -0
- model_unfolder-0.2.7/model_unfolder/renderers/html/block_views/attention_types/multi_query.py +89 -0
- model_unfolder-0.2.7/model_unfolder/renderers/html/block_views/attention_types/rwkv.py +55 -0
- model_unfolder-0.2.7/model_unfolder/renderers/html/block_views/attention_types/state_space.py +44 -0
- model_unfolder-0.2.7/model_unfolder/renderers/html/block_views/feed_forward.py +95 -0
- model_unfolder-0.2.7/model_unfolder/renderers/html/block_views/mixture_of_experts.py +117 -0
- model_unfolder-0.2.7/model_unfolder/renderers/html/cards.py +234 -0
- model_unfolder-0.2.7/model_unfolder/renderers/html/document.py +201 -0
- model_unfolder-0.2.7/model_unfolder/renderers/html/interactions.py +125 -0
- {model_unfolder-0.2.5 → model_unfolder-0.2.7}/model_unfolder/renderers/html/metadata.py +3 -0
- {model_unfolder-0.2.5 → model_unfolder-0.2.7}/model_unfolder/renderers/html/styles.py +33 -17
- {model_unfolder-0.2.5 → model_unfolder-0.2.7}/model_unfolder/renderers/html/svg.py +19 -2
- {model_unfolder-0.2.5 → model_unfolder-0.2.7}/model_unfolder/renderers/html/views.py +55 -11
- {model_unfolder-0.2.5 → model_unfolder-0.2.7}/model_unfolder.egg-info/PKG-INFO +12 -6
- {model_unfolder-0.2.5 → model_unfolder-0.2.7}/model_unfolder.egg-info/SOURCES.txt +24 -1
- {model_unfolder-0.2.5 → model_unfolder-0.2.7}/pyproject.toml +1 -1
- {model_unfolder-0.2.5 → model_unfolder-0.2.7}/tests/test_smoke.py +215 -2
- model_unfolder-0.2.5/model_unfolder/adapters/transformer/blocks.py +0 -238
- model_unfolder-0.2.5/model_unfolder/adapters/transformer/families/__init__.py +0 -14
- model_unfolder-0.2.5/model_unfolder/renderers/html/block_views/__init__.py +0 -20
- model_unfolder-0.2.5/model_unfolder/renderers/html/block_views/feed_forward.py +0 -213
- model_unfolder-0.2.5/model_unfolder/renderers/html/cards.py +0 -130
- model_unfolder-0.2.5/model_unfolder/renderers/html/document.py +0 -157
- model_unfolder-0.2.5/model_unfolder/renderers/html/interactions.py +0 -64
- {model_unfolder-0.2.5 → model_unfolder-0.2.7}/LICENSE +0 -0
- {model_unfolder-0.2.5 → model_unfolder-0.2.7}/model_unfolder/__init__.py +0 -0
- {model_unfolder-0.2.5 → model_unfolder-0.2.7}/model_unfolder/adapters/__init__.py +0 -0
- {model_unfolder-0.2.5 → model_unfolder-0.2.7}/model_unfolder/adapters/custom/__init__.py +0 -0
- {model_unfolder-0.2.5 → model_unfolder-0.2.7}/model_unfolder/adapters/diffusor/__init__.py +0 -0
- {model_unfolder-0.2.5 → model_unfolder-0.2.7}/model_unfolder/adapters/transformer/__init__.py +0 -0
- {model_unfolder-0.2.5 → model_unfolder-0.2.7}/model_unfolder/adapters/transformer/common.py +0 -0
- {model_unfolder-0.2.5 → model_unfolder-0.2.7}/model_unfolder/adapters/transformer/families/fallback.py +0 -0
- {model_unfolder-0.2.5 → model_unfolder-0.2.7}/model_unfolder/adapters/transformer/families/gemma/gemma3.py +0 -0
- {model_unfolder-0.2.5 → model_unfolder-0.2.7}/model_unfolder/adapters/transformer/families/minimax.py +0 -0
- {model_unfolder-0.2.5 → model_unfolder-0.2.7}/model_unfolder/adapters/transformer/families/mistral.py +0 -0
- {model_unfolder-0.2.5 → model_unfolder-0.2.7}/model_unfolder/adapters/transformer/special_parts/__init__.py +0 -0
- {model_unfolder-0.2.5 → model_unfolder-0.2.7}/model_unfolder/adapters/transformer/special_parts/per_layer_embedding.py +0 -0
- {model_unfolder-0.2.5 → model_unfolder-0.2.7}/model_unfolder/diagram.py +0 -0
- {model_unfolder-0.2.5 → model_unfolder-0.2.7}/model_unfolder/html_renderer.py +0 -0
- {model_unfolder-0.2.5 → model_unfolder-0.2.7}/model_unfolder/params.py +0 -0
- {model_unfolder-0.2.5 → model_unfolder-0.2.7}/model_unfolder/renderers/__init__.py +0 -0
- {model_unfolder-0.2.5 → model_unfolder-0.2.7}/model_unfolder/renderers/html/__init__.py +0 -0
- {model_unfolder-0.2.5 → model_unfolder-0.2.7}/model_unfolder/renderers/html/block_views/per_layer_embedding.py +0 -0
- {model_unfolder-0.2.5 → model_unfolder-0.2.7}/model_unfolder/renderers/html/sections.py +0 -0
- {model_unfolder-0.2.5 → model_unfolder-0.2.7}/model_unfolder/renderers/html/theme.py +0 -0
- {model_unfolder-0.2.5 → model_unfolder-0.2.7}/model_unfolder/renderers/html/utils.py +0 -0
- {model_unfolder-0.2.5 → model_unfolder-0.2.7}/model_unfolder.egg-info/dependency_links.txt +0 -0
- {model_unfolder-0.2.5 → model_unfolder-0.2.7}/model_unfolder.egg-info/requires.txt +0 -0
- {model_unfolder-0.2.5 → model_unfolder-0.2.7}/model_unfolder.egg-info/top_level.txt +0 -0
- {model_unfolder-0.2.5 → model_unfolder-0.2.7}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: model-unfolder
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.7
|
|
4
4
|
Summary: Unfold any HuggingFace transformer into an interactive architecture diagram, inline in Jupyter.
|
|
5
5
|
Author: model-unfolder contributors
|
|
6
6
|
License: Apache-2.0
|
|
@@ -108,12 +108,18 @@ Param estimates are close to published numbers — DeepSeek-V3 reports `~675B (~
|
|
|
108
108
|
|
|
109
109
|
| Family | Models |
|
|
110
110
|
|---|---|
|
|
111
|
-
| DeepSeek | DeepSeek-V2, DeepSeek-V3, Kimi K2 |
|
|
112
|
-
| Llama | Llama 3 / 3.1 / 3.2 / 3.3 |
|
|
111
|
+
| DeepSeek | DeepSeek-V2, DeepSeek-V3 (+ MTP head), Kimi K2 |
|
|
112
|
+
| Llama | Llama 3 / 3.1 / 3.2 / 3.3, OLMo-2, Llama 4 Scout / Maverick (MoE + iRoPE NoPE layers) |
|
|
113
113
|
| Mistral | Mistral 7B, Mixtral 8x7B / 8x22B, Mistral Medium 3.5 |
|
|
114
|
-
| Qwen | Qwen2 / 2.5, Qwen2-MoE, Qwen3, Qwen3-MoE, Qwen3.5 / 3.6 |
|
|
115
|
-
| Gemma | Gemma 3 / 3n, Gemma 4
|
|
116
|
-
|
|
|
114
|
+
| Qwen | Qwen2 / 2.5, Qwen2-MoE, Qwen3, Qwen3-MoE, Qwen3.5 / 3.6 (+ MTP) |
|
|
115
|
+
| Gemma | Gemma 2 9B / 27B (interleaved local+global), Gemma 3 / 3n (+ PLE), Gemma 4 31B / E2B / E4B (+ PLE), RecurrentGemma 2B / 9B (LRU + local attention) |
|
|
116
|
+
| Cohere | Command R, Command R+, Command R7B (QK-Norm attention) |
|
|
117
|
+
| Jamba | Jamba (SSM + attention hybrid, MoE) |
|
|
118
|
+
| Zamba | Zamba 7B, Zamba2 2.7B / 7B (Mamba SSM + weight-shared attention) |
|
|
119
|
+
| Mamba | Mamba 130M–2.8B, Mamba-2 (pure SSM, no attention) |
|
|
120
|
+
| Falcon | Falcon 7B / 40B (parallel attn+FFN), Falcon-H1 (Mamba-2 SSM) |
|
|
121
|
+
| MiniMax | MiniMax-Text-01 (lightning + softmax hybrid, MoE) |
|
|
122
|
+
| RWKV | RWKV-4 / 5 / 6 (pure recurrent, no attention) |
|
|
117
123
|
|
|
118
124
|
### Diffusors
|
|
119
125
|
|
|
@@ -84,12 +84,18 @@ Param estimates are close to published numbers — DeepSeek-V3 reports `~675B (~
|
|
|
84
84
|
|
|
85
85
|
| Family | Models |
|
|
86
86
|
|---|---|
|
|
87
|
-
| DeepSeek | DeepSeek-V2, DeepSeek-V3, Kimi K2 |
|
|
88
|
-
| Llama | Llama 3 / 3.1 / 3.2 / 3.3 |
|
|
87
|
+
| DeepSeek | DeepSeek-V2, DeepSeek-V3 (+ MTP head), Kimi K2 |
|
|
88
|
+
| Llama | Llama 3 / 3.1 / 3.2 / 3.3, OLMo-2, Llama 4 Scout / Maverick (MoE + iRoPE NoPE layers) |
|
|
89
89
|
| Mistral | Mistral 7B, Mixtral 8x7B / 8x22B, Mistral Medium 3.5 |
|
|
90
|
-
| Qwen | Qwen2 / 2.5, Qwen2-MoE, Qwen3, Qwen3-MoE, Qwen3.5 / 3.6 |
|
|
91
|
-
| Gemma | Gemma 3 / 3n, Gemma 4
|
|
92
|
-
|
|
|
90
|
+
| Qwen | Qwen2 / 2.5, Qwen2-MoE, Qwen3, Qwen3-MoE, Qwen3.5 / 3.6 (+ MTP) |
|
|
91
|
+
| Gemma | Gemma 2 9B / 27B (interleaved local+global), Gemma 3 / 3n (+ PLE), Gemma 4 31B / E2B / E4B (+ PLE), RecurrentGemma 2B / 9B (LRU + local attention) |
|
|
92
|
+
| Cohere | Command R, Command R+, Command R7B (QK-Norm attention) |
|
|
93
|
+
| Jamba | Jamba (SSM + attention hybrid, MoE) |
|
|
94
|
+
| Zamba | Zamba 7B, Zamba2 2.7B / 7B (Mamba SSM + weight-shared attention) |
|
|
95
|
+
| Mamba | Mamba 130M–2.8B, Mamba-2 (pure SSM, no attention) |
|
|
96
|
+
| Falcon | Falcon 7B / 40B (parallel attn+FFN), Falcon-H1 (Mamba-2 SSM) |
|
|
97
|
+
| MiniMax | MiniMax-Text-01 (lightning + softmax hybrid, MoE) |
|
|
98
|
+
| RWKV | RWKV-4 / 5 / 6 (pure recurrent, no attention) |
|
|
93
99
|
|
|
94
100
|
### Diffusors
|
|
95
101
|
|
{model_unfolder-0.2.5 → model_unfolder-0.2.7}/model_unfolder/adapters/transformer/assembly.py
RENAMED
|
@@ -5,7 +5,7 @@ from collections.abc import Iterable, Mapping
|
|
|
5
5
|
from typing import Any
|
|
6
6
|
|
|
7
7
|
from ...ir import AttentionSpec, FFNSpec, LayerSpec
|
|
8
|
-
from .blocks import decoder_layer_blocks, decoder_only_render_spec
|
|
8
|
+
from .blocks import decoder_layer_blocks, decoder_only_render_spec, parallel_decoder_layer_blocks
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
def decoder_layer(
|
|
@@ -15,15 +15,43 @@ def decoder_layer(
|
|
|
15
15
|
hidden_size: int,
|
|
16
16
|
*,
|
|
17
17
|
extra_blocks: Iterable[dict] | None = None,
|
|
18
|
+
norm_kind: str = "rmsnorm",
|
|
19
|
+
norm_placement: str = "pre",
|
|
18
20
|
) -> LayerSpec:
|
|
19
21
|
"""Build a decoder layer from parsed specs plus optional reusable parts."""
|
|
20
|
-
blocks = decoder_layer_blocks(attention, ffn, hidden_size)
|
|
22
|
+
blocks = decoder_layer_blocks(attention, ffn, hidden_size, norm_kind=norm_kind)
|
|
21
23
|
if extra_blocks:
|
|
22
24
|
blocks.extend(extra_blocks)
|
|
23
25
|
return LayerSpec(
|
|
24
26
|
index=index,
|
|
25
27
|
attention=attention,
|
|
26
28
|
ffn=ffn,
|
|
29
|
+
norm_kind=norm_kind,
|
|
30
|
+
norm_placement=norm_placement,
|
|
31
|
+
blocks=blocks,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def parallel_decoder_layer(
|
|
36
|
+
index: int,
|
|
37
|
+
attention: AttentionSpec,
|
|
38
|
+
ffn: FFNSpec,
|
|
39
|
+
hidden_size: int,
|
|
40
|
+
*,
|
|
41
|
+
norm_kind: str = "rmsnorm",
|
|
42
|
+
) -> LayerSpec:
|
|
43
|
+
"""Build a parallel-residual decoder layer (GPT-NeoX / GPT-J).
|
|
44
|
+
|
|
45
|
+
Attention and FFN share a single input norm and their outputs are summed
|
|
46
|
+
into one residual add rather than two sequential adds.
|
|
47
|
+
"""
|
|
48
|
+
blocks = parallel_decoder_layer_blocks(attention, ffn, hidden_size, norm_kind=norm_kind)
|
|
49
|
+
return LayerSpec(
|
|
50
|
+
index=index,
|
|
51
|
+
attention=attention,
|
|
52
|
+
ffn=ffn,
|
|
53
|
+
norm_kind=norm_kind,
|
|
54
|
+
norm_placement="pre",
|
|
27
55
|
blocks=blocks,
|
|
28
56
|
)
|
|
29
57
|
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
"""Reusable transformer block descriptions for renderers.
|
|
2
|
+
|
|
3
|
+
Adapters attach these block parts to the IR. Renderers can then draw generic
|
|
4
|
+
decoder-only transformer layouts without rediscovering model-specific names
|
|
5
|
+
or labels with another layer of ``if model_type`` logic.
|
|
6
|
+
|
|
7
|
+
The package is split by responsibility:
|
|
8
|
+
|
|
9
|
+
* ``model``: model-level input/output bookend blocks.
|
|
10
|
+
* ``layers``: sequential and parallel decoder-layer topology.
|
|
11
|
+
* ``attention``: reusable attention/SSM/recurrent child block contracts.
|
|
12
|
+
* ``feed_forward``: dense, gated, and MoE FFN child block contracts.
|
|
13
|
+
* ``descriptions``: labels, titles, and short metadata strings.
|
|
14
|
+
|
|
15
|
+
Each block carries two orthogonal tags:
|
|
16
|
+
|
|
17
|
+
* ``role`` — semantic ("norm", "attention", "ffn", "residual", "gate") used
|
|
18
|
+
for tooltips, click handlers, and the inspect cards.
|
|
19
|
+
* ``kind`` — rendering shape ("norm", "linear", "activation", "attention",
|
|
20
|
+
"ffn", "residual_add", "gate_mul", "embedding", "output", "source") used
|
|
21
|
+
by the architecture view to pick a glyph and lay out a slot.
|
|
22
|
+
|
|
23
|
+
Edges between blocks travel on the destination side as plain string fields:
|
|
24
|
+
|
|
25
|
+
* ``residual_from: "<other_block_id>"`` — the residual_add block consumes the
|
|
26
|
+
*input* of the named block (the standard pre-attention bypass pattern).
|
|
27
|
+
* ``lane: "left" | "right"`` — the block is rendered off the central chain
|
|
28
|
+
and connected via ``tap_from`` / ``feeds``. Reusable parts such as
|
|
29
|
+
per-layer embeddings use this instead of model-specific renderer logic.
|
|
30
|
+
"""
|
|
31
|
+
from __future__ import annotations
|
|
32
|
+
|
|
33
|
+
from .attention import attention_child_blocks
|
|
34
|
+
from .descriptions import (
|
|
35
|
+
attention_label,
|
|
36
|
+
attention_title,
|
|
37
|
+
describe_attention,
|
|
38
|
+
describe_ffn,
|
|
39
|
+
)
|
|
40
|
+
from .feed_forward import ffn_child_blocks, ffn_detail_view
|
|
41
|
+
from .layers import decoder_layer_blocks, parallel_decoder_layer_blocks
|
|
42
|
+
from .model import decoder_model_blocks, decoder_only_render_spec
|
|
43
|
+
|
|
44
|
+
__all__ = [
|
|
45
|
+
"attention_child_blocks",
|
|
46
|
+
"attention_label",
|
|
47
|
+
"attention_title",
|
|
48
|
+
"decoder_layer_blocks",
|
|
49
|
+
"decoder_model_blocks",
|
|
50
|
+
"decoder_only_render_spec",
|
|
51
|
+
"describe_attention",
|
|
52
|
+
"describe_ffn",
|
|
53
|
+
"ffn_child_blocks",
|
|
54
|
+
"ffn_detail_view",
|
|
55
|
+
"parallel_decoder_layer_blocks",
|
|
56
|
+
]
|
|
@@ -0,0 +1,479 @@
|
|
|
1
|
+
"""Reusable attention-family child block declarations."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from ....ir import AttentionSpec
|
|
5
|
+
from ..common import format_dim as _fmt
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def attention_child_blocks(attention: AttentionSpec, hidden_size: int) -> list[dict]:
|
|
9
|
+
builders = {
|
|
10
|
+
"mla": _mla_child_blocks,
|
|
11
|
+
"ssm": _ssm_child_blocks,
|
|
12
|
+
"recurrent": _recurrent_child_blocks,
|
|
13
|
+
"rwkv": _rwkv_child_blocks,
|
|
14
|
+
"linear": _linear_attention_child_blocks,
|
|
15
|
+
}
|
|
16
|
+
builder = builders.get(attention.kind, _sdpa_child_blocks)
|
|
17
|
+
return builder(attention, hidden_size)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _sdpa_child_blocks(attention: AttentionSpec, hidden_size: int) -> list[dict]:
|
|
21
|
+
hidden = _fmt(hidden_size)
|
|
22
|
+
num_heads = attention.num_heads or 0
|
|
23
|
+
num_kv_heads = attention.num_kv_heads or num_heads
|
|
24
|
+
head_dim = attention.head_dim or 0
|
|
25
|
+
q_per_group = num_heads // num_kv_heads if (num_heads and num_kv_heads and num_heads % num_kv_heads == 0) else None
|
|
26
|
+
q_out = _fmt(num_heads * head_dim) if (num_heads and head_dim) else hidden
|
|
27
|
+
kv_out = _fmt(num_kv_heads * head_dim) if (num_kv_heads and head_dim) else hidden
|
|
28
|
+
d_k = _fmt(head_dim) if head_dim else "d_k"
|
|
29
|
+
if attention.kind in {"mha", "gqa", "mqa"}:
|
|
30
|
+
return _sdpa_detailed_child_blocks(
|
|
31
|
+
attention.kind,
|
|
32
|
+
hidden,
|
|
33
|
+
q_out,
|
|
34
|
+
kv_out,
|
|
35
|
+
num_heads,
|
|
36
|
+
num_kv_heads,
|
|
37
|
+
d_k,
|
|
38
|
+
q_per_group,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
attention_title, attention_desc = _sdpa_operation_meta(attention, num_heads, num_kv_heads, d_k, q_per_group)
|
|
42
|
+
return [
|
|
43
|
+
{
|
|
44
|
+
"id": "q_proj",
|
|
45
|
+
"title": "Query projection",
|
|
46
|
+
"description": f"Linear; {hidden} -> {q_out} ({num_heads} heads x {d_k} dims)",
|
|
47
|
+
},
|
|
48
|
+
{
|
|
49
|
+
"id": "k_proj",
|
|
50
|
+
"title": "Key projection",
|
|
51
|
+
"description": f"Linear; {hidden} -> {kv_out} ({num_kv_heads} KV-heads x {d_k} dims)",
|
|
52
|
+
},
|
|
53
|
+
{
|
|
54
|
+
"id": "v_proj",
|
|
55
|
+
"title": "Value projection",
|
|
56
|
+
"description": f"Linear; {hidden} -> {kv_out} ({num_kv_heads} KV-heads x {d_k} dims)",
|
|
57
|
+
},
|
|
58
|
+
{
|
|
59
|
+
"id": "qkv_dot",
|
|
60
|
+
"title": attention_title,
|
|
61
|
+
"description": attention_desc,
|
|
62
|
+
},
|
|
63
|
+
{
|
|
64
|
+
"id": "o_proj",
|
|
65
|
+
"title": "Output projection",
|
|
66
|
+
"description": f"Linear; {q_out} -> {hidden} (recombines all {num_heads} heads)",
|
|
67
|
+
},
|
|
68
|
+
]
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def _sdpa_detailed_child_blocks(
|
|
72
|
+
kind: str,
|
|
73
|
+
hidden: str,
|
|
74
|
+
q_out: str,
|
|
75
|
+
kv_out: str,
|
|
76
|
+
num_heads: int,
|
|
77
|
+
num_kv_heads: int,
|
|
78
|
+
d_k: str,
|
|
79
|
+
q_per_group: int | None,
|
|
80
|
+
) -> list[dict]:
|
|
81
|
+
kv_label = "1 shared K/V head" if kind == "mqa" else f"{num_kv_heads} KV-heads"
|
|
82
|
+
scaled_title = "Scaled attention scores"
|
|
83
|
+
scaled_desc = "Per head: QK^T / sqrt(dim); dot-product scores scaled for numerical stability"
|
|
84
|
+
if kind == "gqa":
|
|
85
|
+
scaled_title = "Grouped scaled dot-product attention"
|
|
86
|
+
group = f"; each KV head serves {q_per_group} query heads" if q_per_group else ""
|
|
87
|
+
scaled_desc = (
|
|
88
|
+
f"Grouped SDPA scores: {num_heads} query heads attend through "
|
|
89
|
+
f"{num_kv_heads} shared K/V heads{group}; scores use QK^T / sqrt(dim)"
|
|
90
|
+
)
|
|
91
|
+
elif kind == "mqa":
|
|
92
|
+
scaled_title = "Multi-query scaled dot-product attention"
|
|
93
|
+
scaled_desc = (
|
|
94
|
+
f"Multi-Query SDPA scores: {num_heads} query heads share one K/V stream; "
|
|
95
|
+
"scores use QK^T / sqrt(dim)"
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
return [
|
|
99
|
+
{
|
|
100
|
+
"id": "q_proj",
|
|
101
|
+
"title": "Query projection",
|
|
102
|
+
"description": f"Linear; {hidden} -> {q_out} ({num_heads} heads x {d_k} dims)",
|
|
103
|
+
},
|
|
104
|
+
{
|
|
105
|
+
"id": "k_proj",
|
|
106
|
+
"title": "Key projection",
|
|
107
|
+
"description": f"Linear; {hidden} -> {kv_out} ({kv_label} x {d_k} dims)",
|
|
108
|
+
},
|
|
109
|
+
{
|
|
110
|
+
"id": "v_proj",
|
|
111
|
+
"title": "Value projection",
|
|
112
|
+
"description": f"Linear; {hidden} -> {kv_out} ({kv_label} x {d_k} dims)",
|
|
113
|
+
},
|
|
114
|
+
{
|
|
115
|
+
"id": "scaled_scores",
|
|
116
|
+
"title": scaled_title,
|
|
117
|
+
"description": scaled_desc,
|
|
118
|
+
},
|
|
119
|
+
{
|
|
120
|
+
"id": "attn_softmax",
|
|
121
|
+
"title": "Softmax weights",
|
|
122
|
+
"description": "Normalize each query row into attention weights over source tokens",
|
|
123
|
+
},
|
|
124
|
+
{
|
|
125
|
+
"id": "attn_apply_v",
|
|
126
|
+
"title": "Apply values",
|
|
127
|
+
"description": "Multiply attention weights by V to produce one context vector per head",
|
|
128
|
+
},
|
|
129
|
+
{
|
|
130
|
+
"id": "concat_heads",
|
|
131
|
+
"title": "Concatenate heads",
|
|
132
|
+
"description": f"Stack all {num_heads} per-head context vectors back into width {q_out}",
|
|
133
|
+
},
|
|
134
|
+
{
|
|
135
|
+
"id": "o_proj",
|
|
136
|
+
"title": "Output projection",
|
|
137
|
+
"description": f"Linear; {q_out} -> {hidden} (mixes information across heads)",
|
|
138
|
+
},
|
|
139
|
+
]
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def _sdpa_operation_meta(
|
|
143
|
+
attention: AttentionSpec,
|
|
144
|
+
num_heads: int,
|
|
145
|
+
num_kv_heads: int,
|
|
146
|
+
d_k: str,
|
|
147
|
+
q_per_group: int | None,
|
|
148
|
+
) -> tuple[str, str]:
|
|
149
|
+
if attention.kind == "mqa":
|
|
150
|
+
return (
|
|
151
|
+
"Multi-query scaled dot-product attention",
|
|
152
|
+
(
|
|
153
|
+
f"scores = softmax(QK^T / sqrt({d_k})); "
|
|
154
|
+
f"{num_heads} query heads share one K/V head"
|
|
155
|
+
),
|
|
156
|
+
)
|
|
157
|
+
if attention.kind == "gqa":
|
|
158
|
+
group = (
|
|
159
|
+
f"; each KV head serves {q_per_group} query heads"
|
|
160
|
+
if q_per_group
|
|
161
|
+
else ""
|
|
162
|
+
)
|
|
163
|
+
return (
|
|
164
|
+
"Grouped scaled dot-product attention",
|
|
165
|
+
(
|
|
166
|
+
f"scores = softmax(QK^T / sqrt({d_k})); "
|
|
167
|
+
f"{num_heads} query heads attend through {num_kv_heads} shared KV heads{group}"
|
|
168
|
+
),
|
|
169
|
+
)
|
|
170
|
+
return (
|
|
171
|
+
"Scaled dot-product attention",
|
|
172
|
+
(
|
|
173
|
+
f"scores = softmax(QK^T / sqrt({d_k})); "
|
|
174
|
+
"context = scores * V; "
|
|
175
|
+
f"output shape [batch, {num_heads}, seq, {d_k}]"
|
|
176
|
+
),
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def _mla_child_blocks(attention: AttentionSpec, hidden_size: int) -> list[dict]:
|
|
181
|
+
hidden = _fmt(hidden_size)
|
|
182
|
+
q_rank = _fmt(attention.q_lora_rank) if attention.q_lora_rank else "direct"
|
|
183
|
+
kv_rank = _fmt(attention.kv_lora_rank)
|
|
184
|
+
rope = _fmt(attention.rope_dim)
|
|
185
|
+
num_heads = attention.num_heads or 0
|
|
186
|
+
head_dim = attention.head_dim or 0
|
|
187
|
+
q_out = _fmt(num_heads * head_dim) if (num_heads and head_dim) else hidden
|
|
188
|
+
query_children = [
|
|
189
|
+
{
|
|
190
|
+
"id": "mla_q",
|
|
191
|
+
"label": "Q projection",
|
|
192
|
+
"title": "Query projection",
|
|
193
|
+
"description": (
|
|
194
|
+
f"Projects hidden states into query latent space through LoRA rank {q_rank}"
|
|
195
|
+
if attention.q_lora_rank
|
|
196
|
+
else f"Projects hidden states directly into query heads; {hidden} -> {q_out}"
|
|
197
|
+
),
|
|
198
|
+
},
|
|
199
|
+
{
|
|
200
|
+
"id": "mla_q_nope",
|
|
201
|
+
"label": "Q noPE",
|
|
202
|
+
"title": "Query content slice",
|
|
203
|
+
"description": "Query content component that does not receive rotary position encoding",
|
|
204
|
+
},
|
|
205
|
+
{
|
|
206
|
+
"id": "mla_q_rope",
|
|
207
|
+
"label": "Q RoPE",
|
|
208
|
+
"title": "Query positional slice",
|
|
209
|
+
"description": f"Query positional component prepared for rotary position encoding; dim {rope}",
|
|
210
|
+
},
|
|
211
|
+
{
|
|
212
|
+
"id": "mla_q_rope_apply",
|
|
213
|
+
"label": "Apply RoPE",
|
|
214
|
+
"title": "Apply RoPE to query",
|
|
215
|
+
"description": "Applies rotary position encoding to the query positional slice",
|
|
216
|
+
},
|
|
217
|
+
{
|
|
218
|
+
"id": "mla_q_concat",
|
|
219
|
+
"label": "Q concat",
|
|
220
|
+
"title": "Final MLA query",
|
|
221
|
+
"description": "Concatenates Q noPE with RoPE-encoded Q RoPE before score computation",
|
|
222
|
+
},
|
|
223
|
+
]
|
|
224
|
+
kv_children = [
|
|
225
|
+
{
|
|
226
|
+
"id": "mla_kv_down",
|
|
227
|
+
"label": "KV compress",
|
|
228
|
+
"title": "K/V latent compression",
|
|
229
|
+
"description": f"Compresses the token state into the shared latent K/V cache; {hidden} -> rank {kv_rank}",
|
|
230
|
+
},
|
|
231
|
+
{
|
|
232
|
+
"id": "mla_cache",
|
|
233
|
+
"label": "latent cache c_t",
|
|
234
|
+
"title": "Stored latent cache",
|
|
235
|
+
"description": f"Compressed K/V latent stored in the cache instead of full K and V heads; rank {kv_rank}",
|
|
236
|
+
},
|
|
237
|
+
{
|
|
238
|
+
"id": "mla_kv_up",
|
|
239
|
+
"label": "KV expand",
|
|
240
|
+
"title": "K/V head expansion",
|
|
241
|
+
"description": f"Expands cached latent c_t into K noPE content and V values for {num_heads} query heads",
|
|
242
|
+
},
|
|
243
|
+
{
|
|
244
|
+
"id": "mla_k_nope",
|
|
245
|
+
"label": "K noPE",
|
|
246
|
+
"title": "Latent key content",
|
|
247
|
+
"description": "Key content expanded from the compressed K/V latent; concatenated with the RoPE key before scoring",
|
|
248
|
+
},
|
|
249
|
+
{
|
|
250
|
+
"id": "mla_k_rope",
|
|
251
|
+
"label": "K RoPE",
|
|
252
|
+
"title": "Key positional slice",
|
|
253
|
+
"description": f"Key positional component produced alongside the latent cache; dim {rope}",
|
|
254
|
+
},
|
|
255
|
+
{
|
|
256
|
+
"id": "mla_k_rope_apply",
|
|
257
|
+
"label": "Apply RoPE",
|
|
258
|
+
"title": "Apply RoPE to key",
|
|
259
|
+
"description": "Applies rotary position encoding to the key positional slice",
|
|
260
|
+
},
|
|
261
|
+
{
|
|
262
|
+
"id": "mla_k_merge",
|
|
263
|
+
"label": "K concat",
|
|
264
|
+
"title": "Composed MLA key",
|
|
265
|
+
"description": "Concatenates K noPE with the RoPE key side-channel before QK^T score computation",
|
|
266
|
+
},
|
|
267
|
+
{
|
|
268
|
+
"id": "mla_v",
|
|
269
|
+
"label": "V values",
|
|
270
|
+
"title": "Latent value heads",
|
|
271
|
+
"description": "Value heads expanded from the compressed K/V latent; consumed after softmax",
|
|
272
|
+
},
|
|
273
|
+
]
|
|
274
|
+
return [
|
|
275
|
+
{
|
|
276
|
+
"id": "mla_query_path",
|
|
277
|
+
"label": "Query path",
|
|
278
|
+
"title": "MLA query path",
|
|
279
|
+
"description": (
|
|
280
|
+
"Builds Q by projecting the hidden state, splitting content and positional slices, "
|
|
281
|
+
"applying RoPE to the positional slice, then concatenating them"
|
|
282
|
+
),
|
|
283
|
+
"detail_view": "mla_query_path",
|
|
284
|
+
"children": query_children,
|
|
285
|
+
},
|
|
286
|
+
{
|
|
287
|
+
"id": "mla_kv_path",
|
|
288
|
+
"label": "KV cache path",
|
|
289
|
+
"title": "MLA K/V cache path",
|
|
290
|
+
"description": (
|
|
291
|
+
f"Compresses hidden state into rank {kv_rank} latent cache, expands K/V content, "
|
|
292
|
+
"and combines K noPE with a RoPE key side-channel"
|
|
293
|
+
),
|
|
294
|
+
"detail_view": "mla_kv_cache_path",
|
|
295
|
+
"children": kv_children,
|
|
296
|
+
},
|
|
297
|
+
{
|
|
298
|
+
"id": "scaled_scores",
|
|
299
|
+
"label": "Latent scores",
|
|
300
|
+
"title": "Multi-Head Latent scores",
|
|
301
|
+
"description": "Q attends to expanded latent K plus the RoPE key side-channel; scores use QK^T / sqrt(dim)",
|
|
302
|
+
},
|
|
303
|
+
{
|
|
304
|
+
"id": "attn_softmax",
|
|
305
|
+
"label": "Softmax",
|
|
306
|
+
"title": "Softmax weights",
|
|
307
|
+
"description": "Normalize latent attention scores over source positions",
|
|
308
|
+
},
|
|
309
|
+
{
|
|
310
|
+
"id": "attn_apply_v",
|
|
311
|
+
"label": "Apply V",
|
|
312
|
+
"title": "Apply latent values",
|
|
313
|
+
"description": "Multiply softmax weights by V expanded from the compressed K/V latent",
|
|
314
|
+
},
|
|
315
|
+
{
|
|
316
|
+
"id": "concat_heads",
|
|
317
|
+
"label": "Concat heads",
|
|
318
|
+
"title": "Concatenate latent heads",
|
|
319
|
+
"description": f"Stack all {num_heads} context heads back into width {q_out}",
|
|
320
|
+
},
|
|
321
|
+
{
|
|
322
|
+
"id": "o_proj",
|
|
323
|
+
"label": "Linear (out)",
|
|
324
|
+
"title": "Output projection",
|
|
325
|
+
"description": f"Linear; {q_out} -> {hidden}",
|
|
326
|
+
},
|
|
327
|
+
]
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
def _ssm_child_blocks(attention: AttentionSpec, hidden_size: int) -> list[dict]:
|
|
331
|
+
hidden = _fmt(hidden_size)
|
|
332
|
+
state = _fmt(attention.head_dim)
|
|
333
|
+
return [
|
|
334
|
+
{
|
|
335
|
+
"id": "ssm_in_proj",
|
|
336
|
+
"label": "Input projection",
|
|
337
|
+
"title": "SSM input projection",
|
|
338
|
+
"description": f"Project hidden activations into SSM channels; hidden {hidden}",
|
|
339
|
+
},
|
|
340
|
+
{
|
|
341
|
+
"id": "ssm_conv",
|
|
342
|
+
"label": "Local conv",
|
|
343
|
+
"title": "Short convolution",
|
|
344
|
+
"description": "Depthwise local mixing before the state-space recurrence",
|
|
345
|
+
},
|
|
346
|
+
{
|
|
347
|
+
"id": "ssm_scan",
|
|
348
|
+
"label": "Selective scan",
|
|
349
|
+
"title": "Selective state-space scan",
|
|
350
|
+
"description": f"Token recurrence with state dimension {state}",
|
|
351
|
+
},
|
|
352
|
+
{
|
|
353
|
+
"id": "ssm_gate",
|
|
354
|
+
"label": "Gate",
|
|
355
|
+
"title": "SSM gate",
|
|
356
|
+
"description": "Element-wise gate controlling the recurrent output",
|
|
357
|
+
},
|
|
358
|
+
{
|
|
359
|
+
"id": "ssm_out_proj",
|
|
360
|
+
"label": "Output projection",
|
|
361
|
+
"title": "SSM output projection",
|
|
362
|
+
"description": f"Project SSM channels back to hidden dim {hidden}",
|
|
363
|
+
},
|
|
364
|
+
]
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
def _recurrent_child_blocks(attention: AttentionSpec, hidden_size: int) -> list[dict]:
|
|
368
|
+
hidden = _fmt(hidden_size)
|
|
369
|
+
width = _fmt(attention.head_dim)
|
|
370
|
+
return [
|
|
371
|
+
{
|
|
372
|
+
"id": "lru_in_proj",
|
|
373
|
+
"label": "Input projection",
|
|
374
|
+
"title": "LRU input projection",
|
|
375
|
+
"description": f"Linear; hidden {hidden} -> recurrent width {width}",
|
|
376
|
+
},
|
|
377
|
+
{
|
|
378
|
+
"id": "lru_state",
|
|
379
|
+
"label": "Recurrent state",
|
|
380
|
+
"title": "Linear recurrent state",
|
|
381
|
+
"description": f"State update over sequence positions; width {width}",
|
|
382
|
+
},
|
|
383
|
+
{
|
|
384
|
+
"id": "lru_gate",
|
|
385
|
+
"label": "Gate",
|
|
386
|
+
"title": "Recurrent gate",
|
|
387
|
+
"description": "Element-wise gate controlling recurrent features",
|
|
388
|
+
},
|
|
389
|
+
{
|
|
390
|
+
"id": "lru_out_proj",
|
|
391
|
+
"label": "Output projection",
|
|
392
|
+
"title": "LRU output projection",
|
|
393
|
+
"description": f"Linear; recurrent width {width} -> hidden {hidden}",
|
|
394
|
+
},
|
|
395
|
+
]
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
def _rwkv_child_blocks(attention: AttentionSpec, hidden_size: int) -> list[dict]:
|
|
399
|
+
hidden = _fmt(hidden_size)
|
|
400
|
+
heads = attention.num_heads or 0
|
|
401
|
+
return [
|
|
402
|
+
{
|
|
403
|
+
"id": "rwkv_receptance",
|
|
404
|
+
"label": "Receptance",
|
|
405
|
+
"title": "Receptance gate",
|
|
406
|
+
"description": f"Token-wise gate over hidden dim {hidden}",
|
|
407
|
+
},
|
|
408
|
+
{
|
|
409
|
+
"id": "rwkv_key",
|
|
410
|
+
"label": "Key",
|
|
411
|
+
"title": "RWKV key projection",
|
|
412
|
+
"description": f"Key-like channel mixing over {heads} recurrent heads",
|
|
413
|
+
},
|
|
414
|
+
{
|
|
415
|
+
"id": "rwkv_value",
|
|
416
|
+
"label": "Value",
|
|
417
|
+
"title": "RWKV value projection",
|
|
418
|
+
"description": "Value channel sent through time-mixing recurrence",
|
|
419
|
+
},
|
|
420
|
+
{
|
|
421
|
+
"id": "rwkv_time_mix",
|
|
422
|
+
"label": "Time-mix",
|
|
423
|
+
"title": "Time-decay recurrence",
|
|
424
|
+
"description": "Linear-time weighted recurrence replacing self-attention",
|
|
425
|
+
},
|
|
426
|
+
{
|
|
427
|
+
"id": "rwkv_out",
|
|
428
|
+
"label": "Output projection",
|
|
429
|
+
"title": "RWKV output projection",
|
|
430
|
+
"description": f"Project mixed channels back to hidden dim {hidden}",
|
|
431
|
+
},
|
|
432
|
+
]
|
|
433
|
+
|
|
434
|
+
|
|
435
|
+
def _linear_attention_child_blocks(attention: AttentionSpec, hidden_size: int) -> list[dict]:
|
|
436
|
+
hidden = _fmt(hidden_size)
|
|
437
|
+
num_heads = attention.num_heads or 0
|
|
438
|
+
num_kv_heads = attention.num_kv_heads or num_heads
|
|
439
|
+
head_dim = attention.head_dim or 0
|
|
440
|
+
q_out = _fmt(num_heads * head_dim) if (num_heads and head_dim) else hidden
|
|
441
|
+
kv_out = _fmt(num_kv_heads * head_dim) if (num_kv_heads and head_dim) else hidden
|
|
442
|
+
return [
|
|
443
|
+
{
|
|
444
|
+
"id": "q_proj",
|
|
445
|
+
"label": "Linear (Q)",
|
|
446
|
+
"title": "Query projection",
|
|
447
|
+
"description": f"Linear; {hidden} -> {q_out}",
|
|
448
|
+
},
|
|
449
|
+
{
|
|
450
|
+
"id": "k_proj",
|
|
451
|
+
"label": "Linear (K)",
|
|
452
|
+
"title": "Key projection",
|
|
453
|
+
"description": f"Linear; {hidden} -> {kv_out}",
|
|
454
|
+
},
|
|
455
|
+
{
|
|
456
|
+
"id": "v_proj",
|
|
457
|
+
"label": "Linear (V)",
|
|
458
|
+
"title": "Value projection",
|
|
459
|
+
"description": f"Linear; {hidden} -> {kv_out}",
|
|
460
|
+
},
|
|
461
|
+
{
|
|
462
|
+
"id": "kernel_map",
|
|
463
|
+
"label": "Kernel map",
|
|
464
|
+
"title": "Feature map",
|
|
465
|
+
"description": "Apply kernel feature map so attention can be accumulated linearly",
|
|
466
|
+
},
|
|
467
|
+
{
|
|
468
|
+
"id": "linear_mix",
|
|
469
|
+
"label": "Linear mix",
|
|
470
|
+
"title": "Linear attention mix",
|
|
471
|
+
"description": "Prefix/state accumulation computes attention in linear time",
|
|
472
|
+
},
|
|
473
|
+
{
|
|
474
|
+
"id": "o_proj",
|
|
475
|
+
"label": "Linear (out)",
|
|
476
|
+
"title": "Output projection",
|
|
477
|
+
"description": f"Linear; {q_out} -> {hidden}",
|
|
478
|
+
},
|
|
479
|
+
]
|