@simulatte/doppler 0.1.5 → 0.1.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +126 -0
- package/README.md +25 -17
- package/package.json +20 -4
- package/src/adapters/adapter-registry.js +12 -1
- package/src/adapters/lora-loader.js +23 -6
- package/src/bridge/extension-client.d.ts +5 -0
- package/src/bridge/extension-client.js +40 -0
- package/src/bridge/index.d.ts +2 -1
- package/src/bridge/index.js +6 -4
- package/src/browser/browser-converter.js +26 -1
- package/src/browser/file-picker.js +6 -0
- package/src/browser/safetensors-parser-browser.js +84 -1
- package/src/browser/shard-io-browser.js +2 -2
- package/src/browser/tensor-source-download.js +8 -2
- package/src/browser/tensor-source-http.d.ts +1 -0
- package/src/browser/tensor-source-http.js +5 -1
- package/src/client/doppler-api.browser.js +20 -4
- package/src/client/doppler-api.js +19 -3
- package/src/client/doppler-provider/generation.js +12 -0
- package/src/client/doppler-provider/model-manager.d.ts +10 -0
- package/src/client/doppler-provider/model-manager.js +91 -19
- package/src/client/doppler-provider/source-runtime.d.ts +2 -1
- package/src/client/doppler-provider/source-runtime.js +132 -13
- package/src/client/doppler-registry.json +8 -7
- package/src/config/backward-registry-loader.js +17 -2
- package/src/config/execution-v0-contract-check.js +113 -15
- package/src/config/kernel-path-contract-check.js +57 -29
- package/src/config/kernel-path-loader.js +5 -36
- package/src/config/kernels/kernel-ref-digests.js +39 -39
- package/src/config/kernels/registry.js +14 -1
- package/src/config/kernels/registry.json +49 -7
- package/src/config/loader.d.ts +1 -1
- package/src/config/loader.js +43 -4
- package/src/config/merge-contract-check.js +59 -4
- package/src/config/merge-helpers.js +128 -7
- package/src/config/merge.d.ts +1 -0
- package/src/config/merge.js +28 -0
- package/src/config/param-validator.js +47 -2
- package/src/config/presets/kernel-paths/{gemma2-q4k-dequant-f32a.json → gemma2-q4k-dequant-f32a-nosubgroups.json} +3 -3
- package/src/config/presets/kernel-paths/gemma3-f16-fused-f32a-online-streamingprefill.json +223 -0
- package/src/config/presets/kernel-paths/{gemma3-q4k-dequant-f32a.json → gemma3-q4k-dequant-f32a-nosubgroups.json} +3 -3
- package/src/config/presets/kernel-paths/registry.json +29 -8
- package/src/config/presets/models/gemma2.json +2 -2
- package/src/config/presets/models/qwen3.json +9 -2
- package/src/config/presets/models/transformer.json +5 -0
- package/src/config/presets/runtime/experiments/bench/gemma3-bench-q4k.json +1 -1
- package/src/config/presets/runtime/experiments/debug/gemma3-debug-q4k.json +1 -1
- package/src/config/presets/runtime/experiments/verify/gemma3-verify.json +1 -1
- package/src/config/presets/runtime/kernels/dequant-f16-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/dequant-f32-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/embeddinggemma-q4k-dequant-f32a.json +37 -0
- package/src/config/presets/runtime/kernels/fused-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f16a.json +33 -0
- package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f32a-nosubgroups.json +33 -0
- package/src/config/presets/runtime/kernels/gemma2-q4k-fused-f32a.json +33 -0
- package/src/config/presets/runtime/kernels/safe-q4k.json +6 -13
- package/src/config/presets/runtime/platform/metal-apple-q4k.json +1 -1
- package/src/config/required-inference-fields-contract-check.js +6 -0
- package/src/config/runtime.js +6 -1
- package/src/config/schema/debug.schema.d.ts +5 -0
- package/src/config/schema/doppler.schema.js +16 -21
- package/src/config/schema/inference-defaults.schema.js +6 -3
- package/src/config/schema/inference.schema.d.ts +9 -0
- package/src/config/schema/kernel-path.schema.d.ts +11 -1
- package/src/config/schema/kernel-thresholds.schema.js +12 -4
- package/src/config/schema/manifest.schema.d.ts +8 -1
- package/src/config/schema/manifest.schema.js +19 -3
- package/src/config/training-defaults.js +30 -22
- package/src/converter/conversion-plan.js +94 -9
- package/src/converter/core.d.ts +7 -0
- package/src/converter/core.js +14 -9
- package/src/converter/execution-v0-manifest.js +4 -1
- package/src/converter/index.d.ts +1 -0
- package/src/converter/index.js +1 -0
- package/src/converter/manifest-inference.js +43 -12
- package/src/converter/parsers/diffusion.js +0 -3
- package/src/converter/quantization-info.js +35 -15
- package/src/converter/rope-config.js +42 -0
- package/src/converter/shard-packer.d.ts +1 -1
- package/src/converter/shard-packer.js +4 -1
- package/src/debug/config.js +123 -11
- package/src/debug/signals.js +7 -1
- package/src/debug/tensor.d.ts +2 -0
- package/src/debug/tensor.js +13 -2
- package/src/distribution/p2p-control-plane.js +52 -12
- package/src/distribution/p2p-observability.js +43 -7
- package/src/distribution/p2p-webrtc-browser.js +20 -0
- package/src/distribution/shard-delivery.js +77 -26
- package/src/formats/gguf/types.js +33 -16
- package/src/formats/rdrr/groups.d.ts +12 -4
- package/src/formats/rdrr/groups.js +3 -6
- package/src/formats/rdrr/parsing.js +39 -2
- package/src/formats/rdrr/types.d.ts +2 -1
- package/src/gpu/command-recorder.js +86 -61
- package/src/gpu/device.d.ts +1 -0
- package/src/gpu/device.js +131 -19
- package/src/gpu/kernel-tuner/benchmarks.js +326 -316
- package/src/gpu/kernel-tuner/cache.js +71 -4
- package/src/gpu/kernel-tuner/tuner.js +22 -4
- package/src/gpu/kernels/attention.js +113 -34
- package/src/gpu/kernels/backward/adam.js +62 -58
- package/src/gpu/kernels/backward/attention_backward.js +257 -169
- package/src/gpu/kernels/backward/conv2d_backward.js +14 -1
- package/src/gpu/kernels/bias_add.wgsl +8 -6
- package/src/gpu/kernels/bias_add_f16.wgsl +8 -5
- package/src/gpu/kernels/cast.js +191 -149
- package/src/gpu/kernels/check-stop.js +33 -44
- package/src/gpu/kernels/conv2d.js +27 -17
- package/src/gpu/kernels/conv2d.wgsl +7 -8
- package/src/gpu/kernels/conv2d_f16.wgsl +7 -8
- package/src/gpu/kernels/cross_entropy_loss.js +21 -15
- package/src/gpu/kernels/depthwise_conv2d.js +37 -26
- package/src/gpu/kernels/depthwise_conv2d.wgsl +6 -9
- package/src/gpu/kernels/depthwise_conv2d_f16.wgsl +6 -9
- package/src/gpu/kernels/dequant.js +178 -126
- package/src/gpu/kernels/energy.d.ts +3 -21
- package/src/gpu/kernels/energy.js +111 -88
- package/src/gpu/kernels/feature-check.js +1 -1
- package/src/gpu/kernels/fused_ffn.js +84 -65
- package/src/gpu/kernels/fused_matmul_residual.js +56 -33
- package/src/gpu/kernels/fused_matmul_rmsnorm.js +62 -45
- package/src/gpu/kernels/gather.js +33 -15
- package/src/gpu/kernels/gelu.js +19 -11
- package/src/gpu/kernels/grouped_pointwise_conv2d.js +34 -23
- package/src/gpu/kernels/grouped_pointwise_conv2d.wgsl +6 -9
- package/src/gpu/kernels/grouped_pointwise_conv2d_f16.wgsl +6 -9
- package/src/gpu/kernels/groupnorm.js +34 -23
- package/src/gpu/kernels/kv-quantize.js +5 -2
- package/src/gpu/kernels/layernorm.js +35 -19
- package/src/gpu/kernels/logit-merge.js +5 -3
- package/src/gpu/kernels/matmul.js +83 -39
- package/src/gpu/kernels/modulate.js +23 -15
- package/src/gpu/kernels/moe.js +221 -175
- package/src/gpu/kernels/pixel_shuffle.js +22 -14
- package/src/gpu/kernels/pixel_shuffle.wgsl +4 -5
- package/src/gpu/kernels/pixel_shuffle_f16.wgsl +4 -5
- package/src/gpu/kernels/relu.js +31 -10
- package/src/gpu/kernels/relu.wgsl +2 -1
- package/src/gpu/kernels/relu_f16.wgsl +2 -1
- package/src/gpu/kernels/repeat_channels.js +25 -17
- package/src/gpu/kernels/repeat_channels.wgsl +4 -5
- package/src/gpu/kernels/repeat_channels_f16.wgsl +4 -5
- package/src/gpu/kernels/residual.js +69 -23
- package/src/gpu/kernels/residual.wgsl +6 -3
- package/src/gpu/kernels/residual_f16.wgsl +2 -1
- package/src/gpu/kernels/residual_f16_vec4.wgsl +2 -1
- package/src/gpu/kernels/residual_vec4.wgsl +2 -1
- package/src/gpu/kernels/rmsnorm.js +96 -28
- package/src/gpu/kernels/rmsnorm.wgsl +14 -6
- package/src/gpu/kernels/rmsnorm_f16.wgsl +10 -2
- package/src/gpu/kernels/rope.d.ts +2 -0
- package/src/gpu/kernels/rope.js +14 -1
- package/src/gpu/kernels/rope.wgsl +56 -40
- package/src/gpu/kernels/sample.js +27 -38
- package/src/gpu/kernels/sana_linear_attention.js +19 -12
- package/src/gpu/kernels/sana_linear_attention_apply.wgsl +4 -5
- package/src/gpu/kernels/sana_linear_attention_apply_f16.wgsl +4 -5
- package/src/gpu/kernels/sana_linear_attention_summary.wgsl +4 -0
- package/src/gpu/kernels/sana_linear_attention_summary_f16.wgsl +4 -0
- package/src/gpu/kernels/scale.js +18 -11
- package/src/gpu/kernels/shader-cache.js +4 -2
- package/src/gpu/kernels/silu.d.ts +1 -0
- package/src/gpu/kernels/silu.js +148 -82
- package/src/gpu/kernels/silu.wgsl +19 -9
- package/src/gpu/kernels/silu_f16.wgsl +19 -9
- package/src/gpu/kernels/softmax.js +44 -25
- package/src/gpu/kernels/split_qkv.js +23 -13
- package/src/gpu/kernels/transpose.js +31 -10
- package/src/gpu/kernels/transpose.wgsl +6 -5
- package/src/gpu/kernels/upsample2d.js +22 -13
- package/src/gpu/kernels/upsample2d.wgsl +6 -9
- package/src/gpu/kernels/upsample2d_f16.wgsl +6 -9
- package/src/gpu/kernels/utils.js +35 -13
- package/src/gpu/partitioned-buffer-pool.js +10 -2
- package/src/gpu/perf-guards.js +2 -9
- package/src/gpu/profiler.js +27 -22
- package/src/gpu/readback-utils.d.ts +16 -0
- package/src/gpu/readback-utils.js +41 -0
- package/src/gpu/submit-tracker.js +13 -0
- package/src/gpu/uniform-cache.d.ts +1 -0
- package/src/gpu/uniform-cache.js +30 -9
- package/src/hotswap/intent-bundle.js +6 -0
- package/src/hotswap/manifest.d.ts +10 -1
- package/src/hotswap/manifest.js +12 -2
- package/src/hotswap/runtime.js +30 -8
- package/src/index-browser.d.ts +44 -0
- package/src/index-browser.js +14 -0
- package/src/inference/browser-harness-contract-helpers.d.ts +5 -0
- package/src/inference/browser-harness-contract-helpers.js +28 -0
- package/src/inference/browser-harness-diffusion-energy-suites.d.ts +2 -0
- package/src/inference/browser-harness-diffusion-energy-suites.js +269 -0
- package/src/inference/browser-harness-model-helpers.d.ts +16 -0
- package/src/inference/browser-harness-model-helpers.js +217 -0
- package/src/inference/browser-harness-report-helpers.d.ts +7 -0
- package/src/inference/browser-harness-report-helpers.js +42 -0
- package/src/inference/browser-harness-runtime-helpers.d.ts +61 -0
- package/src/inference/browser-harness-runtime-helpers.js +415 -0
- package/src/inference/browser-harness-suite-helpers.d.ts +28 -0
- package/src/inference/browser-harness-suite-helpers.js +268 -0
- package/src/inference/browser-harness-text-helpers.d.ts +27 -0
- package/src/inference/browser-harness-text-helpers.js +788 -0
- package/src/inference/browser-harness.d.ts +6 -0
- package/src/inference/browser-harness.js +130 -1950
- package/src/inference/kv-cache/base.js +140 -94
- package/src/inference/kv-cache/tiered.js +5 -3
- package/src/inference/moe-router.js +88 -56
- package/src/inference/multi-model-network.js +5 -3
- package/src/inference/network-evolution.d.ts +11 -2
- package/src/inference/network-evolution.js +20 -21
- package/src/inference/pipelines/context.d.ts +3 -0
- package/src/inference/pipelines/context.js +142 -2
- package/src/inference/pipelines/diffusion/helpers.js +7 -2
- package/src/inference/pipelines/diffusion/pipeline.js +17 -7
- package/src/inference/pipelines/diffusion/sd3-transformer.js +10 -10
- package/src/inference/pipelines/diffusion/text-encoder-gpu.d.ts +5 -0
- package/src/inference/pipelines/diffusion/text-encoder-gpu.js +27 -15
- package/src/inference/pipelines/diffusion/vae.js +3 -7
- package/src/inference/pipelines/energy/pipeline.js +27 -21
- package/src/inference/pipelines/energy/quintel.d.ts +5 -0
- package/src/inference/pipelines/energy/quintel.js +11 -0
- package/src/inference/pipelines/energy-head/row-head-pipeline.js +17 -13
- package/src/inference/pipelines/structured/json-head-pipeline.js +26 -11
- package/src/inference/pipelines/text/attention/projections.js +151 -101
- package/src/inference/pipelines/text/attention/record.js +73 -10
- package/src/inference/pipelines/text/attention/run.js +73 -10
- package/src/inference/pipelines/text/chat-format.js +25 -1
- package/src/inference/pipelines/text/config.d.ts +4 -0
- package/src/inference/pipelines/text/config.js +71 -5
- package/src/inference/pipelines/text/embed.js +2 -8
- package/src/inference/pipelines/text/execution-plan.js +64 -50
- package/src/inference/pipelines/text/execution-v0-contract-helpers.d.ts +59 -0
- package/src/inference/pipelines/text/execution-v0-contract-helpers.js +937 -0
- package/src/inference/pipelines/text/execution-v0-runtime-builders.d.ts +15 -0
- package/src/inference/pipelines/text/execution-v0-runtime-builders.js +279 -0
- package/src/inference/pipelines/text/execution-v0.js +78 -1002
- package/src/inference/pipelines/text/ffn/standard.js +3 -0
- package/src/inference/pipelines/text/generator-steps.d.ts +46 -0
- package/src/inference/pipelines/text/generator-steps.js +298 -207
- package/src/inference/pipelines/text/generator.js +6 -23
- package/src/inference/pipelines/text/init.d.ts +4 -0
- package/src/inference/pipelines/text/init.js +134 -29
- package/src/inference/pipelines/text/kernel-path-auto-select.js +2 -0
- package/src/inference/pipelines/text/kernel-trace.d.ts +2 -0
- package/src/inference/pipelines/text/kernel-trace.js +6 -0
- package/src/inference/pipelines/text/layer.js +14 -9
- package/src/inference/pipelines/text/linear-attention.d.ts +10 -0
- package/src/inference/pipelines/text/linear-attention.js +80 -6
- package/src/inference/pipelines/text/logits/gpu.js +10 -5
- package/src/inference/pipelines/text/logits/index.js +10 -11
- package/src/inference/pipelines/text/logits/utils.d.ts +7 -0
- package/src/inference/pipelines/text/logits/utils.js +9 -0
- package/src/inference/pipelines/text/lora-apply.js +50 -32
- package/src/inference/pipelines/text/model-load.js +279 -104
- package/src/inference/pipelines/text/moe-cache.js +5 -4
- package/src/inference/pipelines/text/moe-cpu-gptoss.js +74 -69
- package/src/inference/pipelines/text/moe-cpu.js +42 -38
- package/src/inference/pipelines/text/moe-gpu.js +110 -86
- package/src/inference/pipelines/text/ops.js +90 -90
- package/src/inference/pipelines/text/probes.js +9 -9
- package/src/inference/pipelines/text/weights.js +17 -7
- package/src/inference/pipelines/text.js +17 -1
- package/src/inference/speculative.d.ts +2 -2
- package/src/inference/speculative.js +4 -18
- package/src/inference/test-harness.d.ts +1 -1
- package/src/inference/test-harness.js +15 -5
- package/src/inference/tokenizer.d.ts +0 -5
- package/src/inference/tokenizer.js +4 -23
- package/src/inference/tokenizers/bpe.js +9 -0
- package/src/inference/tokenizers/bundled.js +176 -33
- package/src/inference/tokenizers/sentencepiece.js +12 -0
- package/src/loader/doppler-loader.js +38 -22
- package/src/loader/dtype-utils.js +3 -44
- package/src/loader/embedding-loader.js +7 -3
- package/src/loader/experts/expert-cache.js +13 -6
- package/src/loader/experts/expert-loader.js +10 -6
- package/src/loader/final-weights-loader.js +8 -4
- package/src/loader/layer-loader.js +2 -1
- package/src/loader/loader-state.js +2 -2
- package/src/loader/memory-monitor.js +8 -0
- package/src/loader/multi-model-loader.d.ts +14 -0
- package/src/loader/multi-model-loader.js +70 -24
- package/src/loader/shard-cache.js +81 -12
- package/src/loader/shard-resolver.js +25 -3
- package/src/loader/tensors/tensor-loader.js +209 -144
- package/src/loader/tensors/tensor-reader.js +76 -19
- package/src/loader/weight-downcast.js +1 -1
- package/src/memory/buffer-pool.d.ts +9 -1
- package/src/memory/buffer-pool.js +109 -44
- package/src/memory/unified-detect.js +1 -1
- package/src/rules/inference/kernel-path.rules.json +24 -8
- package/src/rules/rule-registry.js +25 -1
- package/src/rules/tooling/command-runtime.rules.json +18 -0
- package/src/storage/backends/opfs-store.js +68 -24
- package/src/storage/downloader.js +364 -83
- package/src/storage/index.d.ts +3 -0
- package/src/storage/index.js +3 -0
- package/src/storage/preflight.d.ts +2 -2
- package/src/storage/preflight.js +24 -2
- package/src/storage/quickstart-downloader.js +11 -5
- package/src/storage/registry.js +10 -4
- package/src/storage/reports.js +1 -1
- package/src/storage/shard-manager.d.ts +15 -1
- package/src/storage/shard-manager.js +51 -3
- package/src/storage/source-artifact-store.d.ts +52 -0
- package/src/storage/source-artifact-store.js +234 -0
- package/src/tooling/command-api-constants.d.ts +9 -0
- package/src/tooling/command-api-constants.js +9 -0
- package/src/tooling/command-api-family-normalizers.d.ts +9 -0
- package/src/tooling/command-api-family-normalizers.js +343 -0
- package/src/tooling/command-api-helpers.d.ts +25 -0
- package/src/tooling/command-api-helpers.js +262 -0
- package/src/tooling/command-api.d.ts +27 -1
- package/src/tooling/command-api.js +26 -473
- package/src/tooling/command-envelope.js +4 -1
- package/src/tooling/command-runner-shared.js +52 -18
- package/src/tooling/lean-execution-contract.js +150 -3
- package/src/tooling/node-browser-command-runner.d.ts +4 -0
- package/src/tooling/node-browser-command-runner.js +218 -273
- package/src/tooling/node-command-runner.js +44 -3
- package/src/tooling/node-converter.js +27 -1
- package/src/tooling/node-source-runtime.d.ts +1 -1
- package/src/tooling/node-source-runtime.js +84 -3
- package/src/tooling/node-webgpu.js +30 -105
- package/src/tooling/opfs-cache.js +21 -4
- package/src/tooling/runtime-input-composition.d.ts +38 -0
- package/src/tooling/runtime-input-composition.js +86 -0
- package/src/tooling/source-runtime-bundle.d.ts +40 -5
- package/src/tooling/source-runtime-bundle.js +261 -34
- package/src/tooling/source-runtime-materializer.d.ts +6 -0
- package/src/tooling/source-runtime-materializer.js +93 -0
- package/src/training/attention-backward.js +32 -17
- package/src/training/autograd.js +80 -52
- package/src/training/checkpoint-watch.d.ts +8 -0
- package/src/training/checkpoint-watch.js +139 -0
- package/src/training/checkpoint.d.ts +6 -1
- package/src/training/checkpoint.js +46 -7
- package/src/training/clip.js +2 -1
- package/src/training/datasets/token-batch.js +20 -8
- package/src/training/distillation/artifacts.d.ts +71 -0
- package/src/training/distillation/artifacts.js +132 -0
- package/src/training/distillation/checkpoint-watch.d.ts +10 -0
- package/src/training/distillation/checkpoint-watch.js +58 -0
- package/src/training/distillation/dataset.d.ts +59 -0
- package/src/training/distillation/dataset.js +337 -0
- package/src/training/distillation/eval.d.ts +34 -0
- package/src/training/distillation/eval.js +310 -0
- package/src/training/distillation/index.d.ts +29 -0
- package/src/training/distillation/index.js +29 -0
- package/src/training/distillation/runtime.d.ts +20 -0
- package/src/training/distillation/runtime.js +121 -0
- package/src/training/distillation/scoreboard.d.ts +6 -0
- package/src/training/distillation/scoreboard.js +8 -0
- package/src/training/distillation/stage-a.d.ts +45 -0
- package/src/training/distillation/stage-a.js +338 -0
- package/src/training/distillation/stage-b.d.ts +24 -0
- package/src/training/distillation/stage-b.js +20 -0
- package/src/training/distillation/student-fixture.d.ts +22 -0
- package/src/training/distillation/student-fixture.js +846 -0
- package/src/training/distillation/suite-data.d.ts +45 -0
- package/src/training/distillation/suite-data.js +189 -0
- package/src/training/index.d.ts +10 -0
- package/src/training/index.js +10 -0
- package/src/training/lora-pipeline.d.ts +40 -0
- package/src/training/lora-pipeline.js +793 -0
- package/src/training/lora.js +26 -12
- package/src/training/loss.js +5 -6
- package/src/training/objectives/cross_entropy.js +2 -5
- package/src/training/objectives/distill_kd.js +4 -8
- package/src/training/objectives/distill_triplet.js +4 -8
- package/src/training/objectives/ul_stage2_base.js +4 -8
- package/src/training/operator-artifacts.d.ts +62 -0
- package/src/training/operator-artifacts.js +140 -0
- package/src/training/operator-command.d.ts +5 -0
- package/src/training/operator-command.js +455 -0
- package/src/training/operator-eval.d.ts +48 -0
- package/src/training/operator-eval.js +230 -0
- package/src/training/operator-scoreboard.d.ts +5 -0
- package/src/training/operator-scoreboard.js +44 -0
- package/src/training/optimizer.js +19 -7
- package/src/training/runner.d.ts +52 -0
- package/src/training/runner.js +31 -5
- package/src/training/suite.d.ts +112 -0
- package/src/training/suite.js +24 -984
- package/src/training/tensor-factory.d.ts +9 -0
- package/src/training/tensor-factory.js +13 -0
- package/src/training/trainer.js +3 -5
- package/src/training/ul_dataset.js +3 -5
- package/src/training/workloads.d.ts +164 -0
- package/src/training/workloads.js +530 -0
- package/src/version.js +1 -1
- package/tools/convert-safetensors-node.js +22 -16
- package/tools/doppler-cli.js +179 -63
|
@@ -1043,18 +1043,9 @@ export class PipelineGenerator {
|
|
|
1043
1043
|
if (allowReadback(`pipeline.prefill.layer-${l}`)) {
|
|
1044
1044
|
try {
|
|
1045
1045
|
const sampleSize = config.hiddenSize * activationBytes;
|
|
1046
|
-
const staging = device.createBuffer({
|
|
1047
|
-
size: sampleSize,
|
|
1048
|
-
usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ,
|
|
1049
|
-
});
|
|
1050
|
-
const enc = device.createCommandEncoder();
|
|
1051
1046
|
const lastTokenOffset = (numTokens - 1) * config.hiddenSize * activationBytes;
|
|
1052
|
-
|
|
1053
|
-
|
|
1054
|
-
await staging.mapAsync(GPUMapMode.READ);
|
|
1055
|
-
const data = decodeReadback(staging.getMappedRange().slice(0), activationDtype);
|
|
1056
|
-
staging.unmap();
|
|
1057
|
-
staging.destroy();
|
|
1047
|
+
const readback = await readBufferSlice(currentHiddenBuffer, lastTokenOffset, sampleSize);
|
|
1048
|
+
const data = decodeReadback(readback, activationDtype);
|
|
1058
1049
|
let min = Infinity;
|
|
1059
1050
|
let max = -Infinity;
|
|
1060
1051
|
let maxAbs = 0;
|
|
@@ -1112,20 +1103,12 @@ export class PipelineGenerator {
|
|
|
1112
1103
|
if (opts.debug) {
|
|
1113
1104
|
log.debug('Pipeline', `LAYER_LOOP_DONE, currentHiddenBuffer type=${currentHiddenBuffer?.constructor?.name}`);
|
|
1114
1105
|
if (currentHiddenBuffer && allowReadback('pipeline.prefill.final-hidden')) {
|
|
1115
|
-
const device = getDevice();
|
|
1116
1106
|
const lastTokenOffset = (numTokens - 1) * config.hiddenSize * activationBytes;
|
|
1117
1107
|
const sampleSize = config.hiddenSize * activationBytes;
|
|
1118
|
-
const
|
|
1119
|
-
|
|
1120
|
-
|
|
1121
|
-
|
|
1122
|
-
const enc = device.createCommandEncoder();
|
|
1123
|
-
enc.copyBufferToBuffer(currentHiddenBuffer, lastTokenOffset, staging, 0, sampleSize);
|
|
1124
|
-
device.queue.submit([enc.finish()]);
|
|
1125
|
-
await staging.mapAsync(GPUMapMode.READ);
|
|
1126
|
-
const data = decodeReadback(staging.getMappedRange().slice(0), activationDtype);
|
|
1127
|
-
staging.unmap();
|
|
1128
|
-
staging.destroy();
|
|
1108
|
+
const data = decodeReadback(
|
|
1109
|
+
await readBufferSlice(currentHiddenBuffer, lastTokenOffset, sampleSize),
|
|
1110
|
+
activationDtype
|
|
1111
|
+
);
|
|
1129
1112
|
const nanCount = Array.from(data).filter(x => !Number.isFinite(x)).length;
|
|
1130
1113
|
const nonZero = Array.from(data).filter(x => Number.isFinite(x) && x !== 0).slice(0, 5);
|
|
1131
1114
|
log.debug('Pipeline', `FINAL_HIDDEN[pos=${numTokens - 1}]: nan=${nanCount}/${data.length}, sample=[${nonZero.map(x => x.toFixed(4)).join(', ')}]`);
|
|
@@ -71,9 +71,13 @@ export interface PipelineContexts {
|
|
|
71
71
|
*/
|
|
72
72
|
export interface RoPEConfig {
|
|
73
73
|
headDim: number;
|
|
74
|
+
rotaryDim?: number;
|
|
74
75
|
maxSeqLen: number;
|
|
75
76
|
ropeTheta: number;
|
|
76
77
|
ropeLocalTheta?: number | null;
|
|
78
|
+
mropeInterleaved?: boolean;
|
|
79
|
+
mropeSection?: number[] | null;
|
|
80
|
+
partialRotaryFactor?: number | null;
|
|
77
81
|
ropeScale: number;
|
|
78
82
|
ropeLocalScale?: number;
|
|
79
83
|
ropeScalingType?: string | null;
|
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
import { parseModelConfig } from './config.js';
|
|
4
4
|
import { getDevice, getDeviceLimits, getKernelCapabilities } from '../../../gpu/device.js';
|
|
5
|
-
import { acquireBuffer } from '../../../memory/buffer-pool.js';
|
|
5
|
+
import { acquireBuffer, releaseBuffer } from '../../../memory/buffer-pool.js';
|
|
6
6
|
import { KVCache, SlidingWindowKVCache, TieredKVCache, BasisDecomposedPagedCache } from '../../kv-cache.js';
|
|
7
7
|
import { Tokenizer } from '../../tokenizer.js';
|
|
8
8
|
import { MoERouter } from '../../moe-router.js';
|
|
@@ -14,6 +14,10 @@ import { PAGED_LAYOUT_SEQ_LEN_THRESHOLD } from '../../../config/schema/index.js'
|
|
|
14
14
|
import { isKernelPathFusedQ4K } from '../../../config/kernel-path-loader.js';
|
|
15
15
|
import { createWeightBuffer, getWeightDtype, isWeightBuffer } from '../../../gpu/weight-buffer.js';
|
|
16
16
|
import { selectRuleValue } from '../../../rules/rule-registry.js';
|
|
17
|
+
import {
|
|
18
|
+
createSourceStorageContext,
|
|
19
|
+
getSourceRuntimeMetadata,
|
|
20
|
+
} from '../../../tooling/source-runtime-bundle.js';
|
|
17
21
|
|
|
18
22
|
function resolveErrorMessage(error) {
|
|
19
23
|
if (error && typeof error === 'object' && typeof error.message === 'string') {
|
|
@@ -56,12 +60,61 @@ function normalizeBaseUrl(baseUrl) {
|
|
|
56
60
|
return baseUrl.replace(/\/$/, '');
|
|
57
61
|
}
|
|
58
62
|
|
|
63
|
+
async function fetchBytes(url, offset = null, length = null) {
|
|
64
|
+
const headers = {};
|
|
65
|
+
if (Number.isFinite(offset) && Number.isFinite(length) && length > 0) {
|
|
66
|
+
const start = Math.max(0, Math.floor(offset));
|
|
67
|
+
const end = start + Math.max(0, Math.floor(length)) - 1;
|
|
68
|
+
headers.Range = `bytes=${start}-${end}`;
|
|
69
|
+
}
|
|
70
|
+
const response = await fetch(url, { headers });
|
|
71
|
+
if (!response.ok) {
|
|
72
|
+
throw new Error(`Failed to fetch ${url}: ${response.status}`);
|
|
73
|
+
}
|
|
74
|
+
return new Uint8Array(await response.arrayBuffer());
|
|
75
|
+
}
|
|
76
|
+
|
|
59
77
|
function createRemoteStorageContext(baseUrl, manifest) {
|
|
60
78
|
const root = normalizeBaseUrl(baseUrl);
|
|
61
79
|
if (!root || !isRDRRManifest(manifest)) {
|
|
62
80
|
return null;
|
|
63
81
|
}
|
|
64
82
|
|
|
83
|
+
const sourceRuntime = getSourceRuntimeMetadata(manifest);
|
|
84
|
+
if (sourceRuntime) {
|
|
85
|
+
const readRange = async (relativePath, offset, length) => {
|
|
86
|
+
const filename = String(relativePath || '').replace(/^\/+/, '');
|
|
87
|
+
if (!filename) {
|
|
88
|
+
throw new Error('Direct-source artifact path is required.');
|
|
89
|
+
}
|
|
90
|
+
const url = `${root}/${filename}`;
|
|
91
|
+
return fetchBytes(url, offset, length);
|
|
92
|
+
};
|
|
93
|
+
const readText = async (relativePath) => {
|
|
94
|
+
const filename = String(relativePath || '').replace(/^\/+/, '');
|
|
95
|
+
if (!filename) return null;
|
|
96
|
+
const response = await fetch(`${root}/${filename}`);
|
|
97
|
+
if (!response.ok) {
|
|
98
|
+
throw new Error(`Failed to fetch ${filename} from ${root}: ${response.status}`);
|
|
99
|
+
}
|
|
100
|
+
return response.text();
|
|
101
|
+
};
|
|
102
|
+
const readBinary = async (relativePath) => {
|
|
103
|
+
const filename = String(relativePath || '').replace(/^\/+/, '');
|
|
104
|
+
if (!filename) {
|
|
105
|
+
throw new Error('Direct-source binary asset path is required.');
|
|
106
|
+
}
|
|
107
|
+
return fetchBytes(`${root}/${filename}`);
|
|
108
|
+
};
|
|
109
|
+
return createSourceStorageContext({
|
|
110
|
+
manifest,
|
|
111
|
+
readRange,
|
|
112
|
+
readText,
|
|
113
|
+
readBinary,
|
|
114
|
+
verifyHashes: true,
|
|
115
|
+
});
|
|
116
|
+
}
|
|
117
|
+
|
|
65
118
|
return {
|
|
66
119
|
async loadShard(index) {
|
|
67
120
|
const shard = manifest.shards[index];
|
|
@@ -69,11 +122,7 @@ function createRemoteStorageContext(baseUrl, manifest) {
|
|
|
69
122
|
if (!filename) {
|
|
70
123
|
throw new Error(`Manifest shard ${index} is missing filename.`);
|
|
71
124
|
}
|
|
72
|
-
|
|
73
|
-
if (!response.ok) {
|
|
74
|
-
throw new Error(`Failed to fetch shard ${index} from ${root}: ${response.status}`);
|
|
75
|
-
}
|
|
76
|
-
return new Uint8Array(await response.arrayBuffer());
|
|
125
|
+
return fetchBytes(`${root}/${filename.replace(/^\/+/, '')}`);
|
|
77
126
|
},
|
|
78
127
|
};
|
|
79
128
|
}
|
|
@@ -206,13 +255,45 @@ function isSameRoPEScalingConfig(
|
|
|
206
255
|
=== (rightScaling?.original_max_position_embeddings ?? null);
|
|
207
256
|
}
|
|
208
257
|
|
|
258
|
+
function resolveRotaryDim(headDim, rotaryDim, partialRotaryFactor) {
|
|
259
|
+
if (rotaryDim != null) {
|
|
260
|
+
if (!Number.isFinite(rotaryDim) || rotaryDim <= 0 || (rotaryDim % 2) !== 0) {
|
|
261
|
+
throw new Error(`RoPE rotary dim must be a positive even integer; got "${rotaryDim}".`);
|
|
262
|
+
}
|
|
263
|
+
if (rotaryDim > headDim) {
|
|
264
|
+
throw new Error(`RoPE rotary dim ${rotaryDim} cannot exceed headDim ${headDim}.`);
|
|
265
|
+
}
|
|
266
|
+
return rotaryDim;
|
|
267
|
+
}
|
|
268
|
+
if (partialRotaryFactor == null) {
|
|
269
|
+
return headDim;
|
|
270
|
+
}
|
|
271
|
+
if (!Number.isFinite(partialRotaryFactor) || partialRotaryFactor <= 0 || partialRotaryFactor > 1) {
|
|
272
|
+
throw new Error(
|
|
273
|
+
`RoPE partialRotaryFactor must be a number in (0, 1]; got "${partialRotaryFactor}".`
|
|
274
|
+
);
|
|
275
|
+
}
|
|
276
|
+
const resolved = Math.trunc(headDim * partialRotaryFactor);
|
|
277
|
+
if (resolved <= 0 || (resolved % 2) !== 0) {
|
|
278
|
+
throw new Error(
|
|
279
|
+
`RoPE partialRotaryFactor=${partialRotaryFactor} with headDim=${headDim} resolves ` +
|
|
280
|
+
`to rotaryDim=${resolved}, but rotaryDim must be a positive even integer.`
|
|
281
|
+
);
|
|
282
|
+
}
|
|
283
|
+
return resolved;
|
|
284
|
+
}
|
|
285
|
+
|
|
209
286
|
|
|
210
287
|
export async function initRoPEFrequencies(config, useGPU) {
|
|
211
288
|
const {
|
|
212
289
|
headDim,
|
|
290
|
+
rotaryDim,
|
|
213
291
|
maxSeqLen,
|
|
214
292
|
ropeTheta,
|
|
215
293
|
ropeLocalTheta,
|
|
294
|
+
mropeInterleaved,
|
|
295
|
+
mropeSection,
|
|
296
|
+
partialRotaryFactor,
|
|
216
297
|
ropeScale,
|
|
217
298
|
ropeLocalScale,
|
|
218
299
|
ropeScalingType,
|
|
@@ -230,14 +311,23 @@ export async function initRoPEFrequencies(config, useGPU) {
|
|
|
230
311
|
const resolvedLocalTheta = ropeLocalTheta ?? ropeTheta;
|
|
231
312
|
const resolvedLocalScalingType = ropeLocalScalingType ?? ropeScalingType;
|
|
232
313
|
const resolvedLocalScaling = ropeLocalScaling ?? ropeScaling;
|
|
314
|
+
const resolvedRotaryDim = resolveRotaryDim(headDim, rotaryDim, partialRotaryFactor);
|
|
315
|
+
const halfDim = resolvedRotaryDim / 2;
|
|
316
|
+
if (mropeInterleaved === true && Array.isArray(mropeSection)) {
|
|
317
|
+
const expandedDim = mropeSection.reduce((sum, entry) => sum + entry, 0) * 2;
|
|
318
|
+
if (expandedDim !== resolvedRotaryDim) {
|
|
319
|
+
throw new Error(
|
|
320
|
+
`RoPE mropeSection expands to ${expandedDim} dims, but rotaryDim is ${resolvedRotaryDim}.`
|
|
321
|
+
);
|
|
322
|
+
}
|
|
323
|
+
}
|
|
233
324
|
|
|
234
|
-
const halfDim = headDim / 2;
|
|
235
325
|
const isYarn = ropeScalingType === 'yarn';
|
|
236
326
|
const isLocalYarn = resolvedLocalScalingType === 'yarn';
|
|
237
327
|
|
|
238
328
|
// Compute global (full_attention) frequencies
|
|
239
329
|
const globalFreqs = computeRoPEFreqsForTheta(
|
|
240
|
-
ropeTheta,
|
|
330
|
+
ropeTheta, resolvedRotaryDim, maxSeqLen, ropeScale, ropeScalingType, ropeScaling
|
|
241
331
|
);
|
|
242
332
|
|
|
243
333
|
// Compute local (sliding_attention) frequencies if different from global.
|
|
@@ -256,7 +346,7 @@ export async function initRoPEFrequencies(config, useGPU) {
|
|
|
256
346
|
if (hasDistinctLocalTheta || hasDistinctLocalScaling) {
|
|
257
347
|
localFreqs = computeRoPEFreqsForTheta(
|
|
258
348
|
resolvedLocalTheta,
|
|
259
|
-
|
|
349
|
+
resolvedRotaryDim,
|
|
260
350
|
maxSeqLen,
|
|
261
351
|
resolvedLocalScale,
|
|
262
352
|
resolvedLocalScalingType,
|
|
@@ -285,27 +375,37 @@ export async function initRoPEFrequencies(config, useGPU) {
|
|
|
285
375
|
// Upload to GPU if available
|
|
286
376
|
const device = getDevice();
|
|
287
377
|
if (device && useGPU) {
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
378
|
+
let cosBuffer = null;
|
|
379
|
+
let sinBuffer = null;
|
|
380
|
+
let localCosBuffer = null;
|
|
381
|
+
let localSinBuffer = null;
|
|
382
|
+
try {
|
|
383
|
+
cosBuffer = acquireBuffer(globalFreqs.cos.byteLength, undefined, 'rope_cos');
|
|
384
|
+
sinBuffer = acquireBuffer(globalFreqs.sin.byteLength, undefined, 'rope_sin');
|
|
385
|
+
device.queue.writeBuffer(cosBuffer, 0, globalFreqs.cos.buffer, globalFreqs.cos.byteOffset, globalFreqs.cos.byteLength);
|
|
386
|
+
device.queue.writeBuffer(sinBuffer, 0, globalFreqs.sin.buffer, globalFreqs.sin.byteOffset, globalFreqs.sin.byteLength);
|
|
387
|
+
|
|
388
|
+
if (localFreqs) {
|
|
389
|
+
localCosBuffer = acquireBuffer(localFreqs.cos.byteLength, undefined, 'rope_local_cos');
|
|
390
|
+
localSinBuffer = acquireBuffer(localFreqs.sin.byteLength, undefined, 'rope_local_sin');
|
|
391
|
+
device.queue.writeBuffer(localCosBuffer, 0, localFreqs.cos.buffer, localFreqs.cos.byteOffset, localFreqs.cos.byteLength);
|
|
392
|
+
device.queue.writeBuffer(localSinBuffer, 0, localFreqs.sin.buffer, localFreqs.sin.byteOffset, localFreqs.sin.byteLength);
|
|
393
|
+
}
|
|
394
|
+
} catch (error) {
|
|
395
|
+
for (const buffer of [cosBuffer, sinBuffer, localCosBuffer, localSinBuffer]) {
|
|
396
|
+
if (buffer) {
|
|
397
|
+
releaseBuffer(buffer);
|
|
398
|
+
}
|
|
399
|
+
}
|
|
400
|
+
throw error;
|
|
302
401
|
}
|
|
303
402
|
|
|
304
403
|
log.debug(
|
|
305
404
|
'Pipeline',
|
|
306
|
-
`RoPE frequencies initialized (GPU): ${maxSeqLen} positions, dim=${halfDim}, headDim=${headDim}, ` +
|
|
405
|
+
`RoPE frequencies initialized (GPU): ${maxSeqLen} positions, dim=${halfDim}, headDim=${headDim}, rotaryDim=${resolvedRotaryDim}, ` +
|
|
307
406
|
`theta=${ropeTheta}${hasDistinctLocalTheta ? `, localTheta=${resolvedLocalTheta}` : ''}, ` +
|
|
308
|
-
`scaling=${ropeScalingType ?? 'none'}:${ropeScale}${hasDistinctLocalScaling ? `, localScaling=${resolvedLocalScalingType ?? 'none'}:${resolvedLocalScale}` : ''}`
|
|
407
|
+
`scaling=${ropeScalingType ?? 'none'}:${ropeScale}${hasDistinctLocalScaling ? `, localScaling=${resolvedLocalScalingType ?? 'none'}:${resolvedLocalScale}` : ''}, ` +
|
|
408
|
+
`interleaved=${mropeInterleaved === true}`
|
|
309
409
|
);
|
|
310
410
|
|
|
311
411
|
return {
|
|
@@ -318,9 +418,10 @@ export async function initRoPEFrequencies(config, useGPU) {
|
|
|
318
418
|
|
|
319
419
|
log.debug(
|
|
320
420
|
'Pipeline',
|
|
321
|
-
`RoPE frequencies initialized (CPU): ${maxSeqLen} positions, dim=${halfDim}, headDim=${headDim}, ` +
|
|
421
|
+
`RoPE frequencies initialized (CPU): ${maxSeqLen} positions, dim=${halfDim}, headDim=${headDim}, rotaryDim=${resolvedRotaryDim}, ` +
|
|
322
422
|
`theta=${ropeTheta}${hasDistinctLocalTheta ? `, localTheta=${resolvedLocalTheta}` : ''}, ` +
|
|
323
|
-
`scaling=${ropeScalingType ?? 'none'}:${ropeScale}${hasDistinctLocalScaling ? `, localScaling=${resolvedLocalScalingType ?? 'none'}:${resolvedLocalScale}` : ''}`
|
|
423
|
+
`scaling=${ropeScalingType ?? 'none'}:${ropeScale}${hasDistinctLocalScaling ? `, localScaling=${resolvedLocalScalingType ?? 'none'}:${resolvedLocalScale}` : ''}, ` +
|
|
424
|
+
`interleaved=${mropeInterleaved === true}`
|
|
324
425
|
);
|
|
325
426
|
|
|
326
427
|
return {
|
|
@@ -688,6 +789,10 @@ function applyChatMLTemplate(prompt) {
|
|
|
688
789
|
return `<|im_start|>user\n${prompt}<|im_end|>\n<|im_start|>assistant\n`;
|
|
689
790
|
}
|
|
690
791
|
|
|
792
|
+
function applyQwenTemplate(prompt) {
|
|
793
|
+
return `<|im_start|>user\n${prompt}<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n`;
|
|
794
|
+
}
|
|
795
|
+
|
|
691
796
|
function applyTranslateGemmaTemplate() {
|
|
692
797
|
throw new Error(
|
|
693
798
|
'TranslateGemma template requires structured messages. ' +
|
|
@@ -702,7 +807,7 @@ const PROMPT_TEMPLATES = {
|
|
|
702
807
|
'llama3': applyHeaderBasedTemplate,
|
|
703
808
|
'gpt-oss': applyChannelBasedTemplate,
|
|
704
809
|
'chatml': applyChatMLTemplate,
|
|
705
|
-
'qwen':
|
|
810
|
+
'qwen': applyQwenTemplate,
|
|
706
811
|
'translategemma': applyTranslateGemmaTemplate,
|
|
707
812
|
};
|
|
708
813
|
|
|
@@ -721,7 +826,7 @@ export function applyChatTemplate(prompt, templateType) {
|
|
|
721
826
|
export const applyGemmaChatTemplate = applyTurnBasedTemplate;
|
|
722
827
|
export const applyLlama3ChatTemplate = applyHeaderBasedTemplate;
|
|
723
828
|
export const applyGptOssChatTemplate = applyChannelBasedTemplate;
|
|
724
|
-
export const applyQwenChatTemplate =
|
|
829
|
+
export const applyQwenChatTemplate = applyQwenTemplate;
|
|
725
830
|
|
|
726
831
|
|
|
727
832
|
export function isStopToken(token, stopTokenIds, eosTokenId) {
|
|
@@ -78,6 +78,7 @@ export function resolveCapabilityKernelPathRef(configuredKernelPathRef, kernelPa
|
|
|
78
78
|
|
|
79
79
|
const normalizedPolicy = resolveKernelPathPolicy(kernelPathPolicy);
|
|
80
80
|
const hasSubgroups = capabilities?.hasSubgroups === true;
|
|
81
|
+
const hasF16 = capabilities?.hasF16 === true;
|
|
81
82
|
const normalizedSource = normalizeKernelPathSource(kernelPathSource);
|
|
82
83
|
const allowCapabilityAutoSelection = normalizedPolicy.mode === 'capability-aware'
|
|
83
84
|
&& normalizedPolicy.sourceScope.includes(normalizedSource);
|
|
@@ -85,6 +86,7 @@ export function resolveCapabilityKernelPathRef(configuredKernelPathRef, kernelPa
|
|
|
85
86
|
return selectRuleValue('inference', 'kernelPath', 'autoSelect', {
|
|
86
87
|
kernelPathRef: configuredKernelPathRef,
|
|
87
88
|
hasSubgroups,
|
|
89
|
+
hasF16,
|
|
88
90
|
allowCapabilityAutoSelection,
|
|
89
91
|
});
|
|
90
92
|
}
|
|
@@ -283,6 +283,9 @@ export async function traceStep(name, label, layer, outputBuffer, outputShape, o
|
|
|
283
283
|
if (layer >= 0 && !kernelTrace.shouldTraceLayer(layer)) return;
|
|
284
284
|
|
|
285
285
|
const output = await snapshotTensor(outputBuffer, outputShape);
|
|
286
|
+
if (!output.ok) {
|
|
287
|
+
throw new Error(`[TRACE] Failed to snapshot output for ${label}: ${output.error}`);
|
|
288
|
+
}
|
|
286
289
|
|
|
287
290
|
// Snapshot inputs if provided (expensive - only do if tracing)
|
|
288
291
|
|
|
@@ -290,6 +293,9 @@ export async function traceStep(name, label, layer, outputBuffer, outputShape, o
|
|
|
290
293
|
if (options?.inputs && options?.inputShapes) {
|
|
291
294
|
for (let i = 0; i < options.inputs.length; i++) {
|
|
292
295
|
const snap = await snapshotTensor(options.inputs[i], options.inputShapes[i]);
|
|
296
|
+
if (!snap.ok) {
|
|
297
|
+
throw new Error(`[TRACE] Failed to snapshot input ${i} for ${label}: ${snap.error}`);
|
|
298
|
+
}
|
|
293
299
|
inputs.push(snap);
|
|
294
300
|
}
|
|
295
301
|
}
|
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
import { log, trace } from '../../../debug/index.js';
|
|
4
4
|
import { getDevice } from '../../../gpu/device.js';
|
|
5
|
-
import { releaseBuffer } from '../../../memory/buffer-pool.js';
|
|
5
|
+
import { releaseBuffer, readBuffer } from '../../../memory/buffer-pool.js';
|
|
6
6
|
import { allowReadback } from '../../../gpu/perf-guards.js';
|
|
7
7
|
import { createTensor } from '../../../gpu/tensor.js';
|
|
8
8
|
import {
|
|
@@ -228,6 +228,7 @@ export async function processLayerGPU(layerIdx, inputBuffer, numTokens, isPrefil
|
|
|
228
228
|
linearRuntime: context.linearAttentionRuntime ?? null,
|
|
229
229
|
getWeightBuffer: (weight, label) => getWeightBuffer(weight, label),
|
|
230
230
|
getNormWeightBuffer: (weight, label) => getNormWeightBuffer(weight, label, weightConfig, debugFlags),
|
|
231
|
+
debugProbes: context.debugProbes,
|
|
231
232
|
recorder: recorder ?? null,
|
|
232
233
|
});
|
|
233
234
|
} else {
|
|
@@ -259,6 +260,8 @@ export async function processLayerGPU(layerIdx, inputBuffer, numTokens, isPrefil
|
|
|
259
260
|
attentionOutputGate: config.attentionOutputGate,
|
|
260
261
|
causalAttention: config.causalAttention,
|
|
261
262
|
rmsNormWeightOffset: config.rmsNormWeightOffset,
|
|
263
|
+
ropeRotaryDim: config.ropeRotaryDim,
|
|
264
|
+
ropeInterleaved: config.ropeInterleaved,
|
|
262
265
|
tokenIds: context.currentTokenIds ?? null,
|
|
263
266
|
kernelPath: context.kernelPath ?? null,
|
|
264
267
|
disableRoPE,
|
|
@@ -312,14 +315,7 @@ export async function processLayerGPU(layerIdx, inputBuffer, numTokens, isPrefil
|
|
|
312
315
|
if (allowReadback(`layer.attn-out.${layerIdx}`)) {
|
|
313
316
|
try {
|
|
314
317
|
const sampleSize = Math.min(128, attnOutput.buffer.size);
|
|
315
|
-
const
|
|
316
|
-
const enc = device.createCommandEncoder();
|
|
317
|
-
enc.copyBufferToBuffer(attnOutput.buffer, 0, staging, 0, sampleSize);
|
|
318
|
-
device.queue.submit([enc.finish()]);
|
|
319
|
-
await staging.mapAsync(GPUMapMode.READ);
|
|
320
|
-
const data = new Float32Array(staging.getMappedRange().slice(0));
|
|
321
|
-
staging.unmap();
|
|
322
|
-
staging.destroy();
|
|
318
|
+
const data = new Float32Array(await readBuffer(attnOutput.buffer, sampleSize));
|
|
323
319
|
let maxAbs = 0;
|
|
324
320
|
for (let i = 0; i < data.length; i++) {
|
|
325
321
|
const abs = Math.abs(data[i]);
|
|
@@ -661,6 +657,8 @@ async function processLayerPlanGPU(layerIdx, inputBuffer, numTokens, isPrefill,
|
|
|
661
657
|
attentionOutputGate: config.attentionOutputGate,
|
|
662
658
|
causalAttention: config.causalAttention,
|
|
663
659
|
rmsNormWeightOffset: config.rmsNormWeightOffset,
|
|
660
|
+
ropeRotaryDim: config.ropeRotaryDim,
|
|
661
|
+
ropeInterleaved: config.ropeInterleaved,
|
|
664
662
|
tokenIds: context.currentTokenIds ?? null,
|
|
665
663
|
skipInputNorm: step.skipInputNorm === true,
|
|
666
664
|
activationDtype,
|
|
@@ -690,6 +688,7 @@ async function processLayerPlanGPU(layerIdx, inputBuffer, numTokens, isPrefill,
|
|
|
690
688
|
hiddenSize,
|
|
691
689
|
probes: context.debugProbes,
|
|
692
690
|
recorder,
|
|
691
|
+
dtype: outputDtype,
|
|
693
692
|
});
|
|
694
693
|
}
|
|
695
694
|
break;
|
|
@@ -733,6 +732,7 @@ async function processLayerPlanGPU(layerIdx, inputBuffer, numTokens, isPrefill,
|
|
|
733
732
|
hiddenSize,
|
|
734
733
|
probes: context.debugProbes,
|
|
735
734
|
recorder,
|
|
735
|
+
dtype: outputDtype,
|
|
736
736
|
});
|
|
737
737
|
}
|
|
738
738
|
break;
|
|
@@ -767,6 +767,7 @@ async function processLayerPlanGPU(layerIdx, inputBuffer, numTokens, isPrefill,
|
|
|
767
767
|
hiddenSize,
|
|
768
768
|
probes: context.debugProbes,
|
|
769
769
|
recorder,
|
|
770
|
+
dtype: outputDtype,
|
|
770
771
|
});
|
|
771
772
|
}
|
|
772
773
|
break;
|
|
@@ -801,6 +802,7 @@ async function processLayerPlanGPU(layerIdx, inputBuffer, numTokens, isPrefill,
|
|
|
801
802
|
hiddenSize,
|
|
802
803
|
probes: context.debugProbes,
|
|
803
804
|
recorder,
|
|
805
|
+
dtype: outputDtype,
|
|
804
806
|
});
|
|
805
807
|
}
|
|
806
808
|
break;
|
|
@@ -825,6 +827,7 @@ async function processLayerPlanGPU(layerIdx, inputBuffer, numTokens, isPrefill,
|
|
|
825
827
|
hiddenSize,
|
|
826
828
|
probes: context.debugProbes,
|
|
827
829
|
recorder,
|
|
830
|
+
dtype: outputDtype,
|
|
828
831
|
});
|
|
829
832
|
}
|
|
830
833
|
break;
|
|
@@ -851,6 +854,7 @@ async function processLayerPlanGPU(layerIdx, inputBuffer, numTokens, isPrefill,
|
|
|
851
854
|
hiddenSize,
|
|
852
855
|
probes: context.debugProbes,
|
|
853
856
|
recorder,
|
|
857
|
+
dtype: toDtype,
|
|
854
858
|
});
|
|
855
859
|
}
|
|
856
860
|
break;
|
|
@@ -880,6 +884,7 @@ async function processLayerPlanGPU(layerIdx, inputBuffer, numTokens, isPrefill,
|
|
|
880
884
|
hiddenSize,
|
|
881
885
|
probes: context.debugProbes,
|
|
882
886
|
recorder,
|
|
887
|
+
dtype: getSlotDtype('state') ?? activationDtype,
|
|
883
888
|
});
|
|
884
889
|
|
|
885
890
|
const computeConfig = context.runtimeComputeConfig ?? null;
|
|
@@ -3,6 +3,7 @@ import type { Tensor } from '../../../gpu/tensor.js';
|
|
|
3
3
|
import type { WeightBuffer } from '../../../gpu/weight-buffer.js';
|
|
4
4
|
import type { CommandRecorder } from '../../../gpu/command-recorder.js';
|
|
5
5
|
import type { LinearNormMode } from '../../../config/schema/index.js';
|
|
6
|
+
import type { ProbeConfigSchema } from '../../../config/schema/index.js';
|
|
6
7
|
|
|
7
8
|
export interface LinearLayerRuntimeState {
|
|
8
9
|
layerIdx: number;
|
|
@@ -67,6 +68,7 @@ export interface RunLinearAttentionLayerOptions {
|
|
|
67
68
|
weight: GPUBuffer | Float32Array | ArrayBuffer,
|
|
68
69
|
label: string
|
|
69
70
|
) => GPUBuffer;
|
|
71
|
+
debugProbes?: ProbeConfigSchema[] | null;
|
|
70
72
|
recorder?: CommandRecorder | null;
|
|
71
73
|
}
|
|
72
74
|
|
|
@@ -74,6 +76,14 @@ export declare function hasLinearAttentionLayers(layerTypes: unknown): boolean;
|
|
|
74
76
|
|
|
75
77
|
export declare function createLinearAttentionRuntime(): LinearAttentionRuntime;
|
|
76
78
|
|
|
79
|
+
export declare function inferLinearNormMode(
|
|
80
|
+
weight: { size?: number; dtype?: string } | GPUBuffer | WeightBuffer | ArrayBufferView | ArrayBuffer | null | undefined,
|
|
81
|
+
projectionLayout: {
|
|
82
|
+
headVDim: number;
|
|
83
|
+
valueDim: number;
|
|
84
|
+
}
|
|
85
|
+
): LinearNormMode | null;
|
|
86
|
+
|
|
77
87
|
export declare function resetLinearAttentionRuntime(
|
|
78
88
|
runtime: LinearAttentionRuntime | null | undefined
|
|
79
89
|
): LinearAttentionRuntime;
|
|
@@ -4,6 +4,7 @@ import { readBuffer, releaseBuffer, uploadData, acquireBuffer } from '../../../m
|
|
|
4
4
|
import { log } from '../../../debug/index.js';
|
|
5
5
|
import { decodeReadback } from './debug-utils/index.js';
|
|
6
6
|
import { runLinearAttentionCoreGPU } from '../../../gpu/kernels/linear-attention-core.js';
|
|
7
|
+
import { runProbes } from './probes.js';
|
|
7
8
|
|
|
8
9
|
const LINEAR_RUNTIME_SCHEMA_VERSION = 1;
|
|
9
10
|
const QK_L2NORM_EPS = 1e-6;
|
|
@@ -173,9 +174,22 @@ function inferLinearNormModeFromWeight(weight, projectionLayout) {
|
|
|
173
174
|
if (weight instanceof ArrayBuffer) {
|
|
174
175
|
return classify(Math.trunc(weight.byteLength / Float32Array.BYTES_PER_ELEMENT));
|
|
175
176
|
}
|
|
177
|
+
const explicitDtype = typeof weight?.dtype === 'string' ? weight.dtype.toLowerCase() : null;
|
|
178
|
+
const trackedDtype = isGpuBuffer(weight) ? String(getBufferDtype(weight) ?? '').toLowerCase() : '';
|
|
179
|
+
const bytesPerElement = bytesFromDtype(explicitDtype || trackedDtype || null);
|
|
180
|
+
const sizedElements = Number.isFinite(weight?.size)
|
|
181
|
+
? Math.trunc(Number(weight.size) / bytesPerElement)
|
|
182
|
+
: null;
|
|
183
|
+
if (sizedElements && Number(weight.size) % bytesPerElement === 0) {
|
|
184
|
+
return classify(sizedElements);
|
|
185
|
+
}
|
|
176
186
|
return null;
|
|
177
187
|
}
|
|
178
188
|
|
|
189
|
+
export function inferLinearNormMode(weight, projectionLayout) {
|
|
190
|
+
return inferLinearNormModeFromWeight(weight, projectionLayout);
|
|
191
|
+
}
|
|
192
|
+
|
|
179
193
|
function resolveLinearNormMode(configNormMode, normWeight, projectionLayout, layerIdx) {
|
|
180
194
|
const configuredMode = normalizeLinearNormMode(configNormMode);
|
|
181
195
|
const inferredMode = inferLinearNormModeFromWeight(normWeight, projectionLayout);
|
|
@@ -185,7 +199,15 @@ function resolveLinearNormMode(configNormMode, normWeight, projectionLayout, lay
|
|
|
185
199
|
`but norm.weight shape implies "${inferredMode}".`
|
|
186
200
|
);
|
|
187
201
|
}
|
|
188
|
-
|
|
202
|
+
if (configuredMode) {
|
|
203
|
+
return configuredMode;
|
|
204
|
+
}
|
|
205
|
+
if (inferredMode) {
|
|
206
|
+
return inferredMode;
|
|
207
|
+
}
|
|
208
|
+
throw new Error(
|
|
209
|
+
`linear_attention layer ${layerIdx} requires explicit linearNormMode or a norm.weight shape that resolves it.`
|
|
210
|
+
);
|
|
189
211
|
}
|
|
190
212
|
|
|
191
213
|
async function readWeightAsF32(weight, expectedElements, label) {
|
|
@@ -395,10 +417,17 @@ async function createLayerRuntimeState(
|
|
|
395
417
|
|
|
396
418
|
let convKernelSize = toPositiveInt(config.linearConvKernelDim) ?? null;
|
|
397
419
|
if (isWeightBuffer(convKernel) && Array.isArray(convKernel.shape) && convKernel.shape.length >= 3) {
|
|
398
|
-
|
|
420
|
+
const shapeKernelSize = toPositiveInt(convKernel.shape[2]) ?? null;
|
|
421
|
+
if (convKernelSize != null && shapeKernelSize != null && convKernelSize !== shapeKernelSize) {
|
|
422
|
+
throw new Error(
|
|
423
|
+
`linear_attention layer ${layerIdx} declares linearConvKernelDim=${convKernelSize}, ` +
|
|
424
|
+
`but conv1d weight shape implies ${shapeKernelSize}.`
|
|
425
|
+
);
|
|
426
|
+
}
|
|
427
|
+
convKernelSize = shapeKernelSize ?? convKernelSize;
|
|
399
428
|
}
|
|
400
429
|
if (!convKernelSize) {
|
|
401
|
-
|
|
430
|
+
throw new Error(`linear_attention layer ${layerIdx} requires linearConvKernelDim.`);
|
|
402
431
|
}
|
|
403
432
|
|
|
404
433
|
const convWeight = await readWeightAsF32(
|
|
@@ -435,6 +464,11 @@ async function createLayerRuntimeState(
|
|
|
435
464
|
const recurrentState = new Float32Array(
|
|
436
465
|
projectionLayout.numVHeads * projectionLayout.headKDim * projectionLayout.headVDim
|
|
437
466
|
);
|
|
467
|
+
const rmsNormEps = Number(config.rmsNormEps);
|
|
468
|
+
if (!Number.isFinite(rmsNormEps) || rmsNormEps <= 0) {
|
|
469
|
+
throw new Error(`linear_attention layer ${layerIdx} requires a positive rmsNormEps.`);
|
|
470
|
+
}
|
|
471
|
+
|
|
438
472
|
const layerState = {
|
|
439
473
|
layerIdx,
|
|
440
474
|
seqLen: currentSeqLen,
|
|
@@ -452,7 +486,7 @@ async function createLayerRuntimeState(
|
|
|
452
486
|
vSize: projectionLayout.vSize,
|
|
453
487
|
qRep: projectionLayout.qRep,
|
|
454
488
|
normMode,
|
|
455
|
-
rmsNormEps
|
|
489
|
+
rmsNormEps,
|
|
456
490
|
convWeight,
|
|
457
491
|
dtBias,
|
|
458
492
|
aNegExp,
|
|
@@ -681,13 +715,13 @@ export async function runLinearAttentionLayer(inputTensor, layerWeights, options
|
|
|
681
715
|
const normWeightBuffer = getNormWeightBuffer(layerWeights.inputNorm, `L${layerIdx}.linear_input_norm`);
|
|
682
716
|
try {
|
|
683
717
|
if (recorder) {
|
|
684
|
-
normedTensor = await recordRMSNorm(recorder, inputTensor, normWeightBuffer,
|
|
718
|
+
normedTensor = await recordRMSNorm(recorder, inputTensor, normWeightBuffer, layerState.rmsNormEps, {
|
|
685
719
|
batchSize: numTokens,
|
|
686
720
|
hiddenSize,
|
|
687
721
|
rmsNormWeightOffset: config.rmsNormWeightOffset,
|
|
688
722
|
});
|
|
689
723
|
} else {
|
|
690
|
-
normedTensor = await runRMSNorm(inputTensor, normWeightBuffer,
|
|
724
|
+
normedTensor = await runRMSNorm(inputTensor, normWeightBuffer, layerState.rmsNormEps, {
|
|
691
725
|
batchSize: numTokens,
|
|
692
726
|
hiddenSize,
|
|
693
727
|
rmsNormWeightOffset: config.rmsNormWeightOffset,
|
|
@@ -755,6 +789,38 @@ export async function runLinearAttentionLayer(inputTensor, layerWeights, options
|
|
|
755
789
|
});
|
|
756
790
|
|
|
757
791
|
try {
|
|
792
|
+
await runProbes('linear_qkv_proj', qkvTensor.buffer, {
|
|
793
|
+
layerIdx,
|
|
794
|
+
numTokens,
|
|
795
|
+
hiddenSize: projectionLayout.convDim,
|
|
796
|
+
probes: options.debugProbes,
|
|
797
|
+
recorder,
|
|
798
|
+
dtype: qkvTensor.dtype,
|
|
799
|
+
});
|
|
800
|
+
await runProbes('linear_z_proj', zTensor.buffer, {
|
|
801
|
+
layerIdx,
|
|
802
|
+
numTokens,
|
|
803
|
+
hiddenSize: projectionLayout.valueDim,
|
|
804
|
+
probes: options.debugProbes,
|
|
805
|
+
recorder,
|
|
806
|
+
dtype: zTensor.dtype,
|
|
807
|
+
});
|
|
808
|
+
await runProbes('linear_a_proj', aTensor.buffer, {
|
|
809
|
+
layerIdx,
|
|
810
|
+
numTokens,
|
|
811
|
+
hiddenSize: projectionLayout.numVHeads,
|
|
812
|
+
probes: options.debugProbes,
|
|
813
|
+
recorder,
|
|
814
|
+
dtype: aTensor.dtype,
|
|
815
|
+
});
|
|
816
|
+
await runProbes('linear_b_proj', bTensor.buffer, {
|
|
817
|
+
layerIdx,
|
|
818
|
+
numTokens,
|
|
819
|
+
hiddenSize: projectionLayout.numVHeads,
|
|
820
|
+
probes: options.debugProbes,
|
|
821
|
+
recorder,
|
|
822
|
+
dtype: bTensor.dtype,
|
|
823
|
+
});
|
|
758
824
|
const coreTensor = await runLinearAttentionCoreGPU(
|
|
759
825
|
qkvTensor,
|
|
760
826
|
zTensor,
|
|
@@ -768,6 +834,14 @@ export async function runLinearAttentionLayer(inputTensor, layerWeights, options
|
|
|
768
834
|
recorder,
|
|
769
835
|
}
|
|
770
836
|
);
|
|
837
|
+
await runProbes('linear_core_out', coreTensor.buffer, {
|
|
838
|
+
layerIdx,
|
|
839
|
+
numTokens,
|
|
840
|
+
hiddenSize: projectionLayout.valueDim,
|
|
841
|
+
probes: options.debugProbes,
|
|
842
|
+
recorder,
|
|
843
|
+
dtype: coreTensor.dtype,
|
|
844
|
+
});
|
|
771
845
|
layerState.seqLen = currentSeqLen + numTokens;
|
|
772
846
|
const outProjWeight = getWeightBuffer(layerWeights.oProj, `L${layerIdx}.linear_out_proj`);
|
|
773
847
|
try {
|