@simulatte/doppler 0.1.7 → 0.1.8
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +19 -0
- package/package.json +21 -36
- package/src/browser/browser-converter.js +5 -0
- package/src/client/doppler-registry.json +1 -17
- package/src/config/kernel-path-loader.d.ts +5 -0
- package/src/config/kernel-path-loader.js +13 -0
- package/src/config/kernels/registry.json +74 -0
- package/src/config/loader.js +3 -0
- package/src/config/merge-contract-check.js +7 -0
- package/src/config/presets/kernel-paths/gemma3-q4k-dequant-f32w-f32a-online.json +56 -0
- package/src/config/presets/kernel-paths/lfm2-q4k-dequant-f32a-nosubgroups.json +61 -0
- package/src/config/presets/kernel-paths/registry.json +14 -0
- package/src/config/presets/models/gemma2.json +2 -1
- package/src/config/presets/models/gemma3.json +2 -0
- package/src/config/presets/models/qwen3.json +4 -3
- package/src/config/presets/models/qwen3_5.json +16 -0
- package/src/config/presets/runtime/model/qwen3-5-layer-probe.json +52 -0
- package/src/config/presets/runtime/model/qwen3-5-linear-attn-debug.json +90 -0
- package/src/config/schema/conversion.schema.d.ts +1 -0
- package/src/config/schema/manifest.schema.d.ts +1 -1
- package/src/config/schema/manifest.schema.js +1 -1
- package/src/config/schema/storage.schema.js +1 -1
- package/src/converter/conversion-plan.js +10 -2
- package/src/converter/core.js +2 -0
- package/src/converter/manifest-inference.js +12 -22
- package/src/converter/parsers/transformer.js +4 -0
- package/src/converter/quantization-info.js +5 -1
- package/src/converter/quantizer.js +19 -12
- package/src/converter/rope-config.js +8 -6
- package/src/converter/tokenizer-utils.d.ts +1 -0
- package/src/converter/tokenizer-utils.js +4 -1
- package/src/debug/reference/hf_qwen35_linear_attn_debug.py +268 -0
- package/src/distribution/shard-delivery.js +6 -1
- package/src/formats/rdrr/parsing.d.ts +4 -0
- package/src/formats/rdrr/parsing.js +14 -1
- package/src/gpu/kernels/index.d.ts +8 -0
- package/src/gpu/kernels/index.js +6 -0
- package/src/gpu/kernels/matmul-selection.js +47 -4
- package/src/gpu/kernels/matmul.d.ts +2 -0
- package/src/gpu/kernels/matmul.js +1 -1
- package/src/gpu/kernels/rmsnorm.js +9 -2
- package/src/gpu/kernels/split_qg.d.ts +50 -0
- package/src/gpu/kernels/split_qg.js +46 -0
- package/src/gpu/kernels/split_qg.wgsl +58 -0
- package/src/gpu/kernels/split_qg_f16.wgsl +62 -0
- package/src/gpu/weight-buffer.d.ts +1 -1
- package/src/gpu/weight-buffer.js +1 -1
- package/src/inference/browser-harness.d.ts +2 -0
- package/src/inference/browser-harness.js +20 -1
- package/src/inference/pipelines/diffusion/helpers.js +3 -0
- package/src/inference/pipelines/diffusion/text-encoder-gpu.js +8 -2
- package/src/inference/pipelines/text/attention/output-projection.d.ts +12 -0
- package/src/inference/pipelines/text/attention/output-projection.js +8 -0
- package/src/inference/pipelines/text/attention/projections.d.ts +10 -1
- package/src/inference/pipelines/text/attention/projections.js +41 -11
- package/src/inference/pipelines/text/attention/record.js +15 -6
- package/src/inference/pipelines/text/attention/run.js +50 -6
- package/src/inference/pipelines/text/config.js +14 -0
- package/src/inference/pipelines/text/execution-plan.js +5 -4
- package/src/inference/pipelines/text/generator-runtime.js +5 -0
- package/src/inference/pipelines/text/generator-steps.d.ts +6 -0
- package/src/inference/pipelines/text/generator-steps.js +43 -15
- package/src/inference/pipelines/text/generator.js +50 -17
- package/src/inference/pipelines/text/init.d.ts +13 -0
- package/src/inference/pipelines/text/init.js +16 -5
- package/src/inference/pipelines/text/layer.js +1 -0
- package/src/inference/pipelines/text/linear-attention.d.ts +5 -0
- package/src/inference/pipelines/text/linear-attention.js +33 -3
- package/src/inference/pipelines/text/logits/gpu.js +2 -2
- package/src/inference/pipelines/text/logits/index.d.ts +6 -1
- package/src/inference/pipelines/text/logits/index.js +3 -1
- package/src/inference/pipelines/text/model-load.js +3 -0
- package/src/inference/pipelines/text/sampling.js +52 -6
- package/src/inference/test-harness.js +2 -2
- package/src/loader/final-weights-loader.js +2 -0
- package/src/loader/shard-cache.js +3 -2
- package/src/loader/tensors/tensor-loader.js +6 -1
- package/src/rules/inference/dtype.rules.json +5 -0
- package/src/rules/inference/kernel-path.rules.json +2 -2
- package/src/rules/kernels/split-qg.rules.json +6 -0
- package/src/rules/rule-registry.js +2 -0
- package/src/storage/downloader.js +2 -1
- package/src/storage/shard-manager.js +4 -3
- package/src/tooling/conversion-config-materializer.js +3 -5
- package/src/tooling/node-converter.js +3 -0
- package/src/tooling/node-source-runtime.js +36 -0
- package/src/types/model.d.ts +5 -0
- package/tools/doppler-cli.js +6 -1
|
@@ -0,0 +1,268 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Dump intermediate values from Qwen3.5 linear attention (GatedDeltaNet) for comparison with Doppler.
|
|
4
|
+
|
|
5
|
+
Usage:
|
|
6
|
+
HF_HOME=/media/x/models/huggingface_cache python3 src/debug/reference/hf_qwen35_linear_attn_debug.py
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import os
|
|
10
|
+
import torch
|
|
11
|
+
import numpy as np
|
|
12
|
+
|
|
13
|
+
os.environ.setdefault("HF_HOME", "/media/x/models/huggingface_cache")
|
|
14
|
+
|
|
15
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
16
|
+
|
|
17
|
+
MODEL_ID = "Qwen/Qwen3.5-0.8B"
|
|
18
|
+
PROMPT = "Hello"
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def stats(name, tensor):
|
|
22
|
+
t = tensor.float().detach().flatten()
|
|
23
|
+
print(f" {name}: shape={list(tensor.shape)}, "
|
|
24
|
+
f"min={t.min().item():.6f}, max={t.max().item():.6f}, "
|
|
25
|
+
f"mean={t.mean().item():.6f}, absMax={t.abs().max().item():.6f}")
|
|
26
|
+
first8 = t[:8].tolist()
|
|
27
|
+
print(f" first8: {[f'{v:.6f}' for v in first8]}")
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def main():
|
|
31
|
+
print(f"Loading {MODEL_ID}...")
|
|
32
|
+
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, dtype=torch.float32)
|
|
33
|
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
|
34
|
+
model.eval()
|
|
35
|
+
|
|
36
|
+
inputs = tokenizer(PROMPT, return_tensors="pt")
|
|
37
|
+
input_ids = inputs["input_ids"]
|
|
38
|
+
print(f"Prompt: '{PROMPT}', Token IDs: {input_ids[0].tolist()}")
|
|
39
|
+
num_tokens = input_ids.shape[1]
|
|
40
|
+
|
|
41
|
+
# Dump key weight values for layer 0
|
|
42
|
+
layer0 = model.model.layers[0]
|
|
43
|
+
attn = layer0.linear_attn
|
|
44
|
+
|
|
45
|
+
print(f"\n=== Layer 0 weights ===")
|
|
46
|
+
if hasattr(attn, 'A_log'):
|
|
47
|
+
a_log = attn.A_log.detach().float()
|
|
48
|
+
a_neg_exp = -torch.exp(a_log)
|
|
49
|
+
stats("A_log", a_log)
|
|
50
|
+
stats("a_neg_exp", a_neg_exp)
|
|
51
|
+
if hasattr(attn, 'dt_bias'):
|
|
52
|
+
stats("dt_bias", attn.dt_bias.detach().float())
|
|
53
|
+
stats("conv1d.weight", attn.conv1d.weight.detach().float())
|
|
54
|
+
stats("norm.weight", attn.norm.weight.detach().float())
|
|
55
|
+
|
|
56
|
+
# Hook into the linear_attn module to capture its input and output
|
|
57
|
+
captured = {}
|
|
58
|
+
|
|
59
|
+
def hook_linear_attn_input(module, args, kwargs):
|
|
60
|
+
if len(args) > 0:
|
|
61
|
+
captured['linear_attn_input'] = args[0].detach().clone()
|
|
62
|
+
return None
|
|
63
|
+
|
|
64
|
+
def hook_linear_attn_output(module, args, kwargs, output):
|
|
65
|
+
if isinstance(output, tuple):
|
|
66
|
+
captured['linear_attn_output'] = output[0].detach().clone()
|
|
67
|
+
else:
|
|
68
|
+
captured['linear_attn_output'] = output.detach().clone()
|
|
69
|
+
return None
|
|
70
|
+
|
|
71
|
+
# Hook into individual projection layers
|
|
72
|
+
def make_hook(name):
|
|
73
|
+
def hook(module, input, output):
|
|
74
|
+
captured[name] = output.detach().clone()
|
|
75
|
+
return hook
|
|
76
|
+
|
|
77
|
+
hooks = []
|
|
78
|
+
hooks.append(attn.register_forward_pre_hook(hook_linear_attn_input, with_kwargs=True))
|
|
79
|
+
hooks.append(attn.register_forward_hook(hook_linear_attn_output, with_kwargs=True))
|
|
80
|
+
hooks.append(attn.in_proj_qkv.register_forward_hook(make_hook('qkv_proj')))
|
|
81
|
+
hooks.append(attn.in_proj_z.register_forward_hook(make_hook('z_proj')))
|
|
82
|
+
hooks.append(attn.in_proj_a.register_forward_hook(make_hook('a_proj')))
|
|
83
|
+
hooks.append(attn.in_proj_b.register_forward_hook(make_hook('b_proj')))
|
|
84
|
+
hooks.append(attn.out_proj.register_forward_hook(make_hook('out_proj')))
|
|
85
|
+
hooks.append(attn.conv1d.register_forward_hook(make_hook('conv1d_raw')))
|
|
86
|
+
hooks.append(attn.norm.register_forward_hook(make_hook('gated_norm')))
|
|
87
|
+
|
|
88
|
+
# Also hook input_layernorm
|
|
89
|
+
hooks.append(layer0.input_layernorm.register_forward_hook(make_hook('input_layernorm')))
|
|
90
|
+
|
|
91
|
+
print(f"\n=== Running forward pass ===")
|
|
92
|
+
with torch.no_grad():
|
|
93
|
+
outputs = model(input_ids, output_hidden_states=True)
|
|
94
|
+
|
|
95
|
+
# Remove hooks
|
|
96
|
+
for h in hooks:
|
|
97
|
+
h.remove()
|
|
98
|
+
|
|
99
|
+
print(f"\n=== Captured intermediates ===")
|
|
100
|
+
for name in ['input_layernorm', 'qkv_proj', 'z_proj', 'a_proj', 'b_proj',
|
|
101
|
+
'conv1d_raw', 'gated_norm', 'linear_attn_input', 'linear_attn_output', 'out_proj']:
|
|
102
|
+
if name in captured:
|
|
103
|
+
stats(name, captured[name])
|
|
104
|
+
else:
|
|
105
|
+
print(f" {name}: NOT CAPTURED")
|
|
106
|
+
|
|
107
|
+
# Hidden states per layer
|
|
108
|
+
print(f"\n=== Hidden states per layer (last token) ===")
|
|
109
|
+
for i in range(min(6, len(outputs.hidden_states) - 1)):
|
|
110
|
+
hs = outputs.hidden_states[i + 1]
|
|
111
|
+
t = hs[0, -1] # last token
|
|
112
|
+
vals = t[:8].tolist()
|
|
113
|
+
max_abs = t.abs().max().item()
|
|
114
|
+
mean_abs = t.abs().mean().item()
|
|
115
|
+
layer_type = type(model.model.layers[i]).__name__
|
|
116
|
+
attn_type = "linear" if hasattr(model.model.layers[i], 'linear_attn') else "full"
|
|
117
|
+
print(f" Layer {i} ({attn_type}): first8={[f'{v:.4f}' for v in vals]}, "
|
|
118
|
+
f"maxAbs={max_abs:.4f}, meanAbs={mean_abs:.4f}")
|
|
119
|
+
|
|
120
|
+
# Logits
|
|
121
|
+
logits = outputs.logits[0, -1]
|
|
122
|
+
top5 = torch.topk(logits, 5)
|
|
123
|
+
print(f"\nTop-5 logits: {[(tokenizer.decode([idx.item()]), f'{val.item():.2f}') for val, idx in zip(top5.values, top5.indices)]}")
|
|
124
|
+
|
|
125
|
+
# Also trace through the linear attention manually to compare with Doppler's kernel
|
|
126
|
+
print(f"\n=== Manual linear attention trace (layer 0) ===")
|
|
127
|
+
with torch.no_grad():
|
|
128
|
+
embed = model.model.embed_tokens(input_ids)
|
|
129
|
+
normed = layer0.input_layernorm(embed)
|
|
130
|
+
stats("normed_input", normed)
|
|
131
|
+
|
|
132
|
+
qkv = attn.in_proj_qkv(normed)
|
|
133
|
+
stats("qkv", qkv)
|
|
134
|
+
|
|
135
|
+
# The HF Qwen3.5 GatedDeltaNet does conv1d on the QKV, then applies SiLU
|
|
136
|
+
# The conv1d expects [batch, channels, seq_len] format
|
|
137
|
+
qkv_t = qkv.transpose(1, 2) # [1, 6144, 1]
|
|
138
|
+
|
|
139
|
+
# Use the conv1d module directly (it has padding configured)
|
|
140
|
+
conv_raw = attn.conv1d(qkv_t)
|
|
141
|
+
stats("conv_raw (from module)", conv_raw.transpose(1, 2))
|
|
142
|
+
|
|
143
|
+
# Truncate to seq_len (causal conv padding)
|
|
144
|
+
conv_causal = conv_raw[..., :num_tokens]
|
|
145
|
+
stats("conv_causal (truncated)", conv_causal.transpose(1, 2))
|
|
146
|
+
|
|
147
|
+
# Apply SiLU
|
|
148
|
+
conv_silu = torch.nn.functional.silu(conv_causal)
|
|
149
|
+
stats("conv_silu", conv_silu.transpose(1, 2))
|
|
150
|
+
|
|
151
|
+
# Split Q, K, V
|
|
152
|
+
conv_out = conv_silu.transpose(1, 2) # [1, seq_len, 6144]
|
|
153
|
+
num_k_heads = 16
|
|
154
|
+
head_k_dim = 128
|
|
155
|
+
head_v_dim = 128
|
|
156
|
+
num_v_heads = 16
|
|
157
|
+
q_size = num_k_heads * head_k_dim # 2048
|
|
158
|
+
k_size = q_size
|
|
159
|
+
v_size = num_v_heads * head_v_dim # 2048
|
|
160
|
+
|
|
161
|
+
q = conv_out[..., :q_size]
|
|
162
|
+
k = conv_out[..., q_size:q_size + k_size]
|
|
163
|
+
v = conv_out[..., q_size + k_size:]
|
|
164
|
+
stats("Q (raw)", q)
|
|
165
|
+
stats("K (raw)", k)
|
|
166
|
+
stats("V (raw)", v)
|
|
167
|
+
|
|
168
|
+
# Reshape for per-head processing
|
|
169
|
+
# Q and K: [batch, seq, num_k_heads, head_k_dim]
|
|
170
|
+
q_heads = q.view(1, num_tokens, num_k_heads, head_k_dim)
|
|
171
|
+
k_heads = k.view(1, num_tokens, num_k_heads, head_k_dim)
|
|
172
|
+
v_heads = v.view(1, num_tokens, num_v_heads, head_v_dim)
|
|
173
|
+
|
|
174
|
+
# L2 normalize Q and K
|
|
175
|
+
eps = 1e-6
|
|
176
|
+
q_norm = torch.nn.functional.normalize(q_heads, p=2, dim=-1, eps=eps)
|
|
177
|
+
k_norm = torch.nn.functional.normalize(k_heads, p=2, dim=-1, eps=eps)
|
|
178
|
+
|
|
179
|
+
# Scale Q by 1/sqrt(head_k_dim)
|
|
180
|
+
head_scale = 1.0 / (head_k_dim ** 0.5)
|
|
181
|
+
q_scaled = q_norm * head_scale
|
|
182
|
+
|
|
183
|
+
stats("Q_normed_scaled (per-head)", q_scaled.reshape(1, num_tokens, -1))
|
|
184
|
+
stats("K_normed (per-head)", k_norm.reshape(1, num_tokens, -1))
|
|
185
|
+
|
|
186
|
+
# Projections for gating
|
|
187
|
+
z = attn.in_proj_z(normed)
|
|
188
|
+
a_out = attn.in_proj_a(normed)
|
|
189
|
+
b_out = attn.in_proj_b(normed)
|
|
190
|
+
stats("z", z)
|
|
191
|
+
stats("a", a_out)
|
|
192
|
+
stats("b", b_out)
|
|
193
|
+
|
|
194
|
+
# Compute gating values
|
|
195
|
+
a_log = attn.A_log.detach().float()
|
|
196
|
+
a_neg_exp = -torch.exp(a_log)
|
|
197
|
+
dt_bias = attn.dt_bias.detach().float()
|
|
198
|
+
|
|
199
|
+
softplus_input = a_out.squeeze(0).squeeze(0) + dt_bias
|
|
200
|
+
softplus_val = torch.nn.functional.softplus(softplus_input)
|
|
201
|
+
g = a_neg_exp * softplus_val
|
|
202
|
+
g_exp = torch.exp(g)
|
|
203
|
+
beta = torch.sigmoid(b_out.squeeze(0).squeeze(0))
|
|
204
|
+
|
|
205
|
+
stats("softplus(a + dt_bias)", softplus_val.unsqueeze(0).unsqueeze(0))
|
|
206
|
+
stats("g (decay)", g.unsqueeze(0).unsqueeze(0))
|
|
207
|
+
stats("g_exp (decay factor)", g_exp.unsqueeze(0).unsqueeze(0))
|
|
208
|
+
stats("beta (sigmoid(b))", beta.unsqueeze(0).unsqueeze(0))
|
|
209
|
+
|
|
210
|
+
# Recurrent state update (for first token, state is all zeros)
|
|
211
|
+
# state[head, kd, vd] = state * g_exp + k[kd] * delta[vd]
|
|
212
|
+
# where delta[vd] = (v[vd] - state^T @ k * beta
|
|
213
|
+
# For zero state: delta[vd] = v[vd] * beta, state = k ⊗ delta
|
|
214
|
+
state = torch.zeros(num_v_heads, head_k_dim, head_v_dim)
|
|
215
|
+
|
|
216
|
+
# Apply decay (no-op for zero state)
|
|
217
|
+
for head in range(num_v_heads):
|
|
218
|
+
state[head] *= g_exp[head].item()
|
|
219
|
+
|
|
220
|
+
k_head = k_norm[0, 0, head % num_k_heads] # broadcast q_rep
|
|
221
|
+
v_head = v_heads[0, 0, head]
|
|
222
|
+
|
|
223
|
+
# kv_mem = state @ k
|
|
224
|
+
kv_mem = state[head].t() @ k_head # [head_v_dim]
|
|
225
|
+
|
|
226
|
+
# delta = (v - kv_mem) * beta
|
|
227
|
+
delta = (v_head - kv_mem) * beta[head].item()
|
|
228
|
+
|
|
229
|
+
# state += outer(k, delta)
|
|
230
|
+
state[head] += torch.outer(k_head, delta)
|
|
231
|
+
|
|
232
|
+
# Output: out = state^T @ q
|
|
233
|
+
output_per_head = torch.zeros(1, num_tokens, num_v_heads, head_v_dim)
|
|
234
|
+
for head in range(num_v_heads):
|
|
235
|
+
q_head = q_scaled[0, 0, head % num_k_heads]
|
|
236
|
+
out_head = state[head].t() @ q_head # [head_v_dim]
|
|
237
|
+
output_per_head[0, 0, head] = out_head
|
|
238
|
+
|
|
239
|
+
raw_out = output_per_head.reshape(1, num_tokens, num_v_heads * head_v_dim)
|
|
240
|
+
stats("Recurrent output (raw)", raw_out)
|
|
241
|
+
|
|
242
|
+
# RMS norm per head + SiLU gate
|
|
243
|
+
z_reshaped = z.view(1, num_tokens, num_v_heads, head_v_dim)
|
|
244
|
+
norm_weight = attn.norm.weight.detach().float() # [head_v_dim] (shared mode)
|
|
245
|
+
rms_eps = 1e-6
|
|
246
|
+
|
|
247
|
+
for head in range(num_v_heads):
|
|
248
|
+
head_out = output_per_head[0, 0, head] # [head_v_dim]
|
|
249
|
+
mean_sq = (head_out ** 2).mean()
|
|
250
|
+
inv_rms = 1.0 / torch.sqrt(mean_sq + rms_eps)
|
|
251
|
+
z_gate = torch.nn.functional.silu(z_reshaped[0, 0, head])
|
|
252
|
+
output_per_head[0, 0, head] = head_out * inv_rms * norm_weight * z_gate
|
|
253
|
+
|
|
254
|
+
gated_out = output_per_head.reshape(1, num_tokens, num_v_heads * head_v_dim)
|
|
255
|
+
stats("After RMSNorm + SiLU gate", gated_out)
|
|
256
|
+
|
|
257
|
+
# Output projection
|
|
258
|
+
o_result = torch.nn.functional.linear(gated_out, attn.out_proj.weight)
|
|
259
|
+
stats("After out_proj", o_result)
|
|
260
|
+
|
|
261
|
+
# Compare with captured output
|
|
262
|
+
if 'linear_attn_output' in captured:
|
|
263
|
+
diff = (o_result - captured['linear_attn_output']).abs()
|
|
264
|
+
print(f"\n Diff vs captured output: maxDiff={diff.max().item():.6f}")
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
if __name__ == "__main__":
|
|
268
|
+
main()
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import { log } from '../debug/index.js';
|
|
2
|
+
import { getExpectedShardHash } from '../formats/rdrr/index.js';
|
|
2
3
|
import {
|
|
3
4
|
computeHash,
|
|
4
5
|
createStreamingHasher,
|
|
@@ -2018,7 +2019,11 @@ export async function downloadShard(
|
|
|
2018
2019
|
onDeliveryMetrics,
|
|
2019
2020
|
signal,
|
|
2020
2021
|
requiredEncoding: requiredEncoding ?? activeConfig.requiredContentEncoding ?? null,
|
|
2021
|
-
expectedHash:
|
|
2022
|
+
expectedHash:
|
|
2023
|
+
options.expectedHash
|
|
2024
|
+
?? getExpectedShardHash(shardInfo, algorithm)
|
|
2025
|
+
?? activeConfig.expectedHash
|
|
2026
|
+
?? null,
|
|
2022
2027
|
expectedSize: expectedSize ?? shardInfo?.size ?? null,
|
|
2023
2028
|
expectedManifestVersionSet: options.expectedManifestVersionSet ?? null,
|
|
2024
2029
|
writeToStore,
|
|
@@ -7,6 +7,10 @@
|
|
|
7
7
|
import type { RDRRManifest, ShardInfo, TensorMap } from './types.js';
|
|
8
8
|
|
|
9
9
|
export declare function parseManifest(jsonString: string): RDRRManifest;
|
|
10
|
+
export declare function getExpectedShardHash(
|
|
11
|
+
shard: Partial<ShardInfo> | Record<string, unknown> | null | undefined,
|
|
12
|
+
manifestHashAlgorithm?: string | null
|
|
13
|
+
): string;
|
|
10
14
|
|
|
11
15
|
export declare function parseTensorMap(jsonString: string): TensorMap;
|
|
12
16
|
|
|
@@ -4,6 +4,19 @@ import { validateManifest } from './validation.js';
|
|
|
4
4
|
|
|
5
5
|
let currentManifest = null;
|
|
6
6
|
|
|
7
|
+
export function getExpectedShardHash(shard, manifestHashAlgorithm = null) {
|
|
8
|
+
if (!shard || typeof shard !== 'object' || Array.isArray(shard)) {
|
|
9
|
+
return '';
|
|
10
|
+
}
|
|
11
|
+
const algorithm = typeof manifestHashAlgorithm === 'string'
|
|
12
|
+
? manifestHashAlgorithm.trim().toLowerCase()
|
|
13
|
+
: '';
|
|
14
|
+
if (algorithm === 'blake3') {
|
|
15
|
+
return shard.blake3 || shard.hash || '';
|
|
16
|
+
}
|
|
17
|
+
return shard.hash || shard.blake3 || '';
|
|
18
|
+
}
|
|
19
|
+
|
|
7
20
|
export function parseManifest(jsonString) {
|
|
8
21
|
let manifest;
|
|
9
22
|
|
|
@@ -21,7 +34,7 @@ export function parseManifest(jsonString) {
|
|
|
21
34
|
index: shard.index ?? i,
|
|
22
35
|
filename: shard.filename || shard.fileName || '',
|
|
23
36
|
size: shard.size,
|
|
24
|
-
hash: shard
|
|
37
|
+
hash: getExpectedShardHash(shard, manifest.hashAlgorithm),
|
|
25
38
|
blake3: shard.blake3 || shard.hash,
|
|
26
39
|
offset: shard.offset ?? offset,
|
|
27
40
|
hashAlgorithm: shard.hashAlgorithm,
|
|
@@ -326,6 +326,14 @@ export {
|
|
|
326
326
|
type SplitQKVResult,
|
|
327
327
|
} from './split_qkv.js';
|
|
328
328
|
|
|
329
|
+
// Split Q and Gate (de-interleave attentionOutputGate q_proj output)
|
|
330
|
+
export {
|
|
331
|
+
runSplitQG,
|
|
332
|
+
recordSplitQG,
|
|
333
|
+
type SplitQGOptions,
|
|
334
|
+
type SplitQGResult,
|
|
335
|
+
} from './split_qg.js';
|
|
336
|
+
|
|
329
337
|
// Transpose
|
|
330
338
|
export {
|
|
331
339
|
runTranspose,
|
package/src/gpu/kernels/index.js
CHANGED
|
@@ -268,6 +268,12 @@ export {
|
|
|
268
268
|
recordSplitQKV,
|
|
269
269
|
} from './split_qkv.js';
|
|
270
270
|
|
|
271
|
+
// Split Q and Gate (de-interleave attentionOutputGate q_proj output)
|
|
272
|
+
export {
|
|
273
|
+
runSplitQG,
|
|
274
|
+
recordSplitQG,
|
|
275
|
+
} from './split_qg.js';
|
|
276
|
+
|
|
271
277
|
// Transpose
|
|
272
278
|
export {
|
|
273
279
|
runTranspose,
|
|
@@ -29,7 +29,13 @@ function selectQ4KFusedVariant(isM1, wantF16Output, aDtype) {
|
|
|
29
29
|
}
|
|
30
30
|
|
|
31
31
|
|
|
32
|
-
export function resolveMatmulPhase(M) {
|
|
32
|
+
export function resolveMatmulPhase(M, phaseOverride = null) {
|
|
33
|
+
if (phaseOverride != null) {
|
|
34
|
+
if (phaseOverride !== 'decode' && phaseOverride !== 'prefill') {
|
|
35
|
+
throw new Error(`[Matmul] Invalid phase override "${phaseOverride}". Expected "decode" or "prefill".`);
|
|
36
|
+
}
|
|
37
|
+
return phaseOverride;
|
|
38
|
+
}
|
|
33
39
|
return selectKernelRuleValue('matmul', 'phase', { isDecode: M === 1 });
|
|
34
40
|
}
|
|
35
41
|
|
|
@@ -125,7 +131,9 @@ export function selectMatmulKernel(options = {}) {
|
|
|
125
131
|
const { tiledPrefillMinRows } = getKernelThresholds().matmul;
|
|
126
132
|
|
|
127
133
|
const inputsAreF16 = aDtype === 'f16' && bDtype === 'f16';
|
|
128
|
-
|
|
134
|
+
// F16 weights needing F32a path: weights are F16 and either activation is already F32,
|
|
135
|
+
// or both inputs are F16 but output is F32 (activation will be cast to F32 by executeMatmul)
|
|
136
|
+
const weightsAreF16 = bDtype === 'f16' && (aDtype !== 'f16' || outputDtype !== 'f16');
|
|
129
137
|
const useF16Matmul = outputDtype === 'f16' && preferF16 && inputsAreF16 && capabilities.hasF16;
|
|
130
138
|
const useF16wF32a = preferF16 && weightsAreF16 && capabilities.hasF16;
|
|
131
139
|
const useTiled = isPrefill
|
|
@@ -244,6 +252,30 @@ export function requiresF32Input(variant) {
|
|
|
244
252
|
return !supportsF16Input(variant);
|
|
245
253
|
}
|
|
246
254
|
|
|
255
|
+
function resolveRequiredWeightDtype(config) {
|
|
256
|
+
const shaderFile = String(config?.shaderFile ?? config?.wgsl ?? '');
|
|
257
|
+
if (!shaderFile) {
|
|
258
|
+
return null;
|
|
259
|
+
}
|
|
260
|
+
if (shaderFile.startsWith('fused_matmul_q4')) {
|
|
261
|
+
return 'q4k';
|
|
262
|
+
}
|
|
263
|
+
if (
|
|
264
|
+
shaderFile === 'matmul_f16.wgsl'
|
|
265
|
+
|| shaderFile === 'matmul_f16_tiled.wgsl'
|
|
266
|
+
|| shaderFile === 'matmul_f16w_f32a.wgsl'
|
|
267
|
+
|| shaderFile === 'matmul_f16w_f32a_tiled.wgsl'
|
|
268
|
+
|| shaderFile === 'matmul_gemv_subgroup.wgsl'
|
|
269
|
+
|| shaderFile === 'matmul_gemv_subgroup_f16a.wgsl'
|
|
270
|
+
) {
|
|
271
|
+
return 'f16';
|
|
272
|
+
}
|
|
273
|
+
if (shaderFile === 'matmul_f32.wgsl') {
|
|
274
|
+
return 'f32';
|
|
275
|
+
}
|
|
276
|
+
return null;
|
|
277
|
+
}
|
|
278
|
+
|
|
247
279
|
|
|
248
280
|
function resolveMatmulOverride(
|
|
249
281
|
variantOverride,
|
|
@@ -287,6 +319,16 @@ function resolveMatmulOverride(
|
|
|
287
319
|
);
|
|
288
320
|
}
|
|
289
321
|
|
|
322
|
+
const requiredWeightDtype = resolveRequiredWeightDtype(config);
|
|
323
|
+
const weightDtypeOk = !requiredWeightDtype
|
|
324
|
+
|| bDtype === requiredWeightDtype
|
|
325
|
+
|| (requiredWeightDtype === 'f16' && bDtype === 'q4k');
|
|
326
|
+
if (!weightDtypeOk) {
|
|
327
|
+
return failOrWarn(
|
|
328
|
+
`Matmul kernel "${variantOverride}" requires ${requiredWeightDtype} weights but B dtype is ${bDtype}.`
|
|
329
|
+
);
|
|
330
|
+
}
|
|
331
|
+
|
|
290
332
|
if (supportsF16Input(override) && aDtype !== 'f16') {
|
|
291
333
|
return failOrWarn(`Matmul kernel "${variantOverride}" requires f16 activations but A dtype is ${aDtype}.`);
|
|
292
334
|
}
|
|
@@ -341,7 +383,7 @@ function selectGemvVariant(useF16Gemv, useF32Gemv, hasSubgroups, useVec4, N, mul
|
|
|
341
383
|
export function selectMatmulVariantAndFlags(mode, M, N, K, aDtype, bDtype, transposeB, requestedOutputDtype, options) {
|
|
342
384
|
const capabilities = getKernelCapabilities();
|
|
343
385
|
const strict = getKernelPathStrict();
|
|
344
|
-
const phase = resolveMatmulPhase(M);
|
|
386
|
+
const phase = resolveMatmulPhase(M, options.phaseOverride ?? null);
|
|
345
387
|
let pathVariant = getKernelPathMatmulVariant(options.role, phase, options.layerIdx, options.kernelPath);
|
|
346
388
|
const hadPathVariant = Boolean(pathVariant);
|
|
347
389
|
|
|
@@ -426,7 +468,8 @@ export function selectMatmulVariantAndFlags(mode, M, N, K, aDtype, bDtype, trans
|
|
|
426
468
|
|
|
427
469
|
const canGemv = M === 1 && effectiveBDtype === 'f16' && capabilities.hasF16;
|
|
428
470
|
const useF16Gemv = canGemv && aDtype === 'f16' && wantF16Output;
|
|
429
|
-
|
|
471
|
+
// F32 GEMV: activation is F32, or activation is F16 with F32 output (will be cast to F32)
|
|
472
|
+
const useF32Gemv = canGemv && (aDtype === 'f32' || (aDtype === 'f16' && !wantF16Output));
|
|
430
473
|
const useGemv = useF16Gemv || useF32Gemv;
|
|
431
474
|
const useVec4 = (K % 4 === 0);
|
|
432
475
|
const { multicolThreshold } = getKernelThresholds().matmul;
|
|
@@ -23,6 +23,8 @@ export interface MatmulOptions extends OutputBufferOptions, OutputDtypeOptions,
|
|
|
23
23
|
layerIdx?: number;
|
|
24
24
|
/** Explicit kernel path context for variant selection (avoids global path state). */
|
|
25
25
|
kernelPath?: KernelPathSchema | null;
|
|
26
|
+
/** Optional explicit phase for kernel-path lookup when the runtime rewrites rows (for example prefill last-position logits). */
|
|
27
|
+
phaseOverride?: 'decode' | 'prefill' | null;
|
|
26
28
|
/**
|
|
27
29
|
* Whether B matrix is stored transposed.
|
|
28
30
|
* - true: B is [N,K] (SafeTensors/row-major), needs transpose
|
|
@@ -165,7 +165,7 @@ async function executeMatmul(recorder, A, B, M, N, K, options = {}) {
|
|
|
165
165
|
options
|
|
166
166
|
);
|
|
167
167
|
|
|
168
|
-
const phase = resolveMatmulPhase(M);
|
|
168
|
+
const phase = resolveMatmulPhase(M, options.phaseOverride ?? null);
|
|
169
169
|
const constants = resolveMatmulConstants(options, phase);
|
|
170
170
|
|
|
171
171
|
let matmulInput = A;
|
|
@@ -9,6 +9,9 @@ import { selectRuleValue as selectLoaderRule } from '../../rules/rule-registry.j
|
|
|
9
9
|
import { getBuffer, getWeightDtype, getBufferDtype } from '../weight-buffer.js';
|
|
10
10
|
import { unifiedKernelWrapper } from './utils.js';
|
|
11
11
|
|
|
12
|
+
// Conservative fallback dtype for norm weight inference when metadata is unavailable.
|
|
13
|
+
const DEFAULT_DTYPE = 'f32';
|
|
14
|
+
|
|
12
15
|
function inferHiddenSize(input, hiddenSize) {
|
|
13
16
|
if (hiddenSize != null) return hiddenSize;
|
|
14
17
|
const shape = input?.shape;
|
|
@@ -39,9 +42,12 @@ function resolveNormWeightDtype(weight, hiddenSize) {
|
|
|
39
42
|
return taggedDtype;
|
|
40
43
|
}
|
|
41
44
|
|
|
45
|
+
// Conservative fallback: f32 avoids precision loss when dtype cannot be determined.
|
|
46
|
+
// This path fires for non-GPU buffers or missing hiddenSize, both of which prevent
|
|
47
|
+
// size-based dtype inference below.
|
|
42
48
|
const hasGPUBufferType = typeof GPUBuffer !== 'undefined';
|
|
43
49
|
if (!hasGPUBufferType || !(weightBuffer instanceof GPUBuffer) || hiddenSize == null || hiddenSize <= 0) {
|
|
44
|
-
return
|
|
50
|
+
return DEFAULT_DTYPE;
|
|
45
51
|
}
|
|
46
52
|
|
|
47
53
|
const byteSize = getBufferRequestedSize(weightBuffer);
|
|
@@ -55,7 +61,8 @@ function resolveNormWeightDtype(weight, hiddenSize) {
|
|
|
55
61
|
sizeMatchesF32,
|
|
56
62
|
});
|
|
57
63
|
}
|
|
58
|
-
|
|
64
|
+
// Buffer size matches neither f16 nor f32 for given hiddenSize; fall back to f32.
|
|
65
|
+
return DEFAULT_DTYPE;
|
|
59
66
|
}
|
|
60
67
|
|
|
61
68
|
function assertRMSNormWeightBuffer(weight, weightBuffer, hiddenSize) {
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Split Q and Gate Kernel
|
|
3
|
+
*
|
|
4
|
+
* De-interleaves Q and Gate projections from q_proj output for attentionOutputGate models.
|
|
5
|
+
* Models like Qwen 3.5 store q_proj weights in per-head interleaved layout:
|
|
6
|
+
* rows [h*headDim*2 : h*headDim*2+headDim] = Q for head h
|
|
7
|
+
* rows [h*headDim*2+headDim : (h+1)*headDim*2] = Gate for head h
|
|
8
|
+
* This kernel separates the full matmul output into contiguous Q and Gate tensors.
|
|
9
|
+
*/
|
|
10
|
+
|
|
11
|
+
import type { Tensor } from '../tensor.js';
|
|
12
|
+
import type { CommandRecorder } from '../command-recorder.js';
|
|
13
|
+
|
|
14
|
+
/** Split Q and Gate options */
|
|
15
|
+
export interface SplitQGOptions {
|
|
16
|
+
numTokens: number;
|
|
17
|
+
numHeads: number;
|
|
18
|
+
headDim: number;
|
|
19
|
+
/** Pre-allocated Q output tensor */
|
|
20
|
+
qTensor?: Tensor | null;
|
|
21
|
+
/** Pre-allocated Gate output tensor */
|
|
22
|
+
gTensor?: Tensor | null;
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
/** Split Q and Gate result */
|
|
26
|
+
export interface SplitQGResult {
|
|
27
|
+
Q: Tensor;
|
|
28
|
+
G: Tensor;
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
/**
|
|
32
|
+
* De-interleave Q and Gate from q_proj output.
|
|
33
|
+
*
|
|
34
|
+
* @param qgTensor - Full q_proj output [numTokens, numHeads * headDim * 2] (interleaved)
|
|
35
|
+
* @param options - Split configuration
|
|
36
|
+
* @returns Separate Q and Gate tensors, each [numTokens, numHeads * headDim]
|
|
37
|
+
*/
|
|
38
|
+
export declare function runSplitQG(
|
|
39
|
+
qgTensor: Tensor,
|
|
40
|
+
options: SplitQGOptions
|
|
41
|
+
): Promise<SplitQGResult>;
|
|
42
|
+
|
|
43
|
+
/**
|
|
44
|
+
* Record split Q and Gate (batched, no submit).
|
|
45
|
+
*/
|
|
46
|
+
export declare function recordSplitQG(
|
|
47
|
+
recorder: CommandRecorder,
|
|
48
|
+
qgTensor: Tensor,
|
|
49
|
+
options: SplitQGOptions
|
|
50
|
+
): Promise<SplitQGResult>;
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
|
|
2
|
+
import { acquireBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
|
|
3
|
+
import { createTensor, dtypeBytes } from '../tensor.js';
|
|
4
|
+
import { WORKGROUP_SIZES } from './constants.js';
|
|
5
|
+
import { unifiedKernelWrapper } from './utils.js';
|
|
6
|
+
import { selectRuleValue } from './rule-registry.js';
|
|
7
|
+
|
|
8
|
+
async function _splitQG(target, qgTensor, options) {
|
|
9
|
+
const { numTokens, numHeads, headDim, qTensor = null, gTensor = null } = options;
|
|
10
|
+
const ownsQ = qTensor == null;
|
|
11
|
+
const ownsG = gTensor == null;
|
|
12
|
+
|
|
13
|
+
const outputDtype = qgTensor.dtype;
|
|
14
|
+
const pipelineVariant = selectRuleValue('splitQg', 'variant', { outputDtype });
|
|
15
|
+
const bytesPerElement = dtypeBytes(outputDtype);
|
|
16
|
+
const qSize = numHeads * headDim;
|
|
17
|
+
|
|
18
|
+
const qBuffer = qTensor?.buffer || acquireBuffer(numTokens * qSize * bytesPerElement, undefined, 'Q');
|
|
19
|
+
const gBuffer = gTensor?.buffer || acquireBuffer(numTokens * qSize * bytesPerElement, undefined, 'Q_gate');
|
|
20
|
+
|
|
21
|
+
try {
|
|
22
|
+
await unifiedKernelWrapper(
|
|
23
|
+
'split_qg', target, pipelineVariant,
|
|
24
|
+
[qgTensor, qBuffer, gBuffer],
|
|
25
|
+
{ num_tokens: numTokens, num_heads: numHeads, head_dim: headDim, _pad: 0 },
|
|
26
|
+
Math.ceil((numTokens * qSize) / WORKGROUP_SIZES.DEFAULT)
|
|
27
|
+
);
|
|
28
|
+
|
|
29
|
+
const Q = qTensor || createTensor(qBuffer, outputDtype, [numTokens, qSize], 'Q');
|
|
30
|
+
const G = gTensor || createTensor(gBuffer, outputDtype, [numTokens, qSize], 'Q_gate');
|
|
31
|
+
|
|
32
|
+
return { Q, G };
|
|
33
|
+
} catch (error) {
|
|
34
|
+
if (ownsQ) releaseBuffer(qBuffer);
|
|
35
|
+
if (ownsG) releaseBuffer(gBuffer);
|
|
36
|
+
throw error;
|
|
37
|
+
}
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
export async function runSplitQG(qgTensor, options) {
|
|
41
|
+
return _splitQG(null, qgTensor, options);
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
export async function recordSplitQG(recorder, qgTensor, options) {
|
|
45
|
+
return _splitQG(recorder, qgTensor, options);
|
|
46
|
+
}
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
// split_qg.wgsl
|
|
2
|
+
|
|
3
|
+
/**
|
|
4
|
+
* De-interleave Q and Gate projections from q_proj output for attentionOutputGate models.
|
|
5
|
+
*
|
|
6
|
+
* Models like Qwen 3.5 store q_proj weights with interleaved head layout:
|
|
7
|
+
* rows [h*headDim*2 : h*headDim*2+headDim] = Q for head h
|
|
8
|
+
* rows [h*headDim*2+headDim : (h+1)*headDim*2] = Gate for head h
|
|
9
|
+
*
|
|
10
|
+
* A single full matmul over all 2*qSize rows produces interleaved output:
|
|
11
|
+
* input[token, h*headDim*2 : h*headDim*2+headDim] = Q head h
|
|
12
|
+
* input[token, h*headDim*2+headDim : (h+1)*headDim*2] = Gate head h
|
|
13
|
+
*
|
|
14
|
+
* This kernel separates them into contiguous Q and G outputs:
|
|
15
|
+
* Q[token, h*headDim + dim] = input[token, h*headDim*2 + dim]
|
|
16
|
+
* G[token, h*headDim + dim] = input[token, h*headDim*2 + headDim + dim]
|
|
17
|
+
*
|
|
18
|
+
* Input layout (row-major): [numTokens, numHeads * headDim * 2]
|
|
19
|
+
* Output Q layout (row-major): [numTokens, numHeads * headDim]
|
|
20
|
+
* Output G layout (row-major): [numTokens, numHeads * headDim]
|
|
21
|
+
*/
|
|
22
|
+
|
|
23
|
+
struct Params {
|
|
24
|
+
num_tokens: u32,
|
|
25
|
+
num_heads: u32,
|
|
26
|
+
head_dim: u32,
|
|
27
|
+
_pad: u32,
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
override WORKGROUP_SIZE: u32 = 256u;
|
|
31
|
+
|
|
32
|
+
@group(0) @binding(0) var<uniform> params: Params;
|
|
33
|
+
@group(0) @binding(1) var<storage, read> input: array<f32>;
|
|
34
|
+
@group(0) @binding(2) var<storage, read_write> Q: array<f32>;
|
|
35
|
+
@group(0) @binding(3) var<storage, read_write> G: array<f32>;
|
|
36
|
+
|
|
37
|
+
@compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
|
|
38
|
+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
39
|
+
let idx = gid.x;
|
|
40
|
+
let q_size = params.num_heads * params.head_dim;
|
|
41
|
+
let total_elements = params.num_tokens * q_size;
|
|
42
|
+
|
|
43
|
+
if (idx >= total_elements) {
|
|
44
|
+
return;
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
let token = idx / q_size;
|
|
48
|
+
let elem = idx % q_size;
|
|
49
|
+
let head = elem / params.head_dim;
|
|
50
|
+
let dim = elem % params.head_dim;
|
|
51
|
+
|
|
52
|
+
// Input is interleaved per head: [Q_h (headDim elems), G_h (headDim elems)]
|
|
53
|
+
let src_q = token * (q_size * 2u) + head * (params.head_dim * 2u) + dim;
|
|
54
|
+
let src_g = src_q + params.head_dim;
|
|
55
|
+
|
|
56
|
+
Q[idx] = input[src_q];
|
|
57
|
+
G[idx] = input[src_g];
|
|
58
|
+
}
|