@simulatte/doppler 0.1.6 → 0.1.8
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +145 -0
- package/README.md +16 -23
- package/package.json +30 -32
- package/src/adapters/adapter-registry.js +12 -1
- package/src/adapters/lora-loader.js +23 -6
- package/src/bridge/extension-client.d.ts +5 -0
- package/src/bridge/extension-client.js +40 -0
- package/src/bridge/index.d.ts +2 -1
- package/src/bridge/index.js +6 -4
- package/src/browser/browser-converter.js +31 -1
- package/src/browser/file-picker.js +6 -0
- package/src/browser/safetensors-parser-browser.js +84 -1
- package/src/browser/shard-io-browser.js +2 -2
- package/src/browser/tensor-source-download.js +8 -2
- package/src/browser/tensor-source-http.d.ts +1 -0
- package/src/browser/tensor-source-http.js +5 -1
- package/src/client/doppler-api.browser.js +20 -4
- package/src/client/doppler-api.js +19 -3
- package/src/client/doppler-provider/generation.js +12 -0
- package/src/client/doppler-provider/model-manager.d.ts +10 -0
- package/src/client/doppler-provider/model-manager.js +91 -19
- package/src/client/doppler-provider/source-runtime.d.ts +2 -1
- package/src/client/doppler-provider/source-runtime.js +132 -13
- package/src/client/doppler-registry.json +5 -20
- package/src/config/backward-registry-loader.js +17 -2
- package/src/config/execution-v0-contract-check.js +113 -15
- package/src/config/kernel-path-contract-check.js +57 -29
- package/src/config/kernel-path-loader.d.ts +5 -0
- package/src/config/kernel-path-loader.js +18 -36
- package/src/config/kernels/kernel-ref-digests.js +1 -1
- package/src/config/kernels/registry.js +14 -1
- package/src/config/kernels/registry.json +81 -5
- package/src/config/loader.d.ts +1 -1
- package/src/config/loader.js +15 -2
- package/src/config/merge-contract-check.js +66 -4
- package/src/config/merge-helpers.js +128 -7
- package/src/config/merge.d.ts +1 -0
- package/src/config/merge.js +10 -0
- package/src/config/param-validator.js +47 -2
- package/src/config/presets/kernel-paths/{gemma2-q4k-dequant-f32a.json → gemma2-q4k-dequant-f32a-nosubgroups.json} +3 -3
- package/src/config/presets/kernel-paths/gemma3-f16-fused-f32a-online-streamingprefill.json +223 -0
- package/src/config/presets/kernel-paths/{gemma3-q4k-dequant-f32a.json → gemma3-q4k-dequant-f32a-nosubgroups.json} +3 -3
- package/src/config/presets/kernel-paths/gemma3-q4k-dequant-f32w-f32a-online.json +56 -0
- package/src/config/presets/kernel-paths/lfm2-q4k-dequant-f32a-nosubgroups.json +61 -0
- package/src/config/presets/kernel-paths/registry.json +43 -8
- package/src/config/presets/models/gemma2.json +3 -2
- package/src/config/presets/models/gemma3.json +2 -0
- package/src/config/presets/models/qwen3.json +4 -3
- package/src/config/presets/models/qwen3_5.json +16 -0
- package/src/config/presets/runtime/experiments/bench/gemma3-bench-q4k.json +1 -1
- package/src/config/presets/runtime/experiments/debug/gemma3-debug-q4k.json +1 -1
- package/src/config/presets/runtime/experiments/verify/gemma3-verify.json +1 -1
- package/src/config/presets/runtime/kernels/dequant-f16-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/dequant-f32-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/embeddinggemma-q4k-dequant-f32a.json +37 -0
- package/src/config/presets/runtime/kernels/fused-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f16a.json +33 -0
- package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f32a-nosubgroups.json +33 -0
- package/src/config/presets/runtime/kernels/gemma2-q4k-fused-f32a.json +33 -0
- package/src/config/presets/runtime/kernels/safe-q4k.json +6 -13
- package/src/config/presets/runtime/model/qwen3-5-layer-probe.json +52 -0
- package/src/config/presets/runtime/model/qwen3-5-linear-attn-debug.json +90 -0
- package/src/config/presets/runtime/platform/metal-apple-q4k.json +1 -1
- package/src/config/runtime.js +6 -1
- package/src/config/schema/conversion.schema.d.ts +1 -0
- package/src/config/schema/debug.schema.d.ts +5 -0
- package/src/config/schema/doppler.schema.js +16 -21
- package/src/config/schema/inference-defaults.schema.js +3 -3
- package/src/config/schema/kernel-path.schema.d.ts +5 -1
- package/src/config/schema/kernel-thresholds.schema.js +12 -4
- package/src/config/schema/manifest.schema.d.ts +3 -2
- package/src/config/schema/manifest.schema.js +17 -4
- package/src/config/schema/storage.schema.js +1 -1
- package/src/config/training-defaults.js +30 -22
- package/src/converter/conversion-plan.js +104 -11
- package/src/converter/core.d.ts +7 -0
- package/src/converter/core.js +16 -9
- package/src/converter/execution-v0-manifest.js +4 -1
- package/src/converter/index.d.ts +1 -0
- package/src/converter/index.js +1 -0
- package/src/converter/manifest-inference.js +50 -29
- package/src/converter/parsers/diffusion.js +0 -3
- package/src/converter/parsers/transformer.js +4 -0
- package/src/converter/quantization-info.js +40 -16
- package/src/converter/quantizer.js +19 -12
- package/src/converter/rope-config.js +8 -6
- package/src/converter/shard-packer.d.ts +1 -1
- package/src/converter/shard-packer.js +4 -1
- package/src/converter/tokenizer-utils.d.ts +1 -0
- package/src/converter/tokenizer-utils.js +4 -1
- package/src/debug/config.js +123 -11
- package/src/debug/reference/hf_qwen35_linear_attn_debug.py +268 -0
- package/src/debug/signals.js +7 -1
- package/src/debug/tensor.d.ts +2 -0
- package/src/debug/tensor.js +13 -2
- package/src/distribution/p2p-control-plane.js +52 -12
- package/src/distribution/p2p-observability.js +43 -7
- package/src/distribution/p2p-webrtc-browser.js +20 -0
- package/src/distribution/shard-delivery.js +83 -27
- package/src/formats/gguf/types.js +33 -16
- package/src/formats/rdrr/groups.d.ts +12 -4
- package/src/formats/rdrr/groups.js +3 -6
- package/src/formats/rdrr/parsing.d.ts +4 -0
- package/src/formats/rdrr/parsing.js +53 -3
- package/src/formats/rdrr/types.d.ts +2 -1
- package/src/gpu/command-recorder.js +86 -61
- package/src/gpu/device.d.ts +1 -0
- package/src/gpu/device.js +73 -19
- package/src/gpu/kernel-tuner/benchmarks.js +326 -316
- package/src/gpu/kernel-tuner/cache.js +71 -4
- package/src/gpu/kernel-tuner/tuner.js +22 -4
- package/src/gpu/kernels/attention.js +15 -34
- package/src/gpu/kernels/backward/adam.js +62 -58
- package/src/gpu/kernels/backward/attention_backward.js +257 -169
- package/src/gpu/kernels/backward/conv2d_backward.js +14 -1
- package/src/gpu/kernels/cast.js +191 -149
- package/src/gpu/kernels/check-stop.js +33 -44
- package/src/gpu/kernels/conv2d.js +27 -17
- package/src/gpu/kernels/cross_entropy_loss.js +21 -15
- package/src/gpu/kernels/depthwise_conv2d.js +36 -26
- package/src/gpu/kernels/dequant.js +178 -126
- package/src/gpu/kernels/energy.d.ts +3 -21
- package/src/gpu/kernels/energy.js +111 -88
- package/src/gpu/kernels/feature-check.js +1 -1
- package/src/gpu/kernels/fused_ffn.js +84 -65
- package/src/gpu/kernels/fused_matmul_residual.js +56 -33
- package/src/gpu/kernels/fused_matmul_rmsnorm.js +62 -45
- package/src/gpu/kernels/gather.js +33 -15
- package/src/gpu/kernels/gelu.js +19 -11
- package/src/gpu/kernels/grouped_pointwise_conv2d.js +33 -23
- package/src/gpu/kernels/groupnorm.js +34 -23
- package/src/gpu/kernels/index.d.ts +8 -0
- package/src/gpu/kernels/index.js +6 -0
- package/src/gpu/kernels/kv-quantize.js +5 -2
- package/src/gpu/kernels/layernorm.js +35 -19
- package/src/gpu/kernels/logit-merge.js +5 -3
- package/src/gpu/kernels/matmul-selection.js +47 -4
- package/src/gpu/kernels/matmul.d.ts +2 -0
- package/src/gpu/kernels/matmul.js +59 -40
- package/src/gpu/kernels/modulate.js +23 -15
- package/src/gpu/kernels/moe.js +221 -175
- package/src/gpu/kernels/pixel_shuffle.js +22 -14
- package/src/gpu/kernels/relu.js +18 -10
- package/src/gpu/kernels/repeat_channels.js +25 -17
- package/src/gpu/kernels/residual.js +37 -27
- package/src/gpu/kernels/rmsnorm.js +66 -43
- package/src/gpu/kernels/rope.js +3 -0
- package/src/gpu/kernels/sample.js +27 -38
- package/src/gpu/kernels/sana_linear_attention.js +18 -10
- package/src/gpu/kernels/scale.js +18 -11
- package/src/gpu/kernels/shader-cache.js +4 -2
- package/src/gpu/kernels/silu.js +120 -72
- package/src/gpu/kernels/softmax.js +44 -25
- package/src/gpu/kernels/split_qg.d.ts +50 -0
- package/src/gpu/kernels/split_qg.js +46 -0
- package/src/gpu/kernels/split_qg.wgsl +58 -0
- package/src/gpu/kernels/split_qg_f16.wgsl +62 -0
- package/src/gpu/kernels/split_qkv.js +23 -13
- package/src/gpu/kernels/transpose.js +18 -10
- package/src/gpu/kernels/transpose.wgsl +5 -3
- package/src/gpu/kernels/upsample2d.js +21 -13
- package/src/gpu/kernels/utils.js +20 -13
- package/src/gpu/partitioned-buffer-pool.js +10 -2
- package/src/gpu/perf-guards.js +2 -9
- package/src/gpu/profiler.js +27 -22
- package/src/gpu/readback-utils.d.ts +16 -0
- package/src/gpu/readback-utils.js +41 -0
- package/src/gpu/submit-tracker.js +13 -0
- package/src/gpu/uniform-cache.d.ts +1 -0
- package/src/gpu/uniform-cache.js +30 -9
- package/src/gpu/weight-buffer.d.ts +1 -1
- package/src/gpu/weight-buffer.js +1 -1
- package/src/hotswap/intent-bundle.js +6 -0
- package/src/hotswap/manifest.d.ts +10 -1
- package/src/hotswap/manifest.js +12 -2
- package/src/hotswap/runtime.js +30 -8
- package/src/index-browser.d.ts +44 -0
- package/src/index-browser.js +14 -0
- package/src/inference/browser-harness-contract-helpers.d.ts +5 -0
- package/src/inference/browser-harness-contract-helpers.js +28 -0
- package/src/inference/browser-harness-diffusion-energy-suites.d.ts +2 -0
- package/src/inference/browser-harness-diffusion-energy-suites.js +269 -0
- package/src/inference/browser-harness-model-helpers.d.ts +16 -0
- package/src/inference/browser-harness-model-helpers.js +217 -0
- package/src/inference/browser-harness-report-helpers.d.ts +7 -0
- package/src/inference/browser-harness-report-helpers.js +42 -0
- package/src/inference/browser-harness-runtime-helpers.d.ts +61 -0
- package/src/inference/browser-harness-runtime-helpers.js +415 -0
- package/src/inference/browser-harness-suite-helpers.d.ts +28 -0
- package/src/inference/browser-harness-suite-helpers.js +268 -0
- package/src/inference/browser-harness-text-helpers.d.ts +27 -0
- package/src/inference/browser-harness-text-helpers.js +788 -0
- package/src/inference/browser-harness.d.ts +8 -0
- package/src/inference/browser-harness.js +149 -1996
- package/src/inference/kv-cache/base.js +140 -94
- package/src/inference/kv-cache/tiered.js +5 -3
- package/src/inference/moe-router.js +88 -56
- package/src/inference/multi-model-network.js +5 -3
- package/src/inference/network-evolution.d.ts +11 -2
- package/src/inference/network-evolution.js +20 -21
- package/src/inference/pipelines/context.d.ts +3 -0
- package/src/inference/pipelines/context.js +142 -2
- package/src/inference/pipelines/diffusion/helpers.js +10 -2
- package/src/inference/pipelines/diffusion/pipeline.js +2 -1
- package/src/inference/pipelines/diffusion/sd3-transformer.js +10 -10
- package/src/inference/pipelines/diffusion/text-encoder-gpu.js +8 -2
- package/src/inference/pipelines/diffusion/vae.js +3 -7
- package/src/inference/pipelines/energy/pipeline.js +27 -21
- package/src/inference/pipelines/energy/quintel.d.ts +5 -0
- package/src/inference/pipelines/energy/quintel.js +11 -0
- package/src/inference/pipelines/energy-head/row-head-pipeline.js +17 -13
- package/src/inference/pipelines/structured/json-head-pipeline.js +26 -11
- package/src/inference/pipelines/text/attention/output-projection.d.ts +12 -0
- package/src/inference/pipelines/text/attention/output-projection.js +8 -0
- package/src/inference/pipelines/text/attention/projections.d.ts +10 -1
- package/src/inference/pipelines/text/attention/projections.js +192 -112
- package/src/inference/pipelines/text/attention/record.js +77 -14
- package/src/inference/pipelines/text/attention/run.js +112 -14
- package/src/inference/pipelines/text/config.js +17 -4
- package/src/inference/pipelines/text/embed.js +2 -8
- package/src/inference/pipelines/text/execution-plan.js +46 -23
- package/src/inference/pipelines/text/execution-v0-contract-helpers.d.ts +59 -0
- package/src/inference/pipelines/text/execution-v0-contract-helpers.js +937 -0
- package/src/inference/pipelines/text/execution-v0-runtime-builders.d.ts +15 -0
- package/src/inference/pipelines/text/execution-v0-runtime-builders.js +279 -0
- package/src/inference/pipelines/text/execution-v0.js +62 -1013
- package/src/inference/pipelines/text/generator-runtime.js +5 -0
- package/src/inference/pipelines/text/generator-steps.d.ts +52 -0
- package/src/inference/pipelines/text/generator-steps.js +340 -221
- package/src/inference/pipelines/text/generator.js +56 -40
- package/src/inference/pipelines/text/init.d.ts +13 -0
- package/src/inference/pipelines/text/init.js +94 -25
- package/src/inference/pipelines/text/kernel-path-auto-select.js +2 -0
- package/src/inference/pipelines/text/kernel-trace.d.ts +2 -0
- package/src/inference/pipelines/text/kernel-trace.js +6 -0
- package/src/inference/pipelines/text/layer.js +4 -9
- package/src/inference/pipelines/text/linear-attention.d.ts +15 -0
- package/src/inference/pipelines/text/linear-attention.js +113 -9
- package/src/inference/pipelines/text/logits/gpu.js +12 -7
- package/src/inference/pipelines/text/logits/index.d.ts +6 -1
- package/src/inference/pipelines/text/logits/index.js +13 -12
- package/src/inference/pipelines/text/logits/utils.d.ts +7 -0
- package/src/inference/pipelines/text/logits/utils.js +9 -0
- package/src/inference/pipelines/text/lora-apply.js +50 -32
- package/src/inference/pipelines/text/model-load.js +282 -104
- package/src/inference/pipelines/text/moe-cache.js +5 -4
- package/src/inference/pipelines/text/moe-cpu-gptoss.js +74 -69
- package/src/inference/pipelines/text/moe-cpu.js +42 -38
- package/src/inference/pipelines/text/moe-gpu.js +110 -86
- package/src/inference/pipelines/text/ops.js +90 -90
- package/src/inference/pipelines/text/probes.js +9 -9
- package/src/inference/pipelines/text/sampling.js +52 -6
- package/src/inference/pipelines/text/weights.js +17 -7
- package/src/inference/pipelines/text.js +13 -1
- package/src/inference/speculative.d.ts +2 -2
- package/src/inference/speculative.js +4 -18
- package/src/inference/test-harness.d.ts +1 -1
- package/src/inference/test-harness.js +17 -7
- package/src/inference/tokenizer.d.ts +0 -5
- package/src/inference/tokenizer.js +4 -23
- package/src/inference/tokenizers/bpe.js +9 -0
- package/src/inference/tokenizers/bundled.js +20 -0
- package/src/inference/tokenizers/sentencepiece.js +12 -0
- package/src/loader/doppler-loader.js +38 -22
- package/src/loader/dtype-utils.js +3 -44
- package/src/loader/embedding-loader.js +7 -3
- package/src/loader/experts/expert-cache.js +13 -6
- package/src/loader/experts/expert-loader.js +10 -6
- package/src/loader/final-weights-loader.js +10 -4
- package/src/loader/layer-loader.js +2 -1
- package/src/loader/loader-state.js +2 -2
- package/src/loader/memory-monitor.js +8 -0
- package/src/loader/multi-model-loader.d.ts +14 -0
- package/src/loader/multi-model-loader.js +70 -24
- package/src/loader/shard-cache.js +84 -14
- package/src/loader/shard-resolver.js +25 -3
- package/src/loader/tensors/tensor-loader.js +214 -144
- package/src/loader/tensors/tensor-reader.js +76 -19
- package/src/loader/weight-downcast.js +1 -1
- package/src/memory/buffer-pool.d.ts +9 -1
- package/src/memory/buffer-pool.js +109 -44
- package/src/memory/unified-detect.js +1 -1
- package/src/rules/inference/dtype.rules.json +5 -0
- package/src/rules/inference/kernel-path.rules.json +24 -8
- package/src/rules/kernels/split-qg.rules.json +6 -0
- package/src/rules/rule-registry.js +27 -1
- package/src/storage/backends/opfs-store.js +68 -24
- package/src/storage/downloader.js +365 -83
- package/src/storage/index.d.ts +3 -0
- package/src/storage/index.js +3 -0
- package/src/storage/preflight.d.ts +2 -2
- package/src/storage/preflight.js +24 -2
- package/src/storage/quickstart-downloader.js +11 -5
- package/src/storage/registry.js +10 -4
- package/src/storage/reports.js +1 -1
- package/src/storage/shard-manager.d.ts +15 -1
- package/src/storage/shard-manager.js +55 -6
- package/src/storage/source-artifact-store.d.ts +52 -0
- package/src/storage/source-artifact-store.js +234 -0
- package/src/tooling/command-api-constants.d.ts +9 -0
- package/src/tooling/command-api-constants.js +9 -0
- package/src/tooling/command-api-family-normalizers.d.ts +9 -0
- package/src/tooling/command-api-family-normalizers.js +343 -0
- package/src/tooling/command-api-helpers.d.ts +25 -0
- package/src/tooling/command-api-helpers.js +262 -0
- package/src/tooling/command-api.js +16 -602
- package/src/tooling/command-envelope.js +4 -1
- package/src/tooling/command-runner-shared.js +52 -18
- package/src/tooling/conversion-config-materializer.js +3 -5
- package/src/tooling/lean-execution-contract.js +150 -3
- package/src/tooling/node-browser-command-runner.js +161 -271
- package/src/tooling/node-command-runner.js +29 -3
- package/src/tooling/node-converter.js +30 -1
- package/src/tooling/node-source-runtime.d.ts +1 -1
- package/src/tooling/node-source-runtime.js +120 -3
- package/src/tooling/node-webgpu.js +24 -21
- package/src/tooling/opfs-cache.js +21 -4
- package/src/tooling/runtime-input-composition.d.ts +38 -0
- package/src/tooling/runtime-input-composition.js +86 -0
- package/src/tooling/source-runtime-bundle.d.ts +40 -5
- package/src/tooling/source-runtime-bundle.js +261 -34
- package/src/tooling/source-runtime-materializer.d.ts +6 -0
- package/src/tooling/source-runtime-materializer.js +93 -0
- package/src/training/attention-backward.js +32 -17
- package/src/training/autograd.js +80 -52
- package/src/training/checkpoint-watch.d.ts +2 -1
- package/src/training/checkpoint-watch.js +39 -6
- package/src/training/checkpoint.js +40 -11
- package/src/training/clip.js +2 -1
- package/src/training/datasets/token-batch.js +20 -8
- package/src/training/distillation/checkpoint-watch.js +1 -0
- package/src/training/distillation/student-fixture.d.ts +22 -0
- package/src/training/distillation/student-fixture.js +846 -0
- package/src/training/distillation/suite-data.d.ts +45 -0
- package/src/training/distillation/suite-data.js +189 -0
- package/src/training/lora-pipeline.js +4 -7
- package/src/training/lora.js +26 -12
- package/src/training/loss.js +5 -6
- package/src/training/objectives/cross_entropy.js +2 -5
- package/src/training/objectives/distill_kd.js +4 -8
- package/src/training/objectives/distill_triplet.js +4 -8
- package/src/training/objectives/ul_stage2_base.js +4 -8
- package/src/training/operator-command.js +2 -0
- package/src/training/optimizer.js +19 -7
- package/src/training/runner.js +2 -1
- package/src/training/suite.js +18 -978
- package/src/training/tensor-factory.d.ts +9 -0
- package/src/training/tensor-factory.js +13 -0
- package/src/training/trainer.js +3 -5
- package/src/training/ul_dataset.js +3 -5
- package/src/training/workloads.js +70 -79
- package/src/types/model.d.ts +5 -0
- package/src/version.js +1 -1
- package/tools/convert-safetensors-node.js +22 -16
- package/tools/doppler-cli.js +50 -26
|
@@ -46,7 +46,16 @@ export function recordAttentionInputs(
|
|
|
46
46
|
info: AttentionInputInfo | null | undefined
|
|
47
47
|
): void;
|
|
48
48
|
|
|
49
|
-
export function
|
|
49
|
+
export function shouldForceF32AttentionProjectionForRoPE(options: {
|
|
50
|
+
attentionInputDtype: string;
|
|
51
|
+
headDim: number;
|
|
52
|
+
rotaryDim?: number;
|
|
53
|
+
interleaved?: boolean;
|
|
54
|
+
}): boolean;
|
|
55
|
+
export function resolveAttentionProjectionOutputDtype(
|
|
56
|
+
attentionInputDtype: string,
|
|
57
|
+
options?: { forceF32?: boolean }
|
|
58
|
+
): 'f16' | 'f32' | string;
|
|
50
59
|
export function resolveProjectionSliceOffsetBytes(
|
|
51
60
|
weightBuffer: WeightBuffer | Tensor | GPUBuffer | null | undefined,
|
|
52
61
|
outputRows: number,
|
|
@@ -1,10 +1,12 @@
|
|
|
1
|
-
import {
|
|
1
|
+
import { releaseBuffer } from '../../../../memory/buffer-pool.js';
|
|
2
2
|
import { isWeightBuffer, getLayout, getWeightDtype } from '../../../../gpu/weight-buffer.js';
|
|
3
3
|
import {
|
|
4
4
|
runMatmul,
|
|
5
5
|
recordMatmul,
|
|
6
6
|
runSplitQKV,
|
|
7
7
|
recordSplitQKV,
|
|
8
|
+
runSplitQG,
|
|
9
|
+
recordSplitQG,
|
|
8
10
|
runRMSNorm,
|
|
9
11
|
recordRMSNorm,
|
|
10
12
|
} from '../../../../gpu/kernel-selector.js';
|
|
@@ -28,6 +30,13 @@ function getSplitRunner(recorder) {
|
|
|
28
30
|
return (qkvTensor, options) => recordSplitQKV(recorder, qkvTensor, options);
|
|
29
31
|
}
|
|
30
32
|
|
|
33
|
+
function getSplitQGRunner(recorder) {
|
|
34
|
+
if (!recorder) {
|
|
35
|
+
return (qgTensor, options) => runSplitQG(qgTensor, options);
|
|
36
|
+
}
|
|
37
|
+
return (qgTensor, options) => recordSplitQG(recorder, qgTensor, options);
|
|
38
|
+
}
|
|
39
|
+
|
|
31
40
|
function getRmsNormRunner(recorder) {
|
|
32
41
|
if (!recorder) {
|
|
33
42
|
return (input, weight, eps, options) => runRMSNorm(input, weight, eps, options);
|
|
@@ -36,7 +45,7 @@ function getRmsNormRunner(recorder) {
|
|
|
36
45
|
}
|
|
37
46
|
|
|
38
47
|
function releaseOwnedWeightBuffer(layerWeight, resolvedWeightBuffer, releaseTemporary) {
|
|
39
|
-
if (layerWeight instanceof GPUBuffer || isWeightBuffer(layerWeight)) {
|
|
48
|
+
if ((typeof GPUBuffer !== 'undefined' && layerWeight instanceof GPUBuffer) || isWeightBuffer(layerWeight)) {
|
|
40
49
|
return;
|
|
41
50
|
}
|
|
42
51
|
if (!resolvedWeightBuffer) {
|
|
@@ -66,10 +75,16 @@ async function projectSingleQkvTensor({
|
|
|
66
75
|
}) {
|
|
67
76
|
const runMatmulForMode = getMatmulRunner(recorder);
|
|
68
77
|
const layerWeight = layerWeights?.[weightKey];
|
|
69
|
-
|
|
78
|
+
if (!layerWeight) {
|
|
79
|
+
throw new Error(`Attention projection requires ${weightKey}.`);
|
|
80
|
+
}
|
|
81
|
+
if (!getWeightBuffer) {
|
|
82
|
+
throw new Error(`Attention projection requires getWeightBuffer for ${role}.`);
|
|
83
|
+
}
|
|
70
84
|
|
|
71
|
-
|
|
72
|
-
|
|
85
|
+
let projected;
|
|
86
|
+
const projBuffer = getWeightBuffer(layerWeight, role);
|
|
87
|
+
try {
|
|
73
88
|
projected = await runMatmulForMode(normed, projBuffer, numTokens, outputSize, hiddenSize, {
|
|
74
89
|
transposeB: 'auto',
|
|
75
90
|
role,
|
|
@@ -77,26 +92,31 @@ async function projectSingleQkvTensor({
|
|
|
77
92
|
kernelPath,
|
|
78
93
|
outputDtype: matmulOutputDtype,
|
|
79
94
|
});
|
|
95
|
+
} finally {
|
|
80
96
|
releaseOwnedWeightBuffer(layerWeight, projBuffer, releaseTemporary);
|
|
81
|
-
} else {
|
|
82
|
-
const fallback = acquireBuffer(numTokens * outputSize * 4, undefined, outputLabel);
|
|
83
|
-
projected = createTensor(fallback, normed.dtype, [numTokens, outputSize], outputLabel);
|
|
84
97
|
}
|
|
85
98
|
|
|
86
99
|
const loraModule = getLoRAModule(lora, layerIdx, loraKey);
|
|
87
100
|
if (loraModule && getWeightBuffer) {
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
101
|
+
try {
|
|
102
|
+
const combined = await applyLoRA(
|
|
103
|
+
normed,
|
|
104
|
+
projected,
|
|
105
|
+
loraModule,
|
|
106
|
+
{ M: numTokens, N: outputSize, K: hiddenSize },
|
|
107
|
+
getWeightBuffer,
|
|
108
|
+
recorder ?? undefined,
|
|
109
|
+
{ kernelPath }
|
|
110
|
+
);
|
|
111
|
+
if (combined.buffer !== projected.buffer) {
|
|
112
|
+
releaseTemporary(projected.buffer);
|
|
113
|
+
projected = combined;
|
|
114
|
+
}
|
|
115
|
+
} catch (error) {
|
|
116
|
+
if (projected?.buffer) {
|
|
117
|
+
releaseTemporary(projected.buffer);
|
|
118
|
+
}
|
|
119
|
+
throw error;
|
|
100
120
|
}
|
|
101
121
|
}
|
|
102
122
|
|
|
@@ -190,13 +210,17 @@ async function projectQueryWithOptionalGate({
|
|
|
190
210
|
return { qTensor, qGateTensor: null };
|
|
191
211
|
}
|
|
192
212
|
|
|
213
|
+
// q_proj weights are stored with interleaved head layout: for head h,
|
|
214
|
+
// rows [h*headDim*2 : h*headDim*2+headDim] = Q, rows [h*headDim*2+headDim : (h+1)*headDim*2] = gate.
|
|
215
|
+
// Compute the full 2*qSize matmul, then de-interleave into separate Q and gate tensors.
|
|
193
216
|
const runMatmulForMode = getMatmulRunner(recorder);
|
|
217
|
+
const runSplitQGForMode = getSplitQGRunner(recorder);
|
|
194
218
|
const qWeightBuffer = getWeightBuffer(qWeight, 'q_proj');
|
|
195
|
-
|
|
219
|
+
let fullQGTensor = null;
|
|
196
220
|
let qTensor = null;
|
|
197
221
|
let qGateTensor = null;
|
|
198
222
|
try {
|
|
199
|
-
|
|
223
|
+
fullQGTensor = await runMatmulForMode(normed, qWeightBuffer, numTokens, qSize * 2, hiddenSize, {
|
|
200
224
|
transposeB: 'auto',
|
|
201
225
|
role: 'q_proj',
|
|
202
226
|
layerIdx,
|
|
@@ -204,32 +228,54 @@ async function projectQueryWithOptionalGate({
|
|
|
204
228
|
outputDtype: matmulOutputDtype,
|
|
205
229
|
});
|
|
206
230
|
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
kernelPath,
|
|
212
|
-
bOffset: gateOffset,
|
|
213
|
-
outputDtype: matmulOutputDtype,
|
|
231
|
+
const split = await runSplitQGForMode(fullQGTensor, {
|
|
232
|
+
numTokens,
|
|
233
|
+
numHeads,
|
|
234
|
+
headDim,
|
|
214
235
|
});
|
|
236
|
+
releaseTemporary(fullQGTensor.buffer);
|
|
237
|
+
fullQGTensor = null;
|
|
238
|
+
qTensor = split.Q;
|
|
239
|
+
qGateTensor = split.G;
|
|
240
|
+
} catch (error) {
|
|
241
|
+
if (fullQGTensor) {
|
|
242
|
+
releaseTemporary(fullQGTensor.buffer);
|
|
243
|
+
}
|
|
244
|
+
if (qTensor) {
|
|
245
|
+
releaseTemporary(qTensor.buffer);
|
|
246
|
+
}
|
|
247
|
+
if (qGateTensor) {
|
|
248
|
+
releaseTemporary(qGateTensor.buffer);
|
|
249
|
+
}
|
|
250
|
+
throw error;
|
|
215
251
|
} finally {
|
|
216
252
|
releaseOwnedWeightBuffer(qWeight, qWeightBuffer, releaseTemporary);
|
|
217
253
|
}
|
|
218
254
|
|
|
219
255
|
const loraModule = getLoRAModule(lora, layerIdx, 'q_proj');
|
|
220
256
|
if (loraModule && getWeightBuffer) {
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
257
|
+
try {
|
|
258
|
+
const combined = await applyLoRA(
|
|
259
|
+
normed,
|
|
260
|
+
qTensor,
|
|
261
|
+
loraModule,
|
|
262
|
+
{ M: numTokens, N: qSize, K: hiddenSize },
|
|
263
|
+
getWeightBuffer,
|
|
264
|
+
recorder ?? undefined,
|
|
265
|
+
{ kernelPath }
|
|
266
|
+
);
|
|
267
|
+
if (combined.buffer !== qTensor.buffer) {
|
|
268
|
+
releaseTemporary(qTensor.buffer);
|
|
269
|
+
qTensor = combined;
|
|
270
|
+
}
|
|
271
|
+
} catch (error) {
|
|
272
|
+
if (qTensor?.buffer) {
|
|
273
|
+
releaseTemporary(qTensor.buffer);
|
|
274
|
+
}
|
|
275
|
+
if (qGateTensor?.buffer) {
|
|
276
|
+
releaseTemporary(qGateTensor.buffer);
|
|
277
|
+
}
|
|
278
|
+
throw error;
|
|
233
279
|
}
|
|
234
280
|
}
|
|
235
281
|
|
|
@@ -248,9 +294,22 @@ export function recordAttentionInputs(state, info) {
|
|
|
248
294
|
state.stats.attentionInputs.push(info);
|
|
249
295
|
}
|
|
250
296
|
|
|
251
|
-
export function
|
|
297
|
+
export function shouldForceF32AttentionProjectionForRoPE({
|
|
298
|
+
attentionInputDtype,
|
|
299
|
+
headDim,
|
|
300
|
+
rotaryDim = headDim,
|
|
301
|
+
interleaved = false,
|
|
302
|
+
}) {
|
|
303
|
+
return attentionInputDtype === 'f16'
|
|
304
|
+
&& Number.isFinite(headDim)
|
|
305
|
+
&& Number.isFinite(rotaryDim)
|
|
306
|
+
&& (rotaryDim !== headDim || interleaved === true);
|
|
307
|
+
}
|
|
308
|
+
|
|
309
|
+
export function resolveAttentionProjectionOutputDtype(attentionInputDtype, options = {}) {
|
|
252
310
|
const useF16Activations = attentionInputDtype === 'f16';
|
|
253
|
-
return selectRuleValue('
|
|
311
|
+
return selectRuleValue('inference', 'dtype', 'attentionProjectionOutputDtype', {
|
|
312
|
+
forceF32: options.forceF32 === true,
|
|
254
313
|
useF16: useF16Activations,
|
|
255
314
|
fallback: attentionInputDtype,
|
|
256
315
|
});
|
|
@@ -289,82 +348,103 @@ export async function projectAttentionQKV({
|
|
|
289
348
|
if (useFusedQKV && layerWeights.qkvProj && layerWeights.qkvSizes) {
|
|
290
349
|
const [qSizeFused, kSizeFused, vSizeFused] = layerWeights.qkvSizes;
|
|
291
350
|
const qkvSizeTotal = qSizeFused + kSizeFused + vSizeFused;
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
351
|
+
let qkvTensor = null;
|
|
352
|
+
try {
|
|
353
|
+
qkvTensor = await runMatmulForMode(normed, layerWeights.qkvProj, numTokens, qkvSizeTotal, hiddenSize, {
|
|
354
|
+
transposeB: 'auto',
|
|
355
|
+
role: 'qkv_proj',
|
|
356
|
+
layerIdx,
|
|
357
|
+
kernelPath,
|
|
358
|
+
outputDtype: matmulOutputDtype,
|
|
359
|
+
});
|
|
360
|
+
const split = await runSplitForMode(qkvTensor, {
|
|
361
|
+
numTokens,
|
|
362
|
+
qSize: qSizeFused,
|
|
363
|
+
kSize: kSizeFused,
|
|
364
|
+
vSize: vSizeFused,
|
|
365
|
+
});
|
|
366
|
+
releaseTemporary(qkvTensor.buffer);
|
|
367
|
+
if (onFusedQKV) {
|
|
368
|
+
onFusedQKV({ qSize: qSizeFused, kSize: kSizeFused, vSize: vSizeFused, totalSize: qkvSizeTotal });
|
|
369
|
+
}
|
|
370
|
+
return { qTensor: split.Q, qGateTensor: null, kTensor: split.K, vTensor: split.V, usedFusedQKV: true };
|
|
371
|
+
} catch (error) {
|
|
372
|
+
if (qkvTensor) {
|
|
373
|
+
releaseTemporary(qkvTensor.buffer);
|
|
374
|
+
}
|
|
375
|
+
throw error;
|
|
376
|
+
}
|
|
377
|
+
}
|
|
378
|
+
|
|
379
|
+
let qTensor = null;
|
|
380
|
+
let qGateTensor = null;
|
|
381
|
+
let kTensor = null;
|
|
382
|
+
let vTensor = null;
|
|
383
|
+
try {
|
|
384
|
+
({ qTensor, qGateTensor } = await projectQueryWithOptionalGate({
|
|
385
|
+
recorder,
|
|
386
|
+
normed,
|
|
387
|
+
layerWeights,
|
|
388
|
+
numTokens,
|
|
389
|
+
numHeads,
|
|
390
|
+
headDim,
|
|
391
|
+
hiddenSize,
|
|
295
392
|
layerIdx,
|
|
296
393
|
kernelPath,
|
|
297
|
-
|
|
394
|
+
matmulOutputDtype,
|
|
395
|
+
getWeightBuffer,
|
|
396
|
+
lora,
|
|
397
|
+
releaseTemporary,
|
|
398
|
+
attentionOutputGate,
|
|
399
|
+
}));
|
|
400
|
+
|
|
401
|
+
kTensor = await projectSingleQkvTensor({
|
|
402
|
+
recorder,
|
|
403
|
+
normed,
|
|
404
|
+
layerWeights,
|
|
405
|
+
weightKey: 'kProj',
|
|
406
|
+
role: 'k_proj',
|
|
407
|
+
outputSize: numKVHeads * headDim,
|
|
408
|
+
outputLabel: 'K',
|
|
409
|
+
loraKey: 'k_proj',
|
|
410
|
+
numTokens,
|
|
411
|
+
hiddenSize,
|
|
412
|
+
layerIdx,
|
|
413
|
+
kernelPath,
|
|
414
|
+
matmulOutputDtype,
|
|
415
|
+
getWeightBuffer,
|
|
416
|
+
lora,
|
|
417
|
+
releaseTemporary,
|
|
298
418
|
});
|
|
299
|
-
|
|
419
|
+
|
|
420
|
+
vTensor = await projectSingleQkvTensor({
|
|
421
|
+
recorder,
|
|
422
|
+
normed,
|
|
423
|
+
layerWeights,
|
|
424
|
+
weightKey: 'vProj',
|
|
425
|
+
role: 'v_proj',
|
|
426
|
+
outputSize: numKVHeads * headDim,
|
|
427
|
+
outputLabel: 'V',
|
|
428
|
+
loraKey: 'v_proj',
|
|
300
429
|
numTokens,
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
430
|
+
hiddenSize,
|
|
431
|
+
layerIdx,
|
|
432
|
+
kernelPath,
|
|
433
|
+
matmulOutputDtype,
|
|
434
|
+
getWeightBuffer,
|
|
435
|
+
lora,
|
|
436
|
+
releaseTemporary,
|
|
304
437
|
});
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
438
|
+
|
|
439
|
+
return { qTensor, qGateTensor, kTensor, vTensor, usedFusedQKV: false };
|
|
440
|
+
} catch (error) {
|
|
441
|
+
for (const tensor of [qTensor, qGateTensor, kTensor, vTensor]) {
|
|
442
|
+
if (tensor?.buffer) {
|
|
443
|
+
releaseTemporary(tensor.buffer);
|
|
444
|
+
}
|
|
308
445
|
}
|
|
309
|
-
|
|
446
|
+
throw error;
|
|
310
447
|
}
|
|
311
|
-
|
|
312
|
-
const { qTensor, qGateTensor } = await projectQueryWithOptionalGate({
|
|
313
|
-
recorder,
|
|
314
|
-
normed,
|
|
315
|
-
layerWeights,
|
|
316
|
-
numTokens,
|
|
317
|
-
numHeads,
|
|
318
|
-
headDim,
|
|
319
|
-
hiddenSize,
|
|
320
|
-
layerIdx,
|
|
321
|
-
kernelPath,
|
|
322
|
-
matmulOutputDtype,
|
|
323
|
-
getWeightBuffer,
|
|
324
|
-
lora,
|
|
325
|
-
releaseTemporary,
|
|
326
|
-
attentionOutputGate,
|
|
327
|
-
});
|
|
328
|
-
|
|
329
|
-
const kTensor = await projectSingleQkvTensor({
|
|
330
|
-
recorder,
|
|
331
|
-
normed,
|
|
332
|
-
layerWeights,
|
|
333
|
-
weightKey: 'kProj',
|
|
334
|
-
role: 'k_proj',
|
|
335
|
-
outputSize: numKVHeads * headDim,
|
|
336
|
-
outputLabel: 'K',
|
|
337
|
-
loraKey: 'k_proj',
|
|
338
|
-
numTokens,
|
|
339
|
-
hiddenSize,
|
|
340
|
-
layerIdx,
|
|
341
|
-
kernelPath,
|
|
342
|
-
matmulOutputDtype,
|
|
343
|
-
getWeightBuffer,
|
|
344
|
-
lora,
|
|
345
|
-
releaseTemporary,
|
|
346
|
-
});
|
|
347
|
-
|
|
348
|
-
const vTensor = await projectSingleQkvTensor({
|
|
349
|
-
recorder,
|
|
350
|
-
normed,
|
|
351
|
-
layerWeights,
|
|
352
|
-
weightKey: 'vProj',
|
|
353
|
-
role: 'v_proj',
|
|
354
|
-
outputSize: numKVHeads * headDim,
|
|
355
|
-
outputLabel: 'V',
|
|
356
|
-
loraKey: 'v_proj',
|
|
357
|
-
numTokens,
|
|
358
|
-
hiddenSize,
|
|
359
|
-
layerIdx,
|
|
360
|
-
kernelPath,
|
|
361
|
-
matmulOutputDtype,
|
|
362
|
-
getWeightBuffer,
|
|
363
|
-
lora,
|
|
364
|
-
releaseTemporary,
|
|
365
|
-
});
|
|
366
|
-
|
|
367
|
-
return { qTensor, qGateTensor, kTensor, vTensor, usedFusedQKV: false };
|
|
368
448
|
}
|
|
369
449
|
|
|
370
450
|
export async function applyAttentionQKNorm({
|
|
@@ -24,10 +24,12 @@ import { selectRuleValue } from '../../../../rules/rule-registry.js';
|
|
|
24
24
|
import { SlidingWindowKVCache } from '../../../kv-cache.js';
|
|
25
25
|
import {
|
|
26
26
|
recordAttentionInputs,
|
|
27
|
+
shouldForceF32AttentionProjectionForRoPE,
|
|
27
28
|
resolveAttentionProjectionOutputDtype,
|
|
28
29
|
projectAttentionQKV,
|
|
29
30
|
applyAttentionQKNorm,
|
|
30
31
|
} from './projections.js';
|
|
32
|
+
import { prepareAttentionProjectionInput } from './output-projection.js';
|
|
31
33
|
|
|
32
34
|
import { releaseOrTrack, shouldDebugLayer } from './types.js';
|
|
33
35
|
|
|
@@ -90,9 +92,20 @@ export async function recordLayerAttentionGPU(
|
|
|
90
92
|
const allowF16Attention = wantsF16Output && kvCacheDtype === 'f16';
|
|
91
93
|
let attentionInput = input;
|
|
92
94
|
let attentionInputTemp = false;
|
|
95
|
+
let normed = attentionInput;
|
|
96
|
+
let qTensor = null;
|
|
97
|
+
let qGateTensor = null;
|
|
98
|
+
let kTensor = null;
|
|
99
|
+
let vTensor = null;
|
|
100
|
+
let attnOutput = null;
|
|
101
|
+
let attnForProjection = null;
|
|
102
|
+
let output = null;
|
|
103
|
+
let finalOutput = null;
|
|
104
|
+
let oProjInputTemp = null;
|
|
93
105
|
if (wantsF16Output && !allowF16Attention) {
|
|
94
106
|
attentionInput = await recordCastF16ToF32(recorder, input);
|
|
95
107
|
attentionInputTemp = true;
|
|
108
|
+
normed = attentionInput;
|
|
96
109
|
}
|
|
97
110
|
|
|
98
111
|
if (!layerWeights) {
|
|
@@ -108,7 +121,7 @@ export async function recordLayerAttentionGPU(
|
|
|
108
121
|
|
|
109
122
|
// 1. Input norm
|
|
110
123
|
|
|
111
|
-
|
|
124
|
+
try {
|
|
112
125
|
if (!skipInputNorm && layerWeights.inputNorm && getNormWeightBuffer) {
|
|
113
126
|
const normWeightBuf = getNormWeightBuffer(layerWeights.inputNorm, 'input_norm');
|
|
114
127
|
normed = await recordRMSNorm(recorder, attentionInput, normWeightBuf, rmsNormEps, {
|
|
@@ -131,8 +144,16 @@ export async function recordLayerAttentionGPU(
|
|
|
131
144
|
}
|
|
132
145
|
|
|
133
146
|
// 2. Q/K/V projections
|
|
134
|
-
const matmulOutputDtype = resolveAttentionProjectionOutputDtype(desiredOutputDtype
|
|
135
|
-
|
|
147
|
+
const matmulOutputDtype = resolveAttentionProjectionOutputDtype(desiredOutputDtype, {
|
|
148
|
+
forceF32: shouldForceF32AttentionProjectionForRoPE({
|
|
149
|
+
attentionInputDtype: desiredOutputDtype,
|
|
150
|
+
headDim,
|
|
151
|
+
rotaryDim: config.ropeRotaryDim,
|
|
152
|
+
interleaved: config.ropeInterleaved,
|
|
153
|
+
}),
|
|
154
|
+
});
|
|
155
|
+
let usedFusedQKV = false;
|
|
156
|
+
({ qTensor, qGateTensor, kTensor, vTensor, usedFusedQKV } = await projectAttentionQKV({
|
|
136
157
|
recorder,
|
|
137
158
|
normed,
|
|
138
159
|
layerWeights,
|
|
@@ -153,7 +174,7 @@ export async function recordLayerAttentionGPU(
|
|
|
153
174
|
trace.attn(layerIdx, `Using fused QKV path: ${qSizeFused}+${kSizeFused}+${vSizeFused}=${totalSize}`);
|
|
154
175
|
}
|
|
155
176
|
: null,
|
|
156
|
-
});
|
|
177
|
+
}));
|
|
157
178
|
|
|
158
179
|
// Optional per-head Q/K normalization.
|
|
159
180
|
// Some models use RMSNorm with (1+weight) offset formula, controlled by rmsNormWeightOffset.
|
|
@@ -502,9 +523,9 @@ export async function recordLayerAttentionGPU(
|
|
|
502
523
|
throw new Error(`Unsupported attention kernel variant "${attentionKernelVariant}" at layer ${layerIdx}`);
|
|
503
524
|
}
|
|
504
525
|
|
|
505
|
-
|
|
526
|
+
attnOutput = await runAttentionKernel();
|
|
506
527
|
|
|
507
|
-
|
|
528
|
+
attnForProjection = attnOutput;
|
|
508
529
|
if (qGateTensor) {
|
|
509
530
|
attnForProjection = await recordSiLU(recorder, attnOutput, {
|
|
510
531
|
size: numTokens * numHeads * headDim,
|
|
@@ -518,19 +539,19 @@ export async function recordLayerAttentionGPU(
|
|
|
518
539
|
|
|
519
540
|
// 6. Output projection (with optional fused residual for decode)
|
|
520
541
|
|
|
521
|
-
|
|
542
|
+
output = null;
|
|
522
543
|
let residualFused = false;
|
|
523
544
|
let oProjInput = attnForProjection;
|
|
524
|
-
|
|
545
|
+
oProjInputTemp = null;
|
|
525
546
|
if (layerWeights.oProj && getWeightBuffer) {
|
|
547
|
+
({ oProjInput, oProjInputTemp } = await prepareAttentionProjectionInput(
|
|
548
|
+
attnForProjection,
|
|
549
|
+
matmulOutputDtype,
|
|
550
|
+
(tensor) => recordCastF32ToF16(recorder, tensor)
|
|
551
|
+
));
|
|
526
552
|
const oProjBuf = getWeightBuffer(layerWeights.oProj, 'o_proj');
|
|
527
553
|
const loraO = getLoRAModule(lora, layerIdx, 'o_proj');
|
|
528
554
|
|
|
529
|
-
if (matmulOutputDtype === 'f16' && attnForProjection.dtype !== 'f16') {
|
|
530
|
-
oProjInput = await recordCastF32ToF16(recorder, attnForProjection);
|
|
531
|
-
oProjInputTemp = oProjInput;
|
|
532
|
-
}
|
|
533
|
-
|
|
534
555
|
// Use fused o_proj + residual for decode when possible
|
|
535
556
|
// Note: dtype from WeightBuffer metadata (buffer-dtypes WeakMap removed)
|
|
536
557
|
const oProjDtype = getWeightDtype(oProjBuf);
|
|
@@ -589,7 +610,7 @@ export async function recordLayerAttentionGPU(
|
|
|
589
610
|
}
|
|
590
611
|
}
|
|
591
612
|
|
|
592
|
-
|
|
613
|
+
finalOutput = output;
|
|
593
614
|
|
|
594
615
|
const buffersToTrack = [];
|
|
595
616
|
if (output.buffer !== attnForProjection.buffer) {
|
|
@@ -619,4 +640,46 @@ export async function recordLayerAttentionGPU(
|
|
|
619
640
|
}
|
|
620
641
|
|
|
621
642
|
return { output: finalOutput, residualFused };
|
|
643
|
+
} catch (error) {
|
|
644
|
+
const tracked = new Set();
|
|
645
|
+
const trackOnce = (buffer) => {
|
|
646
|
+
if (!buffer || tracked.has(buffer)) return;
|
|
647
|
+
tracked.add(buffer);
|
|
648
|
+
recorder.trackTemporaryBuffer(buffer);
|
|
649
|
+
};
|
|
650
|
+
if (finalOutput?.buffer && finalOutput.buffer !== output?.buffer) {
|
|
651
|
+
trackOnce(finalOutput.buffer);
|
|
652
|
+
}
|
|
653
|
+
if (output?.buffer && output.buffer !== attnForProjection?.buffer) {
|
|
654
|
+
trackOnce(output.buffer);
|
|
655
|
+
}
|
|
656
|
+
if (oProjInputTemp?.buffer) {
|
|
657
|
+
trackOnce(oProjInputTemp.buffer);
|
|
658
|
+
}
|
|
659
|
+
if (attnForProjection?.buffer && attnForProjection.buffer !== attnOutput?.buffer) {
|
|
660
|
+
trackOnce(attnForProjection.buffer);
|
|
661
|
+
}
|
|
662
|
+
if (attnOutput?.buffer) {
|
|
663
|
+
trackOnce(attnOutput.buffer);
|
|
664
|
+
}
|
|
665
|
+
if (qGateTensor?.buffer) {
|
|
666
|
+
trackOnce(qGateTensor.buffer);
|
|
667
|
+
}
|
|
668
|
+
if (qTensor?.buffer) {
|
|
669
|
+
trackOnce(qTensor.buffer);
|
|
670
|
+
}
|
|
671
|
+
if (kTensor?.buffer) {
|
|
672
|
+
trackOnce(kTensor.buffer);
|
|
673
|
+
}
|
|
674
|
+
if (vTensor?.buffer) {
|
|
675
|
+
trackOnce(vTensor.buffer);
|
|
676
|
+
}
|
|
677
|
+
if (normed?.buffer && normed.buffer !== attentionInput?.buffer) {
|
|
678
|
+
trackOnce(normed.buffer);
|
|
679
|
+
}
|
|
680
|
+
if (attentionInputTemp && attentionInput?.buffer) {
|
|
681
|
+
trackOnce(attentionInput.buffer);
|
|
682
|
+
}
|
|
683
|
+
throw error;
|
|
684
|
+
}
|
|
622
685
|
}
|