model-unfolder 0.2.4__tar.gz → 0.2.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. {model_unfolder-0.2.4 → model_unfolder-0.2.6}/PKG-INFO +12 -5
  2. {model_unfolder-0.2.4 → model_unfolder-0.2.6}/README.md +11 -4
  3. {model_unfolder-0.2.4 → model_unfolder-0.2.6}/model_unfolder/adapters/transformer/assembly.py +30 -2
  4. model_unfolder-0.2.6/model_unfolder/adapters/transformer/blocks/__init__.py +56 -0
  5. model_unfolder-0.2.6/model_unfolder/adapters/transformer/blocks/attention.py +298 -0
  6. model_unfolder-0.2.6/model_unfolder/adapters/transformer/blocks/descriptions.py +132 -0
  7. model_unfolder-0.2.6/model_unfolder/adapters/transformer/blocks/feed_forward.py +166 -0
  8. model_unfolder-0.2.6/model_unfolder/adapters/transformer/blocks/layers.py +123 -0
  9. model_unfolder-0.2.6/model_unfolder/adapters/transformer/blocks/model.py +52 -0
  10. model_unfolder-0.2.6/model_unfolder/adapters/transformer/families/__init__.py +26 -0
  11. model_unfolder-0.2.6/model_unfolder/adapters/transformer/families/cohere.py +106 -0
  12. {model_unfolder-0.2.4 → model_unfolder-0.2.6}/model_unfolder/adapters/transformer/families/deepseek.py +9 -10
  13. model_unfolder-0.2.6/model_unfolder/adapters/transformer/families/falcon.py +183 -0
  14. model_unfolder-0.2.6/model_unfolder/adapters/transformer/families/gemma/__init__.py +26 -0
  15. model_unfolder-0.2.6/model_unfolder/adapters/transformer/families/gemma/gemma2.py +101 -0
  16. model_unfolder-0.2.6/model_unfolder/adapters/transformer/families/gemma/gemma3.py +137 -0
  17. {model_unfolder-0.2.4/model_unfolder/adapters/transformer/families → model_unfolder-0.2.6/model_unfolder/adapters/transformer/families/gemma}/gemma4.py +5 -7
  18. model_unfolder-0.2.6/model_unfolder/adapters/transformer/families/gpt_neox.py +114 -0
  19. model_unfolder-0.2.6/model_unfolder/adapters/transformer/families/jamba.py +126 -0
  20. {model_unfolder-0.2.4 → model_unfolder-0.2.6}/model_unfolder/adapters/transformer/families/llama.py +57 -9
  21. model_unfolder-0.2.6/model_unfolder/adapters/transformer/families/minimax.py +99 -0
  22. {model_unfolder-0.2.4 → model_unfolder-0.2.6}/model_unfolder/adapters/transformer/families/qwen.py +5 -1
  23. model_unfolder-0.2.6/model_unfolder/adapters/transformer/families/recurrent_gemma.py +103 -0
  24. model_unfolder-0.2.6/model_unfolder/adapters/transformer/families/rwkv.py +74 -0
  25. model_unfolder-0.2.6/model_unfolder/adapters/transformer/families/zamba.py +127 -0
  26. {model_unfolder-0.2.4 → model_unfolder-0.2.6}/model_unfolder/ir.py +8 -1
  27. {model_unfolder-0.2.4 → model_unfolder-0.2.6}/model_unfolder/labels.py +57 -10
  28. {model_unfolder-0.2.4 → model_unfolder-0.2.6}/model_unfolder/parser.py +36 -3
  29. model_unfolder-0.2.6/model_unfolder/renderers/html/block_views/__init__.py +33 -0
  30. {model_unfolder-0.2.4 → model_unfolder-0.2.6}/model_unfolder/renderers/html/block_views/attention.py +34 -3
  31. model_unfolder-0.2.6/model_unfolder/renderers/html/block_views/attention_types/__init__.py +22 -0
  32. model_unfolder-0.2.6/model_unfolder/renderers/html/block_views/attention_types/common.py +310 -0
  33. model_unfolder-0.2.6/model_unfolder/renderers/html/block_views/attention_types/grouped_query.py +90 -0
  34. model_unfolder-0.2.6/model_unfolder/renderers/html/block_views/attention_types/latent.py +61 -0
  35. model_unfolder-0.2.6/model_unfolder/renderers/html/block_views/attention_types/linear.py +58 -0
  36. model_unfolder-0.2.6/model_unfolder/renderers/html/block_views/attention_types/multi_head.py +74 -0
  37. model_unfolder-0.2.6/model_unfolder/renderers/html/block_views/attention_types/multi_query.py +118 -0
  38. model_unfolder-0.2.6/model_unfolder/renderers/html/block_views/attention_types/rwkv.py +54 -0
  39. model_unfolder-0.2.6/model_unfolder/renderers/html/block_views/attention_types/state_space.py +44 -0
  40. model_unfolder-0.2.6/model_unfolder/renderers/html/block_views/feed_forward.py +95 -0
  41. model_unfolder-0.2.6/model_unfolder/renderers/html/block_views/mixture_of_experts.py +117 -0
  42. model_unfolder-0.2.6/model_unfolder/renderers/html/cards.py +234 -0
  43. model_unfolder-0.2.6/model_unfolder/renderers/html/document.py +201 -0
  44. model_unfolder-0.2.6/model_unfolder/renderers/html/interactions.py +125 -0
  45. {model_unfolder-0.2.4 → model_unfolder-0.2.6}/model_unfolder/renderers/html/metadata.py +3 -0
  46. {model_unfolder-0.2.4 → model_unfolder-0.2.6}/model_unfolder/renderers/html/sections.py +10 -0
  47. {model_unfolder-0.2.4 → model_unfolder-0.2.6}/model_unfolder/renderers/html/styles.py +39 -17
  48. {model_unfolder-0.2.4 → model_unfolder-0.2.6}/model_unfolder/renderers/html/views.py +55 -11
  49. {model_unfolder-0.2.4 → model_unfolder-0.2.6}/model_unfolder.egg-info/PKG-INFO +12 -5
  50. {model_unfolder-0.2.4 → model_unfolder-0.2.6}/model_unfolder.egg-info/SOURCES.txt +28 -2
  51. {model_unfolder-0.2.4 → model_unfolder-0.2.6}/pyproject.toml +1 -1
  52. {model_unfolder-0.2.4 → model_unfolder-0.2.6}/tests/test_smoke.py +213 -2
  53. model_unfolder-0.2.4/model_unfolder/adapters/transformer/blocks.py +0 -238
  54. model_unfolder-0.2.4/model_unfolder/adapters/transformer/families/__init__.py +0 -13
  55. model_unfolder-0.2.4/model_unfolder/renderers/html/block_views/__init__.py +0 -20
  56. model_unfolder-0.2.4/model_unfolder/renderers/html/block_views/feed_forward.py +0 -213
  57. model_unfolder-0.2.4/model_unfolder/renderers/html/cards.py +0 -130
  58. model_unfolder-0.2.4/model_unfolder/renderers/html/document.py +0 -157
  59. model_unfolder-0.2.4/model_unfolder/renderers/html/interactions.py +0 -64
  60. {model_unfolder-0.2.4 → model_unfolder-0.2.6}/LICENSE +0 -0
  61. {model_unfolder-0.2.4 → model_unfolder-0.2.6}/model_unfolder/__init__.py +0 -0
  62. {model_unfolder-0.2.4 → model_unfolder-0.2.6}/model_unfolder/adapters/__init__.py +0 -0
  63. {model_unfolder-0.2.4 → model_unfolder-0.2.6}/model_unfolder/adapters/custom/__init__.py +0 -0
  64. {model_unfolder-0.2.4 → model_unfolder-0.2.6}/model_unfolder/adapters/diffusor/__init__.py +0 -0
  65. {model_unfolder-0.2.4 → model_unfolder-0.2.6}/model_unfolder/adapters/transformer/__init__.py +0 -0
  66. {model_unfolder-0.2.4 → model_unfolder-0.2.6}/model_unfolder/adapters/transformer/common.py +0 -0
  67. {model_unfolder-0.2.4 → model_unfolder-0.2.6}/model_unfolder/adapters/transformer/families/fallback.py +0 -0
  68. {model_unfolder-0.2.4 → model_unfolder-0.2.6}/model_unfolder/adapters/transformer/families/mistral.py +0 -0
  69. {model_unfolder-0.2.4 → model_unfolder-0.2.6}/model_unfolder/adapters/transformer/special_parts/__init__.py +0 -0
  70. {model_unfolder-0.2.4 → model_unfolder-0.2.6}/model_unfolder/adapters/transformer/special_parts/per_layer_embedding.py +0 -0
  71. {model_unfolder-0.2.4 → model_unfolder-0.2.6}/model_unfolder/diagram.py +0 -0
  72. {model_unfolder-0.2.4 → model_unfolder-0.2.6}/model_unfolder/html_renderer.py +0 -0
  73. {model_unfolder-0.2.4 → model_unfolder-0.2.6}/model_unfolder/params.py +0 -0
  74. {model_unfolder-0.2.4 → model_unfolder-0.2.6}/model_unfolder/renderers/__init__.py +0 -0
  75. {model_unfolder-0.2.4 → model_unfolder-0.2.6}/model_unfolder/renderers/html/__init__.py +0 -0
  76. {model_unfolder-0.2.4 → model_unfolder-0.2.6}/model_unfolder/renderers/html/block_views/per_layer_embedding.py +0 -0
  77. {model_unfolder-0.2.4 → model_unfolder-0.2.6}/model_unfolder/renderers/html/svg.py +0 -0
  78. {model_unfolder-0.2.4 → model_unfolder-0.2.6}/model_unfolder/renderers/html/theme.py +0 -0
  79. {model_unfolder-0.2.4 → model_unfolder-0.2.6}/model_unfolder/renderers/html/utils.py +0 -0
  80. {model_unfolder-0.2.4 → model_unfolder-0.2.6}/model_unfolder.egg-info/dependency_links.txt +0 -0
  81. {model_unfolder-0.2.4 → model_unfolder-0.2.6}/model_unfolder.egg-info/requires.txt +0 -0
  82. {model_unfolder-0.2.4 → model_unfolder-0.2.6}/model_unfolder.egg-info/top_level.txt +0 -0
  83. {model_unfolder-0.2.4 → model_unfolder-0.2.6}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: model-unfolder
