model-unfolder 0.2.6__tar.gz → 0.2.8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. {model_unfolder-0.2.6 → model_unfolder-0.2.8}/PKG-INFO +2 -2
  2. {model_unfolder-0.2.6 → model_unfolder-0.2.8}/README.md +1 -1
  3. {model_unfolder-0.2.6 → model_unfolder-0.2.8}/model_unfolder/__init__.py +31 -5
  4. model_unfolder-0.2.8/model_unfolder/adapters/transformer/__init__.py +8 -0
  5. {model_unfolder-0.2.6 → model_unfolder-0.2.8}/model_unfolder/adapters/transformer/blocks/attention.py +211 -15
  6. {model_unfolder-0.2.6 → model_unfolder-0.2.8}/model_unfolder/adapters/transformer/blocks/descriptions.py +6 -4
  7. model_unfolder-0.2.8/model_unfolder/adapters/transformer/parser.py +553 -0
  8. model_unfolder-0.2.8/model_unfolder/adapters/transformer/special_parts/modalities/__init__.py +6 -0
  9. model_unfolder-0.2.8/model_unfolder/adapters/transformer/special_parts/modalities/accessors.py +71 -0
  10. model_unfolder-0.2.8/model_unfolder/adapters/transformer/special_parts/modalities/audio.py +124 -0
  11. model_unfolder-0.2.8/model_unfolder/adapters/transformer/special_parts/modalities/builder.py +33 -0
  12. model_unfolder-0.2.8/model_unfolder/adapters/transformer/special_parts/modalities/detect.py +182 -0
  13. model_unfolder-0.2.8/model_unfolder/adapters/transformer/special_parts/modalities/fusion.py +115 -0
  14. model_unfolder-0.2.8/model_unfolder/adapters/transformer/special_parts/modalities/schema.py +17 -0
  15. model_unfolder-0.2.8/model_unfolder/adapters/transformer/special_parts/modalities/vision.py +380 -0
  16. {model_unfolder-0.2.6 → model_unfolder-0.2.8}/model_unfolder/diagram.py +20 -2
  17. model_unfolder-0.2.8/model_unfolder/evidence/__init__.py +13 -0
  18. model_unfolder-0.2.8/model_unfolder/evidence/ast_scanner.py +87 -0
  19. model_unfolder-0.2.8/model_unfolder/evidence/inspector.py +28 -0
  20. model_unfolder-0.2.8/model_unfolder/evidence/models.py +169 -0
  21. model_unfolder-0.2.8/model_unfolder/evidence/patterns.py +392 -0
  22. model_unfolder-0.2.8/model_unfolder/evidence/sources.py +318 -0
  23. model_unfolder-0.2.8/model_unfolder/evidence/validate.py +153 -0
  24. model_unfolder-0.2.8/model_unfolder/expanded/__init__.py +89 -0
  25. model_unfolder-0.2.8/model_unfolder/expanded/attention.py +185 -0
  26. model_unfolder-0.2.8/model_unfolder/expanded/block_graph.py +41 -0
  27. model_unfolder-0.2.8/model_unfolder/expanded/code_evidence.py +62 -0
  28. model_unfolder-0.2.8/model_unfolder/expanded/ffn.py +85 -0
  29. model_unfolder-0.2.8/model_unfolder/expanded/grouping.py +90 -0
  30. model_unfolder-0.2.8/model_unfolder/expanded/layer_group.py +39 -0
  31. model_unfolder-0.2.8/model_unfolder/expanded/modalities.py +72 -0
  32. model_unfolder-0.2.8/model_unfolder/expanded/norms.py +13 -0
  33. model_unfolder-0.2.8/model_unfolder/expanded/ops.py +41 -0
  34. model_unfolder-0.2.8/model_unfolder/expanded/pathways.py +29 -0
  35. model_unfolder-0.2.8/model_unfolder/expanded/residual.py +25 -0
  36. model_unfolder-0.2.8/model_unfolder/expanded/sections.py +84 -0
  37. model_unfolder-0.2.8/model_unfolder/expanded/stack.py +59 -0
  38. model_unfolder-0.2.8/model_unfolder/expanded/utils.py +49 -0
  39. {model_unfolder-0.2.6 → model_unfolder-0.2.8}/model_unfolder/labels.py +4 -3
  40. {model_unfolder-0.2.6 → model_unfolder-0.2.8}/model_unfolder/parser.py +47 -2
  41. model_unfolder-0.2.8/model_unfolder/renderers/html/block_views/__init__.py +20 -0
  42. {model_unfolder-0.2.6 → model_unfolder-0.2.8}/model_unfolder/renderers/html/block_views/attention_types/__init__.py +4 -0
  43. model_unfolder-0.2.8/model_unfolder/renderers/html/block_views/attention_types/common.py +639 -0
  44. model_unfolder-0.2.8/model_unfolder/renderers/html/block_views/attention_types/grouped_query.py +105 -0
  45. model_unfolder-0.2.8/model_unfolder/renderers/html/block_views/attention_types/latent.py +265 -0
  46. {model_unfolder-0.2.6 → model_unfolder-0.2.8}/model_unfolder/renderers/html/block_views/attention_types/linear.py +3 -2
  47. model_unfolder-0.2.8/model_unfolder/renderers/html/block_views/attention_types/multi_head.py +96 -0
  48. model_unfolder-0.2.8/model_unfolder/renderers/html/block_views/attention_types/multi_query.py +103 -0
  49. {model_unfolder-0.2.6 → model_unfolder-0.2.8}/model_unfolder/renderers/html/block_views/attention_types/rwkv.py +3 -2
  50. model_unfolder-0.2.8/model_unfolder/renderers/html/block_views/attention_types/sliding_window.py +76 -0
  51. model_unfolder-0.2.8/model_unfolder/renderers/html/block_views/modalities.py +16 -0
  52. model_unfolder-0.2.8/model_unfolder/renderers/html/block_views/modality_views/__init__.py +15 -0
  53. model_unfolder-0.2.8/model_unfolder/renderers/html/block_views/modality_views/audio.py +31 -0
  54. model_unfolder-0.2.8/model_unfolder/renderers/html/block_views/modality_views/common.py +80 -0
  55. model_unfolder-0.2.8/model_unfolder/renderers/html/block_views/modality_views/fusion_cross_attention.py +62 -0
  56. model_unfolder-0.2.8/model_unfolder/renderers/html/block_views/modality_views/fusion_grid.py +119 -0
  57. model_unfolder-0.2.8/model_unfolder/renderers/html/block_views/modality_views/fusion_placeholder.py +208 -0
  58. model_unfolder-0.2.8/model_unfolder/renderers/html/block_views/modality_views/video.py +41 -0
  59. model_unfolder-0.2.8/model_unfolder/renderers/html/block_views/modality_views/vision.py +43 -0
  60. model_unfolder-0.2.8/model_unfolder/renderers/html/block_views/registry.py +60 -0
  61. {model_unfolder-0.2.6 → model_unfolder-0.2.8}/model_unfolder/renderers/html/cards.py +16 -1
  62. {model_unfolder-0.2.6 → model_unfolder-0.2.8}/model_unfolder/renderers/html/document.py +3 -0
  63. model_unfolder-0.2.8/model_unfolder/renderers/html/evidence.py +114 -0
  64. {model_unfolder-0.2.6 → model_unfolder-0.2.8}/model_unfolder/renderers/html/interactions.py +15 -10
  65. {model_unfolder-0.2.6 → model_unfolder-0.2.8}/model_unfolder/renderers/html/metadata.py +3 -0
  66. model_unfolder-0.2.8/model_unfolder/renderers/html/metadata_modalities.py +634 -0
  67. {model_unfolder-0.2.6 → model_unfolder-0.2.8}/model_unfolder/renderers/html/styles.py +75 -0
  68. {model_unfolder-0.2.6 → model_unfolder-0.2.8}/model_unfolder/renderers/html/svg.py +33 -4
  69. {model_unfolder-0.2.6 → model_unfolder-0.2.8}/model_unfolder/renderers/html/theme.py +1 -0
  70. {model_unfolder-0.2.6 → model_unfolder-0.2.8}/model_unfolder/renderers/html/views.py +31 -11
  71. model_unfolder-0.2.8/model_unfolder/renderers/html/views_modalities.py +161 -0
  72. {model_unfolder-0.2.6 → model_unfolder-0.2.8}/model_unfolder.egg-info/PKG-INFO +2 -2
  73. {model_unfolder-0.2.6 → model_unfolder-0.2.8}/model_unfolder.egg-info/SOURCES.txt +47 -18
  74. {model_unfolder-0.2.6 → model_unfolder-0.2.8}/pyproject.toml +1 -1
  75. model_unfolder-0.2.8/tests/test_code_evidence.py +460 -0
  76. model_unfolder-0.2.8/tests/test_expanded_json.py +530 -0
  77. {model_unfolder-0.2.6 → model_unfolder-0.2.8}/tests/test_smoke.py +346 -86
  78. model_unfolder-0.2.6/model_unfolder/adapters/transformer/__init__.py +0 -5
  79. model_unfolder-0.2.6/model_unfolder/adapters/transformer/families/__init__.py +0 -26
  80. model_unfolder-0.2.6/model_unfolder/adapters/transformer/families/cohere.py +0 -106
  81. model_unfolder-0.2.6/model_unfolder/adapters/transformer/families/deepseek.py +0 -106
  82. model_unfolder-0.2.6/model_unfolder/adapters/transformer/families/falcon.py +0 -183
  83. model_unfolder-0.2.6/model_unfolder/adapters/transformer/families/fallback.py +0 -192
  84. model_unfolder-0.2.6/model_unfolder/adapters/transformer/families/gemma/__init__.py +0 -26
  85. model_unfolder-0.2.6/model_unfolder/adapters/transformer/families/gemma/gemma2.py +0 -101
  86. model_unfolder-0.2.6/model_unfolder/adapters/transformer/families/gemma/gemma3.py +0 -137
  87. model_unfolder-0.2.6/model_unfolder/adapters/transformer/families/gemma/gemma4.py +0 -200
  88. model_unfolder-0.2.6/model_unfolder/adapters/transformer/families/gpt_neox.py +0 -114
  89. model_unfolder-0.2.6/model_unfolder/adapters/transformer/families/jamba.py +0 -126
  90. model_unfolder-0.2.6/model_unfolder/adapters/transformer/families/llama.py +0 -138
  91. model_unfolder-0.2.6/model_unfolder/adapters/transformer/families/minimax.py +0 -99
  92. model_unfolder-0.2.6/model_unfolder/adapters/transformer/families/mistral.py +0 -119
  93. model_unfolder-0.2.6/model_unfolder/adapters/transformer/families/qwen.py +0 -148
  94. model_unfolder-0.2.6/model_unfolder/adapters/transformer/families/recurrent_gemma.py +0 -103
  95. model_unfolder-0.2.6/model_unfolder/adapters/transformer/families/rwkv.py +0 -74
  96. model_unfolder-0.2.6/model_unfolder/adapters/transformer/families/zamba.py +0 -127
  97. model_unfolder-0.2.6/model_unfolder/renderers/html/block_views/__init__.py +0 -33
  98. model_unfolder-0.2.6/model_unfolder/renderers/html/block_views/attention_types/common.py +0 -310
  99. model_unfolder-0.2.6/model_unfolder/renderers/html/block_views/attention_types/grouped_query.py +0 -90
  100. model_unfolder-0.2.6/model_unfolder/renderers/html/block_views/attention_types/latent.py +0 -61
  101. model_unfolder-0.2.6/model_unfolder/renderers/html/block_views/attention_types/multi_head.py +0 -74
  102. model_unfolder-0.2.6/model_unfolder/renderers/html/block_views/attention_types/multi_query.py +0 -118
  103. {model_unfolder-0.2.6 → model_unfolder-0.2.8}/LICENSE +0 -0
  104. {model_unfolder-0.2.6 → model_unfolder-0.2.8}/model_unfolder/adapters/__init__.py +0 -0
  105. {model_unfolder-0.2.6 → model_unfolder-0.2.8}/model_unfolder/adapters/custom/__init__.py +0 -0
  106. {model_unfolder-0.2.6 → model_unfolder-0.2.8}/model_unfolder/adapters/diffusor/__init__.py +0 -0
  107. {model_unfolder-0.2.6 → model_unfolder-0.2.8}/model_unfolder/adapters/transformer/assembly.py +0 -0
  108. {model_unfolder-0.2.6 → model_unfolder-0.2.8}/model_unfolder/adapters/transformer/blocks/__init__.py +0 -0
  109. {model_unfolder-0.2.6 → model_unfolder-0.2.8}/model_unfolder/adapters/transformer/blocks/feed_forward.py +0 -0
  110. {model_unfolder-0.2.6 → model_unfolder-0.2.8}/model_unfolder/adapters/transformer/blocks/layers.py +0 -0
  111. {model_unfolder-0.2.6 → model_unfolder-0.2.8}/model_unfolder/adapters/transformer/blocks/model.py +0 -0
  112. {model_unfolder-0.2.6 → model_unfolder-0.2.8}/model_unfolder/adapters/transformer/common.py +0 -0
  113. {model_unfolder-0.2.6 → model_unfolder-0.2.8}/model_unfolder/adapters/transformer/special_parts/__init__.py +0 -0
  114. {model_unfolder-0.2.6 → model_unfolder-0.2.8}/model_unfolder/adapters/transformer/special_parts/per_layer_embedding.py +0 -0
  115. {model_unfolder-0.2.6 → model_unfolder-0.2.8}/model_unfolder/html_renderer.py +0 -0
  116. {model_unfolder-0.2.6 → model_unfolder-0.2.8}/model_unfolder/ir.py +0 -0
  117. {model_unfolder-0.2.6 → model_unfolder-0.2.8}/model_unfolder/params.py +0 -0
  118. {model_unfolder-0.2.6 → model_unfolder-0.2.8}/model_unfolder/renderers/__init__.py +0 -0
  119. {model_unfolder-0.2.6 → model_unfolder-0.2.8}/model_unfolder/renderers/html/__init__.py +0 -0
  120. {model_unfolder-0.2.6 → model_unfolder-0.2.8}/model_unfolder/renderers/html/block_views/attention.py +0 -0
  121. {model_unfolder-0.2.6 → model_unfolder-0.2.8}/model_unfolder/renderers/html/block_views/attention_types/state_space.py +0 -0
  122. {model_unfolder-0.2.6 → model_unfolder-0.2.8}/model_unfolder/renderers/html/block_views/feed_forward.py +0 -0
  123. {model_unfolder-0.2.6 → model_unfolder-0.2.8}/model_unfolder/renderers/html/block_views/mixture_of_experts.py +0 -0
  124. {model_unfolder-0.2.6 → model_unfolder-0.2.8}/model_unfolder/renderers/html/block_views/per_layer_embedding.py +0 -0
  125. {model_unfolder-0.2.6 → model_unfolder-0.2.8}/model_unfolder/renderers/html/sections.py +0 -0
  126. {model_unfolder-0.2.6 → model_unfolder-0.2.8}/model_unfolder/renderers/html/utils.py +0 -0
  127. {model_unfolder-0.2.6 → model_unfolder-0.2.8}/model_unfolder.egg-info/dependency_links.txt +0 -0
  128. {model_unfolder-0.2.6 → model_unfolder-0.2.8}/model_unfolder.egg-info/requires.txt +0 -0
  129. {model_unfolder-0.2.6 → model_unfolder-0.2.8}/model_unfolder.egg-info/top_level.txt +0 -0
  130. {model_unfolder-0.2.6 → model_unfolder-0.2.8}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: model-unfolder
