@simulatte/doppler 0.1.7 → 0.1.9

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (172) hide show
  1. package/CHANGELOG.md +32 -0
  2. package/README.md +25 -6
  3. package/package.json +25 -38
  4. package/src/browser/browser-converter.js +5 -0
  5. package/src/client/doppler-api.browser.js +6 -0
  6. package/src/client/doppler-api.d.ts +3 -0
  7. package/src/client/doppler-api.js +11 -2
  8. package/src/client/doppler-registry.js +3 -5
  9. package/src/client/doppler-registry.json +2 -2
  10. package/src/config/kernel-path-loader.d.ts +5 -0
  11. package/src/config/kernel-path-loader.js +13 -0
  12. package/src/config/kernels/kernel-ref-digests.js +23 -21
  13. package/src/config/kernels/moe/mixtral.paths.json +46 -0
  14. package/src/config/kernels/registry.json +74 -0
  15. package/src/config/loader.js +9 -0
  16. package/src/config/merge-contract-check.js +7 -0
  17. package/src/config/platforms/loader.js +3 -1
  18. package/src/config/presets/kernel-paths/gemma3-q4k-dequant-f32a-nosubgroups.json +16 -16
  19. package/src/config/presets/kernel-paths/gemma3-q4k-dequant-f32a-online.json +8 -8
  20. package/src/config/presets/kernel-paths/gemma3-q4k-dequant-f32a-small-attn.json +61 -0
  21. package/src/config/presets/kernel-paths/gemma3-q4k-dequant-f32w-f32a-online.json +56 -0
  22. package/src/config/presets/kernel-paths/lfm2-q4k-dequant-f32a-nosubgroups.json +61 -0
  23. package/src/config/presets/kernel-paths/registry.json +21 -0
  24. package/src/config/presets/models/gemma2.json +2 -1
  25. package/src/config/presets/models/gemma3.json +4 -1
  26. package/src/config/presets/models/gemma4.json +61 -0
  27. package/src/config/presets/models/granite-docling.json +70 -0
  28. package/src/config/presets/models/lfm2.json +6 -1
  29. package/src/config/presets/models/qwen3.json +4 -3
  30. package/src/config/presets/models/qwen3_5.json +16 -0
  31. package/src/config/presets/models/qwen3_vl.json +40 -0
  32. package/src/config/presets/runtime/experiments/bench/gemma3-bench-q4k.json +2 -1
  33. package/src/config/presets/runtime/experiments/verify/lfm2-verify.json +46 -0
  34. package/src/config/presets/runtime/experiments/verify/translategemma-verify.json +39 -0
  35. package/src/config/presets/runtime/model/qwen3-5-layer-probe.json +52 -0
  36. package/src/config/presets/runtime/model/qwen3-5-linear-attn-debug.json +90 -0
  37. package/src/config/presets/runtime/modes/trace-layers.json +1 -0
  38. package/src/config/presets/runtime/tiers/gemma4-16gb.json +69 -0
  39. package/src/config/presets/runtime/tiers/gemma4-24gb.json +66 -0
  40. package/src/config/presets/runtime/tiers/gemma4-32gb.json +66 -0
  41. package/src/config/runtime.js +3 -0
  42. package/src/config/schema/conversion.schema.d.ts +1 -0
  43. package/src/config/schema/debug.schema.d.ts +40 -0
  44. package/src/config/schema/debug.schema.js +28 -0
  45. package/src/config/schema/index.js +2 -0
  46. package/src/config/schema/inference-defaults.schema.js +1 -1
  47. package/src/config/schema/kernel-path.schema.d.ts +1 -0
  48. package/src/config/schema/manifest.schema.d.ts +1 -1
  49. package/src/config/schema/manifest.schema.js +1 -1
  50. package/src/config/schema/memory-limits.schema.js +2 -2
  51. package/src/config/schema/storage.schema.js +2 -2
  52. package/src/converter/conversion-plan.js +11 -3
  53. package/src/converter/core.js +19 -8
  54. package/src/converter/manifest-inference.js +12 -22
  55. package/src/converter/parsers/transformer.js +4 -0
  56. package/src/converter/quantization-info.js +5 -1
  57. package/src/converter/quantizer.d.ts +5 -0
  58. package/src/converter/quantizer.js +34 -12
  59. package/src/converter/rope-config.js +8 -6
  60. package/src/converter/tokenizer-utils.d.ts +1 -0
  61. package/src/converter/tokenizer-utils.js +4 -1
  62. package/src/debug/reference/hf_qwen35_linear_attn_debug.py +268 -0
  63. package/src/distribution/shard-delivery.js +40 -1
  64. package/src/formats/rdrr/classification.js +32 -0
  65. package/src/formats/rdrr/parsing.d.ts +4 -0
  66. package/src/formats/rdrr/parsing.js +14 -1
  67. package/src/gpu/kernel-runtime.js +4 -2
  68. package/src/gpu/kernels/attention.js +2 -1
  69. package/src/gpu/kernels/dequant_f16_out.wgsl +4 -2
  70. package/src/gpu/kernels/dequant_f16_out_vec4.wgsl +5 -2
  71. package/src/gpu/kernels/dequant_shared.wgsl +4 -2
  72. package/src/gpu/kernels/dequant_shared_vec4.wgsl +4 -2
  73. package/src/gpu/kernels/dequant_subgroup.wgsl +6 -2
  74. package/src/gpu/kernels/gated-short-conv.d.ts +63 -0
  75. package/src/gpu/kernels/gated-short-conv.js +284 -0
  76. package/src/gpu/kernels/index.d.ts +8 -0
  77. package/src/gpu/kernels/index.js +6 -0
  78. package/src/gpu/kernels/linear-attention-core.js +37 -17
  79. package/src/gpu/kernels/matmul-selection.js +48 -4
  80. package/src/gpu/kernels/matmul.d.ts +5 -0
  81. package/src/gpu/kernels/matmul.js +71 -2
  82. package/src/gpu/kernels/matmul_gemv_subgroup.wgsl +77 -79
  83. package/src/gpu/kernels/rmsnorm.js +9 -2
  84. package/src/gpu/kernels/sample.js +1 -3
  85. package/src/gpu/kernels/sample.wgsl +39 -9
  86. package/src/gpu/kernels/sample_f16.wgsl +38 -8
  87. package/src/gpu/kernels/shader-cache.js +9 -4
  88. package/src/gpu/kernels/split_qg.d.ts +50 -0
  89. package/src/gpu/kernels/split_qg.js +46 -0
  90. package/src/gpu/kernels/split_qg.wgsl +58 -0
  91. package/src/gpu/kernels/split_qg_f16.wgsl +62 -0
  92. package/src/gpu/weight-buffer.d.ts +1 -1
  93. package/src/gpu/weight-buffer.js +1 -1
  94. package/src/inference/browser-harness.d.ts +2 -0
  95. package/src/inference/browser-harness.js +20 -1
  96. package/src/inference/kv-cache/base.js +3 -10
  97. package/src/inference/pipelines/diffusion/helpers.js +3 -0
  98. package/src/inference/pipelines/diffusion/pipeline.js +2 -1
  99. package/src/inference/pipelines/diffusion/text-encoder-gpu.js +10 -3
  100. package/src/inference/pipelines/text/attention/output-projection.d.ts +12 -0
  101. package/src/inference/pipelines/text/attention/output-projection.js +8 -0
  102. package/src/inference/pipelines/text/attention/projections.d.ts +13 -1
  103. package/src/inference/pipelines/text/attention/projections.js +54 -13
  104. package/src/inference/pipelines/text/attention/record.js +16 -6
  105. package/src/inference/pipelines/text/attention/run.js +59 -6
  106. package/src/inference/pipelines/text/config.d.ts +1 -0
  107. package/src/inference/pipelines/text/config.js +46 -4
  108. package/src/inference/pipelines/text/embed.js +26 -7
  109. package/src/inference/pipelines/text/execution-plan.js +5 -4
  110. package/src/inference/pipelines/text/execution-v0-runtime-builders.js +10 -3
  111. package/src/inference/pipelines/text/execution-v0.js +12 -1
  112. package/src/inference/pipelines/text/generator-helpers.js +1 -0
  113. package/src/inference/pipelines/text/generator-runtime.js +19 -0
  114. package/src/inference/pipelines/text/generator-steps.d.ts +15 -0
  115. package/src/inference/pipelines/text/generator-steps.js +71 -26
  116. package/src/inference/pipelines/text/generator.d.ts +5 -0
  117. package/src/inference/pipelines/text/generator.js +353 -166
  118. package/src/inference/pipelines/text/init.d.ts +15 -0
  119. package/src/inference/pipelines/text/init.js +35 -10
  120. package/src/inference/pipelines/text/layer.js +38 -8
  121. package/src/inference/pipelines/text/linear-attention.d.ts +5 -0
  122. package/src/inference/pipelines/text/linear-attention.js +33 -3
  123. package/src/inference/pipelines/text/logits/gpu.js +2 -2
  124. package/src/inference/pipelines/text/logits/index.d.ts +6 -1
  125. package/src/inference/pipelines/text/logits/index.js +3 -1
  126. package/src/inference/pipelines/text/model-load.js +3 -0
  127. package/src/inference/pipelines/text/moe-gpu.js +21 -3
  128. package/src/inference/pipelines/text/moe-shape-validator.d.ts +9 -0
  129. package/src/inference/pipelines/text/moe-shape-validator.js +31 -11
  130. package/src/inference/pipelines/text/ops.js +123 -53
  131. package/src/inference/pipelines/text/probes.js +1 -0
  132. package/src/inference/pipelines/text/sampling.js +52 -6
  133. package/src/inference/pipelines/text/state.js +2 -0
  134. package/src/inference/pipelines/text.d.ts +5 -0
  135. package/src/inference/pipelines/text.js +59 -1
  136. package/src/inference/pipelines/vision/encoder.js +386 -0
  137. package/src/inference/pipelines/vision/image-preprocess.js +151 -0
  138. package/src/inference/pipelines/vision/index.js +173 -0
  139. package/src/inference/pipelines/vision/ops.js +78 -0
  140. package/src/inference/pipelines/vision/patch-embed.js +151 -0
  141. package/src/inference/test-harness.js +11 -9
  142. package/src/loader/doppler-loader.d.ts +3 -0
  143. package/src/loader/doppler-loader.js +20 -3
  144. package/src/loader/experts/expert-cache.js +6 -2
  145. package/src/loader/experts/expert-loader.js +6 -2
  146. package/src/loader/final-weights-loader.js +2 -0
  147. package/src/loader/layer-loader.js +42 -3
  148. package/src/loader/manifest-config.js +3 -1
  149. package/src/loader/shard-cache.js +3 -2
  150. package/src/loader/tensors/tensor-loader.d.ts +3 -0
  151. package/src/loader/tensors/tensor-loader.js +130 -4
  152. package/src/rules/inference/dtype.rules.json +5 -0
  153. package/src/rules/inference/kernel-path.rules.json +2 -2
  154. package/src/rules/kernels/moe.rules.mixtral.json +75 -0
  155. package/src/rules/kernels/softmax.rules.json +2 -0
  156. package/src/rules/kernels/split-qg.rules.json +6 -0
  157. package/src/rules/rule-registry.d.ts +1 -0
  158. package/src/rules/rule-registry.js +4 -0
  159. package/src/storage/downloader.js +2 -1
  160. package/src/storage/quickstart-downloader.d.ts +3 -0
  161. package/src/storage/quickstart-downloader.js +27 -30
  162. package/src/storage/shard-manager.js +4 -3
  163. package/src/tooling/conversion-config-materializer.js +3 -5
  164. package/src/tooling/node-converter.js +28 -7
  165. package/src/tooling/node-source-runtime.js +65 -5
  166. package/src/tooling/node-webgpu.js +24 -7
  167. package/src/types/model.d.ts +5 -0
  168. package/src/utils/hf-resolve-url.d.ts +16 -0
  169. package/src/utils/hf-resolve-url.js +17 -0
  170. package/src/version.js +1 -1
  171. package/tools/doppler-cli.js +6 -1
  172. package/src/tooling/node-convert.d.ts +0 -54