3
- Version: 0.2.4
3
+ Version: 0.2.6
4
4
  Summary: Unfold any HuggingFace transformer into an interactive architecture diagram, inline in Jupyter.
5
5
  Author: model-unfolder contributors
6
6
  License: Apache-2.0
@@ -108,11 +108,18 @@ Param estimates are close to published numbers — DeepSeek-V3 reports `~675B (~
108
108
 
109
109
  | Family | Models |
110
110
  |---|---|
111
- | DeepSeek | DeepSeek-V2, DeepSeek-V3, Kimi K2 |
112
- | Llama | Llama 3 / 3.1 / 3.2 / 3.3 |
111
+ | DeepSeek | DeepSeek-V2, DeepSeek-V3 (+ MTP head), Kimi K2 |
112
+ | Llama | Llama 3 / 3.1 / 3.2 / 3.3, OLMo-2, Llama 4 Scout / Maverick (MoE + iRoPE NoPE layers) |
113
113
  | Mistral | Mistral 7B, Mixtral 8x7B / 8x22B, Mistral Medium 3.5 |
114
- | Qwen | Qwen2 / 2.5, Qwen2-MoE, Qwen3, Qwen3-MoE, Qwen3.5 / 3.6 (hybrid linear+full attn) |
115
- | Gemma | Gemma 3, Gemma 4 (31B, E2B, E4B) |
114
+ | Qwen | Qwen2 / 2.5, Qwen2-MoE, Qwen3, Qwen3-MoE, Qwen3.5 / 3.6 (+ MTP) |
115
+ | Gemma | Gemma 2 9B / 27B (interleaved local+global), Gemma 3 / 3n (+ PLE), Gemma 4 31B / E2B / E4B (+ PLE), RecurrentGemma 2B / 9B (LRU + local attention) |
116
+ | Cohere | Command R, Command R+, Command R7B (QK-Norm attention) |
117
+ | Jamba | Jamba (SSM + attention hybrid, MoE) |
118
+ | Zamba | Zamba 7B, Zamba2 2.7B / 7B (Mamba SSM + weight-shared attention) |
119
+ | Mamba | Mamba 130M–2.8B, Mamba-2 (pure SSM, no attention) |
120
+ | Falcon | Falcon 7B / 40B (parallel attn+FFN), Falcon-H1 (Mamba-2 SSM) |
121
+ | MiniMax | MiniMax-Text-01 (lightning + softmax hybrid, MoE) |
122
+ | RWKV | RWKV-4 / 5 / 6 (pure recurrent, no attention) |
116
123
 
117
124
  ### Diffusors
118
125
 
@@ -84,11 +84,18 @@ Param estimates are close to published numbers — DeepSeek-V3 reports `~675B (~
84
84
 
85
85
  | Family | Models |
86
86
  |---|---|
87
- | DeepSeek | DeepSeek-V2, DeepSeek-V3, Kimi K2 |
88
- | Llama | Llama 3 / 3.1 / 3.2 / 3.3 |
87
+ | DeepSeek | DeepSeek-V2, DeepSeek-V3 (+ MTP head), Kimi K2 |
88
+ | Llama | Llama 3 / 3.1 / 3.2 / 3.3, OLMo-2, Llama 4 Scout / Maverick (MoE + iRoPE NoPE layers) |
89
89
  | Mistral | Mistral 7B, Mixtral 8x7B / 8x22B, Mistral Medium 3.5 |
90
- | Qwen | Qwen2 / 2.5, Qwen2-MoE, Qwen3, Qwen3-MoE, Qwen3.5 / 3.6 (hybrid linear+full attn) |
91
- | Gemma | Gemma 3, Gemma 4 (31B, E2B, E4B) |
90
+ | Qwen | Qwen2 / 2.5, Qwen2-MoE, Qwen3, Qwen3-MoE, Qwen3.5 / 3.6 (+ MTP) |
91
+ | Gemma | Gemma 2 9B / 27B (interleaved local+global), Gemma 3 / 3n (+ PLE), Gemma 4 31B / E2B / E4B (+ PLE), RecurrentGemma 2B / 9B (LRU + local attention) |
92
+ | Cohere | Command R, Command R+, Command R7B (QK-Norm attention) |
93
+ | Jamba | Jamba (SSM + attention hybrid, MoE) |
94
+ | Zamba | Zamba 7B, Zamba2 2.7B / 7B (Mamba SSM + weight-shared attention) |
95
+ | Mamba | Mamba 130M–2.8B, Mamba-2 (pure SSM, no attention) |
96
+ | Falcon | Falcon 7B / 40B (parallel attn+FFN), Falcon-H1 (Mamba-2 SSM) |
97
+ | MiniMax | MiniMax-Text-01 (lightning + softmax hybrid, MoE) |
98
+ | RWKV | RWKV-4 / 5 / 6 (pure recurrent, no attention) |
92
99
 
93
100
  ### Diffusors
94
101
 
@@ -5,7 +5,7 @@ from collections.abc import Iterable, Mapping
5
5
  from typing import Any
6
6
 
7
7
  from ...ir import AttentionSpec, FFNSpec, LayerSpec
8
- from .blocks import decoder_layer_blocks, decoder_only_render_spec
8
+ from .blocks import decoder_layer_blocks, decoder_only_render_spec, parallel_decoder_layer_blocks
9
9
 
10
10
 
11
11
  def decoder_layer(
@@ -15,15 +15,43 @@ def decoder_layer(
15
15
  hidden_size: int,
16
16
  *,
17
17
  extra_blocks: Iterable[dict] | None = None,
18
+ norm_kind: str = "rmsnorm",
19
+ norm_placement: str = "pre",
18
20
  ) -> LayerSpec:
19
21
  """Build a decoder layer from parsed specs plus optional reusable parts."""
20
- blocks = decoder_layer_blocks(attention, ffn, hidden_size)
22
+ blocks = decoder_layer_blocks(attention, ffn, hidden_size, norm_kind=norm_kind)
21
23
  if extra_blocks:
22
24
  blocks.extend(extra_blocks)
23
25
  return LayerSpec(
24
26
  index=index,
25
27
  attention=attention,
26
28
  ffn=ffn,
29
+ norm_kind=norm_kind,
30
+ norm_placement=norm_placement,
31
+ blocks=blocks,
32
+ )
33
+
34
+
35
+ def parallel_decoder_layer(
36
+ index: int,
37
+ attention: AttentionSpec,
38
+ ffn: FFNSpec,
39
+ hidden_size: int,
40
+ *,
41
+ norm_kind: str = "rmsnorm",
42
+ ) -> LayerSpec:
43
+ """Build a parallel-residual decoder layer (GPT-NeoX / GPT-J).
44
+
45
+ Attention and FFN share a single input norm and their outputs are summed
46
+ into one residual add rather than two sequential adds.
47
+ """
48
+ blocks = parallel_decoder_layer_blocks(attention, ffn, hidden_size, norm_kind=norm_kind)
49
+ return LayerSpec(
50
+ index=index,
51
+ attention=attention,
52
+ ffn=ffn,
53
+ norm_kind=norm_kind,
54
+ norm_placement="pre",
27
55
  blocks=blocks,
28
56
  )
29
57
 
@@ -0,0 +1,56 @@
1
+ """Reusable transformer block descriptions for renderers.
2
+
3
+ Adapters attach these block parts to the IR. Renderers can then draw generic
4
+ decoder-only transformer layouts without rediscovering model-specific names
5
+ or labels with another layer of ``if model_type`` logic.
6
+
7
+ The package is split by responsibility:
8
+
9
+ * ``model``: model-level input/output bookend blocks.
10
+ * ``layers``: sequential and parallel decoder-layer topology.
11
+ * ``attention``: reusable attention/SSM/recurrent child block contracts.
12
+ * ``feed_forward``: dense, gated, and MoE FFN child block contracts.
13
+ * ``descriptions``: labels, titles, and short metadata strings.
14
+
15
+ Each block carries two orthogonal tags:
16
+
17
+ * ``role`` — semantic ("norm", "attention", "ffn", "residual", "gate") used
18
+ for tooltips, click handlers, and the inspect cards.
19
+ * ``kind`` — rendering shape ("norm", "linear", "activation", "attention",
20
+ "ffn", "residual_add", "gate_mul", "embedding", "output", "source") used
21
+ by the architecture view to pick a glyph and lay out a slot.
22
+
23
+ Edges between blocks travel on the destination side as plain string fields:
24
+
25
+ * ``residual_from: "<other_block_id>"`` — the residual_add block consumes the
26
+ *input* of the named block (the standard pre-attention bypass pattern).
27
+ * ``lane: "left" | "right"`` — the block is rendered off the central chain
28
+ and connected via ``tap_from`` / ``feeds``. Reusable parts such as
29
+ per-layer embeddings use this instead of model-specific renderer logic.
30
+ """
31
+ from __future__ import annotations
32
+
33
+ from .attention import attention_child_blocks
34
+ from .descriptions import (
35
+ attention_label,
36
+ attention_title,
37
+ describe_attention,
38
+ describe_ffn,
39
+ )
40
+ from .feed_forward import ffn_child_blocks, ffn_detail_view
41
+ from .layers import decoder_layer_blocks, parallel_decoder_layer_blocks
42
+ from .model import decoder_model_blocks, decoder_only_render_spec
43
+
44
+ __all__ = [
45
+ "attention_child_blocks",
46
+ "attention_label",
47
+ "attention_title",
48
+ "decoder_layer_blocks",
49
+ "decoder_model_blocks",
50
+ "decoder_only_render_spec",
51
+ "describe_attention",
52
+ "describe_ffn",
53
+ "ffn_child_blocks",
54
+ "ffn_detail_view",
55
+ "parallel_decoder_layer_blocks",
56
+ ]
@@ -0,0 +1,298 @@
1
+ """Reusable attention-family child block declarations."""
2
+ from __future__ import annotations
3
+
4
+ from ....ir import AttentionSpec
5
+ from ..common import format_dim as _fmt
6
+
7
+
8
+ def attention_child_blocks(attention: AttentionSpec, hidden_size: int) -> list[dict]:
9
+ builders = {
10
+ "mla": _mla_child_blocks,
11
+ "ssm": _ssm_child_blocks,
12
+ "recurrent": _recurrent_child_blocks,
13
+ "rwkv": _rwkv_child_blocks,
14
+ "linear": _linear_attention_child_blocks,
15
+ }
16
+ builder = builders.get(attention.kind, _sdpa_child_blocks)
17
+ return builder(attention, hidden_size)
18
+
19
+
20
+ def _sdpa_child_blocks(attention: AttentionSpec, hidden_size: int) -> list[dict]:
21
+ hidden = _fmt(hidden_size)
22
+ num_heads = attention.num_heads or 0
23
+ num_kv_heads = attention.num_kv_heads or num_heads
24
+ head_dim = attention.head_dim or 0
25
+ q_per_group = num_heads // num_kv_heads if (num_heads and num_kv_heads and num_heads % num_kv_heads == 0) else None
26
+ q_out = _fmt(num_heads * head_dim) if (num_heads and head_dim) else hidden
27
+ kv_out = _fmt(num_kv_heads * head_dim) if (num_kv_heads and head_dim) else hidden
28
+ d_k = _fmt(head_dim) if head_dim else "d_k"
29
+ attention_title, attention_desc = _sdpa_operation_meta(attention, num_heads, num_kv_heads, d_k, q_per_group)
30
+ return [
31
+ {
32
+ "id": "q_proj",
33
+ "title": "Query projection",
34
+ "description": f"Linear; {hidden} -> {q_out} ({num_heads} heads x {d_k} dims)",
35
+ },
36
+ {
37
+ "id": "k_proj",
38
+ "title": "Key projection",
39
+ "description": f"Linear; {hidden} -> {kv_out} ({num_kv_heads} KV-heads x {d_k} dims)",
40
+ },
41
+ {
42
+ "id": "v_proj",
43
+ "title": "Value projection",
44
+ "description": f"Linear; {hidden} -> {kv_out} ({num_kv_heads} KV-heads x {d_k} dims)",
45
+ },
46
+ {
47
+ "id": "qkv_dot",
48
+ "title": attention_title,
49
+ "description": attention_desc,
50
+ },
51
+ {
52
+ "id": "o_proj",
53
+ "title": "Output projection",
54
+ "description": f"Linear; {q_out} -> {hidden} (recombines all {num_heads} heads)",
55
+ },
56
+ ]
57
+
58
+
59
+ def _sdpa_operation_meta(
60
+ attention: AttentionSpec,
61
+ num_heads: int,
62
+ num_kv_heads: int,
63
+ d_k: str,
64
+ q_per_group: int | None,
65
+ ) -> tuple[str, str]:
66
+ if attention.kind == "mqa":
67
+ return (
68
+ "Multi-query scaled dot-product attention",
69
+ (
70
+ f"scores = softmax(QK^T / sqrt({d_k})); "
71
+ f"{num_heads} query heads share one K/V head"
72
+ ),
73
+ )
74
+ if attention.kind == "gqa":
75
+ group = (
76
+ f"; each KV head serves {q_per_group} query heads"
77
+ if q_per_group
78
+ else ""
79
+ )
80
+ return (
81
+ "Grouped scaled dot-product attention",
82
+ (
83
+ f"scores = softmax(QK^T / sqrt({d_k})); "
84
+ f"{num_heads} query heads attend through {num_kv_heads} shared KV heads{group}"
85
+ ),
86
+ )
87
+ return (
88
+ "Scaled dot-product attention",
89
+ (
90
+ f"scores = softmax(QK^T / sqrt({d_k})); "
91
+ "context = scores * V; "
92
+ f"output shape [batch, {num_heads}, seq, {d_k}]"
93
+ ),
94
+ )
95
+
96
+
97
+ def _mla_child_blocks(attention: AttentionSpec, hidden_size: int) -> list[dict]:
98
+ hidden = _fmt(hidden_size)
99
+ q_rank = _fmt(attention.q_lora_rank) if attention.q_lora_rank else "direct"
100
+ kv_rank = _fmt(attention.kv_lora_rank)
101
+ rope = _fmt(attention.rope_dim)
102
+ num_heads = attention.num_heads or 0
103
+ head_dim = attention.head_dim or 0
104
+ q_out = _fmt(num_heads * head_dim) if (num_heads and head_dim) else hidden
105
+ return [
106
+ {
107
+ "id": "mla_q",
108
+ "label": "Q projection",
109
+ "title": "Query projection",
110
+ "description": (
111
+ f"Projects hidden states into query heads through LoRA rank {q_rank}"
112
+ if attention.q_lora_rank
113
+ else f"Q projection; {hidden} -> {q_out}"
114
+ ),
115
+ },
116
+ {
117
+ "id": "mla_kv_down",
118
+ "label": "KV compress",
119
+ "title": "K/V latent compression",
120
+ "description": f"Compresses the token state into a shared latent K/V vector; {hidden} -> rank {kv_rank}",
121
+ },
122
+ {
123
+ "id": "mla_kv_up",
124
+ "label": "KV expand",
125
+ "title": "K/V head expansion",
126
+ "description": f"Expands the latent K/V vector into per-head key/value content for {num_heads} query heads",
127
+ },
128
+ {
129
+ "id": "mla_rope",
130
+ "label": "RoPE key",
131
+ "title": "Rotary key side-channel",
132
+ "description": f"Separate positional key slice used with RoPE; dim {rope}",
133
+ },
134
+ {
135
+ "id": "mla_attn",
136
+ "label": "Latent attention",
137
+ "title": "Multi-head latent attention",
138
+ "description": "Attention over decompressed latent K/V plus the RoPE side channel",
139
+ },
140
+ {
141
+ "id": "o_proj",
142
+ "label": "Linear (out)",
143
+ "title": "Output projection",
144
+ "description": f"Linear; {q_out} -> {hidden}",
145
+ },
146
+ ]
147
+
148
+
149
+ def _ssm_child_blocks(attention: AttentionSpec, hidden_size: int) -> list[dict]:
150
+ hidden = _fmt(hidden_size)
151
+ state = _fmt(attention.head_dim)
152
+ return [
153
+ {
154
+ "id": "ssm_in_proj",
155
+ "label": "Input projection",
156
+ "title": "SSM input projection",
157
+ "description": f"Project hidden activations into SSM channels; hidden {hidden}",
158
+ },
159
+ {
160
+ "id": "ssm_conv",
161
+ "label": "Local conv",
162
+ "title": "Short convolution",
163
+ "description": "Depthwise local mixing before the state-space recurrence",
164
+ },
165
+ {
166
+ "id": "ssm_scan",
167
+ "label": "Selective scan",
168
+ "title": "Selective state-space scan",
169
+ "description": f"Token recurrence with state dimension {state}",
170
+ },
171
+ {
172
+ "id": "ssm_gate",
173
+ "label": "Gate",
174
+ "title": "SSM gate",
175
+ "description": "Element-wise gate controlling the recurrent output",
176
+ },
177
+ {
178
+ "id": "ssm_out_proj",
179
+ "label": "Output projection",
180
+ "title": "SSM output projection",
181
+ "description": f"Project SSM channels back to hidden dim {hidden}",
182
+ },
183
+ ]
184
+
185
+
186
+ def _recurrent_child_blocks(attention: AttentionSpec, hidden_size: int) -> list[dict]:
187
+ hidden = _fmt(hidden_size)
188
+ width = _fmt(attention.head_dim)
189
+ return [
190
+ {
191
+ "id": "lru_in_proj",
192
+ "label": "Input projection",
193
+ "title": "LRU input projection",
194
+ "description": f"Linear; hidden {hidden} -> recurrent width {width}",
195
+ },
196
+ {
197
+ "id": "lru_state",
198
+ "label": "Recurrent state",
199
+ "title": "Linear recurrent state",
200
+ "description": f"State update over sequence positions; width {width}",
201
+ },
202
+ {
203
+ "id": "lru_gate",
204
+ "label": "Gate",
205
+ "title": "Recurrent gate",
206
+ "description": "Element-wise gate controlling recurrent features",
207
+ },
208
+ {
209
+ "id": "lru_out_proj",
210
+ "label": "Output projection",
211
+ "title": "LRU output projection",
212
+ "description": f"Linear; recurrent width {width} -> hidden {hidden}",
213
+ },
214
+ ]
215
+
216
+
217
+ def _rwkv_child_blocks(attention: AttentionSpec, hidden_size: int) -> list[dict]:
218
+ hidden = _fmt(hidden_size)
219
+ heads = attention.num_heads or 0
220
+ return [
221
+ {
222
+ "id": "rwkv_receptance",
223
+ "label": "Receptance",
224
+ "title": "Receptance gate",
225
+ "description": f"Token-wise gate over hidden dim {hidden}",
226
+ },
227
+ {
228
+ "id": "rwkv_key",
229
+ "label": "Key",
230
+ "title": "RWKV key projection",
231
+ "description": f"Key-like channel mixing over {heads} recurrent heads",
232
+ },
233
+ {
234
+ "id": "rwkv_value",
235
+ "label": "Value",
236
+ "title": "RWKV value projection",
237
+ "description": "Value channel sent through time-mixing recurrence",
238
+ },
239
+ {
240
+ "id": "rwkv_time_mix",
241
+ "label": "Time-mix",
242
+ "title": "Time-decay recurrence",
243
+ "description": "Linear-time weighted recurrence replacing self-attention",
244
+ },
245
+ {
246
+ "id": "rwkv_out",
247
+ "label": "Output projection",
248
+ "title": "RWKV output projection",
249
+ "description": f"Project mixed channels back to hidden dim {hidden}",
250
+ },
251
+ ]
252
+
253
+
254
+ def _linear_attention_child_blocks(attention: AttentionSpec, hidden_size: int) -> list[dict]:
255
+ hidden = _fmt(hidden_size)
256
+ num_heads = attention.num_heads or 0
257
+ num_kv_heads = attention.num_kv_heads or num_heads
258
+ head_dim = attention.head_dim or 0
259
+ q_out = _fmt(num_heads * head_dim) if (num_heads and head_dim) else hidden
260
+ kv_out = _fmt(num_kv_heads * head_dim) if (num_kv_heads and head_dim) else hidden
261
+ return [
262
+ {
263
+ "id": "q_proj",
264
+ "label": "Linear (Q)",
265
+ "title": "Query projection",
266
+ "description": f"Linear; {hidden} -> {q_out}",
267
+ },
268
+ {
269
+ "id": "k_proj",
270
+ "label": "Linear (K)",
271
+ "title": "Key projection",
272
+ "description": f"Linear; {hidden} -> {kv_out}",
273
+ },
274
+ {
275
+ "id": "v_proj",
276
+ "label": "Linear (V)",
277
+ "title": "Value projection",
278
+ "description": f"Linear; {hidden} -> {kv_out}",
279
+ },
280
+ {
281
+ "id": "kernel_map",
282
+ "label": "Kernel map",
283
+ "title": "Feature map",
284
+ "description": "Apply kernel feature map so attention can be accumulated linearly",
285
+ },
286
+ {
287
+ "id": "linear_mix",
288
+ "label": "Linear mix",
289
+ "title": "Linear attention mix",
290
+ "description": "Prefix/state accumulation computes attention in linear time",
291
+ },
292
+ {
293
+ "id": "o_proj",
294
+ "label": "Linear (out)",
295
+ "title": "Output projection",
296
+ "description": f"Linear; {q_out} -> {hidden}",
297
+ },
298
+ ]
@@ -0,0 +1,132 @@
1
+ """Labels, titles, and descriptions for transformer block specs."""
2
+ from __future__ import annotations
3
+
4
+ from ....ir import AttentionSpec, FFNSpec
5
+ from ....labels import activation_label
6
+ from ..common import format_dim as _fmt
7
+
8
+
9
+ def attention_label(attention: AttentionSpec) -> list[str]:
10
+ kind = attention.kind
11
+ prefix = _attention_mask_prefix(attention)
12
+ if kind == "mla":
13
+ return _prefixed_label(prefix, "Multi-Head Latent", "Attention")
14
+ if kind == "mqa":
15
+ return _prefixed_label(prefix, "Multi-Query", "Attention")
16
+ if kind == "gqa":
17
+ tag = "(QK-Norm)" if attention.qk_norm else "Attention"
18
+ return _prefixed_label(prefix, "Grouped-Query", tag)
19
+ if kind == "ssm":
20
+ shared_tag = "(Shared)" if attention.shared else "Block"
21
+ return ["Selective SSM", shared_tag]
22
+ if kind == "recurrent":
23
+ return ["Linear Recurrent", "Unit (LRU)"]
24
+ if kind == "rwkv":
25
+ return ["RWKV", "Token-Mixing"]
26
+ if kind == "linear":
27
+ return ["Linear", "Attention"]
28
+
29
+ tags = []
30
+ if attention.qk_norm:
31
+ tags.append("QK-Norm")
32
+ if attention.no_rope:
33
+ tags.append("NoPE")
34
+ if tags:
35
+ return ["Multi-Head Attn", f"({', '.join(tags)})"]
36
+ return ["Multi-Head", "Attention"]
37
+
38
+
39
+ def attention_title(attention: AttentionSpec) -> str:
40
+ if attention.kind == "mqa":
41
+ base = "Multi-query attention"
42
+ else:
43
+ base = {
44
+ "mla": "Multi-head latent attention",
45
+ "gqa": "Grouped-query attention",
46
+ "ssm": "Selective state-space model (Mamba)",
47
+ "recurrent": "Linear Recurrent Unit (LRU)",
48
+ "rwkv": "RWKV token-mixing",
49
+ "linear": "Linear attention",
50
+ }.get(attention.kind, "Attention")
51
+ base = _prefixed_title(_attention_mask_title_prefix(attention), base)
52
+ extras = []
53
+ if attention.qk_norm:
54
+ extras.append("QK-Norm")
55
+ if attention.shared:
56
+ extras.append("weight-shared")
57
+ if attention.no_rope:
58
+ extras.append("NoPE")
59
+ return f"{base} ({', '.join(extras)})" if extras else base
60
+
61
+
62
+ def describe_attention(attention: AttentionSpec) -> str:
63
+ if attention.kind == "mla":
64
+ text = (
65
+ f"Multi-head latent attention; {attention.num_heads} heads; "
66
+ f"KV LoRA {_fmt(attention.kv_lora_rank)}"
67
+ )
68
+ if attention.q_lora_rank:
69
+ text += f"; Q LoRA {_fmt(attention.q_lora_rank)}"
70
+ return text
71
+ if attention.kind == "mqa":
72
+ return _with_attention_window(attention, f"Multi-query; {attention.num_heads} Q / 1 KV head")
73
+ if attention.kind == "gqa":
74
+ return _with_attention_window(attention, (
75
+ f"Grouped-query; {attention.num_heads} Q / {attention.num_kv_heads} KV heads; "
76
+ f"head dim {_fmt(attention.head_dim)}"
77
+ ))
78
+ if attention.kind == "ssm":
79
+ shared = "; weight-shared across positions" if attention.shared else ""
80
+ return f"Selective SSM; state dim {_fmt(attention.head_dim)}{shared}"
81
+ if attention.kind == "recurrent":
82
+ return f"Linear Recurrent Unit; LRU width {_fmt(attention.head_dim)}"
83
+ if attention.kind == "rwkv":
84
+ return f"RWKV token-mixing; {attention.num_heads} heads"
85
+ if attention.kind == "linear":
86
+ return (
87
+ f"Linear attention; {attention.num_heads} Q / {attention.num_kv_heads} KV; "
88
+ f"head dim {_fmt(attention.head_dim)}"
89
+ )
90
+
91
+ extras = []
92
+ if attention.qk_norm:
93
+ extras.append("QK-Norm")
94
+ if attention.no_rope:
95
+ extras.append("NoPE")
96
+ suffix = f"; {', '.join(extras)}" if extras else ""
97
+ return _with_attention_window(attention, f"Multi-head; {attention.num_heads} heads; head dim {_fmt(attention.head_dim)}{suffix}")
98
+
99
+
100
+ def _attention_mask_prefix(attention: AttentionSpec) -> str:
101
+ return "SWA" if attention.mask == "sliding" else ""
102
+
103
+
104
+ def _attention_mask_title_prefix(attention: AttentionSpec) -> str:
105
+ return "Sliding-window" if attention.mask == "sliding" else ""
106
+
107
+
108
+ def _prefixed_label(prefix: str, first: str, second: str) -> list[str]:
109
+ return [f"{prefix} · {first}", second] if prefix else [first, second]
110
+
111
+
112
+ def _prefixed_title(prefix: str, title: str) -> str:
113
+ return f"{prefix} {title}" if prefix else title
114
+
115
+
116
+ def _with_attention_window(attention: AttentionSpec, text: str) -> str:
117
+ if attention.mask == "sliding" and attention.window_size:
118
+ return f"{text}; sliding window {_fmt(attention.window_size)}"
119
+ return text
120
+
121
+
122
+ def describe_ffn(ffn: FFNSpec) -> str:
123
+ if ffn.kind == "moe":
124
+ text = f"MoE; {_fmt(ffn.num_experts)} experts; top-{ffn.num_experts_per_tok}"
125
+ if ffn.num_shared_experts:
126
+ text += f" + {ffn.num_shared_experts} shared"
127
+ if ffn.num_experts and ffn.num_experts_per_tok:
128
+ text += f"; {100 * ffn.num_experts_per_tok / ffn.num_experts:.1f}% active"
129
+ text += f"; expert hidden {_fmt(ffn.expert_intermediate_size or ffn.intermediate_size)}"
130
+ return text
131
+ gated = "gated " if ffn.gated else ""
132
+ return f"{gated}FFN; {activation_label(ffn.activation)}; hidden {_fmt(ffn.intermediate_size)}"