3
- Version: 0.2.6
3
+ Version: 0.2.8
4
4
  Summary: Unfold any HuggingFace transformer into an interactive architecture diagram, inline in Jupyter.
5
5
  Author: model-unfolder contributors
6
6
  License: Apache-2.0
@@ -95,7 +95,7 @@ No extra config in `model_unfolder` itself.
95
95
  ```python
96
96
  diagram = unfold(cfg)
97
97
  diagram.save("model.html") # standalone interactive HTML
98
- diagram.save("model.json") # IR (no rendering)
98
+ diagram.save("model.json") # expanded architecture JSON (no rendering)
99
99
  diagram.param_count() # {"total": ..., "active": ..., "per_layer": [...]}
100
100
  diagram.to_ir() # full IR dict
101
101
  ```
@@ -71,7 +71,7 @@ No extra config in `model_unfolder` itself.
71
71
  ```python
72
72
  diagram = unfold(cfg)
73
73
  diagram.save("model.html") # standalone interactive HTML
74
- diagram.save("model.json") # IR (no rendering)
74
+ diagram.save("model.json") # expanded architecture JSON (no rendering)
75
75
  diagram.param_count() # {"total": ..., "active": ..., "per_layer": [...]}
76
76
  diagram.to_ir() # full IR dict
77
77
  ```
