@simulatte/doppler 0.1.5 → 0.1.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +23 -8
- package/package.json +7 -4
- package/src/config/kernels/kernel-ref-digests.js +39 -39
- package/src/config/kernels/registry.json +42 -2
- package/src/config/loader.js +31 -2
- package/src/config/merge.js +18 -0
- package/src/config/presets/models/qwen3.json +9 -2
- package/src/config/presets/models/transformer.json +5 -0
- package/src/config/required-inference-fields-contract-check.js +6 -0
- package/src/config/schema/inference-defaults.schema.js +3 -0
- package/src/config/schema/inference.schema.d.ts +9 -0
- package/src/config/schema/kernel-path.schema.d.ts +6 -0
- package/src/config/schema/manifest.schema.d.ts +6 -0
- package/src/config/schema/manifest.schema.js +3 -0
- package/src/converter/rope-config.js +42 -0
- package/src/gpu/device.js +58 -0
- package/src/gpu/kernels/attention.js +98 -0
- package/src/gpu/kernels/bias_add.wgsl +8 -6
- package/src/gpu/kernels/bias_add_f16.wgsl +8 -5
- package/src/gpu/kernels/conv2d.js +1 -1
- package/src/gpu/kernels/conv2d.wgsl +7 -8
- package/src/gpu/kernels/conv2d_f16.wgsl +7 -8
- package/src/gpu/kernels/depthwise_conv2d.js +2 -1
- package/src/gpu/kernels/depthwise_conv2d.wgsl +6 -9
- package/src/gpu/kernels/depthwise_conv2d_f16.wgsl +6 -9
- package/src/gpu/kernels/grouped_pointwise_conv2d.js +2 -1
- package/src/gpu/kernels/grouped_pointwise_conv2d.wgsl +6 -9
- package/src/gpu/kernels/grouped_pointwise_conv2d_f16.wgsl +6 -9
- package/src/gpu/kernels/matmul.js +25 -0
- package/src/gpu/kernels/pixel_shuffle.js +1 -1
- package/src/gpu/kernels/pixel_shuffle.wgsl +4 -5
- package/src/gpu/kernels/pixel_shuffle_f16.wgsl +4 -5
- package/src/gpu/kernels/relu.js +15 -2
- package/src/gpu/kernels/relu.wgsl +2 -1
- package/src/gpu/kernels/relu_f16.wgsl +2 -1
- package/src/gpu/kernels/repeat_channels.js +1 -1
- package/src/gpu/kernels/repeat_channels.wgsl +4 -5
- package/src/gpu/kernels/repeat_channels_f16.wgsl +4 -5
- package/src/gpu/kernels/residual.js +44 -8
- package/src/gpu/kernels/residual.wgsl +6 -3
- package/src/gpu/kernels/residual_f16.wgsl +2 -1
- package/src/gpu/kernels/residual_f16_vec4.wgsl +2 -1
- package/src/gpu/kernels/residual_vec4.wgsl +2 -1
- package/src/gpu/kernels/rmsnorm.js +58 -6
- package/src/gpu/kernels/rmsnorm.wgsl +14 -6
- package/src/gpu/kernels/rmsnorm_f16.wgsl +10 -2
- package/src/gpu/kernels/rope.d.ts +2 -0
- package/src/gpu/kernels/rope.js +11 -1
- package/src/gpu/kernels/rope.wgsl +56 -40
- package/src/gpu/kernels/sana_linear_attention.js +1 -2
- package/src/gpu/kernels/sana_linear_attention_apply.wgsl +4 -5
- package/src/gpu/kernels/sana_linear_attention_apply_f16.wgsl +4 -5
- package/src/gpu/kernels/sana_linear_attention_summary.wgsl +4 -0
- package/src/gpu/kernels/sana_linear_attention_summary_f16.wgsl +4 -0
- package/src/gpu/kernels/silu.d.ts +1 -0
- package/src/gpu/kernels/silu.js +32 -14
- package/src/gpu/kernels/silu.wgsl +19 -9
- package/src/gpu/kernels/silu_f16.wgsl +19 -9
- package/src/gpu/kernels/transpose.js +15 -2
- package/src/gpu/kernels/transpose.wgsl +5 -6
- package/src/gpu/kernels/upsample2d.js +2 -1
- package/src/gpu/kernels/upsample2d.wgsl +6 -9
- package/src/gpu/kernels/upsample2d_f16.wgsl +6 -9
- package/src/gpu/kernels/utils.js +16 -1
- package/src/inference/browser-harness.js +47 -1
- package/src/inference/pipelines/diffusion/pipeline.js +15 -6
- package/src/inference/pipelines/diffusion/text-encoder-gpu.d.ts +5 -0
- package/src/inference/pipelines/diffusion/text-encoder-gpu.js +27 -15
- package/src/inference/pipelines/text/attention/record.js +11 -2
- package/src/inference/pipelines/text/attention/run.js +11 -2
- package/src/inference/pipelines/text/chat-format.js +25 -1
- package/src/inference/pipelines/text/config.d.ts +4 -0
- package/src/inference/pipelines/text/config.js +68 -1
- package/src/inference/pipelines/text/execution-plan.js +23 -31
- package/src/inference/pipelines/text/execution-v0.js +29 -2
- package/src/inference/pipelines/text/ffn/standard.js +3 -0
- package/src/inference/pipelines/text/init.d.ts +4 -0
- package/src/inference/pipelines/text/init.js +56 -9
- package/src/inference/pipelines/text/layer.js +11 -0
- package/src/inference/pipelines/text.js +4 -0
- package/src/inference/tokenizers/bundled.js +156 -33
- package/src/rules/tooling/command-runtime.rules.json +18 -0
- package/src/tooling/command-api.d.ts +27 -1
- package/src/tooling/command-api.js +142 -3
- package/src/tooling/node-browser-command-runner.d.ts +4 -0
- package/src/tooling/node-browser-command-runner.js +58 -3
- package/src/tooling/node-command-runner.js +15 -0
- package/src/tooling/node-webgpu.js +9 -87
- package/src/training/checkpoint-watch.d.ts +7 -0
- package/src/training/checkpoint-watch.js +106 -0
- package/src/training/checkpoint.d.ts +6 -1
- package/src/training/checkpoint.js +12 -2
- package/src/training/distillation/artifacts.d.ts +71 -0
- package/src/training/distillation/artifacts.js +132 -0
- package/src/training/distillation/checkpoint-watch.d.ts +10 -0
- package/src/training/distillation/checkpoint-watch.js +57 -0
- package/src/training/distillation/dataset.d.ts +59 -0
- package/src/training/distillation/dataset.js +337 -0
- package/src/training/distillation/eval.d.ts +34 -0
- package/src/training/distillation/eval.js +310 -0
- package/src/training/distillation/index.d.ts +29 -0
- package/src/training/distillation/index.js +29 -0
- package/src/training/distillation/runtime.d.ts +20 -0
- package/src/training/distillation/runtime.js +121 -0
- package/src/training/distillation/scoreboard.d.ts +6 -0
- package/src/training/distillation/scoreboard.js +8 -0
- package/src/training/distillation/stage-a.d.ts +45 -0
- package/src/training/distillation/stage-a.js +338 -0
- package/src/training/distillation/stage-b.d.ts +24 -0
- package/src/training/distillation/stage-b.js +20 -0
- package/src/training/index.d.ts +10 -0
- package/src/training/index.js +10 -0
- package/src/training/lora-pipeline.d.ts +40 -0
- package/src/training/lora-pipeline.js +796 -0
- package/src/training/operator-artifacts.d.ts +62 -0
- package/src/training/operator-artifacts.js +140 -0
- package/src/training/operator-command.d.ts +5 -0
- package/src/training/operator-command.js +453 -0
- package/src/training/operator-eval.d.ts +48 -0
- package/src/training/operator-eval.js +230 -0
- package/src/training/operator-scoreboard.d.ts +5 -0
- package/src/training/operator-scoreboard.js +44 -0
- package/src/training/runner.d.ts +52 -0
- package/src/training/runner.js +29 -4
- package/src/training/suite.d.ts +112 -0
- package/src/training/suite.js +9 -9
- package/src/training/workloads.d.ts +164 -0
- package/src/training/workloads.js +539 -0
- package/src/version.js +1 -1
- package/tools/doppler-cli.js +137 -40
|
@@ -26,8 +26,8 @@ struct Uniforms {
|
|
|
26
26
|
start_pos: u32, // Starting position (for decode)
|
|
27
27
|
rope_base: f32, // Base frequency (default 10000)
|
|
28
28
|
rope_scale: f32, // Scaling factor for extended context
|
|
29
|
-
|
|
30
|
-
|
|
29
|
+
rotary_dim: u32, // Rotary slice within head_dim
|
|
30
|
+
interleaved: u32, // 1 = adjacent pairs, 0 = rotate-half
|
|
31
31
|
}
|
|
32
32
|
|
|
33
33
|
@group(0) @binding(0) var<uniform> u: Uniforms;
|
|
@@ -46,7 +46,8 @@ fn main(
|
|
|
46
46
|
let start_pos = u.start_pos;
|
|
47
47
|
|
|
48
48
|
// Global thread index (one thread per complex pair)
|
|
49
|
-
let
|
|
49
|
+
let rotary_dim = u.rotary_dim;
|
|
50
|
+
let half_dim = rotary_dim / 2u;
|
|
50
51
|
let total_pairs = seq_len * num_heads * half_dim;
|
|
51
52
|
let idx = global_id.x;
|
|
52
53
|
|
|
@@ -68,16 +69,18 @@ fn main(
|
|
|
68
69
|
|
|
69
70
|
// Apply "rotate-half" layout: pair (x[i], x[i + half_dim])
|
|
70
71
|
let base_idx = pos * num_heads * head_dim + head_idx * head_dim;
|
|
71
|
-
let
|
|
72
|
-
let
|
|
72
|
+
let first_idx = select(pair_idx, pair_idx * 2u, u.interleaved == 1u);
|
|
73
|
+
let second_idx = select(pair_idx + half_dim, pair_idx * 2u + 1u, u.interleaved == 1u);
|
|
74
|
+
let x0 = input[base_idx + first_idx];
|
|
75
|
+
let x1 = input[base_idx + second_idx];
|
|
73
76
|
|
|
74
77
|
// Apply rotation
|
|
75
78
|
let y0 = x0 * cos_val - x1 * sin_val;
|
|
76
79
|
let y1 = x0 * sin_val + x1 * cos_val;
|
|
77
80
|
|
|
78
81
|
// Write back
|
|
79
|
-
input[base_idx +
|
|
80
|
-
input[base_idx +
|
|
82
|
+
input[base_idx + first_idx] = y0;
|
|
83
|
+
input[base_idx + second_idx] = y1;
|
|
81
84
|
}
|
|
82
85
|
|
|
83
86
|
// Compute frequencies on-the-fly (no precomputation needed)
|
|
@@ -91,9 +94,10 @@ fn rope_compute_freqs(
|
|
|
91
94
|
let start_pos = u.start_pos;
|
|
92
95
|
let rope_base = u.rope_base;
|
|
93
96
|
let rope_scale = u.rope_scale;
|
|
97
|
+
let rotary_dim = u.rotary_dim;
|
|
94
98
|
|
|
95
99
|
let idx = global_id.x;
|
|
96
|
-
let half_dim =
|
|
100
|
+
let half_dim = rotary_dim / 2u;
|
|
97
101
|
let total_pairs = seq_len * num_heads * half_dim;
|
|
98
102
|
|
|
99
103
|
if (idx >= total_pairs) {
|
|
@@ -109,7 +113,7 @@ fn rope_compute_freqs(
|
|
|
109
113
|
let actual_pos = f32(start_pos + pos) / rope_scale;
|
|
110
114
|
|
|
111
115
|
// Compute frequency: 1 / (base^(2*pair_idx/head_dim))
|
|
112
|
-
let exponent = f32(pair_idx * 2u) / f32(
|
|
116
|
+
let exponent = f32(pair_idx * 2u) / f32(rotary_dim);
|
|
113
117
|
let freq = 1.0 / pow(rope_base, exponent);
|
|
114
118
|
let theta = actual_pos * freq;
|
|
115
119
|
|
|
@@ -118,12 +122,14 @@ fn rope_compute_freqs(
|
|
|
118
122
|
|
|
119
123
|
// Apply "rotate-half" layout: pair (x[i], x[i + half_dim])
|
|
120
124
|
let base_idx = pos * num_heads * head_dim + head_idx * head_dim;
|
|
121
|
-
let
|
|
122
|
-
let
|
|
125
|
+
let first_idx = select(pair_idx, pair_idx * 2u, u.interleaved == 1u);
|
|
126
|
+
let second_idx = select(pair_idx + half_dim, pair_idx * 2u + 1u, u.interleaved == 1u);
|
|
127
|
+
let x0 = input[base_idx + first_idx];
|
|
128
|
+
let x1 = input[base_idx + second_idx];
|
|
123
129
|
|
|
124
130
|
// Apply rotation
|
|
125
|
-
input[base_idx +
|
|
126
|
-
input[base_idx +
|
|
131
|
+
input[base_idx + first_idx] = x0 * cos_val - x1 * sin_val;
|
|
132
|
+
input[base_idx + second_idx] = x0 * sin_val + x1 * cos_val;
|
|
127
133
|
}
|
|
128
134
|
|
|
129
135
|
// Apply RoPE to both Q and K in one pass
|
|
@@ -138,10 +144,11 @@ fn rope_qk(
|
|
|
138
144
|
let start_pos = u.start_pos;
|
|
139
145
|
let rope_base = u.rope_base;
|
|
140
146
|
let rope_scale = u.rope_scale;
|
|
147
|
+
let rotary_dim = u.rotary_dim;
|
|
141
148
|
|
|
142
149
|
let idx = global_id.x;
|
|
143
150
|
// Each thread handles one Q-K pair at one dimension pair
|
|
144
|
-
let half_dim =
|
|
151
|
+
let half_dim = rotary_dim / 2u;
|
|
145
152
|
let total_pairs = seq_len * num_heads * half_dim;
|
|
146
153
|
|
|
147
154
|
if (idx >= total_pairs) {
|
|
@@ -156,7 +163,7 @@ fn rope_qk(
|
|
|
156
163
|
let actual_pos = f32(start_pos + pos) / rope_scale;
|
|
157
164
|
|
|
158
165
|
// Compute frequency
|
|
159
|
-
let exponent = f32(pair_idx * 2u) / f32(
|
|
166
|
+
let exponent = f32(pair_idx * 2u) / f32(rotary_dim);
|
|
160
167
|
let freq = 1.0 / pow(rope_base, exponent);
|
|
161
168
|
let theta = actual_pos * freq;
|
|
162
169
|
|
|
@@ -168,16 +175,18 @@ fn rope_qk(
|
|
|
168
175
|
let k_base_idx = q_base_idx + head_dim; // K starts after Q
|
|
169
176
|
|
|
170
177
|
// Process Q
|
|
171
|
-
let
|
|
172
|
-
let
|
|
173
|
-
input[q_base_idx +
|
|
174
|
-
input[q_base_idx +
|
|
178
|
+
let first_idx = select(pair_idx, pair_idx * 2u, u.interleaved == 1u);
|
|
179
|
+
let second_idx = select(pair_idx + half_dim, pair_idx * 2u + 1u, u.interleaved == 1u);
|
|
180
|
+
let q0 = input[q_base_idx + first_idx];
|
|
181
|
+
let q1 = input[q_base_idx + second_idx];
|
|
182
|
+
input[q_base_idx + first_idx] = q0 * cos_val - q1 * sin_val;
|
|
183
|
+
input[q_base_idx + second_idx] = q0 * sin_val + q1 * cos_val;
|
|
175
184
|
|
|
176
185
|
// Process K
|
|
177
|
-
let k0 = input[k_base_idx +
|
|
178
|
-
let k1 = input[k_base_idx +
|
|
179
|
-
input[k_base_idx +
|
|
180
|
-
input[k_base_idx +
|
|
186
|
+
let k0 = input[k_base_idx + first_idx];
|
|
187
|
+
let k1 = input[k_base_idx + second_idx];
|
|
188
|
+
input[k_base_idx + first_idx] = k0 * cos_val - k1 * sin_val;
|
|
189
|
+
input[k_base_idx + second_idx] = k0 * sin_val + k1 * cos_val;
|
|
181
190
|
}
|
|
182
191
|
|
|
183
192
|
// Precompute frequency table (run once at init)
|
|
@@ -190,9 +199,10 @@ fn precompute_freqs(
|
|
|
190
199
|
let seq_len = u.seq_len; // maxSeqLen for precomputation
|
|
191
200
|
let rope_base = u.rope_base;
|
|
192
201
|
let rope_scale = u.rope_scale;
|
|
202
|
+
let rotary_dim = u.rotary_dim;
|
|
193
203
|
|
|
194
204
|
let idx = global_id.x;
|
|
195
|
-
let half_dim =
|
|
205
|
+
let half_dim = rotary_dim / 2u;
|
|
196
206
|
let total_elements = seq_len * half_dim;
|
|
197
207
|
|
|
198
208
|
if (idx >= total_elements) {
|
|
@@ -203,7 +213,7 @@ fn precompute_freqs(
|
|
|
203
213
|
let dim_idx = idx % half_dim;
|
|
204
214
|
|
|
205
215
|
let actual_pos = f32(pos) / rope_scale;
|
|
206
|
-
let exponent = f32(dim_idx * 2u) / f32(
|
|
216
|
+
let exponent = f32(dim_idx * 2u) / f32(rotary_dim);
|
|
207
217
|
let freq = 1.0 / pow(rope_base, exponent);
|
|
208
218
|
let theta = actual_pos * freq;
|
|
209
219
|
|
|
@@ -218,6 +228,7 @@ fn rope_ntk_scaled(
|
|
|
218
228
|
@builtin(global_invocation_id) global_id: vec3<u32>
|
|
219
229
|
) {
|
|
220
230
|
let head_dim = u.head_dim;
|
|
231
|
+
let rotary_dim = u.rotary_dim;
|
|
221
232
|
let num_heads = u.num_heads;
|
|
222
233
|
let seq_len = u.seq_len;
|
|
223
234
|
let start_pos = u.start_pos;
|
|
@@ -225,7 +236,7 @@ fn rope_ntk_scaled(
|
|
|
225
236
|
let rope_scale = u.rope_scale;
|
|
226
237
|
|
|
227
238
|
let idx = global_id.x;
|
|
228
|
-
let half_dim =
|
|
239
|
+
let half_dim = rotary_dim / 2u;
|
|
229
240
|
let total_pairs = seq_len * num_heads * half_dim;
|
|
230
241
|
|
|
231
242
|
if (idx >= total_pairs) {
|
|
@@ -234,7 +245,7 @@ fn rope_ntk_scaled(
|
|
|
234
245
|
|
|
235
246
|
// NTK scaling: increase base proportionally to scale factor
|
|
236
247
|
// This preserves high-frequency components better than linear interpolation
|
|
237
|
-
rope_base = rope_base * pow(rope_scale, f32(
|
|
248
|
+
rope_base = rope_base * pow(rope_scale, f32(rotary_dim) / (f32(rotary_dim) - 2.0));
|
|
238
249
|
|
|
239
250
|
let pos = idx / (num_heads * half_dim);
|
|
240
251
|
let remainder = idx % (num_heads * half_dim);
|
|
@@ -243,7 +254,7 @@ fn rope_ntk_scaled(
|
|
|
243
254
|
|
|
244
255
|
let actual_pos = f32(start_pos + pos);
|
|
245
256
|
|
|
246
|
-
let exponent = f32(pair_idx * 2u) / f32(
|
|
257
|
+
let exponent = f32(pair_idx * 2u) / f32(rotary_dim);
|
|
247
258
|
let freq = 1.0 / pow(rope_base, exponent);
|
|
248
259
|
let theta = actual_pos * freq;
|
|
249
260
|
|
|
@@ -251,11 +262,13 @@ fn rope_ntk_scaled(
|
|
|
251
262
|
let sin_val = sin(theta);
|
|
252
263
|
|
|
253
264
|
let base_idx = pos * num_heads * head_dim + head_idx * head_dim;
|
|
254
|
-
let
|
|
255
|
-
let
|
|
265
|
+
let first_idx = select(pair_idx, pair_idx * 2u, u.interleaved == 1u);
|
|
266
|
+
let second_idx = select(pair_idx + half_dim, pair_idx * 2u + 1u, u.interleaved == 1u);
|
|
267
|
+
let x0 = input[base_idx + first_idx];
|
|
268
|
+
let x1 = input[base_idx + second_idx];
|
|
256
269
|
|
|
257
|
-
input[base_idx +
|
|
258
|
-
input[base_idx +
|
|
270
|
+
input[base_idx + first_idx] = x0 * cos_val - x1 * sin_val;
|
|
271
|
+
input[base_idx + second_idx] = x0 * sin_val + x1 * cos_val;
|
|
259
272
|
}
|
|
260
273
|
|
|
261
274
|
// YaRN-style RoPE with attention scaling
|
|
@@ -265,6 +278,7 @@ fn rope_yarn(
|
|
|
265
278
|
@builtin(global_invocation_id) global_id: vec3<u32>
|
|
266
279
|
) {
|
|
267
280
|
let head_dim = u.head_dim;
|
|
281
|
+
let rotary_dim = u.rotary_dim;
|
|
268
282
|
let num_heads = u.num_heads;
|
|
269
283
|
let seq_len = u.seq_len;
|
|
270
284
|
let start_pos = u.start_pos;
|
|
@@ -272,7 +286,7 @@ fn rope_yarn(
|
|
|
272
286
|
let rope_scale = u.rope_scale;
|
|
273
287
|
|
|
274
288
|
let idx = global_id.x;
|
|
275
|
-
let half_dim =
|
|
289
|
+
let half_dim = rotary_dim / 2u;
|
|
276
290
|
let total_pairs = seq_len * num_heads * half_dim;
|
|
277
291
|
|
|
278
292
|
if (idx >= total_pairs) {
|
|
@@ -292,7 +306,7 @@ fn rope_yarn(
|
|
|
292
306
|
let alpha: f32 = 1.0;
|
|
293
307
|
|
|
294
308
|
// Compute original frequency
|
|
295
|
-
let exponent = f32(pair_idx * 2u) / f32(
|
|
309
|
+
let exponent = f32(pair_idx * 2u) / f32(rotary_dim);
|
|
296
310
|
let orig_freq = 1.0 / pow(rope_base, exponent);
|
|
297
311
|
|
|
298
312
|
// Compute wavelength
|
|
@@ -300,8 +314,8 @@ fn rope_yarn(
|
|
|
300
314
|
|
|
301
315
|
// Interpolation factor based on wavelength
|
|
302
316
|
var ramp: f32;
|
|
303
|
-
let low_wavelength = f32(
|
|
304
|
-
let high_wavelength = f32(
|
|
317
|
+
let low_wavelength = f32(rotary_dim) / beta_fast;
|
|
318
|
+
let high_wavelength = f32(rotary_dim) / beta_slow;
|
|
305
319
|
|
|
306
320
|
if (wavelength < low_wavelength) {
|
|
307
321
|
ramp = 0.0; // No interpolation for high frequencies
|
|
@@ -320,9 +334,11 @@ fn rope_yarn(
|
|
|
320
334
|
let sin_val = sin(theta);
|
|
321
335
|
|
|
322
336
|
let base_idx = pos * num_heads * head_dim + head_idx * head_dim;
|
|
323
|
-
let
|
|
324
|
-
let
|
|
337
|
+
let first_idx = select(pair_idx, pair_idx * 2u, u.interleaved == 1u);
|
|
338
|
+
let second_idx = select(pair_idx + half_dim, pair_idx * 2u + 1u, u.interleaved == 1u);
|
|
339
|
+
let x0 = input[base_idx + first_idx];
|
|
340
|
+
let x1 = input[base_idx + second_idx];
|
|
325
341
|
|
|
326
|
-
input[base_idx +
|
|
327
|
-
input[base_idx +
|
|
342
|
+
input[base_idx + first_idx] = x0 * cos_val - x1 * sin_val;
|
|
343
|
+
input[base_idx + second_idx] = x0 * sin_val + x1 * cos_val;
|
|
328
344
|
}
|
|
@@ -29,7 +29,6 @@ async function runSummary(target, query, key, value, summaryBuffer, uniforms, va
|
|
|
29
29
|
}
|
|
30
30
|
|
|
31
31
|
async function runApply(target, query, summaryBuffer, outputBuffer, uniforms, variant) {
|
|
32
|
-
const outputSize = uniforms.num_tokens * uniforms.hidden_size;
|
|
33
32
|
await unifiedKernelWrapper(
|
|
34
33
|
'sana_linear_attention_apply',
|
|
35
34
|
target,
|
|
@@ -45,7 +44,7 @@ async function runApply(target, query, summaryBuffer, outputBuffer, uniforms, va
|
|
|
45
44
|
_pad1: 0,
|
|
46
45
|
_pad2: 0,
|
|
47
46
|
},
|
|
48
|
-
Math.ceil(
|
|
47
|
+
[Math.ceil(uniforms.hidden_size / WORKGROUP_SIZES.DEFAULT), uniforms.num_tokens, 1]
|
|
49
48
|
);
|
|
50
49
|
}
|
|
51
50
|
|
|
@@ -18,14 +18,13 @@ struct Uniforms {
|
|
|
18
18
|
|
|
19
19
|
@compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
|
|
20
20
|
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
21
|
-
let
|
|
22
|
-
let
|
|
23
|
-
if (
|
|
21
|
+
let hidden = gid.x;
|
|
22
|
+
let token = gid.y;
|
|
23
|
+
if (token >= u.num_tokens || hidden >= u.hidden_size) {
|
|
24
24
|
return;
|
|
25
25
|
}
|
|
26
26
|
|
|
27
|
-
let
|
|
28
|
-
let hidden = idx - token * u.hidden_size;
|
|
27
|
+
let idx = token * u.hidden_size + hidden;
|
|
29
28
|
let head = hidden / u.head_dim;
|
|
30
29
|
let dim = hidden - head * u.head_dim;
|
|
31
30
|
let rows_per_head = u.head_dim + 1u;
|
|
@@ -20,14 +20,13 @@ struct Uniforms {
|
|
|
20
20
|
|
|
21
21
|
@compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
|
|
22
22
|
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
23
|
-
let
|
|
24
|
-
let
|
|
25
|
-
if (
|
|
23
|
+
let hidden = gid.x;
|
|
24
|
+
let token = gid.y;
|
|
25
|
+
if (token >= u.num_tokens || hidden >= u.hidden_size) {
|
|
26
26
|
return;
|
|
27
27
|
}
|
|
28
28
|
|
|
29
|
-
let
|
|
30
|
-
let hidden = idx - token * u.hidden_size;
|
|
29
|
+
let idx = token * u.hidden_size + hidden;
|
|
31
30
|
let head = hidden / u.head_dim;
|
|
32
31
|
let dim = hidden - head * u.head_dim;
|
|
33
32
|
let rows_per_head = u.head_dim + 1u;
|
|
@@ -33,6 +33,7 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
|
33
33
|
|
|
34
34
|
var acc: f32 = 0.0;
|
|
35
35
|
for (var token: u32 = 0u; token < u.num_tokens; token = token + 1u) {
|
|
36
|
+
let query_value = query[token * u.hidden_size + hidden_base + col];
|
|
36
37
|
let key_idx = token * u.hidden_size + hidden_base + col;
|
|
37
38
|
let key_value = max(key[key_idx], 0.0);
|
|
38
39
|
let value_value = select(
|
|
@@ -40,6 +41,9 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
|
40
41
|
1.0,
|
|
41
42
|
row == u.head_dim
|
|
42
43
|
);
|
|
44
|
+
if (u.hidden_size == 0u) {
|
|
45
|
+
acc = acc + query_value;
|
|
46
|
+
}
|
|
43
47
|
acc = acc + value_value * key_value;
|
|
44
48
|
}
|
|
45
49
|
|
|
@@ -35,6 +35,7 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
|
35
35
|
|
|
36
36
|
var acc: f32 = 0.0;
|
|
37
37
|
for (var token: u32 = 0u; token < u.num_tokens; token = token + 1u) {
|
|
38
|
+
let query_value = f32(query[token * u.hidden_size + hidden_base + col]);
|
|
38
39
|
let key_idx = token * u.hidden_size + hidden_base + col;
|
|
39
40
|
let key_value = max(f32(key[key_idx]), 0.0);
|
|
40
41
|
let value_value = select(
|
|
@@ -42,6 +43,9 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
|
42
43
|
1.0,
|
|
43
44
|
row == u.head_dim
|
|
44
45
|
);
|
|
46
|
+
if (u.hidden_size == 0u) {
|
|
47
|
+
acc = acc + query_value;
|
|
48
|
+
}
|
|
45
49
|
acc = acc + value_value * key_value;
|
|
46
50
|
}
|
|
47
51
|
|
|
@@ -16,6 +16,7 @@ export interface SiLUOptions extends OutputBufferOptions {
|
|
|
16
16
|
size?: number | null;
|
|
17
17
|
gate?: Tensor | null;
|
|
18
18
|
gateActivation?: 'silu' | 'sigmoid';
|
|
19
|
+
inputActivation?: 'silu' | 'identity';
|
|
19
20
|
useVec4?: boolean;
|
|
20
21
|
biasOffset?: number;
|
|
21
22
|
swigluLimit: number | null;
|
package/src/gpu/kernels/silu.js
CHANGED
|
@@ -47,6 +47,18 @@ function createSiLUBindGroupEntries(uniformBuffer, input, output, gate) {
|
|
|
47
47
|
];
|
|
48
48
|
}
|
|
49
49
|
|
|
50
|
+
function planSiLUDispatch(device, size, useVec4) {
|
|
51
|
+
const maxPerDim = Number.isFinite(device?.limits?.maxComputeWorkgroupsPerDimension)
|
|
52
|
+
? device.limits.maxComputeWorkgroupsPerDimension
|
|
53
|
+
: 65535;
|
|
54
|
+
const laneWidth = useVec4 ? 4 : 1;
|
|
55
|
+
const chunkSize = maxPerDim * WORKGROUP_SIZES.DEFAULT * laneWidth;
|
|
56
|
+
const dispatchStride = Math.min(size, chunkSize);
|
|
57
|
+
const x = Math.min(maxPerDim, Math.ceil(dispatchStride / (WORKGROUP_SIZES.DEFAULT * laneWidth)));
|
|
58
|
+
const y = Math.max(1, Math.ceil(size / chunkSize));
|
|
59
|
+
return { dispatchStride, workgroups: [x, y, 1] };
|
|
60
|
+
}
|
|
61
|
+
|
|
50
62
|
|
|
51
63
|
export async function runSiLU(
|
|
52
64
|
input,
|
|
@@ -60,6 +72,7 @@ export async function runSiLU(
|
|
|
60
72
|
useVec4 = false,
|
|
61
73
|
swigluLimit,
|
|
62
74
|
gateActivation = 'silu',
|
|
75
|
+
inputActivation = 'silu',
|
|
63
76
|
} = options;
|
|
64
77
|
const resolvedSwigluLimit = resolveSwigluLimit(swigluLimit, 'SiLU');
|
|
65
78
|
|
|
@@ -74,14 +87,17 @@ export async function runSiLU(
|
|
|
74
87
|
useSplit: false,
|
|
75
88
|
useRowsplit: false,
|
|
76
89
|
});
|
|
77
|
-
const constants =
|
|
78
|
-
|
|
79
|
-
:
|
|
90
|
+
const constants = {
|
|
91
|
+
...(overrides || {}),
|
|
92
|
+
...(gate && gateActivation === 'sigmoid' ? { GATE_USE_SIGMOID: true } : {}),
|
|
93
|
+
...(inputActivation === 'identity' ? { INPUT_USE_IDENTITY: true } : {}),
|
|
94
|
+
};
|
|
80
95
|
const pipeline = await getPipelineFast('silu', variant, null, constants);
|
|
81
96
|
|
|
82
97
|
const inferredSize = size || (input.buffer.size / bytesPerElement);
|
|
83
98
|
const outputSize = inferredSize * bytesPerElement;
|
|
84
99
|
const output = outputBuffer || acquireBuffer(outputSize, undefined, 'silu_output');
|
|
100
|
+
const dispatchPlan = planSiLUDispatch(device, inferredSize, useVec4);
|
|
85
101
|
|
|
86
102
|
// Create uniform buffer
|
|
87
103
|
const uniformBuffer = createUniformBufferWithView(
|
|
@@ -89,7 +105,7 @@ export async function runSiLU(
|
|
|
89
105
|
16,
|
|
90
106
|
(view) => {
|
|
91
107
|
view.setUint32(0, inferredSize, true);
|
|
92
|
-
view.setUint32(4,
|
|
108
|
+
view.setUint32(4, dispatchPlan.dispatchStride, true);
|
|
93
109
|
view.setFloat32(8, gate ? resolvedSwigluLimit : 0, true);
|
|
94
110
|
view.setFloat32(12, 0, true);
|
|
95
111
|
},
|
|
@@ -106,8 +122,7 @@ export async function runSiLU(
|
|
|
106
122
|
entries,
|
|
107
123
|
});
|
|
108
124
|
|
|
109
|
-
|
|
110
|
-
dispatch(device, pipeline, bindGroup, workgroups, 'silu');
|
|
125
|
+
dispatch(device, pipeline, bindGroup, dispatchPlan.workgroups, 'silu');
|
|
111
126
|
|
|
112
127
|
uniformBuffer.destroy();
|
|
113
128
|
|
|
@@ -215,7 +230,7 @@ export async function runSiLURowSplit(
|
|
|
215
230
|
],
|
|
216
231
|
});
|
|
217
232
|
|
|
218
|
-
const workgroups = Math.ceil(
|
|
233
|
+
const workgroups = [Math.ceil(dim / WORKGROUP_SIZES.DEFAULT), numTokens, 1];
|
|
219
234
|
dispatch(device, pipeline, bindGroup, workgroups, 'silu_rowsplit');
|
|
220
235
|
|
|
221
236
|
uniformBuffer.destroy();
|
|
@@ -269,7 +284,7 @@ export async function recordSiLURowSplit(
|
|
|
269
284
|
],
|
|
270
285
|
});
|
|
271
286
|
|
|
272
|
-
const workgroups = Math.ceil(
|
|
287
|
+
const workgroups = [Math.ceil(dim / WORKGROUP_SIZES.DEFAULT), numTokens, 1];
|
|
273
288
|
recordDispatch(recorder, pipeline, bindGroup, workgroups, 'silu_rowsplit');
|
|
274
289
|
|
|
275
290
|
return createTensor(output, input.dtype, [numTokens, dim], 'silu_rowsplit_output');
|
|
@@ -288,6 +303,7 @@ export async function recordSiLU(
|
|
|
288
303
|
outputBuffer = null,
|
|
289
304
|
swigluLimit,
|
|
290
305
|
gateActivation = 'silu',
|
|
306
|
+
inputActivation = 'silu',
|
|
291
307
|
} = options;
|
|
292
308
|
const resolvedSwigluLimit = resolveSwigluLimit(swigluLimit, 'SiLU');
|
|
293
309
|
|
|
@@ -302,14 +318,17 @@ export async function recordSiLU(
|
|
|
302
318
|
useSplit: false,
|
|
303
319
|
useRowsplit: false,
|
|
304
320
|
});
|
|
305
|
-
const constants =
|
|
306
|
-
|
|
307
|
-
:
|
|
321
|
+
const constants = {
|
|
322
|
+
...(overrides || {}),
|
|
323
|
+
...(gate && gateActivation === 'sigmoid' ? { GATE_USE_SIGMOID: true } : {}),
|
|
324
|
+
...(inputActivation === 'identity' ? { INPUT_USE_IDENTITY: true } : {}),
|
|
325
|
+
};
|
|
308
326
|
const pipeline = await getPipelineFast('silu', variant, null, constants);
|
|
309
327
|
|
|
310
328
|
const inferredSize = size || (input.buffer.size / bytesPerElement);
|
|
311
329
|
const outputSize = inferredSize * bytesPerElement;
|
|
312
330
|
const output = outputBuffer || acquireBuffer(outputSize, undefined, 'silu_output');
|
|
331
|
+
const dispatchPlan = planSiLUDispatch(device, inferredSize, false);
|
|
313
332
|
|
|
314
333
|
// Uniform buffer
|
|
315
334
|
const uniformBuffer = createUniformBufferWithView(
|
|
@@ -317,7 +336,7 @@ export async function recordSiLU(
|
|
|
317
336
|
16,
|
|
318
337
|
(view) => {
|
|
319
338
|
view.setUint32(0, inferredSize, true);
|
|
320
|
-
view.setUint32(4,
|
|
339
|
+
view.setUint32(4, dispatchPlan.dispatchStride, true);
|
|
321
340
|
view.setFloat32(8, gate ? resolvedSwigluLimit : 0, true);
|
|
322
341
|
view.setFloat32(12, 0, true);
|
|
323
342
|
},
|
|
@@ -333,8 +352,7 @@ export async function recordSiLU(
|
|
|
333
352
|
entries,
|
|
334
353
|
});
|
|
335
354
|
|
|
336
|
-
|
|
337
|
-
recordDispatch(recorder, pipeline, bindGroup, workgroups, 'silu');
|
|
355
|
+
recordDispatch(recorder, pipeline, bindGroup, dispatchPlan.workgroups, 'silu');
|
|
338
356
|
|
|
339
357
|
return createTensor(output, input.dtype, [inferredSize], 'silu_output');
|
|
340
358
|
}
|
|
@@ -10,13 +10,14 @@
|
|
|
10
10
|
override WORKGROUP_SIZE: u32 = 256u;
|
|
11
11
|
override HAS_GATE: bool = false;
|
|
12
12
|
override GATE_USE_SIGMOID: bool = false;
|
|
13
|
+
override INPUT_USE_IDENTITY: bool = false;
|
|
13
14
|
override USE_SPLIT: bool = false;
|
|
14
15
|
override USE_VEC4: bool = false;
|
|
15
16
|
override USE_ROWSPLIT: bool = false;
|
|
16
17
|
|
|
17
18
|
struct Uniforms {
|
|
18
19
|
size: u32, // Total output elements
|
|
19
|
-
rowsplit_dim: u32, //
|
|
20
|
+
rowsplit_dim: u32, // Row-split dim or dispatch stride for non-row-split variants
|
|
20
21
|
clamp_max: f32, // SwiGLU clamp (0 = disabled)
|
|
21
22
|
_pad1: f32,
|
|
22
23
|
}
|
|
@@ -35,6 +36,10 @@ fn silu(x: f32) -> f32 {
|
|
|
35
36
|
return x * sigmoid(x);
|
|
36
37
|
}
|
|
37
38
|
|
|
39
|
+
fn apply_input_activation(x: f32) -> f32 {
|
|
40
|
+
return select(silu(x), x, INPUT_USE_IDENTITY);
|
|
41
|
+
}
|
|
42
|
+
|
|
38
43
|
fn clamp_swiglu(x: f32) -> f32 {
|
|
39
44
|
if (u.clamp_max <= 0.0) {
|
|
40
45
|
return x;
|
|
@@ -46,8 +51,9 @@ fn clamp_swiglu(x: f32) -> f32 {
|
|
|
46
51
|
fn main(
|
|
47
52
|
@builtin(global_invocation_id) global_id: vec3<u32>
|
|
48
53
|
) {
|
|
54
|
+
let dispatch_stride = max(u.rowsplit_dim, 1u);
|
|
49
55
|
if (USE_VEC4) {
|
|
50
|
-
let base_idx = global_id.x * 4u;
|
|
56
|
+
let base_idx = global_id.y * dispatch_stride + global_id.x * 4u;
|
|
51
57
|
if (base_idx >= u.size) {
|
|
52
58
|
return;
|
|
53
59
|
}
|
|
@@ -55,12 +61,12 @@ fn main(
|
|
|
55
61
|
let remaining = min(4u, u.size - base_idx);
|
|
56
62
|
for (var i: u32 = 0u; i < remaining; i = i + 1u) {
|
|
57
63
|
let x = input[base_idx + i];
|
|
58
|
-
output[base_idx + i] =
|
|
64
|
+
output[base_idx + i] = apply_input_activation(x);
|
|
59
65
|
}
|
|
60
66
|
return;
|
|
61
67
|
}
|
|
62
68
|
|
|
63
|
-
let idx = global_id.x;
|
|
69
|
+
let idx = global_id.y * dispatch_stride + global_id.x;
|
|
64
70
|
if (idx >= u.size) {
|
|
65
71
|
return;
|
|
66
72
|
}
|
|
@@ -70,12 +76,16 @@ fn main(
|
|
|
70
76
|
return;
|
|
71
77
|
}
|
|
72
78
|
let dim = u.rowsplit_dim;
|
|
73
|
-
let
|
|
74
|
-
let
|
|
79
|
+
let num_tokens = u.size / dim;
|
|
80
|
+
let token_idx = global_id.y;
|
|
81
|
+
let dim_idx = global_id.x;
|
|
82
|
+
if (token_idx >= num_tokens || dim_idx >= dim) {
|
|
83
|
+
return;
|
|
84
|
+
}
|
|
75
85
|
let row_base = token_idx * dim * 2u;
|
|
76
86
|
let g = input[row_base + dim_idx];
|
|
77
87
|
let up = input[row_base + dim + dim_idx];
|
|
78
|
-
output[
|
|
88
|
+
output[token_idx * dim + dim_idx] = clamp_swiglu(silu(g) * up);
|
|
79
89
|
return;
|
|
80
90
|
}
|
|
81
91
|
|
|
@@ -83,7 +93,7 @@ fn main(
|
|
|
83
93
|
let up = input[idx];
|
|
84
94
|
let g = gate[idx];
|
|
85
95
|
let gateAct = select(silu(g), sigmoid(g), GATE_USE_SIGMOID);
|
|
86
|
-
output[idx] = clamp_swiglu(gateAct * up);
|
|
96
|
+
output[idx] = clamp_swiglu(gateAct * apply_input_activation(up));
|
|
87
97
|
return;
|
|
88
98
|
}
|
|
89
99
|
|
|
@@ -95,5 +105,5 @@ fn main(
|
|
|
95
105
|
}
|
|
96
106
|
|
|
97
107
|
let x = input[idx];
|
|
98
|
-
output[idx] =
|
|
108
|
+
output[idx] = apply_input_activation(x);
|
|
99
109
|
}
|
|
@@ -9,13 +9,14 @@ enable f16;
|
|
|
9
9
|
override WORKGROUP_SIZE: u32 = 256u;
|
|
10
10
|
override HAS_GATE: bool = false;
|
|
11
11
|
override GATE_USE_SIGMOID: bool = false;
|
|
12
|
+
override INPUT_USE_IDENTITY: bool = false;
|
|
12
13
|
override USE_SPLIT: bool = false;
|
|
13
14
|
override USE_VEC4: bool = false;
|
|
14
15
|
override USE_ROWSPLIT: bool = false;
|
|
15
16
|
|
|
16
17
|
struct Uniforms {
|
|
17
18
|
size: u32, // Total output elements
|
|
18
|
-
rowsplit_dim: u32, //
|
|
19
|
+
rowsplit_dim: u32, // Row-split dim or dispatch stride for non-row-split variants
|
|
19
20
|
clamp_max: f32, // SwiGLU clamp (0 = disabled)
|
|
20
21
|
_pad1: f32,
|
|
21
22
|
}
|
|
@@ -34,6 +35,10 @@ fn silu(x: f32) -> f32 {
|
|
|
34
35
|
return x * sigmoid(x);
|
|
35
36
|
}
|
|
36
37
|
|
|
38
|
+
fn apply_input_activation(x: f32) -> f32 {
|
|
39
|
+
return select(silu(x), x, INPUT_USE_IDENTITY);
|
|
40
|
+
}
|
|
41
|
+
|
|
37
42
|
fn clamp_swiglu(x: f32) -> f32 {
|
|
38
43
|
if (u.clamp_max <= 0.0) {
|
|
39
44
|
return x;
|
|
@@ -45,8 +50,9 @@ fn clamp_swiglu(x: f32) -> f32 {
|
|
|
45
50
|
fn main(
|
|
46
51
|
@builtin(global_invocation_id) global_id: vec3<u32>
|
|
47
52
|
) {
|
|
53
|
+
let dispatch_stride = max(u.rowsplit_dim, 1u);
|
|
48
54
|
if (USE_VEC4) {
|
|
49
|
-
let base_idx = global_id.x * 4u;
|
|
55
|
+
let base_idx = global_id.y * dispatch_stride + global_id.x * 4u;
|
|
50
56
|
if (base_idx >= u.size) {
|
|
51
57
|
return;
|
|
52
58
|
}
|
|
@@ -54,12 +60,12 @@ fn main(
|
|
|
54
60
|
let remaining = min(4u, u.size - base_idx);
|
|
55
61
|
for (var i: u32 = 0u; i < remaining; i = i + 1u) {
|
|
56
62
|
let x = f32(input[base_idx + i]);
|
|
57
|
-
output[base_idx + i] = f16(
|
|
63
|
+
output[base_idx + i] = f16(apply_input_activation(x));
|
|
58
64
|
}
|
|
59
65
|
return;
|
|
60
66
|
}
|
|
61
67
|
|
|
62
|
-
let idx = global_id.x;
|
|
68
|
+
let idx = global_id.y * dispatch_stride + global_id.x;
|
|
63
69
|
if (idx >= u.size) {
|
|
64
70
|
return;
|
|
65
71
|
}
|
|
@@ -69,12 +75,16 @@ fn main(
|
|
|
69
75
|
return;
|
|
70
76
|
}
|
|
71
77
|
let dim = u.rowsplit_dim;
|
|
72
|
-
let
|
|
73
|
-
let
|
|
78
|
+
let num_tokens = u.size / dim;
|
|
79
|
+
let token_idx = global_id.y;
|
|
80
|
+
let dim_idx = global_id.x;
|
|
81
|
+
if (token_idx >= num_tokens || dim_idx >= dim) {
|
|
82
|
+
return;
|
|
83
|
+
}
|
|
74
84
|
let row_base = token_idx * dim * 2u;
|
|
75
85
|
let g = f32(input[row_base + dim_idx]);
|
|
76
86
|
let up = f32(input[row_base + dim + dim_idx]);
|
|
77
|
-
output[
|
|
87
|
+
output[token_idx * dim + dim_idx] = f16(clamp_swiglu(silu(g) * up));
|
|
78
88
|
return;
|
|
79
89
|
}
|
|
80
90
|
|
|
@@ -82,7 +92,7 @@ fn main(
|
|
|
82
92
|
let up = f32(input[idx]);
|
|
83
93
|
let g = f32(gate[idx]);
|
|
84
94
|
let gateAct = select(silu(g), sigmoid(g), GATE_USE_SIGMOID);
|
|
85
|
-
output[idx] = f16(clamp_swiglu(gateAct * up));
|
|
95
|
+
output[idx] = f16(clamp_swiglu(gateAct * apply_input_activation(up)));
|
|
86
96
|
return;
|
|
87
97
|
}
|
|
88
98
|
|
|
@@ -94,5 +104,5 @@ fn main(
|
|
|
94
104
|
}
|
|
95
105
|
|
|
96
106
|
let x = f32(input[idx]);
|
|
97
|
-
output[idx] = f16(
|
|
107
|
+
output[idx] = f16(apply_input_activation(x));
|
|
98
108
|
}
|