@simulatte/doppler 0.1.7 → 0.1.8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. package/CHANGELOG.md +19 -0
  2. package/package.json +21 -36
  3. package/src/browser/browser-converter.js +5 -0
  4. package/src/client/doppler-registry.json +1 -17
  5. package/src/config/kernel-path-loader.d.ts +5 -0
  6. package/src/config/kernel-path-loader.js +13 -0
  7. package/src/config/kernels/registry.json +74 -0
  8. package/src/config/loader.js +3 -0
  9. package/src/config/merge-contract-check.js +7 -0
  10. package/src/config/presets/kernel-paths/gemma3-q4k-dequant-f32w-f32a-online.json +56 -0
  11. package/src/config/presets/kernel-paths/lfm2-q4k-dequant-f32a-nosubgroups.json +61 -0
  12. package/src/config/presets/kernel-paths/registry.json +14 -0
  13. package/src/config/presets/models/gemma2.json +2 -1
  14. package/src/config/presets/models/gemma3.json +2 -0
  15. package/src/config/presets/models/qwen3.json +4 -3
  16. package/src/config/presets/models/qwen3_5.json +16 -0
  17. package/src/config/presets/runtime/model/qwen3-5-layer-probe.json +52 -0
  18. package/src/config/presets/runtime/model/qwen3-5-linear-attn-debug.json +90 -0
  19. package/src/config/schema/conversion.schema.d.ts +1 -0
  20. package/src/config/schema/manifest.schema.d.ts +1 -1
  21. package/src/config/schema/manifest.schema.js +1 -1
  22. package/src/config/schema/storage.schema.js +1 -1
  23. package/src/converter/conversion-plan.js +10 -2
  24. package/src/converter/core.js +2 -0
  25. package/src/converter/manifest-inference.js +12 -22
  26. package/src/converter/parsers/transformer.js +4 -0
  27. package/src/converter/quantization-info.js +5 -1
  28. package/src/converter/quantizer.js +19 -12
  29. package/src/converter/rope-config.js +8 -6
  30. package/src/converter/tokenizer-utils.d.ts +1 -0
  31. package/src/converter/tokenizer-utils.js +4 -1
  32. package/src/debug/reference/hf_qwen35_linear_attn_debug.py +268 -0
  33. package/src/distribution/shard-delivery.js +6 -1
  34. package/src/formats/rdrr/parsing.d.ts +4 -0
  35. package/src/formats/rdrr/parsing.js +14 -1
  36. package/src/gpu/kernels/index.d.ts +8 -0
  37. package/src/gpu/kernels/index.js +6 -0
  38. package/src/gpu/kernels/matmul-selection.js +47 -4
  39. package/src/gpu/kernels/matmul.d.ts +2 -0
  40. package/src/gpu/kernels/matmul.js +1 -1
  41. package/src/gpu/kernels/rmsnorm.js +9 -2
  42. package/src/gpu/kernels/split_qg.d.ts +50 -0
  43. package/src/gpu/kernels/split_qg.js +46 -0
  44. package/src/gpu/kernels/split_qg.wgsl +58 -0
  45. package/src/gpu/kernels/split_qg_f16.wgsl +62 -0
  46. package/src/gpu/weight-buffer.d.ts +1 -1
  47. package/src/gpu/weight-buffer.js +1 -1
  48. package/src/inference/browser-harness.d.ts +2 -0
  49. package/src/inference/browser-harness.js +20 -1
  50. package/src/inference/pipelines/diffusion/helpers.js +3 -0
  51. package/src/inference/pipelines/diffusion/text-encoder-gpu.js +8 -2
  52. package/src/inference/pipelines/text/attention/output-projection.d.ts +12 -0
  53. package/src/inference/pipelines/text/attention/output-projection.js +8 -0
  54. package/src/inference/pipelines/text/attention/projections.d.ts +10 -1
  55. package/src/inference/pipelines/text/attention/projections.js +41 -11
  56. package/src/inference/pipelines/text/attention/record.js +15 -6
  57. package/src/inference/pipelines/text/attention/run.js +50 -6
  58. package/src/inference/pipelines/text/config.js +14 -0
  59. package/src/inference/pipelines/text/execution-plan.js +5 -4
  60. package/src/inference/pipelines/text/generator-runtime.js +5 -0
  61. package/src/inference/pipelines/text/generator-steps.d.ts +6 -0
  62. package/src/inference/pipelines/text/generator-steps.js +43 -15
  63. package/src/inference/pipelines/text/generator.js +50 -17
  64. package/src/inference/pipelines/text/init.d.ts +13 -0
  65. package/src/inference/pipelines/text/init.js +16 -5
  66. package/src/inference/pipelines/text/layer.js +1 -0
  67. package/src/inference/pipelines/text/linear-attention.d.ts +5 -0
  68. package/src/inference/pipelines/text/linear-attention.js +33 -3
  69. package/src/inference/pipelines/text/logits/gpu.js +2 -2
  70. package/src/inference/pipelines/text/logits/index.d.ts +6 -1
  71. package/src/inference/pipelines/text/logits/index.js +3 -1
  72. package/src/inference/pipelines/text/model-load.js +3 -0
  73. package/src/inference/pipelines/text/sampling.js +52 -6
  74. package/src/inference/test-harness.js +2 -2
  75. package/src/loader/final-weights-loader.js +2 -0
  76. package/src/loader/shard-cache.js +3 -2
  77. package/src/loader/tensors/tensor-loader.js +6 -1
  78. package/src/rules/inference/dtype.rules.json +5 -0
  79. package/src/rules/inference/kernel-path.rules.json +2 -2
  80. package/src/rules/kernels/split-qg.rules.json +6 -0
  81. package/src/rules/rule-registry.js +2 -0
  82. package/src/storage/downloader.js +2 -1
  83. package/src/storage/shard-manager.js +4 -3
  84. package/src/tooling/conversion-config-materializer.js +3 -5
  85. package/src/tooling/node-converter.js +3 -0
  86. package/src/tooling/node-source-runtime.js +36 -0
  87. package/src/types/model.d.ts +5 -0
  88. package/tools/doppler-cli.js +6 -1
