@tryhamster/gerbil 1.0.0-rc.9 → 1.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (179) hide show
  1. package/LICENSE +1 -1
  2. package/README.md +318 -104
  3. package/dist/architectures-C1I5V3Dt.mjs +6070 -0
  4. package/dist/architectures-C1I5V3Dt.mjs.map +1 -0
  5. package/dist/browser/index.d.ts +276 -590
  6. package/dist/browser/index.d.ts.map +1 -1
  7. package/dist/browser/index.js +592 -2334
  8. package/dist/browser/index.js.map +1 -1
  9. package/dist/cli.mjs +625 -1098
  10. package/dist/cli.mjs.map +1 -1
  11. package/dist/defaults-9komdrbY.mjs +24 -0
  12. package/dist/defaults-9komdrbY.mjs.map +1 -0
  13. package/dist/frameworks/express.d.mts +1 -3
  14. package/dist/frameworks/express.d.mts.map +1 -1
  15. package/dist/frameworks/express.mjs +7 -7
  16. package/dist/frameworks/express.mjs.map +1 -1
  17. package/dist/frameworks/fastify.d.mts +1 -1
  18. package/dist/frameworks/fastify.d.mts.map +1 -1
  19. package/dist/frameworks/fastify.mjs +3 -3
  20. package/dist/frameworks/fastify.mjs.map +1 -1
  21. package/dist/frameworks/hono.d.mts +1 -1
  22. package/dist/frameworks/hono.d.mts.map +1 -1
  23. package/dist/frameworks/hono.mjs +4 -4
  24. package/dist/frameworks/hono.mjs.map +1 -1
  25. package/dist/frameworks/next.d.mts +3 -2
  26. package/dist/frameworks/next.d.mts.map +1 -1
  27. package/dist/frameworks/next.mjs +4 -4
  28. package/dist/frameworks/next.mjs.map +1 -1
  29. package/dist/frameworks/react.d.mts +1 -1
  30. package/dist/frameworks/trpc.d.mts +1 -1
  31. package/dist/frameworks/trpc.d.mts.map +1 -1
  32. package/dist/frameworks/trpc.mjs +4 -4
  33. package/dist/frameworks/trpc.mjs.map +1 -1
  34. package/dist/gerbil-BetB5xb0.d.mts +488 -0
  35. package/dist/gerbil-BetB5xb0.d.mts.map +1 -0
  36. package/dist/gerbil-CTZUa8EZ.mjs +4 -0
  37. package/dist/gerbil-DNniplr4.mjs +1656 -0
  38. package/dist/gerbil-DNniplr4.mjs.map +1 -0
  39. package/dist/gpu/hooks.d.mts +640 -0
  40. package/dist/gpu/hooks.d.mts.map +1 -0
  41. package/dist/gpu/hooks.mjs +1369 -0
  42. package/dist/gpu/hooks.mjs.map +1 -0
  43. package/dist/gpu/index.d.mts +2 -0
  44. package/dist/gpu/index.mjs +6 -0
  45. package/dist/gpu-DFuglcEx.mjs +3790 -0
  46. package/dist/gpu-DFuglcEx.mjs.map +1 -0
  47. package/dist/index-Dgmb2kE3.d.mts +245 -0
  48. package/dist/index-Dgmb2kE3.d.mts.map +1 -0
  49. package/dist/index-DukkJRMj.d.mts +2114 -0
  50. package/dist/index-DukkJRMj.d.mts.map +1 -0
  51. package/dist/index.d.mts +22 -487
  52. package/dist/index.d.mts.map +1 -1
  53. package/dist/index.mjs +13 -8
  54. package/dist/index.mjs.map +1 -1
  55. package/dist/indexeddb-store-BWIMtxxH.mjs +103 -0
  56. package/dist/indexeddb-store-BWIMtxxH.mjs.map +1 -0
  57. package/dist/indexeddb-store-ClH12Xnl.mjs +4 -0
  58. package/dist/integrations/ai-sdk.d.mts +75 -6
  59. package/dist/integrations/ai-sdk.d.mts.map +1 -1
  60. package/dist/integrations/ai-sdk.mjs +131 -15
  61. package/dist/integrations/ai-sdk.mjs.map +1 -1
  62. package/dist/integrations/langchain.d.mts +1 -1
  63. package/dist/integrations/langchain.d.mts.map +1 -1
  64. package/dist/integrations/langchain.mjs +5 -5
  65. package/dist/integrations/langchain.mjs.map +1 -1
  66. package/dist/integrations/llamaindex.d.mts +1 -1
  67. package/dist/integrations/llamaindex.d.mts.map +1 -1
  68. package/dist/integrations/llamaindex.mjs +5 -5
  69. package/dist/integrations/llamaindex.mjs.map +1 -1
  70. package/dist/integrations/mcp-client.mjs +3 -3
  71. package/dist/integrations/mcp-client.mjs.map +1 -1
  72. package/dist/integrations/mcp.d.mts +3 -2
  73. package/dist/integrations/mcp.d.mts.map +1 -1
  74. package/dist/integrations/mcp.mjs +5 -5
  75. package/dist/{mcp-BvbriaBy.mjs → mcp-D2vvH1Xc.mjs} +4 -4
  76. package/dist/mcp-D2vvH1Xc.mjs.map +1 -0
  77. package/dist/memory/index.d.mts +3 -0
  78. package/dist/memory/index.mjs +6 -0
  79. package/dist/memory-D1P7Tmda.mjs +4 -0
  80. package/dist/memory-DVN0MnIG.mjs +132 -0
  81. package/dist/memory-DVN0MnIG.mjs.map +1 -0
  82. package/dist/memory-Dj0J1v88.mjs +294 -0
  83. package/dist/memory-Dj0J1v88.mjs.map +1 -0
  84. package/dist/moonshine-stt-17dpP1kr.mjs +4 -0
  85. package/dist/moonshine-stt-4ojLtMq7.mjs +11962 -0
  86. package/dist/moonshine-stt-4ojLtMq7.mjs.map +1 -0
  87. package/dist/{one-liner-s-lD8rCC.mjs → one-liner-JhdIPxzF.mjs} +14 -16
  88. package/dist/one-liner-JhdIPxzF.mjs.map +1 -0
  89. package/dist/repl-BDRkwPGX.mjs +9 -0
  90. package/dist/skills/index.d.mts +270 -320
  91. package/dist/skills/index.d.mts.map +1 -1
  92. package/dist/skills/index.mjs +5 -5
  93. package/dist/{skills-CD3Orlex.mjs → skills-CU694Dc8.mjs} +187 -32
  94. package/dist/skills-CU694Dc8.mjs.map +1 -0
  95. package/dist/{tools-Bi1P7Xoy.mjs → tools-DQ1mPUw5.mjs} +34 -22
  96. package/dist/tools-DQ1mPUw5.mjs.map +1 -0
  97. package/dist/types-DQBe2lFo.d.mts +165 -0
  98. package/dist/types-DQBe2lFo.d.mts.map +1 -0
  99. package/dist/{types-CiTc7ez3.d.mts → types-LlyYILII.d.mts} +112 -14
  100. package/dist/types-LlyYILII.d.mts.map +1 -0
  101. package/dist/{utils-CZBZ8dgR.mjs → utils-DKO55ZmZ.mjs} +1 -1
  102. package/dist/{utils-CZBZ8dgR.mjs.map → utils-DKO55ZmZ.mjs.map} +1 -1
  103. package/dist/vector-B0panuy6.mjs +95 -0
  104. package/dist/vector-B0panuy6.mjs.map +1 -0
  105. package/docs/PROJECT-STATE.md +321 -0
  106. package/docs/adding-a-model-family.md +280 -0
  107. package/docs/ai-sdk.md +70 -61
  108. package/docs/architecture/overview.md +17 -7
  109. package/docs/browser.md +203 -8
  110. package/docs/embeddings.md +156 -0
  111. package/docs/gerbil-site-native-migration.md +217 -0
  112. package/docs/gpu-engine/architectures.md +398 -0
  113. package/docs/gpu-engine/ir.md +372 -0
  114. package/docs/gpu-engine/kernels.md +718 -0
  115. package/docs/gpu-engine/paper.html +1759 -0
  116. package/docs/gpu-engine/paper.md +2109 -0
  117. package/docs/gpu-engine/safetensors.md +312 -0
  118. package/docs/gpu-engine/tokenizer.md +302 -0
  119. package/docs/memory-rag.md +91 -0
  120. package/docs/metal-safari-intel.md +190 -0
  121. package/docs/mobile-failure-diagnosis.md +124 -0
  122. package/docs/mobile.md +99 -0
  123. package/docs/observability.md +230 -0
  124. package/docs/onnx-removal-plan.md +339 -0
  125. package/docs/research/autoresearch-portable.md +904 -0
  126. package/docs/research/dispatch-reduction-hivemind.md +84 -0
  127. package/docs/research/ios-safari-model-caching.md +117 -0
  128. package/docs/research/mobile-webgpu-speed-fusion.md +135 -0
  129. package/docs/research/native-stt-model-selection.md +49 -0
  130. package/docs/research/native-tts-model-selection.md +90 -0
  131. package/docs/research/native-vs-chromium-decision.md +152 -0
  132. package/docs/research/nemotron-mamba2-inference.md +910 -0
  133. package/docs/research/qwen35-multimodal.md +293 -0
  134. package/docs/research/qwen36-gemma4-targets.md +337 -0
  135. package/docs/research/sota-embedding-models.md +179 -0
  136. package/docs/research/sota-mobile-models-2026.md +263 -0
  137. package/docs/research/sota-modality-models.md +202 -0
  138. package/docs/research/tps-baselines.md +71 -0
  139. package/docs/research/webgpu-m4-reference.md +104 -0
  140. package/docs/site-update-plan.md +155 -0
  141. package/docs/structured-output.md +123 -0
  142. package/docs/stt.md +63 -446
  143. package/docs/tts.md +77 -499
  144. package/docs/vision.md +100 -338
  145. package/package.json +22 -7
  146. package/dist/chrome-backend-CORwaIyC.mjs +0 -1212
  147. package/dist/chrome-backend-CORwaIyC.mjs.map +0 -1
  148. package/dist/chrome-backend-DIKYoWj-.mjs +0 -3
  149. package/dist/gerbil-CJ3ifloF.mjs +0 -4
  150. package/dist/gerbil-Dw4Qj77e.mjs +0 -1631
  151. package/dist/gerbil-Dw4Qj77e.mjs.map +0 -1
  152. package/dist/gerbil-qOTe1nl2.d.mts +0 -431
  153. package/dist/gerbil-qOTe1nl2.d.mts.map +0 -1
  154. package/dist/kokoro-BNTb6egA.mjs +0 -20210
  155. package/dist/kokoro-BNTb6egA.mjs.map +0 -1
  156. package/dist/kokoro-CMOGDSgT.js +0 -20212
  157. package/dist/kokoro-CMOGDSgT.js.map +0 -1
  158. package/dist/mcp-BvbriaBy.mjs.map +0 -1
  159. package/dist/one-liner-s-lD8rCC.mjs.map +0 -1
  160. package/dist/repl-DveXw36T.mjs +0 -9
  161. package/dist/skills-CD3Orlex.mjs.map +0 -1
  162. package/dist/stt-Bu-E23Sc.js +0 -433
  163. package/dist/stt-Bu-E23Sc.js.map +0 -1
  164. package/dist/stt-CpLYbGFd.mjs +0 -433
  165. package/dist/stt-CpLYbGFd.mjs.map +0 -1
  166. package/dist/stt-DRPLEEHB.mjs +0 -3
  167. package/dist/tools-Bi1P7Xoy.mjs.map +0 -1
  168. package/dist/transformers.web-DiD1gTwk.js +0 -44695
  169. package/dist/transformers.web-DiD1gTwk.js.map +0 -1
  170. package/dist/transformers.web-u34VxRFM.js +0 -3
  171. package/dist/tts-CqroPaSK.js +0 -724
  172. package/dist/tts-CqroPaSK.js.map +0 -1
  173. package/dist/tts-DXgsKGCe.mjs +0 -3
  174. package/dist/tts-DeGANMNV.mjs +0 -730
  175. package/dist/tts-DeGANMNV.mjs.map +0 -1
  176. package/dist/types-CiTc7ez3.d.mts.map +0 -1
  177. /package/dist/{auto-update-S9s5-g0C.mjs → auto-update-BVaLXcDE.mjs} +0 -0
  178. /package/dist/{chunk-CkXuGtQK.mjs → chunk-B9cbKln6.mjs} +0 -0
  179. /package/dist/{microphone-DaMZFRuR.mjs → microphone-Bqmoz9_K.mjs} +0 -0
@@ -0,0 +1,718 @@
1
+ # WGSL Kernel Reference
2
+
3
+ Complete reference for all 12 WGSL compute shaders in `src/gpu/kernels/wgsl/`. Each section covers the algorithm, binding layout, uniform struct, dispatch formula, example dispatches, and optimization opportunities.
4
+
5
+ ---
6
+
7
+ ## Table of Contents
8
+
9
+ 1. [Embedding](#1-embedding)
10
+ 2. [MatMul (f32)](#2-matmul-f32)
11
+ 3. [MatMulInt4](#3-matmulint4)
12
+ 4. [RMSNorm](#4-rmsnorm)
13
+ 5. [LayerNorm](#5-layernorm)
14
+ 6. [RoPE](#6-rope)
15
+ 7. [Attention](#7-attention)
16
+ 8. [Softmax](#8-softmax)
17
+ 9. [SiLU](#9-silu)
18
+ 10. [GELU](#10-gelu)
19
+ 11. [Add](#11-add)
20
+ 12. [Mul](#12-mul)
21
+ 13. [Micro-Benchmarking Approach](#13-micro-benchmarking-approach)
22
+
23
+ ---
24
+
25
+ ## 1. Embedding
26
+
27
+ **File:** `embedding.wgsl`
28
+
29
+ **Computes:** `output[t, d] = weight[input_ids[t], d]` for all positions `t` and dimensions `d`.
30
+
31
+ **Algorithm:** Flat parallel gather. Each thread handles one element of the output matrix. Thread `idx` maps to position `(idx / hidden_size, idx % hidden_size)`, looks up the token ID at that position, and copies one element from the corresponding row of the embedding weight matrix.
32
+
33
+ ### Binding Layout
34
+
35
+ | Binding | Name | Type | Access | Shape |
36
+ |---------|------|------|--------|-------|
37
+ | 0 | `input_ids` | `array<u32>` | read | `[T]` |
38
+ | 1 | `weight` | `array<f32>` | read | `[vocab_size, hidden_size]` (row-major) |
39
+ | 2 | `output` | `array<f32>` | read_write | `[T, hidden_size]` |
40
+ | 3 | `params` | uniform | - | `Params` struct |
41
+
42
+ ### Uniform Struct
43
+
44
+ ```wgsl
45
+ struct Params {
46
+ seq_len: u32, // T
47
+ hidden_size: u32, // D
48
+ }
49
+ ```
50
+
51
+ ### Dispatch Formula
52
+
53
+ ```
54
+ workgroups_x = ceil(T * hidden_size / 256)
55
+ workgroups_y = 1
56
+ workgroups_z = 1
57
+ ```
58
+
59
+ ### Example Dispatch
60
+
61
+ | T | hidden_size | Total elements | Workgroups |
62
+ |---|-------------|---------------|------------|
63
+ | 1 | 896 | 896 | (4, 1, 1) |
64
+ | 50 | 896 | 44,800 | (176, 1, 1) |
65
+ | 2048 | 896 | 1,835,008 | (7,169, 1, 1) |
66
+
67
+ ### Optimization Notes
68
+
69
+ - Could use `vec4<f32>` loads to process 4 elements per thread (4x fewer threads)
70
+ - Weight matrix could be stored as f16 and converted on read for 2x bandwidth savings
71
+ - For very large vocabularies (150K+), the weight matrix itself is the bottleneck; consider caching hot rows in shared memory
72
+
73
+ ---
74
+
75
+ ## 2. MatMul (f32)
76
+
77
+ **File:** `matmul.wgsl`
78
+
79
+ **Computes:** `C[M, N] = A[M, K] * B[K, N]` (row-major matrices)
80
+
81
+ **Algorithm:** Classic 16x16 tiled matrix multiply with shared memory. Each workgroup computes a 16x16 tile of the output. For each tile of the K dimension, threads cooperatively load a 16x16 tile of A and a 16x16 tile of B into shared memory, synchronize, then compute partial dot products.
82
+
83
+ ### Binding Layout
84
+
85
+ | Binding | Name | Type | Access | Shape |
86
+ |---------|------|------|--------|-------|
87
+ | 0 | `A` | `array<f32>` | read | `[M, K]` |
88
+ | 1 | `B` | `array<f32>` | read | `[K, N]` |
89
+ | 2 | `C` | `array<f32>` | read_write | `[M, N]` |
90
+ | 3 | `params` | uniform | - | `Params` struct |
91
+
92
+ ### Uniform Struct
93
+
94
+ ```wgsl
95
+ struct Params {
96
+ M: u32,
97
+ K: u32,
98
+ N: u32,
99
+ }
100
+ ```
101
+
102
+ ### Dispatch Formula
103
+
104
+ ```
105
+ workgroups_x = ceil(N / 16)
106
+ workgroups_y = ceil(M / 16)
107
+ workgroups_z = 1
108
+ ```
109
+
110
+ ### Example Dispatch
111
+
112
+ | Operation | M | K | N | Workgroups |
113
+ |-----------|---|---|---|------------|
114
+ | Q projection (T=1) | 1 | 896 | 896 | (56, 1, 1) |
115
+ | Q projection (T=50) | 50 | 896 | 896 | (56, 4, 1) |
116
+ | Gate proj (T=1) | 1 | 896 | 4864 | (304, 1, 1) |
117
+ | Down proj (T=1) | 1 | 4864 | 896 | (56, 1, 1) |
118
+ | LM head (T=1) | 1 | 896 | 151936 | (9,496, 1, 1) |
119
+
120
+ ### Optimization Notes
121
+
122
+ - **Register blocking**: Process 4x4 or 8x8 output elements per thread to improve arithmetic intensity
123
+ - **Vectorized loads**: Use `vec4<f32>` for 128-bit loads from shared memory
124
+ - **Double buffering**: Overlap tile loading with computation by using two shared memory tiles
125
+ - **f16 accumulation**: When `shader-f16` is available, load weights as f16 and accumulate in f32 for 2x bandwidth
126
+ - **Batched matmul**: Fuse multiple small matmuls (Q, K, V projections) into one kernel
127
+
128
+ ---
129
+
130
+ ## 3. MatMulInt4
131
+
132
+ **File:** `matmul_int4.wgsl`
133
+
134
+ **Computes:** `C[M, N] = A[M, K] * dequant(B_q[K, N])` where B_q is packed INT4 with per-group scales and zeros.
135
+
136
+ **Algorithm:** Each thread computes one element of C. For each element, it iterates over the K dimension, dequantizing B on-the-fly: extract a 4-bit nibble from the packed u32, apply `(nibble - zero) * scale` using per-group parameters, then multiply-accumulate with the corresponding A element.
137
+
138
+ ### Binding Layout
139
+
140
+ | Binding | Name | Type | Access | Shape |
141
+ |---------|------|------|--------|-------|
142
+ | 0 | `A` | `array<f32>` | read | `[M, K]` (activations) |
143
+ | 1 | `B_q` | `array<u32>` | read | Packed INT4 `[K, N]`, 8 values per u32 |
144
+ | 2 | `scales` | `array<f32>` | read | One per group |
145
+ | 3 | `zeros` | `array<f32>` | read | One per group |
146
+ | 4 | `C` | `array<f32>` | read_write | `[M, N]` |
147
+ | 5 | `params` | uniform | - | `Params` struct |
148
+
149
+ ### Uniform Struct
150
+
151
+ ```wgsl
152
+ struct Params {
153
+ M: u32,
154
+ K: u32,
155
+ N: u32,
156
+ group_size: u32, // typically 32 or 128
157
+ }
158
+ ```
159
+
160
+ ### INT4 Packing
161
+
162
+ Each `u32` in `B_q` stores 8 INT4 values. The `extract_nibble` function extracts a single 4-bit value:
163
+
164
+ ```wgsl
165
+ fn extract_nibble(packed: u32, pos: u32) -> f32 {
166
+ let shift = pos * 4u;
167
+ let nibble = (packed >> shift) & 0xFu;
168
+ return f32(nibble);
169
+ }
170
+ ```
171
+
172
+ ### Dequantization
173
+
174
+ For element at flat index `flat_idx` in B:
175
+ ```
176
+ group_idx = flat_idx / group_size
177
+ value = (nibble - zeros[group_idx]) * scales[group_idx]
178
+ ```
179
+
180
+ ### Dispatch Formula
181
+
182
+ ```
183
+ workgroups_x = ceil(N / 16)
184
+ workgroups_y = ceil(M / 16)
185
+ workgroups_z = 1
186
+ ```
187
+
188
+ ### Optimization Notes
189
+
190
+ - **This kernel is not tiled**: Each thread iterates over the full K dimension, resulting in poor data reuse. A tiled version with shared memory (similar to `matmul.wgsl`) would dramatically improve performance.
191
+ - **Vectorized nibble extraction**: Process 8 nibbles from one u32 load simultaneously
192
+ - **Pre-dequantize tiles**: Load a tile of B_q into shared memory, dequantize in parallel, then do the tiled multiply
193
+ - This kernel is defined but **not yet wired** to graph generators. It will be used when GPTQ/AWQ quantized models are supported.
194
+
195
+ ---
196
+
197
+ ## 4. RMSNorm
198
+
199
+ **File:** `rmsnorm.wgsl`
200
+
201
+ **Computes:** `output[t, d] = (input[t, d] / rms(input[t, :])) * weight[d]` where `rms(x) = sqrt(mean(x^2) + eps)`
202
+
203
+ **Algorithm:** One workgroup per row (token position). Each thread accumulates squared values for a subset of the hidden dimension using a strided pattern (thread `tid` processes elements `tid, tid+256, tid+512, ...`). A tree reduction in shared memory computes the total sum of squares. Then each thread normalizes and scales its assigned elements.
204
+
205
+ ### Binding Layout
206
+
207
+ | Binding | Name | Type | Access | Shape |
208
+ |---------|------|------|--------|-------|
209
+ | 0 | `input` | `array<f32>` | read | `[T, hidden_size]` |
210
+ | 1 | `weight` | `array<f32>` | read | `[hidden_size]` |
211
+ | 2 | `output` | `array<f32>` | read_write | `[T, hidden_size]` |
212
+ | 3 | `params` | uniform | - | `Params` struct |
213
+
214
+ ### Uniform Struct
215
+
216
+ ```wgsl
217
+ struct Params {
218
+ seq_len: u32, // T
219
+ hidden_size: u32, // D
220
+ eps_bits: u32, // f32 epsilon reinterpreted as u32 (for uniform alignment)
221
+ _pad: u32, // padding to 16-byte alignment
222
+ }
223
+ ```
224
+
225
+ Note: `eps` is passed as `bitcast<u32>(eps_f32)` because WebGPU uniform buffers require 16-byte alignment and mixing f32/u32 in a struct can cause alignment issues. The shader uses `bitcast<f32>(params.eps_bits)` to recover the float value.
226
+
227
+ ### Dispatch Formula
228
+
229
+ ```
230
+ workgroups_x = T // one workgroup per row
231
+ workgroups_y = 1
232
+ workgroups_z = 1
233
+ ```
234
+
235
+ ### Example Dispatch
236
+
237
+ | T | hidden_size | Workgroups |
238
+ |---|-------------|------------|
239
+ | 1 | 896 | (1, 1, 1) |
240
+ | 50 | 896 | (50, 1, 1) |
241
+ | 2048 | 896 | (2048, 1, 1) |
242
+
243
+ ### Optimization Notes
244
+
245
+ - For hidden_size <= 256, the tree reduction is oversized (many idle threads)
246
+ - Could use `vec4<f32>` accumulation for 4x fewer loop iterations
247
+ - Fusing RMSNorm with the following MatMul would save one global memory round-trip
248
+
249
+ ---
250
+
251
+ ## 5. LayerNorm
252
+
253
+ **File:** `layernorm.wgsl`
254
+
255
+ **Computes:** `output[t, d] = ((input[t, d] - mean) / sqrt(variance + eps)) * weight[d] + bias[d]`
256
+
257
+ **Algorithm:** Two-pass reduction per row. First pass: tree reduction to compute mean. Second pass: tree reduction to compute variance (sum of squared differences from mean). Then each thread normalizes, scales, and shifts its assigned elements.
258
+
259
+ ### Binding Layout
260
+
261
+ | Binding | Name | Type | Access | Shape |
262
+ |---------|------|------|--------|-------|
263
+ | 0 | `input` | `array<f32>` | read | `[T, hidden_size]` |
264
+ | 1 | `weight` | `array<f32>` | read | `[hidden_size]` |
265
+ | 2 | `bias` | `array<f32>` | read | `[hidden_size]` |
266
+ | 3 | `output` | `array<f32>` | read_write | `[T, hidden_size]` |
267
+ | 4 | `params` | uniform | - | `Params` struct |
268
+
269
+ ### Uniform Struct
270
+
271
+ ```wgsl
272
+ struct Params {
273
+ seq_len: u32,
274
+ hidden_size: u32,
275
+ eps_bits: u32,
276
+ _pad: u32,
277
+ }
278
+ ```
279
+
280
+ ### Dispatch Formula
281
+
282
+ ```
283
+ workgroups_x = T
284
+ workgroups_y = 1
285
+ workgroups_z = 1
286
+ ```
287
+
288
+ ### Optimization Notes
289
+
290
+ - Uses two separate shared memory arrays (`shared_sum` and `shared_sq_sum`); could use a single-pass Welford algorithm to compute both mean and variance simultaneously
291
+ - Same vectorization and fusion opportunities as RMSNorm
292
+
293
+ ---
294
+
295
+ ## 6. RoPE
296
+
297
+ **File:** `rope.wgsl`
298
+
299
+ **Computes:** For each head, rotates pairs of dimensions by position-dependent angles:
300
+ ```
301
+ q_out[2i] = q[2i] * cos(theta) - q[2i+1] * sin(theta)
302
+ q_out[2i+1] = q[2i] * sin(theta) + q[2i+1] * cos(theta)
303
+ ```
304
+ where `theta = pos * base^(-2i/dim)`.
305
+
306
+ **Algorithm:** Each thread handles one pair of dimensions for one head at one position. Handles Q and K separately since they may have different head counts (GQA). The kernel supports a `position_offset` for decode steps where the query is at a later position in the sequence.
307
+
308
+ **Note:** This kernel operates **in-place** -- Q and K are both input and output buffers (`read_write`).
309
+
310
+ ### Binding Layout
311
+
312
+ | Binding | Name | Type | Access | Shape |
313
+ |---------|------|------|--------|-------|
314
+ | 0 | `q` | `array<f32>` | read_write | `[T, num_q_heads * head_dim]` |
315
+ | 1 | `k` | `array<f32>` | read_write | `[T, num_kv_heads * head_dim]` |
316
+ | 2 | `params` | uniform | - | `Params` struct |
317
+
318
+ ### Uniform Struct
319
+
320
+ ```wgsl
321
+ struct Params {
322
+ seq_len: u32,
323
+ num_q_heads: u32,
324
+ num_kv_heads: u32,
325
+ head_dim: u32,
326
+ rope_base_bits: u32, // f32 base reinterpreted as u32
327
+ position_offset: u32, // starting position for decode step
328
+ }
329
+ ```
330
+
331
+ ### Dispatch Formula
332
+
333
+ ```
334
+ total_pairs = T * max(num_q_heads, num_kv_heads) * (head_dim / 2)
335
+ workgroups_x = ceil(total_pairs / 256)
336
+ workgroups_y = 1
337
+ workgroups_z = 1
338
+ ```
339
+
340
+ ### Example Dispatch
341
+
342
+ | T | num_q_heads | num_kv_heads | head_dim | Total pairs | Workgroups |
343
+ |---|-------------|-------------|----------|-------------|------------|
344
+ | 1 | 14 | 2 | 64 | 448 | (2, 1, 1) |
345
+ | 50 | 14 | 2 | 64 | 22,400 | (88, 1, 1) |
346
+
347
+ ### Optimization Notes
348
+
349
+ - Q and K are processed by the same dispatch; with GQA (num_q_heads >> num_kv_heads), the K processing finishes much earlier while Q threads are still running
350
+ - Frequency computation (`pow(base, ...)`) could be pre-computed into a lookup table in shared memory
351
+ - The sin/cos computation could use `sincos()` if available, though WGSL doesn't have it natively
352
+
353
+ ---
354
+
355
+ ## 7. Attention
356
+
357
+ **File:** `attention.wgsl`
358
+
359
+ **Computes:** Scaled dot-product attention with causal masking and GQA:
360
+ ```
361
+ scores = (Q @ K^T) / sqrt(head_dim)
362
+ scores = causal_mask(scores)
363
+ weights = softmax(scores)
364
+ output = weights @ V
365
+ ```
366
+
367
+ **Algorithm:** Each workgroup handles one (query_position, q_head) pair. **Correctness-first implementation**: thread 0 performs the full computation (dot products, max-finding, softmax, V accumulation) while other threads are idle. This is explicitly marked as a TODO for optimization.
368
+
369
+ GQA is handled by mapping each Q head to its corresponding KV head: `kv_head = q_head / (num_q_heads / num_kv_heads)`.
370
+
371
+ ### Binding Layout
372
+
373
+ | Binding | Name | Type | Access | Shape |
374
+ |---------|------|------|--------|-------|
375
+ | 0 | `Q` | `array<f32>` | read | `[T, num_q_heads * head_dim]` |
376
+ | 1 | `K` | `array<f32>` | read | `[S, num_kv_heads * head_dim]` |
377
+ | 2 | `V` | `array<f32>` | read | `[S, num_kv_heads * head_dim]` |
378
+ | 3 | `output` | `array<f32>` | read_write | `[T, num_q_heads * head_dim]` |
379
+ | 4 | `params` | uniform | - | `Params` struct |
380
+
381
+ ### Uniform Struct
382
+
383
+ ```wgsl
384
+ struct Params {
385
+ T: u32, // query seq len (prompt during prefill, 1 during decode)
386
+ S: u32, // key/value seq len (total including cache)
387
+ num_q_heads: u32,
388
+ num_kv_heads: u32,
389
+ head_dim: u32,
390
+ position_offset: u32, // start position of Q in full sequence
391
+ }
392
+ ```
393
+
394
+ ### Dispatch Formula
395
+
396
+ ```
397
+ workgroups_x = T // one per query position
398
+ workgroups_y = num_q_heads // one per Q head
399
+ workgroups_z = 1
400
+ ```
401
+
402
+ ### Example Dispatch
403
+
404
+ | T | S | num_q_heads | Workgroups |
405
+ |---|---|-------------|------------|
406
+ | 1 (decode) | 100 | 14 | (1, 14, 1) |
407
+ | 50 (prefill) | 50 | 14 | (50, 14, 1) |
408
+
409
+ ### Optimization Notes
410
+
411
+ This is the **highest priority kernel for optimization**:
412
+
413
+ - **Current**: Only thread 0 works; 255 threads per workgroup are wasted
414
+ - **Parallel dot products**: Each thread computes one dimension of the Q@K dot product, followed by a shared memory reduction
415
+ - **Tiled attention**: Process S positions in tiles, with online softmax (FlashAttention-style) to avoid materializing the full attention matrix
416
+ - **Shared memory K/V**: Load K/V tiles into shared memory for data reuse across dot products
417
+ - **Multi-query batching**: During prefill, process multiple query positions per workgroup
418
+
419
+ ---
420
+
421
+ ## 8. Softmax
422
+
423
+ **File:** `softmax.wgsl`
424
+
425
+ **Computes:** Row-wise softmax: `output[i] = exp(input[i] - max) / sum(exp(input[:] - max))`
426
+
427
+ **Algorithm:** Three-pass parallel reduction, one workgroup per row:
428
+ 1. **Pass 1 (max)**: Each thread finds the max of its assigned elements; tree reduction to get row max
429
+ 2. **Pass 2 (sum)**: Each thread computes `sum(exp(x - max))` for its elements; tree reduction to get total sum
430
+ 3. **Pass 3 (normalize)**: Each thread computes `exp(x - max) / total_sum` for its elements
431
+
432
+ ### Binding Layout
433
+
434
+ | Binding | Name | Type | Access | Shape |
435
+ |---------|------|------|--------|-------|
436
+ | 0 | `input` | `array<f32>` | read | `[num_rows, row_size]` |
437
+ | 1 | `output` | `array<f32>` | read_write | `[num_rows, row_size]` |
438
+ | 2 | `params` | uniform | - | `Params` struct |
439
+
440
+ ### Uniform Struct
441
+
442
+ ```wgsl
443
+ struct Params {
444
+ num_rows: u32,
445
+ row_size: u32,
446
+ }
447
+ ```
448
+
449
+ ### Dispatch Formula
450
+
451
+ ```
452
+ workgroups_x = num_rows
453
+ workgroups_y = 1
454
+ workgroups_z = 1
455
+ ```
456
+
457
+ ### Optimization Notes
458
+
459
+ - Pass 3 redundantly recomputes `exp(x - max)` -- could store intermediate exp values in a buffer or fuse passes 2 and 3 using online normalization
460
+ - For small row sizes (< 256), many threads are idle
461
+ - Currently a standalone kernel; in practice, softmax is most commonly needed inside the attention kernel, so fusing them eliminates a global memory round-trip
462
+
463
+ ---
464
+
465
+ ## 9. SiLU
466
+
467
+ **File:** `silu.wgsl`
468
+
469
+ **Computes:** `output[i] = input[i] / (1 + exp(-input[i]))` (equivalent to `x * sigmoid(x)`)
470
+
471
+ **Algorithm:** Pure element-wise. Each thread processes one element.
472
+
473
+ ### Binding Layout
474
+
475
+ | Binding | Name | Type | Access | Shape |
476
+ |---------|------|------|--------|-------|
477
+ | 0 | `input` | `array<f32>` | read | `[count]` |
478
+ | 1 | `output` | `array<f32>` | read_write | `[count]` |
479
+ | 2 | `params` | uniform | - | `Params` struct |
480
+
481
+ ### Uniform Struct
482
+
483
+ ```wgsl
484
+ struct Params {
485
+ count: u32,
486
+ }
487
+ ```
488
+
489
+ ### Dispatch Formula
490
+
491
+ ```
492
+ workgroups_x = ceil(count / 256)
493
+ workgroups_y = 1
494
+ workgroups_z = 1
495
+ ```
496
+
497
+ ### Example Dispatch
498
+
499
+ | T | intermediate_size | count | Workgroups |
500
+ |---|-------------------|-------|------------|
501
+ | 1 | 4864 | 4,864 | (19, 1, 1) |
502
+ | 50 | 4864 | 243,200 | (951, 1, 1) |
503
+
504
+ ### Optimization Notes
505
+
506
+ - Could process `vec4<f32>` per thread for 4x fewer threads
507
+ - Could be fused with the following Mul (SwiGLU combine): `output = silu(gate) * up`
508
+
509
+ ---
510
+
511
+ ## 10. GELU
512
+
513
+ **File:** `gelu.wgsl`
514
+
515
+ **Computes:** Approximate GELU: `output = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`
516
+
517
+ **Algorithm:** Pure element-wise. Uses the tanh approximation, which is the standard for transformer models.
518
+
519
+ ### Binding Layout
520
+
521
+ | Binding | Name | Type | Access | Shape |
522
+ |---------|------|------|--------|-------|
523
+ | 0 | `input` | `array<f32>` | read | `[count]` |
524
+ | 1 | `output` | `array<f32>` | read_write | `[count]` |
525
+ | 2 | `params` | uniform | - | `Params` struct |
526
+
527
+ ### Uniform Struct
528
+
529
+ ```wgsl
530
+ struct Params {
531
+ count: u32,
532
+ }
533
+ ```
534
+
535
+ ### Dispatch Formula
536
+
537
+ ```
538
+ workgroups_x = ceil(count / 256)
539
+ ```
540
+
541
+ ### Constants
542
+
543
+ ```wgsl
544
+ const SQRT_2_OVER_PI: f32 = 0.7978845608;
545
+ const GELU_COEFF: f32 = 0.044715;
546
+ ```
547
+
548
+ ### Optimization Notes
549
+
550
+ - Same vectorization opportunity as SiLU
551
+ - Used by Phi and GPT-2 style models (not Qwen/LLaMA which use SiLU)
552
+
553
+ ---
554
+
555
+ ## 11. Add
556
+
557
+ **File:** `add.wgsl`
558
+
559
+ **Computes:** `output[i] = a[i] + b[i]`
560
+
561
+ **Algorithm:** Pure element-wise. Used for residual connections.
562
+
563
+ ### Binding Layout
564
+
565
+ | Binding | Name | Type | Access | Shape |
566
+ |---------|------|------|--------|-------|
567
+ | 0 | `a` | `array<f32>` | read | `[count]` |
568
+ | 1 | `b` | `array<f32>` | read | `[count]` |
569
+ | 2 | `output` | `array<f32>` | read_write | `[count]` |
570
+ | 3 | `params` | uniform | - | `Params` struct |
571
+
572
+ ### Uniform Struct
573
+
574
+ ```wgsl
575
+ struct Params {
576
+ count: u32,
577
+ }
578
+ ```
579
+
580
+ ### Dispatch Formula
581
+
582
+ ```
583
+ workgroups_x = ceil(count / 256)
584
+ ```
585
+
586
+ ### Optimization Notes
587
+
588
+ - Could use `vec4<f32>` for 128-bit operations
589
+ - Residual add could be fused with the preceding kernel (e.g., output projection writes directly to `resid + output`)
590
+
591
+ ---
592
+
593
+ ## 12. Mul
594
+
595
+ **File:** `mul.wgsl`
596
+
597
+ **Computes:** `output[i] = a[i] * b[i]`
598
+
599
+ **Algorithm:** Pure element-wise. Used for the SwiGLU combine step (`silu(gate) * up`).
600
+
601
+ ### Binding Layout
602
+
603
+ | Binding | Name | Type | Access | Shape |
604
+ |---------|------|------|--------|-------|
605
+ | 0 | `a` | `array<f32>` | read | `[count]` |
606
+ | 1 | `b` | `array<f32>` | read | `[count]` |
607
+ | 2 | `output` | `array<f32>` | read_write | `[count]` |
608
+ | 3 | `params` | uniform | - | `Params` struct |
609
+
610
+ ### Uniform Struct
611
+
612
+ ```wgsl
613
+ struct Params {
614
+ count: u32,
615
+ }
616
+ ```
617
+
618
+ ### Dispatch Formula
619
+
620
+ ```
621
+ workgroups_x = ceil(count / 256)
622
+ ```
623
+
624
+ ### Optimization Notes
625
+
626
+ - Same vectorization and fusion opportunities as Add
627
+ - Could be fused with preceding SiLU: single kernel computes `silu(a[i]) * b[i]`
628
+
629
+ ---
630
+
631
+ ## 13. Micro-Benchmarking Approach
632
+
633
+ To measure individual kernel performance, use this pattern:
634
+
635
+ ### Test Harness
636
+
637
+ ```typescript
638
+ import { initGPU, createStorageBuffer, createUniformBuffer, getOrCreatePipeline, createBindGroup } from "./device.js";
639
+
640
+ async function benchmarkMatmul(M: number, K: number, N: number, iterations: number = 100) {
641
+ const ctx = await initGPU();
642
+
643
+ // Allocate buffers with random data
644
+ const A = createStorageBuffer(ctx, "A", M * K * 4, randomF32(M * K));
645
+ const B = createStorageBuffer(ctx, "B", K * N * 4, randomF32(K * N));
646
+ const C = createStorageBuffer(ctx, "C", M * N * 4);
647
+ const params = createUniformBuffer(ctx, "params",
648
+ new Uint32Array([M, K, N]).buffer);
649
+
650
+ const pipeline = getOrCreatePipeline(ctx, "matmul", matmulWGSL, "main");
651
+ const bindGroup = createBindGroup(ctx, pipeline,
652
+ [{ buffer: A }, { buffer: B }, { buffer: C }, { buffer: params }]);
653
+
654
+ const wgX = Math.ceil(N / 16);
655
+ const wgY = Math.ceil(M / 16);
656
+
657
+ // Warmup
658
+ for (let i = 0; i < 10; i++) {
659
+ const enc = ctx.device.createCommandEncoder();
660
+ const pass = enc.beginComputePass();
661
+ pass.setPipeline(pipeline);
662
+ pass.setBindGroup(0, bindGroup);
663
+ pass.dispatchWorkgroups(wgX, wgY, 1);
664
+ pass.end();
665
+ ctx.device.queue.submit([enc.finish()]);
666
+ }
667
+ await ctx.device.queue.onSubmittedWorkDone();
668
+
669
+ // Benchmark
670
+ const start = performance.now();
671
+ for (let i = 0; i < iterations; i++) {
672
+ const enc = ctx.device.createCommandEncoder();
673
+ const pass = enc.beginComputePass();
674
+ pass.setPipeline(pipeline);
675
+ pass.setBindGroup(0, bindGroup);
676
+ pass.dispatchWorkgroups(wgX, wgY, 1);
677
+ pass.end();
678
+ ctx.device.queue.submit([enc.finish()]);
679
+ }
680
+ await ctx.device.queue.onSubmittedWorkDone();
681
+ const elapsed = performance.now() - start;
682
+
683
+ // Compute metrics
684
+ const flops = 2 * M * K * N; // multiply-add = 2 ops
685
+ const totalFlops = flops * iterations;
686
+ const gflops = totalFlops / (elapsed * 1e6); // elapsed is in ms
687
+
688
+ console.log(`MatMul ${M}x${K}x${N}:`);
689
+ console.log(` ${iterations} iterations in ${elapsed.toFixed(1)} ms`);
690
+ console.log(` ${(elapsed / iterations).toFixed(3)} ms/iter`);
691
+ console.log(` ${gflops.toFixed(1)} GFLOPS`);
692
+
693
+ A.destroy(); B.destroy(); C.destroy(); params.destroy();
694
+ }
695
+ ```
696
+
697
+ ### Expected Performance Targets
698
+
699
+ For reference, approximate theoretical peak GFLOPS for common GPUs:
700
+
701
+ | GPU | f32 TFLOPS | Expected matmul efficiency |
702
+ |-----|-----------|---------------------------|
703
+ | Apple M1 (8-core) | ~2.6 | 30-50% with 16x16 tiling |
704
+ | Apple M2 (10-core) | ~3.6 | 30-50% |
705
+ | NVIDIA RTX 3060 | ~12.7 | 40-60% |
706
+ | Intel Arc A770 | ~17.2 | 30-50% |
707
+
708
+ The 16x16 tiling strategy without register blocking typically achieves 30-50% of peak throughput. Larger tiles (32x32, 64x64) with register blocking can push this to 70-80%.
709
+
710
+ ### Key Metrics Per Kernel
711
+
712
+ | Kernel | Metric | Formula |
713
+ |--------|--------|---------|
714
+ | MatMul | GFLOPS | `2 * M * K * N / (time_ms * 1e6)` |
715
+ | Embedding | GB/s | `T * hidden_size * 4 / (time_ms * 1e6)` |
716
+ | RMSNorm | GB/s | `T * hidden_size * 4 * 3 / (time_ms * 1e6)` (read input + weight + write output) |
717
+ | Attention | GFLOPS | `2 * T * S * head_dim * num_heads / (time_ms * 1e6)` |
718
+ | Element-wise | GB/s | `count * 4 * (num_inputs + 1) / (time_ms * 1e6)` |