@@ -11,6 +11,7 @@ Outside Jupyter::
11
11
  diagram.save("kimi_k2.html")
12
12
  """
13
13
  from .diagram import Diagram
14
+ from .evidence import inspect_model_code
14
15
  from .parser import config_to_ir
15
16
  from .ir import ModelIR, LayerSpec, AttentionSpec, FFNSpec, CrossLayerEdge
16
17
  from .params import estimate_params
@@ -27,11 +28,19 @@ __all__ = [
27
28
  "FFNSpec",
28
29
  "CrossLayerEdge",
29
30
  "config_to_ir",
31
+ "inspect_model_code",
30
32
  "estimate_params",
31
33
  ]
32
34
 
33
35
 
34
- def unfold(cfg_or_id, token=None) -> Diagram:
36
+ def unfold(
37
+ cfg_or_id,
38
+ token=None,
39
+ *,
40
+ inspect_code: bool = False,
41
+ code_source: str = "local",
42
+ return_json: bool = False,
43
+ ):
35
44
  """Unfold a transformer into a renderable architecture diagram.
36
45
 
37
46
  Parameters
@@ -44,14 +53,31 @@ def unfold(cfg_or_id, token=None) -> Diagram:
44
53
  Optional Hugging Face token used only when ``cfg_or_id`` is a model ID.