@@ -0,0 +1,268 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ Dump intermediate values from Qwen3.5 linear attention (GatedDeltaNet) for comparison with Doppler.
4
+
5
+ Usage:
6
+ HF_HOME=/media/x/models/huggingface_cache python3 src/debug/reference/hf_qwen35_linear_attn_debug.py
7
+ """
8
+
9
+ import os
10
+ import torch
11
+ import numpy as np
12
+
13
+ os.environ.setdefault("HF_HOME", "/media/x/models/huggingface_cache")
14
+
15
+ from transformers import AutoModelForCausalLM, AutoTokenizer
16
+
17
+ MODEL_ID = "Qwen/Qwen3.5-0.8B"
18
+ PROMPT = "Hello"
19
+
20
+
21
+ def stats(name, tensor):
22
+ t = tensor.float().detach().flatten()
23
+ print(f" {name}: shape={list(tensor.shape)}, "
24
+ f"min={t.min().item():.6f}, max={t.max().item():.6f}, "
25
+ f"mean={t.mean().item():.6f}, absMax={t.abs().max().item():.6f}")
26
+ first8 = t[:8].tolist()
27
+ print(f" first8: {[f'{v:.6f}' for v in first8]}")
28
+
29
+
30
+ def main():
31
+ print(f"Loading {MODEL_ID}...")
32
+ model = AutoModelForCausalLM.from_pretrained(MODEL_ID, dtype=torch.float32)
33
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
34
+ model.eval()
35
+
36
+ inputs = tokenizer(PROMPT, return_tensors="pt")
37
+ input_ids = inputs["input_ids"]
38
+ print(f"Prompt: '{PROMPT}', Token IDs: {input_ids[0].tolist()}")
39
+ num_tokens = input_ids.shape[1]
40
+
41
+ # Dump key weight values for layer 0
42
+ layer0 = model.model.layers[0]
43
+ attn = layer0.linear_attn
44
+
45
+ print(f"\n=== Layer 0 weights ===")
46
+ if hasattr(attn, 'A_log'):
47
+ a_log = attn.A_log.detach().float()
48
+ a_neg_exp = -torch.exp(a_log)
49
+ stats("A_log", a_log)
50
+ stats("a_neg_exp", a_neg_exp)
51
+ if hasattr(attn, 'dt_bias'):
52
+ stats("dt_bias", attn.dt_bias.detach().float())
53
+ stats("conv1d.weight", attn.conv1d.weight.detach().float())
54
+ stats("norm.weight", attn.norm.weight.detach().float())
55
+
56
+ # Hook into the linear_attn module to capture its input and output
57
+ captured = {}
58
+
59
+ def hook_linear_attn_input(module, args, kwargs):
60
+ if len(args) > 0:
61
+ captured['linear_attn_input'] = args[0].detach().clone()
62
+ return None
63
+
64
+ def hook_linear_attn_output(module, args, kwargs, output):
65
+ if isinstance(output, tuple):
66
+ captured['linear_attn_output'] = output[0].detach().clone()
67
+ else:
68
+ captured['linear_attn_output'] = output.detach().clone()
69
+ return None
70
+
71
+ # Hook into individual projection layers
72
+ def make_hook(name):
73
+ def hook(module, input, output):
74
+ captured[name] = output.detach().clone()
75
+ return hook
76
+
77
+ hooks = []
78
+ hooks.append(attn.register_forward_pre_hook(hook_linear_attn_input, with_kwargs=True))
79
+ hooks.append(attn.register_forward_hook(hook_linear_attn_output, with_kwargs=True))
80
+ hooks.append(attn.in_proj_qkv.register_forward_hook(make_hook('qkv_proj')))
81
+ hooks.append(attn.in_proj_z.register_forward_hook(make_hook('z_proj')))
82
+ hooks.append(attn.in_proj_a.register_forward_hook(make_hook('a_proj')))
83
+ hooks.append(attn.in_proj_b.register_forward_hook(make_hook('b_proj')))
84
+ hooks.append(attn.out_proj.register_forward_hook(make_hook('out_proj')))
85
+ hooks.append(attn.conv1d.register_forward_hook(make_hook('conv1d_raw')))
86
+ hooks.append(attn.norm.register_forward_hook(make_hook('gated_norm')))
87
+
88
+ # Also hook input_layernorm
89
+ hooks.append(layer0.input_layernorm.register_forward_hook(make_hook('input_layernorm')))
90
+
91
+ print(f"\n=== Running forward pass ===")
92
+ with torch.no_grad():
93
+ outputs = model(input_ids, output_hidden_states=True)
94
+
95
+ # Remove hooks
96
+ for h in hooks:
97
+ h.remove()
98
+
99
+ print(f"\n=== Captured intermediates ===")
100
+ for name in ['input_layernorm', 'qkv_proj', 'z_proj', 'a_proj', 'b_proj',
101
+ 'conv1d_raw', 'gated_norm', 'linear_attn_input', 'linear_attn_output', 'out_proj']:
102
+ if name in captured:
103
+ stats(name, captured[name])
104
+ else:
105
+ print(f" {name}: NOT CAPTURED")
106
+
107
+ # Hidden states per layer
108
+ print(f"\n=== Hidden states per layer (last token) ===")
109
+ for i in range(min(6, len(outputs.hidden_states) - 1)):
110
+ hs = outputs.hidden_states[i + 1]
111
+ t = hs[0, -1] # last token
112
+ vals = t[:8].tolist()
113
+ max_abs = t.abs().max().item()
114
+ mean_abs = t.abs().mean().item()
115
+ layer_type = type(model.model.layers[i]).__name__
116
+ attn_type = "linear" if hasattr(model.model.layers[i], 'linear_attn') else "full"
117
+ print(f" Layer {i} ({attn_type}): first8={[f'{v:.4f}' for v in vals]}, "
118
+ f"maxAbs={max_abs:.4f}, meanAbs={mean_abs:.4f}")
119
+
120
+ # Logits
121
+ logits = outputs.logits[0, -1]
122
+ top5 = torch.topk(logits, 5)
123
+ print(f"\nTop-5 logits: {[(tokenizer.decode([idx.item()]), f'{val.item():.2f}') for val, idx in zip(top5.values, top5.indices)]}")
124
+
125
+ # Also trace through the linear attention manually to compare with Doppler's kernel
126
+ print(f"\n=== Manual linear attention trace (layer 0) ===")
127
+ with torch.no_grad():
128
+ embed = model.model.embed_tokens(input_ids)
129
+ normed = layer0.input_layernorm(embed)
130
+ stats("normed_input", normed)
131
+
132
+ qkv = attn.in_proj_qkv(normed)
133
+ stats("qkv", qkv)
134
+
135
+ # The HF Qwen3.5 GatedDeltaNet does conv1d on the QKV, then applies SiLU
136
+ # The conv1d expects [batch, channels, seq_len] format
137
+ qkv_t = qkv.transpose(1, 2) # [1, 6144, 1]
138
+
139
+ # Use the conv1d module directly (it has padding configured)
140
+ conv_raw = attn.conv1d(qkv_t)
141
+ stats("conv_raw (from module)", conv_raw.transpose(1, 2))
142
+
143
+ # Truncate to seq_len (causal conv padding)
144
+ conv_causal = conv_raw[..., :num_tokens]
145
+ stats("conv_causal (truncated)", conv_causal.transpose(1, 2))
146
+
147
+ # Apply SiLU
148
+ conv_silu = torch.nn.functional.silu(conv_causal)
149
+ stats("conv_silu", conv_silu.transpose(1, 2))
150
+
151
+ # Split Q, K, V
152
+ conv_out = conv_silu.transpose(1, 2) # [1, seq_len, 6144]
153
+ num_k_heads = 16
154
+ head_k_dim = 128
155
+ head_v_dim = 128
156
+ num_v_heads = 16
157
+ q_size = num_k_heads * head_k_dim # 2048
158
+ k_size = q_size
159
+ v_size = num_v_heads * head_v_dim # 2048
160
+
161
+ q = conv_out[..., :q_size]
162
+ k = conv_out[..., q_size:q_size + k_size]
163
+ v = conv_out[..., q_size + k_size:]
164
+ stats("Q (raw)", q)
165
+ stats("K (raw)", k)
166
+ stats("V (raw)", v)
167
+
168
+ # Reshape for per-head processing
169
+ # Q and K: [batch, seq, num_k_heads, head_k_dim]
170
+ q_heads = q.view(1, num_tokens, num_k_heads, head_k_dim)
171
+ k_heads = k.view(1, num_tokens, num_k_heads, head_k_dim)
172
+ v_heads = v.view(1, num_tokens, num_v_heads, head_v_dim)
173
+
174
+ # L2 normalize Q and K
175
+ eps = 1e-6
176
+ q_norm = torch.nn.functional.normalize(q_heads, p=2, dim=-1, eps=eps)
177
+ k_norm = torch.nn.functional.normalize(k_heads, p=2, dim=-1, eps=eps)
178
+
179
+ # Scale Q by 1/sqrt(head_k_dim)
180
+ head_scale = 1.0 / (head_k_dim ** 0.5)
181
+ q_scaled = q_norm * head_scale
182
+
183
+ stats("Q_normed_scaled (per-head)", q_scaled.reshape(1, num_tokens, -1))
184
+ stats("K_normed (per-head)", k_norm.reshape(1, num_tokens, -1))
185
+
186
+ # Projections for gating
187
+ z = attn.in_proj_z(normed)
188
+ a_out = attn.in_proj_a(normed)
189
+ b_out = attn.in_proj_b(normed)
190
+ stats("z", z)
191
+ stats("a", a_out)
192
+ stats("b", b_out)
193
+
194
+ # Compute gating values
195
+ a_log = attn.A_log.detach().float()
196
+ a_neg_exp = -torch.exp(a_log)
197
+ dt_bias = attn.dt_bias.detach().float()
198
+
199
+ softplus_input = a_out.squeeze(0).squeeze(0) + dt_bias
200
+ softplus_val = torch.nn.functional.softplus(softplus_input)
201
+ g = a_neg_exp * softplus_val
202
+ g_exp = torch.exp(g)
203
+ beta = torch.sigmoid(b_out.squeeze(0).squeeze(0))
204
+
205
+ stats("softplus(a + dt_bias)", softplus_val.unsqueeze(0).unsqueeze(0))
206
+ stats("g (decay)", g.unsqueeze(0).unsqueeze(0))
207
+ stats("g_exp (decay factor)", g_exp.unsqueeze(0).unsqueeze(0))
208
+ stats("beta (sigmoid(b))", beta.unsqueeze(0).unsqueeze(0))
209
+
210
+ # Recurrent state update (for first token, state is all zeros)
211
+ # state[head, kd, vd] = state * g_exp + k[kd] * delta[vd]
212
+ # where delta[vd] = (v[vd] - state^T @ k * beta
213
+ # For zero state: delta[vd] = v[vd] * beta, state = k ⊗ delta
214
+ state = torch.zeros(num_v_heads, head_k_dim, head_v_dim)
215
+
216
+ # Apply decay (no-op for zero state)
217
+ for head in range(num_v_heads):
218
+ state[head] *= g_exp[head].item()
219
+
220
+ k_head = k_norm[0, 0, head % num_k_heads] # broadcast q_rep
221
+ v_head = v_heads[0, 0, head]
222
+
223
+ # kv_mem = state @ k
224
+ kv_mem = state[head].t() @ k_head # [head_v_dim]
225
+
226
+ # delta = (v - kv_mem) * beta
227
+ delta = (v_head - kv_mem) * beta[head].item()
228
+
229
+ # state += outer(k, delta)
230
+ state[head] += torch.outer(k_head, delta)
231
+
232
+ # Output: out = state^T @ q
233
+ output_per_head = torch.zeros(1, num_tokens, num_v_heads, head_v_dim)
234
+ for head in range(num_v_heads):
235
+ q_head = q_scaled[0, 0, head % num_k_heads]
236
+ out_head = state[head].t() @ q_head # [head_v_dim]
237
+ output_per_head[0, 0, head] = out_head
238
+
239
+ raw_out = output_per_head.reshape(1, num_tokens, num_v_heads * head_v_dim)
240
+ stats("Recurrent output (raw)", raw_out)
241
+
242
+ # RMS norm per head + SiLU gate
243
+ z_reshaped = z.view(1, num_tokens, num_v_heads, head_v_dim)
244
+ norm_weight = attn.norm.weight.detach().float() # [head_v_dim] (shared mode)
245
+ rms_eps = 1e-6
246
+
247
+ for head in range(num_v_heads):
248
+ head_out = output_per_head[0, 0, head] # [head_v_dim]
249
+ mean_sq = (head_out ** 2).mean()
250
+ inv_rms = 1.0 / torch.sqrt(mean_sq + rms_eps)
251
+ z_gate = torch.nn.functional.silu(z_reshaped[0, 0, head])
252
+ output_per_head[0, 0, head] = head_out * inv_rms * norm_weight * z_gate
253
+
254
+ gated_out = output_per_head.reshape(1, num_tokens, num_v_heads * head_v_dim)
255
+ stats("After RMSNorm + SiLU gate", gated_out)
256
+
257
+ # Output projection
258
+ o_result = torch.nn.functional.linear(gated_out, attn.out_proj.weight)
259
+ stats("After out_proj", o_result)
260
+
261
+ # Compare with captured output
262
+ if 'linear_attn_output' in captured:
263
+ diff = (o_result - captured['linear_attn_output']).abs()
264
+ print(f"\n Diff vs captured output: maxDiff={diff.max().item():.6f}")
265
+
266
+
267
+ if __name__ == "__main__":
268
+ main()
@@ -1,4 +1,5 @@
1
1
  import { log } from '../debug/index.js';
2
+ import { getExpectedShardHash } from '../formats/rdrr/index.js';
2
3
  import {
3
4
  computeHash,
4
5
  createStreamingHasher,
@@ -2018,7 +2019,11 @@ export async function downloadShard(
2018
2019
  onDeliveryMetrics,
2019
2020
  signal,
2020
2021
  requiredEncoding: requiredEncoding ?? activeConfig.requiredContentEncoding ?? null,
2021
- expectedHash: options.expectedHash ?? shardInfo?.hash ?? activeConfig.expectedHash ?? null,
2022
+ expectedHash:
2023
+ options.expectedHash
2024
+ ?? getExpectedShardHash(shardInfo, algorithm)
2025
+ ?? activeConfig.expectedHash
2026
+ ?? null,
2022
2027
  expectedSize: expectedSize ?? shardInfo?.size ?? null,
2023
2028
  expectedManifestVersionSet: options.expectedManifestVersionSet ?? null,
2024
2029
  writeToStore,
@@ -7,6 +7,10 @@
7
7
  import type { RDRRManifest, ShardInfo, TensorMap } from './types.js';
8
8
 
9
9
  export declare function parseManifest(jsonString: string): RDRRManifest;
10
+ export declare function getExpectedShardHash(
11
+ shard: Partial<ShardInfo> | Record<string, unknown> | null | undefined,
12
+ manifestHashAlgorithm?: string | null
13
+ ): string;
10
14
 
11
15
  export declare function parseTensorMap(jsonString: string): TensorMap;
12
16
 
@@ -4,6 +4,19 @@ import { validateManifest } from './validation.js';
4
4
 
5
5
  let currentManifest = null;
6
6
 
7
+ export function getExpectedShardHash(shard, manifestHashAlgorithm = null) {
8
+ if (!shard || typeof shard !== 'object' || Array.isArray(shard)) {
9
+ return '';
10
+ }
11
+ const algorithm = typeof manifestHashAlgorithm === 'string'
12
+ ? manifestHashAlgorithm.trim().toLowerCase()
13
+ : '';
14
+ if (algorithm === 'blake3') {
15
+ return shard.blake3 || shard.hash || '';
16
+ }
17
+ return shard.hash || shard.blake3 || '';
18
+ }
19
+
7
20
  export function parseManifest(jsonString) {
8
21
  let manifest;
9
22
 
@@ -21,7 +34,7 @@ export function parseManifest(jsonString) {
21
34
  index: shard.index ?? i,
22
35
  filename: shard.filename || shard.fileName || '',
23
36
  size: shard.size,
24
- hash: shard.hash || shard.blake3 || '',
37
+ hash: getExpectedShardHash(shard, manifest.hashAlgorithm),
25
38
  blake3: shard.blake3 || shard.hash,
26
39
  offset: shard.offset ?? offset,
27
40
  hashAlgorithm: shard.hashAlgorithm,
@@ -326,6 +326,14 @@ export {
326
326
  type SplitQKVResult,
327
327
  } from './split_qkv.js';
328
328
 
329
+ // Split Q and Gate (de-interleave attentionOutputGate q_proj output)
330
+ export {
331
+ runSplitQG,
332
+ recordSplitQG,
333
+ type SplitQGOptions,
334
+ type SplitQGResult,
335
+ } from './split_qg.js';
336
+
329
337
  // Transpose
330
338
  export {
331
339
  runTranspose,
@@ -268,6 +268,12 @@ export {
268
268
  recordSplitQKV,
269
269
  } from './split_qkv.js';
270
270
 
271
+ // Split Q and Gate (de-interleave attentionOutputGate q_proj output)
272
+ export {
273
+ runSplitQG,
274
+ recordSplitQG,
275
+ } from './split_qg.js';
276
+
271
277
  // Transpose
272
278
  export {
273
279
  runTranspose,
@@ -29,7 +29,13 @@ function selectQ4KFusedVariant(isM1, wantF16Output, aDtype) {
29
29
  }
30
30
 
31
31
 
32
- export function resolveMatmulPhase(M) {
32
+ export function resolveMatmulPhase(M, phaseOverride = null) {
33
+ if (phaseOverride != null) {
34
+ if (phaseOverride !== 'decode' && phaseOverride !== 'prefill') {
35
+ throw new Error(`[Matmul] Invalid phase override "${phaseOverride}". Expected "decode" or "prefill".`);
36
+ }
37
+ return phaseOverride;
38
+ }
33
39
  return selectKernelRuleValue('matmul', 'phase', { isDecode: M === 1 });
34
40
  }
35
41
 
@@ -125,7 +131,9 @@ export function selectMatmulKernel(options = {}) {
125
131
  const { tiledPrefillMinRows } = getKernelThresholds().matmul;
126
132
 
127
133
  const inputsAreF16 = aDtype === 'f16' && bDtype === 'f16';
128
- const weightsAreF16 = bDtype === 'f16' && aDtype !== 'f16';
134
+ // F16 weights needing F32a path: weights are F16 and either activation is already F32,
135
+ // or both inputs are F16 but output is F32 (activation will be cast to F32 by executeMatmul)
136
+ const weightsAreF16 = bDtype === 'f16' && (aDtype !== 'f16' || outputDtype !== 'f16');
129
137
  const useF16Matmul = outputDtype === 'f16' && preferF16 && inputsAreF16 && capabilities.hasF16;
130
138
  const useF16wF32a = preferF16 && weightsAreF16 && capabilities.hasF16;
131
139
  const useTiled = isPrefill
@@ -244,6 +252,30 @@ export function requiresF32Input(variant) {
244
252
  return !supportsF16Input(variant);
245
253
  }
246
254
 
255
+ function resolveRequiredWeightDtype(config) {
256
+ const shaderFile = String(config?.shaderFile ?? config?.wgsl ?? '');
257
+ if (!shaderFile) {
258
+ return null;
259
+ }
260
+ if (shaderFile.startsWith('fused_matmul_q4')) {
261
+ return 'q4k';
262
+ }
263
+ if (
264
+ shaderFile === 'matmul_f16.wgsl'
265
+ || shaderFile === 'matmul_f16_tiled.wgsl'
266
+ || shaderFile === 'matmul_f16w_f32a.wgsl'
267
+ || shaderFile === 'matmul_f16w_f32a_tiled.wgsl'
268
+ || shaderFile === 'matmul_gemv_subgroup.wgsl'
269
+ || shaderFile === 'matmul_gemv_subgroup_f16a.wgsl'
270
+ ) {
271
+ return 'f16';
272
+ }
273
+ if (shaderFile === 'matmul_f32.wgsl') {
274
+ return 'f32';
275
+ }
276
+ return null;
277
+ }
278
+
247
279
 
248
280
  function resolveMatmulOverride(
249
281
  variantOverride,
@@ -287,6 +319,16 @@ function resolveMatmulOverride(
287
319
  );
288
320
  }
289
321
 
322
+ const requiredWeightDtype = resolveRequiredWeightDtype(config);
323
+ const weightDtypeOk = !requiredWeightDtype
324
+ || bDtype === requiredWeightDtype
325
+ || (requiredWeightDtype === 'f16' && bDtype === 'q4k');
326
+ if (!weightDtypeOk) {
327
+ return failOrWarn(
328
+ `Matmul kernel "${variantOverride}" requires ${requiredWeightDtype} weights but B dtype is ${bDtype}.`
329
+ );
330
+ }
331
+
290
332
  if (supportsF16Input(override) && aDtype !== 'f16') {
291
333
  return failOrWarn(`Matmul kernel "${variantOverride}" requires f16 activations but A dtype is ${aDtype}.`);
292
334
  }
@@ -341,7 +383,7 @@ function selectGemvVariant(useF16Gemv, useF32Gemv, hasSubgroups, useVec4, N, mul
341
383
  export function selectMatmulVariantAndFlags(mode, M, N, K, aDtype, bDtype, transposeB, requestedOutputDtype, options) {
342
384
  const capabilities = getKernelCapabilities();
343
385
  const strict = getKernelPathStrict();
344
- const phase = resolveMatmulPhase(M);
386
+ const phase = resolveMatmulPhase(M, options.phaseOverride ?? null);
345
387
  let pathVariant = getKernelPathMatmulVariant(options.role, phase, options.layerIdx, options.kernelPath);
346
388
  const hadPathVariant = Boolean(pathVariant);
347
389
 
@@ -426,7 +468,8 @@ export function selectMatmulVariantAndFlags(mode, M, N, K, aDtype, bDtype, trans
426
468
 
427
469
  const canGemv = M === 1 && effectiveBDtype === 'f16' && capabilities.hasF16;
428
470
  const useF16Gemv = canGemv && aDtype === 'f16' && wantF16Output;
429
- const useF32Gemv = canGemv && aDtype === 'f32';
471
+ // F32 GEMV: activation is F32, or activation is F16 with F32 output (will be cast to F32)
472
+ const useF32Gemv = canGemv && (aDtype === 'f32' || (aDtype === 'f16' && !wantF16Output));
430
473
  const useGemv = useF16Gemv || useF32Gemv;
431
474
  const useVec4 = (K % 4 === 0);
432
475
  const { multicolThreshold } = getKernelThresholds().matmul;
@@ -23,6 +23,8 @@ export interface MatmulOptions extends OutputBufferOptions, OutputDtypeOptions,
23
23
  layerIdx?: number;
24
24
  /** Explicit kernel path context for variant selection (avoids global path state). */
25
25
  kernelPath?: KernelPathSchema | null;
26
+ /** Optional explicit phase for kernel-path lookup when the runtime rewrites rows (for example prefill last-position logits). */
27
+ phaseOverride?: 'decode' | 'prefill' | null;
26
28
  /**
27
29
  * Whether B matrix is stored transposed.
28
30
  * - true: B is [N,K] (SafeTensors/row-major), needs transpose
@@ -165,7 +165,7 @@ async function executeMatmul(recorder, A, B, M, N, K, options = {}) {
165
165
  options
166
166
  );
167
167
 
168
- const phase = resolveMatmulPhase(M);
168
+ const phase = resolveMatmulPhase(M, options.phaseOverride ?? null);
169
169
  const constants = resolveMatmulConstants(options, phase);
170
170
 
171
171
  let matmulInput = A;
@@ -9,6 +9,9 @@ import { selectRuleValue as selectLoaderRule } from '../../rules/rule-registry.j
9
9
  import { getBuffer, getWeightDtype, getBufferDtype } from '../weight-buffer.js';
10
10
  import { unifiedKernelWrapper } from './utils.js';
11
11
 
12
+ // Conservative fallback dtype for norm weight inference when metadata is unavailable.
13
+ const DEFAULT_DTYPE = 'f32';
14
+
12
15
  function inferHiddenSize(input, hiddenSize) {
13
16
  if (hiddenSize != null) return hiddenSize;
14
17
  const shape = input?.shape;
@@ -39,9 +42,12 @@ function resolveNormWeightDtype(weight, hiddenSize) {
39
42
  return taggedDtype;
40
43
  }
41
44
 
45
+ // Conservative fallback: f32 avoids precision loss when dtype cannot be determined.
46
+ // This path fires for non-GPU buffers or missing hiddenSize, both of which prevent
47
+ // size-based dtype inference below.
42
48
  const hasGPUBufferType = typeof GPUBuffer !== 'undefined';
43
49
  if (!hasGPUBufferType || !(weightBuffer instanceof GPUBuffer) || hiddenSize == null || hiddenSize <= 0) {
44
- return 'f32';
50
+ return DEFAULT_DTYPE;
45
51
  }
46
52
 
47
53
  const byteSize = getBufferRequestedSize(weightBuffer);
@@ -55,7 +61,8 @@ function resolveNormWeightDtype(weight, hiddenSize) {
55
61
  sizeMatchesF32,
56
62
  });
57
63
  }
58
- return 'f32';
64
+ // Buffer size matches neither f16 nor f32 for given hiddenSize; fall back to f32.
65
+ return DEFAULT_DTYPE;
59
66
  }
60
67
 
61
68
  function assertRMSNormWeightBuffer(weight, weightBuffer, hiddenSize) {
@@ -0,0 +1,50 @@
1
+ /**
2
+ * Split Q and Gate Kernel
3
+ *
4
+ * De-interleaves Q and Gate projections from q_proj output for attentionOutputGate models.
5
+ * Models like Qwen 3.5 store q_proj weights in per-head interleaved layout:
6
+ * rows [h*headDim*2 : h*headDim*2+headDim] = Q for head h
7
+ * rows [h*headDim*2+headDim : (h+1)*headDim*2] = Gate for head h
8
+ * This kernel separates the full matmul output into contiguous Q and Gate tensors.
9
+ */
10
+
11
+ import type { Tensor } from '../tensor.js';
12
+ import type { CommandRecorder } from '../command-recorder.js';
13
+
14
+ /** Split Q and Gate options */
15
+ export interface SplitQGOptions {
16
+ numTokens: number;
17
+ numHeads: number;
18
+ headDim: number;
19
+ /** Pre-allocated Q output tensor */
20
+ qTensor?: Tensor | null;
21
+ /** Pre-allocated Gate output tensor */
22
+ gTensor?: Tensor | null;
23
+ }
24
+
25
+ /** Split Q and Gate result */
26
+ export interface SplitQGResult {
27
+ Q: Tensor;
28
+ G: Tensor;
29
+ }
30
+
31
+ /**
32
+ * De-interleave Q and Gate from q_proj output.
33
+ *
34
+ * @param qgTensor - Full q_proj output [numTokens, numHeads * headDim * 2] (interleaved)
35
+ * @param options - Split configuration
36
+ * @returns Separate Q and Gate tensors, each [numTokens, numHeads * headDim]
37
+ */
38
+ export declare function runSplitQG(
39
+ qgTensor: Tensor,
40
+ options: SplitQGOptions
41
+ ): Promise<SplitQGResult>;
42
+
43
+ /**
44
+ * Record split Q and Gate (batched, no submit).
45
+ */
46
+ export declare function recordSplitQG(
47
+ recorder: CommandRecorder,
48
+ qgTensor: Tensor,
49
+ options: SplitQGOptions
50
+ ): Promise<SplitQGResult>;
@@ -0,0 +1,46 @@
1
+
2
+ import { acquireBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
3
+ import { createTensor, dtypeBytes } from '../tensor.js';
4
+ import { WORKGROUP_SIZES } from './constants.js';
5
+ import { unifiedKernelWrapper } from './utils.js';
6
+ import { selectRuleValue } from './rule-registry.js';
7
+
8
+ async function _splitQG(target, qgTensor, options) {
9
+ const { numTokens, numHeads, headDim, qTensor = null, gTensor = null } = options;
10
+ const ownsQ = qTensor == null;
11
+ const ownsG = gTensor == null;
12
+
13
+ const outputDtype = qgTensor.dtype;
14
+ const pipelineVariant = selectRuleValue('splitQg', 'variant', { outputDtype });
15
+ const bytesPerElement = dtypeBytes(outputDtype);
16
+ const qSize = numHeads * headDim;
17
+
18
+ const qBuffer = qTensor?.buffer || acquireBuffer(numTokens * qSize * bytesPerElement, undefined, 'Q');
19
+ const gBuffer = gTensor?.buffer || acquireBuffer(numTokens * qSize * bytesPerElement, undefined, 'Q_gate');
20
+
21
+ try {
22
+ await unifiedKernelWrapper(
23
+ 'split_qg', target, pipelineVariant,
24
+ [qgTensor, qBuffer, gBuffer],
25
+ { num_tokens: numTokens, num_heads: numHeads, head_dim: headDim, _pad: 0 },
26
+ Math.ceil((numTokens * qSize) / WORKGROUP_SIZES.DEFAULT)
27
+ );
28
+
29
+ const Q = qTensor || createTensor(qBuffer, outputDtype, [numTokens, qSize], 'Q');
30
+ const G = gTensor || createTensor(gBuffer, outputDtype, [numTokens, qSize], 'Q_gate');
31
+
32
+ return { Q, G };
33
+ } catch (error) {
34
+ if (ownsQ) releaseBuffer(qBuffer);
35
+ if (ownsG) releaseBuffer(gBuffer);
36
+ throw error;
37
+ }
38
+ }
39
+
40
+ export async function runSplitQG(qgTensor, options) {
41
+ return _splitQG(null, qgTensor, options);
42
+ }
43
+
44
+ export async function recordSplitQG(recorder, qgTensor, options) {
45
+ return _splitQG(recorder, qgTensor, options);
46
+ }
@@ -0,0 +1,58 @@
1
+ // split_qg.wgsl
2
+
3
+ /**
4
+ * De-interleave Q and Gate projections from q_proj output for attentionOutputGate models.
5
+ *
6
+ * Models like Qwen 3.5 store q_proj weights with interleaved head layout:
7
+ * rows [h*headDim*2 : h*headDim*2+headDim] = Q for head h
8
+ * rows [h*headDim*2+headDim : (h+1)*headDim*2] = Gate for head h
9
+ *
10
+ * A single full matmul over all 2*qSize rows produces interleaved output:
11
+ * input[token, h*headDim*2 : h*headDim*2+headDim] = Q head h
12
+ * input[token, h*headDim*2+headDim : (h+1)*headDim*2] = Gate head h
13
+ *
14
+ * This kernel separates them into contiguous Q and G outputs:
15
+ * Q[token, h*headDim + dim] = input[token, h*headDim*2 + dim]
16
+ * G[token, h*headDim + dim] = input[token, h*headDim*2 + headDim + dim]
17
+ *
18
+ * Input layout (row-major): [numTokens, numHeads * headDim * 2]
19
+ * Output Q layout (row-major): [numTokens, numHeads * headDim]
20
+ * Output G layout (row-major): [numTokens, numHeads * headDim]
21
+ */
22
+
23
+ struct Params {
24
+ num_tokens: u32,
25
+ num_heads: u32,
26
+ head_dim: u32,
27
+ _pad: u32,
28
+ }
29
+
30
+ override WORKGROUP_SIZE: u32 = 256u;
31
+
32
+ @group(0) @binding(0) var<uniform> params: Params;
33
+ @group(0) @binding(1) var<storage, read> input: array<f32>;
34
+ @group(0) @binding(2) var<storage, read_write> Q: array<f32>;
35
+ @group(0) @binding(3) var<storage, read_write> G: array<f32>;
36
+
37
+ @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
38
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
39
+ let idx = gid.x;
40
+ let q_size = params.num_heads * params.head_dim;
41
+ let total_elements = params.num_tokens * q_size;
42
+
43
+ if (idx >= total_elements) {
44
+ return;
45
+ }
46
+
47
+ let token = idx / q_size;
48
+ let elem = idx % q_size;
49
+ let head = elem / params.head_dim;
50
+ let dim = elem % params.head_dim;
51
+
52
+ // Input is interleaved per head: [Q_h (headDim elems), G_h (headDim elems)]
53
+ let src_q = token * (q_size * 2u) + head * (params.head_dim * 2u) + dim;
54
+ let src_g = src_q + params.head_dim;
55
+
56
+ Q[idx] = input[src_q];
57
+ G[idx] = input[src_g];
58
+ }