@tryhamster/gerbil 1.0.0-rc.8 → 1.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +1 -1
- package/README.md +247 -84
- package/dist/architectures-C1I5V3Dt.mjs +6070 -0
- package/dist/architectures-C1I5V3Dt.mjs.map +1 -0
- package/dist/browser/index.d.ts +264 -588
- package/dist/browser/index.d.ts.map +1 -1
- package/dist/browser/index.js +585 -2334
- package/dist/browser/index.js.map +1 -1
- package/dist/cli.mjs +625 -1098
- package/dist/cli.mjs.map +1 -1
- package/dist/defaults-9komdrbY.mjs +24 -0
- package/dist/defaults-9komdrbY.mjs.map +1 -0
- package/dist/frameworks/express.d.mts +1 -3
- package/dist/frameworks/express.d.mts.map +1 -1
- package/dist/frameworks/express.mjs +7 -7
- package/dist/frameworks/express.mjs.map +1 -1
- package/dist/frameworks/fastify.d.mts +1 -1
- package/dist/frameworks/fastify.d.mts.map +1 -1
- package/dist/frameworks/fastify.mjs +3 -3
- package/dist/frameworks/fastify.mjs.map +1 -1
- package/dist/frameworks/hono.d.mts +1 -1
- package/dist/frameworks/hono.d.mts.map +1 -1
- package/dist/frameworks/hono.mjs +4 -4
- package/dist/frameworks/hono.mjs.map +1 -1
- package/dist/frameworks/next.d.mts +3 -2
- package/dist/frameworks/next.d.mts.map +1 -1
- package/dist/frameworks/next.mjs +4 -4
- package/dist/frameworks/next.mjs.map +1 -1
- package/dist/frameworks/react.d.mts +1 -1
- package/dist/frameworks/trpc.d.mts +1 -1
- package/dist/frameworks/trpc.d.mts.map +1 -1
- package/dist/frameworks/trpc.mjs +4 -4
- package/dist/frameworks/trpc.mjs.map +1 -1
- package/dist/gerbil-BHrJJIa4.mjs +1656 -0
- package/dist/gerbil-BHrJJIa4.mjs.map +1 -0
- package/dist/gerbil-BT9fCydo.d.mts +488 -0
- package/dist/gerbil-BT9fCydo.d.mts.map +1 -0
- package/dist/gerbil-DomNfIr1.mjs +4 -0
- package/dist/gpu/hooks.d.mts +520 -0
- package/dist/gpu/hooks.d.mts.map +1 -0
- package/dist/gpu/hooks.mjs +1188 -0
- package/dist/gpu/hooks.mjs.map +1 -0
- package/dist/gpu/index.d.mts +2 -0
- package/dist/gpu/index.mjs +6 -0
- package/dist/gpu-33qCAtHW.mjs +3615 -0
- package/dist/gpu-33qCAtHW.mjs.map +1 -0
- package/dist/index-Dgmb2kE3.d.mts +245 -0
- package/dist/index-Dgmb2kE3.d.mts.map +1 -0
- package/dist/index-jEAL2s-A.d.mts +2022 -0
- package/dist/index-jEAL2s-A.d.mts.map +1 -0
- package/dist/index.d.mts +22 -487
- package/dist/index.d.mts.map +1 -1
- package/dist/index.mjs +13 -8
- package/dist/index.mjs.map +1 -1
- package/dist/indexeddb-store-BWIMtxxH.mjs +103 -0
- package/dist/indexeddb-store-BWIMtxxH.mjs.map +1 -0
- package/dist/indexeddb-store-ClH12Xnl.mjs +4 -0
- package/dist/integrations/ai-sdk.d.mts +75 -6
- package/dist/integrations/ai-sdk.d.mts.map +1 -1
- package/dist/integrations/ai-sdk.mjs +131 -15
- package/dist/integrations/ai-sdk.mjs.map +1 -1
- package/dist/integrations/langchain.d.mts +1 -1
- package/dist/integrations/langchain.d.mts.map +1 -1
- package/dist/integrations/langchain.mjs +5 -5
- package/dist/integrations/langchain.mjs.map +1 -1
- package/dist/integrations/llamaindex.d.mts +1 -1
- package/dist/integrations/llamaindex.d.mts.map +1 -1
- package/dist/integrations/llamaindex.mjs +5 -5
- package/dist/integrations/llamaindex.mjs.map +1 -1
- package/dist/integrations/mcp-client.mjs +3 -3
- package/dist/integrations/mcp-client.mjs.map +1 -1
- package/dist/integrations/mcp.d.mts +3 -2
- package/dist/integrations/mcp.d.mts.map +1 -1
- package/dist/integrations/mcp.mjs +5 -5
- package/dist/{mcp-BvbriaBy.mjs → mcp-1DaMsaBc.mjs} +4 -4
- package/dist/mcp-1DaMsaBc.mjs.map +1 -0
- package/dist/memory/index.d.mts +3 -0
- package/dist/memory/index.mjs +6 -0
- package/dist/memory-D1P7Tmda.mjs +4 -0
- package/dist/memory-DVN0MnIG.mjs +132 -0
- package/dist/memory-DVN0MnIG.mjs.map +1 -0
- package/dist/memory-Dj0J1v88.mjs +294 -0
- package/dist/memory-Dj0J1v88.mjs.map +1 -0
- package/dist/moonshine-stt-BLyVoRpB.mjs +4 -0
- package/dist/moonshine-stt-v_P_Ci_m.mjs +11936 -0
- package/dist/moonshine-stt-v_P_Ci_m.mjs.map +1 -0
- package/dist/{one-liner-s-lD8rCC.mjs → one-liner-DnQn7HJK.mjs} +14 -16
- package/dist/one-liner-DnQn7HJK.mjs.map +1 -0
- package/dist/repl-jV5gcJFA.mjs +9 -0
- package/dist/skills/index.d.mts +270 -320
- package/dist/skills/index.d.mts.map +1 -1
- package/dist/skills/index.mjs +5 -5
- package/dist/{skills-CD3Orlex.mjs → skills-DX8D59UH.mjs} +187 -32
- package/dist/skills-DX8D59UH.mjs.map +1 -0
- package/dist/{tools-Bi1P7Xoy.mjs → tools-DQ1mPUw5.mjs} +34 -22
- package/dist/tools-DQ1mPUw5.mjs.map +1 -0
- package/dist/{types-CiTc7ez3.d.mts → types-D6FiR_oh.d.mts} +106 -12
- package/dist/types-D6FiR_oh.d.mts.map +1 -0
- package/dist/types-DQBe2lFo.d.mts +165 -0
- package/dist/types-DQBe2lFo.d.mts.map +1 -0
- package/dist/{utils-CZBZ8dgR.mjs → utils-DKO55ZmZ.mjs} +1 -1
- package/dist/{utils-CZBZ8dgR.mjs.map → utils-DKO55ZmZ.mjs.map} +1 -1
- package/dist/vector-B0panuy6.mjs +95 -0
- package/dist/vector-B0panuy6.mjs.map +1 -0
- package/docs/PROJECT-STATE.md +321 -0
- package/docs/adding-a-model-family.md +280 -0
- package/docs/ai-sdk.md +70 -61
- package/docs/architecture/overview.md +17 -7
- package/docs/browser.md +203 -8
- package/docs/embeddings.md +156 -0
- package/docs/gerbil-site-native-migration.md +217 -0
- package/docs/gpu-engine/architectures.md +398 -0
- package/docs/gpu-engine/ir.md +372 -0
- package/docs/gpu-engine/kernels.md +718 -0
- package/docs/gpu-engine/paper.html +1759 -0
- package/docs/gpu-engine/paper.md +2109 -0
- package/docs/gpu-engine/safetensors.md +312 -0
- package/docs/gpu-engine/tokenizer.md +302 -0
- package/docs/memory-rag.md +91 -0
- package/docs/metal-safari-intel.md +190 -0
- package/docs/mobile-failure-diagnosis.md +124 -0
- package/docs/mobile.md +99 -0
- package/docs/observability.md +230 -0
- package/docs/onnx-removal-plan.md +339 -0
- package/docs/research/autoresearch-portable.md +904 -0
- package/docs/research/dispatch-reduction-hivemind.md +84 -0
- package/docs/research/ios-safari-model-caching.md +117 -0
- package/docs/research/mobile-webgpu-speed-fusion.md +135 -0
- package/docs/research/native-stt-model-selection.md +49 -0
- package/docs/research/native-tts-model-selection.md +90 -0
- package/docs/research/native-vs-chromium-decision.md +152 -0
- package/docs/research/nemotron-mamba2-inference.md +910 -0
- package/docs/research/qwen35-multimodal.md +293 -0
- package/docs/research/qwen36-gemma4-targets.md +337 -0
- package/docs/research/sota-embedding-models.md +179 -0
- package/docs/research/sota-mobile-models-2026.md +263 -0
- package/docs/research/sota-modality-models.md +202 -0
- package/docs/research/tps-baselines.md +71 -0
- package/docs/research/webgpu-m4-reference.md +104 -0
- package/docs/site-update-plan.md +155 -0
- package/docs/structured-output.md +123 -0
- package/docs/stt.md +63 -446
- package/docs/tts.md +77 -499
- package/docs/vision.md +100 -338
- package/package.json +22 -7
- package/dist/chrome-backend-CORwaIyC.mjs +0 -1212
- package/dist/chrome-backend-CORwaIyC.mjs.map +0 -1
- package/dist/chrome-backend-DIKYoWj-.mjs +0 -3
- package/dist/gerbil-CJ3ifloF.mjs +0 -4
- package/dist/gerbil-Dw4Qj77e.mjs +0 -1631
- package/dist/gerbil-Dw4Qj77e.mjs.map +0 -1
- package/dist/gerbil-qOTe1nl2.d.mts +0 -431
- package/dist/gerbil-qOTe1nl2.d.mts.map +0 -1
- package/dist/kokoro-BNTb6egA.mjs +0 -20210
- package/dist/kokoro-BNTb6egA.mjs.map +0 -1
- package/dist/kokoro-DFRQ1OeM.js +0 -20212
- package/dist/kokoro-DFRQ1OeM.js.map +0 -1
- package/dist/mcp-BvbriaBy.mjs.map +0 -1
- package/dist/one-liner-s-lD8rCC.mjs.map +0 -1
- package/dist/repl-DveXw36T.mjs +0 -9
- package/dist/skills-CD3Orlex.mjs.map +0 -1
- package/dist/stt-CpLYbGFd.mjs +0 -433
- package/dist/stt-CpLYbGFd.mjs.map +0 -1
- package/dist/stt-DRPLEEHB.mjs +0 -3
- package/dist/stt-Te8Qz-Ay.js +0 -433
- package/dist/stt-Te8Qz-Ay.js.map +0 -1
- package/dist/tools-Bi1P7Xoy.mjs.map +0 -1
- package/dist/transformers.web-DokyH3rP.js +0 -3
- package/dist/transformers.web-M6mCnEYJ.js +0 -30382
- package/dist/transformers.web-M6mCnEYJ.js.map +0 -1
- package/dist/tts-C0xx3CtE.js +0 -724
- package/dist/tts-C0xx3CtE.js.map +0 -1
- package/dist/tts-DXgsKGCe.mjs +0 -3
- package/dist/tts-DeGANMNV.mjs +0 -730
- package/dist/tts-DeGANMNV.mjs.map +0 -1
- package/dist/types-CiTc7ez3.d.mts.map +0 -1
- /package/dist/{auto-update-S9s5-g0C.mjs → auto-update-BVaLXcDE.mjs} +0 -0
- /package/dist/{chunk-CkXuGtQK.mjs → chunk-B9cbKln6.mjs} +0 -0
- /package/dist/{microphone-DaMZFRuR.mjs → microphone-Bqmoz9_K.mjs} +0 -0
|
@@ -0,0 +1,372 @@
|
|
|
1
|
+
# IR Deep Dive
|
|
2
|
+
|
|
3
|
+
The Intermediate Representation (IR) defined in `src/gpu/ir.ts` is the central contract of the GPU engine. Every component -- graph generators, the executor, the kernel registry, the model loader -- speaks this IR.
|
|
4
|
+
|
|
5
|
+
This document covers how the IR is designed, how to read a `ModelGraph`, how the executor interprets it, and how to add new operation types.
|
|
6
|
+
|
|
7
|
+
---
|
|
8
|
+
|
|
9
|
+
## Type Overview
|
|
10
|
+
|
|
11
|
+
```
|
|
12
|
+
ModelGraph
|
|
13
|
+
├── architecture: string "Qwen2ForCausalLM"
|
|
14
|
+
├── config: ModelArchConfig Resolved model dimensions
|
|
15
|
+
├── capabilities: ModelCapabilities { text, vision, moe }
|
|
16
|
+
├── tensors: Record<string, TensorDesc>
|
|
17
|
+
│ ├── "embed_tokens.weight" { shape: [vocab, hidden], dtype: "f32", storage: "constant" }
|
|
18
|
+
│ ├── "embed_out" { shape: ["T", hidden], dtype: "f32", storage: "activation" }
|
|
19
|
+
│ ├── "layer0_k_cache" { shape: ["L_max", kv_dim], dtype: "f32", storage: "kv_cache" }
|
|
20
|
+
│ └── ...
|
|
21
|
+
├── nodes: OpNode[]
|
|
22
|
+
│ ├── { id: "embed", opType: "Embedding", inputs: [...], outputs: [...] }
|
|
23
|
+
│ ├── { id: "layer0_norm1", opType: "RMSNorm", inputs: [...], outputs: [...] }
|
|
24
|
+
│ └── ...
|
|
25
|
+
├── executionOrder: string[] ["embed", "layer0_norm1", "layer0_q_proj", ...]
|
|
26
|
+
├── inputs: string[] ["input_ids"]
|
|
27
|
+
└── outputs: string[] ["logits"]
|
|
28
|
+
```
|
|
29
|
+
|
|
30
|
+
---
|
|
31
|
+
|
|
32
|
+
## TensorDesc In Detail
|
|
33
|
+
|
|
34
|
+
```typescript
|
|
35
|
+
interface TensorDesc {
|
|
36
|
+
name: string;
|
|
37
|
+
shape: (number | string)[];
|
|
38
|
+
dtype: DType;
|
|
39
|
+
storage: TensorStorage;
|
|
40
|
+
safetensorsKey?: string;
|
|
41
|
+
}
|
|
42
|
+
```
|
|
43
|
+
|
|
44
|
+
### Shape Dimensions
|
|
45
|
+
|
|
46
|
+
Shapes can contain concrete numbers and symbolic strings:
|
|
47
|
+
|
|
48
|
+
| Symbol | Meaning | Resolved When |
|
|
49
|
+
|--------|---------|---------------|
|
|
50
|
+
| `"T"` | Current sequence length | Each forward pass (prefill: prompt length; decode: 1) |
|
|
51
|
+
| `"L_max"` | Total cached sequence length | Each forward pass (seqPos + T) |
|
|
52
|
+
|
|
53
|
+
Concrete examples:
|
|
54
|
+
- `[151936, 896]` -- embedding weight for Qwen3.5-0.8B (vocab_size x hidden_size)
|
|
55
|
+
- `["T", 896]` -- activation after embedding (seq_len x hidden_size)
|
|
56
|
+
- `["L_max", 256]` -- KV cache for one layer (total_seq x kv_dim)
|
|
57
|
+
|
|
58
|
+
### Storage Types
|
|
59
|
+
|
|
60
|
+
| Storage | Lifetime | Allocation | Source |
|
|
61
|
+
|---------|----------|------------|--------|
|
|
62
|
+
| `"constant"` | Entire model session | Exact size from safetensors | Downloaded weights |
|
|
63
|
+
| `"activation"` | Overwritten each forward pass | Max size (T = maxSeqLen) | Computed during forward pass |
|
|
64
|
+
| `"kv_cache"` | Grows during generation | Max size (L_max = maxSeqLen) | Written by attention kernel |
|
|
65
|
+
|
|
66
|
+
### DType
|
|
67
|
+
|
|
68
|
+
| DType | Bytes per Element | TypedArray | Notes |
|
|
69
|
+
|-------|------------------|------------|-------|
|
|
70
|
+
| `"f32"` | 4 | `Float32Array` | Default for all computation |
|
|
71
|
+
| `"f16"` | 2 | `Uint16Array` (bitwise) | No native JS typed array |
|
|
72
|
+
| `"i32"` | 4 | `Int32Array` | |
|
|
73
|
+
| `"u32"` | 4 | `Uint32Array` | Used for input_ids |
|
|
74
|
+
| `"i4"` | 0.5 | Packed in `Uint32Array` | 8 values per u32 for INT4 weights |
|
|
75
|
+
|
|
76
|
+
---
|
|
77
|
+
|
|
78
|
+
## OpNode In Detail
|
|
79
|
+
|
|
80
|
+
```typescript
|
|
81
|
+
interface OpNode {
|
|
82
|
+
id: string;
|
|
83
|
+
opType: OpType;
|
|
84
|
+
inputs: string[];
|
|
85
|
+
outputs: string[];
|
|
86
|
+
attributes: Record<string, unknown>;
|
|
87
|
+
}
|
|
88
|
+
```
|
|
89
|
+
|
|
90
|
+
### ID Convention
|
|
91
|
+
|
|
92
|
+
Node IDs follow a systematic pattern:
|
|
93
|
+
- `"embed"` -- global embedding node
|
|
94
|
+
- `"layer{i}_norm1"` -- input layernorm for layer i
|
|
95
|
+
- `"layer{i}_q_proj"` -- Q projection matmul for layer i
|
|
96
|
+
- `"layer{i}_rope"` -- RoPE for layer i
|
|
97
|
+
- `"layer{i}_attn"` -- attention for layer i
|
|
98
|
+
- `"layer{i}_resid1"` -- first residual add for layer i
|
|
99
|
+
- `"layer{i}_silu"` -- SiLU activation for layer i
|
|
100
|
+
- `"layer{i}_swiglu"` -- gate * up multiply for layer i
|
|
101
|
+
- `"final_norm"` -- final layer norm
|
|
102
|
+
- `"lm_head"` -- language model head matmul
|
|
103
|
+
|
|
104
|
+
### Input/Output Ordering
|
|
105
|
+
|
|
106
|
+
The `inputs` and `outputs` arrays define which tensors the kernel reads and writes. **Order matters**: it matches the kernel's binding order.
|
|
107
|
+
|
|
108
|
+
For example, an RMSNorm node:
|
|
109
|
+
```typescript
|
|
110
|
+
{
|
|
111
|
+
id: "layer0_norm1",
|
|
112
|
+
opType: "RMSNorm",
|
|
113
|
+
inputs: ["embed_out", "layers.0.input_layernorm.weight"],
|
|
114
|
+
outputs: ["layer0_norm1_out"],
|
|
115
|
+
attributes: { hidden_size: 896, eps: 1e-6, seq_len_tensor: "embed_out" }
|
|
116
|
+
}
|
|
117
|
+
```
|
|
118
|
+
|
|
119
|
+
This maps to the RMSNorm WGSL bindings:
|
|
120
|
+
- `@binding(0)` input -> `embed_out` (from `inputs[0]`)
|
|
121
|
+
- `@binding(1)` weight -> `layers.0.input_layernorm.weight` (from `inputs[1]`)
|
|
122
|
+
- `@binding(2)` output -> `layer0_norm1_out` (from `outputs[0]`)
|
|
123
|
+
- `@binding(3)` params -> built from `attributes`
|
|
124
|
+
|
|
125
|
+
### Attributes
|
|
126
|
+
|
|
127
|
+
The `attributes` dictionary carries op-specific parameters. Common patterns:
|
|
128
|
+
|
|
129
|
+
| Attribute | Used By | Purpose |
|
|
130
|
+
|-----------|---------|---------|
|
|
131
|
+
| `hidden_size` | RMSNorm, Add | Dimension of the feature vector |
|
|
132
|
+
| `eps` | RMSNorm, LayerNorm | Numerical stability epsilon |
|
|
133
|
+
| `M_tensor` | MatMul | Name of tensor whose shape gives the M dimension |
|
|
134
|
+
| `K` | MatMul | Inner dimension |
|
|
135
|
+
| `N` | MatMul | Output dimension |
|
|
136
|
+
| `head_dim` | RoPE, Attention | Per-head dimension |
|
|
137
|
+
| `num_heads` | RoPE, Attention | Number of query heads |
|
|
138
|
+
| `num_kv_heads` | RoPE, Attention | Number of key/value heads (GQA) |
|
|
139
|
+
| `rope_base` | RoPE | Base frequency for position encoding |
|
|
140
|
+
| `causal` | Attention | Whether to apply causal mask |
|
|
141
|
+
| `layer_index` | Attention | Which layer (for KV cache buffer lookup) |
|
|
142
|
+
| `count_tensor` | SiLU, Mul, Add | Name of tensor whose element count determines dispatch size |
|
|
143
|
+
|
|
144
|
+
---
|
|
145
|
+
|
|
146
|
+
## Concrete Qwen2 Layer Example
|
|
147
|
+
|
|
148
|
+
For a Qwen3.5-0.8B model (hidden_size=896, num_heads=14, num_kv_heads=2, head_dim=64, intermediate_size=4864), here is layer 0 in full:
|
|
149
|
+
|
|
150
|
+
### Tensors for Layer 0
|
|
151
|
+
|
|
152
|
+
```
|
|
153
|
+
Weight tensors (constant):
|
|
154
|
+
layers.0.input_layernorm.weight [896] f32
|
|
155
|
+
layers.0.self_attn.q_proj.weight [896, 896] f32
|
|
156
|
+
layers.0.self_attn.k_proj.weight [896, 128] f32 (2 heads * 64 dim)
|
|
157
|
+
layers.0.self_attn.v_proj.weight [896, 128] f32
|
|
158
|
+
layers.0.self_attn.o_proj.weight [896, 896] f32
|
|
159
|
+
layers.0.post_attention_layernorm.weight [896] f32
|
|
160
|
+
layers.0.mlp.gate_proj.weight [896, 4864] f32
|
|
161
|
+
layers.0.mlp.up_proj.weight [896, 4864] f32
|
|
162
|
+
layers.0.mlp.down_proj.weight [4864, 896] f32
|
|
163
|
+
|
|
164
|
+
Activation tensors:
|
|
165
|
+
layer0_norm1_out [T, 896] f32
|
|
166
|
+
layer0_q [T, 896] f32
|
|
167
|
+
layer0_k [T, 128] f32
|
|
168
|
+
layer0_v [T, 128] f32
|
|
169
|
+
layer0_q_rope [T, 896] f32
|
|
170
|
+
layer0_k_rope [T, 128] f32
|
|
171
|
+
layer0_attn_out [T, 896] f32
|
|
172
|
+
layer0_o_proj_out [T, 896] f32
|
|
173
|
+
layer0_resid1 [T, 896] f32
|
|
174
|
+
layer0_norm2_out [T, 896] f32
|
|
175
|
+
layer0_gate_out [T, 4864] f32
|
|
176
|
+
layer0_up_out [T, 4864] f32
|
|
177
|
+
layer0_silu_out [T, 4864] f32
|
|
178
|
+
layer0_swiglu_out [T, 4864] f32
|
|
179
|
+
layer0_mlp_out [T, 896] f32
|
|
180
|
+
layer0_resid2 [T, 896] f32
|
|
181
|
+
|
|
182
|
+
KV cache tensors:
|
|
183
|
+
layer0_k_cache [L_max, 128] f32
|
|
184
|
+
layer0_v_cache [L_max, 128] f32
|
|
185
|
+
```
|
|
186
|
+
|
|
187
|
+
### Nodes for Layer 0
|
|
188
|
+
|
|
189
|
+
```
|
|
190
|
+
1. layer0_norm1 RMSNorm [embed_out, layers.0.input_layernorm.weight]
|
|
191
|
+
-> [layer0_norm1_out]
|
|
192
|
+
|
|
193
|
+
2. layer0_q_proj MatMul [layer0_norm1_out, layers.0.self_attn.q_proj.weight]
|
|
194
|
+
-> [layer0_q] {K:896, N:896}
|
|
195
|
+
|
|
196
|
+
3. layer0_k_proj MatMul [layer0_norm1_out, layers.0.self_attn.k_proj.weight]
|
|
197
|
+
-> [layer0_k] {K:896, N:128}
|
|
198
|
+
|
|
199
|
+
4. layer0_v_proj MatMul [layer0_norm1_out, layers.0.self_attn.v_proj.weight]
|
|
200
|
+
-> [layer0_v] {K:896, N:128}
|
|
201
|
+
|
|
202
|
+
5. layer0_rope RoPE [layer0_q, layer0_k]
|
|
203
|
+
-> [layer0_q_rope, layer0_k_rope]
|
|
204
|
+
{head_dim:64, num_heads:14, num_kv_heads:2}
|
|
205
|
+
|
|
206
|
+
6. layer0_attn Attention [layer0_q_rope, layer0_k_rope, layer0_v,
|
|
207
|
+
layer0_k_cache, layer0_v_cache]
|
|
208
|
+
-> [layer0_attn_out]
|
|
209
|
+
{num_heads:14, num_kv_heads:2, head_dim:64, causal:true}
|
|
210
|
+
|
|
211
|
+
7. layer0_o_proj MatMul [layer0_attn_out, layers.0.self_attn.o_proj.weight]
|
|
212
|
+
-> [layer0_o_proj_out] {K:896, N:896}
|
|
213
|
+
|
|
214
|
+
8. layer0_resid1 Add [embed_out, layer0_o_proj_out]
|
|
215
|
+
-> [layer0_resid1]
|
|
216
|
+
|
|
217
|
+
9. layer0_norm2 RMSNorm [layer0_resid1, layers.0.post_attention_layernorm.weight]
|
|
218
|
+
-> [layer0_norm2_out]
|
|
219
|
+
|
|
220
|
+
10. layer0_gate MatMul [layer0_norm2_out, layers.0.mlp.gate_proj.weight]
|
|
221
|
+
-> [layer0_gate_out] {K:896, N:4864}
|
|
222
|
+
|
|
223
|
+
11. layer0_up MatMul [layer0_norm2_out, layers.0.mlp.up_proj.weight]
|
|
224
|
+
-> [layer0_up_out] {K:896, N:4864}
|
|
225
|
+
|
|
226
|
+
12. layer0_silu SiLU [layer0_gate_out]
|
|
227
|
+
-> [layer0_silu_out]
|
|
228
|
+
|
|
229
|
+
13. layer0_swiglu Mul [layer0_silu_out, layer0_up_out]
|
|
230
|
+
-> [layer0_swiglu_out]
|
|
231
|
+
|
|
232
|
+
14. layer0_down MatMul [layer0_swiglu_out, layers.0.mlp.down_proj.weight]
|
|
233
|
+
-> [layer0_mlp_out] {K:4864, N:896}
|
|
234
|
+
|
|
235
|
+
15. layer0_resid2 Add [layer0_resid1, layer0_mlp_out]
|
|
236
|
+
-> [layer0_resid2]
|
|
237
|
+
```
|
|
238
|
+
|
|
239
|
+
Layer 1 takes `layer0_resid2` as its input (the `prevOutput` variable in the generator) and produces `layer1_resid2`, and so on through all 24 layers.
|
|
240
|
+
|
|
241
|
+
---
|
|
242
|
+
|
|
243
|
+
## How the Executor Interprets the IR
|
|
244
|
+
|
|
245
|
+
The executor's `forward()` method iterates through `graph.executionOrder` -- the topologically sorted list of node IDs:
|
|
246
|
+
|
|
247
|
+
```typescript
|
|
248
|
+
for (const nodeId of this.graph.executionOrder) {
|
|
249
|
+
const node = this.graph.nodes.find(n => n.id === nodeId)!;
|
|
250
|
+
this.dispatchOp(pass, node, inputIdsBuffer, resolvedShapes);
|
|
251
|
+
}
|
|
252
|
+
```
|
|
253
|
+
|
|
254
|
+
For each node, `dispatchOp()`:
|
|
255
|
+
|
|
256
|
+
1. **Looks up the kernel**: `KERNEL_REGISTRY[node.opType]` returns a `KernelSpec` with the WGSL source code, binding layout, parameter builder, and dispatch size calculator
|
|
257
|
+
2. **Gets or creates the pipeline**: Shader compilation is cached by code+entryPoint
|
|
258
|
+
3. **Builds uniforms**: `spec.buildParams(node, resolvedShapes)` converts the node's attributes and resolved tensor shapes into the packed `ArrayBuffer` that the kernel expects as `@group(0) @binding(N) var<uniform> params`
|
|
259
|
+
4. **Gathers buffers**: For each binding in the kernel spec, finds the corresponding GPU buffer from weight buffers, activation buffers, KV cache buffers, or the input IDs buffer
|
|
260
|
+
5. **Creates bind group**: Binds buffers to their numbered bindings
|
|
261
|
+
6. **Dispatches**: Calls `pass.dispatchWorkgroups(wgX, wgY, wgZ)` with workgroup counts computed by `spec.getDispatchSize(node, resolvedShapes)`
|
|
262
|
+
|
|
263
|
+
---
|
|
264
|
+
|
|
265
|
+
## Adding a New Op Type
|
|
266
|
+
|
|
267
|
+
To add a new operation (e.g., `Conv2d` for vision models):
|
|
268
|
+
|
|
269
|
+
### Step 1: Add the OpType
|
|
270
|
+
|
|
271
|
+
In `ir.ts`, add the new type to the `OpType` union:
|
|
272
|
+
|
|
273
|
+
```typescript
|
|
274
|
+
export type OpType =
|
|
275
|
+
// ... existing ops ...
|
|
276
|
+
| "Conv2d";
|
|
277
|
+
```
|
|
278
|
+
|
|
279
|
+
### Step 2: Write the WGSL Kernel
|
|
280
|
+
|
|
281
|
+
Create `kernels/wgsl/conv2d.wgsl` with the compute shader. Follow the conventions:
|
|
282
|
+
- Last binding is always the uniform params struct
|
|
283
|
+
- Bindings follow the order: inputs first, then outputs, then uniforms
|
|
284
|
+
- Use `@compute @workgroup_size(...)` to declare workgroup size
|
|
285
|
+
|
|
286
|
+
### Step 3: Register the Kernel
|
|
287
|
+
|
|
288
|
+
In the kernel registry (when implemented), add a `KernelSpec` entry:
|
|
289
|
+
|
|
290
|
+
```typescript
|
|
291
|
+
KERNEL_REGISTRY["Conv2d"] = {
|
|
292
|
+
shaderCode: conv2dWGSL,
|
|
293
|
+
entryPoint: "main",
|
|
294
|
+
bindings: [
|
|
295
|
+
{ type: "storage-read" }, // input
|
|
296
|
+
{ type: "storage-read" }, // weight
|
|
297
|
+
{ type: "storage-rw" }, // output
|
|
298
|
+
{ type: "uniform" }, // params
|
|
299
|
+
],
|
|
300
|
+
buildParams: (node, shapes) => { /* pack uniform struct */ },
|
|
301
|
+
getDispatchSize: (node, shapes) => { /* compute workgroup counts */ },
|
|
302
|
+
};
|
|
303
|
+
```
|
|
304
|
+
|
|
305
|
+
### Step 4: Use in a Graph Generator
|
|
306
|
+
|
|
307
|
+
In a graph generator (e.g., a vision encoder generator), create nodes that reference the new op:
|
|
308
|
+
|
|
309
|
+
```typescript
|
|
310
|
+
addNode({
|
|
311
|
+
id: "vision_conv0",
|
|
312
|
+
opType: "Conv2d",
|
|
313
|
+
inputs: ["patch_embed_input", "vision_conv0_weight"],
|
|
314
|
+
outputs: ["vision_conv0_out"],
|
|
315
|
+
attributes: { in_channels: 3, out_channels: 64, kernel_size: 7, stride: 2, padding: 3 },
|
|
316
|
+
});
|
|
317
|
+
```
|
|
318
|
+
|
|
319
|
+
The executor will automatically dispatch it using the registered kernel spec.
|
|
320
|
+
|
|
321
|
+
---
|
|
322
|
+
|
|
323
|
+
## ModelArchConfig Reference
|
|
324
|
+
|
|
325
|
+
The `ModelArchConfig` struct captures all the architectural dimensions needed to generate a graph:
|
|
326
|
+
|
|
327
|
+
```typescript
|
|
328
|
+
interface ModelArchConfig {
|
|
329
|
+
// Core dimensions
|
|
330
|
+
hidden_size: number; // e.g. 896
|
|
331
|
+
num_layers: number; // e.g. 24
|
|
332
|
+
num_heads: number; // e.g. 14 (query heads)
|
|
333
|
+
num_kv_heads: number; // e.g. 2 (for GQA)
|
|
334
|
+
head_dim: number; // e.g. 64
|
|
335
|
+
intermediate_size: number; // e.g. 4864 (MLP hidden)
|
|
336
|
+
vocab_size: number; // e.g. 151936
|
|
337
|
+
context_length: number; // e.g. 32768
|
|
338
|
+
|
|
339
|
+
// Normalization
|
|
340
|
+
rms_norm_eps: number; // e.g. 1e-6
|
|
341
|
+
norm_type: "rmsnorm" | "layernorm";
|
|
342
|
+
|
|
343
|
+
// Positional encoding
|
|
344
|
+
rope_base: number; // e.g. 1000000.0
|
|
345
|
+
rope_dim: number; // e.g. 64 (usually == head_dim)
|
|
346
|
+
|
|
347
|
+
// KV cache
|
|
348
|
+
kv_layout: "LHSd"; // Only LHSd is currently supported
|
|
349
|
+
|
|
350
|
+
// MoE (future)
|
|
351
|
+
is_moe: boolean;
|
|
352
|
+
num_experts?: number;
|
|
353
|
+
top_k_experts?: number;
|
|
354
|
+
|
|
355
|
+
// Vision (future)
|
|
356
|
+
has_vision_tower: boolean;
|
|
357
|
+
vision_architecture?: string;
|
|
358
|
+
vision_patch_size?: number;
|
|
359
|
+
vision_embed_dim?: number;
|
|
360
|
+
}
|
|
361
|
+
```
|
|
362
|
+
|
|
363
|
+
These values are extracted from HuggingFace `config.json` by the graph generator. For example, the Qwen2 generator maps:
|
|
364
|
+
- `config.hidden_size` -> `hidden_size`
|
|
365
|
+
- `config.num_hidden_layers` -> `num_layers`
|
|
366
|
+
- `config.num_attention_heads` -> `num_heads`
|
|
367
|
+
- `config.num_key_value_heads` -> `num_kv_heads`
|
|
368
|
+
- `config.intermediate_size` -> `intermediate_size`
|
|
369
|
+
- `config.vocab_size` -> `vocab_size`
|
|
370
|
+
- `config.max_position_embeddings` -> `context_length`
|
|
371
|
+
- `config.rms_norm_eps` -> `rms_norm_eps`
|
|
372
|
+
- `config.rope_theta` -> `rope_base`
|