@@ -10,7 +10,7 @@ import {
10
10
  import { recordDispatch } from './dispatch.js';
11
11
 
12
12
  const CONV_WORKGROUP_SIZE = WORKGROUP_SIZES.DEFAULT;
13
- const HEAD_WORKGROUP_SIZE = 64;
13
+ const HEAD_WORKGROUP_SIZE = 128;
14
14
 
15
15
  const CONV_SHADER = /* wgsl */ `
16
16
  override WORKGROUP_SIZE: u32 = 256u;
@@ -79,7 +79,7 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
79
79
  `;
80
80
 
81
81
  const RECURRENT_SHADER = /* wgsl */ `
82
- override WORKGROUP_SIZE: u32 = 64u;
82
+ override WORKGROUP_SIZE: u32 = 128u;
83
83
 
84
84
  struct LinearAttentionParams {
85
85
  num_tokens: u32,
@@ -111,6 +111,8 @@ struct LinearAttentionParams {
111
111
  @group(0) @binding(8) var<storage, read_write> recurrent_state: array<f32>;
112
112
  @group(0) @binding(9) var<storage, read_write> output: array<f32>;
113
113
 
114
+ var<workgroup> shared_sq: array<f32, WORKGROUP_SIZE>;
115
+
114
116
  fn softplus(x: f32) -> f32 {
115
117
  if (x > 20.0) {
116
118
  return x;
@@ -131,17 +133,19 @@ fn silu(x: f32) -> f32 {
131
133
  }
132
134
 
133
135
  @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
134
- fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
135
- let head = gid.x;
136
+ fn main(@builtin(workgroup_id) wid: vec3<u32>,
137
+ @builtin(local_invocation_id) lid: vec3<u32>) {
138
+ let head = wid.x;
139
+ let vd = lid.x;
136
140
  if (head >= params.num_v_heads) {
137
141
  return;
138
142
  }
139
143
 
140
144
  let head_k_dim = params.head_k_dim;
141
145
  let head_v_dim = params.head_v_dim;
146
+ let is_active = vd < head_v_dim;
142
147
  let head_scale = inverseSqrt(f32(head_k_dim));
143
148
  let recurrent_head_base = head * head_k_dim * head_v_dim;
144
- let recurrent_head_size = head_k_dim * head_v_dim;
145
149
  let q_rep = max(params.q_rep, 1u);
146
150
  let src_head = head / q_rep;
147
151
  let q_base = src_head * head_k_dim;
@@ -154,6 +158,7 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
154
158
  let ab_row_base = token_idx * params.num_v_heads + head;
155
159
  let out_row_base = token_idx * params.value_dim + head * head_v_dim;
156
160
 
161
+ // L2 norm for Q and K (redundant across threads but avoids shared memory)
157
162
  var q_norm_sq = 0.0;
158
163
  var k_norm_sq = 0.0;
159
164
  for (var d: u32 = 0u; d < head_k_dim; d = d + 1u) {
@@ -169,11 +174,16 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
169
174
  let g = a_neg_exp[head] * softplus(a_proj[ab_row_base] + dt_bias[head]);
170
175
  let g_exp = exp(g);
171
176
 
172
- for (var i: u32 = 0u; i < recurrent_head_size; i = i + 1u) {
173
- recurrent_state[recurrent_head_base + i] = recurrent_state[recurrent_head_base + i] * g_exp;
177
+ // Decay state each thread handles head_k_dim elements at stride head_v_dim
178
+ if (is_active) {
179
+ for (var kd: u32 = 0u; kd < head_k_dim; kd = kd + 1u) {
180
+ let state_idx = recurrent_head_base + kd * head_v_dim + vd;
181
+ recurrent_state[state_idx] = recurrent_state[state_idx] * g_exp;
182
+ }
174
183
  }
175
184
 
176
- for (var vd: u32 = 0u; vd < head_v_dim; vd = vd + 1u) {
185
+ // Delta update each thread handles one vd slice (no cross-thread dependency)
186
+ if (is_active) {
177
187
  var kv_mem = 0.0;
178
188
  for (var kd: u32 = 0u; kd < head_k_dim; kd = kd + 1u) {
179
189
  let k_normed = conv_out[conv_row_base + k_base + kd] * k_norm_scale;
@@ -188,21 +198,31 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
188
198
  }
189
199
  }
190
200
 
191
- var mean_sq = 0.0;
192
- for (var vd: u32 = 0u; vd < head_v_dim; vd = vd + 1u) {
193
- var out_value = 0.0;
201
+ // Output each thread computes one vd element
202
+ var out_value = 0.0;
203
+ if (is_active) {
194
204
  for (var kd: u32 = 0u; kd < head_k_dim; kd = kd + 1u) {
195
205
  let q_normed = conv_out[conv_row_base + q_base + kd] * q_norm_scale;
196
206
  let state_idx = recurrent_head_base + kd * head_v_dim + vd;
197
207
  out_value = out_value + recurrent_state[state_idx] * q_normed;
198
208
  }
199
209
  output[out_row_base + vd] = out_value;
200
- let value = out_value;
201
- mean_sq = mean_sq + value * value;
202
210
  }
203
- let inv_rms = inverseSqrt(mean_sq / f32(head_v_dim) + params.rms_norm_eps);
204
211
 
205
- for (var vd: u32 = 0u; vd < head_v_dim; vd = vd + 1u) {
212
+ // RMS norm reduction across vd (workgroup-level)
213
+ shared_sq[vd] = select(0.0, out_value * out_value, is_active);
214
+ workgroupBarrier();
215
+ // Tree reduction
216
+ for (var stride: u32 = WORKGROUP_SIZE / 2u; stride > 0u; stride = stride / 2u) {
217
+ if (vd < stride) {
218
+ shared_sq[vd] = shared_sq[vd] + shared_sq[vd + stride];
219
+ }
220
+ workgroupBarrier();
221
+ }
222
+ let inv_rms = inverseSqrt(shared_sq[0] / f32(head_v_dim) + params.rms_norm_eps);
223
+
224
+ // Apply norm + gate
225
+ if (is_active) {
206
226
  let gate = silu(z_proj[z_row_base + vd]);
207
227
  let norm_index = select(vd, head * head_v_dim + vd, params.norm_mode == 1u);
208
228
  output[out_row_base + vd] = (output[out_row_base + vd] * inv_rms) * norm_weight[norm_index] * gate;
@@ -415,7 +435,7 @@ export async function runLinearAttentionCoreGPU(qkvTensor, zTensor, aTensor, bTe
415
435
  recorder,
416
436
  recurrentPipeline,
417
437
  recurrentBindGroup,
418
- [Math.ceil(layerState.numVHeads / HEAD_WORKGROUP_SIZE), 1, 1],
438
+ [layerState.numVHeads, 1, 1],
419
439
  'linear_attention_recurrent'
420
440
  );
421
441
 
@@ -502,7 +522,7 @@ export async function runLinearAttentionCoreGPU(qkvTensor, zTensor, aTensor, bTe
502
522
  const pass = encoder.beginComputePass({ label: 'linear_attention_recurrent_pass' });
503
523
  pass.setPipeline(recurrentPipeline);
504
524
  pass.setBindGroup(0, recurrentBindGroup);
505
- pass.dispatchWorkgroups(Math.ceil(layerState.numVHeads / HEAD_WORKGROUP_SIZE), 1, 1);
525
+ pass.dispatchWorkgroups(layerState.numVHeads, 1, 1);
506
526
  pass.end();
507
527
  }
508
528
 
@@ -29,7 +29,13 @@ function selectQ4KFusedVariant(isM1, wantF16Output, aDtype) {
29
29
  }
30
30
 
31
31
 
32
- export function resolveMatmulPhase(M) {
32
+ export function resolveMatmulPhase(M, phaseOverride = null) {
33
+ if (phaseOverride != null) {
34
+ if (phaseOverride !== 'decode' && phaseOverride !== 'prefill') {
35
+ throw new Error(`[Matmul] Invalid phase override "${phaseOverride}". Expected "decode" or "prefill".`);
36
+ }
37
+ return phaseOverride;
38
+ }
33
39
  return selectKernelRuleValue('matmul', 'phase', { isDecode: M === 1 });
34
40
  }
35
41
 
@@ -86,6 +92,7 @@ export function getMatmulConfig(variant, constants) {
86
92
 
87
93
 
88
94
  export function isFusedQ4KDisabled(options = {}) {
95
+ if (options.disableFusedQ4K === true) return true;
89
96
  const capabilities = getKernelCapabilities();
90
97
  const hasSubgroups = capabilities?.hasSubgroups === true;
91
98
 
@@ -125,7 +132,9 @@ export function selectMatmulKernel(options = {}) {
125
132
  const { tiledPrefillMinRows } = getKernelThresholds().matmul;
126
133
 
127
134
  const inputsAreF16 = aDtype === 'f16' && bDtype === 'f16';
128
- const weightsAreF16 = bDtype === 'f16' && aDtype !== 'f16';
135
+ // F16 weights needing F32a path: weights are F16 and either activation is already F32,
136
+ // or both inputs are F16 but output is F32 (activation will be cast to F32 by executeMatmul)
137
+ const weightsAreF16 = bDtype === 'f16' && (aDtype !== 'f16' || outputDtype !== 'f16');
129
138
  const useF16Matmul = outputDtype === 'f16' && preferF16 && inputsAreF16 && capabilities.hasF16;
130
139
  const useF16wF32a = preferF16 && weightsAreF16 && capabilities.hasF16;
131
140
  const useTiled = isPrefill
@@ -244,6 +253,30 @@ export function requiresF32Input(variant) {
244
253
  return !supportsF16Input(variant);
245
254
  }
246
255
 
256
+ function resolveRequiredWeightDtype(config) {
257
+ const shaderFile = String(config?.shaderFile ?? config?.wgsl ?? '');
258
+ if (!shaderFile) {
259
+ return null;
260
+ }
261
+ if (shaderFile.startsWith('fused_matmul_q4')) {
262
+ return 'q4k';
263
+ }
264
+ if (
265
+ shaderFile === 'matmul_f16.wgsl'
266
+ || shaderFile === 'matmul_f16_tiled.wgsl'
267
+ || shaderFile === 'matmul_f16w_f32a.wgsl'
268
+ || shaderFile === 'matmul_f16w_f32a_tiled.wgsl'
269
+ || shaderFile === 'matmul_gemv_subgroup.wgsl'
270
+ || shaderFile === 'matmul_gemv_subgroup_f16a.wgsl'
271
+ ) {
272
+ return 'f16';
273
+ }
274
+ if (shaderFile === 'matmul_f32.wgsl') {
275
+ return 'f32';
276
+ }
277
+ return null;
278
+ }
279
+
247
280
 
248
281
  function resolveMatmulOverride(
249
282
  variantOverride,
@@ -287,6 +320,16 @@ function resolveMatmulOverride(
287
320
  );
288
321
  }
289
322
 
323
+ const requiredWeightDtype = resolveRequiredWeightDtype(config);
324
+ const weightDtypeOk = !requiredWeightDtype
325
+ || bDtype === requiredWeightDtype
326
+ || (requiredWeightDtype === 'f16' && bDtype === 'q4k');
327
+ if (!weightDtypeOk) {
328
+ return failOrWarn(
329
+ `Matmul kernel "${variantOverride}" requires ${requiredWeightDtype} weights but B dtype is ${bDtype}.`
330
+ );
331
+ }
332
+
290
333
  if (supportsF16Input(override) && aDtype !== 'f16') {
291
334
  return failOrWarn(`Matmul kernel "${variantOverride}" requires f16 activations but A dtype is ${aDtype}.`);
292
335
  }
@@ -341,7 +384,7 @@ function selectGemvVariant(useF16Gemv, useF32Gemv, hasSubgroups, useVec4, N, mul
341
384
  export function selectMatmulVariantAndFlags(mode, M, N, K, aDtype, bDtype, transposeB, requestedOutputDtype, options) {
342
385
  const capabilities = getKernelCapabilities();
343
386
  const strict = getKernelPathStrict();
344
- const phase = resolveMatmulPhase(M);
387
+ const phase = resolveMatmulPhase(M, options.phaseOverride ?? null);
345
388
  let pathVariant = getKernelPathMatmulVariant(options.role, phase, options.layerIdx, options.kernelPath);
346
389
  const hadPathVariant = Boolean(pathVariant);
347
390
 
@@ -426,7 +469,8 @@ export function selectMatmulVariantAndFlags(mode, M, N, K, aDtype, bDtype, trans
426
469
 
427
470
  const canGemv = M === 1 && effectiveBDtype === 'f16' && capabilities.hasF16;
428
471
  const useF16Gemv = canGemv && aDtype === 'f16' && wantF16Output;
429
- const useF32Gemv = canGemv && aDtype === 'f32';
472
+ // F32 GEMV: activation is F32, or activation is F16 with F32 output (will be cast to F32)
473
+ const useF32Gemv = canGemv && (aDtype === 'f32' || (aDtype === 'f16' && !wantF16Output));
430
474
  const useGemv = useF16Gemv || useF32Gemv;
431
475
  const useVec4 = (K % 4 === 0);
432
476
  const { multicolThreshold } = getKernelThresholds().matmul;
@@ -13,6 +13,7 @@ import type { WeightBuffer } from '../weight-buffer.js';
13
13
  import type { CommandRecorder } from '../command-recorder.js';
14
14
  import type { OutputBufferOptions, OutputDtypeOptions, Vec4Options } from './types.js';
15
15
  import type { KernelPathSchema } from '../../config/schema/index.js';
16
+ import type { MatmulDebugConfigSchema } from '../../config/schema/debug.schema.js';
16
17
 
17
18
  /** Matmul kernel options */
18
19
  export interface MatmulOptions extends OutputBufferOptions, OutputDtypeOptions, Vec4Options {
@@ -23,6 +24,8 @@ export interface MatmulOptions extends OutputBufferOptions, OutputDtypeOptions,
23
24
  layerIdx?: number;
24
25
  /** Explicit kernel path context for variant selection (avoids global path state). */
25
26
  kernelPath?: KernelPathSchema | null;
27
+ /** Optional explicit phase for kernel-path lookup when the runtime rewrites rows (for example prefill last-position logits). */
28
+ phaseOverride?: 'decode' | 'prefill' | null;
26
29
  /**
27
30
  * Whether B matrix is stored transposed.
28
31
  * - true: B is [N,K] (SafeTensors/row-major), needs transpose
@@ -38,6 +41,8 @@ export interface MatmulOptions extends OutputBufferOptions, OutputDtypeOptions,
38
41
  preferF16?: boolean;
39
42
  /** WGSL override constants for pipeline creation */
40
43
  constants?: Record<string, number | boolean>;
44
+ /** Runtime debug controls for attention projection diagnostics. */
45
+ matmulDebug?: MatmulDebugConfigSchema | null;
41
46
  }
42
47
 
43
48
  /** Context for base matmul kernel selection rules. */
@@ -2,7 +2,7 @@ import { getDevice, getKernelCapabilities } from '../device.js';
2
2
  import { createTensor } from '../tensor.js';
3
3
  import { getBuffer, getLayout, getWeightDtype } from '../weight-buffer.js';
4
4
  import { log, trace, isTraceEnabled } from '../../debug/index.js';
5
- import { releaseBuffer } from '../../memory/buffer-pool.js';
5
+ import { releaseBuffer, readBuffer } from '../../memory/buffer-pool.js';
6
6
  import { releaseUniformBuffer } from '../uniform-cache.js';
7
7
  import { castF16ToF32, recordCastF16ToF32 } from './cast.js';
8
8
  import {
@@ -34,6 +34,24 @@ export { createMatmulBindGroupLayout };
34
34
  let _runMatmulDebugCount = 0;
35
35
  let _recordMatmulDebugCount = 0;
36
36
 
37
+ function normalizeMatmulDebugConfig(config) {
38
+ if (!config || typeof config !== 'object') {
39
+ return null;
40
+ }
41
+ return {
42
+ enabled: config.enabled === true,
43
+ forceSplitQKV: config.forceSplitQKV === true,
44
+ validateAttentionWeightBuffer: config.validateAttentionWeightBuffer === true,
45
+ failOnSmallAttentionWeightBuffer: config.failOnSmallAttentionWeightBuffer === true,
46
+ logAttentionWeightBuffer: config.logAttentionWeightBuffer === true,
47
+ logProjectionValues: config.logProjectionValues === true,
48
+ };
49
+ }
50
+
51
+ function isAttentionProjectionRole(role = '') {
52
+ return role === 'qkv_proj' || role === 'q_proj' || role === 'k_proj' || role === 'v_proj';
53
+ }
54
+
37
55
  function getDebugCounter(isRecord) {
38
56
  return isRecord ? _recordMatmulDebugCount : _runMatmulDebugCount;
39
57
  }
@@ -126,6 +144,12 @@ async function executeMatmul(recorder, A, B, M, N, K, options = {}) {
126
144
  const weightLabel = (B && typeof B === 'object' ? B.label : null) ?? bBuffer?.label ?? null;
127
145
  const weightLayout = getLayout(B);
128
146
  const weightShape = B?.shape ? `[${B.shape.join(', ')}]` : null;
147
+ const matmulDebug = normalizeMatmulDebugConfig(options.matmulDebug);
148
+ const debugAttention = matmulDebug?.enabled === true;
149
+ const isAttnProj = isAttentionProjectionRole(options.role ?? '');
150
+ const shouldValidateAttentionWeightBuffer = debugAttention && matmulDebug.validateAttentionWeightBuffer;
151
+ const shouldFailOnSmallAttentionWeightBuffer = debugAttention && matmulDebug.failOnSmallAttentionWeightBuffer;
152
+ const shouldLogAttentionWeightBuffer = debugAttention && matmulDebug.logAttentionWeightBuffer;
129
153
 
130
154
  if (isTraceEnabled('kernels') && getDebugCounter(isRecord) < 20) {
131
155
  incrementDebugCounter(isRecord);
@@ -165,7 +189,7 @@ async function executeMatmul(recorder, A, B, M, N, K, options = {}) {
165
189
  options
166
190
  );
167
191
 
168
- const phase = resolveMatmulPhase(M);
192
+ const phase = resolveMatmulPhase(M, options.phaseOverride ?? null);
169
193
  const constants = resolveMatmulConstants(options, phase);
170
194
 
171
195
  let matmulInput = A;
@@ -201,6 +225,27 @@ async function executeMatmul(recorder, A, B, M, N, K, options = {}) {
201
225
  bOffset
202
226
  );
203
227
  } catch (err) {
228
+ if (shouldValidateAttentionWeightBuffer && isAttnProj && err instanceof Error && err.message.includes('B buffer too small')) {
229
+ const detailParts = [
230
+ `role=${options.role ?? ''}`,
231
+ `layer=${Number.isFinite(options.layerIdx) ? options.layerIdx : '?'}`,
232
+ `M=${M}`,
233
+ `N=${N}`,
234
+ `K=${K}`,
235
+ ];
236
+ if (weightDtype) detailParts.push(`weightDtype=${weightDtype}`);
237
+ if (weightLayout) detailParts.push(`weightLayout=${weightLayout}`);
238
+ if (weightShape) detailParts.push(`shape=${weightShape}`);
239
+ if (weightLabel) detailParts.push(`label=${weightLabel}`);
240
+ if (Number.isFinite(bBuffer?.size)) detailParts.push(`bSize=${bBuffer.size}`);
241
+ const detail = detailParts.join(' ');
242
+ if (shouldLogAttentionWeightBuffer) {
243
+ log.warn('MatmulQKVProbe', `${err.message} | ${detail}`);
244
+ }
245
+ if (shouldFailOnSmallAttentionWeightBuffer) {
246
+ throw new Error(`${err.message}${detail ? ` (${detail})` : ''}`);
247
+ }
248
+ }
204
249
  if (!isRecord && err instanceof Error && err.message.includes('B buffer too small')) {
205
250
  const detailParts = [];
206
251
  if (weightLabel) detailParts.push(`label=${weightLabel}`);
@@ -226,6 +271,15 @@ async function executeMatmul(recorder, A, B, M, N, K, options = {}) {
226
271
  trace.kernels(`MATMUL_LARGE: N=${N}, variant=${variant}, aDtype=${aDtype}, bDtype=${bDtype}, transposeB=${transposeB}`);
227
272
  }
228
273
 
274
+ if (isAttnProj && shouldLogAttentionWeightBuffer) {
275
+ log.warn('MatmulQKVProbe',
276
+ `role=${options.role ?? ''} layer=${Number.isFinite(options.layerIdx) ? options.layerIdx : '?'} ` +
277
+ `M=${M} N=${N} K=${K} transposeB=${transposeB} bSize=${bBuffer?.size ?? 0} ` +
278
+ `requiredB=${bindingSizes?.bBindingSize ?? 'n/a'} weightShape=${weightShape ?? 'n/a'} ` +
279
+ `weightDtype=${weightDtype ?? 'unknown'} weightLayout=${weightLayout ?? 'unknown'}`
280
+ );
281
+ }
282
+
229
283
  const config = getMatmulConfig(variant, constants);
230
284
  const kernel = new MatmulKernel(device);
231
285
  const pipeline = await getMatmulPipeline(variant, constants);
@@ -238,6 +292,14 @@ async function executeMatmul(recorder, A, B, M, N, K, options = {}) {
238
292
  );
239
293
  const ownsOutput = outputBuffer == null;
240
294
 
295
+ if (isAttnProj && shouldLogAttentionWeightBuffer) {
296
+ log.warn('MatmulVariantDiag',
297
+ `role=${options.role ?? ''} layer=${Number.isFinite(options.layerIdx) ? options.layerIdx : '?'} mode=${mode} ` +
298
+ `variant=${variant} useQ4KFused=${useQ4KFused} useGemv=${useGemv} ` +
299
+ `aDtype=${aDtype} bDtype=${bDtype} output=${actualOutputDtype}`
300
+ );
301
+ }
302
+
241
303
  if (!Number.isFinite(outputSize) || outputSize <= 0) {
242
304
  throw new Error(`[${opLabel}] Invalid output size: ${outputSize} (M=${M}, N=${N})`);
243
305
  }
@@ -290,6 +352,13 @@ async function executeMatmul(recorder, A, B, M, N, K, options = {}) {
290
352
  kernel.dispatch(pipeline, bindGroup, dispatchPlan.workgroups);
291
353
  }
292
354
  completed = true;
355
+ if (!isRecord && matmulDebug?.logProjectionValues && isAttnProj && M === 1 && options.layerIdx === 0) {
356
+ await device.queue.onSubmittedWorkDone();
357
+ const raw = await readBuffer(C);
358
+ const numVals = Math.min(8, Math.floor(raw.byteLength / 4));
359
+ const vals = numVals > 0 ? new Float32Array(raw, 0, numVals) : [];
360
+ log.warn('ProjectionProbe', `role=${options.role ?? ''} L0 M1 first8_f32: ${Array.from(vals).map(v => v.toFixed(5)).join(' ')}`);
361
+ }
293
362
  return createTensor(C, actualOutputDtype, [M, N], 'matmul_output');
294
363
  } finally {
295
364
  if (!isRecord && uniformBuffer) {
@@ -5,7 +5,11 @@
5
5
  // 1. Use subgroupAdd() for reduction - much faster than shared memory
6
6
  // 2. Vectorized vec4 loads for weights
7
7
  // 3. Each workgroup processes multiple output columns
8
- // 4. Loop unrolling for better ILP
8
+ // 4. Warp-stride loop for row-major (transpose_b=1): all threads in a column
9
+ // step through K together, so adjacent threads load adjacent addresses.
10
+ // At each step, 64 threads × 8 bytes = 512 bytes from 4 consecutive cache
11
+ // lines → 100% cache-line utilization vs ~10% for the old contiguous-range
12
+ // pattern (where threads were 80 bytes apart in the same iteration).
9
13
  //
10
14
  // A is f32 (activations), B is f16 (weights), C is f32.
11
15
  // transpose_b=0: B is [K, N] (GGUF/column-major), access B[k * N + col]
@@ -69,40 +73,29 @@ fn main(
69
73
  // Each thread computes partial sum for its assigned k values
70
74
  var partial_sum: f32 = 0.0;
71
75
 
72
- // Only do work if this column is valid
73
76
  if (is_valid) {
74
- // Process K in chunks, each thread handles K/64 elements
75
- let k_per_thread = (u.K + THREADS_PER_COL - 1u) / THREADS_PER_COL;
76
- let k_start = thread_in_col * k_per_thread;
77
- let k_end = min(k_start + k_per_thread, u.K);
78
-
79
- // Main loop - process 4 elements at a time when aligned
80
- var k = k_start;
81
- let k_aligned_end = k_start + ((k_end - k_start) / 4u) * 4u;
82
-
83
77
  if (u.transpose_b == 1u) {
84
- // B is [N, K] (SafeTensors/row-major): B[col, k] = B[col * K + k]
78
+ // B is [N, K] (row-major): B[col, k] = B[col * K + k]
79
+ // Warp-stride: step THREADS_PER_COL elements per outer iteration so that
80
+ // all wavefront threads load consecutive addresses simultaneously.
81
+ // At each step, 64 threads × 2 bytes = 128 bytes = exactly 1 cache line → 100% utilization.
85
82
  let b_row_offset = col * u.K;
86
-
87
- for (; k < k_aligned_end; k = k + 4u) {
88
- let a0 = A[k];
89
- let a1 = A[k + 1u];
90
- let a2 = A[k + 2u];
91
- let a3 = A[k + 3u];
92
-
93
- let b0 = f32(B[b_row_offset + k]);
94
- let b1 = f32(B[b_row_offset + k + 1u]);
95
- let b2 = f32(B[b_row_offset + k + 2u]);
96
- let b3 = f32(B[b_row_offset + k + 3u]);
97
-
98
- partial_sum = partial_sum + a0 * b0 + a1 * b1 + a2 * b2 + a3 * b3;
99
- }
100
-
101
- for (; k < k_end; k = k + 1u) {
102
- partial_sum = partial_sum + A[k] * f32(B[b_row_offset + k]);
83
+ for (var k_base: u32 = 0u; k_base < u.K; k_base = k_base + THREADS_PER_COL) {
84
+ let k = k_base + thread_in_col;
85
+ if (k < u.K) {
86
+ partial_sum = partial_sum + A[k] * f32(B[b_row_offset + k]);
87
+ }
103
88
  }
104
89
  } else {
105
- // B is [K, N] (GGUF/column-major): B[k, col] = B[k * N + col]
90
+ // B is [K, N] (column-major): B[k, col] = B[k * N + col]
91
+ // Contiguous-range per thread: sequential access within each thread.
92
+ let k_per_thread = (u.K + THREADS_PER_COL - 1u) / THREADS_PER_COL;
93
+ let k_start = thread_in_col * k_per_thread;
94
+ let k_end = min(k_start + k_per_thread, u.K);
95
+
96
+ var k = k_start;
97
+ let k_aligned_end = k_start + ((k_end - k_start) / 4u) * 4u;
98
+
106
99
  for (; k < k_aligned_end; k = k + 4u) {
107
100
  let a0 = A[k];
108
101
  let a1 = A[k + 1u];
@@ -189,38 +182,36 @@ fn main_multicol(
189
182
  var partial_sum: f32 = 0.0;
190
183
 
191
184
  if (is_valid) {
192
- // Each of 8 threads splits K
193
- let k_per_thread = (u.K + MULTICOL_THREADS_PER_COL - 1u) / MULTICOL_THREADS_PER_COL;
194
- let k_start = thread_in_col * k_per_thread;
195
- let k_end = min(k_start + k_per_thread, u.K);
196
-
197
- // Unroll by 4 for ILP
198
- var k = k_start;
199
- let k_aligned_end = k_start + ((k_end - k_start) / 4u) * 4u;
200
-
201
185
  if (u.transpose_b == 1u) {
202
- // B is [N, K] (SafeTensors/row-major): B[col, k] = B[col * K + k]
186
+ // B is [N, K] (row-major): B[col, k] = B[col * K + k]
187
+ // Warp-stride: step MULTICOL_THREADS_PER_COL vec4 groups per outer iteration.
188
+ // Adjacent threads in the same column load adjacent vec4 groups → coalesced.
189
+ let K4 = u.K / 4u;
203
190
  let b_row_offset = col * u.K;
204
-
205
- for (; k < k_aligned_end; k = k + 4u) {
206
- let a0 = A[k];
207
- let a1 = A[k + 1u];
208
- let a2 = A[k + 2u];
209
- let a3 = A[k + 3u];
210
-
211
- let b0 = f32(B[b_row_offset + k]);
212
- let b1 = f32(B[b_row_offset + k + 1u]);
213
- let b2 = f32(B[b_row_offset + k + 2u]);
214
- let b3 = f32(B[b_row_offset + k + 3u]);
215
-
216
- partial_sum = partial_sum + a0 * b0 + a1 * b1 + a2 * b2 + a3 * b3;
217
- }
218
-
219
- for (; k < k_end; k = k + 1u) {
220
- partial_sum = partial_sum + A[k] * f32(B[b_row_offset + k]);
191
+ for (var k4_base: u32 = 0u; k4_base < K4; k4_base = k4_base + MULTICOL_THREADS_PER_COL) {
192
+ let k4 = k4_base + thread_in_col;
193
+ if (k4 < K4) {
194
+ let k = k4 * 4u;
195
+ let a0 = A[k];
196
+ let a1 = A[k + 1u];
197
+ let a2 = A[k + 2u];
198
+ let a3 = A[k + 3u];
199
+ let b0 = f32(B[b_row_offset + k]);
200
+ let b1 = f32(B[b_row_offset + k + 1u]);
201
+ let b2 = f32(B[b_row_offset + k + 2u]);
202
+ let b3 = f32(B[b_row_offset + k + 3u]);
203
+ partial_sum = partial_sum + a0 * b0 + a1 * b1 + a2 * b2 + a3 * b3;
204
+ }
221
205
  }
222
206
  } else {
223
- // B is [K, N] (GGUF/column-major): B[k, col] = B[k * N + col]
207
+ // B is [K, N] (column-major): B[k, col] = B[k * N + col]
208
+ let k_per_thread = (u.K + MULTICOL_THREADS_PER_COL - 1u) / MULTICOL_THREADS_PER_COL;
209
+ let k_start = thread_in_col * k_per_thread;
210
+ let k_end = min(k_start + k_per_thread, u.K);
211
+
212
+ var k = k_start;
213
+ let k_aligned_end = k_start + ((k_end - k_start) / 4u) * 4u;
214
+
224
215
  for (; k < k_aligned_end; k = k + 4u) {
225
216
  let a0 = A[k];
226
217
  let a1 = A[k + 1u];
@@ -245,7 +236,7 @@ fn main_multicol(
245
236
  multicol_wg_sums[local_id] = partial_sum;
246
237
  workgroupBarrier();
247
238
 
248
- // Thread 0 of each column reduces its 8 values
239
+ // Thread 0 of each column reduces its MULTICOL_THREADS_PER_COL values
249
240
  if (thread_in_col == 0u && is_valid) {
250
241
  var final_sum: f32 = 0.0;
251
242
  let base = col_in_wg * MULTICOL_THREADS_PER_COL;
@@ -282,30 +273,37 @@ fn main_vec4(
282
273
  if (is_valid) {
283
274
  // K is guaranteed to be multiple of 4
284
275
  let K4 = u.K / 4u;
285
- let k4_per_thread = (K4 + THREADS_PER_COL - 1u) / THREADS_PER_COL;
286
- let k4_start = thread_in_col * k4_per_thread;
287
- let k4_end = min(k4_start + k4_per_thread, K4);
288
276
 
289
277
  if (u.transpose_b == 1u) {
290
- // B is [N, K] (SafeTensors/row-major): B[col, k] = B[col * K + k]
278
+ // B is [N, K] (row-major): B[col, k] = B[col * K + k]
279
+ // Warp-stride: step THREADS_PER_COL vec4 groups per outer iteration so that
280
+ // adjacent threads load adjacent groups → 100% cache-line utilization.
281
+ // At each step: 64 threads × 4 f16 × 2 bytes = 512 bytes from 4 consecutive
282
+ // cache lines, vs the old contiguous-range pattern (~10% utilization).
291
283
  let b_row_offset = col * u.K;
292
-
293
- for (var k4: u32 = k4_start; k4 < k4_end; k4 = k4 + 1u) {
294
- let k = k4 * 4u;
295
-
296
- let a = vec4<f32>(A[k], A[k + 1u], A[k + 2u], A[k + 3u]);
297
-
298
- let b = vec4<f32>(
299
- f32(B[b_row_offset + k]),
300
- f32(B[b_row_offset + k + 1u]),
301
- f32(B[b_row_offset + k + 2u]),
302
- f32(B[b_row_offset + k + 3u])
303
- );
304
-
305
- partial_sum = partial_sum + dot(a, b);
284
+ for (var k4_base: u32 = 0u; k4_base < K4; k4_base = k4_base + THREADS_PER_COL) {
285
+ let k4 = k4_base + thread_in_col;
286
+ if (k4 < K4) {
287
+ let k = k4 * 4u;
288
+
289
+ let a = vec4<f32>(A[k], A[k + 1u], A[k + 2u], A[k + 3u]);
290
+
291
+ let b = vec4<f32>(
292
+ f32(B[b_row_offset + k]),
293
+ f32(B[b_row_offset + k + 1u]),
294
+ f32(B[b_row_offset + k + 2u]),
295
+ f32(B[b_row_offset + k + 3u])
296
+ );
297
+
298
+ partial_sum = partial_sum + dot(a, b);
299
+ }
306
300
  }
307
301
  } else {
308
- // B is [K, N] (GGUF/column-major): B[k, col] = B[k * N + col]
302
+ // B is [K, N] (column-major): B[k, col] = B[k * N + col]
303
+ // Contiguous-range per thread: sequential access within each thread.
304
+ let k4_per_thread = (K4 + THREADS_PER_COL - 1u) / THREADS_PER_COL;
305
+ let k4_start = thread_in_col * k4_per_thread;
306
+ let k4_end = min(k4_start + k4_per_thread, K4);
309
307
  for (var k4: u32 = k4_start; k4 < k4_end; k4 = k4 + 1u) {
310
308
  let k = k4 * 4u;
311
309
 
@@ -342,4 +340,4 @@ fn main_vec4(
342
340
  }
343
341
  C[col] = final_sum * u.alpha;
344
342
  }
345
- }
343
+ }
@@ -9,6 +9,9 @@ import { selectRuleValue as selectLoaderRule } from '../../rules/rule-registry.j
9
9
  import { getBuffer, getWeightDtype, getBufferDtype } from '../weight-buffer.js';
10
10
  import { unifiedKernelWrapper } from './utils.js';
11
11
 
12
+ // Conservative fallback dtype for norm weight inference when metadata is unavailable.
13
+ const DEFAULT_DTYPE = 'f32';
14
+
12
15
  function inferHiddenSize(input, hiddenSize) {
13
16
  if (hiddenSize != null) return hiddenSize;
14
17
  const shape = input?.shape;
@@ -39,9 +42,12 @@ function resolveNormWeightDtype(weight, hiddenSize) {
39
42
  return taggedDtype;
40
43
  }
41
44
 
45
+ // Conservative fallback: f32 avoids precision loss when dtype cannot be determined.
46
+ // This path fires for non-GPU buffers or missing hiddenSize, both of which prevent
47
+ // size-based dtype inference below.
42
48
  const hasGPUBufferType = typeof GPUBuffer !== 'undefined';
43
49
  if (!hasGPUBufferType || !(weightBuffer instanceof GPUBuffer) || hiddenSize == null || hiddenSize <= 0) {
44
- return 'f32';
50
+ return DEFAULT_DTYPE;
45
51
  }
46
52
 
47
53
  const byteSize = getBufferRequestedSize(weightBuffer);
@@ -55,7 +61,8 @@ function resolveNormWeightDtype(weight, hiddenSize) {
55
61
  sizeMatchesF32,
56
62
  });
57
63
  }
58
- return 'f32';
64
+ // Buffer size matches neither f16 nor f32 for given hiddenSize; fall back to f32.
65
+ return DEFAULT_DTYPE;
59
66
  }
60
67
 
61
68
  function assertRMSNormWeightBuffer(weight, weightBuffer, hiddenSize) {
@@ -7,7 +7,6 @@ import { createPipeline, createUniformBufferWithView, getOrCreateBindGroupLayout
7
7
  import { allowReadback } from '../perf-guards.js';
8
8
  import { selectRuleValue as selectKernelRuleValue } from './rule-registry.js';
9
9
  import { selectRuleValue as selectSharedRuleValue } from '../../rules/rule-registry.js';
10
- import { getKernelThresholds } from '../../config/schema/index.js';
11
10
 
12
11
 
13
12
  function getSampleBindGroupLayout(device) {
@@ -96,8 +95,7 @@ async function resolveArgmaxPipelines(device, vocabSize, variants) {
96
95
  const argmaxPipeline = await createSamplePipeline(device, variants.argmax);
97
96
  const numWorkgroups = Math.min(WORKGROUP_SIZES.DEFAULT, Math.ceil(vocabSize / WORKGROUP_SIZES.DEFAULT));
98
97
  const useSinglePassArgmax = numWorkgroups === 1;
99
- const argmaxReduceVocabThreshold = getKernelThresholds().sample.argmaxReduceVocabThreshold;
100
- const reducePipeline = useSinglePassArgmax || vocabSize <= argmaxReduceVocabThreshold
98
+ const reducePipeline = useSinglePassArgmax
101
99
  ? null
102
100
  : await createSamplePipeline(device, variants.argmaxReduce);
103
101
  const singlePassPipeline = useSinglePassArgmax