@simulatte/doppler 0.1.5 → 0.1.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +126 -0
- package/README.md +25 -17
- package/package.json +20 -4
- package/src/adapters/adapter-registry.js +12 -1
- package/src/adapters/lora-loader.js +23 -6
- package/src/bridge/extension-client.d.ts +5 -0
- package/src/bridge/extension-client.js +40 -0
- package/src/bridge/index.d.ts +2 -1
- package/src/bridge/index.js +6 -4
- package/src/browser/browser-converter.js +26 -1
- package/src/browser/file-picker.js +6 -0
- package/src/browser/safetensors-parser-browser.js +84 -1
- package/src/browser/shard-io-browser.js +2 -2
- package/src/browser/tensor-source-download.js +8 -2
- package/src/browser/tensor-source-http.d.ts +1 -0
- package/src/browser/tensor-source-http.js +5 -1
- package/src/client/doppler-api.browser.js +20 -4
- package/src/client/doppler-api.js +19 -3
- package/src/client/doppler-provider/generation.js +12 -0
- package/src/client/doppler-provider/model-manager.d.ts +10 -0
- package/src/client/doppler-provider/model-manager.js +91 -19
- package/src/client/doppler-provider/source-runtime.d.ts +2 -1
- package/src/client/doppler-provider/source-runtime.js +132 -13
- package/src/client/doppler-registry.json +8 -7
- package/src/config/backward-registry-loader.js +17 -2
- package/src/config/execution-v0-contract-check.js +113 -15
- package/src/config/kernel-path-contract-check.js +57 -29
- package/src/config/kernel-path-loader.js +5 -36
- package/src/config/kernels/kernel-ref-digests.js +39 -39
- package/src/config/kernels/registry.js +14 -1
- package/src/config/kernels/registry.json +49 -7
- package/src/config/loader.d.ts +1 -1
- package/src/config/loader.js +43 -4
- package/src/config/merge-contract-check.js +59 -4
- package/src/config/merge-helpers.js +128 -7
- package/src/config/merge.d.ts +1 -0
- package/src/config/merge.js +28 -0
- package/src/config/param-validator.js +47 -2
- package/src/config/presets/kernel-paths/{gemma2-q4k-dequant-f32a.json → gemma2-q4k-dequant-f32a-nosubgroups.json} +3 -3
- package/src/config/presets/kernel-paths/gemma3-f16-fused-f32a-online-streamingprefill.json +223 -0
- package/src/config/presets/kernel-paths/{gemma3-q4k-dequant-f32a.json → gemma3-q4k-dequant-f32a-nosubgroups.json} +3 -3
- package/src/config/presets/kernel-paths/registry.json +29 -8
- package/src/config/presets/models/gemma2.json +2 -2
- package/src/config/presets/models/qwen3.json +9 -2
- package/src/config/presets/models/transformer.json +5 -0
- package/src/config/presets/runtime/experiments/bench/gemma3-bench-q4k.json +1 -1
- package/src/config/presets/runtime/experiments/debug/gemma3-debug-q4k.json +1 -1
- package/src/config/presets/runtime/experiments/verify/gemma3-verify.json +1 -1
- package/src/config/presets/runtime/kernels/dequant-f16-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/dequant-f32-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/embeddinggemma-q4k-dequant-f32a.json +37 -0
- package/src/config/presets/runtime/kernels/fused-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f16a.json +33 -0
- package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f32a-nosubgroups.json +33 -0
- package/src/config/presets/runtime/kernels/gemma2-q4k-fused-f32a.json +33 -0
- package/src/config/presets/runtime/kernels/safe-q4k.json +6 -13
- package/src/config/presets/runtime/platform/metal-apple-q4k.json +1 -1
- package/src/config/required-inference-fields-contract-check.js +6 -0
- package/src/config/runtime.js +6 -1
- package/src/config/schema/debug.schema.d.ts +5 -0
- package/src/config/schema/doppler.schema.js +16 -21
- package/src/config/schema/inference-defaults.schema.js +6 -3
- package/src/config/schema/inference.schema.d.ts +9 -0
- package/src/config/schema/kernel-path.schema.d.ts +11 -1
- package/src/config/schema/kernel-thresholds.schema.js +12 -4
- package/src/config/schema/manifest.schema.d.ts +8 -1
- package/src/config/schema/manifest.schema.js +19 -3
- package/src/config/training-defaults.js +30 -22
- package/src/converter/conversion-plan.js +94 -9
- package/src/converter/core.d.ts +7 -0
- package/src/converter/core.js +14 -9
- package/src/converter/execution-v0-manifest.js +4 -1
- package/src/converter/index.d.ts +1 -0
- package/src/converter/index.js +1 -0
- package/src/converter/manifest-inference.js +43 -12
- package/src/converter/parsers/diffusion.js +0 -3
- package/src/converter/quantization-info.js +35 -15
- package/src/converter/rope-config.js +42 -0
- package/src/converter/shard-packer.d.ts +1 -1
- package/src/converter/shard-packer.js +4 -1
- package/src/debug/config.js +123 -11
- package/src/debug/signals.js +7 -1
- package/src/debug/tensor.d.ts +2 -0
- package/src/debug/tensor.js +13 -2
- package/src/distribution/p2p-control-plane.js +52 -12
- package/src/distribution/p2p-observability.js +43 -7
- package/src/distribution/p2p-webrtc-browser.js +20 -0
- package/src/distribution/shard-delivery.js +77 -26
- package/src/formats/gguf/types.js +33 -16
- package/src/formats/rdrr/groups.d.ts +12 -4
- package/src/formats/rdrr/groups.js +3 -6
- package/src/formats/rdrr/parsing.js +39 -2
- package/src/formats/rdrr/types.d.ts +2 -1
- package/src/gpu/command-recorder.js +86 -61
- package/src/gpu/device.d.ts +1 -0
- package/src/gpu/device.js +131 -19
- package/src/gpu/kernel-tuner/benchmarks.js +326 -316
- package/src/gpu/kernel-tuner/cache.js +71 -4
- package/src/gpu/kernel-tuner/tuner.js +22 -4
- package/src/gpu/kernels/attention.js +113 -34
- package/src/gpu/kernels/backward/adam.js +62 -58
- package/src/gpu/kernels/backward/attention_backward.js +257 -169
- package/src/gpu/kernels/backward/conv2d_backward.js +14 -1
- package/src/gpu/kernels/bias_add.wgsl +8 -6
- package/src/gpu/kernels/bias_add_f16.wgsl +8 -5
- package/src/gpu/kernels/cast.js +191 -149
- package/src/gpu/kernels/check-stop.js +33 -44
- package/src/gpu/kernels/conv2d.js +27 -17
- package/src/gpu/kernels/conv2d.wgsl +7 -8
- package/src/gpu/kernels/conv2d_f16.wgsl +7 -8
- package/src/gpu/kernels/cross_entropy_loss.js +21 -15
- package/src/gpu/kernels/depthwise_conv2d.js +37 -26
- package/src/gpu/kernels/depthwise_conv2d.wgsl +6 -9
- package/src/gpu/kernels/depthwise_conv2d_f16.wgsl +6 -9
- package/src/gpu/kernels/dequant.js +178 -126
- package/src/gpu/kernels/energy.d.ts +3 -21
- package/src/gpu/kernels/energy.js +111 -88
- package/src/gpu/kernels/feature-check.js +1 -1
- package/src/gpu/kernels/fused_ffn.js +84 -65
- package/src/gpu/kernels/fused_matmul_residual.js +56 -33
- package/src/gpu/kernels/fused_matmul_rmsnorm.js +62 -45
- package/src/gpu/kernels/gather.js +33 -15
- package/src/gpu/kernels/gelu.js +19 -11
- package/src/gpu/kernels/grouped_pointwise_conv2d.js +34 -23
- package/src/gpu/kernels/grouped_pointwise_conv2d.wgsl +6 -9
- package/src/gpu/kernels/grouped_pointwise_conv2d_f16.wgsl +6 -9
- package/src/gpu/kernels/groupnorm.js +34 -23
- package/src/gpu/kernels/kv-quantize.js +5 -2
- package/src/gpu/kernels/layernorm.js +35 -19
- package/src/gpu/kernels/logit-merge.js +5 -3
- package/src/gpu/kernels/matmul.js +83 -39
- package/src/gpu/kernels/modulate.js +23 -15
- package/src/gpu/kernels/moe.js +221 -175
- package/src/gpu/kernels/pixel_shuffle.js +22 -14
- package/src/gpu/kernels/pixel_shuffle.wgsl +4 -5
- package/src/gpu/kernels/pixel_shuffle_f16.wgsl +4 -5
- package/src/gpu/kernels/relu.js +31 -10
- package/src/gpu/kernels/relu.wgsl +2 -1
- package/src/gpu/kernels/relu_f16.wgsl +2 -1
- package/src/gpu/kernels/repeat_channels.js +25 -17
- package/src/gpu/kernels/repeat_channels.wgsl +4 -5
- package/src/gpu/kernels/repeat_channels_f16.wgsl +4 -5
- package/src/gpu/kernels/residual.js +69 -23
- package/src/gpu/kernels/residual.wgsl +6 -3
- package/src/gpu/kernels/residual_f16.wgsl +2 -1
- package/src/gpu/kernels/residual_f16_vec4.wgsl +2 -1
- package/src/gpu/kernels/residual_vec4.wgsl +2 -1
- package/src/gpu/kernels/rmsnorm.js +96 -28
- package/src/gpu/kernels/rmsnorm.wgsl +14 -6
- package/src/gpu/kernels/rmsnorm_f16.wgsl +10 -2
- package/src/gpu/kernels/rope.d.ts +2 -0
- package/src/gpu/kernels/rope.js +14 -1
- package/src/gpu/kernels/rope.wgsl +56 -40
- package/src/gpu/kernels/sample.js +27 -38
- package/src/gpu/kernels/sana_linear_attention.js +19 -12
- package/src/gpu/kernels/sana_linear_attention_apply.wgsl +4 -5
- package/src/gpu/kernels/sana_linear_attention_apply_f16.wgsl +4 -5
- package/src/gpu/kernels/sana_linear_attention_summary.wgsl +4 -0
- package/src/gpu/kernels/sana_linear_attention_summary_f16.wgsl +4 -0
- package/src/gpu/kernels/scale.js +18 -11
- package/src/gpu/kernels/shader-cache.js +4 -2
- package/src/gpu/kernels/silu.d.ts +1 -0
- package/src/gpu/kernels/silu.js +148 -82
- package/src/gpu/kernels/silu.wgsl +19 -9
- package/src/gpu/kernels/silu_f16.wgsl +19 -9
- package/src/gpu/kernels/softmax.js +44 -25
- package/src/gpu/kernels/split_qkv.js +23 -13
- package/src/gpu/kernels/transpose.js +31 -10
- package/src/gpu/kernels/transpose.wgsl +6 -5
- package/src/gpu/kernels/upsample2d.js +22 -13
- package/src/gpu/kernels/upsample2d.wgsl +6 -9
- package/src/gpu/kernels/upsample2d_f16.wgsl +6 -9
- package/src/gpu/kernels/utils.js +35 -13
- package/src/gpu/partitioned-buffer-pool.js +10 -2
- package/src/gpu/perf-guards.js +2 -9
- package/src/gpu/profiler.js +27 -22
- package/src/gpu/readback-utils.d.ts +16 -0
- package/src/gpu/readback-utils.js +41 -0
- package/src/gpu/submit-tracker.js +13 -0
- package/src/gpu/uniform-cache.d.ts +1 -0
- package/src/gpu/uniform-cache.js +30 -9
- package/src/hotswap/intent-bundle.js +6 -0
- package/src/hotswap/manifest.d.ts +10 -1
- package/src/hotswap/manifest.js +12 -2
- package/src/hotswap/runtime.js +30 -8
- package/src/index-browser.d.ts +44 -0
- package/src/index-browser.js +14 -0
- package/src/inference/browser-harness-contract-helpers.d.ts +5 -0
- package/src/inference/browser-harness-contract-helpers.js +28 -0
- package/src/inference/browser-harness-diffusion-energy-suites.d.ts +2 -0
- package/src/inference/browser-harness-diffusion-energy-suites.js +269 -0
- package/src/inference/browser-harness-model-helpers.d.ts +16 -0
- package/src/inference/browser-harness-model-helpers.js +217 -0
- package/src/inference/browser-harness-report-helpers.d.ts +7 -0
- package/src/inference/browser-harness-report-helpers.js +42 -0
- package/src/inference/browser-harness-runtime-helpers.d.ts +61 -0
- package/src/inference/browser-harness-runtime-helpers.js +415 -0
- package/src/inference/browser-harness-suite-helpers.d.ts +28 -0
- package/src/inference/browser-harness-suite-helpers.js +268 -0
- package/src/inference/browser-harness-text-helpers.d.ts +27 -0
- package/src/inference/browser-harness-text-helpers.js +788 -0
- package/src/inference/browser-harness.d.ts +6 -0
- package/src/inference/browser-harness.js +130 -1950
- package/src/inference/kv-cache/base.js +140 -94
- package/src/inference/kv-cache/tiered.js +5 -3
- package/src/inference/moe-router.js +88 -56
- package/src/inference/multi-model-network.js +5 -3
- package/src/inference/network-evolution.d.ts +11 -2
- package/src/inference/network-evolution.js +20 -21
- package/src/inference/pipelines/context.d.ts +3 -0
- package/src/inference/pipelines/context.js +142 -2
- package/src/inference/pipelines/diffusion/helpers.js +7 -2
- package/src/inference/pipelines/diffusion/pipeline.js +17 -7
- package/src/inference/pipelines/diffusion/sd3-transformer.js +10 -10
- package/src/inference/pipelines/diffusion/text-encoder-gpu.d.ts +5 -0
- package/src/inference/pipelines/diffusion/text-encoder-gpu.js +27 -15
- package/src/inference/pipelines/diffusion/vae.js +3 -7
- package/src/inference/pipelines/energy/pipeline.js +27 -21
- package/src/inference/pipelines/energy/quintel.d.ts +5 -0
- package/src/inference/pipelines/energy/quintel.js +11 -0
- package/src/inference/pipelines/energy-head/row-head-pipeline.js +17 -13
- package/src/inference/pipelines/structured/json-head-pipeline.js +26 -11
- package/src/inference/pipelines/text/attention/projections.js +151 -101
- package/src/inference/pipelines/text/attention/record.js +73 -10
- package/src/inference/pipelines/text/attention/run.js +73 -10
- package/src/inference/pipelines/text/chat-format.js +25 -1
- package/src/inference/pipelines/text/config.d.ts +4 -0
- package/src/inference/pipelines/text/config.js +71 -5
- package/src/inference/pipelines/text/embed.js +2 -8
- package/src/inference/pipelines/text/execution-plan.js +64 -50
- package/src/inference/pipelines/text/execution-v0-contract-helpers.d.ts +59 -0
- package/src/inference/pipelines/text/execution-v0-contract-helpers.js +937 -0
- package/src/inference/pipelines/text/execution-v0-runtime-builders.d.ts +15 -0
- package/src/inference/pipelines/text/execution-v0-runtime-builders.js +279 -0
- package/src/inference/pipelines/text/execution-v0.js +78 -1002
- package/src/inference/pipelines/text/ffn/standard.js +3 -0
- package/src/inference/pipelines/text/generator-steps.d.ts +46 -0
- package/src/inference/pipelines/text/generator-steps.js +298 -207
- package/src/inference/pipelines/text/generator.js +6 -23
- package/src/inference/pipelines/text/init.d.ts +4 -0
- package/src/inference/pipelines/text/init.js +134 -29
- package/src/inference/pipelines/text/kernel-path-auto-select.js +2 -0
- package/src/inference/pipelines/text/kernel-trace.d.ts +2 -0
- package/src/inference/pipelines/text/kernel-trace.js +6 -0
- package/src/inference/pipelines/text/layer.js +14 -9
- package/src/inference/pipelines/text/linear-attention.d.ts +10 -0
- package/src/inference/pipelines/text/linear-attention.js +80 -6
- package/src/inference/pipelines/text/logits/gpu.js +10 -5
- package/src/inference/pipelines/text/logits/index.js +10 -11
- package/src/inference/pipelines/text/logits/utils.d.ts +7 -0
- package/src/inference/pipelines/text/logits/utils.js +9 -0
- package/src/inference/pipelines/text/lora-apply.js +50 -32
- package/src/inference/pipelines/text/model-load.js +279 -104
- package/src/inference/pipelines/text/moe-cache.js +5 -4
- package/src/inference/pipelines/text/moe-cpu-gptoss.js +74 -69
- package/src/inference/pipelines/text/moe-cpu.js +42 -38
- package/src/inference/pipelines/text/moe-gpu.js +110 -86
- package/src/inference/pipelines/text/ops.js +90 -90
- package/src/inference/pipelines/text/probes.js +9 -9
- package/src/inference/pipelines/text/weights.js +17 -7
- package/src/inference/pipelines/text.js +17 -1
- package/src/inference/speculative.d.ts +2 -2
- package/src/inference/speculative.js +4 -18
- package/src/inference/test-harness.d.ts +1 -1
- package/src/inference/test-harness.js +15 -5
- package/src/inference/tokenizer.d.ts +0 -5
- package/src/inference/tokenizer.js +4 -23
- package/src/inference/tokenizers/bpe.js +9 -0
- package/src/inference/tokenizers/bundled.js +176 -33
- package/src/inference/tokenizers/sentencepiece.js +12 -0
- package/src/loader/doppler-loader.js +38 -22
- package/src/loader/dtype-utils.js +3 -44
- package/src/loader/embedding-loader.js +7 -3
- package/src/loader/experts/expert-cache.js +13 -6
- package/src/loader/experts/expert-loader.js +10 -6
- package/src/loader/final-weights-loader.js +8 -4
- package/src/loader/layer-loader.js +2 -1
- package/src/loader/loader-state.js +2 -2
- package/src/loader/memory-monitor.js +8 -0
- package/src/loader/multi-model-loader.d.ts +14 -0
- package/src/loader/multi-model-loader.js +70 -24
- package/src/loader/shard-cache.js +81 -12
- package/src/loader/shard-resolver.js +25 -3
- package/src/loader/tensors/tensor-loader.js +209 -144
- package/src/loader/tensors/tensor-reader.js +76 -19
- package/src/loader/weight-downcast.js +1 -1
- package/src/memory/buffer-pool.d.ts +9 -1
- package/src/memory/buffer-pool.js +109 -44
- package/src/memory/unified-detect.js +1 -1
- package/src/rules/inference/kernel-path.rules.json +24 -8
- package/src/rules/rule-registry.js +25 -1
- package/src/rules/tooling/command-runtime.rules.json +18 -0
- package/src/storage/backends/opfs-store.js +68 -24
- package/src/storage/downloader.js +364 -83
- package/src/storage/index.d.ts +3 -0
- package/src/storage/index.js +3 -0
- package/src/storage/preflight.d.ts +2 -2
- package/src/storage/preflight.js +24 -2
- package/src/storage/quickstart-downloader.js +11 -5
- package/src/storage/registry.js +10 -4
- package/src/storage/reports.js +1 -1
- package/src/storage/shard-manager.d.ts +15 -1
- package/src/storage/shard-manager.js +51 -3
- package/src/storage/source-artifact-store.d.ts +52 -0
- package/src/storage/source-artifact-store.js +234 -0
- package/src/tooling/command-api-constants.d.ts +9 -0
- package/src/tooling/command-api-constants.js +9 -0
- package/src/tooling/command-api-family-normalizers.d.ts +9 -0
- package/src/tooling/command-api-family-normalizers.js +343 -0
- package/src/tooling/command-api-helpers.d.ts +25 -0
- package/src/tooling/command-api-helpers.js +262 -0
- package/src/tooling/command-api.d.ts +27 -1
- package/src/tooling/command-api.js +26 -473
- package/src/tooling/command-envelope.js +4 -1
- package/src/tooling/command-runner-shared.js +52 -18
- package/src/tooling/lean-execution-contract.js +150 -3
- package/src/tooling/node-browser-command-runner.d.ts +4 -0
- package/src/tooling/node-browser-command-runner.js +218 -273
- package/src/tooling/node-command-runner.js +44 -3
- package/src/tooling/node-converter.js +27 -1
- package/src/tooling/node-source-runtime.d.ts +1 -1
- package/src/tooling/node-source-runtime.js +84 -3
- package/src/tooling/node-webgpu.js +30 -105
- package/src/tooling/opfs-cache.js +21 -4
- package/src/tooling/runtime-input-composition.d.ts +38 -0
- package/src/tooling/runtime-input-composition.js +86 -0
- package/src/tooling/source-runtime-bundle.d.ts +40 -5
- package/src/tooling/source-runtime-bundle.js +261 -34
- package/src/tooling/source-runtime-materializer.d.ts +6 -0
- package/src/tooling/source-runtime-materializer.js +93 -0
- package/src/training/attention-backward.js +32 -17
- package/src/training/autograd.js +80 -52
- package/src/training/checkpoint-watch.d.ts +8 -0
- package/src/training/checkpoint-watch.js +139 -0
- package/src/training/checkpoint.d.ts +6 -1
- package/src/training/checkpoint.js +46 -7
- package/src/training/clip.js +2 -1
- package/src/training/datasets/token-batch.js +20 -8
- package/src/training/distillation/artifacts.d.ts +71 -0
- package/src/training/distillation/artifacts.js +132 -0
- package/src/training/distillation/checkpoint-watch.d.ts +10 -0
- package/src/training/distillation/checkpoint-watch.js +58 -0
- package/src/training/distillation/dataset.d.ts +59 -0
- package/src/training/distillation/dataset.js +337 -0
- package/src/training/distillation/eval.d.ts +34 -0
- package/src/training/distillation/eval.js +310 -0
- package/src/training/distillation/index.d.ts +29 -0
- package/src/training/distillation/index.js +29 -0
- package/src/training/distillation/runtime.d.ts +20 -0
- package/src/training/distillation/runtime.js +121 -0
- package/src/training/distillation/scoreboard.d.ts +6 -0
- package/src/training/distillation/scoreboard.js +8 -0
- package/src/training/distillation/stage-a.d.ts +45 -0
- package/src/training/distillation/stage-a.js +338 -0
- package/src/training/distillation/stage-b.d.ts +24 -0
- package/src/training/distillation/stage-b.js +20 -0
- package/src/training/distillation/student-fixture.d.ts +22 -0
- package/src/training/distillation/student-fixture.js +846 -0
- package/src/training/distillation/suite-data.d.ts +45 -0
- package/src/training/distillation/suite-data.js +189 -0
- package/src/training/index.d.ts +10 -0
- package/src/training/index.js +10 -0
- package/src/training/lora-pipeline.d.ts +40 -0
- package/src/training/lora-pipeline.js +793 -0
- package/src/training/lora.js +26 -12
- package/src/training/loss.js +5 -6
- package/src/training/objectives/cross_entropy.js +2 -5
- package/src/training/objectives/distill_kd.js +4 -8
- package/src/training/objectives/distill_triplet.js +4 -8
- package/src/training/objectives/ul_stage2_base.js +4 -8
- package/src/training/operator-artifacts.d.ts +62 -0
- package/src/training/operator-artifacts.js +140 -0
- package/src/training/operator-command.d.ts +5 -0
- package/src/training/operator-command.js +455 -0
- package/src/training/operator-eval.d.ts +48 -0
- package/src/training/operator-eval.js +230 -0
- package/src/training/operator-scoreboard.d.ts +5 -0
- package/src/training/operator-scoreboard.js +44 -0
- package/src/training/optimizer.js +19 -7
- package/src/training/runner.d.ts +52 -0
- package/src/training/runner.js +31 -5
- package/src/training/suite.d.ts +112 -0
- package/src/training/suite.js +24 -984
- package/src/training/tensor-factory.d.ts +9 -0
- package/src/training/tensor-factory.js +13 -0
- package/src/training/trainer.js +3 -5
- package/src/training/ul_dataset.js +3 -5
- package/src/training/workloads.d.ts +164 -0
- package/src/training/workloads.js +530 -0
- package/src/version.js +1 -1
- package/tools/convert-safetensors-node.js +22 -16
- package/tools/doppler-cli.js +179 -63
|
@@ -84,20 +84,35 @@ function parseStructuredJSONObject(rawText) {
|
|
|
84
84
|
function resolveStructuredRuntime(manifest, runtimeConfig) {
|
|
85
85
|
const modelCfg = isObj(manifest?.inference?.structuredJsonHead)
|
|
86
86
|
? manifest.inference.structuredJsonHead
|
|
87
|
-
:
|
|
87
|
+
: null;
|
|
88
|
+
if (!modelCfg) {
|
|
89
|
+
throw new Error('StructuredJsonHeadPipeline: manifest.inference.structuredJsonHead is required.');
|
|
90
|
+
}
|
|
88
91
|
const runtimeCfg = isObj(runtimeConfig?.inference?.structuredJsonHead)
|
|
89
92
|
? runtimeConfig.inference.structuredJsonHead
|
|
90
|
-
:
|
|
93
|
+
: {};
|
|
94
|
+
const resolvedMaxTokens = Number.isFinite(runtimeCfg.maxTokens)
|
|
95
|
+
? Math.max(1, Math.floor(runtimeCfg.maxTokens))
|
|
96
|
+
: (Number.isFinite(modelCfg.maxTokens) ? Math.max(1, Math.floor(modelCfg.maxTokens)) : null);
|
|
97
|
+
const resolvedTemperature = Number.isFinite(runtimeCfg.temperature)
|
|
98
|
+
? Number(runtimeCfg.temperature)
|
|
99
|
+
: (Number.isFinite(modelCfg.temperature) ? Number(modelCfg.temperature) : null);
|
|
100
|
+
const resolvedMaxOutputChars = Number.isFinite(runtimeCfg.maxOutputChars)
|
|
101
|
+
? Math.max(4096, Math.floor(runtimeCfg.maxOutputChars))
|
|
102
|
+
: (Number.isFinite(modelCfg.maxOutputChars) ? Math.max(4096, Math.floor(modelCfg.maxOutputChars)) : null);
|
|
103
|
+
if (!Number.isFinite(resolvedMaxTokens)) {
|
|
104
|
+
throw new Error('StructuredJsonHeadPipeline: structuredJsonHead.maxTokens is required.');
|
|
105
|
+
}
|
|
106
|
+
if (!Number.isFinite(resolvedTemperature)) {
|
|
107
|
+
throw new Error('StructuredJsonHeadPipeline: structuredJsonHead.temperature is required.');
|
|
108
|
+
}
|
|
109
|
+
if (!Number.isFinite(resolvedMaxOutputChars)) {
|
|
110
|
+
throw new Error('StructuredJsonHeadPipeline: structuredJsonHead.maxOutputChars is required.');
|
|
111
|
+
}
|
|
91
112
|
return {
|
|
92
|
-
maxTokens:
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
temperature: Number.isFinite(runtimeCfg.temperature)
|
|
96
|
-
? Number(runtimeCfg.temperature)
|
|
97
|
-
: (Number.isFinite(modelCfg.temperature) ? Number(modelCfg.temperature) : 0),
|
|
98
|
-
maxOutputChars: Number.isFinite(runtimeCfg.maxOutputChars)
|
|
99
|
-
? Math.max(4096, Math.floor(runtimeCfg.maxOutputChars))
|
|
100
|
-
: (Number.isFinite(modelCfg.maxOutputChars) ? Math.max(4096, Math.floor(modelCfg.maxOutputChars)) : 262144),
|
|
113
|
+
maxTokens: resolvedMaxTokens,
|
|
114
|
+
temperature: resolvedTemperature,
|
|
115
|
+
maxOutputChars: resolvedMaxOutputChars,
|
|
101
116
|
};
|
|
102
117
|
}
|
|
103
118
|
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import {
|
|
1
|
+
import { releaseBuffer } from '../../../../memory/buffer-pool.js';
|
|
2
2
|
import { isWeightBuffer, getLayout, getWeightDtype } from '../../../../gpu/weight-buffer.js';
|
|
3
3
|
import {
|
|
4
4
|
runMatmul,
|
|
@@ -36,7 +36,7 @@ function getRmsNormRunner(recorder) {
|
|
|
36
36
|
}
|
|
37
37
|
|
|
38
38
|
function releaseOwnedWeightBuffer(layerWeight, resolvedWeightBuffer, releaseTemporary) {
|
|
39
|
-
if (layerWeight instanceof GPUBuffer || isWeightBuffer(layerWeight)) {
|
|
39
|
+
if ((typeof GPUBuffer !== 'undefined' && layerWeight instanceof GPUBuffer) || isWeightBuffer(layerWeight)) {
|
|
40
40
|
return;
|
|
41
41
|
}
|
|
42
42
|
if (!resolvedWeightBuffer) {
|
|
@@ -66,10 +66,16 @@ async function projectSingleQkvTensor({
|
|
|
66
66
|
}) {
|
|
67
67
|
const runMatmulForMode = getMatmulRunner(recorder);
|
|
68
68
|
const layerWeight = layerWeights?.[weightKey];
|
|
69
|
-
|
|
69
|
+
if (!layerWeight) {
|
|
70
|
+
throw new Error(`Attention projection requires ${weightKey}.`);
|
|
71
|
+
}
|
|
72
|
+
if (!getWeightBuffer) {
|
|
73
|
+
throw new Error(`Attention projection requires getWeightBuffer for ${role}.`);
|
|
74
|
+
}
|
|
70
75
|
|
|
71
|
-
|
|
72
|
-
|
|
76
|
+
let projected;
|
|
77
|
+
const projBuffer = getWeightBuffer(layerWeight, role);
|
|
78
|
+
try {
|
|
73
79
|
projected = await runMatmulForMode(normed, projBuffer, numTokens, outputSize, hiddenSize, {
|
|
74
80
|
transposeB: 'auto',
|
|
75
81
|
role,
|
|
@@ -77,26 +83,31 @@ async function projectSingleQkvTensor({
|
|
|
77
83
|
kernelPath,
|
|
78
84
|
outputDtype: matmulOutputDtype,
|
|
79
85
|
});
|
|
86
|
+
} finally {
|
|
80
87
|
releaseOwnedWeightBuffer(layerWeight, projBuffer, releaseTemporary);
|
|
81
|
-
} else {
|
|
82
|
-
const fallback = acquireBuffer(numTokens * outputSize * 4, undefined, outputLabel);
|
|
83
|
-
projected = createTensor(fallback, normed.dtype, [numTokens, outputSize], outputLabel);
|
|
84
88
|
}
|
|
85
89
|
|
|
86
90
|
const loraModule = getLoRAModule(lora, layerIdx, loraKey);
|
|
87
91
|
if (loraModule && getWeightBuffer) {
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
92
|
+
try {
|
|
93
|
+
const combined = await applyLoRA(
|
|
94
|
+
normed,
|
|
95
|
+
projected,
|
|
96
|
+
loraModule,
|
|
97
|
+
{ M: numTokens, N: outputSize, K: hiddenSize },
|
|
98
|
+
getWeightBuffer,
|
|
99
|
+
recorder ?? undefined,
|
|
100
|
+
{ kernelPath }
|
|
101
|
+
);
|
|
102
|
+
if (combined.buffer !== projected.buffer) {
|
|
103
|
+
releaseTemporary(projected.buffer);
|
|
104
|
+
projected = combined;
|
|
105
|
+
}
|
|
106
|
+
} catch (error) {
|
|
107
|
+
if (projected?.buffer) {
|
|
108
|
+
releaseTemporary(projected.buffer);
|
|
109
|
+
}
|
|
110
|
+
throw error;
|
|
100
111
|
}
|
|
101
112
|
}
|
|
102
113
|
|
|
@@ -212,24 +223,42 @@ async function projectQueryWithOptionalGate({
|
|
|
212
223
|
bOffset: gateOffset,
|
|
213
224
|
outputDtype: matmulOutputDtype,
|
|
214
225
|
});
|
|
226
|
+
} catch (error) {
|
|
227
|
+
if (qTensor) {
|
|
228
|
+
releaseTemporary(qTensor.buffer);
|
|
229
|
+
}
|
|
230
|
+
if (qGateTensor) {
|
|
231
|
+
releaseTemporary(qGateTensor.buffer);
|
|
232
|
+
}
|
|
233
|
+
throw error;
|
|
215
234
|
} finally {
|
|
216
235
|
releaseOwnedWeightBuffer(qWeight, qWeightBuffer, releaseTemporary);
|
|
217
236
|
}
|
|
218
237
|
|
|
219
238
|
const loraModule = getLoRAModule(lora, layerIdx, 'q_proj');
|
|
220
239
|
if (loraModule && getWeightBuffer) {
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
240
|
+
try {
|
|
241
|
+
const combined = await applyLoRA(
|
|
242
|
+
normed,
|
|
243
|
+
qTensor,
|
|
244
|
+
loraModule,
|
|
245
|
+
{ M: numTokens, N: qSize, K: hiddenSize },
|
|
246
|
+
getWeightBuffer,
|
|
247
|
+
recorder ?? undefined,
|
|
248
|
+
{ kernelPath }
|
|
249
|
+
);
|
|
250
|
+
if (combined.buffer !== qTensor.buffer) {
|
|
251
|
+
releaseTemporary(qTensor.buffer);
|
|
252
|
+
qTensor = combined;
|
|
253
|
+
}
|
|
254
|
+
} catch (error) {
|
|
255
|
+
if (qTensor?.buffer) {
|
|
256
|
+
releaseTemporary(qTensor.buffer);
|
|
257
|
+
}
|
|
258
|
+
if (qGateTensor?.buffer) {
|
|
259
|
+
releaseTemporary(qGateTensor.buffer);
|
|
260
|
+
}
|
|
261
|
+
throw error;
|
|
233
262
|
}
|
|
234
263
|
}
|
|
235
264
|
|
|
@@ -289,82 +318,103 @@ export async function projectAttentionQKV({
|
|
|
289
318
|
if (useFusedQKV && layerWeights.qkvProj && layerWeights.qkvSizes) {
|
|
290
319
|
const [qSizeFused, kSizeFused, vSizeFused] = layerWeights.qkvSizes;
|
|
291
320
|
const qkvSizeTotal = qSizeFused + kSizeFused + vSizeFused;
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
321
|
+
let qkvTensor = null;
|
|
322
|
+
try {
|
|
323
|
+
qkvTensor = await runMatmulForMode(normed, layerWeights.qkvProj, numTokens, qkvSizeTotal, hiddenSize, {
|
|
324
|
+
transposeB: 'auto',
|
|
325
|
+
role: 'qkv_proj',
|
|
326
|
+
layerIdx,
|
|
327
|
+
kernelPath,
|
|
328
|
+
outputDtype: matmulOutputDtype,
|
|
329
|
+
});
|
|
330
|
+
const split = await runSplitForMode(qkvTensor, {
|
|
331
|
+
numTokens,
|
|
332
|
+
qSize: qSizeFused,
|
|
333
|
+
kSize: kSizeFused,
|
|
334
|
+
vSize: vSizeFused,
|
|
335
|
+
});
|
|
336
|
+
releaseTemporary(qkvTensor.buffer);
|
|
337
|
+
if (onFusedQKV) {
|
|
338
|
+
onFusedQKV({ qSize: qSizeFused, kSize: kSizeFused, vSize: vSizeFused, totalSize: qkvSizeTotal });
|
|
339
|
+
}
|
|
340
|
+
return { qTensor: split.Q, qGateTensor: null, kTensor: split.K, vTensor: split.V, usedFusedQKV: true };
|
|
341
|
+
} catch (error) {
|
|
342
|
+
if (qkvTensor) {
|
|
343
|
+
releaseTemporary(qkvTensor.buffer);
|
|
344
|
+
}
|
|
345
|
+
throw error;
|
|
346
|
+
}
|
|
347
|
+
}
|
|
348
|
+
|
|
349
|
+
let qTensor = null;
|
|
350
|
+
let qGateTensor = null;
|
|
351
|
+
let kTensor = null;
|
|
352
|
+
let vTensor = null;
|
|
353
|
+
try {
|
|
354
|
+
({ qTensor, qGateTensor } = await projectQueryWithOptionalGate({
|
|
355
|
+
recorder,
|
|
356
|
+
normed,
|
|
357
|
+
layerWeights,
|
|
358
|
+
numTokens,
|
|
359
|
+
numHeads,
|
|
360
|
+
headDim,
|
|
361
|
+
hiddenSize,
|
|
295
362
|
layerIdx,
|
|
296
363
|
kernelPath,
|
|
297
|
-
|
|
364
|
+
matmulOutputDtype,
|
|
365
|
+
getWeightBuffer,
|
|
366
|
+
lora,
|
|
367
|
+
releaseTemporary,
|
|
368
|
+
attentionOutputGate,
|
|
369
|
+
}));
|
|
370
|
+
|
|
371
|
+
kTensor = await projectSingleQkvTensor({
|
|
372
|
+
recorder,
|
|
373
|
+
normed,
|
|
374
|
+
layerWeights,
|
|
375
|
+
weightKey: 'kProj',
|
|
376
|
+
role: 'k_proj',
|
|
377
|
+
outputSize: numKVHeads * headDim,
|
|
378
|
+
outputLabel: 'K',
|
|
379
|
+
loraKey: 'k_proj',
|
|
380
|
+
numTokens,
|
|
381
|
+
hiddenSize,
|
|
382
|
+
layerIdx,
|
|
383
|
+
kernelPath,
|
|
384
|
+
matmulOutputDtype,
|
|
385
|
+
getWeightBuffer,
|
|
386
|
+
lora,
|
|
387
|
+
releaseTemporary,
|
|
298
388
|
});
|
|
299
|
-
|
|
389
|
+
|
|
390
|
+
vTensor = await projectSingleQkvTensor({
|
|
391
|
+
recorder,
|
|
392
|
+
normed,
|
|
393
|
+
layerWeights,
|
|
394
|
+
weightKey: 'vProj',
|
|
395
|
+
role: 'v_proj',
|
|
396
|
+
outputSize: numKVHeads * headDim,
|
|
397
|
+
outputLabel: 'V',
|
|
398
|
+
loraKey: 'v_proj',
|
|
300
399
|
numTokens,
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
400
|
+
hiddenSize,
|
|
401
|
+
layerIdx,
|
|
402
|
+
kernelPath,
|
|
403
|
+
matmulOutputDtype,
|
|
404
|
+
getWeightBuffer,
|
|
405
|
+
lora,
|
|
406
|
+
releaseTemporary,
|
|
304
407
|
});
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
408
|
+
|
|
409
|
+
return { qTensor, qGateTensor, kTensor, vTensor, usedFusedQKV: false };
|
|
410
|
+
} catch (error) {
|
|
411
|
+
for (const tensor of [qTensor, qGateTensor, kTensor, vTensor]) {
|
|
412
|
+
if (tensor?.buffer) {
|
|
413
|
+
releaseTemporary(tensor.buffer);
|
|
414
|
+
}
|
|
308
415
|
}
|
|
309
|
-
|
|
416
|
+
throw error;
|
|
310
417
|
}
|
|
311
|
-
|
|
312
|
-
const { qTensor, qGateTensor } = await projectQueryWithOptionalGate({
|
|
313
|
-
recorder,
|
|
314
|
-
normed,
|
|
315
|
-
layerWeights,
|
|
316
|
-
numTokens,
|
|
317
|
-
numHeads,
|
|
318
|
-
headDim,
|
|
319
|
-
hiddenSize,
|
|
320
|
-
layerIdx,
|
|
321
|
-
kernelPath,
|
|
322
|
-
matmulOutputDtype,
|
|
323
|
-
getWeightBuffer,
|
|
324
|
-
lora,
|
|
325
|
-
releaseTemporary,
|
|
326
|
-
attentionOutputGate,
|
|
327
|
-
});
|
|
328
|
-
|
|
329
|
-
const kTensor = await projectSingleQkvTensor({
|
|
330
|
-
recorder,
|
|
331
|
-
normed,
|
|
332
|
-
layerWeights,
|
|
333
|
-
weightKey: 'kProj',
|
|
334
|
-
role: 'k_proj',
|
|
335
|
-
outputSize: numKVHeads * headDim,
|
|
336
|
-
outputLabel: 'K',
|
|
337
|
-
loraKey: 'k_proj',
|
|
338
|
-
numTokens,
|
|
339
|
-
hiddenSize,
|
|
340
|
-
layerIdx,
|
|
341
|
-
kernelPath,
|
|
342
|
-
matmulOutputDtype,
|
|
343
|
-
getWeightBuffer,
|
|
344
|
-
lora,
|
|
345
|
-
releaseTemporary,
|
|
346
|
-
});
|
|
347
|
-
|
|
348
|
-
const vTensor = await projectSingleQkvTensor({
|
|
349
|
-
recorder,
|
|
350
|
-
normed,
|
|
351
|
-
layerWeights,
|
|
352
|
-
weightKey: 'vProj',
|
|
353
|
-
role: 'v_proj',
|
|
354
|
-
outputSize: numKVHeads * headDim,
|
|
355
|
-
outputLabel: 'V',
|
|
356
|
-
loraKey: 'v_proj',
|
|
357
|
-
numTokens,
|
|
358
|
-
hiddenSize,
|
|
359
|
-
layerIdx,
|
|
360
|
-
kernelPath,
|
|
361
|
-
matmulOutputDtype,
|
|
362
|
-
getWeightBuffer,
|
|
363
|
-
lora,
|
|
364
|
-
releaseTemporary,
|
|
365
|
-
});
|
|
366
|
-
|
|
367
|
-
return { qTensor, qGateTensor, kTensor, vTensor, usedFusedQKV: false };
|
|
368
418
|
}
|
|
369
419
|
|
|
370
420
|
export async function applyAttentionQKNorm({
|
|
@@ -90,9 +90,20 @@ export async function recordLayerAttentionGPU(
|
|
|
90
90
|
const allowF16Attention = wantsF16Output && kvCacheDtype === 'f16';
|
|
91
91
|
let attentionInput = input;
|
|
92
92
|
let attentionInputTemp = false;
|
|
93
|
+
let normed = attentionInput;
|
|
94
|
+
let qTensor = null;
|
|
95
|
+
let qGateTensor = null;
|
|
96
|
+
let kTensor = null;
|
|
97
|
+
let vTensor = null;
|
|
98
|
+
let attnOutput = null;
|
|
99
|
+
let attnForProjection = null;
|
|
100
|
+
let output = null;
|
|
101
|
+
let finalOutput = null;
|
|
102
|
+
let oProjInputTemp = null;
|
|
93
103
|
if (wantsF16Output && !allowF16Attention) {
|
|
94
104
|
attentionInput = await recordCastF16ToF32(recorder, input);
|
|
95
105
|
attentionInputTemp = true;
|
|
106
|
+
normed = attentionInput;
|
|
96
107
|
}
|
|
97
108
|
|
|
98
109
|
if (!layerWeights) {
|
|
@@ -108,7 +119,7 @@ export async function recordLayerAttentionGPU(
|
|
|
108
119
|
|
|
109
120
|
// 1. Input norm
|
|
110
121
|
|
|
111
|
-
|
|
122
|
+
try {
|
|
112
123
|
if (!skipInputNorm && layerWeights.inputNorm && getNormWeightBuffer) {
|
|
113
124
|
const normWeightBuf = getNormWeightBuffer(layerWeights.inputNorm, 'input_norm');
|
|
114
125
|
normed = await recordRMSNorm(recorder, attentionInput, normWeightBuf, rmsNormEps, {
|
|
@@ -132,7 +143,8 @@ export async function recordLayerAttentionGPU(
|
|
|
132
143
|
|
|
133
144
|
// 2. Q/K/V projections
|
|
134
145
|
const matmulOutputDtype = resolveAttentionProjectionOutputDtype(desiredOutputDtype);
|
|
135
|
-
let
|
|
146
|
+
let usedFusedQKV = false;
|
|
147
|
+
({ qTensor, qGateTensor, kTensor, vTensor, usedFusedQKV } = await projectAttentionQKV({
|
|
136
148
|
recorder,
|
|
137
149
|
normed,
|
|
138
150
|
layerWeights,
|
|
@@ -153,7 +165,7 @@ export async function recordLayerAttentionGPU(
|
|
|
153
165
|
trace.attn(layerIdx, `Using fused QKV path: ${qSizeFused}+${kSizeFused}+${vSizeFused}=${totalSize}`);
|
|
154
166
|
}
|
|
155
167
|
: null,
|
|
156
|
-
});
|
|
168
|
+
}));
|
|
157
169
|
|
|
158
170
|
// Optional per-head Q/K normalization.
|
|
159
171
|
// Some models use RMSNorm with (1+weight) offset formula, controlled by rmsNormWeightOffset.
|
|
@@ -182,10 +194,18 @@ export async function recordLayerAttentionGPU(
|
|
|
182
194
|
// 3. RoPE (modifies tensor in-place)
|
|
183
195
|
if (!disableRoPE && state.ropeFreqsCos && state.ropeFreqsSin) {
|
|
184
196
|
await recordRoPE(recorder, qTensor, state.ropeFreqsCos, state.ropeFreqsSin, numTokens, {
|
|
185
|
-
numHeads,
|
|
197
|
+
numHeads,
|
|
198
|
+
headDim,
|
|
199
|
+
rotaryDim: config.ropeRotaryDim,
|
|
200
|
+
interleaved: config.ropeInterleaved,
|
|
201
|
+
startPos: currentSeqLen,
|
|
186
202
|
});
|
|
187
203
|
await recordRoPE(recorder, kTensor, state.ropeFreqsCos, state.ropeFreqsSin, numTokens, {
|
|
188
|
-
numHeads: numKVHeads,
|
|
204
|
+
numHeads: numKVHeads,
|
|
205
|
+
headDim,
|
|
206
|
+
rotaryDim: config.ropeRotaryDim,
|
|
207
|
+
interleaved: config.ropeInterleaved,
|
|
208
|
+
startPos: currentSeqLen,
|
|
189
209
|
});
|
|
190
210
|
}
|
|
191
211
|
|
|
@@ -494,14 +514,15 @@ export async function recordLayerAttentionGPU(
|
|
|
494
514
|
throw new Error(`Unsupported attention kernel variant "${attentionKernelVariant}" at layer ${layerIdx}`);
|
|
495
515
|
}
|
|
496
516
|
|
|
497
|
-
|
|
517
|
+
attnOutput = await runAttentionKernel();
|
|
498
518
|
|
|
499
|
-
|
|
519
|
+
attnForProjection = attnOutput;
|
|
500
520
|
if (qGateTensor) {
|
|
501
521
|
attnForProjection = await recordSiLU(recorder, attnOutput, {
|
|
502
522
|
size: numTokens * numHeads * headDim,
|
|
503
523
|
gate: qGateTensor,
|
|
504
524
|
gateActivation: 'sigmoid',
|
|
525
|
+
inputActivation: 'identity',
|
|
505
526
|
swigluLimit: null,
|
|
506
527
|
});
|
|
507
528
|
recorder.trackTemporaryBuffer(attnOutput.buffer);
|
|
@@ -509,10 +530,10 @@ export async function recordLayerAttentionGPU(
|
|
|
509
530
|
|
|
510
531
|
// 6. Output projection (with optional fused residual for decode)
|
|
511
532
|
|
|
512
|
-
|
|
533
|
+
output = null;
|
|
513
534
|
let residualFused = false;
|
|
514
535
|
let oProjInput = attnForProjection;
|
|
515
|
-
|
|
536
|
+
oProjInputTemp = null;
|
|
516
537
|
if (layerWeights.oProj && getWeightBuffer) {
|
|
517
538
|
const oProjBuf = getWeightBuffer(layerWeights.oProj, 'o_proj');
|
|
518
539
|
const loraO = getLoRAModule(lora, layerIdx, 'o_proj');
|
|
@@ -580,7 +601,7 @@ export async function recordLayerAttentionGPU(
|
|
|
580
601
|
}
|
|
581
602
|
}
|
|
582
603
|
|
|
583
|
-
|
|
604
|
+
finalOutput = output;
|
|
584
605
|
|
|
585
606
|
const buffersToTrack = [];
|
|
586
607
|
if (output.buffer !== attnForProjection.buffer) {
|
|
@@ -610,4 +631,46 @@ export async function recordLayerAttentionGPU(
|
|
|
610
631
|
}
|
|
611
632
|
|
|
612
633
|
return { output: finalOutput, residualFused };
|
|
634
|
+
} catch (error) {
|
|
635
|
+
const tracked = new Set();
|
|
636
|
+
const trackOnce = (buffer) => {
|
|
637
|
+
if (!buffer || tracked.has(buffer)) return;
|
|
638
|
+
tracked.add(buffer);
|
|
639
|
+
recorder.trackTemporaryBuffer(buffer);
|
|
640
|
+
};
|
|
641
|
+
if (finalOutput?.buffer && finalOutput.buffer !== output?.buffer) {
|
|
642
|
+
trackOnce(finalOutput.buffer);
|
|
643
|
+
}
|
|
644
|
+
if (output?.buffer && output.buffer !== attnForProjection?.buffer) {
|
|
645
|
+
trackOnce(output.buffer);
|
|
646
|
+
}
|
|
647
|
+
if (oProjInputTemp?.buffer) {
|
|
648
|
+
trackOnce(oProjInputTemp.buffer);
|
|
649
|
+
}
|
|
650
|
+
if (attnForProjection?.buffer && attnForProjection.buffer !== attnOutput?.buffer) {
|
|
651
|
+
trackOnce(attnForProjection.buffer);
|
|
652
|
+
}
|
|
653
|
+
if (attnOutput?.buffer) {
|
|
654
|
+
trackOnce(attnOutput.buffer);
|
|
655
|
+
}
|
|
656
|
+
if (qGateTensor?.buffer) {
|
|
657
|
+
trackOnce(qGateTensor.buffer);
|
|
658
|
+
}
|
|
659
|
+
if (qTensor?.buffer) {
|
|
660
|
+
trackOnce(qTensor.buffer);
|
|
661
|
+
}
|
|
662
|
+
if (kTensor?.buffer) {
|
|
663
|
+
trackOnce(kTensor.buffer);
|
|
664
|
+
}
|
|
665
|
+
if (vTensor?.buffer) {
|
|
666
|
+
trackOnce(vTensor.buffer);
|
|
667
|
+
}
|
|
668
|
+
if (normed?.buffer && normed.buffer !== attentionInput?.buffer) {
|
|
669
|
+
trackOnce(normed.buffer);
|
|
670
|
+
}
|
|
671
|
+
if (attentionInputTemp && attentionInput?.buffer) {
|
|
672
|
+
trackOnce(attentionInput.buffer);
|
|
673
|
+
}
|
|
674
|
+
throw error;
|
|
675
|
+
}
|
|
613
676
|
}
|
|
@@ -97,9 +97,20 @@ export async function runLayerAttentionGPU(
|
|
|
97
97
|
const allowF16Attention = wantsF16Output && kvCacheDtype === 'f16';
|
|
98
98
|
let attentionInput = input;
|
|
99
99
|
let attentionInputTemp = false;
|
|
100
|
+
let normed = attentionInput;
|
|
101
|
+
let qTensor = null;
|
|
102
|
+
let qGateTensor = null;
|
|
103
|
+
let kTensor = null;
|
|
104
|
+
let vTensor = null;
|
|
105
|
+
let attnOutput = null;
|
|
106
|
+
let attnForProjection = null;
|
|
107
|
+
let output = null;
|
|
108
|
+
let finalOutput = null;
|
|
109
|
+
let oProjInputTemp = null;
|
|
100
110
|
if (wantsF16Output && !allowF16Attention) {
|
|
101
111
|
attentionInput = await castF16ToF32(input);
|
|
102
112
|
attentionInputTemp = true;
|
|
113
|
+
normed = attentionInput;
|
|
103
114
|
}
|
|
104
115
|
|
|
105
116
|
// Debug: attention input for configured layers
|
|
@@ -123,7 +134,7 @@ export async function runLayerAttentionGPU(
|
|
|
123
134
|
|
|
124
135
|
// 1. Input norm
|
|
125
136
|
|
|
126
|
-
|
|
137
|
+
try {
|
|
127
138
|
if (!skipInputNorm && layerWeights.inputNorm && getNormWeightBuffer) {
|
|
128
139
|
const normWeightBuf = getNormWeightBuffer(layerWeights.inputNorm, 'input_norm');
|
|
129
140
|
|
|
@@ -183,7 +194,8 @@ export async function runLayerAttentionGPU(
|
|
|
183
194
|
|
|
184
195
|
// 2. Q/K/V projections
|
|
185
196
|
const matmulOutputDtype = resolveAttentionProjectionOutputDtype(desiredOutputDtype);
|
|
186
|
-
let
|
|
197
|
+
let usedFusedQKV = false;
|
|
198
|
+
({ qTensor, qGateTensor, kTensor, vTensor, usedFusedQKV } = await projectAttentionQKV({
|
|
187
199
|
recorder: null,
|
|
188
200
|
normed,
|
|
189
201
|
layerWeights,
|
|
@@ -204,7 +216,7 @@ export async function runLayerAttentionGPU(
|
|
|
204
216
|
trace.attn(layerIdx, `Using fused QKV path: ${qSizeFused}+${kSizeFused}+${vSizeFused}=${totalSize}`);
|
|
205
217
|
}
|
|
206
218
|
: null,
|
|
207
|
-
});
|
|
219
|
+
}));
|
|
208
220
|
|
|
209
221
|
// Trace Q/K/V projections
|
|
210
222
|
if (kernelTrace.enabled) {
|
|
@@ -299,10 +311,18 @@ export async function runLayerAttentionGPU(
|
|
|
299
311
|
|
|
300
312
|
if (!disableRoPE && state.ropeFreqsCos && state.ropeFreqsSin) {
|
|
301
313
|
await runRoPE(qTensor, state.ropeFreqsCos, state.ropeFreqsSin, numTokens, {
|
|
302
|
-
numHeads,
|
|
314
|
+
numHeads,
|
|
315
|
+
headDim,
|
|
316
|
+
rotaryDim: config.ropeRotaryDim,
|
|
317
|
+
interleaved: config.ropeInterleaved,
|
|
318
|
+
startPos: currentSeqLen,
|
|
303
319
|
});
|
|
304
320
|
await runRoPE(kTensor, state.ropeFreqsCos, state.ropeFreqsSin, numTokens, {
|
|
305
|
-
numHeads: numKVHeads,
|
|
321
|
+
numHeads: numKVHeads,
|
|
322
|
+
headDim,
|
|
323
|
+
rotaryDim: config.ropeRotaryDim,
|
|
324
|
+
interleaved: config.ropeInterleaved,
|
|
325
|
+
startPos: currentSeqLen,
|
|
306
326
|
});
|
|
307
327
|
|
|
308
328
|
// Trace RoPE outputs
|
|
@@ -661,7 +681,7 @@ export async function runLayerAttentionGPU(
|
|
|
661
681
|
throw new Error(`Unsupported attention kernel variant "${attentionKernelVariant}" at layer ${layerIdx}`);
|
|
662
682
|
}
|
|
663
683
|
|
|
664
|
-
|
|
684
|
+
attnOutput = await runAttentionKernel();
|
|
665
685
|
|
|
666
686
|
// Trace attention output
|
|
667
687
|
if (kernelTrace.enabled) {
|
|
@@ -684,12 +704,13 @@ export async function runLayerAttentionGPU(
|
|
|
684
704
|
await debugCheckBuffer(attnOutput.buffer, `L${layerIdx} attention output (before o_proj, GPU)`, numTokens, numHeads * headDim);
|
|
685
705
|
}
|
|
686
706
|
|
|
687
|
-
|
|
707
|
+
attnForProjection = attnOutput;
|
|
688
708
|
if (qGateTensor) {
|
|
689
709
|
attnForProjection = await runSiLU(attnOutput, {
|
|
690
710
|
size: numTokens * numHeads * headDim,
|
|
691
711
|
gate: qGateTensor,
|
|
692
712
|
gateActivation: 'sigmoid',
|
|
713
|
+
inputActivation: 'identity',
|
|
693
714
|
swigluLimit: null,
|
|
694
715
|
});
|
|
695
716
|
releaseBuffer(attnOutput.buffer);
|
|
@@ -697,10 +718,10 @@ export async function runLayerAttentionGPU(
|
|
|
697
718
|
|
|
698
719
|
// 6. Output projection (with optional fused residual for decode)
|
|
699
720
|
|
|
700
|
-
|
|
721
|
+
output = null;
|
|
701
722
|
let residualFused = false;
|
|
702
723
|
let oProjInput = attnForProjection;
|
|
703
|
-
|
|
724
|
+
oProjInputTemp = null;
|
|
704
725
|
if (layerWeights.oProj && getWeightBuffer) {
|
|
705
726
|
const oProjBuf = getWeightBuffer(layerWeights.oProj, 'o_proj');
|
|
706
727
|
const loraO = getLoRAModule(lora, layerIdx, 'o_proj');
|
|
@@ -798,7 +819,7 @@ export async function runLayerAttentionGPU(
|
|
|
798
819
|
await debugCheckBuffer(output.buffer, `L${layerIdx} attention output (after o_proj, GPU)`, numTokens, hiddenSize);
|
|
799
820
|
}
|
|
800
821
|
|
|
801
|
-
|
|
822
|
+
finalOutput = output;
|
|
802
823
|
|
|
803
824
|
const buffersToRelease = [];
|
|
804
825
|
if (output.buffer !== attnForProjection.buffer) {
|
|
@@ -823,4 +844,46 @@ export async function runLayerAttentionGPU(
|
|
|
823
844
|
}
|
|
824
845
|
|
|
825
846
|
return { output: finalOutput, residualFused };
|
|
847
|
+
} catch (error) {
|
|
848
|
+
const released = new Set();
|
|
849
|
+
const releaseOnce = (buffer) => {
|
|
850
|
+
if (!buffer || released.has(buffer)) return;
|
|
851
|
+
released.add(buffer);
|
|
852
|
+
releaseBuffer(buffer);
|
|
853
|
+
};
|
|
854
|
+
if (finalOutput?.buffer && finalOutput.buffer !== output?.buffer) {
|
|
855
|
+
releaseOnce(finalOutput.buffer);
|
|
856
|
+
}
|
|
857
|
+
if (output?.buffer && output.buffer !== attnForProjection?.buffer) {
|
|
858
|
+
releaseOnce(output.buffer);
|
|
859
|
+
}
|
|
860
|
+
if (oProjInputTemp?.buffer) {
|
|
861
|
+
releaseOnce(oProjInputTemp.buffer);
|
|
862
|
+
}
|
|
863
|
+
if (attnForProjection?.buffer && attnForProjection.buffer !== attnOutput?.buffer) {
|
|
864
|
+
releaseOnce(attnForProjection.buffer);
|
|
865
|
+
}
|
|
866
|
+
if (attnOutput?.buffer) {
|
|
867
|
+
releaseOnce(attnOutput.buffer);
|
|
868
|
+
}
|
|
869
|
+
if (qGateTensor?.buffer) {
|
|
870
|
+
releaseOnce(qGateTensor.buffer);
|
|
871
|
+
}
|
|
872
|
+
if (qTensor?.buffer) {
|
|
873
|
+
releaseOnce(qTensor.buffer);
|
|
874
|
+
}
|
|
875
|
+
if (kTensor?.buffer) {
|
|
876
|
+
releaseOnce(kTensor.buffer);
|
|
877
|
+
}
|
|
878
|
+
if (vTensor?.buffer) {
|
|
879
|
+
releaseOnce(vTensor.buffer);
|
|
880
|
+
}
|
|
881
|
+
if (normed?.buffer && normed.buffer !== attentionInput?.buffer) {
|
|
882
|
+
releaseOnce(normed.buffer);
|
|
883
|
+
}
|
|
884
|
+
if (attentionInputTemp && attentionInput?.buffer) {
|
|
885
|
+
releaseOnce(attentionInput.buffer);
|
|
886
|
+
}
|
|
887
|
+
throw error;
|
|
888
|
+
}
|
|
826
889
|
}
|