@simulatte/doppler 0.1.7 → 0.1.9
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +32 -0
- package/README.md +25 -6
- package/package.json +25 -38
- package/src/browser/browser-converter.js +5 -0
- package/src/client/doppler-api.browser.js +6 -0
- package/src/client/doppler-api.d.ts +3 -0
- package/src/client/doppler-api.js +11 -2
- package/src/client/doppler-registry.js +3 -5
- package/src/client/doppler-registry.json +2 -2
- package/src/config/kernel-path-loader.d.ts +5 -0
- package/src/config/kernel-path-loader.js +13 -0
- package/src/config/kernels/kernel-ref-digests.js +23 -21
- package/src/config/kernels/moe/mixtral.paths.json +46 -0
- package/src/config/kernels/registry.json +74 -0
- package/src/config/loader.js +9 -0
- package/src/config/merge-contract-check.js +7 -0
- package/src/config/platforms/loader.js +3 -1
- package/src/config/presets/kernel-paths/gemma3-q4k-dequant-f32a-nosubgroups.json +16 -16
- package/src/config/presets/kernel-paths/gemma3-q4k-dequant-f32a-online.json +8 -8
- package/src/config/presets/kernel-paths/gemma3-q4k-dequant-f32a-small-attn.json +61 -0
- package/src/config/presets/kernel-paths/gemma3-q4k-dequant-f32w-f32a-online.json +56 -0
- package/src/config/presets/kernel-paths/lfm2-q4k-dequant-f32a-nosubgroups.json +61 -0
- package/src/config/presets/kernel-paths/registry.json +21 -0
- package/src/config/presets/models/gemma2.json +2 -1
- package/src/config/presets/models/gemma3.json +4 -1
- package/src/config/presets/models/gemma4.json +61 -0
- package/src/config/presets/models/granite-docling.json +70 -0
- package/src/config/presets/models/lfm2.json +6 -1
- package/src/config/presets/models/qwen3.json +4 -3
- package/src/config/presets/models/qwen3_5.json +16 -0
- package/src/config/presets/models/qwen3_vl.json +40 -0
- package/src/config/presets/runtime/experiments/bench/gemma3-bench-q4k.json +2 -1
- package/src/config/presets/runtime/experiments/verify/lfm2-verify.json +46 -0
- package/src/config/presets/runtime/experiments/verify/translategemma-verify.json +39 -0
- package/src/config/presets/runtime/model/qwen3-5-layer-probe.json +52 -0
- package/src/config/presets/runtime/model/qwen3-5-linear-attn-debug.json +90 -0
- package/src/config/presets/runtime/modes/trace-layers.json +1 -0
- package/src/config/presets/runtime/tiers/gemma4-16gb.json +69 -0
- package/src/config/presets/runtime/tiers/gemma4-24gb.json +66 -0
- package/src/config/presets/runtime/tiers/gemma4-32gb.json +66 -0
- package/src/config/runtime.js +3 -0
- package/src/config/schema/conversion.schema.d.ts +1 -0
- package/src/config/schema/debug.schema.d.ts +40 -0
- package/src/config/schema/debug.schema.js +28 -0
- package/src/config/schema/index.js +2 -0
- package/src/config/schema/inference-defaults.schema.js +1 -1
- package/src/config/schema/kernel-path.schema.d.ts +1 -0
- package/src/config/schema/manifest.schema.d.ts +1 -1
- package/src/config/schema/manifest.schema.js +1 -1
- package/src/config/schema/memory-limits.schema.js +2 -2
- package/src/config/schema/storage.schema.js +2 -2
- package/src/converter/conversion-plan.js +11 -3
- package/src/converter/core.js +19 -8
- package/src/converter/manifest-inference.js +12 -22
- package/src/converter/parsers/transformer.js +4 -0
- package/src/converter/quantization-info.js +5 -1
- package/src/converter/quantizer.d.ts +5 -0
- package/src/converter/quantizer.js +34 -12
- package/src/converter/rope-config.js +8 -6
- package/src/converter/tokenizer-utils.d.ts +1 -0
- package/src/converter/tokenizer-utils.js +4 -1
- package/src/debug/reference/hf_qwen35_linear_attn_debug.py +268 -0
- package/src/distribution/shard-delivery.js +40 -1
- package/src/formats/rdrr/classification.js +32 -0
- package/src/formats/rdrr/parsing.d.ts +4 -0
- package/src/formats/rdrr/parsing.js +14 -1
- package/src/gpu/kernel-runtime.js +4 -2
- package/src/gpu/kernels/attention.js +2 -1
- package/src/gpu/kernels/dequant_f16_out.wgsl +4 -2
- package/src/gpu/kernels/dequant_f16_out_vec4.wgsl +5 -2
- package/src/gpu/kernels/dequant_shared.wgsl +4 -2
- package/src/gpu/kernels/dequant_shared_vec4.wgsl +4 -2
- package/src/gpu/kernels/dequant_subgroup.wgsl +6 -2
- package/src/gpu/kernels/gated-short-conv.d.ts +63 -0
- package/src/gpu/kernels/gated-short-conv.js +284 -0
- package/src/gpu/kernels/index.d.ts +8 -0
- package/src/gpu/kernels/index.js +6 -0
- package/src/gpu/kernels/linear-attention-core.js +37 -17
- package/src/gpu/kernels/matmul-selection.js +48 -4
- package/src/gpu/kernels/matmul.d.ts +5 -0
- package/src/gpu/kernels/matmul.js +71 -2
- package/src/gpu/kernels/matmul_gemv_subgroup.wgsl +77 -79
- package/src/gpu/kernels/rmsnorm.js +9 -2
- package/src/gpu/kernels/sample.js +1 -3
- package/src/gpu/kernels/sample.wgsl +39 -9
- package/src/gpu/kernels/sample_f16.wgsl +38 -8
- package/src/gpu/kernels/shader-cache.js +9 -4
- package/src/gpu/kernels/split_qg.d.ts +50 -0
- package/src/gpu/kernels/split_qg.js +46 -0
- package/src/gpu/kernels/split_qg.wgsl +58 -0
- package/src/gpu/kernels/split_qg_f16.wgsl +62 -0
- package/src/gpu/weight-buffer.d.ts +1 -1
- package/src/gpu/weight-buffer.js +1 -1
- package/src/inference/browser-harness.d.ts +2 -0
- package/src/inference/browser-harness.js +20 -1
- package/src/inference/kv-cache/base.js +3 -10
- package/src/inference/pipelines/diffusion/helpers.js +3 -0
- package/src/inference/pipelines/diffusion/pipeline.js +2 -1
- package/src/inference/pipelines/diffusion/text-encoder-gpu.js +10 -3
- package/src/inference/pipelines/text/attention/output-projection.d.ts +12 -0
- package/src/inference/pipelines/text/attention/output-projection.js +8 -0
- package/src/inference/pipelines/text/attention/projections.d.ts +13 -1
- package/src/inference/pipelines/text/attention/projections.js +54 -13
- package/src/inference/pipelines/text/attention/record.js +16 -6
- package/src/inference/pipelines/text/attention/run.js +59 -6
- package/src/inference/pipelines/text/config.d.ts +1 -0
- package/src/inference/pipelines/text/config.js +46 -4
- package/src/inference/pipelines/text/embed.js +26 -7
- package/src/inference/pipelines/text/execution-plan.js +5 -4
- package/src/inference/pipelines/text/execution-v0-runtime-builders.js +10 -3
- package/src/inference/pipelines/text/execution-v0.js +12 -1
- package/src/inference/pipelines/text/generator-helpers.js +1 -0
- package/src/inference/pipelines/text/generator-runtime.js +19 -0
- package/src/inference/pipelines/text/generator-steps.d.ts +15 -0
- package/src/inference/pipelines/text/generator-steps.js +71 -26
- package/src/inference/pipelines/text/generator.d.ts +5 -0
- package/src/inference/pipelines/text/generator.js +353 -166
- package/src/inference/pipelines/text/init.d.ts +15 -0
- package/src/inference/pipelines/text/init.js +35 -10
- package/src/inference/pipelines/text/layer.js +38 -8
- package/src/inference/pipelines/text/linear-attention.d.ts +5 -0
- package/src/inference/pipelines/text/linear-attention.js +33 -3
- package/src/inference/pipelines/text/logits/gpu.js +2 -2
- package/src/inference/pipelines/text/logits/index.d.ts +6 -1
- package/src/inference/pipelines/text/logits/index.js +3 -1
- package/src/inference/pipelines/text/model-load.js +3 -0
- package/src/inference/pipelines/text/moe-gpu.js +21 -3
- package/src/inference/pipelines/text/moe-shape-validator.d.ts +9 -0
- package/src/inference/pipelines/text/moe-shape-validator.js +31 -11
- package/src/inference/pipelines/text/ops.js +123 -53
- package/src/inference/pipelines/text/probes.js +1 -0
- package/src/inference/pipelines/text/sampling.js +52 -6
- package/src/inference/pipelines/text/state.js +2 -0
- package/src/inference/pipelines/text.d.ts +5 -0
- package/src/inference/pipelines/text.js +59 -1
- package/src/inference/pipelines/vision/encoder.js +386 -0
- package/src/inference/pipelines/vision/image-preprocess.js +151 -0
- package/src/inference/pipelines/vision/index.js +173 -0
- package/src/inference/pipelines/vision/ops.js +78 -0
- package/src/inference/pipelines/vision/patch-embed.js +151 -0
- package/src/inference/test-harness.js +11 -9
- package/src/loader/doppler-loader.d.ts +3 -0
- package/src/loader/doppler-loader.js +20 -3
- package/src/loader/experts/expert-cache.js +6 -2
- package/src/loader/experts/expert-loader.js +6 -2
- package/src/loader/final-weights-loader.js +2 -0
- package/src/loader/layer-loader.js +42 -3
- package/src/loader/manifest-config.js +3 -1
- package/src/loader/shard-cache.js +3 -2
- package/src/loader/tensors/tensor-loader.d.ts +3 -0
- package/src/loader/tensors/tensor-loader.js +130 -4
- package/src/rules/inference/dtype.rules.json +5 -0
- package/src/rules/inference/kernel-path.rules.json +2 -2
- package/src/rules/kernels/moe.rules.mixtral.json +75 -0
- package/src/rules/kernels/softmax.rules.json +2 -0
- package/src/rules/kernels/split-qg.rules.json +6 -0
- package/src/rules/rule-registry.d.ts +1 -0
- package/src/rules/rule-registry.js +4 -0
- package/src/storage/downloader.js +2 -1
- package/src/storage/quickstart-downloader.d.ts +3 -0
- package/src/storage/quickstart-downloader.js +27 -30
- package/src/storage/shard-manager.js +4 -3
- package/src/tooling/conversion-config-materializer.js +3 -5
- package/src/tooling/node-converter.js +28 -7
- package/src/tooling/node-source-runtime.js +65 -5
- package/src/tooling/node-webgpu.js +24 -7
- package/src/types/model.d.ts +5 -0
- package/src/utils/hf-resolve-url.d.ts +16 -0
- package/src/utils/hf-resolve-url.js +17 -0
- package/src/version.js +1 -1
- package/tools/doppler-cli.js +6 -1
- package/src/tooling/node-convert.d.ts +0 -54
|
@@ -117,7 +117,10 @@ function isLikelyEmbeddingGemma(rawConfig, architectureHint) {
|
|
|
117
117
|
|
|
118
118
|
export function inferSourceWeightQuantization(tensors) {
|
|
119
119
|
if (!Array.isArray(tensors) || tensors.length === 0) {
|
|
120
|
-
|
|
120
|
+
throw new Error(
|
|
121
|
+
'Cannot infer source weight quantization: no tensors provided. ' +
|
|
122
|
+
'Set converterConfig.quantization.weights explicitly.'
|
|
123
|
+
);
|
|
121
124
|
}
|
|
122
125
|
const weightTensors = [];
|
|
123
126
|
for (const tensor of tensors) {
|
|
@@ -128,7 +131,12 @@ export function inferSourceWeightQuantization(tensors) {
|
|
|
128
131
|
weightTensors.push({ name, dtype });
|
|
129
132
|
}
|
|
130
133
|
const dtypes = new Set(weightTensors.map((tensor) => tensor.dtype));
|
|
131
|
-
if (dtypes.size === 0)
|
|
134
|
+
if (dtypes.size === 0) {
|
|
135
|
+
throw new Error(
|
|
136
|
+
'Cannot infer source weight quantization: no recognizable weight dtypes found. ' +
|
|
137
|
+
'Set converterConfig.quantization.weights explicitly.'
|
|
138
|
+
);
|
|
139
|
+
}
|
|
132
140
|
if (dtypes.size > 1) {
|
|
133
141
|
const detail = Array.from(dtypes)
|
|
134
142
|
.sort()
|
|
@@ -465,7 +473,7 @@ export function resolveConversionPlan(options) {
|
|
|
465
473
|
// role dtypes should not change kernel-path selection when explicit compute precision is targeted.
|
|
466
474
|
const embedDtypeRaw = normalizeWeightDtype(findTensorDtypeByRole(tensors, 'embedding'));
|
|
467
475
|
const lmHeadDtypeRaw = normalizeWeightDtype(findTensorDtypeByRole(tensors, 'lm_head'));
|
|
468
|
-
const hasVision = hasAnyTensorPattern(tensors, ['vision_', 'vision_tower', 'vision_model', 'image_encoder']);
|
|
476
|
+
const hasVision = hasAnyTensorPattern(tensors, ['vision_', 'vision_tower', 'vision_model', 'image_encoder', 'visual.']);
|
|
469
477
|
const hasAudio = hasAnyTensorPattern(tensors, ['audio_', 'audio_encoder', 'whisper', 'wav2vec']);
|
|
470
478
|
const hasProjector = hasAnyTensorPattern(tensors, ['multi_modal_projector', 'mm_projector', 'projector']);
|
|
471
479
|
const quantizationInfo = buildQuantizationInfo(
|
package/src/converter/core.js
CHANGED
|
@@ -114,6 +114,15 @@ export function resolveTensorTargetQuant(tensorName, fallbackQuant, quantization
|
|
|
114
114
|
const headQuant = quantizationInfo.lmHead ?? quantizationInfo.embeddings ?? fallback;
|
|
115
115
|
return normalizeStorageQuant(headQuant) ?? fallback;
|
|
116
116
|
}
|
|
117
|
+
if (role === 'vision') {
|
|
118
|
+
return normalizeStorageQuant(quantizationInfo.vision ?? fallback) ?? fallback;
|
|
119
|
+
}
|
|
120
|
+
if (role === 'projector') {
|
|
121
|
+
return normalizeStorageQuant(quantizationInfo.projector ?? fallback) ?? fallback;
|
|
122
|
+
}
|
|
123
|
+
if (role === 'audio') {
|
|
124
|
+
return normalizeStorageQuant(quantizationInfo.audio ?? fallback) ?? fallback;
|
|
125
|
+
}
|
|
117
126
|
return normalizeStorageQuant(quantizationInfo.weights ?? fallback) ?? fallback;
|
|
118
127
|
}
|
|
119
128
|
|
|
@@ -819,11 +828,11 @@ export function extractArchitecture(config, ggufConfig) {
|
|
|
819
828
|
vocabSize,
|
|
820
829
|
maxSeqLen,
|
|
821
830
|
ropeTheta,
|
|
822
|
-
linearNumKeyHeads
|
|
823
|
-
linearNumValueHeads
|
|
824
|
-
linearKeyHeadDim
|
|
825
|
-
linearValueHeadDim
|
|
826
|
-
linearConvKernelDim
|
|
831
|
+
linearNumKeyHeads,
|
|
832
|
+
linearNumValueHeads,
|
|
833
|
+
linearKeyHeadDim,
|
|
834
|
+
linearValueHeadDim,
|
|
835
|
+
linearConvKernelDim,
|
|
827
836
|
linearNormMode,
|
|
828
837
|
};
|
|
829
838
|
}
|
|
@@ -983,6 +992,7 @@ export function createManifest(
|
|
|
983
992
|
isDiffusion ? 'diffusion' : extractArchitecture(model.config, model.ggufConfig)
|
|
984
993
|
);
|
|
985
994
|
const rawConfig = model.config || {};
|
|
995
|
+
const generationConfig = model.generationConfig ?? null;
|
|
986
996
|
const resolvedArchitecture = isDiffusion
|
|
987
997
|
? architecture
|
|
988
998
|
: resolveIntermediateSizeFromTensors(architecture, model, tensorLocations, rawConfig, modelId);
|
|
@@ -1037,6 +1047,7 @@ export function createManifest(
|
|
|
1037
1047
|
? null
|
|
1038
1048
|
: resolveEosTokenId({
|
|
1039
1049
|
config: rawConfig,
|
|
1050
|
+
generationConfig,
|
|
1040
1051
|
tokenizer: model.tokenizer ?? model.tokenizerConfig ?? null,
|
|
1041
1052
|
tokenizerJson: model.tokenizerJson ?? null,
|
|
1042
1053
|
});
|
|
@@ -1054,7 +1065,7 @@ export function createManifest(
|
|
|
1054
1065
|
modelId,
|
|
1055
1066
|
modelType: resolvedModelType,
|
|
1056
1067
|
quantization: resolvedQuantization,
|
|
1057
|
-
quantizationInfo: options.quantizationInfo
|
|
1068
|
+
quantizationInfo: options.quantizationInfo,
|
|
1058
1069
|
architecture: resolvedArchitecture,
|
|
1059
1070
|
moeConfig,
|
|
1060
1071
|
inference,
|
|
@@ -1063,8 +1074,8 @@ export function createManifest(
|
|
|
1063
1074
|
totalSize: shards.reduce((sum, s) => sum + s.size, 0),
|
|
1064
1075
|
hashAlgorithm,
|
|
1065
1076
|
eos_token_id: eosTokenId,
|
|
1066
|
-
config: isDiffusion ? rawConfig : undefined,
|
|
1067
|
-
conversion: options.conversionInfo
|
|
1077
|
+
config: isDiffusion ? rawConfig : (rawConfig.vision_config ? { vision_config: rawConfig.vision_config } : undefined),
|
|
1078
|
+
conversion: options.conversionInfo,
|
|
1068
1079
|
metadata: {
|
|
1069
1080
|
source,
|
|
1070
1081
|
convertedAt: resolveConvertedAt(
|
|
@@ -240,16 +240,6 @@ function detectAttentionOutputGate(presetInference, modelConfig, defaults) {
|
|
|
240
240
|
return modelConfig.attn_output_gate;
|
|
241
241
|
}
|
|
242
242
|
|
|
243
|
-
const modelType = normalizeLayerTypeName(modelConfig?.model_type);
|
|
244
|
-
const hasLinearAttentionLayers = Array.isArray(modelConfig?.layer_types)
|
|
245
|
-
&& modelConfig.layer_types.some((entry) => normalizeCustomLayerType(entry) === 'linear_attention');
|
|
246
|
-
if (
|
|
247
|
-
hasLinearAttentionLayers
|
|
248
|
-
&& (modelType === 'qwen2' || modelType === 'qwen3_5' || modelType === 'qwen3_5_text')
|
|
249
|
-
) {
|
|
250
|
-
return true;
|
|
251
|
-
}
|
|
252
|
-
|
|
253
243
|
return defaults.attention.attentionOutputGate;
|
|
254
244
|
}
|
|
255
245
|
|
|
@@ -259,21 +249,18 @@ function resolveQueryPreAttnScalar(preset, modelConfig, headDim) {
|
|
|
259
249
|
return explicit;
|
|
260
250
|
}
|
|
261
251
|
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
252
|
+
// Standard attention scaling: attnScale = 1/sqrt(queryPreAttnScalar).
|
|
253
|
+
// For standard transformers queryPreAttnScalar = headDim, giving 1/sqrt(headDim).
|
|
254
|
+
// Preset may override for non-standard models.
|
|
255
|
+
const presetScalar = Number(preset?.inference?.attention?.queryPreAttnScalar);
|
|
256
|
+
if (Number.isFinite(presetScalar) && presetScalar > 0) {
|
|
257
|
+
return presetScalar;
|
|
266
258
|
}
|
|
267
259
|
|
|
268
|
-
return
|
|
260
|
+
return headDim;
|
|
269
261
|
}
|
|
270
262
|
|
|
271
263
|
function detectRmsNormWeightOffset(presetInference, modelConfig, defaults) {
|
|
272
|
-
const modelType = normalizeLayerTypeName(modelConfig?.model_type);
|
|
273
|
-
if (modelType === 'qwen3_5' || modelType === 'qwen3_5_text') {
|
|
274
|
-
return true;
|
|
275
|
-
}
|
|
276
|
-
|
|
277
264
|
if (typeof presetInference?.normalization?.rmsNormWeightOffset === 'boolean') {
|
|
278
265
|
return presetInference.normalization.rmsNormWeightOffset;
|
|
279
266
|
}
|
|
@@ -385,8 +372,8 @@ export function buildManifestInference(preset, config, headDim = 64, quantizatio
|
|
|
385
372
|
queryPreAttnScalar: resolveQueryPreAttnScalar(preset, modelConfig, headDim),
|
|
386
373
|
attnLogitSoftcapping: presetInference.attention?.attnLogitSoftcapping ??
|
|
387
374
|
modelConfig.attn_logit_softcapping ?? defaults.attention.attnLogitSoftcapping,
|
|
388
|
-
slidingWindow:
|
|
389
|
-
|
|
375
|
+
slidingWindow: modelConfig.sliding_window ??
|
|
376
|
+
presetInference.attention?.slidingWindow ?? defaults.attention.slidingWindow,
|
|
390
377
|
queryKeyNorm: presetInference.attention?.queryKeyNorm ?? defaults.attention.queryKeyNorm,
|
|
391
378
|
attentionOutputGate: detectAttentionOutputGate(presetInference, modelConfig, defaults),
|
|
392
379
|
causal: detectedCausalAttention ?? presetInference.attention?.causal ?? defaults.attention.causal,
|
|
@@ -459,6 +446,9 @@ export function buildManifestInference(preset, config, headDim = 64, quantizatio
|
|
|
459
446
|
);
|
|
460
447
|
}
|
|
461
448
|
globalPattern = null;
|
|
449
|
+
// Default offset 0 means first global layer at index 0 (most common pattern).
|
|
450
|
+
// This is the every_n pattern default, distinct from layerPattern.offset=null
|
|
451
|
+
// which means "not applicable" in the schema.
|
|
462
452
|
offset = (
|
|
463
453
|
detectEveryNOffsetFromLayerTypes(modelConfig.layer_types, period)
|
|
464
454
|
?? normalizeEveryNOffset(presetPattern.offset, period)
|
|
@@ -7,6 +7,9 @@ export async function parseTransformerModel(adapter) {
|
|
|
7
7
|
} = adapter;
|
|
8
8
|
|
|
9
9
|
const config = await readJson('config.json', 'config.json');
|
|
10
|
+
const generationConfig = await fileExists('generation_config.json')
|
|
11
|
+
? await readJson('generation_config.json', 'generation_config.json')
|
|
12
|
+
: null;
|
|
10
13
|
const architectureHint = config.architectures?.[0] ?? config.model_type ?? '';
|
|
11
14
|
|
|
12
15
|
let tensors = null;
|
|
@@ -19,6 +22,7 @@ export async function parseTransformerModel(adapter) {
|
|
|
19
22
|
|
|
20
23
|
return {
|
|
21
24
|
config,
|
|
25
|
+
generationConfig,
|
|
22
26
|
tensors,
|
|
23
27
|
architectureHint,
|
|
24
28
|
};
|
|
@@ -2,6 +2,10 @@
|
|
|
2
2
|
import { DEFAULT_QUANTIZATION_DEFAULTS, DEFAULT_Q4K_LAYOUT } from '../config/index.js';
|
|
3
3
|
import { classifyTensorRole } from '../formats/rdrr/index.js';
|
|
4
4
|
|
|
5
|
+
// Default quantization tag when no explicit dtype is provided.
|
|
6
|
+
// F16 is the canonical unquantized storage format for WebGPU inference.
|
|
7
|
+
const DEFAULT_QUANT_TAG = 'f16';
|
|
8
|
+
|
|
5
9
|
// Quantization tag aliases mapped to canonical names.
|
|
6
10
|
// Add new aliases here rather than adding if/else branches.
|
|
7
11
|
const QUANT_TAG_ALIASES = {
|
|
@@ -47,7 +51,7 @@ const QUANT_TAG_ALIASES = {
|
|
|
47
51
|
};
|
|
48
52
|
|
|
49
53
|
export function normalizeQuantTag(value) {
|
|
50
|
-
if (!value) return
|
|
54
|
+
if (!value) return DEFAULT_QUANT_TAG;
|
|
51
55
|
const lower = value.toLowerCase();
|
|
52
56
|
return QUANT_TAG_ALIASES[lower] ?? lower;
|
|
53
57
|
}
|
|
@@ -73,6 +73,11 @@ export declare function dequantizeQ4KM(
|
|
|
73
73
|
shape: number[]
|
|
74
74
|
): Float32Array;
|
|
75
75
|
|
|
76
|
+
export declare function dequantizeQ4KMRowWise(
|
|
77
|
+
quantized: Uint8Array,
|
|
78
|
+
shape: [number, number]
|
|
79
|
+
): Float32Array;
|
|
80
|
+
|
|
76
81
|
export declare function calculateQuantizationError(
|
|
77
82
|
original: Float32Array,
|
|
78
83
|
reconstructed: Float32Array
|
|
@@ -74,9 +74,10 @@ function findMinMax(data, offset, length) {
|
|
|
74
74
|
return { min, max };
|
|
75
75
|
}
|
|
76
76
|
|
|
77
|
-
|
|
77
|
+
function quantizeQ4KBlockWithValidLength(data, offset, validLength = QK_K) {
|
|
78
78
|
const block = new Uint8Array(QK4_K_BLOCK_SIZE);
|
|
79
79
|
const blockView = new DataView(block.buffer);
|
|
80
|
+
const clampedValidLength = Math.max(0, Math.min(QK_K, Math.trunc(validLength)));
|
|
80
81
|
|
|
81
82
|
const scales = new Float32Array(8);
|
|
82
83
|
const minOffsets = new Float32Array(8);
|
|
@@ -84,14 +85,22 @@ export function quantizeQ4KBlock(data, offset) {
|
|
|
84
85
|
|
|
85
86
|
for (let sb = 0; sb < 8; sb++) {
|
|
86
87
|
const sbOffset = offset + sb * 32;
|
|
87
|
-
const
|
|
88
|
+
const subblockStart = sb * 32;
|
|
89
|
+
const validInSubblock = Math.max(0, Math.min(32, clampedValidLength - subblockStart));
|
|
90
|
+
if (validInSubblock === 0) {
|
|
91
|
+
scales[sb] = 0;
|
|
92
|
+
minOffsets[sb] = 0;
|
|
93
|
+
continue;
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
const { min, max } = findMinMax(data, sbOffset, validInSubblock);
|
|
88
97
|
|
|
89
98
|
minOffsets[sb] = -min;
|
|
90
99
|
const range = max - min;
|
|
91
100
|
scales[sb] = range > 0 ? range / 15 : 0;
|
|
92
101
|
|
|
93
102
|
const invScale = scales[sb] > 0 ? 1 / scales[sb] : 0;
|
|
94
|
-
for (let i = 0; i <
|
|
103
|
+
for (let i = 0; i < validInSubblock; i++) {
|
|
95
104
|
const val = data[sbOffset + i];
|
|
96
105
|
let q = Math.round((val - min) * invScale);
|
|
97
106
|
q = Math.max(0, Math.min(15, q));
|
|
@@ -155,6 +164,10 @@ export function quantizeQ4KBlock(data, offset) {
|
|
|
155
164
|
return block;
|
|
156
165
|
}
|
|
157
166
|
|
|
167
|
+
export function quantizeQ4KBlock(data, offset) {
|
|
168
|
+
return quantizeQ4KBlockWithValidLength(data, offset, QK_K);
|
|
169
|
+
}
|
|
170
|
+
|
|
158
171
|
function dequantizeQ4KBlock(block) {
|
|
159
172
|
const blockView = new DataView(block.buffer, block.byteOffset);
|
|
160
173
|
const result = new Float32Array(256);
|
|
@@ -245,22 +258,16 @@ export function quantizeToQ4KMRowWise(data, shape) {
|
|
|
245
258
|
}
|
|
246
259
|
|
|
247
260
|
const blocksPerRow = Math.ceil(cols / QK_K);
|
|
248
|
-
const paddedColsPerRow = blocksPerRow * QK_K;
|
|
249
261
|
const totalBlocks = rows * blocksPerRow;
|
|
250
262
|
|
|
251
263
|
const quantized = new Uint8Array(totalBlocks * QK4_K_BLOCK_SIZE);
|
|
252
264
|
|
|
253
265
|
for (let row = 0; row < rows; row++) {
|
|
254
|
-
// Extract and pad this row
|
|
255
|
-
const rowData = new Float32Array(paddedColsPerRow);
|
|
256
|
-
const srcOffset = row * cols;
|
|
257
|
-
for (let c = 0; c < cols; c++) {
|
|
258
|
-
rowData[c] = data[srcOffset + c];
|
|
259
|
-
}
|
|
260
|
-
|
|
261
266
|
// Quantize each block in this row
|
|
262
267
|
for (let b = 0; b < blocksPerRow; b++) {
|
|
263
|
-
const
|
|
268
|
+
const validLength = Math.max(0, Math.min(QK_K, cols - b * QK_K));
|
|
269
|
+
const srcOffset = row * cols + b * QK_K;
|
|
270
|
+
const block = quantizeQ4KBlockWithValidLength(data, srcOffset, validLength);
|
|
264
271
|
const dstOffset = (row * blocksPerRow + b) * QK4_K_BLOCK_SIZE;
|
|
265
272
|
quantized.set(block, dstOffset);
|
|
266
273
|
}
|
|
@@ -348,6 +355,21 @@ export function dequantizeQ4KM(quantized, numBlocks, shape) {
|
|
|
348
355
|
return result;
|
|
349
356
|
}
|
|
350
357
|
|
|
358
|
+
export function dequantizeQ4KMRowWise(quantized, shape) {
|
|
359
|
+
const [rows, cols] = shape;
|
|
360
|
+
const blocksPerRow = Math.ceil(cols / QK_K);
|
|
361
|
+
const result = new Float32Array(rows * cols);
|
|
362
|
+
|
|
363
|
+
for (let row = 0; row < rows; row++) {
|
|
364
|
+
const rowOffset = row * blocksPerRow * QK4_K_BLOCK_SIZE;
|
|
365
|
+
const rowBytes = quantized.slice(rowOffset, rowOffset + (blocksPerRow * QK4_K_BLOCK_SIZE));
|
|
366
|
+
const rowDequantized = dequantizeQ4KM(rowBytes, blocksPerRow, [1, cols]);
|
|
367
|
+
result.set(rowDequantized, row * cols);
|
|
368
|
+
}
|
|
369
|
+
|
|
370
|
+
return result;
|
|
371
|
+
}
|
|
372
|
+
|
|
351
373
|
export function calculateQuantizationError(original, reconstructed) {
|
|
352
374
|
if (original.length !== reconstructed.length) {
|
|
353
375
|
throw new Error('Length mismatch');
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
import { DEFAULT_MANIFEST_INFERENCE } from '../config/schema/index.js';
|
|
2
|
+
|
|
1
3
|
function asObject(value) {
|
|
2
4
|
if (value == null || typeof value !== 'object' || Array.isArray(value)) {
|
|
3
5
|
return null;
|
|
@@ -50,7 +52,7 @@ function resolveScalingConfig(ropeScalingConfig, options = {}) {
|
|
|
50
52
|
}
|
|
51
53
|
return {
|
|
52
54
|
ropeScalingType: null,
|
|
53
|
-
ropeScalingFactor:
|
|
55
|
+
ropeScalingFactor: DEFAULT_MANIFEST_INFERENCE.rope.ropeScalingFactor,
|
|
54
56
|
yarnBetaFast: null,
|
|
55
57
|
yarnBetaSlow: null,
|
|
56
58
|
yarnOriginalMaxPos: null,
|
|
@@ -58,7 +60,7 @@ function resolveScalingConfig(ropeScalingConfig, options = {}) {
|
|
|
58
60
|
}
|
|
59
61
|
|
|
60
62
|
let ropeScalingType = scalingType;
|
|
61
|
-
let ropeScalingFactor =
|
|
63
|
+
let ropeScalingFactor = DEFAULT_MANIFEST_INFERENCE.rope.ropeScalingFactor;
|
|
62
64
|
let yarnBetaFast = null;
|
|
63
65
|
let yarnBetaSlow = null;
|
|
64
66
|
let yarnOriginalMaxPos = null;
|
|
@@ -110,7 +112,7 @@ function hasScalingDirective(ropeScalingConfig) {
|
|
|
110
112
|
function hasMeaningfulScalingConfig(resolvedScaling) {
|
|
111
113
|
if (!resolvedScaling) return false;
|
|
112
114
|
return resolvedScaling.ropeScalingType != null
|
|
113
|
-
|| resolvedScaling.ropeScalingFactor !==
|
|
115
|
+
|| resolvedScaling.ropeScalingFactor !== DEFAULT_MANIFEST_INFERENCE.rope.ropeScalingFactor
|
|
114
116
|
|| resolvedScaling.yarnBetaFast != null
|
|
115
117
|
|| resolvedScaling.yarnBetaSlow != null
|
|
116
118
|
|| resolvedScaling.yarnOriginalMaxPos != null;
|
|
@@ -159,7 +161,7 @@ export function buildRoPEConfig(presetInference, config) {
|
|
|
159
161
|
?? null,
|
|
160
162
|
ropeScalingFactor: presetRoPE.ropeScalingFactor
|
|
161
163
|
?? presetAttn?.ropeScalingFactor // Deprecated location
|
|
162
|
-
??
|
|
164
|
+
?? DEFAULT_MANIFEST_INFERENCE.rope.ropeScalingFactor,
|
|
163
165
|
yarnBetaFast: presetRoPE.yarnBetaFast ?? null,
|
|
164
166
|
yarnBetaSlow: presetRoPE.yarnBetaSlow ?? null,
|
|
165
167
|
yarnOriginalMaxPos: presetRoPE.yarnOriginalMaxPos ?? null,
|
|
@@ -223,7 +225,7 @@ export function buildRoPEConfig(presetInference, config) {
|
|
|
223
225
|
?? asFiniteNumber(flatRoPEParameters?.rope_theta)
|
|
224
226
|
?? asFiniteNumber(config.rope_theta)
|
|
225
227
|
?? presetInference.rope?.ropeTheta
|
|
226
|
-
??
|
|
228
|
+
?? DEFAULT_MANIFEST_INFERENCE.rope.ropeTheta;
|
|
227
229
|
|
|
228
230
|
// For Gemma 3, local sliding attention theta comes from rope_parameters.sliding_attention.
|
|
229
231
|
const ropeLocalTheta = asFiniteNumber(slidingAttentionRoPE?.rope_theta)
|
|
@@ -232,7 +234,7 @@ export function buildRoPEConfig(presetInference, config) {
|
|
|
232
234
|
|
|
233
235
|
const mropeInterleaved = asBoolean(flatRoPEParameters?.mrope_interleaved)
|
|
234
236
|
?? presetInference.rope?.mropeInterleaved
|
|
235
|
-
??
|
|
237
|
+
?? DEFAULT_MANIFEST_INFERENCE.rope.mropeInterleaved;
|
|
236
238
|
const mropeSection = asNumberArray(flatRoPEParameters?.mrope_section)
|
|
237
239
|
?? presetInference.rope?.mropeSection
|
|
238
240
|
?? null;
|
|
@@ -1,6 +1,8 @@
|
|
|
1
|
-
export function resolveEosTokenId({ config, tokenizer, tokenizerJson }) {
|
|
1
|
+
export function resolveEosTokenId({ config, generationConfig, tokenizer, tokenizerJson }) {
|
|
2
2
|
const nestedTextConfig = getNestedTextConfig(config);
|
|
3
3
|
const candidateSources = [
|
|
4
|
+
generationConfig?.eos_token_id,
|
|
5
|
+
generationConfig?.eos_token_ids,
|
|
4
6
|
tokenizer?.eosTokenId,
|
|
5
7
|
tokenizer?.eos_token_id,
|
|
6
8
|
tokenizerJson?.specialTokens?.eos,
|
|
@@ -19,6 +21,7 @@ export function resolveEosTokenId({ config, tokenizer, tokenizerJson }) {
|
|
|
19
21
|
}
|
|
20
22
|
|
|
21
23
|
const eosTokenStringCandidates = [
|
|
24
|
+
generationConfig?.eos_token,
|
|
22
25
|
tokenizer?.eosToken,
|
|
23
26
|
tokenizer?.eos_token,
|
|
24
27
|
tokenizerJson?.specialTokens?.eos_token,
|
|
@@ -0,0 +1,268 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Dump intermediate values from Qwen3.5 linear attention (GatedDeltaNet) for comparison with Doppler.
|
|
4
|
+
|
|
5
|
+
Usage:
|
|
6
|
+
HF_HOME=/media/x/models/huggingface_cache python3 src/debug/reference/hf_qwen35_linear_attn_debug.py
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import os
|
|
10
|
+
import torch
|
|
11
|
+
import numpy as np
|
|
12
|
+
|
|
13
|
+
os.environ.setdefault("HF_HOME", "/media/x/models/huggingface_cache")
|
|
14
|
+
|
|
15
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
16
|
+
|
|
17
|
+
MODEL_ID = "Qwen/Qwen3.5-0.8B"
|
|
18
|
+
PROMPT = "Hello"
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def stats(name, tensor):
|
|
22
|
+
t = tensor.float().detach().flatten()
|
|
23
|
+
print(f" {name}: shape={list(tensor.shape)}, "
|
|
24
|
+
f"min={t.min().item():.6f}, max={t.max().item():.6f}, "
|
|
25
|
+
f"mean={t.mean().item():.6f}, absMax={t.abs().max().item():.6f}")
|
|
26
|
+
first8 = t[:8].tolist()
|
|
27
|
+
print(f" first8: {[f'{v:.6f}' for v in first8]}")
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def main():
|
|
31
|
+
print(f"Loading {MODEL_ID}...")
|
|
32
|
+
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, dtype=torch.float32)
|
|
33
|
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
|
34
|
+
model.eval()
|
|
35
|
+
|
|
36
|
+
inputs = tokenizer(PROMPT, return_tensors="pt")
|
|
37
|
+
input_ids = inputs["input_ids"]
|
|
38
|
+
print(f"Prompt: '{PROMPT}', Token IDs: {input_ids[0].tolist()}")
|
|
39
|
+
num_tokens = input_ids.shape[1]
|
|
40
|
+
|
|
41
|
+
# Dump key weight values for layer 0
|
|
42
|
+
layer0 = model.model.layers[0]
|
|
43
|
+
attn = layer0.linear_attn
|
|
44
|
+
|
|
45
|
+
print(f"\n=== Layer 0 weights ===")
|
|
46
|
+
if hasattr(attn, 'A_log'):
|
|
47
|
+
a_log = attn.A_log.detach().float()
|
|
48
|
+
a_neg_exp = -torch.exp(a_log)
|
|
49
|
+
stats("A_log", a_log)
|
|
50
|
+
stats("a_neg_exp", a_neg_exp)
|
|
51
|
+
if hasattr(attn, 'dt_bias'):
|
|
52
|
+
stats("dt_bias", attn.dt_bias.detach().float())
|
|
53
|
+
stats("conv1d.weight", attn.conv1d.weight.detach().float())
|
|
54
|
+
stats("norm.weight", attn.norm.weight.detach().float())
|
|
55
|
+
|
|
56
|
+
# Hook into the linear_attn module to capture its input and output
|
|
57
|
+
captured = {}
|
|
58
|
+
|
|
59
|
+
def hook_linear_attn_input(module, args, kwargs):
|
|
60
|
+
if len(args) > 0:
|
|
61
|
+
captured['linear_attn_input'] = args[0].detach().clone()
|
|
62
|
+
return None
|
|
63
|
+
|
|
64
|
+
def hook_linear_attn_output(module, args, kwargs, output):
|
|
65
|
+
if isinstance(output, tuple):
|
|
66
|
+
captured['linear_attn_output'] = output[0].detach().clone()
|
|
67
|
+
else:
|
|
68
|
+
captured['linear_attn_output'] = output.detach().clone()
|
|
69
|
+
return None
|
|
70
|
+
|
|
71
|
+
# Hook into individual projection layers
|
|
72
|
+
def make_hook(name):
|
|
73
|
+
def hook(module, input, output):
|
|
74
|
+
captured[name] = output.detach().clone()
|
|
75
|
+
return hook
|
|
76
|
+
|
|
77
|
+
hooks = []
|
|
78
|
+
hooks.append(attn.register_forward_pre_hook(hook_linear_attn_input, with_kwargs=True))
|
|
79
|
+
hooks.append(attn.register_forward_hook(hook_linear_attn_output, with_kwargs=True))
|
|
80
|
+
hooks.append(attn.in_proj_qkv.register_forward_hook(make_hook('qkv_proj')))
|
|
81
|
+
hooks.append(attn.in_proj_z.register_forward_hook(make_hook('z_proj')))
|
|
82
|
+
hooks.append(attn.in_proj_a.register_forward_hook(make_hook('a_proj')))
|
|
83
|
+
hooks.append(attn.in_proj_b.register_forward_hook(make_hook('b_proj')))
|
|
84
|
+
hooks.append(attn.out_proj.register_forward_hook(make_hook('out_proj')))
|
|
85
|
+
hooks.append(attn.conv1d.register_forward_hook(make_hook('conv1d_raw')))
|
|
86
|
+
hooks.append(attn.norm.register_forward_hook(make_hook('gated_norm')))
|
|
87
|
+
|
|
88
|
+
# Also hook input_layernorm
|
|
89
|
+
hooks.append(layer0.input_layernorm.register_forward_hook(make_hook('input_layernorm')))
|
|
90
|
+
|
|
91
|
+
print(f"\n=== Running forward pass ===")
|
|
92
|
+
with torch.no_grad():
|
|
93
|
+
outputs = model(input_ids, output_hidden_states=True)
|
|
94
|
+
|
|
95
|
+
# Remove hooks
|
|
96
|
+
for h in hooks:
|
|
97
|
+
h.remove()
|
|
98
|
+
|
|
99
|
+
print(f"\n=== Captured intermediates ===")
|
|
100
|
+
for name in ['input_layernorm', 'qkv_proj', 'z_proj', 'a_proj', 'b_proj',
|
|
101
|
+
'conv1d_raw', 'gated_norm', 'linear_attn_input', 'linear_attn_output', 'out_proj']:
|
|
102
|
+
if name in captured:
|
|
103
|
+
stats(name, captured[name])
|
|
104
|
+
else:
|
|
105
|
+
print(f" {name}: NOT CAPTURED")
|
|
106
|
+
|
|
107
|
+
# Hidden states per layer
|
|
108
|
+
print(f"\n=== Hidden states per layer (last token) ===")
|
|
109
|
+
for i in range(min(6, len(outputs.hidden_states) - 1)):
|
|
110
|
+
hs = outputs.hidden_states[i + 1]
|
|
111
|
+
t = hs[0, -1] # last token
|
|
112
|
+
vals = t[:8].tolist()
|
|
113
|
+
max_abs = t.abs().max().item()
|
|
114
|
+
mean_abs = t.abs().mean().item()
|
|
115
|
+
layer_type = type(model.model.layers[i]).__name__
|
|
116
|
+
attn_type = "linear" if hasattr(model.model.layers[i], 'linear_attn') else "full"
|
|
117
|
+
print(f" Layer {i} ({attn_type}): first8={[f'{v:.4f}' for v in vals]}, "
|
|
118
|
+
f"maxAbs={max_abs:.4f}, meanAbs={mean_abs:.4f}")
|
|
119
|
+
|
|
120
|
+
# Logits
|
|
121
|
+
logits = outputs.logits[0, -1]
|
|
122
|
+
top5 = torch.topk(logits, 5)
|
|
123
|
+
print(f"\nTop-5 logits: {[(tokenizer.decode([idx.item()]), f'{val.item():.2f}') for val, idx in zip(top5.values, top5.indices)]}")
|
|
124
|
+
|
|
125
|
+
# Also trace through the linear attention manually to compare with Doppler's kernel
|
|
126
|
+
print(f"\n=== Manual linear attention trace (layer 0) ===")
|
|
127
|
+
with torch.no_grad():
|
|
128
|
+
embed = model.model.embed_tokens(input_ids)
|
|
129
|
+
normed = layer0.input_layernorm(embed)
|
|
130
|
+
stats("normed_input", normed)
|
|
131
|
+
|
|
132
|
+
qkv = attn.in_proj_qkv(normed)
|
|
133
|
+
stats("qkv", qkv)
|
|
134
|
+
|
|
135
|
+
# The HF Qwen3.5 GatedDeltaNet does conv1d on the QKV, then applies SiLU
|
|
136
|
+
# The conv1d expects [batch, channels, seq_len] format
|
|
137
|
+
qkv_t = qkv.transpose(1, 2) # [1, 6144, 1]
|
|
138
|
+
|
|
139
|
+
# Use the conv1d module directly (it has padding configured)
|
|
140
|
+
conv_raw = attn.conv1d(qkv_t)
|
|
141
|
+
stats("conv_raw (from module)", conv_raw.transpose(1, 2))
|
|
142
|
+
|
|
143
|
+
# Truncate to seq_len (causal conv padding)
|
|
144
|
+
conv_causal = conv_raw[..., :num_tokens]
|
|
145
|
+
stats("conv_causal (truncated)", conv_causal.transpose(1, 2))
|
|
146
|
+
|
|
147
|
+
# Apply SiLU
|
|
148
|
+
conv_silu = torch.nn.functional.silu(conv_causal)
|
|
149
|
+
stats("conv_silu", conv_silu.transpose(1, 2))
|
|
150
|
+
|
|
151
|
+
# Split Q, K, V
|
|
152
|
+
conv_out = conv_silu.transpose(1, 2) # [1, seq_len, 6144]
|
|
153
|
+
num_k_heads = 16
|
|
154
|
+
head_k_dim = 128
|
|
155
|
+
head_v_dim = 128
|
|
156
|
+
num_v_heads = 16
|
|
157
|
+
q_size = num_k_heads * head_k_dim # 2048
|
|
158
|
+
k_size = q_size
|
|
159
|
+
v_size = num_v_heads * head_v_dim # 2048
|
|
160
|
+
|
|
161
|
+
q = conv_out[..., :q_size]
|
|
162
|
+
k = conv_out[..., q_size:q_size + k_size]
|
|
163
|
+
v = conv_out[..., q_size + k_size:]
|
|
164
|
+
stats("Q (raw)", q)
|
|
165
|
+
stats("K (raw)", k)
|
|
166
|
+
stats("V (raw)", v)
|
|
167
|
+
|
|
168
|
+
# Reshape for per-head processing
|
|
169
|
+
# Q and K: [batch, seq, num_k_heads, head_k_dim]
|
|
170
|
+
q_heads = q.view(1, num_tokens, num_k_heads, head_k_dim)
|
|
171
|
+
k_heads = k.view(1, num_tokens, num_k_heads, head_k_dim)
|
|
172
|
+
v_heads = v.view(1, num_tokens, num_v_heads, head_v_dim)
|
|
173
|
+
|
|
174
|
+
# L2 normalize Q and K
|
|
175
|
+
eps = 1e-6
|
|
176
|
+
q_norm = torch.nn.functional.normalize(q_heads, p=2, dim=-1, eps=eps)
|
|
177
|
+
k_norm = torch.nn.functional.normalize(k_heads, p=2, dim=-1, eps=eps)
|
|
178
|
+
|
|
179
|
+
# Scale Q by 1/sqrt(head_k_dim)
|
|
180
|
+
head_scale = 1.0 / (head_k_dim ** 0.5)
|
|
181
|
+
q_scaled = q_norm * head_scale
|
|
182
|
+
|
|
183
|
+
stats("Q_normed_scaled (per-head)", q_scaled.reshape(1, num_tokens, -1))
|
|
184
|
+
stats("K_normed (per-head)", k_norm.reshape(1, num_tokens, -1))
|
|
185
|
+
|
|
186
|
+
# Projections for gating
|
|
187
|
+
z = attn.in_proj_z(normed)
|
|
188
|
+
a_out = attn.in_proj_a(normed)
|
|
189
|
+
b_out = attn.in_proj_b(normed)
|
|
190
|
+
stats("z", z)
|
|
191
|
+
stats("a", a_out)
|
|
192
|
+
stats("b", b_out)
|
|
193
|
+
|
|
194
|
+
# Compute gating values
|
|
195
|
+
a_log = attn.A_log.detach().float()
|
|
196
|
+
a_neg_exp = -torch.exp(a_log)
|
|
197
|
+
dt_bias = attn.dt_bias.detach().float()
|
|
198
|
+
|
|
199
|
+
softplus_input = a_out.squeeze(0).squeeze(0) + dt_bias
|
|
200
|
+
softplus_val = torch.nn.functional.softplus(softplus_input)
|
|
201
|
+
g = a_neg_exp * softplus_val
|
|
202
|
+
g_exp = torch.exp(g)
|
|
203
|
+
beta = torch.sigmoid(b_out.squeeze(0).squeeze(0))
|
|
204
|
+
|
|
205
|
+
stats("softplus(a + dt_bias)", softplus_val.unsqueeze(0).unsqueeze(0))
|
|
206
|
+
stats("g (decay)", g.unsqueeze(0).unsqueeze(0))
|
|
207
|
+
stats("g_exp (decay factor)", g_exp.unsqueeze(0).unsqueeze(0))
|
|
208
|
+
stats("beta (sigmoid(b))", beta.unsqueeze(0).unsqueeze(0))
|
|
209
|
+
|
|
210
|
+
# Recurrent state update (for first token, state is all zeros)
|
|
211
|
+
# state[head, kd, vd] = state * g_exp + k[kd] * delta[vd]
|
|
212
|
+
# where delta[vd] = (v[vd] - state^T @ k * beta
|
|
213
|
+
# For zero state: delta[vd] = v[vd] * beta, state = k ⊗ delta
|
|
214
|
+
state = torch.zeros(num_v_heads, head_k_dim, head_v_dim)
|
|
215
|
+
|
|
216
|
+
# Apply decay (no-op for zero state)
|
|
217
|
+
for head in range(num_v_heads):
|
|
218
|
+
state[head] *= g_exp[head].item()
|
|
219
|
+
|
|
220
|
+
k_head = k_norm[0, 0, head % num_k_heads] # broadcast q_rep
|
|
221
|
+
v_head = v_heads[0, 0, head]
|
|
222
|
+
|
|
223
|
+
# kv_mem = state @ k
|
|
224
|
+
kv_mem = state[head].t() @ k_head # [head_v_dim]
|
|
225
|
+
|
|
226
|
+
# delta = (v - kv_mem) * beta
|
|
227
|
+
delta = (v_head - kv_mem) * beta[head].item()
|
|
228
|
+
|
|
229
|
+
# state += outer(k, delta)
|
|
230
|
+
state[head] += torch.outer(k_head, delta)
|
|
231
|
+
|
|
232
|
+
# Output: out = state^T @ q
|
|
233
|
+
output_per_head = torch.zeros(1, num_tokens, num_v_heads, head_v_dim)
|
|
234
|
+
for head in range(num_v_heads):
|
|
235
|
+
q_head = q_scaled[0, 0, head % num_k_heads]
|
|
236
|
+
out_head = state[head].t() @ q_head # [head_v_dim]
|
|
237
|
+
output_per_head[0, 0, head] = out_head
|
|
238
|
+
|
|
239
|
+
raw_out = output_per_head.reshape(1, num_tokens, num_v_heads * head_v_dim)
|
|
240
|
+
stats("Recurrent output (raw)", raw_out)
|
|
241
|
+
|
|
242
|
+
# RMS norm per head + SiLU gate
|
|
243
|
+
z_reshaped = z.view(1, num_tokens, num_v_heads, head_v_dim)
|
|
244
|
+
norm_weight = attn.norm.weight.detach().float() # [head_v_dim] (shared mode)
|
|
245
|
+
rms_eps = 1e-6
|
|
246
|
+
|
|
247
|
+
for head in range(num_v_heads):
|
|
248
|
+
head_out = output_per_head[0, 0, head] # [head_v_dim]
|
|
249
|
+
mean_sq = (head_out ** 2).mean()
|
|
250
|
+
inv_rms = 1.0 / torch.sqrt(mean_sq + rms_eps)
|
|
251
|
+
z_gate = torch.nn.functional.silu(z_reshaped[0, 0, head])
|
|
252
|
+
output_per_head[0, 0, head] = head_out * inv_rms * norm_weight * z_gate
|
|
253
|
+
|
|
254
|
+
gated_out = output_per_head.reshape(1, num_tokens, num_v_heads * head_v_dim)
|
|
255
|
+
stats("After RMSNorm + SiLU gate", gated_out)
|
|
256
|
+
|
|
257
|
+
# Output projection
|
|
258
|
+
o_result = torch.nn.functional.linear(gated_out, attn.out_proj.weight)
|
|
259
|
+
stats("After out_proj", o_result)
|
|
260
|
+
|
|
261
|
+
# Compare with captured output
|
|
262
|
+
if 'linear_attn_output' in captured:
|
|
263
|
+
diff = (o_result - captured['linear_attn_output']).abs()
|
|
264
|
+
print(f"\n Diff vs captured output: maxDiff={diff.max().item():.6f}")
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
if __name__ == "__main__":
|
|
268
|
+
main()
|