@simulatte/doppler 0.1.8 → 0.1.9

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. package/CHANGELOG.md +14 -1
  2. package/README.md +25 -6
  3. package/package.json +5 -3
  4. package/src/client/doppler-api.browser.js +6 -0
  5. package/src/client/doppler-api.d.ts +3 -0
  6. package/src/client/doppler-api.js +11 -2
  7. package/src/client/doppler-registry.js +3 -5
  8. package/src/client/doppler-registry.json +16 -0
  9. package/src/config/kernels/kernel-ref-digests.js +23 -21
  10. package/src/config/kernels/moe/mixtral.paths.json +46 -0
  11. package/src/config/loader.js +6 -0
  12. package/src/config/platforms/loader.js +3 -1
  13. package/src/config/presets/kernel-paths/gemma3-q4k-dequant-f32a-nosubgroups.json +16 -16
  14. package/src/config/presets/kernel-paths/gemma3-q4k-dequant-f32a-online.json +8 -8
  15. package/src/config/presets/kernel-paths/gemma3-q4k-dequant-f32a-small-attn.json +61 -0
  16. package/src/config/presets/kernel-paths/registry.json +7 -0
  17. package/src/config/presets/models/gemma3.json +2 -1
  18. package/src/config/presets/models/gemma4.json +61 -0
  19. package/src/config/presets/models/granite-docling.json +70 -0
  20. package/src/config/presets/models/lfm2.json +6 -1
  21. package/src/config/presets/models/qwen3_vl.json +40 -0
  22. package/src/config/presets/runtime/experiments/bench/gemma3-bench-q4k.json +2 -1
  23. package/src/config/presets/runtime/experiments/verify/lfm2-verify.json +46 -0
  24. package/src/config/presets/runtime/experiments/verify/translategemma-verify.json +39 -0
  25. package/src/config/presets/runtime/modes/trace-layers.json +1 -0
  26. package/src/config/presets/runtime/tiers/gemma4-16gb.json +69 -0
  27. package/src/config/presets/runtime/tiers/gemma4-24gb.json +66 -0
  28. package/src/config/presets/runtime/tiers/gemma4-32gb.json +66 -0
  29. package/src/config/runtime.js +3 -0
  30. package/src/config/schema/debug.schema.d.ts +40 -0
  31. package/src/config/schema/debug.schema.js +28 -0
  32. package/src/config/schema/index.js +2 -0
  33. package/src/config/schema/inference-defaults.schema.js +1 -1
  34. package/src/config/schema/kernel-path.schema.d.ts +1 -0
  35. package/src/config/schema/memory-limits.schema.js +2 -2
  36. package/src/config/schema/storage.schema.js +1 -1
  37. package/src/converter/conversion-plan.js +1 -1
  38. package/src/converter/core.js +17 -8
  39. package/src/converter/quantizer.d.ts +5 -0
  40. package/src/converter/quantizer.js +15 -0
  41. package/src/distribution/shard-delivery.js +34 -0
  42. package/src/formats/rdrr/classification.js +32 -0
  43. package/src/gpu/kernel-runtime.js +4 -2
  44. package/src/gpu/kernels/attention.js +2 -1
  45. package/src/gpu/kernels/dequant_f16_out.wgsl +4 -2
  46. package/src/gpu/kernels/dequant_f16_out_vec4.wgsl +5 -2
  47. package/src/gpu/kernels/dequant_shared.wgsl +4 -2
  48. package/src/gpu/kernels/dequant_shared_vec4.wgsl +4 -2
  49. package/src/gpu/kernels/dequant_subgroup.wgsl +6 -2
  50. package/src/gpu/kernels/gated-short-conv.d.ts +63 -0
  51. package/src/gpu/kernels/gated-short-conv.js +284 -0
  52. package/src/gpu/kernels/linear-attention-core.js +37 -17
  53. package/src/gpu/kernels/matmul-selection.js +1 -0
  54. package/src/gpu/kernels/matmul.d.ts +3 -0
  55. package/src/gpu/kernels/matmul.js +70 -1
  56. package/src/gpu/kernels/matmul_gemv_subgroup.wgsl +77 -79
  57. package/src/gpu/kernels/sample.js +1 -3
  58. package/src/gpu/kernels/sample.wgsl +39 -9
  59. package/src/gpu/kernels/sample_f16.wgsl +38 -8
  60. package/src/gpu/kernels/shader-cache.js +9 -4
  61. package/src/inference/kv-cache/base.js +3 -10
  62. package/src/inference/pipelines/diffusion/pipeline.js +2 -1
  63. package/src/inference/pipelines/diffusion/text-encoder-gpu.js +2 -1
  64. package/src/inference/pipelines/text/attention/projections.d.ts +3 -0
  65. package/src/inference/pipelines/text/attention/projections.js +13 -2
  66. package/src/inference/pipelines/text/attention/record.js +1 -0
  67. package/src/inference/pipelines/text/attention/run.js +9 -0
  68. package/src/inference/pipelines/text/config.d.ts +1 -0
  69. package/src/inference/pipelines/text/config.js +32 -4
  70. package/src/inference/pipelines/text/embed.js +26 -7
  71. package/src/inference/pipelines/text/execution-v0-runtime-builders.js +10 -3
  72. package/src/inference/pipelines/text/execution-v0.js +12 -1
  73. package/src/inference/pipelines/text/generator-helpers.js +1 -0
  74. package/src/inference/pipelines/text/generator-runtime.js +14 -0
  75. package/src/inference/pipelines/text/generator-steps.d.ts +9 -0
  76. package/src/inference/pipelines/text/generator-steps.js +46 -29
  77. package/src/inference/pipelines/text/generator.d.ts +5 -0
  78. package/src/inference/pipelines/text/generator.js +320 -166
  79. package/src/inference/pipelines/text/init.d.ts +2 -0
  80. package/src/inference/pipelines/text/init.js +19 -5
  81. package/src/inference/pipelines/text/layer.js +37 -8
  82. package/src/inference/pipelines/text/moe-gpu.js +21 -3
  83. package/src/inference/pipelines/text/moe-shape-validator.d.ts +9 -0
  84. package/src/inference/pipelines/text/moe-shape-validator.js +31 -11
  85. package/src/inference/pipelines/text/ops.js +123 -53
  86. package/src/inference/pipelines/text/probes.js +1 -0
  87. package/src/inference/pipelines/text/state.js +2 -0
  88. package/src/inference/pipelines/text.d.ts +5 -0
  89. package/src/inference/pipelines/text.js +59 -1
  90. package/src/inference/pipelines/vision/encoder.js +386 -0
  91. package/src/inference/pipelines/vision/image-preprocess.js +151 -0
  92. package/src/inference/pipelines/vision/index.js +173 -0
  93. package/src/inference/pipelines/vision/ops.js +78 -0
  94. package/src/inference/pipelines/vision/patch-embed.js +151 -0
  95. package/src/inference/test-harness.js +9 -7
  96. package/src/loader/doppler-loader.d.ts +3 -0
  97. package/src/loader/doppler-loader.js +20 -3
  98. package/src/loader/experts/expert-cache.js +6 -2
  99. package/src/loader/experts/expert-loader.js +6 -2
  100. package/src/loader/layer-loader.js +42 -3
  101. package/src/loader/manifest-config.js +3 -1
  102. package/src/loader/tensors/tensor-loader.d.ts +3 -0
  103. package/src/loader/tensors/tensor-loader.js +124 -3
  104. package/src/rules/kernels/moe.rules.mixtral.json +75 -0
  105. package/src/rules/kernels/softmax.rules.json +2 -0
  106. package/src/rules/rule-registry.d.ts +1 -0
  107. package/src/rules/rule-registry.js +2 -0
  108. package/src/storage/quickstart-downloader.d.ts +3 -0
  109. package/src/storage/quickstart-downloader.js +27 -30
  110. package/src/tooling/node-converter.js +25 -7
  111. package/src/tooling/node-source-runtime.js +29 -5
  112. package/src/tooling/node-webgpu.js +24 -7
  113. package/src/utils/hf-resolve-url.d.ts +16 -0
  114. package/src/utils/hf-resolve-url.js +17 -0
  115. package/src/version.js +1 -1
  116. package/src/tooling/node-convert.d.ts +0 -54