45
54
  If omitted, ``HF_TOKEN`` and legacy Hugging Face token env vars are used
46
55
  when present.
56
+ inspect_code
57
+ If True, attach static source-code evidence to the IR. The code scanner
58
+ parses modeling files as text/AST and does not execute model code.
59
+ code_source
60
+ Source for code inspection: ``"local"`` (installed transformers),
61
+ ``"path"``, ``"hub"``, ``"auto"``, or a local file/directory path.
62
+ return_json
63
+ If True, return the expanded architecture JSON dict instead of the
64
+ renderable ``Diagram``. The JSON uses stable structural fields for
65
+ dimensions, projections, layer groups, operation graphs, cache behavior,
66
+ and trace paths instead of renderer labels/descriptions.
47
67
 
48
68
  Returns
49
69
  -------
50
- Diagram
51
- Renders inline in Jupyter; otherwise call ``.save()`` or ``.to_html()``.
70
+ Diagram | dict
71
+ ``Diagram`` by default; ``dict`` when ``return_json=True``.
52
72
  """
53
- ir = config_to_ir(cfg_or_id, token=token)
54
- return Diagram(ir)
73
+ ir = config_to_ir(
74
+ cfg_or_id,
75
+ token=token,
76
+ inspect_code=inspect_code,
77
+ code_source=code_source,
78
+ )
79
+ diagram = Diagram(ir)
80
+ return diagram.to_json() if return_json else diagram
55
81
 
56
82
 
57
83
  # friendly alias
@@ -0,0 +1,8 @@
1
+ """Transformer-LLM adapter.
2
+
3
+ There is exactly one parser (``parser.py``); see its module docstring for
4
+ the principle (config-driven, no per-family code).
5
+ """
6
+ from . import parser
7
+
8
+ ADAPTERS = [parser]
@@ -26,6 +26,18 @@ def _sdpa_child_blocks(attention: AttentionSpec, hidden_size: int) -> list[dict]
26
26
  q_out = _fmt(num_heads * head_dim) if (num_heads and head_dim) else hidden
27
27
  kv_out = _fmt(num_kv_heads * head_dim) if (num_kv_heads and head_dim) else hidden
28
28
  d_k = _fmt(head_dim) if head_dim else "d_k"
29
+ if attention.kind in {"mha", "gqa", "mqa"}:
30
+ return _sdpa_detailed_child_blocks(
31
+ attention.kind,
32
+ hidden,
33
+ q_out,
34
+ kv_out,
35
+ num_heads,
36
+ num_kv_heads,
37
+ d_k,
38
+ q_per_group,
39
+ )
40
+
29
41
  attention_title, attention_desc = _sdpa_operation_meta(attention, num_heads, num_kv_heads, d_k, q_per_group)
30
42
  return [
31
43
  {
@@ -36,12 +48,18 @@ def _sdpa_child_blocks(attention: AttentionSpec, hidden_size: int) -> list[dict]
36
48
  {
37
49
  "id": "k_proj",
38
50
  "title": "Key projection",
39
- "description": f"Linear; {hidden} -> {kv_out} ({num_kv_heads} KV-heads x {d_k} dims)",
51
+ "description": (
52
+ f"Linear; {hidden} -> {kv_out} ({num_kv_heads} KV-heads x {d_k} dims). "
53
+ "Cache ports show K/V write/read during generation: arrowhead for write, blunt tail for read."
54
+ ),
40
55
  },
41
56
  {
42
57
  "id": "v_proj",
43
58
  "title": "Value projection",
44
- "description": f"Linear; {hidden} -> {kv_out} ({num_kv_heads} KV-heads x {d_k} dims)",
59
+ "description": (
60
+ f"Linear; {hidden} -> {kv_out} ({num_kv_heads} KV-heads x {d_k} dims). "
61
+ "Cache ports show K/V write/read during generation: arrowhead for write, blunt tail for read."
62
+ ),
45
63
  },
46
64
  {
47
65
  "id": "qkv_dot",
@@ -56,6 +74,83 @@ def _sdpa_child_blocks(attention: AttentionSpec, hidden_size: int) -> list[dict]
56
74
  ]
57
75
 
58
76
 
77
+ def _sdpa_detailed_child_blocks(
78
+ kind: str,
79
+ hidden: str,
80
+ q_out: str,
81
+ kv_out: str,
82
+ num_heads: int,
83
+ num_kv_heads: int,
84
+ d_k: str,
85
+ q_per_group: int | None,
86
+ ) -> list[dict]:
87
+ kv_label = "1 shared K/V head" if kind == "mqa" else f"{num_kv_heads} KV-heads"
88
+ scaled_title = "Scaled attention scores"
89
+ scaled_desc = "Per head: QK^T / sqrt(dim); dot-product scores scaled for numerical stability"
90
+ if kind == "gqa":
91
+ scaled_title = "Grouped scaled dot-product attention"
92
+ group = f"; each KV head serves {q_per_group} query heads" if q_per_group else ""
93
+ scaled_desc = (
94
+ f"Grouped SDPA scores: {num_heads} query heads attend through "
95
+ f"{num_kv_heads} shared K/V heads{group}; scores use QK^T / sqrt(dim)"
96
+ )
97
+ elif kind == "mqa":
98
+ scaled_title = "Multi-query scaled dot-product attention"
99
+ scaled_desc = (
100
+ f"Multi-Query SDPA scores: {num_heads} query heads share one K/V stream; "
101
+ "scores use QK^T / sqrt(dim)"
102
+ )
103
+
104
+ return [
105
+ {
106
+ "id": "q_proj",
107
+ "title": "Query projection",
108
+ "description": f"Linear; {hidden} -> {q_out} ({num_heads} heads x {d_k} dims)",
109
+ },
110
+ {
111
+ "id": "k_proj",
112
+ "title": "Key projection",
113
+ "description": (
114
+ f"Linear; {hidden} -> {kv_out} ({kv_label} x {d_k} dims). "
115
+ "Cache ports show K/V write/read during generation: arrowhead for write, blunt tail for read."
116
+ ),
117
+ },
118
+ {
119
+ "id": "v_proj",
120
+ "title": "Value projection",
121
+ "description": (
122
+ f"Linear; {hidden} -> {kv_out} ({kv_label} x {d_k} dims). "
123
+ "Cache ports show K/V write/read during generation: arrowhead for write, blunt tail for read."
124
+ ),
125
+ },
126
+ {
127
+ "id": "scaled_scores",
128
+ "title": scaled_title,
129
+ "description": scaled_desc,
130
+ },
131
+ {
132
+ "id": "attn_softmax",
133
+ "title": "Softmax weights",
134
+ "description": "Normalize each query row into attention weights over source tokens",
135
+ },
136
+ {
137
+ "id": "attn_apply_v",
138
+ "title": "Apply values",
139
+ "description": "Multiply attention weights by V to produce one context vector per head",
140
+ },
141
+ {
142
+ "id": "concat_heads",
143
+ "title": "Concatenate heads",
144
+ "description": f"Stack all {num_heads} per-head context vectors back into width {q_out}",
145
+ },
146
+ {
147
+ "id": "o_proj",
148
+ "title": "Output projection",
149
+ "description": f"Linear; {q_out} -> {hidden} (mixes information across heads)",
150
+ },
151
+ ]
152
+
153
+
59
154
  def _sdpa_operation_meta(
60
155
  attention: AttentionSpec,
61
156
  num_heads: int,
@@ -102,40 +197,141 @@ def _mla_child_blocks(attention: AttentionSpec, hidden_size: int) -> list[dict]:
102
197
  num_heads = attention.num_heads or 0
103
198
  head_dim = attention.head_dim or 0
104
199
  q_out = _fmt(num_heads * head_dim) if (num_heads and head_dim) else hidden
105
- return [
200
+ query_children = [
106
201
  {
107
202
  "id": "mla_q",
108
203
  "label": "Q projection",
109
204
  "title": "Query projection",
110
205
  "description": (
111
- f"Projects hidden states into query heads through LoRA rank {q_rank}"
206
+ f"Projects hidden states into query latent space through LoRA rank {q_rank}"
112
207
  if attention.q_lora_rank
113
- else f"Q projection; {hidden} -> {q_out}"
208
+ else f"Projects hidden states directly into query heads; {hidden} -> {q_out}"
114
209
  ),
115
210
  },
211
+ {
212
+ "id": "mla_q_nope",
213
+ "label": "Q noPE",
214
+ "title": "Query content slice",
215
+ "description": "Query content component that does not receive rotary position encoding",
216
+ },
217
+ {
218
+ "id": "mla_q_rope",
219
+ "label": "Q RoPE",
220
+ "title": "Query positional slice",
221
+ "description": f"Query positional component prepared for rotary position encoding; dim {rope}",
222
+ },
223
+ {
224
+ "id": "mla_q_rope_apply",
225
+ "label": "Apply RoPE",
226
+ "title": "Apply RoPE to query",
227
+ "description": "Applies rotary position encoding to the query positional slice",
228
+ },
229
+ {
230
+ "id": "mla_q_concat",
231
+ "label": "Q concat",
232
+ "title": "Final MLA query",
233
+ "description": "Concatenates Q noPE with RoPE-encoded Q RoPE before score computation",
234
+ },
235
+ ]
236
+ kv_children = [
116
237
  {
117
238
  "id": "mla_kv_down",
118
239
  "label": "KV compress",
119
240
  "title": "K/V latent compression",
120
- "description": f"Compresses the token state into a shared latent K/V vector; {hidden} -> rank {kv_rank}",
241
+ "description": f"Compresses the token state into the shared latent K/V cache; {hidden} -> rank {kv_rank}",
242
+ },
243
+ {
244
+ "id": "mla_cache",
245
+ "label": "latent cache c_t",
246
+ "title": "Stored latent cache",
247
+ "description": (
248
+ f"Compressed K/V latent stored in the cache instead of full K and V heads; rank {kv_rank}. "
249
+ "Cache ports show write from compression and read back into K/V expansion."
250
+ ),
121
251
  },
122
252
  {
123
253
  "id": "mla_kv_up",
124
254
  "label": "KV expand",
125
255
  "title": "K/V head expansion",
126
- "description": f"Expands the latent K/V vector into per-head key/value content for {num_heads} query heads",
256
+ "description": f"Expands cached latent c_t into K noPE content and V values for {num_heads} query heads",
257
+ },
258
+ {
259
+ "id": "mla_k_nope",
260
+ "label": "K noPE",
261
+ "title": "Latent key content",
262
+ "description": "Key content expanded from the compressed K/V latent; concatenated with the RoPE key before scoring",
263
+ },
264
+ {
265
+ "id": "mla_k_rope",
266
+ "label": "K RoPE",
267
+ "title": "Key positional slice",
268
+ "description": f"Key positional component produced alongside the latent cache; dim {rope}",
269
+ },
270
+ {
271
+ "id": "mla_k_rope_apply",
272
+ "label": "Apply RoPE",
273
+ "title": "Apply RoPE to key",
274
+ "description": "Applies rotary position encoding to the key positional slice",
275
+ },
276
+ {
277
+ "id": "mla_k_merge",
278
+ "label": "K concat",
279
+ "title": "Composed MLA key",
280
+ "description": "Concatenates K noPE with the RoPE key side-channel before QK^T score computation",
281
+ },
282
+ {
283
+ "id": "mla_v",
284
+ "label": "V values",
285
+ "title": "Latent value heads",
286
+ "description": "Value heads expanded from the compressed K/V latent; consumed after softmax",
287
+ },
288
+ ]
289
+ return [
290
+ {
291
+ "id": "mla_query_path",
292
+ "label": "Query path",
293
+ "title": "MLA query path",
294
+ "description": (
295
+ "Builds Q by projecting the hidden state, splitting content and positional slices, "
296
+ "applying RoPE to the positional slice, then concatenating them"
297
+ ),
298
+ "detail_view": "mla_query_path",
299
+ "children": query_children,
300
+ },
301
+ {
302
+ "id": "mla_kv_path",
303
+ "label": "KV cache path",
304
+ "title": "MLA K/V cache path",
305
+ "description": (
306
+ f"Compresses hidden state into rank {kv_rank} latent cache, expands K/V content, "
307
+ "and combines K noPE with a RoPE key side-channel. Cache ports mark the latent write/read point."
308
+ ),
309
+ "detail_view": "mla_kv_cache_path",
310
+ "children": kv_children,
311
+ },
312
+ {
313
+ "id": "scaled_scores",
314
+ "label": "Latent scores",
315
+ "title": "Multi-Head Latent scores",
316
+ "description": "Q attends to expanded latent K plus the RoPE key side-channel; scores use QK^T / sqrt(dim)",
317
+ },
318
+ {
319
+ "id": "attn_softmax",
320
+ "label": "Softmax",
321
+ "title": "Softmax weights",
322
+ "description": "Normalize latent attention scores over source positions",
127
323
  },
128
324
  {
129
- "id": "mla_rope",
130
- "label": "RoPE key",
131
- "title": "Rotary key side-channel",
132
- "description": f"Separate positional key slice used with RoPE; dim {rope}",
325
+ "id": "attn_apply_v",
326
+ "label": "Apply V",
327
+ "title": "Apply latent values",
328
+ "description": "Multiply softmax weights by V expanded from the compressed K/V latent",
133
329
  },
134
330
  {
135
- "id": "mla_attn",
136
- "label": "Latent attention",
137
- "title": "Multi-head latent attention",
138
- "description": "Attention over decompressed latent K/V plus the RoPE side channel",
331
+ "id": "concat_heads",
332
+ "label": "Concat heads",
333
+ "title": "Concatenate latent heads",
334
+ "description": f"Stack all {num_heads} context heads back into width {q_out}",
139
335
  },
140
336
  {
141
337
  "id": "o_proj",
@@ -67,13 +67,14 @@ def describe_attention(attention: AttentionSpec) -> str:
67
67
  )
68
68
  if attention.q_lora_rank:
69
69
  text += f"; Q LoRA {_fmt(attention.q_lora_rank)}"
70
+ text += "; cache ports mark latent write/read state"
70
71
  return text
71
72
  if attention.kind == "mqa":
72
- return _with_attention_window(attention, f"Multi-query; {attention.num_heads} Q / 1 KV head")
73
+ return _with_attention_window(attention, f"Multi-query; {attention.num_heads} Q / 1 KV head; cache ports mark K/V write/read state")
73
74
  if attention.kind == "gqa":
74
75
  return _with_attention_window(attention, (
75
76
  f"Grouped-query; {attention.num_heads} Q / {attention.num_kv_heads} KV heads; "
76
- f"head dim {_fmt(attention.head_dim)}"
77
+ f"head dim {_fmt(attention.head_dim)}; cache ports mark K/V write/read state"
77
78
  ))
78
79
  if attention.kind == "ssm":
79
80
  shared = "; weight-shared across positions" if attention.shared else ""
@@ -94,11 +95,12 @@ def describe_attention(attention: AttentionSpec) -> str:
94
95
  if attention.no_rope:
95
96
  extras.append("NoPE")
96
97
  suffix = f"; {', '.join(extras)}" if extras else ""
97
- return _with_attention_window(attention, f"Multi-head; {attention.num_heads} heads; head dim {_fmt(attention.head_dim)}{suffix}")
98
+ cache_note = "; cache ports mark K/V write/read state"
99
+ return _with_attention_window(attention, f"Multi-head; {attention.num_heads} heads; head dim {_fmt(attention.head_dim)}{suffix}{cache_note}")
98
100
 
99
101
 
100
102
  def _attention_mask_prefix(attention: AttentionSpec) -> str:
101
- return "SWA" if attention.mask == "sliding" else ""
103
+ return "SW" if attention.mask == "sliding" else ""
102
104
 
103
105
 
104
106
  def _attention_mask_title_prefix(attention: AttentionSpec) -> str: