@simulatte/doppler 0.1.5 → 0.1.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. package/README.md +23 -8
  2. package/package.json +7 -4
  3. package/src/config/kernels/kernel-ref-digests.js +39 -39
  4. package/src/config/kernels/registry.json +42 -2
  5. package/src/config/loader.js +31 -2
  6. package/src/config/merge.js +18 -0
  7. package/src/config/presets/models/qwen3.json +9 -2
  8. package/src/config/presets/models/transformer.json +5 -0
  9. package/src/config/required-inference-fields-contract-check.js +6 -0
  10. package/src/config/schema/inference-defaults.schema.js +3 -0
  11. package/src/config/schema/inference.schema.d.ts +9 -0
  12. package/src/config/schema/kernel-path.schema.d.ts +6 -0
  13. package/src/config/schema/manifest.schema.d.ts +6 -0
  14. package/src/config/schema/manifest.schema.js +3 -0
  15. package/src/converter/rope-config.js +42 -0
  16. package/src/gpu/device.js +58 -0
  17. package/src/gpu/kernels/attention.js +98 -0
  18. package/src/gpu/kernels/bias_add.wgsl +8 -6
  19. package/src/gpu/kernels/bias_add_f16.wgsl +8 -5
  20. package/src/gpu/kernels/conv2d.js +1 -1
  21. package/src/gpu/kernels/conv2d.wgsl +7 -8
  22. package/src/gpu/kernels/conv2d_f16.wgsl +7 -8
  23. package/src/gpu/kernels/depthwise_conv2d.js +2 -1
  24. package/src/gpu/kernels/depthwise_conv2d.wgsl +6 -9
  25. package/src/gpu/kernels/depthwise_conv2d_f16.wgsl +6 -9
  26. package/src/gpu/kernels/grouped_pointwise_conv2d.js +2 -1
  27. package/src/gpu/kernels/grouped_pointwise_conv2d.wgsl +6 -9
  28. package/src/gpu/kernels/grouped_pointwise_conv2d_f16.wgsl +6 -9
  29. package/src/gpu/kernels/matmul.js +25 -0
  30. package/src/gpu/kernels/pixel_shuffle.js +1 -1
  31. package/src/gpu/kernels/pixel_shuffle.wgsl +4 -5
  32. package/src/gpu/kernels/pixel_shuffle_f16.wgsl +4 -5
  33. package/src/gpu/kernels/relu.js +15 -2
  34. package/src/gpu/kernels/relu.wgsl +2 -1
  35. package/src/gpu/kernels/relu_f16.wgsl +2 -1
  36. package/src/gpu/kernels/repeat_channels.js +1 -1
  37. package/src/gpu/kernels/repeat_channels.wgsl +4 -5
  38. package/src/gpu/kernels/repeat_channels_f16.wgsl +4 -5
  39. package/src/gpu/kernels/residual.js +44 -8
  40. package/src/gpu/kernels/residual.wgsl +6 -3
  41. package/src/gpu/kernels/residual_f16.wgsl +2 -1
  42. package/src/gpu/kernels/residual_f16_vec4.wgsl +2 -1
  43. package/src/gpu/kernels/residual_vec4.wgsl +2 -1
  44. package/src/gpu/kernels/rmsnorm.js +58 -6
  45. package/src/gpu/kernels/rmsnorm.wgsl +14 -6
  46. package/src/gpu/kernels/rmsnorm_f16.wgsl +10 -2
  47. package/src/gpu/kernels/rope.d.ts +2 -0
  48. package/src/gpu/kernels/rope.js +11 -1
  49. package/src/gpu/kernels/rope.wgsl +56 -40
  50. package/src/gpu/kernels/sana_linear_attention.js +1 -2
  51. package/src/gpu/kernels/sana_linear_attention_apply.wgsl +4 -5
  52. package/src/gpu/kernels/sana_linear_attention_apply_f16.wgsl +4 -5
  53. package/src/gpu/kernels/sana_linear_attention_summary.wgsl +4 -0
  54. package/src/gpu/kernels/sana_linear_attention_summary_f16.wgsl +4 -0
  55. package/src/gpu/kernels/silu.d.ts +1 -0
  56. package/src/gpu/kernels/silu.js +32 -14
  57. package/src/gpu/kernels/silu.wgsl +19 -9
  58. package/src/gpu/kernels/silu_f16.wgsl +19 -9
  59. package/src/gpu/kernels/transpose.js +15 -2
  60. package/src/gpu/kernels/transpose.wgsl +5 -6
  61. package/src/gpu/kernels/upsample2d.js +2 -1
  62. package/src/gpu/kernels/upsample2d.wgsl +6 -9
  63. package/src/gpu/kernels/upsample2d_f16.wgsl +6 -9
  64. package/src/gpu/kernels/utils.js +16 -1
  65. package/src/inference/browser-harness.js +47 -1
  66. package/src/inference/pipelines/diffusion/pipeline.js +15 -6
  67. package/src/inference/pipelines/diffusion/text-encoder-gpu.d.ts +5 -0
  68. package/src/inference/pipelines/diffusion/text-encoder-gpu.js +27 -15
  69. package/src/inference/pipelines/text/attention/record.js +11 -2
  70. package/src/inference/pipelines/text/attention/run.js +11 -2
  71. package/src/inference/pipelines/text/chat-format.js +25 -1
  72. package/src/inference/pipelines/text/config.d.ts +4 -0
  73. package/src/inference/pipelines/text/config.js +68 -1
  74. package/src/inference/pipelines/text/execution-plan.js +23 -31
  75. package/src/inference/pipelines/text/execution-v0.js +29 -2
  76. package/src/inference/pipelines/text/ffn/standard.js +3 -0
  77. package/src/inference/pipelines/text/init.d.ts +4 -0
  78. package/src/inference/pipelines/text/init.js +56 -9
  79. package/src/inference/pipelines/text/layer.js +11 -0
  80. package/src/inference/pipelines/text.js +4 -0
  81. package/src/inference/tokenizers/bundled.js +156 -33
  82. package/src/rules/tooling/command-runtime.rules.json +18 -0
  83. package/src/tooling/command-api.d.ts +27 -1
  84. package/src/tooling/command-api.js +142 -3
  85. package/src/tooling/node-browser-command-runner.d.ts +4 -0
  86. package/src/tooling/node-browser-command-runner.js +58 -3
  87. package/src/tooling/node-command-runner.js +15 -0
  88. package/src/tooling/node-webgpu.js +9 -87
  89. package/src/training/checkpoint-watch.d.ts +7 -0
  90. package/src/training/checkpoint-watch.js +106 -0
  91. package/src/training/checkpoint.d.ts +6 -1
  92. package/src/training/checkpoint.js +12 -2
  93. package/src/training/distillation/artifacts.d.ts +71 -0
  94. package/src/training/distillation/artifacts.js +132 -0
  95. package/src/training/distillation/checkpoint-watch.d.ts +10 -0
  96. package/src/training/distillation/checkpoint-watch.js +57 -0
  97. package/src/training/distillation/dataset.d.ts +59 -0
  98. package/src/training/distillation/dataset.js +337 -0
  99. package/src/training/distillation/eval.d.ts +34 -0
  100. package/src/training/distillation/eval.js +310 -0
  101. package/src/training/distillation/index.d.ts +29 -0
  102. package/src/training/distillation/index.js +29 -0
  103. package/src/training/distillation/runtime.d.ts +20 -0
  104. package/src/training/distillation/runtime.js +121 -0
  105. package/src/training/distillation/scoreboard.d.ts +6 -0
  106. package/src/training/distillation/scoreboard.js +8 -0
  107. package/src/training/distillation/stage-a.d.ts +45 -0
  108. package/src/training/distillation/stage-a.js +338 -0
  109. package/src/training/distillation/stage-b.d.ts +24 -0
  110. package/src/training/distillation/stage-b.js +20 -0
  111. package/src/training/index.d.ts +10 -0
  112. package/src/training/index.js +10 -0
  113. package/src/training/lora-pipeline.d.ts +40 -0
  114. package/src/training/lora-pipeline.js +796 -0
  115. package/src/training/operator-artifacts.d.ts +62 -0
  116. package/src/training/operator-artifacts.js +140 -0
  117. package/src/training/operator-command.d.ts +5 -0
  118. package/src/training/operator-command.js +453 -0
  119. package/src/training/operator-eval.d.ts +48 -0
  120. package/src/training/operator-eval.js +230 -0
  121. package/src/training/operator-scoreboard.d.ts +5 -0
  122. package/src/training/operator-scoreboard.js +44 -0
  123. package/src/training/runner.d.ts +52 -0
  124. package/src/training/runner.js +29 -4
  125. package/src/training/suite.d.ts +112 -0
  126. package/src/training/suite.js +9 -9
  127. package/src/training/workloads.d.ts +164 -0
  128. package/src/training/workloads.js +539 -0
  129. package/src/version.js +1 -1
  130. package/tools/doppler-cli.js +137 -40
@@ -26,8 +26,8 @@ struct Uniforms {
26
26
  start_pos: u32, // Starting position (for decode)
27
27
  rope_base: f32, // Base frequency (default 10000)
28
28
  rope_scale: f32, // Scaling factor for extended context
29
- _pad0: u32,
30
- _pad1: u32,
29
+ rotary_dim: u32, // Rotary slice within head_dim
30
+ interleaved: u32, // 1 = adjacent pairs, 0 = rotate-half
31
31
  }
32
32
 
33
33
  @group(0) @binding(0) var<uniform> u: Uniforms;
@@ -46,7 +46,8 @@ fn main(
46
46
  let start_pos = u.start_pos;
47
47
 
48
48
  // Global thread index (one thread per complex pair)
49
- let half_dim = head_dim / 2u;
49
+ let rotary_dim = u.rotary_dim;
50
+ let half_dim = rotary_dim / 2u;
50
51
  let total_pairs = seq_len * num_heads * half_dim;
51
52
  let idx = global_id.x;
52
53
 
@@ -68,16 +69,18 @@ fn main(
68
69
 
69
70
  // Apply "rotate-half" layout: pair (x[i], x[i + half_dim])
70
71
  let base_idx = pos * num_heads * head_dim + head_idx * head_dim;
71
- let x0 = input[base_idx + pair_idx];
72
- let x1 = input[base_idx + pair_idx + half_dim];
72
+ let first_idx = select(pair_idx, pair_idx * 2u, u.interleaved == 1u);
73
+ let second_idx = select(pair_idx + half_dim, pair_idx * 2u + 1u, u.interleaved == 1u);
74
+ let x0 = input[base_idx + first_idx];
75
+ let x1 = input[base_idx + second_idx];
73
76
 
74
77
  // Apply rotation
75
78
  let y0 = x0 * cos_val - x1 * sin_val;
76
79
  let y1 = x0 * sin_val + x1 * cos_val;
77
80
 
78
81
  // Write back
79
- input[base_idx + pair_idx] = y0;
80
- input[base_idx + pair_idx + half_dim] = y1;
82
+ input[base_idx + first_idx] = y0;
83
+ input[base_idx + second_idx] = y1;
81
84
  }
82
85
 
83
86
  // Compute frequencies on-the-fly (no precomputation needed)
@@ -91,9 +94,10 @@ fn rope_compute_freqs(
91
94
  let start_pos = u.start_pos;
92
95
  let rope_base = u.rope_base;
93
96
  let rope_scale = u.rope_scale;
97
+ let rotary_dim = u.rotary_dim;
94
98
 
95
99
  let idx = global_id.x;
96
- let half_dim = head_dim / 2u;
100
+ let half_dim = rotary_dim / 2u;
97
101
  let total_pairs = seq_len * num_heads * half_dim;
98
102
 
99
103
  if (idx >= total_pairs) {
@@ -109,7 +113,7 @@ fn rope_compute_freqs(
109
113
  let actual_pos = f32(start_pos + pos) / rope_scale;
110
114
 
111
115
  // Compute frequency: 1 / (base^(2*pair_idx/head_dim))
112
- let exponent = f32(pair_idx * 2u) / f32(head_dim);
116
+ let exponent = f32(pair_idx * 2u) / f32(rotary_dim);
113
117
  let freq = 1.0 / pow(rope_base, exponent);
114
118
  let theta = actual_pos * freq;
115
119
 
@@ -118,12 +122,14 @@ fn rope_compute_freqs(
118
122
 
119
123
  // Apply "rotate-half" layout: pair (x[i], x[i + half_dim])
120
124
  let base_idx = pos * num_heads * head_dim + head_idx * head_dim;
121
- let x0 = input[base_idx + pair_idx];
122
- let x1 = input[base_idx + pair_idx + half_dim];
125
+ let first_idx = select(pair_idx, pair_idx * 2u, u.interleaved == 1u);
126
+ let second_idx = select(pair_idx + half_dim, pair_idx * 2u + 1u, u.interleaved == 1u);
127
+ let x0 = input[base_idx + first_idx];
128
+ let x1 = input[base_idx + second_idx];
123
129
 
124
130
  // Apply rotation
125
- input[base_idx + pair_idx] = x0 * cos_val - x1 * sin_val;
126
- input[base_idx + pair_idx + half_dim] = x0 * sin_val + x1 * cos_val;
131
+ input[base_idx + first_idx] = x0 * cos_val - x1 * sin_val;
132
+ input[base_idx + second_idx] = x0 * sin_val + x1 * cos_val;
127
133
  }
128
134
 
129
135
  // Apply RoPE to both Q and K in one pass
@@ -138,10 +144,11 @@ fn rope_qk(
138
144
  let start_pos = u.start_pos;
139
145
  let rope_base = u.rope_base;
140
146
  let rope_scale = u.rope_scale;
147
+ let rotary_dim = u.rotary_dim;
141
148
 
142
149
  let idx = global_id.x;
143
150
  // Each thread handles one Q-K pair at one dimension pair
144
- let half_dim = head_dim / 2u;
151
+ let half_dim = rotary_dim / 2u;
145
152
  let total_pairs = seq_len * num_heads * half_dim;
146
153
 
147
154
  if (idx >= total_pairs) {
@@ -156,7 +163,7 @@ fn rope_qk(
156
163
  let actual_pos = f32(start_pos + pos) / rope_scale;
157
164
 
158
165
  // Compute frequency
159
- let exponent = f32(pair_idx * 2u) / f32(head_dim);
166
+ let exponent = f32(pair_idx * 2u) / f32(rotary_dim);
160
167
  let freq = 1.0 / pow(rope_base, exponent);
161
168
  let theta = actual_pos * freq;
162
169
 
@@ -168,16 +175,18 @@ fn rope_qk(
168
175
  let k_base_idx = q_base_idx + head_dim; // K starts after Q
169
176
 
170
177
  // Process Q
171
- let q0 = input[q_base_idx + pair_idx];
172
- let q1 = input[q_base_idx + pair_idx + half_dim];
173
- input[q_base_idx + pair_idx] = q0 * cos_val - q1 * sin_val;
174
- input[q_base_idx + pair_idx + half_dim] = q0 * sin_val + q1 * cos_val;
178
+ let first_idx = select(pair_idx, pair_idx * 2u, u.interleaved == 1u);
179
+ let second_idx = select(pair_idx + half_dim, pair_idx * 2u + 1u, u.interleaved == 1u);
180
+ let q0 = input[q_base_idx + first_idx];
181
+ let q1 = input[q_base_idx + second_idx];
182
+ input[q_base_idx + first_idx] = q0 * cos_val - q1 * sin_val;
183
+ input[q_base_idx + second_idx] = q0 * sin_val + q1 * cos_val;
175
184
 
176
185
  // Process K
177
- let k0 = input[k_base_idx + pair_idx];
178
- let k1 = input[k_base_idx + pair_idx + half_dim];
179
- input[k_base_idx + pair_idx] = k0 * cos_val - k1 * sin_val;
180
- input[k_base_idx + pair_idx + half_dim] = k0 * sin_val + k1 * cos_val;
186
+ let k0 = input[k_base_idx + first_idx];
187
+ let k1 = input[k_base_idx + second_idx];
188
+ input[k_base_idx + first_idx] = k0 * cos_val - k1 * sin_val;
189
+ input[k_base_idx + second_idx] = k0 * sin_val + k1 * cos_val;
181
190
  }
182
191
 
183
192
  // Precompute frequency table (run once at init)
@@ -190,9 +199,10 @@ fn precompute_freqs(
190
199
  let seq_len = u.seq_len; // maxSeqLen for precomputation
191
200
  let rope_base = u.rope_base;
192
201
  let rope_scale = u.rope_scale;
202
+ let rotary_dim = u.rotary_dim;
193
203
 
194
204
  let idx = global_id.x;
195
- let half_dim = head_dim / 2u;
205
+ let half_dim = rotary_dim / 2u;
196
206
  let total_elements = seq_len * half_dim;
197
207
 
198
208
  if (idx >= total_elements) {
@@ -203,7 +213,7 @@ fn precompute_freqs(
203
213
  let dim_idx = idx % half_dim;
204
214
 
205
215
  let actual_pos = f32(pos) / rope_scale;
206
- let exponent = f32(dim_idx * 2u) / f32(head_dim);
216
+ let exponent = f32(dim_idx * 2u) / f32(rotary_dim);
207
217
  let freq = 1.0 / pow(rope_base, exponent);
208
218
  let theta = actual_pos * freq;
209
219
 
@@ -218,6 +228,7 @@ fn rope_ntk_scaled(
218
228
  @builtin(global_invocation_id) global_id: vec3<u32>
219
229
  ) {
220
230
  let head_dim = u.head_dim;
231
+ let rotary_dim = u.rotary_dim;
221
232
  let num_heads = u.num_heads;
222
233
  let seq_len = u.seq_len;
223
234
  let start_pos = u.start_pos;
@@ -225,7 +236,7 @@ fn rope_ntk_scaled(
225
236
  let rope_scale = u.rope_scale;
226
237
 
227
238
  let idx = global_id.x;
228
- let half_dim = head_dim / 2u;
239
+ let half_dim = rotary_dim / 2u;
229
240
  let total_pairs = seq_len * num_heads * half_dim;
230
241
 
231
242
  if (idx >= total_pairs) {
@@ -234,7 +245,7 @@ fn rope_ntk_scaled(
234
245
 
235
246
  // NTK scaling: increase base proportionally to scale factor
236
247
  // This preserves high-frequency components better than linear interpolation
237
- rope_base = rope_base * pow(rope_scale, f32(head_dim) / (f32(head_dim) - 2.0));
248
+ rope_base = rope_base * pow(rope_scale, f32(rotary_dim) / (f32(rotary_dim) - 2.0));
238
249
 
239
250
  let pos = idx / (num_heads * half_dim);
240
251
  let remainder = idx % (num_heads * half_dim);
@@ -243,7 +254,7 @@ fn rope_ntk_scaled(
243
254
 
244
255
  let actual_pos = f32(start_pos + pos);
245
256
 
246
- let exponent = f32(pair_idx * 2u) / f32(head_dim);
257
+ let exponent = f32(pair_idx * 2u) / f32(rotary_dim);
247
258
  let freq = 1.0 / pow(rope_base, exponent);
248
259
  let theta = actual_pos * freq;
249
260
 
@@ -251,11 +262,13 @@ fn rope_ntk_scaled(
251
262
  let sin_val = sin(theta);
252
263
 
253
264
  let base_idx = pos * num_heads * head_dim + head_idx * head_dim;
254
- let x0 = input[base_idx + pair_idx];
255
- let x1 = input[base_idx + pair_idx + half_dim];
265
+ let first_idx = select(pair_idx, pair_idx * 2u, u.interleaved == 1u);
266
+ let second_idx = select(pair_idx + half_dim, pair_idx * 2u + 1u, u.interleaved == 1u);
267
+ let x0 = input[base_idx + first_idx];
268
+ let x1 = input[base_idx + second_idx];
256
269
 
257
- input[base_idx + pair_idx] = x0 * cos_val - x1 * sin_val;
258
- input[base_idx + pair_idx + half_dim] = x0 * sin_val + x1 * cos_val;
270
+ input[base_idx + first_idx] = x0 * cos_val - x1 * sin_val;
271
+ input[base_idx + second_idx] = x0 * sin_val + x1 * cos_val;
259
272
  }
260
273
 
261
274
  // YaRN-style RoPE with attention scaling
@@ -265,6 +278,7 @@ fn rope_yarn(
265
278
  @builtin(global_invocation_id) global_id: vec3<u32>
266
279
  ) {
267
280
  let head_dim = u.head_dim;
281
+ let rotary_dim = u.rotary_dim;
268
282
  let num_heads = u.num_heads;
269
283
  let seq_len = u.seq_len;
270
284
  let start_pos = u.start_pos;
@@ -272,7 +286,7 @@ fn rope_yarn(
272
286
  let rope_scale = u.rope_scale;
273
287
 
274
288
  let idx = global_id.x;
275
- let half_dim = head_dim / 2u;
289
+ let half_dim = rotary_dim / 2u;
276
290
  let total_pairs = seq_len * num_heads * half_dim;
277
291
 
278
292
  if (idx >= total_pairs) {
@@ -292,7 +306,7 @@ fn rope_yarn(
292
306
  let alpha: f32 = 1.0;
293
307
 
294
308
  // Compute original frequency
295
- let exponent = f32(pair_idx * 2u) / f32(head_dim);
309
+ let exponent = f32(pair_idx * 2u) / f32(rotary_dim);
296
310
  let orig_freq = 1.0 / pow(rope_base, exponent);
297
311
 
298
312
  // Compute wavelength
@@ -300,8 +314,8 @@ fn rope_yarn(
300
314
 
301
315
  // Interpolation factor based on wavelength
302
316
  var ramp: f32;
303
- let low_wavelength = f32(head_dim) / beta_fast;
304
- let high_wavelength = f32(head_dim) / beta_slow;
317
+ let low_wavelength = f32(rotary_dim) / beta_fast;
318
+ let high_wavelength = f32(rotary_dim) / beta_slow;
305
319
 
306
320
  if (wavelength < low_wavelength) {
307
321
  ramp = 0.0; // No interpolation for high frequencies
@@ -320,9 +334,11 @@ fn rope_yarn(
320
334
  let sin_val = sin(theta);
321
335
 
322
336
  let base_idx = pos * num_heads * head_dim + head_idx * head_dim;
323
- let x0 = input[base_idx + pair_idx];
324
- let x1 = input[base_idx + pair_idx + half_dim];
337
+ let first_idx = select(pair_idx, pair_idx * 2u, u.interleaved == 1u);
338
+ let second_idx = select(pair_idx + half_dim, pair_idx * 2u + 1u, u.interleaved == 1u);
339
+ let x0 = input[base_idx + first_idx];
340
+ let x1 = input[base_idx + second_idx];
325
341
 
326
- input[base_idx + pair_idx] = x0 * cos_val - x1 * sin_val;
327
- input[base_idx + pair_idx + half_dim] = x0 * sin_val + x1 * cos_val;
342
+ input[base_idx + first_idx] = x0 * cos_val - x1 * sin_val;
343
+ input[base_idx + second_idx] = x0 * sin_val + x1 * cos_val;
328
344
  }
@@ -29,7 +29,6 @@ async function runSummary(target, query, key, value, summaryBuffer, uniforms, va
29
29
  }
30
30
 
31
31
  async function runApply(target, query, summaryBuffer, outputBuffer, uniforms, variant) {
32
- const outputSize = uniforms.num_tokens * uniforms.hidden_size;
33
32
  await unifiedKernelWrapper(
34
33
  'sana_linear_attention_apply',
35
34
  target,
@@ -45,7 +44,7 @@ async function runApply(target, query, summaryBuffer, outputBuffer, uniforms, va
45
44
  _pad1: 0,
46
45
  _pad2: 0,
47
46
  },
48
- Math.ceil(outputSize / WORKGROUP_SIZES.DEFAULT)
47
+ [Math.ceil(uniforms.hidden_size / WORKGROUP_SIZES.DEFAULT), uniforms.num_tokens, 1]
49
48
  );
50
49
  }
51
50
 
@@ -18,14 +18,13 @@ struct Uniforms {
18
18
 
19
19
  @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
20
20
  fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
21
- let idx = gid.x;
22
- let total = u.num_tokens * u.hidden_size;
23
- if (idx >= total) {
21
+ let hidden = gid.x;
22
+ let token = gid.y;
23
+ if (token >= u.num_tokens || hidden >= u.hidden_size) {
24
24
  return;
25
25
  }
26
26
 
27
- let token = idx / u.hidden_size;
28
- let hidden = idx - token * u.hidden_size;
27
+ let idx = token * u.hidden_size + hidden;
29
28
  let head = hidden / u.head_dim;
30
29
  let dim = hidden - head * u.head_dim;
31
30
  let rows_per_head = u.head_dim + 1u;
@@ -20,14 +20,13 @@ struct Uniforms {
20
20
 
21
21
  @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
22
22
  fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
23
- let idx = gid.x;
24
- let total = u.num_tokens * u.hidden_size;
25
- if (idx >= total) {
23
+ let hidden = gid.x;
24
+ let token = gid.y;
25
+ if (token >= u.num_tokens || hidden >= u.hidden_size) {
26
26
  return;
27
27
  }
28
28
 
29
- let token = idx / u.hidden_size;
30
- let hidden = idx - token * u.hidden_size;
29
+ let idx = token * u.hidden_size + hidden;
31
30
  let head = hidden / u.head_dim;
32
31
  let dim = hidden - head * u.head_dim;
33
32
  let rows_per_head = u.head_dim + 1u;
@@ -33,6 +33,7 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
33
33
 
34
34
  var acc: f32 = 0.0;
35
35
  for (var token: u32 = 0u; token < u.num_tokens; token = token + 1u) {
36
+ let query_value = query[token * u.hidden_size + hidden_base + col];
36
37
  let key_idx = token * u.hidden_size + hidden_base + col;
37
38
  let key_value = max(key[key_idx], 0.0);
38
39
  let value_value = select(
@@ -40,6 +41,9 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
40
41
  1.0,
41
42
  row == u.head_dim
42
43
  );
44
+ if (u.hidden_size == 0u) {
45
+ acc = acc + query_value;
46
+ }
43
47
  acc = acc + value_value * key_value;
44
48
  }
45
49
 
@@ -35,6 +35,7 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
35
35
 
36
36
  var acc: f32 = 0.0;
37
37
  for (var token: u32 = 0u; token < u.num_tokens; token = token + 1u) {
38
+ let query_value = f32(query[token * u.hidden_size + hidden_base + col]);
38
39
  let key_idx = token * u.hidden_size + hidden_base + col;
39
40
  let key_value = max(f32(key[key_idx]), 0.0);
40
41
  let value_value = select(
@@ -42,6 +43,9 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
42
43
  1.0,
43
44
  row == u.head_dim
44
45
  );
46
+ if (u.hidden_size == 0u) {
47
+ acc = acc + query_value;
48
+ }
45
49
  acc = acc + value_value * key_value;
46
50
  }
47
51
 
@@ -16,6 +16,7 @@ export interface SiLUOptions extends OutputBufferOptions {
16
16
  size?: number | null;
17
17
  gate?: Tensor | null;
18
18
  gateActivation?: 'silu' | 'sigmoid';
19
+ inputActivation?: 'silu' | 'identity';
19
20
  useVec4?: boolean;
20
21
  biasOffset?: number;
21
22
  swigluLimit: number | null;
@@ -47,6 +47,18 @@ function createSiLUBindGroupEntries(uniformBuffer, input, output, gate) {
47
47
  ];
48
48
  }
49
49
 
50
+ function planSiLUDispatch(device, size, useVec4) {
51
+ const maxPerDim = Number.isFinite(device?.limits?.maxComputeWorkgroupsPerDimension)
52
+ ? device.limits.maxComputeWorkgroupsPerDimension
53
+ : 65535;
54
+ const laneWidth = useVec4 ? 4 : 1;
55
+ const chunkSize = maxPerDim * WORKGROUP_SIZES.DEFAULT * laneWidth;
56
+ const dispatchStride = Math.min(size, chunkSize);
57
+ const x = Math.min(maxPerDim, Math.ceil(dispatchStride / (WORKGROUP_SIZES.DEFAULT * laneWidth)));
58
+ const y = Math.max(1, Math.ceil(size / chunkSize));
59
+ return { dispatchStride, workgroups: [x, y, 1] };
60
+ }
61
+
50
62
 
51
63
  export async function runSiLU(
52
64
  input,
@@ -60,6 +72,7 @@ export async function runSiLU(
60
72
  useVec4 = false,
61
73
  swigluLimit,
62
74
  gateActivation = 'silu',
75
+ inputActivation = 'silu',
63
76
  } = options;
64
77
  const resolvedSwigluLimit = resolveSwigluLimit(swigluLimit, 'SiLU');
65
78
 
@@ -74,14 +87,17 @@ export async function runSiLU(
74
87
  useSplit: false,
75
88
  useRowsplit: false,
76
89
  });
77
- const constants = gate && gateActivation === 'sigmoid'
78
- ? { ...(overrides || {}), GATE_USE_SIGMOID: true }
79
- : overrides;
90
+ const constants = {
91
+ ...(overrides || {}),
92
+ ...(gate && gateActivation === 'sigmoid' ? { GATE_USE_SIGMOID: true } : {}),
93
+ ...(inputActivation === 'identity' ? { INPUT_USE_IDENTITY: true } : {}),
94
+ };
80
95
  const pipeline = await getPipelineFast('silu', variant, null, constants);
81
96
 
82
97
  const inferredSize = size || (input.buffer.size / bytesPerElement);
83
98
  const outputSize = inferredSize * bytesPerElement;
84
99
  const output = outputBuffer || acquireBuffer(outputSize, undefined, 'silu_output');
100
+ const dispatchPlan = planSiLUDispatch(device, inferredSize, useVec4);
85
101
 
86
102
  // Create uniform buffer
87
103
  const uniformBuffer = createUniformBufferWithView(
@@ -89,7 +105,7 @@ export async function runSiLU(
89
105
  16,
90
106
  (view) => {
91
107
  view.setUint32(0, inferredSize, true);
92
- view.setUint32(4, 0, true);
108
+ view.setUint32(4, dispatchPlan.dispatchStride, true);
93
109
  view.setFloat32(8, gate ? resolvedSwigluLimit : 0, true);
94
110
  view.setFloat32(12, 0, true);
95
111
  },
@@ -106,8 +122,7 @@ export async function runSiLU(
106
122
  entries,
107
123
  });
108
124
 
109
- const workgroups = Math.ceil(inferredSize / WORKGROUP_SIZES.DEFAULT);
110
- dispatch(device, pipeline, bindGroup, workgroups, 'silu');
125
+ dispatch(device, pipeline, bindGroup, dispatchPlan.workgroups, 'silu');
111
126
 
112
127
  uniformBuffer.destroy();
113
128
 
@@ -215,7 +230,7 @@ export async function runSiLURowSplit(
215
230
  ],
216
231
  });
217
232
 
218
- const workgroups = Math.ceil((numTokens * dim) / WORKGROUP_SIZES.DEFAULT);
233
+ const workgroups = [Math.ceil(dim / WORKGROUP_SIZES.DEFAULT), numTokens, 1];
219
234
  dispatch(device, pipeline, bindGroup, workgroups, 'silu_rowsplit');
220
235
 
221
236
  uniformBuffer.destroy();
@@ -269,7 +284,7 @@ export async function recordSiLURowSplit(
269
284
  ],
270
285
  });
271
286
 
272
- const workgroups = Math.ceil((numTokens * dim) / WORKGROUP_SIZES.DEFAULT);
287
+ const workgroups = [Math.ceil(dim / WORKGROUP_SIZES.DEFAULT), numTokens, 1];
273
288
  recordDispatch(recorder, pipeline, bindGroup, workgroups, 'silu_rowsplit');
274
289
 
275
290
  return createTensor(output, input.dtype, [numTokens, dim], 'silu_rowsplit_output');
@@ -288,6 +303,7 @@ export async function recordSiLU(
288
303
  outputBuffer = null,
289
304
  swigluLimit,
290
305
  gateActivation = 'silu',
306
+ inputActivation = 'silu',
291
307
  } = options;
292
308
  const resolvedSwigluLimit = resolveSwigluLimit(swigluLimit, 'SiLU');
293
309
 
@@ -302,14 +318,17 @@ export async function recordSiLU(
302
318
  useSplit: false,
303
319
  useRowsplit: false,
304
320
  });
305
- const constants = gate && gateActivation === 'sigmoid'
306
- ? { ...(overrides || {}), GATE_USE_SIGMOID: true }
307
- : overrides;
321
+ const constants = {
322
+ ...(overrides || {}),
323
+ ...(gate && gateActivation === 'sigmoid' ? { GATE_USE_SIGMOID: true } : {}),
324
+ ...(inputActivation === 'identity' ? { INPUT_USE_IDENTITY: true } : {}),
325
+ };
308
326
  const pipeline = await getPipelineFast('silu', variant, null, constants);
309
327
 
310
328
  const inferredSize = size || (input.buffer.size / bytesPerElement);
311
329
  const outputSize = inferredSize * bytesPerElement;
312
330
  const output = outputBuffer || acquireBuffer(outputSize, undefined, 'silu_output');
331
+ const dispatchPlan = planSiLUDispatch(device, inferredSize, false);
313
332
 
314
333
  // Uniform buffer
315
334
  const uniformBuffer = createUniformBufferWithView(
@@ -317,7 +336,7 @@ export async function recordSiLU(
317
336
  16,
318
337
  (view) => {
319
338
  view.setUint32(0, inferredSize, true);
320
- view.setUint32(4, 0, true);
339
+ view.setUint32(4, dispatchPlan.dispatchStride, true);
321
340
  view.setFloat32(8, gate ? resolvedSwigluLimit : 0, true);
322
341
  view.setFloat32(12, 0, true);
323
342
  },
@@ -333,8 +352,7 @@ export async function recordSiLU(
333
352
  entries,
334
353
  });
335
354
 
336
- const workgroups = Math.ceil(inferredSize / WORKGROUP_SIZES.DEFAULT);
337
- recordDispatch(recorder, pipeline, bindGroup, workgroups, 'silu');
355
+ recordDispatch(recorder, pipeline, bindGroup, dispatchPlan.workgroups, 'silu');
338
356
 
339
357
  return createTensor(output, input.dtype, [inferredSize], 'silu_output');
340
358
  }
@@ -10,13 +10,14 @@
10
10
  override WORKGROUP_SIZE: u32 = 256u;
11
11
  override HAS_GATE: bool = false;
12
12
  override GATE_USE_SIGMOID: bool = false;
13
+ override INPUT_USE_IDENTITY: bool = false;
13
14
  override USE_SPLIT: bool = false;
14
15
  override USE_VEC4: bool = false;
15
16
  override USE_ROWSPLIT: bool = false;
16
17
 
17
18
  struct Uniforms {
18
19
  size: u32, // Total output elements
19
- rowsplit_dim: u32, // Dim for rowsplit variants (0 when unused)
20
+ rowsplit_dim: u32, // Row-split dim or dispatch stride for non-row-split variants
20
21
  clamp_max: f32, // SwiGLU clamp (0 = disabled)
21
22
  _pad1: f32,
22
23
  }
@@ -35,6 +36,10 @@ fn silu(x: f32) -> f32 {
35
36
  return x * sigmoid(x);
36
37
  }
37
38
 
39
+ fn apply_input_activation(x: f32) -> f32 {
40
+ return select(silu(x), x, INPUT_USE_IDENTITY);
41
+ }
42
+
38
43
  fn clamp_swiglu(x: f32) -> f32 {
39
44
  if (u.clamp_max <= 0.0) {
40
45
  return x;
@@ -46,8 +51,9 @@ fn clamp_swiglu(x: f32) -> f32 {
46
51
  fn main(
47
52
  @builtin(global_invocation_id) global_id: vec3<u32>
48
53
  ) {
54
+ let dispatch_stride = max(u.rowsplit_dim, 1u);
49
55
  if (USE_VEC4) {
50
- let base_idx = global_id.x * 4u;
56
+ let base_idx = global_id.y * dispatch_stride + global_id.x * 4u;
51
57
  if (base_idx >= u.size) {
52
58
  return;
53
59
  }
@@ -55,12 +61,12 @@ fn main(
55
61
  let remaining = min(4u, u.size - base_idx);
56
62
  for (var i: u32 = 0u; i < remaining; i = i + 1u) {
57
63
  let x = input[base_idx + i];
58
- output[base_idx + i] = silu(x);
64
+ output[base_idx + i] = apply_input_activation(x);
59
65
  }
60
66
  return;
61
67
  }
62
68
 
63
- let idx = global_id.x;
69
+ let idx = global_id.y * dispatch_stride + global_id.x;
64
70
  if (idx >= u.size) {
65
71
  return;
66
72
  }
@@ -70,12 +76,16 @@ fn main(
70
76
  return;
71
77
  }
72
78
  let dim = u.rowsplit_dim;
73
- let token_idx = idx / dim;
74
- let dim_idx = idx % dim;
79
+ let num_tokens = u.size / dim;
80
+ let token_idx = global_id.y;
81
+ let dim_idx = global_id.x;
82
+ if (token_idx >= num_tokens || dim_idx >= dim) {
83
+ return;
84
+ }
75
85
  let row_base = token_idx * dim * 2u;
76
86
  let g = input[row_base + dim_idx];
77
87
  let up = input[row_base + dim + dim_idx];
78
- output[idx] = clamp_swiglu(silu(g) * up);
88
+ output[token_idx * dim + dim_idx] = clamp_swiglu(silu(g) * up);
79
89
  return;
80
90
  }
81
91
 
@@ -83,7 +93,7 @@ fn main(
83
93
  let up = input[idx];
84
94
  let g = gate[idx];
85
95
  let gateAct = select(silu(g), sigmoid(g), GATE_USE_SIGMOID);
86
- output[idx] = clamp_swiglu(gateAct * up);
96
+ output[idx] = clamp_swiglu(gateAct * apply_input_activation(up));
87
97
  return;
88
98
  }
89
99
 
@@ -95,5 +105,5 @@ fn main(
95
105
  }
96
106
 
97
107
  let x = input[idx];
98
- output[idx] = silu(x);
108
+ output[idx] = apply_input_activation(x);
99
109
  }
@@ -9,13 +9,14 @@ enable f16;
9
9
  override WORKGROUP_SIZE: u32 = 256u;
10
10
  override HAS_GATE: bool = false;
11
11
  override GATE_USE_SIGMOID: bool = false;
12
+ override INPUT_USE_IDENTITY: bool = false;
12
13
  override USE_SPLIT: bool = false;
13
14
  override USE_VEC4: bool = false;
14
15
  override USE_ROWSPLIT: bool = false;
15
16
 
16
17
  struct Uniforms {
17
18
  size: u32, // Total output elements
18
- rowsplit_dim: u32, // Dim for rowsplit variants (0 when unused)
19
+ rowsplit_dim: u32, // Row-split dim or dispatch stride for non-row-split variants
19
20
  clamp_max: f32, // SwiGLU clamp (0 = disabled)
20
21
  _pad1: f32,
21
22
  }
@@ -34,6 +35,10 @@ fn silu(x: f32) -> f32 {
34
35
  return x * sigmoid(x);
35
36
  }
36
37
 
38
+ fn apply_input_activation(x: f32) -> f32 {
39
+ return select(silu(x), x, INPUT_USE_IDENTITY);
40
+ }
41
+
37
42
  fn clamp_swiglu(x: f32) -> f32 {
38
43
  if (u.clamp_max <= 0.0) {
39
44
  return x;
@@ -45,8 +50,9 @@ fn clamp_swiglu(x: f32) -> f32 {
45
50
  fn main(
46
51
  @builtin(global_invocation_id) global_id: vec3<u32>
47
52
  ) {
53
+ let dispatch_stride = max(u.rowsplit_dim, 1u);
48
54
  if (USE_VEC4) {
49
- let base_idx = global_id.x * 4u;
55
+ let base_idx = global_id.y * dispatch_stride + global_id.x * 4u;
50
56
  if (base_idx >= u.size) {
51
57
  return;
52
58
  }
@@ -54,12 +60,12 @@ fn main(
54
60
  let remaining = min(4u, u.size - base_idx);
55
61
  for (var i: u32 = 0u; i < remaining; i = i + 1u) {
56
62
  let x = f32(input[base_idx + i]);
57
- output[base_idx + i] = f16(silu(x));
63
+ output[base_idx + i] = f16(apply_input_activation(x));
58
64
  }
59
65
  return;
60
66
  }
61
67
 
62
- let idx = global_id.x;
68
+ let idx = global_id.y * dispatch_stride + global_id.x;
63
69
  if (idx >= u.size) {
64
70
  return;
65
71
  }
@@ -69,12 +75,16 @@ fn main(
69
75
  return;
70
76
  }
71
77
  let dim = u.rowsplit_dim;
72
- let token_idx = idx / dim;
73
- let dim_idx = idx % dim;
78
+ let num_tokens = u.size / dim;
79
+ let token_idx = global_id.y;
80
+ let dim_idx = global_id.x;
81
+ if (token_idx >= num_tokens || dim_idx >= dim) {
82
+ return;
83
+ }
74
84
  let row_base = token_idx * dim * 2u;
75
85
  let g = f32(input[row_base + dim_idx]);
76
86
  let up = f32(input[row_base + dim + dim_idx]);
77
- output[idx] = f16(clamp_swiglu(silu(g) * up));
87
+ output[token_idx * dim + dim_idx] = f16(clamp_swiglu(silu(g) * up));
78
88
  return;
79
89
  }
80
90
 
@@ -82,7 +92,7 @@ fn main(
82
92
  let up = f32(input[idx]);
83
93
  let g = f32(gate[idx]);
84
94
  let gateAct = select(silu(g), sigmoid(g), GATE_USE_SIGMOID);
85
- output[idx] = f16(clamp_swiglu(gateAct * up));
95
+ output[idx] = f16(clamp_swiglu(gateAct * apply_input_activation(up)));
86
96
  return;
87
97
  }
88
98
 
@@ -94,5 +104,5 @@ fn main(
94
104
  }
95
105
 
96
106
  let x = f32(input[idx]);
97
- output[idx] = f16(silu(x));
107
+ output[idx] = f16(apply_input_activation(x));
98
108
  }