@tryhamster/gerbil 1.0.0-rc.8 → 1.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +1 -1
- package/README.md +247 -84
- package/dist/architectures-C1I5V3Dt.mjs +6070 -0
- package/dist/architectures-C1I5V3Dt.mjs.map +1 -0
- package/dist/browser/index.d.ts +264 -588
- package/dist/browser/index.d.ts.map +1 -1
- package/dist/browser/index.js +585 -2334
- package/dist/browser/index.js.map +1 -1
- package/dist/cli.mjs +625 -1098
- package/dist/cli.mjs.map +1 -1
- package/dist/defaults-9komdrbY.mjs +24 -0
- package/dist/defaults-9komdrbY.mjs.map +1 -0
- package/dist/frameworks/express.d.mts +1 -3
- package/dist/frameworks/express.d.mts.map +1 -1
- package/dist/frameworks/express.mjs +7 -7
- package/dist/frameworks/express.mjs.map +1 -1
- package/dist/frameworks/fastify.d.mts +1 -1
- package/dist/frameworks/fastify.d.mts.map +1 -1
- package/dist/frameworks/fastify.mjs +3 -3
- package/dist/frameworks/fastify.mjs.map +1 -1
- package/dist/frameworks/hono.d.mts +1 -1
- package/dist/frameworks/hono.d.mts.map +1 -1
- package/dist/frameworks/hono.mjs +4 -4
- package/dist/frameworks/hono.mjs.map +1 -1
- package/dist/frameworks/next.d.mts +3 -2
- package/dist/frameworks/next.d.mts.map +1 -1
- package/dist/frameworks/next.mjs +4 -4
- package/dist/frameworks/next.mjs.map +1 -1
- package/dist/frameworks/react.d.mts +1 -1
- package/dist/frameworks/trpc.d.mts +1 -1
- package/dist/frameworks/trpc.d.mts.map +1 -1
- package/dist/frameworks/trpc.mjs +4 -4
- package/dist/frameworks/trpc.mjs.map +1 -1
- package/dist/gerbil-BHrJJIa4.mjs +1656 -0
- package/dist/gerbil-BHrJJIa4.mjs.map +1 -0
- package/dist/gerbil-BT9fCydo.d.mts +488 -0
- package/dist/gerbil-BT9fCydo.d.mts.map +1 -0
- package/dist/gerbil-DomNfIr1.mjs +4 -0
- package/dist/gpu/hooks.d.mts +520 -0
- package/dist/gpu/hooks.d.mts.map +1 -0
- package/dist/gpu/hooks.mjs +1188 -0
- package/dist/gpu/hooks.mjs.map +1 -0
- package/dist/gpu/index.d.mts +2 -0
- package/dist/gpu/index.mjs +6 -0
- package/dist/gpu-33qCAtHW.mjs +3615 -0
- package/dist/gpu-33qCAtHW.mjs.map +1 -0
- package/dist/index-Dgmb2kE3.d.mts +245 -0
- package/dist/index-Dgmb2kE3.d.mts.map +1 -0
- package/dist/index-jEAL2s-A.d.mts +2022 -0
- package/dist/index-jEAL2s-A.d.mts.map +1 -0
- package/dist/index.d.mts +22 -487
- package/dist/index.d.mts.map +1 -1
- package/dist/index.mjs +13 -8
- package/dist/index.mjs.map +1 -1
- package/dist/indexeddb-store-BWIMtxxH.mjs +103 -0
- package/dist/indexeddb-store-BWIMtxxH.mjs.map +1 -0
- package/dist/indexeddb-store-ClH12Xnl.mjs +4 -0
- package/dist/integrations/ai-sdk.d.mts +75 -6
- package/dist/integrations/ai-sdk.d.mts.map +1 -1
- package/dist/integrations/ai-sdk.mjs +131 -15
- package/dist/integrations/ai-sdk.mjs.map +1 -1
- package/dist/integrations/langchain.d.mts +1 -1
- package/dist/integrations/langchain.d.mts.map +1 -1
- package/dist/integrations/langchain.mjs +5 -5
- package/dist/integrations/langchain.mjs.map +1 -1
- package/dist/integrations/llamaindex.d.mts +1 -1
- package/dist/integrations/llamaindex.d.mts.map +1 -1
- package/dist/integrations/llamaindex.mjs +5 -5
- package/dist/integrations/llamaindex.mjs.map +1 -1
- package/dist/integrations/mcp-client.mjs +3 -3
- package/dist/integrations/mcp-client.mjs.map +1 -1
- package/dist/integrations/mcp.d.mts +3 -2
- package/dist/integrations/mcp.d.mts.map +1 -1
- package/dist/integrations/mcp.mjs +5 -5
- package/dist/{mcp-BvbriaBy.mjs → mcp-1DaMsaBc.mjs} +4 -4
- package/dist/mcp-1DaMsaBc.mjs.map +1 -0
- package/dist/memory/index.d.mts +3 -0
- package/dist/memory/index.mjs +6 -0
- package/dist/memory-D1P7Tmda.mjs +4 -0
- package/dist/memory-DVN0MnIG.mjs +132 -0
- package/dist/memory-DVN0MnIG.mjs.map +1 -0
- package/dist/memory-Dj0J1v88.mjs +294 -0
- package/dist/memory-Dj0J1v88.mjs.map +1 -0
- package/dist/moonshine-stt-BLyVoRpB.mjs +4 -0
- package/dist/moonshine-stt-v_P_Ci_m.mjs +11936 -0
- package/dist/moonshine-stt-v_P_Ci_m.mjs.map +1 -0
- package/dist/{one-liner-s-lD8rCC.mjs → one-liner-DnQn7HJK.mjs} +14 -16
- package/dist/one-liner-DnQn7HJK.mjs.map +1 -0
- package/dist/repl-jV5gcJFA.mjs +9 -0
- package/dist/skills/index.d.mts +270 -320
- package/dist/skills/index.d.mts.map +1 -1
- package/dist/skills/index.mjs +5 -5
- package/dist/{skills-CD3Orlex.mjs → skills-DX8D59UH.mjs} +187 -32
- package/dist/skills-DX8D59UH.mjs.map +1 -0
- package/dist/{tools-Bi1P7Xoy.mjs → tools-DQ1mPUw5.mjs} +34 -22
- package/dist/tools-DQ1mPUw5.mjs.map +1 -0
- package/dist/{types-CiTc7ez3.d.mts → types-D6FiR_oh.d.mts} +106 -12
- package/dist/types-D6FiR_oh.d.mts.map +1 -0
- package/dist/types-DQBe2lFo.d.mts +165 -0
- package/dist/types-DQBe2lFo.d.mts.map +1 -0
- package/dist/{utils-CZBZ8dgR.mjs → utils-DKO55ZmZ.mjs} +1 -1
- package/dist/{utils-CZBZ8dgR.mjs.map → utils-DKO55ZmZ.mjs.map} +1 -1
- package/dist/vector-B0panuy6.mjs +95 -0
- package/dist/vector-B0panuy6.mjs.map +1 -0
- package/docs/PROJECT-STATE.md +321 -0
- package/docs/adding-a-model-family.md +280 -0
- package/docs/ai-sdk.md +70 -61
- package/docs/architecture/overview.md +17 -7
- package/docs/browser.md +203 -8
- package/docs/embeddings.md +156 -0
- package/docs/gerbil-site-native-migration.md +217 -0
- package/docs/gpu-engine/architectures.md +398 -0
- package/docs/gpu-engine/ir.md +372 -0
- package/docs/gpu-engine/kernels.md +718 -0
- package/docs/gpu-engine/paper.html +1759 -0
- package/docs/gpu-engine/paper.md +2109 -0
- package/docs/gpu-engine/safetensors.md +312 -0
- package/docs/gpu-engine/tokenizer.md +302 -0
- package/docs/memory-rag.md +91 -0
- package/docs/metal-safari-intel.md +190 -0
- package/docs/mobile-failure-diagnosis.md +124 -0
- package/docs/mobile.md +99 -0
- package/docs/observability.md +230 -0
- package/docs/onnx-removal-plan.md +339 -0
- package/docs/research/autoresearch-portable.md +904 -0
- package/docs/research/dispatch-reduction-hivemind.md +84 -0
- package/docs/research/ios-safari-model-caching.md +117 -0
- package/docs/research/mobile-webgpu-speed-fusion.md +135 -0
- package/docs/research/native-stt-model-selection.md +49 -0
- package/docs/research/native-tts-model-selection.md +90 -0
- package/docs/research/native-vs-chromium-decision.md +152 -0
- package/docs/research/nemotron-mamba2-inference.md +910 -0
- package/docs/research/qwen35-multimodal.md +293 -0
- package/docs/research/qwen36-gemma4-targets.md +337 -0
- package/docs/research/sota-embedding-models.md +179 -0
- package/docs/research/sota-mobile-models-2026.md +263 -0
- package/docs/research/sota-modality-models.md +202 -0
- package/docs/research/tps-baselines.md +71 -0
- package/docs/research/webgpu-m4-reference.md +104 -0
- package/docs/site-update-plan.md +155 -0
- package/docs/structured-output.md +123 -0
- package/docs/stt.md +63 -446
- package/docs/tts.md +77 -499
- package/docs/vision.md +100 -338
- package/package.json +22 -7
- package/dist/chrome-backend-CORwaIyC.mjs +0 -1212
- package/dist/chrome-backend-CORwaIyC.mjs.map +0 -1
- package/dist/chrome-backend-DIKYoWj-.mjs +0 -3
- package/dist/gerbil-CJ3ifloF.mjs +0 -4
- package/dist/gerbil-Dw4Qj77e.mjs +0 -1631
- package/dist/gerbil-Dw4Qj77e.mjs.map +0 -1
- package/dist/gerbil-qOTe1nl2.d.mts +0 -431
- package/dist/gerbil-qOTe1nl2.d.mts.map +0 -1
- package/dist/kokoro-BNTb6egA.mjs +0 -20210
- package/dist/kokoro-BNTb6egA.mjs.map +0 -1
- package/dist/kokoro-DFRQ1OeM.js +0 -20212
- package/dist/kokoro-DFRQ1OeM.js.map +0 -1
- package/dist/mcp-BvbriaBy.mjs.map +0 -1
- package/dist/one-liner-s-lD8rCC.mjs.map +0 -1
- package/dist/repl-DveXw36T.mjs +0 -9
- package/dist/skills-CD3Orlex.mjs.map +0 -1
- package/dist/stt-CpLYbGFd.mjs +0 -433
- package/dist/stt-CpLYbGFd.mjs.map +0 -1
- package/dist/stt-DRPLEEHB.mjs +0 -3
- package/dist/stt-Te8Qz-Ay.js +0 -433
- package/dist/stt-Te8Qz-Ay.js.map +0 -1
- package/dist/tools-Bi1P7Xoy.mjs.map +0 -1
- package/dist/transformers.web-DokyH3rP.js +0 -3
- package/dist/transformers.web-M6mCnEYJ.js +0 -30382
- package/dist/transformers.web-M6mCnEYJ.js.map +0 -1
- package/dist/tts-C0xx3CtE.js +0 -724
- package/dist/tts-C0xx3CtE.js.map +0 -1
- package/dist/tts-DXgsKGCe.mjs +0 -3
- package/dist/tts-DeGANMNV.mjs +0 -730
- package/dist/tts-DeGANMNV.mjs.map +0 -1
- package/dist/types-CiTc7ez3.d.mts.map +0 -1
- /package/dist/{auto-update-S9s5-g0C.mjs → auto-update-BVaLXcDE.mjs} +0 -0
- /package/dist/{chunk-CkXuGtQK.mjs → chunk-B9cbKln6.mjs} +0 -0
- /package/dist/{microphone-DaMZFRuR.mjs → microphone-Bqmoz9_K.mjs} +0 -0
|
@@ -0,0 +1,718 @@
|
|
|
1
|
+
# WGSL Kernel Reference
|
|
2
|
+
|
|
3
|
+
Complete reference for all 12 WGSL compute shaders in `src/gpu/kernels/wgsl/`. Each section covers the algorithm, binding layout, uniform struct, dispatch formula, example dispatches, and optimization opportunities.
|
|
4
|
+
|
|
5
|
+
---
|
|
6
|
+
|
|
7
|
+
## Table of Contents
|
|
8
|
+
|
|
9
|
+
1. [Embedding](#1-embedding)
|
|
10
|
+
2. [MatMul (f32)](#2-matmul-f32)
|
|
11
|
+
3. [MatMulInt4](#3-matmulint4)
|
|
12
|
+
4. [RMSNorm](#4-rmsnorm)
|
|
13
|
+
5. [LayerNorm](#5-layernorm)
|
|
14
|
+
6. [RoPE](#6-rope)
|
|
15
|
+
7. [Attention](#7-attention)
|
|
16
|
+
8. [Softmax](#8-softmax)
|
|
17
|
+
9. [SiLU](#9-silu)
|
|
18
|
+
10. [GELU](#10-gelu)
|
|
19
|
+
11. [Add](#11-add)
|
|
20
|
+
12. [Mul](#12-mul)
|
|
21
|
+
13. [Micro-Benchmarking Approach](#13-micro-benchmarking-approach)
|
|
22
|
+
|
|
23
|
+
---
|
|
24
|
+
|
|
25
|
+
## 1. Embedding
|
|
26
|
+
|
|
27
|
+
**File:** `embedding.wgsl`
|
|
28
|
+
|
|
29
|
+
**Computes:** `output[t, d] = weight[input_ids[t], d]` for all positions `t` and dimensions `d`.
|
|
30
|
+
|
|
31
|
+
**Algorithm:** Flat parallel gather. Each thread handles one element of the output matrix. Thread `idx` maps to position `(idx / hidden_size, idx % hidden_size)`, looks up the token ID at that position, and copies one element from the corresponding row of the embedding weight matrix.
|
|
32
|
+
|
|
33
|
+
### Binding Layout
|
|
34
|
+
|
|
35
|
+
| Binding | Name | Type | Access | Shape |
|
|
36
|
+
|---------|------|------|--------|-------|
|
|
37
|
+
| 0 | `input_ids` | `array<u32>` | read | `[T]` |
|
|
38
|
+
| 1 | `weight` | `array<f32>` | read | `[vocab_size, hidden_size]` (row-major) |
|
|
39
|
+
| 2 | `output` | `array<f32>` | read_write | `[T, hidden_size]` |
|
|
40
|
+
| 3 | `params` | uniform | - | `Params` struct |
|
|
41
|
+
|
|
42
|
+
### Uniform Struct
|
|
43
|
+
|
|
44
|
+
```wgsl
|
|
45
|
+
struct Params {
|
|
46
|
+
seq_len: u32, // T
|
|
47
|
+
hidden_size: u32, // D
|
|
48
|
+
}
|
|
49
|
+
```
|
|
50
|
+
|
|
51
|
+
### Dispatch Formula
|
|
52
|
+
|
|
53
|
+
```
|
|
54
|
+
workgroups_x = ceil(T * hidden_size / 256)
|
|
55
|
+
workgroups_y = 1
|
|
56
|
+
workgroups_z = 1
|
|
57
|
+
```
|
|
58
|
+
|
|
59
|
+
### Example Dispatch
|
|
60
|
+
|
|
61
|
+
| T | hidden_size | Total elements | Workgroups |
|
|
62
|
+
|---|-------------|---------------|------------|
|
|
63
|
+
| 1 | 896 | 896 | (4, 1, 1) |
|
|
64
|
+
| 50 | 896 | 44,800 | (176, 1, 1) |
|
|
65
|
+
| 2048 | 896 | 1,835,008 | (7,169, 1, 1) |
|
|
66
|
+
|
|
67
|
+
### Optimization Notes
|
|
68
|
+
|
|
69
|
+
- Could use `vec4<f32>` loads to process 4 elements per thread (4x fewer threads)
|
|
70
|
+
- Weight matrix could be stored as f16 and converted on read for 2x bandwidth savings
|
|
71
|
+
- For very large vocabularies (150K+), the weight matrix itself is the bottleneck; consider caching hot rows in shared memory
|
|
72
|
+
|
|
73
|
+
---
|
|
74
|
+
|
|
75
|
+
## 2. MatMul (f32)
|
|
76
|
+
|
|
77
|
+
**File:** `matmul.wgsl`
|
|
78
|
+
|
|
79
|
+
**Computes:** `C[M, N] = A[M, K] * B[K, N]` (row-major matrices)
|
|
80
|
+
|
|
81
|
+
**Algorithm:** Classic 16x16 tiled matrix multiply with shared memory. Each workgroup computes a 16x16 tile of the output. For each tile of the K dimension, threads cooperatively load a 16x16 tile of A and a 16x16 tile of B into shared memory, synchronize, then compute partial dot products.
|
|
82
|
+
|
|
83
|
+
### Binding Layout
|
|
84
|
+
|
|
85
|
+
| Binding | Name | Type | Access | Shape |
|
|
86
|
+
|---------|------|------|--------|-------|
|
|
87
|
+
| 0 | `A` | `array<f32>` | read | `[M, K]` |
|
|
88
|
+
| 1 | `B` | `array<f32>` | read | `[K, N]` |
|
|
89
|
+
| 2 | `C` | `array<f32>` | read_write | `[M, N]` |
|
|
90
|
+
| 3 | `params` | uniform | - | `Params` struct |
|
|
91
|
+
|
|
92
|
+
### Uniform Struct
|
|
93
|
+
|
|
94
|
+
```wgsl
|
|
95
|
+
struct Params {
|
|
96
|
+
M: u32,
|
|
97
|
+
K: u32,
|
|
98
|
+
N: u32,
|
|
99
|
+
}
|
|
100
|
+
```
|
|
101
|
+
|
|
102
|
+
### Dispatch Formula
|
|
103
|
+
|
|
104
|
+
```
|
|
105
|
+
workgroups_x = ceil(N / 16)
|
|
106
|
+
workgroups_y = ceil(M / 16)
|
|
107
|
+
workgroups_z = 1
|
|
108
|
+
```
|
|
109
|
+
|
|
110
|
+
### Example Dispatch
|
|
111
|
+
|
|
112
|
+
| Operation | M | K | N | Workgroups |
|
|
113
|
+
|-----------|---|---|---|------------|
|
|
114
|
+
| Q projection (T=1) | 1 | 896 | 896 | (56, 1, 1) |
|
|
115
|
+
| Q projection (T=50) | 50 | 896 | 896 | (56, 4, 1) |
|
|
116
|
+
| Gate proj (T=1) | 1 | 896 | 4864 | (304, 1, 1) |
|
|
117
|
+
| Down proj (T=1) | 1 | 4864 | 896 | (56, 1, 1) |
|
|
118
|
+
| LM head (T=1) | 1 | 896 | 151936 | (9,496, 1, 1) |
|
|
119
|
+
|
|
120
|
+
### Optimization Notes
|
|
121
|
+
|
|
122
|
+
- **Register blocking**: Process 4x4 or 8x8 output elements per thread to improve arithmetic intensity
|
|
123
|
+
- **Vectorized loads**: Use `vec4<f32>` for 128-bit loads from shared memory
|
|
124
|
+
- **Double buffering**: Overlap tile loading with computation by using two shared memory tiles
|
|
125
|
+
- **f16 accumulation**: When `shader-f16` is available, load weights as f16 and accumulate in f32 for 2x bandwidth
|
|
126
|
+
- **Batched matmul**: Fuse multiple small matmuls (Q, K, V projections) into one kernel
|
|
127
|
+
|
|
128
|
+
---
|
|
129
|
+
|
|
130
|
+
## 3. MatMulInt4
|
|
131
|
+
|
|
132
|
+
**File:** `matmul_int4.wgsl`
|
|
133
|
+
|
|
134
|
+
**Computes:** `C[M, N] = A[M, K] * dequant(B_q[K, N])` where B_q is packed INT4 with per-group scales and zeros.
|
|
135
|
+
|
|
136
|
+
**Algorithm:** Each thread computes one element of C. For each element, it iterates over the K dimension, dequantizing B on-the-fly: extract a 4-bit nibble from the packed u32, apply `(nibble - zero) * scale` using per-group parameters, then multiply-accumulate with the corresponding A element.
|
|
137
|
+
|
|
138
|
+
### Binding Layout
|
|
139
|
+
|
|
140
|
+
| Binding | Name | Type | Access | Shape |
|
|
141
|
+
|---------|------|------|--------|-------|
|
|
142
|
+
| 0 | `A` | `array<f32>` | read | `[M, K]` (activations) |
|
|
143
|
+
| 1 | `B_q` | `array<u32>` | read | Packed INT4 `[K, N]`, 8 values per u32 |
|
|
144
|
+
| 2 | `scales` | `array<f32>` | read | One per group |
|
|
145
|
+
| 3 | `zeros` | `array<f32>` | read | One per group |
|
|
146
|
+
| 4 | `C` | `array<f32>` | read_write | `[M, N]` |
|
|
147
|
+
| 5 | `params` | uniform | - | `Params` struct |
|
|
148
|
+
|
|
149
|
+
### Uniform Struct
|
|
150
|
+
|
|
151
|
+
```wgsl
|
|
152
|
+
struct Params {
|
|
153
|
+
M: u32,
|
|
154
|
+
K: u32,
|
|
155
|
+
N: u32,
|
|
156
|
+
group_size: u32, // typically 32 or 128
|
|
157
|
+
}
|
|
158
|
+
```
|
|
159
|
+
|
|
160
|
+
### INT4 Packing
|
|
161
|
+
|
|
162
|
+
Each `u32` in `B_q` stores 8 INT4 values. The `extract_nibble` function extracts a single 4-bit value:
|
|
163
|
+
|
|
164
|
+
```wgsl
|
|
165
|
+
fn extract_nibble(packed: u32, pos: u32) -> f32 {
|
|
166
|
+
let shift = pos * 4u;
|
|
167
|
+
let nibble = (packed >> shift) & 0xFu;
|
|
168
|
+
return f32(nibble);
|
|
169
|
+
}
|
|
170
|
+
```
|
|
171
|
+
|
|
172
|
+
### Dequantization
|
|
173
|
+
|
|
174
|
+
For element at flat index `flat_idx` in B:
|
|
175
|
+
```
|
|
176
|
+
group_idx = flat_idx / group_size
|
|
177
|
+
value = (nibble - zeros[group_idx]) * scales[group_idx]
|
|
178
|
+
```
|
|
179
|
+
|
|
180
|
+
### Dispatch Formula
|
|
181
|
+
|
|
182
|
+
```
|
|
183
|
+
workgroups_x = ceil(N / 16)
|
|
184
|
+
workgroups_y = ceil(M / 16)
|
|
185
|
+
workgroups_z = 1
|
|
186
|
+
```
|
|
187
|
+
|
|
188
|
+
### Optimization Notes
|
|
189
|
+
|
|
190
|
+
- **This kernel is not tiled**: Each thread iterates over the full K dimension, resulting in poor data reuse. A tiled version with shared memory (similar to `matmul.wgsl`) would dramatically improve performance.
|
|
191
|
+
- **Vectorized nibble extraction**: Process 8 nibbles from one u32 load simultaneously
|
|
192
|
+
- **Pre-dequantize tiles**: Load a tile of B_q into shared memory, dequantize in parallel, then do the tiled multiply
|
|
193
|
+
- This kernel is defined but **not yet wired** to graph generators. It will be used when GPTQ/AWQ quantized models are supported.
|
|
194
|
+
|
|
195
|
+
---
|
|
196
|
+
|
|
197
|
+
## 4. RMSNorm
|
|
198
|
+
|
|
199
|
+
**File:** `rmsnorm.wgsl`
|
|
200
|
+
|
|
201
|
+
**Computes:** `output[t, d] = (input[t, d] / rms(input[t, :])) * weight[d]` where `rms(x) = sqrt(mean(x^2) + eps)`
|
|
202
|
+
|
|
203
|
+
**Algorithm:** One workgroup per row (token position). Each thread accumulates squared values for a subset of the hidden dimension using a strided pattern (thread `tid` processes elements `tid, tid+256, tid+512, ...`). A tree reduction in shared memory computes the total sum of squares. Then each thread normalizes and scales its assigned elements.
|
|
204
|
+
|
|
205
|
+
### Binding Layout
|
|
206
|
+
|
|
207
|
+
| Binding | Name | Type | Access | Shape |
|
|
208
|
+
|---------|------|------|--------|-------|
|
|
209
|
+
| 0 | `input` | `array<f32>` | read | `[T, hidden_size]` |
|
|
210
|
+
| 1 | `weight` | `array<f32>` | read | `[hidden_size]` |
|
|
211
|
+
| 2 | `output` | `array<f32>` | read_write | `[T, hidden_size]` |
|
|
212
|
+
| 3 | `params` | uniform | - | `Params` struct |
|
|
213
|
+
|
|
214
|
+
### Uniform Struct
|
|
215
|
+
|
|
216
|
+
```wgsl
|
|
217
|
+
struct Params {
|
|
218
|
+
seq_len: u32, // T
|
|
219
|
+
hidden_size: u32, // D
|
|
220
|
+
eps_bits: u32, // f32 epsilon reinterpreted as u32 (for uniform alignment)
|
|
221
|
+
_pad: u32, // padding to 16-byte alignment
|
|
222
|
+
}
|
|
223
|
+
```
|
|
224
|
+
|
|
225
|
+
Note: `eps` is passed as `bitcast<u32>(eps_f32)` because WebGPU uniform buffers require 16-byte alignment and mixing f32/u32 in a struct can cause alignment issues. The shader uses `bitcast<f32>(params.eps_bits)` to recover the float value.
|
|
226
|
+
|
|
227
|
+
### Dispatch Formula
|
|
228
|
+
|
|
229
|
+
```
|
|
230
|
+
workgroups_x = T // one workgroup per row
|
|
231
|
+
workgroups_y = 1
|
|
232
|
+
workgroups_z = 1
|
|
233
|
+
```
|
|
234
|
+
|
|
235
|
+
### Example Dispatch
|
|
236
|
+
|
|
237
|
+
| T | hidden_size | Workgroups |
|
|
238
|
+
|---|-------------|------------|
|
|
239
|
+
| 1 | 896 | (1, 1, 1) |
|
|
240
|
+
| 50 | 896 | (50, 1, 1) |
|
|
241
|
+
| 2048 | 896 | (2048, 1, 1) |
|
|
242
|
+
|
|
243
|
+
### Optimization Notes
|
|
244
|
+
|
|
245
|
+
- For hidden_size <= 256, the tree reduction is oversized (many idle threads)
|
|
246
|
+
- Could use `vec4<f32>` accumulation for 4x fewer loop iterations
|
|
247
|
+
- Fusing RMSNorm with the following MatMul would save one global memory round-trip
|
|
248
|
+
|
|
249
|
+
---
|
|
250
|
+
|
|
251
|
+
## 5. LayerNorm
|
|
252
|
+
|
|
253
|
+
**File:** `layernorm.wgsl`
|
|
254
|
+
|
|
255
|
+
**Computes:** `output[t, d] = ((input[t, d] - mean) / sqrt(variance + eps)) * weight[d] + bias[d]`
|
|
256
|
+
|
|
257
|
+
**Algorithm:** Two-pass reduction per row. First pass: tree reduction to compute mean. Second pass: tree reduction to compute variance (sum of squared differences from mean). Then each thread normalizes, scales, and shifts its assigned elements.
|
|
258
|
+
|
|
259
|
+
### Binding Layout
|
|
260
|
+
|
|
261
|
+
| Binding | Name | Type | Access | Shape |
|
|
262
|
+
|---------|------|------|--------|-------|
|
|
263
|
+
| 0 | `input` | `array<f32>` | read | `[T, hidden_size]` |
|
|
264
|
+
| 1 | `weight` | `array<f32>` | read | `[hidden_size]` |
|
|
265
|
+
| 2 | `bias` | `array<f32>` | read | `[hidden_size]` |
|
|
266
|
+
| 3 | `output` | `array<f32>` | read_write | `[T, hidden_size]` |
|
|
267
|
+
| 4 | `params` | uniform | - | `Params` struct |
|
|
268
|
+
|
|
269
|
+
### Uniform Struct
|
|
270
|
+
|
|
271
|
+
```wgsl
|
|
272
|
+
struct Params {
|
|
273
|
+
seq_len: u32,
|
|
274
|
+
hidden_size: u32,
|
|
275
|
+
eps_bits: u32,
|
|
276
|
+
_pad: u32,
|
|
277
|
+
}
|
|
278
|
+
```
|
|
279
|
+
|
|
280
|
+
### Dispatch Formula
|
|
281
|
+
|
|
282
|
+
```
|
|
283
|
+
workgroups_x = T
|
|
284
|
+
workgroups_y = 1
|
|
285
|
+
workgroups_z = 1
|
|
286
|
+
```
|
|
287
|
+
|
|
288
|
+
### Optimization Notes
|
|
289
|
+
|
|
290
|
+
- Uses two separate shared memory arrays (`shared_sum` and `shared_sq_sum`); could use a single-pass Welford algorithm to compute both mean and variance simultaneously
|
|
291
|
+
- Same vectorization and fusion opportunities as RMSNorm
|
|
292
|
+
|
|
293
|
+
---
|
|
294
|
+
|
|
295
|
+
## 6. RoPE
|
|
296
|
+
|
|
297
|
+
**File:** `rope.wgsl`
|
|
298
|
+
|
|
299
|
+
**Computes:** For each head, rotates pairs of dimensions by position-dependent angles:
|
|
300
|
+
```
|
|
301
|
+
q_out[2i] = q[2i] * cos(theta) - q[2i+1] * sin(theta)
|
|
302
|
+
q_out[2i+1] = q[2i] * sin(theta) + q[2i+1] * cos(theta)
|
|
303
|
+
```
|
|
304
|
+
where `theta = pos * base^(-2i/dim)`.
|
|
305
|
+
|
|
306
|
+
**Algorithm:** Each thread handles one pair of dimensions for one head at one position. Handles Q and K separately since they may have different head counts (GQA). The kernel supports a `position_offset` for decode steps where the query is at a later position in the sequence.
|
|
307
|
+
|
|
308
|
+
**Note:** This kernel operates **in-place** -- Q and K are both input and output buffers (`read_write`).
|
|
309
|
+
|
|
310
|
+
### Binding Layout
|
|
311
|
+
|
|
312
|
+
| Binding | Name | Type | Access | Shape |
|
|
313
|
+
|---------|------|------|--------|-------|
|
|
314
|
+
| 0 | `q` | `array<f32>` | read_write | `[T, num_q_heads * head_dim]` |
|
|
315
|
+
| 1 | `k` | `array<f32>` | read_write | `[T, num_kv_heads * head_dim]` |
|
|
316
|
+
| 2 | `params` | uniform | - | `Params` struct |
|
|
317
|
+
|
|
318
|
+
### Uniform Struct
|
|
319
|
+
|
|
320
|
+
```wgsl
|
|
321
|
+
struct Params {
|
|
322
|
+
seq_len: u32,
|
|
323
|
+
num_q_heads: u32,
|
|
324
|
+
num_kv_heads: u32,
|
|
325
|
+
head_dim: u32,
|
|
326
|
+
rope_base_bits: u32, // f32 base reinterpreted as u32
|
|
327
|
+
position_offset: u32, // starting position for decode step
|
|
328
|
+
}
|
|
329
|
+
```
|
|
330
|
+
|
|
331
|
+
### Dispatch Formula
|
|
332
|
+
|
|
333
|
+
```
|
|
334
|
+
total_pairs = T * max(num_q_heads, num_kv_heads) * (head_dim / 2)
|
|
335
|
+
workgroups_x = ceil(total_pairs / 256)
|
|
336
|
+
workgroups_y = 1
|
|
337
|
+
workgroups_z = 1
|
|
338
|
+
```
|
|
339
|
+
|
|
340
|
+
### Example Dispatch
|
|
341
|
+
|
|
342
|
+
| T | num_q_heads | num_kv_heads | head_dim | Total pairs | Workgroups |
|
|
343
|
+
|---|-------------|-------------|----------|-------------|------------|
|
|
344
|
+
| 1 | 14 | 2 | 64 | 448 | (2, 1, 1) |
|
|
345
|
+
| 50 | 14 | 2 | 64 | 22,400 | (88, 1, 1) |
|
|
346
|
+
|
|
347
|
+
### Optimization Notes
|
|
348
|
+
|
|
349
|
+
- Q and K are processed by the same dispatch; with GQA (num_q_heads >> num_kv_heads), the K processing finishes much earlier while Q threads are still running
|
|
350
|
+
- Frequency computation (`pow(base, ...)`) could be pre-computed into a lookup table in shared memory
|
|
351
|
+
- The sin/cos computation could use `sincos()` if available, though WGSL doesn't have it natively
|
|
352
|
+
|
|
353
|
+
---
|
|
354
|
+
|
|
355
|
+
## 7. Attention
|
|
356
|
+
|
|
357
|
+
**File:** `attention.wgsl`
|
|
358
|
+
|
|
359
|
+
**Computes:** Scaled dot-product attention with causal masking and GQA:
|
|
360
|
+
```
|
|
361
|
+
scores = (Q @ K^T) / sqrt(head_dim)
|
|
362
|
+
scores = causal_mask(scores)
|
|
363
|
+
weights = softmax(scores)
|
|
364
|
+
output = weights @ V
|
|
365
|
+
```
|
|
366
|
+
|
|
367
|
+
**Algorithm:** Each workgroup handles one (query_position, q_head) pair. **Correctness-first implementation**: thread 0 performs the full computation (dot products, max-finding, softmax, V accumulation) while other threads are idle. This is explicitly marked as a TODO for optimization.
|
|
368
|
+
|
|
369
|
+
GQA is handled by mapping each Q head to its corresponding KV head: `kv_head = q_head / (num_q_heads / num_kv_heads)`.
|
|
370
|
+
|
|
371
|
+
### Binding Layout
|
|
372
|
+
|
|
373
|
+
| Binding | Name | Type | Access | Shape |
|
|
374
|
+
|---------|------|------|--------|-------|
|
|
375
|
+
| 0 | `Q` | `array<f32>` | read | `[T, num_q_heads * head_dim]` |
|
|
376
|
+
| 1 | `K` | `array<f32>` | read | `[S, num_kv_heads * head_dim]` |
|
|
377
|
+
| 2 | `V` | `array<f32>` | read | `[S, num_kv_heads * head_dim]` |
|
|
378
|
+
| 3 | `output` | `array<f32>` | read_write | `[T, num_q_heads * head_dim]` |
|
|
379
|
+
| 4 | `params` | uniform | - | `Params` struct |
|
|
380
|
+
|
|
381
|
+
### Uniform Struct
|
|
382
|
+
|
|
383
|
+
```wgsl
|
|
384
|
+
struct Params {
|
|
385
|
+
T: u32, // query seq len (prompt during prefill, 1 during decode)
|
|
386
|
+
S: u32, // key/value seq len (total including cache)
|
|
387
|
+
num_q_heads: u32,
|
|
388
|
+
num_kv_heads: u32,
|
|
389
|
+
head_dim: u32,
|
|
390
|
+
position_offset: u32, // start position of Q in full sequence
|
|
391
|
+
}
|
|
392
|
+
```
|
|
393
|
+
|
|
394
|
+
### Dispatch Formula
|
|
395
|
+
|
|
396
|
+
```
|
|
397
|
+
workgroups_x = T // one per query position
|
|
398
|
+
workgroups_y = num_q_heads // one per Q head
|
|
399
|
+
workgroups_z = 1
|
|
400
|
+
```
|
|
401
|
+
|
|
402
|
+
### Example Dispatch
|
|
403
|
+
|
|
404
|
+
| T | S | num_q_heads | Workgroups |
|
|
405
|
+
|---|---|-------------|------------|
|
|
406
|
+
| 1 (decode) | 100 | 14 | (1, 14, 1) |
|
|
407
|
+
| 50 (prefill) | 50 | 14 | (50, 14, 1) |
|
|
408
|
+
|
|
409
|
+
### Optimization Notes
|
|
410
|
+
|
|
411
|
+
This is the **highest priority kernel for optimization**:
|
|
412
|
+
|
|
413
|
+
- **Current**: Only thread 0 works; 255 threads per workgroup are wasted
|
|
414
|
+
- **Parallel dot products**: Each thread computes one dimension of the Q@K dot product, followed by a shared memory reduction
|
|
415
|
+
- **Tiled attention**: Process S positions in tiles, with online softmax (FlashAttention-style) to avoid materializing the full attention matrix
|
|
416
|
+
- **Shared memory K/V**: Load K/V tiles into shared memory for data reuse across dot products
|
|
417
|
+
- **Multi-query batching**: During prefill, process multiple query positions per workgroup
|
|
418
|
+
|
|
419
|
+
---
|
|
420
|
+
|
|
421
|
+
## 8. Softmax
|
|
422
|
+
|
|
423
|
+
**File:** `softmax.wgsl`
|
|
424
|
+
|
|
425
|
+
**Computes:** Row-wise softmax: `output[i] = exp(input[i] - max) / sum(exp(input[:] - max))`
|
|
426
|
+
|
|
427
|
+
**Algorithm:** Three-pass parallel reduction, one workgroup per row:
|
|
428
|
+
1. **Pass 1 (max)**: Each thread finds the max of its assigned elements; tree reduction to get row max
|
|
429
|
+
2. **Pass 2 (sum)**: Each thread computes `sum(exp(x - max))` for its elements; tree reduction to get total sum
|
|
430
|
+
3. **Pass 3 (normalize)**: Each thread computes `exp(x - max) / total_sum` for its elements
|
|
431
|
+
|
|
432
|
+
### Binding Layout
|
|
433
|
+
|
|
434
|
+
| Binding | Name | Type | Access | Shape |
|
|
435
|
+
|---------|------|------|--------|-------|
|
|
436
|
+
| 0 | `input` | `array<f32>` | read | `[num_rows, row_size]` |
|
|
437
|
+
| 1 | `output` | `array<f32>` | read_write | `[num_rows, row_size]` |
|
|
438
|
+
| 2 | `params` | uniform | - | `Params` struct |
|
|
439
|
+
|
|
440
|
+
### Uniform Struct
|
|
441
|
+
|
|
442
|
+
```wgsl
|
|
443
|
+
struct Params {
|
|
444
|
+
num_rows: u32,
|
|
445
|
+
row_size: u32,
|
|
446
|
+
}
|
|
447
|
+
```
|
|
448
|
+
|
|
449
|
+
### Dispatch Formula
|
|
450
|
+
|
|
451
|
+
```
|
|
452
|
+
workgroups_x = num_rows
|
|
453
|
+
workgroups_y = 1
|
|
454
|
+
workgroups_z = 1
|
|
455
|
+
```
|
|
456
|
+
|
|
457
|
+
### Optimization Notes
|
|
458
|
+
|
|
459
|
+
- Pass 3 redundantly recomputes `exp(x - max)` -- could store intermediate exp values in a buffer or fuse passes 2 and 3 using online normalization
|
|
460
|
+
- For small row sizes (< 256), many threads are idle
|
|
461
|
+
- Currently a standalone kernel; in practice, softmax is most commonly needed inside the attention kernel, so fusing them eliminates a global memory round-trip
|
|
462
|
+
|
|
463
|
+
---
|
|
464
|
+
|
|
465
|
+
## 9. SiLU
|
|
466
|
+
|
|
467
|
+
**File:** `silu.wgsl`
|
|
468
|
+
|
|
469
|
+
**Computes:** `output[i] = input[i] / (1 + exp(-input[i]))` (equivalent to `x * sigmoid(x)`)
|
|
470
|
+
|
|
471
|
+
**Algorithm:** Pure element-wise. Each thread processes one element.
|
|
472
|
+
|
|
473
|
+
### Binding Layout
|
|
474
|
+
|
|
475
|
+
| Binding | Name | Type | Access | Shape |
|
|
476
|
+
|---------|------|------|--------|-------|
|
|
477
|
+
| 0 | `input` | `array<f32>` | read | `[count]` |
|
|
478
|
+
| 1 | `output` | `array<f32>` | read_write | `[count]` |
|
|
479
|
+
| 2 | `params` | uniform | - | `Params` struct |
|
|
480
|
+
|
|
481
|
+
### Uniform Struct
|
|
482
|
+
|
|
483
|
+
```wgsl
|
|
484
|
+
struct Params {
|
|
485
|
+
count: u32,
|
|
486
|
+
}
|
|
487
|
+
```
|
|
488
|
+
|
|
489
|
+
### Dispatch Formula
|
|
490
|
+
|
|
491
|
+
```
|
|
492
|
+
workgroups_x = ceil(count / 256)
|
|
493
|
+
workgroups_y = 1
|
|
494
|
+
workgroups_z = 1
|
|
495
|
+
```
|
|
496
|
+
|
|
497
|
+
### Example Dispatch
|
|
498
|
+
|
|
499
|
+
| T | intermediate_size | count | Workgroups |
|
|
500
|
+
|---|-------------------|-------|------------|
|
|
501
|
+
| 1 | 4864 | 4,864 | (19, 1, 1) |
|
|
502
|
+
| 50 | 4864 | 243,200 | (951, 1, 1) |
|
|
503
|
+
|
|
504
|
+
### Optimization Notes
|
|
505
|
+
|
|
506
|
+
- Could process `vec4<f32>` per thread for 4x fewer threads
|
|
507
|
+
- Could be fused with the following Mul (SwiGLU combine): `output = silu(gate) * up`
|
|
508
|
+
|
|
509
|
+
---
|
|
510
|
+
|
|
511
|
+
## 10. GELU
|
|
512
|
+
|
|
513
|
+
**File:** `gelu.wgsl`
|
|
514
|
+
|
|
515
|
+
**Computes:** Approximate GELU: `output = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`
|
|
516
|
+
|
|
517
|
+
**Algorithm:** Pure element-wise. Uses the tanh approximation, which is the standard for transformer models.
|
|
518
|
+
|
|
519
|
+
### Binding Layout
|
|
520
|
+
|
|
521
|
+
| Binding | Name | Type | Access | Shape |
|
|
522
|
+
|---------|------|------|--------|-------|
|
|
523
|
+
| 0 | `input` | `array<f32>` | read | `[count]` |
|
|
524
|
+
| 1 | `output` | `array<f32>` | read_write | `[count]` |
|
|
525
|
+
| 2 | `params` | uniform | - | `Params` struct |
|
|
526
|
+
|
|
527
|
+
### Uniform Struct
|
|
528
|
+
|
|
529
|
+
```wgsl
|
|
530
|
+
struct Params {
|
|
531
|
+
count: u32,
|
|
532
|
+
}
|
|
533
|
+
```
|
|
534
|
+
|
|
535
|
+
### Dispatch Formula
|
|
536
|
+
|
|
537
|
+
```
|
|
538
|
+
workgroups_x = ceil(count / 256)
|
|
539
|
+
```
|
|
540
|
+
|
|
541
|
+
### Constants
|
|
542
|
+
|
|
543
|
+
```wgsl
|
|
544
|
+
const SQRT_2_OVER_PI: f32 = 0.7978845608;
|
|
545
|
+
const GELU_COEFF: f32 = 0.044715;
|
|
546
|
+
```
|
|
547
|
+
|
|
548
|
+
### Optimization Notes
|
|
549
|
+
|
|
550
|
+
- Same vectorization opportunity as SiLU
|
|
551
|
+
- Used by Phi and GPT-2 style models (not Qwen/LLaMA which use SiLU)
|
|
552
|
+
|
|
553
|
+
---
|
|
554
|
+
|
|
555
|
+
## 11. Add
|
|
556
|
+
|
|
557
|
+
**File:** `add.wgsl`
|
|
558
|
+
|
|
559
|
+
**Computes:** `output[i] = a[i] + b[i]`
|
|
560
|
+
|
|
561
|
+
**Algorithm:** Pure element-wise. Used for residual connections.
|
|
562
|
+
|
|
563
|
+
### Binding Layout
|
|
564
|
+
|
|
565
|
+
| Binding | Name | Type | Access | Shape |
|
|
566
|
+
|---------|------|------|--------|-------|
|
|
567
|
+
| 0 | `a` | `array<f32>` | read | `[count]` |
|
|
568
|
+
| 1 | `b` | `array<f32>` | read | `[count]` |
|
|
569
|
+
| 2 | `output` | `array<f32>` | read_write | `[count]` |
|
|
570
|
+
| 3 | `params` | uniform | - | `Params` struct |
|
|
571
|
+
|
|
572
|
+
### Uniform Struct
|
|
573
|
+
|
|
574
|
+
```wgsl
|
|
575
|
+
struct Params {
|
|
576
|
+
count: u32,
|
|
577
|
+
}
|
|
578
|
+
```
|
|
579
|
+
|
|
580
|
+
### Dispatch Formula
|
|
581
|
+
|
|
582
|
+
```
|
|
583
|
+
workgroups_x = ceil(count / 256)
|
|
584
|
+
```
|
|
585
|
+
|
|
586
|
+
### Optimization Notes
|
|
587
|
+
|
|
588
|
+
- Could use `vec4<f32>` for 128-bit operations
|
|
589
|
+
- Residual add could be fused with the preceding kernel (e.g., output projection writes directly to `resid + output`)
|
|
590
|
+
|
|
591
|
+
---
|
|
592
|
+
|
|
593
|
+
## 12. Mul
|
|
594
|
+
|
|
595
|
+
**File:** `mul.wgsl`
|
|
596
|
+
|
|
597
|
+
**Computes:** `output[i] = a[i] * b[i]`
|
|
598
|
+
|
|
599
|
+
**Algorithm:** Pure element-wise. Used for the SwiGLU combine step (`silu(gate) * up`).
|
|
600
|
+
|
|
601
|
+
### Binding Layout
|
|
602
|
+
|
|
603
|
+
| Binding | Name | Type | Access | Shape |
|
|
604
|
+
|---------|------|------|--------|-------|
|
|
605
|
+
| 0 | `a` | `array<f32>` | read | `[count]` |
|
|
606
|
+
| 1 | `b` | `array<f32>` | read | `[count]` |
|
|
607
|
+
| 2 | `output` | `array<f32>` | read_write | `[count]` |
|
|
608
|
+
| 3 | `params` | uniform | - | `Params` struct |
|
|
609
|
+
|
|
610
|
+
### Uniform Struct
|
|
611
|
+
|
|
612
|
+
```wgsl
|
|
613
|
+
struct Params {
|
|
614
|
+
count: u32,
|
|
615
|
+
}
|
|
616
|
+
```
|
|
617
|
+
|
|
618
|
+
### Dispatch Formula
|
|
619
|
+
|
|
620
|
+
```
|
|
621
|
+
workgroups_x = ceil(count / 256)
|
|
622
|
+
```
|
|
623
|
+
|
|
624
|
+
### Optimization Notes
|
|
625
|
+
|
|
626
|
+
- Same vectorization and fusion opportunities as Add
|
|
627
|
+
- Could be fused with preceding SiLU: single kernel computes `silu(a[i]) * b[i]`
|
|
628
|
+
|
|
629
|
+
---
|
|
630
|
+
|
|
631
|
+
## 13. Micro-Benchmarking Approach
|
|
632
|
+
|
|
633
|
+
To measure individual kernel performance, use this pattern:
|
|
634
|
+
|
|
635
|
+
### Test Harness
|
|
636
|
+
|
|
637
|
+
```typescript
|
|
638
|
+
import { initGPU, createStorageBuffer, createUniformBuffer, getOrCreatePipeline, createBindGroup } from "./device.js";
|
|
639
|
+
|
|
640
|
+
async function benchmarkMatmul(M: number, K: number, N: number, iterations: number = 100) {
|
|
641
|
+
const ctx = await initGPU();
|
|
642
|
+
|
|
643
|
+
// Allocate buffers with random data
|
|
644
|
+
const A = createStorageBuffer(ctx, "A", M * K * 4, randomF32(M * K));
|
|
645
|
+
const B = createStorageBuffer(ctx, "B", K * N * 4, randomF32(K * N));
|
|
646
|
+
const C = createStorageBuffer(ctx, "C", M * N * 4);
|
|
647
|
+
const params = createUniformBuffer(ctx, "params",
|
|
648
|
+
new Uint32Array([M, K, N]).buffer);
|
|
649
|
+
|
|
650
|
+
const pipeline = getOrCreatePipeline(ctx, "matmul", matmulWGSL, "main");
|
|
651
|
+
const bindGroup = createBindGroup(ctx, pipeline,
|
|
652
|
+
[{ buffer: A }, { buffer: B }, { buffer: C }, { buffer: params }]);
|
|
653
|
+
|
|
654
|
+
const wgX = Math.ceil(N / 16);
|
|
655
|
+
const wgY = Math.ceil(M / 16);
|
|
656
|
+
|
|
657
|
+
// Warmup
|
|
658
|
+
for (let i = 0; i < 10; i++) {
|
|
659
|
+
const enc = ctx.device.createCommandEncoder();
|
|
660
|
+
const pass = enc.beginComputePass();
|
|
661
|
+
pass.setPipeline(pipeline);
|
|
662
|
+
pass.setBindGroup(0, bindGroup);
|
|
663
|
+
pass.dispatchWorkgroups(wgX, wgY, 1);
|
|
664
|
+
pass.end();
|
|
665
|
+
ctx.device.queue.submit([enc.finish()]);
|
|
666
|
+
}
|
|
667
|
+
await ctx.device.queue.onSubmittedWorkDone();
|
|
668
|
+
|
|
669
|
+
// Benchmark
|
|
670
|
+
const start = performance.now();
|
|
671
|
+
for (let i = 0; i < iterations; i++) {
|
|
672
|
+
const enc = ctx.device.createCommandEncoder();
|
|
673
|
+
const pass = enc.beginComputePass();
|
|
674
|
+
pass.setPipeline(pipeline);
|
|
675
|
+
pass.setBindGroup(0, bindGroup);
|
|
676
|
+
pass.dispatchWorkgroups(wgX, wgY, 1);
|
|
677
|
+
pass.end();
|
|
678
|
+
ctx.device.queue.submit([enc.finish()]);
|
|
679
|
+
}
|
|
680
|
+
await ctx.device.queue.onSubmittedWorkDone();
|
|
681
|
+
const elapsed = performance.now() - start;
|
|
682
|
+
|
|
683
|
+
// Compute metrics
|
|
684
|
+
const flops = 2 * M * K * N; // multiply-add = 2 ops
|
|
685
|
+
const totalFlops = flops * iterations;
|
|
686
|
+
const gflops = totalFlops / (elapsed * 1e6); // elapsed is in ms
|
|
687
|
+
|
|
688
|
+
console.log(`MatMul ${M}x${K}x${N}:`);
|
|
689
|
+
console.log(` ${iterations} iterations in ${elapsed.toFixed(1)} ms`);
|
|
690
|
+
console.log(` ${(elapsed / iterations).toFixed(3)} ms/iter`);
|
|
691
|
+
console.log(` ${gflops.toFixed(1)} GFLOPS`);
|
|
692
|
+
|
|
693
|
+
A.destroy(); B.destroy(); C.destroy(); params.destroy();
|
|
694
|
+
}
|
|
695
|
+
```
|
|
696
|
+
|
|
697
|
+
### Expected Performance Targets
|
|
698
|
+
|
|
699
|
+
For reference, approximate theoretical peak GFLOPS for common GPUs:
|
|
700
|
+
|
|
701
|
+
| GPU | f32 TFLOPS | Expected matmul efficiency |
|
|
702
|
+
|-----|-----------|---------------------------|
|
|
703
|
+
| Apple M1 (8-core) | ~2.6 | 30-50% with 16x16 tiling |
|
|
704
|
+
| Apple M2 (10-core) | ~3.6 | 30-50% |
|
|
705
|
+
| NVIDIA RTX 3060 | ~12.7 | 40-60% |
|
|
706
|
+
| Intel Arc A770 | ~17.2 | 30-50% |
|
|
707
|
+
|
|
708
|
+
The 16x16 tiling strategy without register blocking typically achieves 30-50% of peak throughput. Larger tiles (32x32, 64x64) with register blocking can push this to 70-80%.
|
|
709
|
+
|
|
710
|
+
### Key Metrics Per Kernel
|
|
711
|
+
|
|
712
|
+
| Kernel | Metric | Formula |
|
|
713
|
+
|--------|--------|---------|
|
|
714
|
+
| MatMul | GFLOPS | `2 * M * K * N / (time_ms * 1e6)` |
|
|
715
|
+
| Embedding | GB/s | `T * hidden_size * 4 / (time_ms * 1e6)` |
|
|
716
|
+
| RMSNorm | GB/s | `T * hidden_size * 4 * 3 / (time_ms * 1e6)` (read input + weight + write output) |
|
|
717
|
+
| Attention | GFLOPS | `2 * T * S * head_dim * num_heads / (time_ms * 1e6)` |
|
|
718
|
+
| Element-wise | GB/s | `count * 4 * (num_inputs + 1) / (time_ms * 1e6)` |
|