@@ -10,7 +10,7 @@ import {
10
10
  import { recordDispatch } from './dispatch.js';
11
11
 
12
12
  const CONV_WORKGROUP_SIZE = WORKGROUP_SIZES.DEFAULT;
13
- const HEAD_WORKGROUP_SIZE = 64;
13
+ const HEAD_WORKGROUP_SIZE = 128;
14
14
 
15
15
  const CONV_SHADER = /* wgsl */ `
16
16
  override WORKGROUP_SIZE: u32 = 256u;
@@ -79,7 +79,7 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
79
79
  `;
80
80
 
81
81
  const RECURRENT_SHADER = /* wgsl */ `
82
- override WORKGROUP_SIZE: u32 = 64u;
82
+ override WORKGROUP_SIZE: u32 = 128u;
83
83
 
84
84
  struct LinearAttentionParams {
85
85
  num_tokens: u32,
@@ -111,6 +111,8 @@ struct LinearAttentionParams {
111
111
  @group(0) @binding(8) var<storage, read_write> recurrent_state: array<f32>;
112
112
  @group(0) @binding(9) var<storage, read_write> output: array<f32>;
113
113
 
114
+ var<workgroup> shared_sq: array<f32, WORKGROUP_SIZE>;
115
+
114
116
  fn softplus(x: f32) -> f32 {
115
117
  if (x > 20.0) {
116
118
  return x;
@@ -131,17 +133,19 @@ fn silu(x: f32) -> f32 {
131
133
  }
132
134
 
133
135
  @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
134
- fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
135
- let head = gid.x;
136
+ fn main(@builtin(workgroup_id) wid: vec3<u32>,
137
+ @builtin(local_invocation_id) lid: vec3<u32>) {
138
+ let head = wid.x;
139
+ let vd = lid.x;
136
140
  if (head >= params.num_v_heads) {
137
141
  return;
138
142
  }
139
143
 
140
144
  let head_k_dim = params.head_k_dim;
141
145
  let head_v_dim = params.head_v_dim;
146
+ let is_active = vd < head_v_dim;
142
147
  let head_scale = inverseSqrt(f32(head_k_dim));
143
148
  let recurrent_head_base = head * head_k_dim * head_v_dim;
144
- let recurrent_head_size = head_k_dim * head_v_dim;
145
149
  let q_rep = max(params.q_rep, 1u);
146
150
  let src_head = head / q_rep;
147
151
  let q_base = src_head * head_k_dim;
@@ -154,6 +158,7 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
154
158
  let ab_row_base = token_idx * params.num_v_heads + head;
155
159
  let out_row_base = token_idx * params.value_dim + head * head_v_dim;
156
160
 
161
+ // L2 norm for Q and K (redundant across threads but avoids shared memory)
157
162
  var q_norm_sq = 0.0;
158
163
  var k_norm_sq = 0.0;
159
164
  for (var d: u32 = 0u; d < head_k_dim; d = d + 1u) {
@@ -169,11 +174,16 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
169
174
  let g = a_neg_exp[head] * softplus(a_proj[ab_row_base] + dt_bias[head]);
170
175
  let g_exp = exp(g);
171
176
 
172
- for (var i: u32 = 0u; i < recurrent_head_size; i = i + 1u) {
173
- recurrent_state[recurrent_head_base + i] = recurrent_state[recurrent_head_base + i] * g_exp;
177
+ // Decay state each thread handles head_k_dim elements at stride head_v_dim
178
+ if (is_active) {
179
+ for (var kd: u32 = 0u; kd < head_k_dim; kd = kd + 1u) {
180
+ let state_idx = recurrent_head_base + kd * head_v_dim + vd;
181
+ recurrent_state[state_idx] = recurrent_state[state_idx] * g_exp;
182
+ }
174
183
  }
175
184
 
176
- for (var vd: u32 = 0u; vd < head_v_dim; vd = vd + 1u) {
185
+ // Delta update each thread handles one vd slice (no cross-thread dependency)
186
+ if (is_active) {
177
187
  var kv_mem = 0.0;
178
188
  for (var kd: u32 = 0u; kd < head_k_dim; kd = kd + 1u) {
179
189
  let k_normed = conv_out[conv_row_base + k_base + kd] * k_norm_scale;
@@ -188,21 +198,31 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
188
198
  }
189
199
  }
190
200
 
191
- var mean_sq = 0.0;
192
- for (var vd: u32 = 0u; vd < head_v_dim; vd = vd + 1u) {
193
- var out_value = 0.0;
201
+ // Output each thread computes one vd element
202
+ var out_value = 0.0;
203
+ if (is_active) {
194
204
  for (var kd: u32 = 0u; kd < head_k_dim; kd = kd + 1u) {
195
205
  let q_normed = conv_out[conv_row_base + q_base + kd] * q_norm_scale;
196
206
  let state_idx = recurrent_head_base + kd * head_v_dim + vd;
197
207
  out_value = out_value + recurrent_state[state_idx] * q_normed;
198
208
  }
199
209
  output[out_row_base + vd] = out_value;
200
- let value = out_value;
201
- mean_sq = mean_sq + value * value;
202
210
  }
203
- let inv_rms = inverseSqrt(mean_sq / f32(head_v_dim) + params.rms_norm_eps);
204
211
 
205
- for (var vd: u32 = 0u; vd < head_v_dim; vd = vd + 1u) {
212
+ // RMS norm reduction across vd (workgroup-level)
213
+ shared_sq[vd] = select(0.0, out_value * out_value, is_active);
214
+ workgroupBarrier();
215
+ // Tree reduction
216
+ for (var stride: u32 = WORKGROUP_SIZE / 2u; stride > 0u; stride = stride / 2u) {
217
+ if (vd < stride) {
218
+ shared_sq[vd] = shared_sq[vd] + shared_sq[vd + stride];
219
+ }
220
+ workgroupBarrier();
221
+ }
222
+ let inv_rms = inverseSqrt(shared_sq[0] / f32(head_v_dim) + params.rms_norm_eps);
223
+
224
+ // Apply norm + gate
225
+ if (is_active) {
206
226
  let gate = silu(z_proj[z_row_base + vd]);
207
227
  let norm_index = select(vd, head * head_v_dim + vd, params.norm_mode == 1u);
208
228
  output[out_row_base + vd] = (output[out_row_base + vd] * inv_rms) * norm_weight[norm_index] * gate;
@@ -415,7 +435,7 @@ export async function runLinearAttentionCoreGPU(qkvTensor, zTensor, aTensor, bTe
415
435
  recorder,
416
436
  recurrentPipeline,
417
437
  recurrentBindGroup,
418
- [Math.ceil(layerState.numVHeads / HEAD_WORKGROUP_SIZE), 1, 1],
438
+ [layerState.numVHeads, 1, 1],
419
439
  'linear_attention_recurrent'
420
440
  );
421
441
 
@@ -502,7 +522,7 @@ export async function runLinearAttentionCoreGPU(qkvTensor, zTensor, aTensor, bTe
502
522
  const pass = encoder.beginComputePass({ label: 'linear_attention_recurrent_pass' });
503
523
  pass.setPipeline(recurrentPipeline);
504
524
  pass.setBindGroup(0, recurrentBindGroup);
505
- pass.dispatchWorkgroups(Math.ceil(layerState.numVHeads / HEAD_WORKGROUP_SIZE), 1, 1);
525
+ pass.dispatchWorkgroups(layerState.numVHeads, 1, 1);
506
526
  pass.end();
507
527
  }
508
528
 
@@ -92,6 +92,7 @@ export function getMatmulConfig(variant, constants) {
92
92
 
93
93
 
94
94
  export function isFusedQ4KDisabled(options = {}) {
95
+ if (options.disableFusedQ4K === true) return true;
95
96
  const capabilities = getKernelCapabilities();
96
97
  const hasSubgroups = capabilities?.hasSubgroups === true;
97
98
 
@@ -13,6 +13,7 @@ import type { WeightBuffer } from '../weight-buffer.js';
13
13
  import type { CommandRecorder } from '../command-recorder.js';
14
14
  import type { OutputBufferOptions, OutputDtypeOptions, Vec4Options } from './types.js';
15
15
  import type { KernelPathSchema } from '../../config/schema/index.js';
16
+ import type { MatmulDebugConfigSchema } from '../../config/schema/debug.schema.js';
16
17
 
17
18
  /** Matmul kernel options */
18
19
  export interface MatmulOptions extends OutputBufferOptions, OutputDtypeOptions, Vec4Options {
@@ -40,6 +41,8 @@ export interface MatmulOptions extends OutputBufferOptions, OutputDtypeOptions,
40
41
  preferF16?: boolean;
41
42
  /** WGSL override constants for pipeline creation */
42
43
  constants?: Record<string, number | boolean>;
44
+ /** Runtime debug controls for attention projection diagnostics. */
45
+ matmulDebug?: MatmulDebugConfigSchema | null;
43
46
  }
44
47
 
45
48
  /** Context for base matmul kernel selection rules. */
@@ -2,7 +2,7 @@ import { getDevice, getKernelCapabilities } from '../device.js';
2
2
  import { createTensor } from '../tensor.js';
3
3
  import { getBuffer, getLayout, getWeightDtype } from '../weight-buffer.js';
4
4
  import { log, trace, isTraceEnabled } from '../../debug/index.js';
5
- import { releaseBuffer } from '../../memory/buffer-pool.js';
5
+ import { releaseBuffer, readBuffer } from '../../memory/buffer-pool.js';
6
6
  import { releaseUniformBuffer } from '../uniform-cache.js';
7
7
  import { castF16ToF32, recordCastF16ToF32 } from './cast.js';
8
8
  import {
@@ -34,6 +34,24 @@ export { createMatmulBindGroupLayout };
34
34
  let _runMatmulDebugCount = 0;
35
35
  let _recordMatmulDebugCount = 0;
36
36
 
37
+ function normalizeMatmulDebugConfig(config) {
38
+ if (!config || typeof config !== 'object') {
39
+ return null;
40
+ }
41
+ return {
42
+ enabled: config.enabled === true,
43
+ forceSplitQKV: config.forceSplitQKV === true,
44
+ validateAttentionWeightBuffer: config.validateAttentionWeightBuffer === true,
45
+ failOnSmallAttentionWeightBuffer: config.failOnSmallAttentionWeightBuffer === true,
46
+ logAttentionWeightBuffer: config.logAttentionWeightBuffer === true,
47
+ logProjectionValues: config.logProjectionValues === true,
48
+ };
49
+ }
50
+
51
+ function isAttentionProjectionRole(role = '') {
52
+ return role === 'qkv_proj' || role === 'q_proj' || role === 'k_proj' || role === 'v_proj';
53
+ }
54
+
37
55
  function getDebugCounter(isRecord) {
38
56
  return isRecord ? _recordMatmulDebugCount : _runMatmulDebugCount;
39
57
  }
@@ -126,6 +144,12 @@ async function executeMatmul(recorder, A, B, M, N, K, options = {}) {
126
144
  const weightLabel = (B && typeof B === 'object' ? B.label : null) ?? bBuffer?.label ?? null;
127
145
  const weightLayout = getLayout(B);
128
146
  const weightShape = B?.shape ? `[${B.shape.join(', ')}]` : null;
147
+ const matmulDebug = normalizeMatmulDebugConfig(options.matmulDebug);
148
+ const debugAttention = matmulDebug?.enabled === true;
149
+ const isAttnProj = isAttentionProjectionRole(options.role ?? '');
150
+ const shouldValidateAttentionWeightBuffer = debugAttention && matmulDebug.validateAttentionWeightBuffer;
151
+ const shouldFailOnSmallAttentionWeightBuffer = debugAttention && matmulDebug.failOnSmallAttentionWeightBuffer;
152
+ const shouldLogAttentionWeightBuffer = debugAttention && matmulDebug.logAttentionWeightBuffer;
129
153
 
130
154
  if (isTraceEnabled('kernels') && getDebugCounter(isRecord) < 20) {
131
155
  incrementDebugCounter(isRecord);
@@ -201,6 +225,27 @@ async function executeMatmul(recorder, A, B, M, N, K, options = {}) {
201
225
  bOffset
202
226
  );
203
227
  } catch (err) {
228
+ if (shouldValidateAttentionWeightBuffer && isAttnProj && err instanceof Error && err.message.includes('B buffer too small')) {
229
+ const detailParts = [
230
+ `role=${options.role ?? ''}`,
231
+ `layer=${Number.isFinite(options.layerIdx) ? options.layerIdx : '?'}`,
232
+ `M=${M}`,
233
+ `N=${N}`,
234
+ `K=${K}`,
235
+ ];
236
+ if (weightDtype) detailParts.push(`weightDtype=${weightDtype}`);
237
+ if (weightLayout) detailParts.push(`weightLayout=${weightLayout}`);
238
+ if (weightShape) detailParts.push(`shape=${weightShape}`);
239
+ if (weightLabel) detailParts.push(`label=${weightLabel}`);
240
+ if (Number.isFinite(bBuffer?.size)) detailParts.push(`bSize=${bBuffer.size}`);
241
+ const detail = detailParts.join(' ');
242
+ if (shouldLogAttentionWeightBuffer) {
243
+ log.warn('MatmulQKVProbe', `${err.message} | ${detail}`);
244
+ }
245
+ if (shouldFailOnSmallAttentionWeightBuffer) {
246
+ throw new Error(`${err.message}${detail ? ` (${detail})` : ''}`);
247
+ }
248
+ }
204
249
  if (!isRecord && err instanceof Error && err.message.includes('B buffer too small')) {
205
250
  const detailParts = [];
206
251
  if (weightLabel) detailParts.push(`label=${weightLabel}`);
@@ -226,6 +271,15 @@ async function executeMatmul(recorder, A, B, M, N, K, options = {}) {
226
271
  trace.kernels(`MATMUL_LARGE: N=${N}, variant=${variant}, aDtype=${aDtype}, bDtype=${bDtype}, transposeB=${transposeB}`);
227
272
  }
228
273
 
274
+ if (isAttnProj && shouldLogAttentionWeightBuffer) {
275
+ log.warn('MatmulQKVProbe',
276
+ `role=${options.role ?? ''} layer=${Number.isFinite(options.layerIdx) ? options.layerIdx : '?'} ` +
277
+ `M=${M} N=${N} K=${K} transposeB=${transposeB} bSize=${bBuffer?.size ?? 0} ` +
278
+ `requiredB=${bindingSizes?.bBindingSize ?? 'n/a'} weightShape=${weightShape ?? 'n/a'} ` +
279
+ `weightDtype=${weightDtype ?? 'unknown'} weightLayout=${weightLayout ?? 'unknown'}`
280
+ );
281
+ }
282
+
229
283
  const config = getMatmulConfig(variant, constants);
230
284
  const kernel = new MatmulKernel(device);
231
285
  const pipeline = await getMatmulPipeline(variant, constants);
@@ -238,6 +292,14 @@ async function executeMatmul(recorder, A, B, M, N, K, options = {}) {
238
292
  );
239
293
  const ownsOutput = outputBuffer == null;
240
294
 
295
+ if (isAttnProj && shouldLogAttentionWeightBuffer) {
296
+ log.warn('MatmulVariantDiag',
297
+ `role=${options.role ?? ''} layer=${Number.isFinite(options.layerIdx) ? options.layerIdx : '?'} mode=${mode} ` +
298
+ `variant=${variant} useQ4KFused=${useQ4KFused} useGemv=${useGemv} ` +
299
+ `aDtype=${aDtype} bDtype=${bDtype} output=${actualOutputDtype}`
300
+ );
301
+ }
302
+
241
303
  if (!Number.isFinite(outputSize) || outputSize <= 0) {
242
304
  throw new Error(`[${opLabel}] Invalid output size: ${outputSize} (M=${M}, N=${N})`);
243
305
  }
@@ -290,6 +352,13 @@ async function executeMatmul(recorder, A, B, M, N, K, options = {}) {
290
352
  kernel.dispatch(pipeline, bindGroup, dispatchPlan.workgroups);
291
353
  }
292
354
  completed = true;
355
+ if (!isRecord && matmulDebug?.logProjectionValues && isAttnProj && M === 1 && options.layerIdx === 0) {
356
+ await device.queue.onSubmittedWorkDone();
357
+ const raw = await readBuffer(C);
358
+ const numVals = Math.min(8, Math.floor(raw.byteLength / 4));
359
+ const vals = numVals > 0 ? new Float32Array(raw, 0, numVals) : [];
360
+ log.warn('ProjectionProbe', `role=${options.role ?? ''} L0 M1 first8_f32: ${Array.from(vals).map(v => v.toFixed(5)).join(' ')}`);
361
+ }
293
362
  return createTensor(C, actualOutputDtype, [M, N], 'matmul_output');
294
363
  } finally {
295
364
  if (!isRecord && uniformBuffer) {
@@ -5,7 +5,11 @@
5
5
  // 1. Use subgroupAdd() for reduction - much faster than shared memory
6
6
  // 2. Vectorized vec4 loads for weights
7
7
  // 3. Each workgroup processes multiple output columns
8
- // 4. Loop unrolling for better ILP
8
+ // 4. Warp-stride loop for row-major (transpose_b=1): all threads in a column
9
+ // step through K together, so adjacent threads load adjacent addresses.
10
+ // At each step, 64 threads × 8 bytes = 512 bytes from 4 consecutive cache
11
+ // lines → 100% cache-line utilization vs ~10% for the old contiguous-range
12
+ // pattern (where threads were 80 bytes apart in the same iteration).
9
13
  //
10
14
  // A is f32 (activations), B is f16 (weights), C is f32.
11
15
  // transpose_b=0: B is [K, N] (GGUF/column-major), access B[k * N + col]
@@ -69,40 +73,29 @@ fn main(
69
73
  // Each thread computes partial sum for its assigned k values
70
74
  var partial_sum: f32 = 0.0;
71
75
 
72
- // Only do work if this column is valid
73
76
  if (is_valid) {
74
- // Process K in chunks, each thread handles K/64 elements
75
- let k_per_thread = (u.K + THREADS_PER_COL - 1u) / THREADS_PER_COL;
76
- let k_start = thread_in_col * k_per_thread;
77
- let k_end = min(k_start + k_per_thread, u.K);
78
-
79
- // Main loop - process 4 elements at a time when aligned
80
- var k = k_start;
81
- let k_aligned_end = k_start + ((k_end - k_start) / 4u) * 4u;
82
-
83
77
  if (u.transpose_b == 1u) {
84
- // B is [N, K] (SafeTensors/row-major): B[col, k] = B[col * K + k]
78
+ // B is [N, K] (row-major): B[col, k] = B[col * K + k]
79
+ // Warp-stride: step THREADS_PER_COL elements per outer iteration so that
80
+ // all wavefront threads load consecutive addresses simultaneously.
81
+ // At each step, 64 threads × 2 bytes = 128 bytes = exactly 1 cache line → 100% utilization.
85
82
  let b_row_offset = col * u.K;
86
-
87
- for (; k < k_aligned_end; k = k + 4u) {
88
- let a0 = A[k];
89
- let a1 = A[k + 1u];
90
- let a2 = A[k + 2u];
91
- let a3 = A[k + 3u];
92
-
93
- let b0 = f32(B[b_row_offset + k]);
94
- let b1 = f32(B[b_row_offset + k + 1u]);
95
- let b2 = f32(B[b_row_offset + k + 2u]);
96
- let b3 = f32(B[b_row_offset + k + 3u]);
97
-
98
- partial_sum = partial_sum + a0 * b0 + a1 * b1 + a2 * b2 + a3 * b3;
99
- }
100
-
101
- for (; k < k_end; k = k + 1u) {
102
- partial_sum = partial_sum + A[k] * f32(B[b_row_offset + k]);
83
+ for (var k_base: u32 = 0u; k_base < u.K; k_base = k_base + THREADS_PER_COL) {
84
+ let k = k_base + thread_in_col;
85
+ if (k < u.K) {
86
+ partial_sum = partial_sum + A[k] * f32(B[b_row_offset + k]);
87
+ }
103
88
  }
104
89
  } else {
105
- // B is [K, N] (GGUF/column-major): B[k, col] = B[k * N + col]
90
+ // B is [K, N] (column-major): B[k, col] = B[k * N + col]
91
+ // Contiguous-range per thread: sequential access within each thread.
92
+ let k_per_thread = (u.K + THREADS_PER_COL - 1u) / THREADS_PER_COL;
93
+ let k_start = thread_in_col * k_per_thread;
94
+ let k_end = min(k_start + k_per_thread, u.K);
95
+
96
+ var k = k_start;
97
+ let k_aligned_end = k_start + ((k_end - k_start) / 4u) * 4u;
98
+
106
99
  for (; k < k_aligned_end; k = k + 4u) {
107
100
  let a0 = A[k];
108
101
  let a1 = A[k + 1u];
@@ -189,38 +182,36 @@ fn main_multicol(
189
182
  var partial_sum: f32 = 0.0;
190
183
 
191
184
  if (is_valid) {
192
- // Each of 8 threads splits K
193
- let k_per_thread = (u.K + MULTICOL_THREADS_PER_COL - 1u) / MULTICOL_THREADS_PER_COL;
194
- let k_start = thread_in_col * k_per_thread;
195
- let k_end = min(k_start + k_per_thread, u.K);
196
-
197
- // Unroll by 4 for ILP
198
- var k = k_start;
199
- let k_aligned_end = k_start + ((k_end - k_start) / 4u) * 4u;
200
-
201
185
  if (u.transpose_b == 1u) {
202
- // B is [N, K] (SafeTensors/row-major): B[col, k] = B[col * K + k]
186
+ // B is [N, K] (row-major): B[col, k] = B[col * K + k]
187
+ // Warp-stride: step MULTICOL_THREADS_PER_COL vec4 groups per outer iteration.
188
+ // Adjacent threads in the same column load adjacent vec4 groups → coalesced.
189
+ let K4 = u.K / 4u;
203
190
  let b_row_offset = col * u.K;
204
-
205
- for (; k < k_aligned_end; k = k + 4u) {
206
- let a0 = A[k];
207
- let a1 = A[k + 1u];
208
- let a2 = A[k + 2u];
209
- let a3 = A[k + 3u];
210
-
211
- let b0 = f32(B[b_row_offset + k]);
212
- let b1 = f32(B[b_row_offset + k + 1u]);
213
- let b2 = f32(B[b_row_offset + k + 2u]);
214
- let b3 = f32(B[b_row_offset + k + 3u]);
215
-
216
- partial_sum = partial_sum + a0 * b0 + a1 * b1 + a2 * b2 + a3 * b3;
217
- }
218
-
219
- for (; k < k_end; k = k + 1u) {
220
- partial_sum = partial_sum + A[k] * f32(B[b_row_offset + k]);
191
+ for (var k4_base: u32 = 0u; k4_base < K4; k4_base = k4_base + MULTICOL_THREADS_PER_COL) {
192
+ let k4 = k4_base + thread_in_col;
193
+ if (k4 < K4) {
194
+ let k = k4 * 4u;
195
+ let a0 = A[k];
196
+ let a1 = A[k + 1u];
197
+ let a2 = A[k + 2u];
198
+ let a3 = A[k + 3u];
199
+ let b0 = f32(B[b_row_offset + k]);
200
+ let b1 = f32(B[b_row_offset + k + 1u]);
201
+ let b2 = f32(B[b_row_offset + k + 2u]);
202
+ let b3 = f32(B[b_row_offset + k + 3u]);
203
+ partial_sum = partial_sum + a0 * b0 + a1 * b1 + a2 * b2 + a3 * b3;
204
+ }
221
205
  }
222
206
  } else {
223
- // B is [K, N] (GGUF/column-major): B[k, col] = B[k * N + col]
207
+ // B is [K, N] (column-major): B[k, col] = B[k * N + col]
208
+ let k_per_thread = (u.K + MULTICOL_THREADS_PER_COL - 1u) / MULTICOL_THREADS_PER_COL;
209
+ let k_start = thread_in_col * k_per_thread;
210
+ let k_end = min(k_start + k_per_thread, u.K);
211
+
212
+ var k = k_start;
213
+ let k_aligned_end = k_start + ((k_end - k_start) / 4u) * 4u;
214
+
224
215
  for (; k < k_aligned_end; k = k + 4u) {
225
216
  let a0 = A[k];
226
217
  let a1 = A[k + 1u];
@@ -245,7 +236,7 @@ fn main_multicol(
245
236
  multicol_wg_sums[local_id] = partial_sum;
246
237
  workgroupBarrier();
247
238
 
248
- // Thread 0 of each column reduces its 8 values
239
+ // Thread 0 of each column reduces its MULTICOL_THREADS_PER_COL values
249
240
  if (thread_in_col == 0u && is_valid) {
250
241
  var final_sum: f32 = 0.0;
251
242
  let base = col_in_wg * MULTICOL_THREADS_PER_COL;
@@ -282,30 +273,37 @@ fn main_vec4(
282
273
  if (is_valid) {
283
274
  // K is guaranteed to be multiple of 4
284
275
  let K4 = u.K / 4u;
285
- let k4_per_thread = (K4 + THREADS_PER_COL - 1u) / THREADS_PER_COL;
286
- let k4_start = thread_in_col * k4_per_thread;
287
- let k4_end = min(k4_start + k4_per_thread, K4);
288
276
 
289
277
  if (u.transpose_b == 1u) {
290
- // B is [N, K] (SafeTensors/row-major): B[col, k] = B[col * K + k]
278
+ // B is [N, K] (row-major): B[col, k] = B[col * K + k]
279
+ // Warp-stride: step THREADS_PER_COL vec4 groups per outer iteration so that
280
+ // adjacent threads load adjacent groups → 100% cache-line utilization.
281
+ // At each step: 64 threads × 4 f16 × 2 bytes = 512 bytes from 4 consecutive
282
+ // cache lines, vs the old contiguous-range pattern (~10% utilization).
291
283
  let b_row_offset = col * u.K;
292
-
293
- for (var k4: u32 = k4_start; k4 < k4_end; k4 = k4 + 1u) {
294
- let k = k4 * 4u;
295
-
296
- let a = vec4<f32>(A[k], A[k + 1u], A[k + 2u], A[k + 3u]);
297
-
298
- let b = vec4<f32>(
299
- f32(B[b_row_offset + k]),
300
- f32(B[b_row_offset + k + 1u]),
301
- f32(B[b_row_offset + k + 2u]),
302
- f32(B[b_row_offset + k + 3u])
303
- );
304
-
305
- partial_sum = partial_sum + dot(a, b);
284
+ for (var k4_base: u32 = 0u; k4_base < K4; k4_base = k4_base + THREADS_PER_COL) {
285
+ let k4 = k4_base + thread_in_col;
286
+ if (k4 < K4) {
287
+ let k = k4 * 4u;
288
+
289
+ let a = vec4<f32>(A[k], A[k + 1u], A[k + 2u], A[k + 3u]);
290
+
291
+ let b = vec4<f32>(
292
+ f32(B[b_row_offset + k]),
293
+ f32(B[b_row_offset + k + 1u]),
294
+ f32(B[b_row_offset + k + 2u]),
295
+ f32(B[b_row_offset + k + 3u])
296
+ );
297
+
298
+ partial_sum = partial_sum + dot(a, b);
299
+ }
306
300
  }
307
301
  } else {
308
- // B is [K, N] (GGUF/column-major): B[k, col] = B[k * N + col]
302
+ // B is [K, N] (column-major): B[k, col] = B[k * N + col]
303
+ // Contiguous-range per thread: sequential access within each thread.
304
+ let k4_per_thread = (K4 + THREADS_PER_COL - 1u) / THREADS_PER_COL;
305
+ let k4_start = thread_in_col * k4_per_thread;
306
+ let k4_end = min(k4_start + k4_per_thread, K4);
309
307
  for (var k4: u32 = k4_start; k4 < k4_end; k4 = k4 + 1u) {
310
308
  let k = k4 * 4u;
311
309
 
@@ -342,4 +340,4 @@ fn main_vec4(
342
340
  }
343
341
  C[col] = final_sum * u.alpha;
344
342
  }
345
- }
343
+ }
@@ -7,7 +7,6 @@ import { createPipeline, createUniformBufferWithView, getOrCreateBindGroupLayout
7
7
  import { allowReadback } from '../perf-guards.js';
8
8
  import { selectRuleValue as selectKernelRuleValue } from './rule-registry.js';
9
9
  import { selectRuleValue as selectSharedRuleValue } from '../../rules/rule-registry.js';
10
- import { getKernelThresholds } from '../../config/schema/index.js';
11
10
 
12
11
 
13
12
  function getSampleBindGroupLayout(device) {
@@ -96,8 +95,7 @@ async function resolveArgmaxPipelines(device, vocabSize, variants) {
96
95
  const argmaxPipeline = await createSamplePipeline(device, variants.argmax);
97
96
  const numWorkgroups = Math.min(WORKGROUP_SIZES.DEFAULT, Math.ceil(vocabSize / WORKGROUP_SIZES.DEFAULT));
98
97
  const useSinglePassArgmax = numWorkgroups === 1;
99
- const argmaxReduceVocabThreshold = getKernelThresholds().sample.argmaxReduceVocabThreshold;
100
- const reducePipeline = useSinglePassArgmax || vocabSize <= argmaxReduceVocabThreshold
98
+ const reducePipeline = useSinglePassArgmax
101
99
  ? null
102
100
  : await createSamplePipeline(device, variants.argmaxReduce);
103
101
  const singlePassPipeline = useSinglePassArgmax
@@ -40,6 +40,16 @@ fn apply_softcap(x: f32, softcap: f32) -> f32 {
40
40
  return softcap * tanh(x / softcap);
41
41
  }
42
42
 
43
+ fn candidate_beats(candidate_value: f32, candidate_index: u32, best_value: f32, best_index: u32) -> bool {
44
+ if (candidate_value > best_value) {
45
+ return true;
46
+ }
47
+ if (candidate_value < best_value) {
48
+ return false;
49
+ }
50
+ return candidate_index < best_index;
51
+ }
52
+
43
53
  @group(0) @binding(0) var<uniform> u: Uniforms;
44
54
  @group(0) @binding(1) var<storage, read> logits: array<f32>; // [vocabSize]
45
55
  @group(0) @binding(2) var<storage, read_write> output: array<u32>; // [N] - selected tokens
@@ -87,7 +97,7 @@ fn find_topk_phase1(
87
97
  if (idx != pad_id) {
88
98
  // Apply softcapping before temperature scaling
89
99
  let val = apply_softcap(logits[idx], softcap) / temperature;
90
- if (val > local_max) {
100
+ if (candidate_beats(val, idx, local_max, local_max_idx)) {
91
101
  local_max = val;
92
102
  local_max_idx = idx;
93
103
  }
@@ -103,7 +113,12 @@ fn find_topk_phase1(
103
113
  var stride = WORKGROUP_SIZE / 2u;
104
114
  while (stride > 0u) {
105
115
  if (thread_idx < stride) {
106
- if (shared_values[thread_idx + stride] > shared_values[thread_idx]) {
116
+ if (candidate_beats(
117
+ shared_values[thread_idx + stride],
118
+ shared_indices[thread_idx + stride],
119
+ shared_values[thread_idx],
120
+ shared_indices[thread_idx]
121
+ )) {
107
122
  shared_values[thread_idx] = shared_values[thread_idx + stride];
108
123
  shared_indices[thread_idx] = shared_indices[thread_idx + stride];
109
124
  }
@@ -150,7 +165,7 @@ fn find_topk_phase2(
150
165
  var max_val = shared_values[k];
151
166
 
152
167
  for (var i: u32 = k + 1u; i < num_candidates; i = i + 1u) {
153
- if (shared_values[i] > max_val) {
168
+ if (candidate_beats(shared_values[i], shared_indices[i], max_val, shared_indices[max_idx])) {
154
169
  max_val = shared_values[i];
155
170
  max_idx = i;
156
171
  }
@@ -249,7 +264,7 @@ fn sample_single_pass(
249
264
  if (idx != pad_id) {
250
265
  // Apply softcapping before temperature scaling
251
266
  let val = apply_softcap(logits[idx], softcap) / temperature;
252
- if (val > local_max) {
267
+ if (candidate_beats(val, idx, local_max, local_max_idx)) {
253
268
  local_max = val;
254
269
  local_max_idx = idx;
255
270
  }
@@ -265,7 +280,12 @@ fn sample_single_pass(
265
280
  var stride = WORKGROUP_SIZE / 2u;
266
281
  while (stride > 0u) {
267
282
  if (thread_idx < stride) {
268
- if (shared_values[thread_idx + stride] > shared_values[thread_idx]) {
283
+ if (candidate_beats(
284
+ shared_values[thread_idx + stride],
285
+ shared_indices[thread_idx + stride],
286
+ shared_values[thread_idx],
287
+ shared_indices[thread_idx]
288
+ )) {
269
289
  shared_values[thread_idx] = shared_values[thread_idx + stride];
270
290
  shared_indices[thread_idx] = shared_indices[thread_idx + stride];
271
291
  }
@@ -308,7 +328,7 @@ fn argmax(
308
328
  if (idx != pad_id) {
309
329
  // Apply softcapping (argmax is greedy, no temperature)
310
330
  let val = apply_softcap(logits[idx], softcap);
311
- if (val > local_max) {
331
+ if (candidate_beats(val, idx, local_max, local_max_idx)) {
312
332
  local_max = val;
313
333
  local_max_idx = idx;
314
334
  }
@@ -324,7 +344,12 @@ fn argmax(
324
344
  var stride = WORKGROUP_SIZE / 2u;
325
345
  while (stride > 0u) {
326
346
  if (thread_idx < stride) {
327
- if (shared_values[thread_idx + stride] > shared_values[thread_idx]) {
347
+ if (candidate_beats(
348
+ shared_values[thread_idx + stride],
349
+ shared_indices[thread_idx + stride],
350
+ shared_values[thread_idx],
351
+ shared_indices[thread_idx]
352
+ )) {
328
353
  shared_values[thread_idx] = shared_values[thread_idx + stride];
329
354
  shared_indices[thread_idx] = shared_indices[thread_idx + stride];
330
355
  }
@@ -362,7 +387,12 @@ fn argmax_reduce(
362
387
  var stride = WORKGROUP_SIZE / 2u;
363
388
  while (stride > 0u) {
364
389
  if (thread_idx < stride) {
365
- if (shared_values[thread_idx + stride] > shared_values[thread_idx]) {
390
+ if (candidate_beats(
391
+ shared_values[thread_idx + stride],
392
+ shared_indices[thread_idx + stride],
393
+ shared_values[thread_idx],
394
+ shared_indices[thread_idx]
395
+ )) {
366
396
  shared_values[thread_idx] = shared_values[thread_idx + stride];
367
397
  shared_indices[thread_idx] = shared_indices[thread_idx + stride];
368
398
  }
@@ -374,4 +404,4 @@ fn argmax_reduce(
374
404
  if (thread_idx == 0u) {
375
405
  output[u.output_index] = shared_indices[0];
376
406
  }
377
- }
407
+ }