@simulatte/doppler 0.1.8 → 0.1.9
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +14 -1
- package/README.md +25 -6
- package/package.json +5 -3
- package/src/client/doppler-api.browser.js +6 -0
- package/src/client/doppler-api.d.ts +3 -0
- package/src/client/doppler-api.js +11 -2
- package/src/client/doppler-registry.js +3 -5
- package/src/client/doppler-registry.json +16 -0
- package/src/config/kernels/kernel-ref-digests.js +23 -21
- package/src/config/kernels/moe/mixtral.paths.json +46 -0
- package/src/config/loader.js +6 -0
- package/src/config/platforms/loader.js +3 -1
- package/src/config/presets/kernel-paths/gemma3-q4k-dequant-f32a-nosubgroups.json +16 -16
- package/src/config/presets/kernel-paths/gemma3-q4k-dequant-f32a-online.json +8 -8
- package/src/config/presets/kernel-paths/gemma3-q4k-dequant-f32a-small-attn.json +61 -0
- package/src/config/presets/kernel-paths/registry.json +7 -0
- package/src/config/presets/models/gemma3.json +2 -1
- package/src/config/presets/models/gemma4.json +61 -0
- package/src/config/presets/models/granite-docling.json +70 -0
- package/src/config/presets/models/lfm2.json +6 -1
- package/src/config/presets/models/qwen3_vl.json +40 -0
- package/src/config/presets/runtime/experiments/bench/gemma3-bench-q4k.json +2 -1
- package/src/config/presets/runtime/experiments/verify/lfm2-verify.json +46 -0
- package/src/config/presets/runtime/experiments/verify/translategemma-verify.json +39 -0
- package/src/config/presets/runtime/modes/trace-layers.json +1 -0
- package/src/config/presets/runtime/tiers/gemma4-16gb.json +69 -0
- package/src/config/presets/runtime/tiers/gemma4-24gb.json +66 -0
- package/src/config/presets/runtime/tiers/gemma4-32gb.json +66 -0
- package/src/config/runtime.js +3 -0
- package/src/config/schema/debug.schema.d.ts +40 -0
- package/src/config/schema/debug.schema.js +28 -0
- package/src/config/schema/index.js +2 -0
- package/src/config/schema/inference-defaults.schema.js +1 -1
- package/src/config/schema/kernel-path.schema.d.ts +1 -0
- package/src/config/schema/memory-limits.schema.js +2 -2
- package/src/config/schema/storage.schema.js +1 -1
- package/src/converter/conversion-plan.js +1 -1
- package/src/converter/core.js +17 -8
- package/src/converter/quantizer.d.ts +5 -0
- package/src/converter/quantizer.js +15 -0
- package/src/distribution/shard-delivery.js +34 -0
- package/src/formats/rdrr/classification.js +32 -0
- package/src/gpu/kernel-runtime.js +4 -2
- package/src/gpu/kernels/attention.js +2 -1
- package/src/gpu/kernels/dequant_f16_out.wgsl +4 -2
- package/src/gpu/kernels/dequant_f16_out_vec4.wgsl +5 -2
- package/src/gpu/kernels/dequant_shared.wgsl +4 -2
- package/src/gpu/kernels/dequant_shared_vec4.wgsl +4 -2
- package/src/gpu/kernels/dequant_subgroup.wgsl +6 -2
- package/src/gpu/kernels/gated-short-conv.d.ts +63 -0
- package/src/gpu/kernels/gated-short-conv.js +284 -0
- package/src/gpu/kernels/linear-attention-core.js +37 -17
- package/src/gpu/kernels/matmul-selection.js +1 -0
- package/src/gpu/kernels/matmul.d.ts +3 -0
- package/src/gpu/kernels/matmul.js +70 -1
- package/src/gpu/kernels/matmul_gemv_subgroup.wgsl +77 -79
- package/src/gpu/kernels/sample.js +1 -3
- package/src/gpu/kernels/sample.wgsl +39 -9
- package/src/gpu/kernels/sample_f16.wgsl +38 -8
- package/src/gpu/kernels/shader-cache.js +9 -4
- package/src/inference/kv-cache/base.js +3 -10
- package/src/inference/pipelines/diffusion/pipeline.js +2 -1
- package/src/inference/pipelines/diffusion/text-encoder-gpu.js +2 -1
- package/src/inference/pipelines/text/attention/projections.d.ts +3 -0
- package/src/inference/pipelines/text/attention/projections.js +13 -2
- package/src/inference/pipelines/text/attention/record.js +1 -0
- package/src/inference/pipelines/text/attention/run.js +9 -0
- package/src/inference/pipelines/text/config.d.ts +1 -0
- package/src/inference/pipelines/text/config.js +32 -4
- package/src/inference/pipelines/text/embed.js +26 -7
- package/src/inference/pipelines/text/execution-v0-runtime-builders.js +10 -3
- package/src/inference/pipelines/text/execution-v0.js +12 -1
- package/src/inference/pipelines/text/generator-helpers.js +1 -0
- package/src/inference/pipelines/text/generator-runtime.js +14 -0
- package/src/inference/pipelines/text/generator-steps.d.ts +9 -0
- package/src/inference/pipelines/text/generator-steps.js +46 -29
- package/src/inference/pipelines/text/generator.d.ts +5 -0
- package/src/inference/pipelines/text/generator.js +320 -166
- package/src/inference/pipelines/text/init.d.ts +2 -0
- package/src/inference/pipelines/text/init.js +19 -5
- package/src/inference/pipelines/text/layer.js +37 -8
- package/src/inference/pipelines/text/moe-gpu.js +21 -3
- package/src/inference/pipelines/text/moe-shape-validator.d.ts +9 -0
- package/src/inference/pipelines/text/moe-shape-validator.js +31 -11
- package/src/inference/pipelines/text/ops.js +123 -53
- package/src/inference/pipelines/text/probes.js +1 -0
- package/src/inference/pipelines/text/state.js +2 -0
- package/src/inference/pipelines/text.d.ts +5 -0
- package/src/inference/pipelines/text.js +59 -1
- package/src/inference/pipelines/vision/encoder.js +386 -0
- package/src/inference/pipelines/vision/image-preprocess.js +151 -0
- package/src/inference/pipelines/vision/index.js +173 -0
- package/src/inference/pipelines/vision/ops.js +78 -0
- package/src/inference/pipelines/vision/patch-embed.js +151 -0
- package/src/inference/test-harness.js +9 -7
- package/src/loader/doppler-loader.d.ts +3 -0
- package/src/loader/doppler-loader.js +20 -3
- package/src/loader/experts/expert-cache.js +6 -2
- package/src/loader/experts/expert-loader.js +6 -2
- package/src/loader/layer-loader.js +42 -3
- package/src/loader/manifest-config.js +3 -1
- package/src/loader/tensors/tensor-loader.d.ts +3 -0
- package/src/loader/tensors/tensor-loader.js +124 -3
- package/src/rules/kernels/moe.rules.mixtral.json +75 -0
- package/src/rules/kernels/softmax.rules.json +2 -0
- package/src/rules/rule-registry.d.ts +1 -0
- package/src/rules/rule-registry.js +2 -0
- package/src/storage/quickstart-downloader.d.ts +3 -0
- package/src/storage/quickstart-downloader.js +27 -30
- package/src/tooling/node-converter.js +25 -7
- package/src/tooling/node-source-runtime.js +29 -5
- package/src/tooling/node-webgpu.js +24 -7
- package/src/utils/hf-resolve-url.d.ts +16 -0
- package/src/utils/hf-resolve-url.js +17 -0
- package/src/version.js +1 -1
- package/src/tooling/node-convert.d.ts +0 -54
|
@@ -10,7 +10,7 @@ import {
|
|
|
10
10
|
import { recordDispatch } from './dispatch.js';
|
|
11
11
|
|
|
12
12
|
const CONV_WORKGROUP_SIZE = WORKGROUP_SIZES.DEFAULT;
|
|
13
|
-
const HEAD_WORKGROUP_SIZE =
|
|
13
|
+
const HEAD_WORKGROUP_SIZE = 128;
|
|
14
14
|
|
|
15
15
|
const CONV_SHADER = /* wgsl */ `
|
|
16
16
|
override WORKGROUP_SIZE: u32 = 256u;
|
|
@@ -79,7 +79,7 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
|
79
79
|
`;
|
|
80
80
|
|
|
81
81
|
const RECURRENT_SHADER = /* wgsl */ `
|
|
82
|
-
override WORKGROUP_SIZE: u32 =
|
|
82
|
+
override WORKGROUP_SIZE: u32 = 128u;
|
|
83
83
|
|
|
84
84
|
struct LinearAttentionParams {
|
|
85
85
|
num_tokens: u32,
|
|
@@ -111,6 +111,8 @@ struct LinearAttentionParams {
|
|
|
111
111
|
@group(0) @binding(8) var<storage, read_write> recurrent_state: array<f32>;
|
|
112
112
|
@group(0) @binding(9) var<storage, read_write> output: array<f32>;
|
|
113
113
|
|
|
114
|
+
var<workgroup> shared_sq: array<f32, WORKGROUP_SIZE>;
|
|
115
|
+
|
|
114
116
|
fn softplus(x: f32) -> f32 {
|
|
115
117
|
if (x > 20.0) {
|
|
116
118
|
return x;
|
|
@@ -131,17 +133,19 @@ fn silu(x: f32) -> f32 {
|
|
|
131
133
|
}
|
|
132
134
|
|
|
133
135
|
@compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
|
|
134
|
-
fn main(@builtin(
|
|
135
|
-
|
|
136
|
+
fn main(@builtin(workgroup_id) wid: vec3<u32>,
|
|
137
|
+
@builtin(local_invocation_id) lid: vec3<u32>) {
|
|
138
|
+
let head = wid.x;
|
|
139
|
+
let vd = lid.x;
|
|
136
140
|
if (head >= params.num_v_heads) {
|
|
137
141
|
return;
|
|
138
142
|
}
|
|
139
143
|
|
|
140
144
|
let head_k_dim = params.head_k_dim;
|
|
141
145
|
let head_v_dim = params.head_v_dim;
|
|
146
|
+
let is_active = vd < head_v_dim;
|
|
142
147
|
let head_scale = inverseSqrt(f32(head_k_dim));
|
|
143
148
|
let recurrent_head_base = head * head_k_dim * head_v_dim;
|
|
144
|
-
let recurrent_head_size = head_k_dim * head_v_dim;
|
|
145
149
|
let q_rep = max(params.q_rep, 1u);
|
|
146
150
|
let src_head = head / q_rep;
|
|
147
151
|
let q_base = src_head * head_k_dim;
|
|
@@ -154,6 +158,7 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
|
154
158
|
let ab_row_base = token_idx * params.num_v_heads + head;
|
|
155
159
|
let out_row_base = token_idx * params.value_dim + head * head_v_dim;
|
|
156
160
|
|
|
161
|
+
// L2 norm for Q and K (redundant across threads but avoids shared memory)
|
|
157
162
|
var q_norm_sq = 0.0;
|
|
158
163
|
var k_norm_sq = 0.0;
|
|
159
164
|
for (var d: u32 = 0u; d < head_k_dim; d = d + 1u) {
|
|
@@ -169,11 +174,16 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
|
169
174
|
let g = a_neg_exp[head] * softplus(a_proj[ab_row_base] + dt_bias[head]);
|
|
170
175
|
let g_exp = exp(g);
|
|
171
176
|
|
|
172
|
-
|
|
173
|
-
|
|
177
|
+
// Decay state — each thread handles head_k_dim elements at stride head_v_dim
|
|
178
|
+
if (is_active) {
|
|
179
|
+
for (var kd: u32 = 0u; kd < head_k_dim; kd = kd + 1u) {
|
|
180
|
+
let state_idx = recurrent_head_base + kd * head_v_dim + vd;
|
|
181
|
+
recurrent_state[state_idx] = recurrent_state[state_idx] * g_exp;
|
|
182
|
+
}
|
|
174
183
|
}
|
|
175
184
|
|
|
176
|
-
|
|
185
|
+
// Delta update — each thread handles one vd slice (no cross-thread dependency)
|
|
186
|
+
if (is_active) {
|
|
177
187
|
var kv_mem = 0.0;
|
|
178
188
|
for (var kd: u32 = 0u; kd < head_k_dim; kd = kd + 1u) {
|
|
179
189
|
let k_normed = conv_out[conv_row_base + k_base + kd] * k_norm_scale;
|
|
@@ -188,21 +198,31 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
|
188
198
|
}
|
|
189
199
|
}
|
|
190
200
|
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
201
|
+
// Output — each thread computes one vd element
|
|
202
|
+
var out_value = 0.0;
|
|
203
|
+
if (is_active) {
|
|
194
204
|
for (var kd: u32 = 0u; kd < head_k_dim; kd = kd + 1u) {
|
|
195
205
|
let q_normed = conv_out[conv_row_base + q_base + kd] * q_norm_scale;
|
|
196
206
|
let state_idx = recurrent_head_base + kd * head_v_dim + vd;
|
|
197
207
|
out_value = out_value + recurrent_state[state_idx] * q_normed;
|
|
198
208
|
}
|
|
199
209
|
output[out_row_base + vd] = out_value;
|
|
200
|
-
let value = out_value;
|
|
201
|
-
mean_sq = mean_sq + value * value;
|
|
202
210
|
}
|
|
203
|
-
let inv_rms = inverseSqrt(mean_sq / f32(head_v_dim) + params.rms_norm_eps);
|
|
204
211
|
|
|
205
|
-
|
|
212
|
+
// RMS norm reduction across vd (workgroup-level)
|
|
213
|
+
shared_sq[vd] = select(0.0, out_value * out_value, is_active);
|
|
214
|
+
workgroupBarrier();
|
|
215
|
+
// Tree reduction
|
|
216
|
+
for (var stride: u32 = WORKGROUP_SIZE / 2u; stride > 0u; stride = stride / 2u) {
|
|
217
|
+
if (vd < stride) {
|
|
218
|
+
shared_sq[vd] = shared_sq[vd] + shared_sq[vd + stride];
|
|
219
|
+
}
|
|
220
|
+
workgroupBarrier();
|
|
221
|
+
}
|
|
222
|
+
let inv_rms = inverseSqrt(shared_sq[0] / f32(head_v_dim) + params.rms_norm_eps);
|
|
223
|
+
|
|
224
|
+
// Apply norm + gate
|
|
225
|
+
if (is_active) {
|
|
206
226
|
let gate = silu(z_proj[z_row_base + vd]);
|
|
207
227
|
let norm_index = select(vd, head * head_v_dim + vd, params.norm_mode == 1u);
|
|
208
228
|
output[out_row_base + vd] = (output[out_row_base + vd] * inv_rms) * norm_weight[norm_index] * gate;
|
|
@@ -415,7 +435,7 @@ export async function runLinearAttentionCoreGPU(qkvTensor, zTensor, aTensor, bTe
|
|
|
415
435
|
recorder,
|
|
416
436
|
recurrentPipeline,
|
|
417
437
|
recurrentBindGroup,
|
|
418
|
-
[
|
|
438
|
+
[layerState.numVHeads, 1, 1],
|
|
419
439
|
'linear_attention_recurrent'
|
|
420
440
|
);
|
|
421
441
|
|
|
@@ -502,7 +522,7 @@ export async function runLinearAttentionCoreGPU(qkvTensor, zTensor, aTensor, bTe
|
|
|
502
522
|
const pass = encoder.beginComputePass({ label: 'linear_attention_recurrent_pass' });
|
|
503
523
|
pass.setPipeline(recurrentPipeline);
|
|
504
524
|
pass.setBindGroup(0, recurrentBindGroup);
|
|
505
|
-
pass.dispatchWorkgroups(
|
|
525
|
+
pass.dispatchWorkgroups(layerState.numVHeads, 1, 1);
|
|
506
526
|
pass.end();
|
|
507
527
|
}
|
|
508
528
|
|
|
@@ -92,6 +92,7 @@ export function getMatmulConfig(variant, constants) {
|
|
|
92
92
|
|
|
93
93
|
|
|
94
94
|
export function isFusedQ4KDisabled(options = {}) {
|
|
95
|
+
if (options.disableFusedQ4K === true) return true;
|
|
95
96
|
const capabilities = getKernelCapabilities();
|
|
96
97
|
const hasSubgroups = capabilities?.hasSubgroups === true;
|
|
97
98
|
|
|
@@ -13,6 +13,7 @@ import type { WeightBuffer } from '../weight-buffer.js';
|
|
|
13
13
|
import type { CommandRecorder } from '../command-recorder.js';
|
|
14
14
|
import type { OutputBufferOptions, OutputDtypeOptions, Vec4Options } from './types.js';
|
|
15
15
|
import type { KernelPathSchema } from '../../config/schema/index.js';
|
|
16
|
+
import type { MatmulDebugConfigSchema } from '../../config/schema/debug.schema.js';
|
|
16
17
|
|
|
17
18
|
/** Matmul kernel options */
|
|
18
19
|
export interface MatmulOptions extends OutputBufferOptions, OutputDtypeOptions, Vec4Options {
|
|
@@ -40,6 +41,8 @@ export interface MatmulOptions extends OutputBufferOptions, OutputDtypeOptions,
|
|
|
40
41
|
preferF16?: boolean;
|
|
41
42
|
/** WGSL override constants for pipeline creation */
|
|
42
43
|
constants?: Record<string, number | boolean>;
|
|
44
|
+
/** Runtime debug controls for attention projection diagnostics. */
|
|
45
|
+
matmulDebug?: MatmulDebugConfigSchema | null;
|
|
43
46
|
}
|
|
44
47
|
|
|
45
48
|
/** Context for base matmul kernel selection rules. */
|
|
@@ -2,7 +2,7 @@ import { getDevice, getKernelCapabilities } from '../device.js';
|
|
|
2
2
|
import { createTensor } from '../tensor.js';
|
|
3
3
|
import { getBuffer, getLayout, getWeightDtype } from '../weight-buffer.js';
|
|
4
4
|
import { log, trace, isTraceEnabled } from '../../debug/index.js';
|
|
5
|
-
import { releaseBuffer } from '../../memory/buffer-pool.js';
|
|
5
|
+
import { releaseBuffer, readBuffer } from '../../memory/buffer-pool.js';
|
|
6
6
|
import { releaseUniformBuffer } from '../uniform-cache.js';
|
|
7
7
|
import { castF16ToF32, recordCastF16ToF32 } from './cast.js';
|
|
8
8
|
import {
|
|
@@ -34,6 +34,24 @@ export { createMatmulBindGroupLayout };
|
|
|
34
34
|
let _runMatmulDebugCount = 0;
|
|
35
35
|
let _recordMatmulDebugCount = 0;
|
|
36
36
|
|
|
37
|
+
function normalizeMatmulDebugConfig(config) {
|
|
38
|
+
if (!config || typeof config !== 'object') {
|
|
39
|
+
return null;
|
|
40
|
+
}
|
|
41
|
+
return {
|
|
42
|
+
enabled: config.enabled === true,
|
|
43
|
+
forceSplitQKV: config.forceSplitQKV === true,
|
|
44
|
+
validateAttentionWeightBuffer: config.validateAttentionWeightBuffer === true,
|
|
45
|
+
failOnSmallAttentionWeightBuffer: config.failOnSmallAttentionWeightBuffer === true,
|
|
46
|
+
logAttentionWeightBuffer: config.logAttentionWeightBuffer === true,
|
|
47
|
+
logProjectionValues: config.logProjectionValues === true,
|
|
48
|
+
};
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
function isAttentionProjectionRole(role = '') {
|
|
52
|
+
return role === 'qkv_proj' || role === 'q_proj' || role === 'k_proj' || role === 'v_proj';
|
|
53
|
+
}
|
|
54
|
+
|
|
37
55
|
function getDebugCounter(isRecord) {
|
|
38
56
|
return isRecord ? _recordMatmulDebugCount : _runMatmulDebugCount;
|
|
39
57
|
}
|
|
@@ -126,6 +144,12 @@ async function executeMatmul(recorder, A, B, M, N, K, options = {}) {
|
|
|
126
144
|
const weightLabel = (B && typeof B === 'object' ? B.label : null) ?? bBuffer?.label ?? null;
|
|
127
145
|
const weightLayout = getLayout(B);
|
|
128
146
|
const weightShape = B?.shape ? `[${B.shape.join(', ')}]` : null;
|
|
147
|
+
const matmulDebug = normalizeMatmulDebugConfig(options.matmulDebug);
|
|
148
|
+
const debugAttention = matmulDebug?.enabled === true;
|
|
149
|
+
const isAttnProj = isAttentionProjectionRole(options.role ?? '');
|
|
150
|
+
const shouldValidateAttentionWeightBuffer = debugAttention && matmulDebug.validateAttentionWeightBuffer;
|
|
151
|
+
const shouldFailOnSmallAttentionWeightBuffer = debugAttention && matmulDebug.failOnSmallAttentionWeightBuffer;
|
|
152
|
+
const shouldLogAttentionWeightBuffer = debugAttention && matmulDebug.logAttentionWeightBuffer;
|
|
129
153
|
|
|
130
154
|
if (isTraceEnabled('kernels') && getDebugCounter(isRecord) < 20) {
|
|
131
155
|
incrementDebugCounter(isRecord);
|
|
@@ -201,6 +225,27 @@ async function executeMatmul(recorder, A, B, M, N, K, options = {}) {
|
|
|
201
225
|
bOffset
|
|
202
226
|
);
|
|
203
227
|
} catch (err) {
|
|
228
|
+
if (shouldValidateAttentionWeightBuffer && isAttnProj && err instanceof Error && err.message.includes('B buffer too small')) {
|
|
229
|
+
const detailParts = [
|
|
230
|
+
`role=${options.role ?? ''}`,
|
|
231
|
+
`layer=${Number.isFinite(options.layerIdx) ? options.layerIdx : '?'}`,
|
|
232
|
+
`M=${M}`,
|
|
233
|
+
`N=${N}`,
|
|
234
|
+
`K=${K}`,
|
|
235
|
+
];
|
|
236
|
+
if (weightDtype) detailParts.push(`weightDtype=${weightDtype}`);
|
|
237
|
+
if (weightLayout) detailParts.push(`weightLayout=${weightLayout}`);
|
|
238
|
+
if (weightShape) detailParts.push(`shape=${weightShape}`);
|
|
239
|
+
if (weightLabel) detailParts.push(`label=${weightLabel}`);
|
|
240
|
+
if (Number.isFinite(bBuffer?.size)) detailParts.push(`bSize=${bBuffer.size}`);
|
|
241
|
+
const detail = detailParts.join(' ');
|
|
242
|
+
if (shouldLogAttentionWeightBuffer) {
|
|
243
|
+
log.warn('MatmulQKVProbe', `${err.message} | ${detail}`);
|
|
244
|
+
}
|
|
245
|
+
if (shouldFailOnSmallAttentionWeightBuffer) {
|
|
246
|
+
throw new Error(`${err.message}${detail ? ` (${detail})` : ''}`);
|
|
247
|
+
}
|
|
248
|
+
}
|
|
204
249
|
if (!isRecord && err instanceof Error && err.message.includes('B buffer too small')) {
|
|
205
250
|
const detailParts = [];
|
|
206
251
|
if (weightLabel) detailParts.push(`label=${weightLabel}`);
|
|
@@ -226,6 +271,15 @@ async function executeMatmul(recorder, A, B, M, N, K, options = {}) {
|
|
|
226
271
|
trace.kernels(`MATMUL_LARGE: N=${N}, variant=${variant}, aDtype=${aDtype}, bDtype=${bDtype}, transposeB=${transposeB}`);
|
|
227
272
|
}
|
|
228
273
|
|
|
274
|
+
if (isAttnProj && shouldLogAttentionWeightBuffer) {
|
|
275
|
+
log.warn('MatmulQKVProbe',
|
|
276
|
+
`role=${options.role ?? ''} layer=${Number.isFinite(options.layerIdx) ? options.layerIdx : '?'} ` +
|
|
277
|
+
`M=${M} N=${N} K=${K} transposeB=${transposeB} bSize=${bBuffer?.size ?? 0} ` +
|
|
278
|
+
`requiredB=${bindingSizes?.bBindingSize ?? 'n/a'} weightShape=${weightShape ?? 'n/a'} ` +
|
|
279
|
+
`weightDtype=${weightDtype ?? 'unknown'} weightLayout=${weightLayout ?? 'unknown'}`
|
|
280
|
+
);
|
|
281
|
+
}
|
|
282
|
+
|
|
229
283
|
const config = getMatmulConfig(variant, constants);
|
|
230
284
|
const kernel = new MatmulKernel(device);
|
|
231
285
|
const pipeline = await getMatmulPipeline(variant, constants);
|
|
@@ -238,6 +292,14 @@ async function executeMatmul(recorder, A, B, M, N, K, options = {}) {
|
|
|
238
292
|
);
|
|
239
293
|
const ownsOutput = outputBuffer == null;
|
|
240
294
|
|
|
295
|
+
if (isAttnProj && shouldLogAttentionWeightBuffer) {
|
|
296
|
+
log.warn('MatmulVariantDiag',
|
|
297
|
+
`role=${options.role ?? ''} layer=${Number.isFinite(options.layerIdx) ? options.layerIdx : '?'} mode=${mode} ` +
|
|
298
|
+
`variant=${variant} useQ4KFused=${useQ4KFused} useGemv=${useGemv} ` +
|
|
299
|
+
`aDtype=${aDtype} bDtype=${bDtype} output=${actualOutputDtype}`
|
|
300
|
+
);
|
|
301
|
+
}
|
|
302
|
+
|
|
241
303
|
if (!Number.isFinite(outputSize) || outputSize <= 0) {
|
|
242
304
|
throw new Error(`[${opLabel}] Invalid output size: ${outputSize} (M=${M}, N=${N})`);
|
|
243
305
|
}
|
|
@@ -290,6 +352,13 @@ async function executeMatmul(recorder, A, B, M, N, K, options = {}) {
|
|
|
290
352
|
kernel.dispatch(pipeline, bindGroup, dispatchPlan.workgroups);
|
|
291
353
|
}
|
|
292
354
|
completed = true;
|
|
355
|
+
if (!isRecord && matmulDebug?.logProjectionValues && isAttnProj && M === 1 && options.layerIdx === 0) {
|
|
356
|
+
await device.queue.onSubmittedWorkDone();
|
|
357
|
+
const raw = await readBuffer(C);
|
|
358
|
+
const numVals = Math.min(8, Math.floor(raw.byteLength / 4));
|
|
359
|
+
const vals = numVals > 0 ? new Float32Array(raw, 0, numVals) : [];
|
|
360
|
+
log.warn('ProjectionProbe', `role=${options.role ?? ''} L0 M1 first8_f32: ${Array.from(vals).map(v => v.toFixed(5)).join(' ')}`);
|
|
361
|
+
}
|
|
293
362
|
return createTensor(C, actualOutputDtype, [M, N], 'matmul_output');
|
|
294
363
|
} finally {
|
|
295
364
|
if (!isRecord && uniformBuffer) {
|
|
@@ -5,7 +5,11 @@
|
|
|
5
5
|
// 1. Use subgroupAdd() for reduction - much faster than shared memory
|
|
6
6
|
// 2. Vectorized vec4 loads for weights
|
|
7
7
|
// 3. Each workgroup processes multiple output columns
|
|
8
|
-
// 4.
|
|
8
|
+
// 4. Warp-stride loop for row-major (transpose_b=1): all threads in a column
|
|
9
|
+
// step through K together, so adjacent threads load adjacent addresses.
|
|
10
|
+
// At each step, 64 threads × 8 bytes = 512 bytes from 4 consecutive cache
|
|
11
|
+
// lines → 100% cache-line utilization vs ~10% for the old contiguous-range
|
|
12
|
+
// pattern (where threads were 80 bytes apart in the same iteration).
|
|
9
13
|
//
|
|
10
14
|
// A is f32 (activations), B is f16 (weights), C is f32.
|
|
11
15
|
// transpose_b=0: B is [K, N] (GGUF/column-major), access B[k * N + col]
|
|
@@ -69,40 +73,29 @@ fn main(
|
|
|
69
73
|
// Each thread computes partial sum for its assigned k values
|
|
70
74
|
var partial_sum: f32 = 0.0;
|
|
71
75
|
|
|
72
|
-
// Only do work if this column is valid
|
|
73
76
|
if (is_valid) {
|
|
74
|
-
// Process K in chunks, each thread handles K/64 elements
|
|
75
|
-
let k_per_thread = (u.K + THREADS_PER_COL - 1u) / THREADS_PER_COL;
|
|
76
|
-
let k_start = thread_in_col * k_per_thread;
|
|
77
|
-
let k_end = min(k_start + k_per_thread, u.K);
|
|
78
|
-
|
|
79
|
-
// Main loop - process 4 elements at a time when aligned
|
|
80
|
-
var k = k_start;
|
|
81
|
-
let k_aligned_end = k_start + ((k_end - k_start) / 4u) * 4u;
|
|
82
|
-
|
|
83
77
|
if (u.transpose_b == 1u) {
|
|
84
|
-
// B is [N, K] (
|
|
78
|
+
// B is [N, K] (row-major): B[col, k] = B[col * K + k]
|
|
79
|
+
// Warp-stride: step THREADS_PER_COL elements per outer iteration so that
|
|
80
|
+
// all wavefront threads load consecutive addresses simultaneously.
|
|
81
|
+
// At each step, 64 threads × 2 bytes = 128 bytes = exactly 1 cache line → 100% utilization.
|
|
85
82
|
let b_row_offset = col * u.K;
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
let a3 = A[k + 3u];
|
|
92
|
-
|
|
93
|
-
let b0 = f32(B[b_row_offset + k]);
|
|
94
|
-
let b1 = f32(B[b_row_offset + k + 1u]);
|
|
95
|
-
let b2 = f32(B[b_row_offset + k + 2u]);
|
|
96
|
-
let b3 = f32(B[b_row_offset + k + 3u]);
|
|
97
|
-
|
|
98
|
-
partial_sum = partial_sum + a0 * b0 + a1 * b1 + a2 * b2 + a3 * b3;
|
|
99
|
-
}
|
|
100
|
-
|
|
101
|
-
for (; k < k_end; k = k + 1u) {
|
|
102
|
-
partial_sum = partial_sum + A[k] * f32(B[b_row_offset + k]);
|
|
83
|
+
for (var k_base: u32 = 0u; k_base < u.K; k_base = k_base + THREADS_PER_COL) {
|
|
84
|
+
let k = k_base + thread_in_col;
|
|
85
|
+
if (k < u.K) {
|
|
86
|
+
partial_sum = partial_sum + A[k] * f32(B[b_row_offset + k]);
|
|
87
|
+
}
|
|
103
88
|
}
|
|
104
89
|
} else {
|
|
105
|
-
// B is [K, N] (
|
|
90
|
+
// B is [K, N] (column-major): B[k, col] = B[k * N + col]
|
|
91
|
+
// Contiguous-range per thread: sequential access within each thread.
|
|
92
|
+
let k_per_thread = (u.K + THREADS_PER_COL - 1u) / THREADS_PER_COL;
|
|
93
|
+
let k_start = thread_in_col * k_per_thread;
|
|
94
|
+
let k_end = min(k_start + k_per_thread, u.K);
|
|
95
|
+
|
|
96
|
+
var k = k_start;
|
|
97
|
+
let k_aligned_end = k_start + ((k_end - k_start) / 4u) * 4u;
|
|
98
|
+
|
|
106
99
|
for (; k < k_aligned_end; k = k + 4u) {
|
|
107
100
|
let a0 = A[k];
|
|
108
101
|
let a1 = A[k + 1u];
|
|
@@ -189,38 +182,36 @@ fn main_multicol(
|
|
|
189
182
|
var partial_sum: f32 = 0.0;
|
|
190
183
|
|
|
191
184
|
if (is_valid) {
|
|
192
|
-
// Each of 8 threads splits K
|
|
193
|
-
let k_per_thread = (u.K + MULTICOL_THREADS_PER_COL - 1u) / MULTICOL_THREADS_PER_COL;
|
|
194
|
-
let k_start = thread_in_col * k_per_thread;
|
|
195
|
-
let k_end = min(k_start + k_per_thread, u.K);
|
|
196
|
-
|
|
197
|
-
// Unroll by 4 for ILP
|
|
198
|
-
var k = k_start;
|
|
199
|
-
let k_aligned_end = k_start + ((k_end - k_start) / 4u) * 4u;
|
|
200
|
-
|
|
201
185
|
if (u.transpose_b == 1u) {
|
|
202
|
-
// B is [N, K] (
|
|
186
|
+
// B is [N, K] (row-major): B[col, k] = B[col * K + k]
|
|
187
|
+
// Warp-stride: step MULTICOL_THREADS_PER_COL vec4 groups per outer iteration.
|
|
188
|
+
// Adjacent threads in the same column load adjacent vec4 groups → coalesced.
|
|
189
|
+
let K4 = u.K / 4u;
|
|
203
190
|
let b_row_offset = col * u.K;
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
for (; k < k_end; k = k + 1u) {
|
|
220
|
-
partial_sum = partial_sum + A[k] * f32(B[b_row_offset + k]);
|
|
191
|
+
for (var k4_base: u32 = 0u; k4_base < K4; k4_base = k4_base + MULTICOL_THREADS_PER_COL) {
|
|
192
|
+
let k4 = k4_base + thread_in_col;
|
|
193
|
+
if (k4 < K4) {
|
|
194
|
+
let k = k4 * 4u;
|
|
195
|
+
let a0 = A[k];
|
|
196
|
+
let a1 = A[k + 1u];
|
|
197
|
+
let a2 = A[k + 2u];
|
|
198
|
+
let a3 = A[k + 3u];
|
|
199
|
+
let b0 = f32(B[b_row_offset + k]);
|
|
200
|
+
let b1 = f32(B[b_row_offset + k + 1u]);
|
|
201
|
+
let b2 = f32(B[b_row_offset + k + 2u]);
|
|
202
|
+
let b3 = f32(B[b_row_offset + k + 3u]);
|
|
203
|
+
partial_sum = partial_sum + a0 * b0 + a1 * b1 + a2 * b2 + a3 * b3;
|
|
204
|
+
}
|
|
221
205
|
}
|
|
222
206
|
} else {
|
|
223
|
-
// B is [K, N] (
|
|
207
|
+
// B is [K, N] (column-major): B[k, col] = B[k * N + col]
|
|
208
|
+
let k_per_thread = (u.K + MULTICOL_THREADS_PER_COL - 1u) / MULTICOL_THREADS_PER_COL;
|
|
209
|
+
let k_start = thread_in_col * k_per_thread;
|
|
210
|
+
let k_end = min(k_start + k_per_thread, u.K);
|
|
211
|
+
|
|
212
|
+
var k = k_start;
|
|
213
|
+
let k_aligned_end = k_start + ((k_end - k_start) / 4u) * 4u;
|
|
214
|
+
|
|
224
215
|
for (; k < k_aligned_end; k = k + 4u) {
|
|
225
216
|
let a0 = A[k];
|
|
226
217
|
let a1 = A[k + 1u];
|
|
@@ -245,7 +236,7 @@ fn main_multicol(
|
|
|
245
236
|
multicol_wg_sums[local_id] = partial_sum;
|
|
246
237
|
workgroupBarrier();
|
|
247
238
|
|
|
248
|
-
// Thread 0 of each column reduces its
|
|
239
|
+
// Thread 0 of each column reduces its MULTICOL_THREADS_PER_COL values
|
|
249
240
|
if (thread_in_col == 0u && is_valid) {
|
|
250
241
|
var final_sum: f32 = 0.0;
|
|
251
242
|
let base = col_in_wg * MULTICOL_THREADS_PER_COL;
|
|
@@ -282,30 +273,37 @@ fn main_vec4(
|
|
|
282
273
|
if (is_valid) {
|
|
283
274
|
// K is guaranteed to be multiple of 4
|
|
284
275
|
let K4 = u.K / 4u;
|
|
285
|
-
let k4_per_thread = (K4 + THREADS_PER_COL - 1u) / THREADS_PER_COL;
|
|
286
|
-
let k4_start = thread_in_col * k4_per_thread;
|
|
287
|
-
let k4_end = min(k4_start + k4_per_thread, K4);
|
|
288
276
|
|
|
289
277
|
if (u.transpose_b == 1u) {
|
|
290
|
-
// B is [N, K] (
|
|
278
|
+
// B is [N, K] (row-major): B[col, k] = B[col * K + k]
|
|
279
|
+
// Warp-stride: step THREADS_PER_COL vec4 groups per outer iteration so that
|
|
280
|
+
// adjacent threads load adjacent groups → 100% cache-line utilization.
|
|
281
|
+
// At each step: 64 threads × 4 f16 × 2 bytes = 512 bytes from 4 consecutive
|
|
282
|
+
// cache lines, vs the old contiguous-range pattern (~10% utilization).
|
|
291
283
|
let b_row_offset = col * u.K;
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
f32(
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
284
|
+
for (var k4_base: u32 = 0u; k4_base < K4; k4_base = k4_base + THREADS_PER_COL) {
|
|
285
|
+
let k4 = k4_base + thread_in_col;
|
|
286
|
+
if (k4 < K4) {
|
|
287
|
+
let k = k4 * 4u;
|
|
288
|
+
|
|
289
|
+
let a = vec4<f32>(A[k], A[k + 1u], A[k + 2u], A[k + 3u]);
|
|
290
|
+
|
|
291
|
+
let b = vec4<f32>(
|
|
292
|
+
f32(B[b_row_offset + k]),
|
|
293
|
+
f32(B[b_row_offset + k + 1u]),
|
|
294
|
+
f32(B[b_row_offset + k + 2u]),
|
|
295
|
+
f32(B[b_row_offset + k + 3u])
|
|
296
|
+
);
|
|
297
|
+
|
|
298
|
+
partial_sum = partial_sum + dot(a, b);
|
|
299
|
+
}
|
|
306
300
|
}
|
|
307
301
|
} else {
|
|
308
|
-
// B is [K, N] (
|
|
302
|
+
// B is [K, N] (column-major): B[k, col] = B[k * N + col]
|
|
303
|
+
// Contiguous-range per thread: sequential access within each thread.
|
|
304
|
+
let k4_per_thread = (K4 + THREADS_PER_COL - 1u) / THREADS_PER_COL;
|
|
305
|
+
let k4_start = thread_in_col * k4_per_thread;
|
|
306
|
+
let k4_end = min(k4_start + k4_per_thread, K4);
|
|
309
307
|
for (var k4: u32 = k4_start; k4 < k4_end; k4 = k4 + 1u) {
|
|
310
308
|
let k = k4 * 4u;
|
|
311
309
|
|
|
@@ -342,4 +340,4 @@ fn main_vec4(
|
|
|
342
340
|
}
|
|
343
341
|
C[col] = final_sum * u.alpha;
|
|
344
342
|
}
|
|
345
|
-
}
|
|
343
|
+
}
|
|
@@ -7,7 +7,6 @@ import { createPipeline, createUniformBufferWithView, getOrCreateBindGroupLayout
|
|
|
7
7
|
import { allowReadback } from '../perf-guards.js';
|
|
8
8
|
import { selectRuleValue as selectKernelRuleValue } from './rule-registry.js';
|
|
9
9
|
import { selectRuleValue as selectSharedRuleValue } from '../../rules/rule-registry.js';
|
|
10
|
-
import { getKernelThresholds } from '../../config/schema/index.js';
|
|
11
10
|
|
|
12
11
|
|
|
13
12
|
function getSampleBindGroupLayout(device) {
|
|
@@ -96,8 +95,7 @@ async function resolveArgmaxPipelines(device, vocabSize, variants) {
|
|
|
96
95
|
const argmaxPipeline = await createSamplePipeline(device, variants.argmax);
|
|
97
96
|
const numWorkgroups = Math.min(WORKGROUP_SIZES.DEFAULT, Math.ceil(vocabSize / WORKGROUP_SIZES.DEFAULT));
|
|
98
97
|
const useSinglePassArgmax = numWorkgroups === 1;
|
|
99
|
-
const
|
|
100
|
-
const reducePipeline = useSinglePassArgmax || vocabSize <= argmaxReduceVocabThreshold
|
|
98
|
+
const reducePipeline = useSinglePassArgmax
|
|
101
99
|
? null
|
|
102
100
|
: await createSamplePipeline(device, variants.argmaxReduce);
|
|
103
101
|
const singlePassPipeline = useSinglePassArgmax
|
|
@@ -40,6 +40,16 @@ fn apply_softcap(x: f32, softcap: f32) -> f32 {
|
|
|
40
40
|
return softcap * tanh(x / softcap);
|
|
41
41
|
}
|
|
42
42
|
|
|
43
|
+
fn candidate_beats(candidate_value: f32, candidate_index: u32, best_value: f32, best_index: u32) -> bool {
|
|
44
|
+
if (candidate_value > best_value) {
|
|
45
|
+
return true;
|
|
46
|
+
}
|
|
47
|
+
if (candidate_value < best_value) {
|
|
48
|
+
return false;
|
|
49
|
+
}
|
|
50
|
+
return candidate_index < best_index;
|
|
51
|
+
}
|
|
52
|
+
|
|
43
53
|
@group(0) @binding(0) var<uniform> u: Uniforms;
|
|
44
54
|
@group(0) @binding(1) var<storage, read> logits: array<f32>; // [vocabSize]
|
|
45
55
|
@group(0) @binding(2) var<storage, read_write> output: array<u32>; // [N] - selected tokens
|
|
@@ -87,7 +97,7 @@ fn find_topk_phase1(
|
|
|
87
97
|
if (idx != pad_id) {
|
|
88
98
|
// Apply softcapping before temperature scaling
|
|
89
99
|
let val = apply_softcap(logits[idx], softcap) / temperature;
|
|
90
|
-
if (val
|
|
100
|
+
if (candidate_beats(val, idx, local_max, local_max_idx)) {
|
|
91
101
|
local_max = val;
|
|
92
102
|
local_max_idx = idx;
|
|
93
103
|
}
|
|
@@ -103,7 +113,12 @@ fn find_topk_phase1(
|
|
|
103
113
|
var stride = WORKGROUP_SIZE / 2u;
|
|
104
114
|
while (stride > 0u) {
|
|
105
115
|
if (thread_idx < stride) {
|
|
106
|
-
if (
|
|
116
|
+
if (candidate_beats(
|
|
117
|
+
shared_values[thread_idx + stride],
|
|
118
|
+
shared_indices[thread_idx + stride],
|
|
119
|
+
shared_values[thread_idx],
|
|
120
|
+
shared_indices[thread_idx]
|
|
121
|
+
)) {
|
|
107
122
|
shared_values[thread_idx] = shared_values[thread_idx + stride];
|
|
108
123
|
shared_indices[thread_idx] = shared_indices[thread_idx + stride];
|
|
109
124
|
}
|
|
@@ -150,7 +165,7 @@ fn find_topk_phase2(
|
|
|
150
165
|
var max_val = shared_values[k];
|
|
151
166
|
|
|
152
167
|
for (var i: u32 = k + 1u; i < num_candidates; i = i + 1u) {
|
|
153
|
-
if (shared_values[i]
|
|
168
|
+
if (candidate_beats(shared_values[i], shared_indices[i], max_val, shared_indices[max_idx])) {
|
|
154
169
|
max_val = shared_values[i];
|
|
155
170
|
max_idx = i;
|
|
156
171
|
}
|
|
@@ -249,7 +264,7 @@ fn sample_single_pass(
|
|
|
249
264
|
if (idx != pad_id) {
|
|
250
265
|
// Apply softcapping before temperature scaling
|
|
251
266
|
let val = apply_softcap(logits[idx], softcap) / temperature;
|
|
252
|
-
if (val
|
|
267
|
+
if (candidate_beats(val, idx, local_max, local_max_idx)) {
|
|
253
268
|
local_max = val;
|
|
254
269
|
local_max_idx = idx;
|
|
255
270
|
}
|
|
@@ -265,7 +280,12 @@ fn sample_single_pass(
|
|
|
265
280
|
var stride = WORKGROUP_SIZE / 2u;
|
|
266
281
|
while (stride > 0u) {
|
|
267
282
|
if (thread_idx < stride) {
|
|
268
|
-
if (
|
|
283
|
+
if (candidate_beats(
|
|
284
|
+
shared_values[thread_idx + stride],
|
|
285
|
+
shared_indices[thread_idx + stride],
|
|
286
|
+
shared_values[thread_idx],
|
|
287
|
+
shared_indices[thread_idx]
|
|
288
|
+
)) {
|
|
269
289
|
shared_values[thread_idx] = shared_values[thread_idx + stride];
|
|
270
290
|
shared_indices[thread_idx] = shared_indices[thread_idx + stride];
|
|
271
291
|
}
|
|
@@ -308,7 +328,7 @@ fn argmax(
|
|
|
308
328
|
if (idx != pad_id) {
|
|
309
329
|
// Apply softcapping (argmax is greedy, no temperature)
|
|
310
330
|
let val = apply_softcap(logits[idx], softcap);
|
|
311
|
-
if (val
|
|
331
|
+
if (candidate_beats(val, idx, local_max, local_max_idx)) {
|
|
312
332
|
local_max = val;
|
|
313
333
|
local_max_idx = idx;
|
|
314
334
|
}
|
|
@@ -324,7 +344,12 @@ fn argmax(
|
|
|
324
344
|
var stride = WORKGROUP_SIZE / 2u;
|
|
325
345
|
while (stride > 0u) {
|
|
326
346
|
if (thread_idx < stride) {
|
|
327
|
-
if (
|
|
347
|
+
if (candidate_beats(
|
|
348
|
+
shared_values[thread_idx + stride],
|
|
349
|
+
shared_indices[thread_idx + stride],
|
|
350
|
+
shared_values[thread_idx],
|
|
351
|
+
shared_indices[thread_idx]
|
|
352
|
+
)) {
|
|
328
353
|
shared_values[thread_idx] = shared_values[thread_idx + stride];
|
|
329
354
|
shared_indices[thread_idx] = shared_indices[thread_idx + stride];
|
|
330
355
|
}
|
|
@@ -362,7 +387,12 @@ fn argmax_reduce(
|
|
|
362
387
|
var stride = WORKGROUP_SIZE / 2u;
|
|
363
388
|
while (stride > 0u) {
|
|
364
389
|
if (thread_idx < stride) {
|
|
365
|
-
if (
|
|
390
|
+
if (candidate_beats(
|
|
391
|
+
shared_values[thread_idx + stride],
|
|
392
|
+
shared_indices[thread_idx + stride],
|
|
393
|
+
shared_values[thread_idx],
|
|
394
|
+
shared_indices[thread_idx]
|
|
395
|
+
)) {
|
|
366
396
|
shared_values[thread_idx] = shared_values[thread_idx + stride];
|
|
367
397
|
shared_indices[thread_idx] = shared_indices[thread_idx + stride];
|
|
368
398
|
}
|
|
@@ -374,4 +404,4 @@ fn argmax_reduce(
|
|
|
374
404
|
if (thread_idx == 0u) {
|
|
375
405
|
output[u.output_index] = shared_indices[0];
|
|
376
406
|
}
|
|
377
|
-
}
|
|
407
|
+
}
|