@simulatte/doppler 0.1.7 → 0.1.9
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +32 -0
- package/README.md +25 -6
- package/package.json +25 -38
- package/src/browser/browser-converter.js +5 -0
- package/src/client/doppler-api.browser.js +6 -0
- package/src/client/doppler-api.d.ts +3 -0
- package/src/client/doppler-api.js +11 -2
- package/src/client/doppler-registry.js +3 -5
- package/src/client/doppler-registry.json +2 -2
- package/src/config/kernel-path-loader.d.ts +5 -0
- package/src/config/kernel-path-loader.js +13 -0
- package/src/config/kernels/kernel-ref-digests.js +23 -21
- package/src/config/kernels/moe/mixtral.paths.json +46 -0
- package/src/config/kernels/registry.json +74 -0
- package/src/config/loader.js +9 -0
- package/src/config/merge-contract-check.js +7 -0
- package/src/config/platforms/loader.js +3 -1
- package/src/config/presets/kernel-paths/gemma3-q4k-dequant-f32a-nosubgroups.json +16 -16
- package/src/config/presets/kernel-paths/gemma3-q4k-dequant-f32a-online.json +8 -8
- package/src/config/presets/kernel-paths/gemma3-q4k-dequant-f32a-small-attn.json +61 -0
- package/src/config/presets/kernel-paths/gemma3-q4k-dequant-f32w-f32a-online.json +56 -0
- package/src/config/presets/kernel-paths/lfm2-q4k-dequant-f32a-nosubgroups.json +61 -0
- package/src/config/presets/kernel-paths/registry.json +21 -0
- package/src/config/presets/models/gemma2.json +2 -1
- package/src/config/presets/models/gemma3.json +4 -1
- package/src/config/presets/models/gemma4.json +61 -0
- package/src/config/presets/models/granite-docling.json +70 -0
- package/src/config/presets/models/lfm2.json +6 -1
- package/src/config/presets/models/qwen3.json +4 -3
- package/src/config/presets/models/qwen3_5.json +16 -0
- package/src/config/presets/models/qwen3_vl.json +40 -0
- package/src/config/presets/runtime/experiments/bench/gemma3-bench-q4k.json +2 -1
- package/src/config/presets/runtime/experiments/verify/lfm2-verify.json +46 -0
- package/src/config/presets/runtime/experiments/verify/translategemma-verify.json +39 -0
- package/src/config/presets/runtime/model/qwen3-5-layer-probe.json +52 -0
- package/src/config/presets/runtime/model/qwen3-5-linear-attn-debug.json +90 -0
- package/src/config/presets/runtime/modes/trace-layers.json +1 -0
- package/src/config/presets/runtime/tiers/gemma4-16gb.json +69 -0
- package/src/config/presets/runtime/tiers/gemma4-24gb.json +66 -0
- package/src/config/presets/runtime/tiers/gemma4-32gb.json +66 -0
- package/src/config/runtime.js +3 -0
- package/src/config/schema/conversion.schema.d.ts +1 -0
- package/src/config/schema/debug.schema.d.ts +40 -0
- package/src/config/schema/debug.schema.js +28 -0
- package/src/config/schema/index.js +2 -0
- package/src/config/schema/inference-defaults.schema.js +1 -1
- package/src/config/schema/kernel-path.schema.d.ts +1 -0
- package/src/config/schema/manifest.schema.d.ts +1 -1
- package/src/config/schema/manifest.schema.js +1 -1
- package/src/config/schema/memory-limits.schema.js +2 -2
- package/src/config/schema/storage.schema.js +2 -2
- package/src/converter/conversion-plan.js +11 -3
- package/src/converter/core.js +19 -8
- package/src/converter/manifest-inference.js +12 -22
- package/src/converter/parsers/transformer.js +4 -0
- package/src/converter/quantization-info.js +5 -1
- package/src/converter/quantizer.d.ts +5 -0
- package/src/converter/quantizer.js +34 -12
- package/src/converter/rope-config.js +8 -6
- package/src/converter/tokenizer-utils.d.ts +1 -0
- package/src/converter/tokenizer-utils.js +4 -1
- package/src/debug/reference/hf_qwen35_linear_attn_debug.py +268 -0
- package/src/distribution/shard-delivery.js +40 -1
- package/src/formats/rdrr/classification.js +32 -0
- package/src/formats/rdrr/parsing.d.ts +4 -0
- package/src/formats/rdrr/parsing.js +14 -1
- package/src/gpu/kernel-runtime.js +4 -2
- package/src/gpu/kernels/attention.js +2 -1
- package/src/gpu/kernels/dequant_f16_out.wgsl +4 -2
- package/src/gpu/kernels/dequant_f16_out_vec4.wgsl +5 -2
- package/src/gpu/kernels/dequant_shared.wgsl +4 -2
- package/src/gpu/kernels/dequant_shared_vec4.wgsl +4 -2
- package/src/gpu/kernels/dequant_subgroup.wgsl +6 -2
- package/src/gpu/kernels/gated-short-conv.d.ts +63 -0
- package/src/gpu/kernels/gated-short-conv.js +284 -0
- package/src/gpu/kernels/index.d.ts +8 -0
- package/src/gpu/kernels/index.js +6 -0
- package/src/gpu/kernels/linear-attention-core.js +37 -17
- package/src/gpu/kernels/matmul-selection.js +48 -4
- package/src/gpu/kernels/matmul.d.ts +5 -0
- package/src/gpu/kernels/matmul.js +71 -2
- package/src/gpu/kernels/matmul_gemv_subgroup.wgsl +77 -79
- package/src/gpu/kernels/rmsnorm.js +9 -2
- package/src/gpu/kernels/sample.js +1 -3
- package/src/gpu/kernels/sample.wgsl +39 -9
- package/src/gpu/kernels/sample_f16.wgsl +38 -8
- package/src/gpu/kernels/shader-cache.js +9 -4
- package/src/gpu/kernels/split_qg.d.ts +50 -0
- package/src/gpu/kernels/split_qg.js +46 -0
- package/src/gpu/kernels/split_qg.wgsl +58 -0
- package/src/gpu/kernels/split_qg_f16.wgsl +62 -0
- package/src/gpu/weight-buffer.d.ts +1 -1
- package/src/gpu/weight-buffer.js +1 -1
- package/src/inference/browser-harness.d.ts +2 -0
- package/src/inference/browser-harness.js +20 -1
- package/src/inference/kv-cache/base.js +3 -10
- package/src/inference/pipelines/diffusion/helpers.js +3 -0
- package/src/inference/pipelines/diffusion/pipeline.js +2 -1
- package/src/inference/pipelines/diffusion/text-encoder-gpu.js +10 -3
- package/src/inference/pipelines/text/attention/output-projection.d.ts +12 -0
- package/src/inference/pipelines/text/attention/output-projection.js +8 -0
- package/src/inference/pipelines/text/attention/projections.d.ts +13 -1
- package/src/inference/pipelines/text/attention/projections.js +54 -13
- package/src/inference/pipelines/text/attention/record.js +16 -6
- package/src/inference/pipelines/text/attention/run.js +59 -6
- package/src/inference/pipelines/text/config.d.ts +1 -0
- package/src/inference/pipelines/text/config.js +46 -4
- package/src/inference/pipelines/text/embed.js +26 -7
- package/src/inference/pipelines/text/execution-plan.js +5 -4
- package/src/inference/pipelines/text/execution-v0-runtime-builders.js +10 -3
- package/src/inference/pipelines/text/execution-v0.js +12 -1
- package/src/inference/pipelines/text/generator-helpers.js +1 -0
- package/src/inference/pipelines/text/generator-runtime.js +19 -0
- package/src/inference/pipelines/text/generator-steps.d.ts +15 -0
- package/src/inference/pipelines/text/generator-steps.js +71 -26
- package/src/inference/pipelines/text/generator.d.ts +5 -0
- package/src/inference/pipelines/text/generator.js +353 -166
- package/src/inference/pipelines/text/init.d.ts +15 -0
- package/src/inference/pipelines/text/init.js +35 -10
- package/src/inference/pipelines/text/layer.js +38 -8
- package/src/inference/pipelines/text/linear-attention.d.ts +5 -0
- package/src/inference/pipelines/text/linear-attention.js +33 -3
- package/src/inference/pipelines/text/logits/gpu.js +2 -2
- package/src/inference/pipelines/text/logits/index.d.ts +6 -1
- package/src/inference/pipelines/text/logits/index.js +3 -1
- package/src/inference/pipelines/text/model-load.js +3 -0
- package/src/inference/pipelines/text/moe-gpu.js +21 -3
- package/src/inference/pipelines/text/moe-shape-validator.d.ts +9 -0
- package/src/inference/pipelines/text/moe-shape-validator.js +31 -11
- package/src/inference/pipelines/text/ops.js +123 -53
- package/src/inference/pipelines/text/probes.js +1 -0
- package/src/inference/pipelines/text/sampling.js +52 -6
- package/src/inference/pipelines/text/state.js +2 -0
- package/src/inference/pipelines/text.d.ts +5 -0
- package/src/inference/pipelines/text.js +59 -1
- package/src/inference/pipelines/vision/encoder.js +386 -0
- package/src/inference/pipelines/vision/image-preprocess.js +151 -0
- package/src/inference/pipelines/vision/index.js +173 -0
- package/src/inference/pipelines/vision/ops.js +78 -0
- package/src/inference/pipelines/vision/patch-embed.js +151 -0
- package/src/inference/test-harness.js +11 -9
- package/src/loader/doppler-loader.d.ts +3 -0
- package/src/loader/doppler-loader.js +20 -3
- package/src/loader/experts/expert-cache.js +6 -2
- package/src/loader/experts/expert-loader.js +6 -2
- package/src/loader/final-weights-loader.js +2 -0
- package/src/loader/layer-loader.js +42 -3
- package/src/loader/manifest-config.js +3 -1
- package/src/loader/shard-cache.js +3 -2
- package/src/loader/tensors/tensor-loader.d.ts +3 -0
- package/src/loader/tensors/tensor-loader.js +130 -4
- package/src/rules/inference/dtype.rules.json +5 -0
- package/src/rules/inference/kernel-path.rules.json +2 -2
- package/src/rules/kernels/moe.rules.mixtral.json +75 -0
- package/src/rules/kernels/softmax.rules.json +2 -0
- package/src/rules/kernels/split-qg.rules.json +6 -0
- package/src/rules/rule-registry.d.ts +1 -0
- package/src/rules/rule-registry.js +4 -0
- package/src/storage/downloader.js +2 -1
- package/src/storage/quickstart-downloader.d.ts +3 -0
- package/src/storage/quickstart-downloader.js +27 -30
- package/src/storage/shard-manager.js +4 -3
- package/src/tooling/conversion-config-materializer.js +3 -5
- package/src/tooling/node-converter.js +28 -7
- package/src/tooling/node-source-runtime.js +65 -5
- package/src/tooling/node-webgpu.js +24 -7
- package/src/types/model.d.ts +5 -0
- package/src/utils/hf-resolve-url.d.ts +16 -0
- package/src/utils/hf-resolve-url.js +17 -0
- package/src/version.js +1 -1
- package/tools/doppler-cli.js +6 -1
- package/src/tooling/node-convert.d.ts +0 -54
|
@@ -10,7 +10,7 @@ import {
|
|
|
10
10
|
import { recordDispatch } from './dispatch.js';
|
|
11
11
|
|
|
12
12
|
const CONV_WORKGROUP_SIZE = WORKGROUP_SIZES.DEFAULT;
|
|
13
|
-
const HEAD_WORKGROUP_SIZE =
|
|
13
|
+
const HEAD_WORKGROUP_SIZE = 128;
|
|
14
14
|
|
|
15
15
|
const CONV_SHADER = /* wgsl */ `
|
|
16
16
|
override WORKGROUP_SIZE: u32 = 256u;
|
|
@@ -79,7 +79,7 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
|
79
79
|
`;
|
|
80
80
|
|
|
81
81
|
const RECURRENT_SHADER = /* wgsl */ `
|
|
82
|
-
override WORKGROUP_SIZE: u32 =
|
|
82
|
+
override WORKGROUP_SIZE: u32 = 128u;
|
|
83
83
|
|
|
84
84
|
struct LinearAttentionParams {
|
|
85
85
|
num_tokens: u32,
|
|
@@ -111,6 +111,8 @@ struct LinearAttentionParams {
|
|
|
111
111
|
@group(0) @binding(8) var<storage, read_write> recurrent_state: array<f32>;
|
|
112
112
|
@group(0) @binding(9) var<storage, read_write> output: array<f32>;
|
|
113
113
|
|
|
114
|
+
var<workgroup> shared_sq: array<f32, WORKGROUP_SIZE>;
|
|
115
|
+
|
|
114
116
|
fn softplus(x: f32) -> f32 {
|
|
115
117
|
if (x > 20.0) {
|
|
116
118
|
return x;
|
|
@@ -131,17 +133,19 @@ fn silu(x: f32) -> f32 {
|
|
|
131
133
|
}
|
|
132
134
|
|
|
133
135
|
@compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
|
|
134
|
-
fn main(@builtin(
|
|
135
|
-
|
|
136
|
+
fn main(@builtin(workgroup_id) wid: vec3<u32>,
|
|
137
|
+
@builtin(local_invocation_id) lid: vec3<u32>) {
|
|
138
|
+
let head = wid.x;
|
|
139
|
+
let vd = lid.x;
|
|
136
140
|
if (head >= params.num_v_heads) {
|
|
137
141
|
return;
|
|
138
142
|
}
|
|
139
143
|
|
|
140
144
|
let head_k_dim = params.head_k_dim;
|
|
141
145
|
let head_v_dim = params.head_v_dim;
|
|
146
|
+
let is_active = vd < head_v_dim;
|
|
142
147
|
let head_scale = inverseSqrt(f32(head_k_dim));
|
|
143
148
|
let recurrent_head_base = head * head_k_dim * head_v_dim;
|
|
144
|
-
let recurrent_head_size = head_k_dim * head_v_dim;
|
|
145
149
|
let q_rep = max(params.q_rep, 1u);
|
|
146
150
|
let src_head = head / q_rep;
|
|
147
151
|
let q_base = src_head * head_k_dim;
|
|
@@ -154,6 +158,7 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
|
154
158
|
let ab_row_base = token_idx * params.num_v_heads + head;
|
|
155
159
|
let out_row_base = token_idx * params.value_dim + head * head_v_dim;
|
|
156
160
|
|
|
161
|
+
// L2 norm for Q and K (redundant across threads but avoids shared memory)
|
|
157
162
|
var q_norm_sq = 0.0;
|
|
158
163
|
var k_norm_sq = 0.0;
|
|
159
164
|
for (var d: u32 = 0u; d < head_k_dim; d = d + 1u) {
|
|
@@ -169,11 +174,16 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
|
169
174
|
let g = a_neg_exp[head] * softplus(a_proj[ab_row_base] + dt_bias[head]);
|
|
170
175
|
let g_exp = exp(g);
|
|
171
176
|
|
|
172
|
-
|
|
173
|
-
|
|
177
|
+
// Decay state — each thread handles head_k_dim elements at stride head_v_dim
|
|
178
|
+
if (is_active) {
|
|
179
|
+
for (var kd: u32 = 0u; kd < head_k_dim; kd = kd + 1u) {
|
|
180
|
+
let state_idx = recurrent_head_base + kd * head_v_dim + vd;
|
|
181
|
+
recurrent_state[state_idx] = recurrent_state[state_idx] * g_exp;
|
|
182
|
+
}
|
|
174
183
|
}
|
|
175
184
|
|
|
176
|
-
|
|
185
|
+
// Delta update — each thread handles one vd slice (no cross-thread dependency)
|
|
186
|
+
if (is_active) {
|
|
177
187
|
var kv_mem = 0.0;
|
|
178
188
|
for (var kd: u32 = 0u; kd < head_k_dim; kd = kd + 1u) {
|
|
179
189
|
let k_normed = conv_out[conv_row_base + k_base + kd] * k_norm_scale;
|
|
@@ -188,21 +198,31 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
|
188
198
|
}
|
|
189
199
|
}
|
|
190
200
|
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
201
|
+
// Output — each thread computes one vd element
|
|
202
|
+
var out_value = 0.0;
|
|
203
|
+
if (is_active) {
|
|
194
204
|
for (var kd: u32 = 0u; kd < head_k_dim; kd = kd + 1u) {
|
|
195
205
|
let q_normed = conv_out[conv_row_base + q_base + kd] * q_norm_scale;
|
|
196
206
|
let state_idx = recurrent_head_base + kd * head_v_dim + vd;
|
|
197
207
|
out_value = out_value + recurrent_state[state_idx] * q_normed;
|
|
198
208
|
}
|
|
199
209
|
output[out_row_base + vd] = out_value;
|
|
200
|
-
let value = out_value;
|
|
201
|
-
mean_sq = mean_sq + value * value;
|
|
202
210
|
}
|
|
203
|
-
let inv_rms = inverseSqrt(mean_sq / f32(head_v_dim) + params.rms_norm_eps);
|
|
204
211
|
|
|
205
|
-
|
|
212
|
+
// RMS norm reduction across vd (workgroup-level)
|
|
213
|
+
shared_sq[vd] = select(0.0, out_value * out_value, is_active);
|
|
214
|
+
workgroupBarrier();
|
|
215
|
+
// Tree reduction
|
|
216
|
+
for (var stride: u32 = WORKGROUP_SIZE / 2u; stride > 0u; stride = stride / 2u) {
|
|
217
|
+
if (vd < stride) {
|
|
218
|
+
shared_sq[vd] = shared_sq[vd] + shared_sq[vd + stride];
|
|
219
|
+
}
|
|
220
|
+
workgroupBarrier();
|
|
221
|
+
}
|
|
222
|
+
let inv_rms = inverseSqrt(shared_sq[0] / f32(head_v_dim) + params.rms_norm_eps);
|
|
223
|
+
|
|
224
|
+
// Apply norm + gate
|
|
225
|
+
if (is_active) {
|
|
206
226
|
let gate = silu(z_proj[z_row_base + vd]);
|
|
207
227
|
let norm_index = select(vd, head * head_v_dim + vd, params.norm_mode == 1u);
|
|
208
228
|
output[out_row_base + vd] = (output[out_row_base + vd] * inv_rms) * norm_weight[norm_index] * gate;
|
|
@@ -415,7 +435,7 @@ export async function runLinearAttentionCoreGPU(qkvTensor, zTensor, aTensor, bTe
|
|
|
415
435
|
recorder,
|
|
416
436
|
recurrentPipeline,
|
|
417
437
|
recurrentBindGroup,
|
|
418
|
-
[
|
|
438
|
+
[layerState.numVHeads, 1, 1],
|
|
419
439
|
'linear_attention_recurrent'
|
|
420
440
|
);
|
|
421
441
|
|
|
@@ -502,7 +522,7 @@ export async function runLinearAttentionCoreGPU(qkvTensor, zTensor, aTensor, bTe
|
|
|
502
522
|
const pass = encoder.beginComputePass({ label: 'linear_attention_recurrent_pass' });
|
|
503
523
|
pass.setPipeline(recurrentPipeline);
|
|
504
524
|
pass.setBindGroup(0, recurrentBindGroup);
|
|
505
|
-
pass.dispatchWorkgroups(
|
|
525
|
+
pass.dispatchWorkgroups(layerState.numVHeads, 1, 1);
|
|
506
526
|
pass.end();
|
|
507
527
|
}
|
|
508
528
|
|
|
@@ -29,7 +29,13 @@ function selectQ4KFusedVariant(isM1, wantF16Output, aDtype) {
|
|
|
29
29
|
}
|
|
30
30
|
|
|
31
31
|
|
|
32
|
-
export function resolveMatmulPhase(M) {
|
|
32
|
+
export function resolveMatmulPhase(M, phaseOverride = null) {
|
|
33
|
+
if (phaseOverride != null) {
|
|
34
|
+
if (phaseOverride !== 'decode' && phaseOverride !== 'prefill') {
|
|
35
|
+
throw new Error(`[Matmul] Invalid phase override "${phaseOverride}". Expected "decode" or "prefill".`);
|
|
36
|
+
}
|
|
37
|
+
return phaseOverride;
|
|
38
|
+
}
|
|
33
39
|
return selectKernelRuleValue('matmul', 'phase', { isDecode: M === 1 });
|
|
34
40
|
}
|
|
35
41
|
|
|
@@ -86,6 +92,7 @@ export function getMatmulConfig(variant, constants) {
|
|
|
86
92
|
|
|
87
93
|
|
|
88
94
|
export function isFusedQ4KDisabled(options = {}) {
|
|
95
|
+
if (options.disableFusedQ4K === true) return true;
|
|
89
96
|
const capabilities = getKernelCapabilities();
|
|
90
97
|
const hasSubgroups = capabilities?.hasSubgroups === true;
|
|
91
98
|
|
|
@@ -125,7 +132,9 @@ export function selectMatmulKernel(options = {}) {
|
|
|
125
132
|
const { tiledPrefillMinRows } = getKernelThresholds().matmul;
|
|
126
133
|
|
|
127
134
|
const inputsAreF16 = aDtype === 'f16' && bDtype === 'f16';
|
|
128
|
-
|
|
135
|
+
// F16 weights needing F32a path: weights are F16 and either activation is already F32,
|
|
136
|
+
// or both inputs are F16 but output is F32 (activation will be cast to F32 by executeMatmul)
|
|
137
|
+
const weightsAreF16 = bDtype === 'f16' && (aDtype !== 'f16' || outputDtype !== 'f16');
|
|
129
138
|
const useF16Matmul = outputDtype === 'f16' && preferF16 && inputsAreF16 && capabilities.hasF16;
|
|
130
139
|
const useF16wF32a = preferF16 && weightsAreF16 && capabilities.hasF16;
|
|
131
140
|
const useTiled = isPrefill
|
|
@@ -244,6 +253,30 @@ export function requiresF32Input(variant) {
|
|
|
244
253
|
return !supportsF16Input(variant);
|
|
245
254
|
}
|
|
246
255
|
|
|
256
|
+
function resolveRequiredWeightDtype(config) {
|
|
257
|
+
const shaderFile = String(config?.shaderFile ?? config?.wgsl ?? '');
|
|
258
|
+
if (!shaderFile) {
|
|
259
|
+
return null;
|
|
260
|
+
}
|
|
261
|
+
if (shaderFile.startsWith('fused_matmul_q4')) {
|
|
262
|
+
return 'q4k';
|
|
263
|
+
}
|
|
264
|
+
if (
|
|
265
|
+
shaderFile === 'matmul_f16.wgsl'
|
|
266
|
+
|| shaderFile === 'matmul_f16_tiled.wgsl'
|
|
267
|
+
|| shaderFile === 'matmul_f16w_f32a.wgsl'
|
|
268
|
+
|| shaderFile === 'matmul_f16w_f32a_tiled.wgsl'
|
|
269
|
+
|| shaderFile === 'matmul_gemv_subgroup.wgsl'
|
|
270
|
+
|| shaderFile === 'matmul_gemv_subgroup_f16a.wgsl'
|
|
271
|
+
) {
|
|
272
|
+
return 'f16';
|
|
273
|
+
}
|
|
274
|
+
if (shaderFile === 'matmul_f32.wgsl') {
|
|
275
|
+
return 'f32';
|
|
276
|
+
}
|
|
277
|
+
return null;
|
|
278
|
+
}
|
|
279
|
+
|
|
247
280
|
|
|
248
281
|
function resolveMatmulOverride(
|
|
249
282
|
variantOverride,
|
|
@@ -287,6 +320,16 @@ function resolveMatmulOverride(
|
|
|
287
320
|
);
|
|
288
321
|
}
|
|
289
322
|
|
|
323
|
+
const requiredWeightDtype = resolveRequiredWeightDtype(config);
|
|
324
|
+
const weightDtypeOk = !requiredWeightDtype
|
|
325
|
+
|| bDtype === requiredWeightDtype
|
|
326
|
+
|| (requiredWeightDtype === 'f16' && bDtype === 'q4k');
|
|
327
|
+
if (!weightDtypeOk) {
|
|
328
|
+
return failOrWarn(
|
|
329
|
+
`Matmul kernel "${variantOverride}" requires ${requiredWeightDtype} weights but B dtype is ${bDtype}.`
|
|
330
|
+
);
|
|
331
|
+
}
|
|
332
|
+
|
|
290
333
|
if (supportsF16Input(override) && aDtype !== 'f16') {
|
|
291
334
|
return failOrWarn(`Matmul kernel "${variantOverride}" requires f16 activations but A dtype is ${aDtype}.`);
|
|
292
335
|
}
|
|
@@ -341,7 +384,7 @@ function selectGemvVariant(useF16Gemv, useF32Gemv, hasSubgroups, useVec4, N, mul
|
|
|
341
384
|
export function selectMatmulVariantAndFlags(mode, M, N, K, aDtype, bDtype, transposeB, requestedOutputDtype, options) {
|
|
342
385
|
const capabilities = getKernelCapabilities();
|
|
343
386
|
const strict = getKernelPathStrict();
|
|
344
|
-
const phase = resolveMatmulPhase(M);
|
|
387
|
+
const phase = resolveMatmulPhase(M, options.phaseOverride ?? null);
|
|
345
388
|
let pathVariant = getKernelPathMatmulVariant(options.role, phase, options.layerIdx, options.kernelPath);
|
|
346
389
|
const hadPathVariant = Boolean(pathVariant);
|
|
347
390
|
|
|
@@ -426,7 +469,8 @@ export function selectMatmulVariantAndFlags(mode, M, N, K, aDtype, bDtype, trans
|
|
|
426
469
|
|
|
427
470
|
const canGemv = M === 1 && effectiveBDtype === 'f16' && capabilities.hasF16;
|
|
428
471
|
const useF16Gemv = canGemv && aDtype === 'f16' && wantF16Output;
|
|
429
|
-
|
|
472
|
+
// F32 GEMV: activation is F32, or activation is F16 with F32 output (will be cast to F32)
|
|
473
|
+
const useF32Gemv = canGemv && (aDtype === 'f32' || (aDtype === 'f16' && !wantF16Output));
|
|
430
474
|
const useGemv = useF16Gemv || useF32Gemv;
|
|
431
475
|
const useVec4 = (K % 4 === 0);
|
|
432
476
|
const { multicolThreshold } = getKernelThresholds().matmul;
|
|
@@ -13,6 +13,7 @@ import type { WeightBuffer } from '../weight-buffer.js';
|
|
|
13
13
|
import type { CommandRecorder } from '../command-recorder.js';
|
|
14
14
|
import type { OutputBufferOptions, OutputDtypeOptions, Vec4Options } from './types.js';
|
|
15
15
|
import type { KernelPathSchema } from '../../config/schema/index.js';
|
|
16
|
+
import type { MatmulDebugConfigSchema } from '../../config/schema/debug.schema.js';
|
|
16
17
|
|
|
17
18
|
/** Matmul kernel options */
|
|
18
19
|
export interface MatmulOptions extends OutputBufferOptions, OutputDtypeOptions, Vec4Options {
|
|
@@ -23,6 +24,8 @@ export interface MatmulOptions extends OutputBufferOptions, OutputDtypeOptions,
|
|
|
23
24
|
layerIdx?: number;
|
|
24
25
|
/** Explicit kernel path context for variant selection (avoids global path state). */
|
|
25
26
|
kernelPath?: KernelPathSchema | null;
|
|
27
|
+
/** Optional explicit phase for kernel-path lookup when the runtime rewrites rows (for example prefill last-position logits). */
|
|
28
|
+
phaseOverride?: 'decode' | 'prefill' | null;
|
|
26
29
|
/**
|
|
27
30
|
* Whether B matrix is stored transposed.
|
|
28
31
|
* - true: B is [N,K] (SafeTensors/row-major), needs transpose
|
|
@@ -38,6 +41,8 @@ export interface MatmulOptions extends OutputBufferOptions, OutputDtypeOptions,
|
|
|
38
41
|
preferF16?: boolean;
|
|
39
42
|
/** WGSL override constants for pipeline creation */
|
|
40
43
|
constants?: Record<string, number | boolean>;
|
|
44
|
+
/** Runtime debug controls for attention projection diagnostics. */
|
|
45
|
+
matmulDebug?: MatmulDebugConfigSchema | null;
|
|
41
46
|
}
|
|
42
47
|
|
|
43
48
|
/** Context for base matmul kernel selection rules. */
|
|
@@ -2,7 +2,7 @@ import { getDevice, getKernelCapabilities } from '../device.js';
|
|
|
2
2
|
import { createTensor } from '../tensor.js';
|
|
3
3
|
import { getBuffer, getLayout, getWeightDtype } from '../weight-buffer.js';
|
|
4
4
|
import { log, trace, isTraceEnabled } from '../../debug/index.js';
|
|
5
|
-
import { releaseBuffer } from '../../memory/buffer-pool.js';
|
|
5
|
+
import { releaseBuffer, readBuffer } from '../../memory/buffer-pool.js';
|
|
6
6
|
import { releaseUniformBuffer } from '../uniform-cache.js';
|
|
7
7
|
import { castF16ToF32, recordCastF16ToF32 } from './cast.js';
|
|
8
8
|
import {
|
|
@@ -34,6 +34,24 @@ export { createMatmulBindGroupLayout };
|
|
|
34
34
|
let _runMatmulDebugCount = 0;
|
|
35
35
|
let _recordMatmulDebugCount = 0;
|
|
36
36
|
|
|
37
|
+
function normalizeMatmulDebugConfig(config) {
|
|
38
|
+
if (!config || typeof config !== 'object') {
|
|
39
|
+
return null;
|
|
40
|
+
}
|
|
41
|
+
return {
|
|
42
|
+
enabled: config.enabled === true,
|
|
43
|
+
forceSplitQKV: config.forceSplitQKV === true,
|
|
44
|
+
validateAttentionWeightBuffer: config.validateAttentionWeightBuffer === true,
|
|
45
|
+
failOnSmallAttentionWeightBuffer: config.failOnSmallAttentionWeightBuffer === true,
|
|
46
|
+
logAttentionWeightBuffer: config.logAttentionWeightBuffer === true,
|
|
47
|
+
logProjectionValues: config.logProjectionValues === true,
|
|
48
|
+
};
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
function isAttentionProjectionRole(role = '') {
|
|
52
|
+
return role === 'qkv_proj' || role === 'q_proj' || role === 'k_proj' || role === 'v_proj';
|
|
53
|
+
}
|
|
54
|
+
|
|
37
55
|
function getDebugCounter(isRecord) {
|
|
38
56
|
return isRecord ? _recordMatmulDebugCount : _runMatmulDebugCount;
|
|
39
57
|
}
|
|
@@ -126,6 +144,12 @@ async function executeMatmul(recorder, A, B, M, N, K, options = {}) {
|
|
|
126
144
|
const weightLabel = (B && typeof B === 'object' ? B.label : null) ?? bBuffer?.label ?? null;
|
|
127
145
|
const weightLayout = getLayout(B);
|
|
128
146
|
const weightShape = B?.shape ? `[${B.shape.join(', ')}]` : null;
|
|
147
|
+
const matmulDebug = normalizeMatmulDebugConfig(options.matmulDebug);
|
|
148
|
+
const debugAttention = matmulDebug?.enabled === true;
|
|
149
|
+
const isAttnProj = isAttentionProjectionRole(options.role ?? '');
|
|
150
|
+
const shouldValidateAttentionWeightBuffer = debugAttention && matmulDebug.validateAttentionWeightBuffer;
|
|
151
|
+
const shouldFailOnSmallAttentionWeightBuffer = debugAttention && matmulDebug.failOnSmallAttentionWeightBuffer;
|
|
152
|
+
const shouldLogAttentionWeightBuffer = debugAttention && matmulDebug.logAttentionWeightBuffer;
|
|
129
153
|
|
|
130
154
|
if (isTraceEnabled('kernels') && getDebugCounter(isRecord) < 20) {
|
|
131
155
|
incrementDebugCounter(isRecord);
|
|
@@ -165,7 +189,7 @@ async function executeMatmul(recorder, A, B, M, N, K, options = {}) {
|
|
|
165
189
|
options
|
|
166
190
|
);
|
|
167
191
|
|
|
168
|
-
const phase = resolveMatmulPhase(M);
|
|
192
|
+
const phase = resolveMatmulPhase(M, options.phaseOverride ?? null);
|
|
169
193
|
const constants = resolveMatmulConstants(options, phase);
|
|
170
194
|
|
|
171
195
|
let matmulInput = A;
|
|
@@ -201,6 +225,27 @@ async function executeMatmul(recorder, A, B, M, N, K, options = {}) {
|
|
|
201
225
|
bOffset
|
|
202
226
|
);
|
|
203
227
|
} catch (err) {
|
|
228
|
+
if (shouldValidateAttentionWeightBuffer && isAttnProj && err instanceof Error && err.message.includes('B buffer too small')) {
|
|
229
|
+
const detailParts = [
|
|
230
|
+
`role=${options.role ?? ''}`,
|
|
231
|
+
`layer=${Number.isFinite(options.layerIdx) ? options.layerIdx : '?'}`,
|
|
232
|
+
`M=${M}`,
|
|
233
|
+
`N=${N}`,
|
|
234
|
+
`K=${K}`,
|
|
235
|
+
];
|
|
236
|
+
if (weightDtype) detailParts.push(`weightDtype=${weightDtype}`);
|
|
237
|
+
if (weightLayout) detailParts.push(`weightLayout=${weightLayout}`);
|
|
238
|
+
if (weightShape) detailParts.push(`shape=${weightShape}`);
|
|
239
|
+
if (weightLabel) detailParts.push(`label=${weightLabel}`);
|
|
240
|
+
if (Number.isFinite(bBuffer?.size)) detailParts.push(`bSize=${bBuffer.size}`);
|
|
241
|
+
const detail = detailParts.join(' ');
|
|
242
|
+
if (shouldLogAttentionWeightBuffer) {
|
|
243
|
+
log.warn('MatmulQKVProbe', `${err.message} | ${detail}`);
|
|
244
|
+
}
|
|
245
|
+
if (shouldFailOnSmallAttentionWeightBuffer) {
|
|
246
|
+
throw new Error(`${err.message}${detail ? ` (${detail})` : ''}`);
|
|
247
|
+
}
|
|
248
|
+
}
|
|
204
249
|
if (!isRecord && err instanceof Error && err.message.includes('B buffer too small')) {
|
|
205
250
|
const detailParts = [];
|
|
206
251
|
if (weightLabel) detailParts.push(`label=${weightLabel}`);
|
|
@@ -226,6 +271,15 @@ async function executeMatmul(recorder, A, B, M, N, K, options = {}) {
|
|
|
226
271
|
trace.kernels(`MATMUL_LARGE: N=${N}, variant=${variant}, aDtype=${aDtype}, bDtype=${bDtype}, transposeB=${transposeB}`);
|
|
227
272
|
}
|
|
228
273
|
|
|
274
|
+
if (isAttnProj && shouldLogAttentionWeightBuffer) {
|
|
275
|
+
log.warn('MatmulQKVProbe',
|
|
276
|
+
`role=${options.role ?? ''} layer=${Number.isFinite(options.layerIdx) ? options.layerIdx : '?'} ` +
|
|
277
|
+
`M=${M} N=${N} K=${K} transposeB=${transposeB} bSize=${bBuffer?.size ?? 0} ` +
|
|
278
|
+
`requiredB=${bindingSizes?.bBindingSize ?? 'n/a'} weightShape=${weightShape ?? 'n/a'} ` +
|
|
279
|
+
`weightDtype=${weightDtype ?? 'unknown'} weightLayout=${weightLayout ?? 'unknown'}`
|
|
280
|
+
);
|
|
281
|
+
}
|
|
282
|
+
|
|
229
283
|
const config = getMatmulConfig(variant, constants);
|
|
230
284
|
const kernel = new MatmulKernel(device);
|
|
231
285
|
const pipeline = await getMatmulPipeline(variant, constants);
|
|
@@ -238,6 +292,14 @@ async function executeMatmul(recorder, A, B, M, N, K, options = {}) {
|
|
|
238
292
|
);
|
|
239
293
|
const ownsOutput = outputBuffer == null;
|
|
240
294
|
|
|
295
|
+
if (isAttnProj && shouldLogAttentionWeightBuffer) {
|
|
296
|
+
log.warn('MatmulVariantDiag',
|
|
297
|
+
`role=${options.role ?? ''} layer=${Number.isFinite(options.layerIdx) ? options.layerIdx : '?'} mode=${mode} ` +
|
|
298
|
+
`variant=${variant} useQ4KFused=${useQ4KFused} useGemv=${useGemv} ` +
|
|
299
|
+
`aDtype=${aDtype} bDtype=${bDtype} output=${actualOutputDtype}`
|
|
300
|
+
);
|
|
301
|
+
}
|
|
302
|
+
|
|
241
303
|
if (!Number.isFinite(outputSize) || outputSize <= 0) {
|
|
242
304
|
throw new Error(`[${opLabel}] Invalid output size: ${outputSize} (M=${M}, N=${N})`);
|
|
243
305
|
}
|
|
@@ -290,6 +352,13 @@ async function executeMatmul(recorder, A, B, M, N, K, options = {}) {
|
|
|
290
352
|
kernel.dispatch(pipeline, bindGroup, dispatchPlan.workgroups);
|
|
291
353
|
}
|
|
292
354
|
completed = true;
|
|
355
|
+
if (!isRecord && matmulDebug?.logProjectionValues && isAttnProj && M === 1 && options.layerIdx === 0) {
|
|
356
|
+
await device.queue.onSubmittedWorkDone();
|
|
357
|
+
const raw = await readBuffer(C);
|
|
358
|
+
const numVals = Math.min(8, Math.floor(raw.byteLength / 4));
|
|
359
|
+
const vals = numVals > 0 ? new Float32Array(raw, 0, numVals) : [];
|
|
360
|
+
log.warn('ProjectionProbe', `role=${options.role ?? ''} L0 M1 first8_f32: ${Array.from(vals).map(v => v.toFixed(5)).join(' ')}`);
|
|
361
|
+
}
|
|
293
362
|
return createTensor(C, actualOutputDtype, [M, N], 'matmul_output');
|
|
294
363
|
} finally {
|
|
295
364
|
if (!isRecord && uniformBuffer) {
|
|
@@ -5,7 +5,11 @@
|
|
|
5
5
|
// 1. Use subgroupAdd() for reduction - much faster than shared memory
|
|
6
6
|
// 2. Vectorized vec4 loads for weights
|
|
7
7
|
// 3. Each workgroup processes multiple output columns
|
|
8
|
-
// 4.
|
|
8
|
+
// 4. Warp-stride loop for row-major (transpose_b=1): all threads in a column
|
|
9
|
+
// step through K together, so adjacent threads load adjacent addresses.
|
|
10
|
+
// At each step, 64 threads × 8 bytes = 512 bytes from 4 consecutive cache
|
|
11
|
+
// lines → 100% cache-line utilization vs ~10% for the old contiguous-range
|
|
12
|
+
// pattern (where threads were 80 bytes apart in the same iteration).
|
|
9
13
|
//
|
|
10
14
|
// A is f32 (activations), B is f16 (weights), C is f32.
|
|
11
15
|
// transpose_b=0: B is [K, N] (GGUF/column-major), access B[k * N + col]
|
|
@@ -69,40 +73,29 @@ fn main(
|
|
|
69
73
|
// Each thread computes partial sum for its assigned k values
|
|
70
74
|
var partial_sum: f32 = 0.0;
|
|
71
75
|
|
|
72
|
-
// Only do work if this column is valid
|
|
73
76
|
if (is_valid) {
|
|
74
|
-
// Process K in chunks, each thread handles K/64 elements
|
|
75
|
-
let k_per_thread = (u.K + THREADS_PER_COL - 1u) / THREADS_PER_COL;
|
|
76
|
-
let k_start = thread_in_col * k_per_thread;
|
|
77
|
-
let k_end = min(k_start + k_per_thread, u.K);
|
|
78
|
-
|
|
79
|
-
// Main loop - process 4 elements at a time when aligned
|
|
80
|
-
var k = k_start;
|
|
81
|
-
let k_aligned_end = k_start + ((k_end - k_start) / 4u) * 4u;
|
|
82
|
-
|
|
83
77
|
if (u.transpose_b == 1u) {
|
|
84
|
-
// B is [N, K] (
|
|
78
|
+
// B is [N, K] (row-major): B[col, k] = B[col * K + k]
|
|
79
|
+
// Warp-stride: step THREADS_PER_COL elements per outer iteration so that
|
|
80
|
+
// all wavefront threads load consecutive addresses simultaneously.
|
|
81
|
+
// At each step, 64 threads × 2 bytes = 128 bytes = exactly 1 cache line → 100% utilization.
|
|
85
82
|
let b_row_offset = col * u.K;
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
let a3 = A[k + 3u];
|
|
92
|
-
|
|
93
|
-
let b0 = f32(B[b_row_offset + k]);
|
|
94
|
-
let b1 = f32(B[b_row_offset + k + 1u]);
|
|
95
|
-
let b2 = f32(B[b_row_offset + k + 2u]);
|
|
96
|
-
let b3 = f32(B[b_row_offset + k + 3u]);
|
|
97
|
-
|
|
98
|
-
partial_sum = partial_sum + a0 * b0 + a1 * b1 + a2 * b2 + a3 * b3;
|
|
99
|
-
}
|
|
100
|
-
|
|
101
|
-
for (; k < k_end; k = k + 1u) {
|
|
102
|
-
partial_sum = partial_sum + A[k] * f32(B[b_row_offset + k]);
|
|
83
|
+
for (var k_base: u32 = 0u; k_base < u.K; k_base = k_base + THREADS_PER_COL) {
|
|
84
|
+
let k = k_base + thread_in_col;
|
|
85
|
+
if (k < u.K) {
|
|
86
|
+
partial_sum = partial_sum + A[k] * f32(B[b_row_offset + k]);
|
|
87
|
+
}
|
|
103
88
|
}
|
|
104
89
|
} else {
|
|
105
|
-
// B is [K, N] (
|
|
90
|
+
// B is [K, N] (column-major): B[k, col] = B[k * N + col]
|
|
91
|
+
// Contiguous-range per thread: sequential access within each thread.
|
|
92
|
+
let k_per_thread = (u.K + THREADS_PER_COL - 1u) / THREADS_PER_COL;
|
|
93
|
+
let k_start = thread_in_col * k_per_thread;
|
|
94
|
+
let k_end = min(k_start + k_per_thread, u.K);
|
|
95
|
+
|
|
96
|
+
var k = k_start;
|
|
97
|
+
let k_aligned_end = k_start + ((k_end - k_start) / 4u) * 4u;
|
|
98
|
+
|
|
106
99
|
for (; k < k_aligned_end; k = k + 4u) {
|
|
107
100
|
let a0 = A[k];
|
|
108
101
|
let a1 = A[k + 1u];
|
|
@@ -189,38 +182,36 @@ fn main_multicol(
|
|
|
189
182
|
var partial_sum: f32 = 0.0;
|
|
190
183
|
|
|
191
184
|
if (is_valid) {
|
|
192
|
-
// Each of 8 threads splits K
|
|
193
|
-
let k_per_thread = (u.K + MULTICOL_THREADS_PER_COL - 1u) / MULTICOL_THREADS_PER_COL;
|
|
194
|
-
let k_start = thread_in_col * k_per_thread;
|
|
195
|
-
let k_end = min(k_start + k_per_thread, u.K);
|
|
196
|
-
|
|
197
|
-
// Unroll by 4 for ILP
|
|
198
|
-
var k = k_start;
|
|
199
|
-
let k_aligned_end = k_start + ((k_end - k_start) / 4u) * 4u;
|
|
200
|
-
|
|
201
185
|
if (u.transpose_b == 1u) {
|
|
202
|
-
// B is [N, K] (
|
|
186
|
+
// B is [N, K] (row-major): B[col, k] = B[col * K + k]
|
|
187
|
+
// Warp-stride: step MULTICOL_THREADS_PER_COL vec4 groups per outer iteration.
|
|
188
|
+
// Adjacent threads in the same column load adjacent vec4 groups → coalesced.
|
|
189
|
+
let K4 = u.K / 4u;
|
|
203
190
|
let b_row_offset = col * u.K;
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
for (; k < k_end; k = k + 1u) {
|
|
220
|
-
partial_sum = partial_sum + A[k] * f32(B[b_row_offset + k]);
|
|
191
|
+
for (var k4_base: u32 = 0u; k4_base < K4; k4_base = k4_base + MULTICOL_THREADS_PER_COL) {
|
|
192
|
+
let k4 = k4_base + thread_in_col;
|
|
193
|
+
if (k4 < K4) {
|
|
194
|
+
let k = k4 * 4u;
|
|
195
|
+
let a0 = A[k];
|
|
196
|
+
let a1 = A[k + 1u];
|
|
197
|
+
let a2 = A[k + 2u];
|
|
198
|
+
let a3 = A[k + 3u];
|
|
199
|
+
let b0 = f32(B[b_row_offset + k]);
|
|
200
|
+
let b1 = f32(B[b_row_offset + k + 1u]);
|
|
201
|
+
let b2 = f32(B[b_row_offset + k + 2u]);
|
|
202
|
+
let b3 = f32(B[b_row_offset + k + 3u]);
|
|
203
|
+
partial_sum = partial_sum + a0 * b0 + a1 * b1 + a2 * b2 + a3 * b3;
|
|
204
|
+
}
|
|
221
205
|
}
|
|
222
206
|
} else {
|
|
223
|
-
// B is [K, N] (
|
|
207
|
+
// B is [K, N] (column-major): B[k, col] = B[k * N + col]
|
|
208
|
+
let k_per_thread = (u.K + MULTICOL_THREADS_PER_COL - 1u) / MULTICOL_THREADS_PER_COL;
|
|
209
|
+
let k_start = thread_in_col * k_per_thread;
|
|
210
|
+
let k_end = min(k_start + k_per_thread, u.K);
|
|
211
|
+
|
|
212
|
+
var k = k_start;
|
|
213
|
+
let k_aligned_end = k_start + ((k_end - k_start) / 4u) * 4u;
|
|
214
|
+
|
|
224
215
|
for (; k < k_aligned_end; k = k + 4u) {
|
|
225
216
|
let a0 = A[k];
|
|
226
217
|
let a1 = A[k + 1u];
|
|
@@ -245,7 +236,7 @@ fn main_multicol(
|
|
|
245
236
|
multicol_wg_sums[local_id] = partial_sum;
|
|
246
237
|
workgroupBarrier();
|
|
247
238
|
|
|
248
|
-
// Thread 0 of each column reduces its
|
|
239
|
+
// Thread 0 of each column reduces its MULTICOL_THREADS_PER_COL values
|
|
249
240
|
if (thread_in_col == 0u && is_valid) {
|
|
250
241
|
var final_sum: f32 = 0.0;
|
|
251
242
|
let base = col_in_wg * MULTICOL_THREADS_PER_COL;
|
|
@@ -282,30 +273,37 @@ fn main_vec4(
|
|
|
282
273
|
if (is_valid) {
|
|
283
274
|
// K is guaranteed to be multiple of 4
|
|
284
275
|
let K4 = u.K / 4u;
|
|
285
|
-
let k4_per_thread = (K4 + THREADS_PER_COL - 1u) / THREADS_PER_COL;
|
|
286
|
-
let k4_start = thread_in_col * k4_per_thread;
|
|
287
|
-
let k4_end = min(k4_start + k4_per_thread, K4);
|
|
288
276
|
|
|
289
277
|
if (u.transpose_b == 1u) {
|
|
290
|
-
// B is [N, K] (
|
|
278
|
+
// B is [N, K] (row-major): B[col, k] = B[col * K + k]
|
|
279
|
+
// Warp-stride: step THREADS_PER_COL vec4 groups per outer iteration so that
|
|
280
|
+
// adjacent threads load adjacent groups → 100% cache-line utilization.
|
|
281
|
+
// At each step: 64 threads × 4 f16 × 2 bytes = 512 bytes from 4 consecutive
|
|
282
|
+
// cache lines, vs the old contiguous-range pattern (~10% utilization).
|
|
291
283
|
let b_row_offset = col * u.K;
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
f32(
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
284
|
+
for (var k4_base: u32 = 0u; k4_base < K4; k4_base = k4_base + THREADS_PER_COL) {
|
|
285
|
+
let k4 = k4_base + thread_in_col;
|
|
286
|
+
if (k4 < K4) {
|
|
287
|
+
let k = k4 * 4u;
|
|
288
|
+
|
|
289
|
+
let a = vec4<f32>(A[k], A[k + 1u], A[k + 2u], A[k + 3u]);
|
|
290
|
+
|
|
291
|
+
let b = vec4<f32>(
|
|
292
|
+
f32(B[b_row_offset + k]),
|
|
293
|
+
f32(B[b_row_offset + k + 1u]),
|
|
294
|
+
f32(B[b_row_offset + k + 2u]),
|
|
295
|
+
f32(B[b_row_offset + k + 3u])
|
|
296
|
+
);
|
|
297
|
+
|
|
298
|
+
partial_sum = partial_sum + dot(a, b);
|
|
299
|
+
}
|
|
306
300
|
}
|
|
307
301
|
} else {
|
|
308
|
-
// B is [K, N] (
|
|
302
|
+
// B is [K, N] (column-major): B[k, col] = B[k * N + col]
|
|
303
|
+
// Contiguous-range per thread: sequential access within each thread.
|
|
304
|
+
let k4_per_thread = (K4 + THREADS_PER_COL - 1u) / THREADS_PER_COL;
|
|
305
|
+
let k4_start = thread_in_col * k4_per_thread;
|
|
306
|
+
let k4_end = min(k4_start + k4_per_thread, K4);
|
|
309
307
|
for (var k4: u32 = k4_start; k4 < k4_end; k4 = k4 + 1u) {
|
|
310
308
|
let k = k4 * 4u;
|
|
311
309
|
|
|
@@ -342,4 +340,4 @@ fn main_vec4(
|
|
|
342
340
|
}
|
|
343
341
|
C[col] = final_sum * u.alpha;
|
|
344
342
|
}
|
|
345
|
-
}
|
|
343
|
+
}
|
|
@@ -9,6 +9,9 @@ import { selectRuleValue as selectLoaderRule } from '../../rules/rule-registry.j
|
|
|
9
9
|
import { getBuffer, getWeightDtype, getBufferDtype } from '../weight-buffer.js';
|
|
10
10
|
import { unifiedKernelWrapper } from './utils.js';
|
|
11
11
|
|
|
12
|
+
// Conservative fallback dtype for norm weight inference when metadata is unavailable.
|
|
13
|
+
const DEFAULT_DTYPE = 'f32';
|
|
14
|
+
|
|
12
15
|
function inferHiddenSize(input, hiddenSize) {
|
|
13
16
|
if (hiddenSize != null) return hiddenSize;
|
|
14
17
|
const shape = input?.shape;
|
|
@@ -39,9 +42,12 @@ function resolveNormWeightDtype(weight, hiddenSize) {
|
|
|
39
42
|
return taggedDtype;
|
|
40
43
|
}
|
|
41
44
|
|
|
45
|
+
// Conservative fallback: f32 avoids precision loss when dtype cannot be determined.
|
|
46
|
+
// This path fires for non-GPU buffers or missing hiddenSize, both of which prevent
|
|
47
|
+
// size-based dtype inference below.
|
|
42
48
|
const hasGPUBufferType = typeof GPUBuffer !== 'undefined';
|
|
43
49
|
if (!hasGPUBufferType || !(weightBuffer instanceof GPUBuffer) || hiddenSize == null || hiddenSize <= 0) {
|
|
44
|
-
return
|
|
50
|
+
return DEFAULT_DTYPE;
|
|
45
51
|
}
|
|
46
52
|
|
|
47
53
|
const byteSize = getBufferRequestedSize(weightBuffer);
|
|
@@ -55,7 +61,8 @@ function resolveNormWeightDtype(weight, hiddenSize) {
|
|
|
55
61
|
sizeMatchesF32,
|
|
56
62
|
});
|
|
57
63
|
}
|
|
58
|
-
|
|
64
|
+
// Buffer size matches neither f16 nor f32 for given hiddenSize; fall back to f32.
|
|
65
|
+
return DEFAULT_DTYPE;
|
|
59
66
|
}
|
|
60
67
|
|
|
61
68
|
function assertRMSNormWeightBuffer(weight, weightBuffer, hiddenSize) {
|
|
@@ -7,7 +7,6 @@ import { createPipeline, createUniformBufferWithView, getOrCreateBindGroupLayout
|
|
|
7
7
|
import { allowReadback } from '../perf-guards.js';
|
|
8
8
|
import { selectRuleValue as selectKernelRuleValue } from './rule-registry.js';
|
|
9
9
|
import { selectRuleValue as selectSharedRuleValue } from '../../rules/rule-registry.js';
|
|
10
|
-
import { getKernelThresholds } from '../../config/schema/index.js';
|
|
11
10
|
|
|
12
11
|
|
|
13
12
|
function getSampleBindGroupLayout(device) {
|
|
@@ -96,8 +95,7 @@ async function resolveArgmaxPipelines(device, vocabSize, variants) {
|
|
|
96
95
|
const argmaxPipeline = await createSamplePipeline(device, variants.argmax);
|
|
97
96
|
const numWorkgroups = Math.min(WORKGROUP_SIZES.DEFAULT, Math.ceil(vocabSize / WORKGROUP_SIZES.DEFAULT));
|
|
98
97
|
const useSinglePassArgmax = numWorkgroups === 1;
|
|
99
|
-
const
|
|
100
|
-
const reducePipeline = useSinglePassArgmax || vocabSize <= argmaxReduceVocabThreshold
|
|
98
|
+
const reducePipeline = useSinglePassArgmax
|
|
101
99
|
? null
|
|
102
100
|
: await createSamplePipeline(device, variants.argmaxReduce);
|
|
103
101
|
const singlePassPipeline = useSinglePassArgmax
|