@simulatte/doppler 0.1.6 → 0.1.8
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +145 -0
- package/README.md +16 -23
- package/package.json +30 -32
- package/src/adapters/adapter-registry.js +12 -1
- package/src/adapters/lora-loader.js +23 -6
- package/src/bridge/extension-client.d.ts +5 -0
- package/src/bridge/extension-client.js +40 -0
- package/src/bridge/index.d.ts +2 -1
- package/src/bridge/index.js +6 -4
- package/src/browser/browser-converter.js +31 -1
- package/src/browser/file-picker.js +6 -0
- package/src/browser/safetensors-parser-browser.js +84 -1
- package/src/browser/shard-io-browser.js +2 -2
- package/src/browser/tensor-source-download.js +8 -2
- package/src/browser/tensor-source-http.d.ts +1 -0
- package/src/browser/tensor-source-http.js +5 -1
- package/src/client/doppler-api.browser.js +20 -4
- package/src/client/doppler-api.js +19 -3
- package/src/client/doppler-provider/generation.js +12 -0
- package/src/client/doppler-provider/model-manager.d.ts +10 -0
- package/src/client/doppler-provider/model-manager.js +91 -19
- package/src/client/doppler-provider/source-runtime.d.ts +2 -1
- package/src/client/doppler-provider/source-runtime.js +132 -13
- package/src/client/doppler-registry.json +5 -20
- package/src/config/backward-registry-loader.js +17 -2
- package/src/config/execution-v0-contract-check.js +113 -15
- package/src/config/kernel-path-contract-check.js +57 -29
- package/src/config/kernel-path-loader.d.ts +5 -0
- package/src/config/kernel-path-loader.js +18 -36
- package/src/config/kernels/kernel-ref-digests.js +1 -1
- package/src/config/kernels/registry.js +14 -1
- package/src/config/kernels/registry.json +81 -5
- package/src/config/loader.d.ts +1 -1
- package/src/config/loader.js +15 -2
- package/src/config/merge-contract-check.js +66 -4
- package/src/config/merge-helpers.js +128 -7
- package/src/config/merge.d.ts +1 -0
- package/src/config/merge.js +10 -0
- package/src/config/param-validator.js +47 -2
- package/src/config/presets/kernel-paths/{gemma2-q4k-dequant-f32a.json → gemma2-q4k-dequant-f32a-nosubgroups.json} +3 -3
- package/src/config/presets/kernel-paths/gemma3-f16-fused-f32a-online-streamingprefill.json +223 -0
- package/src/config/presets/kernel-paths/{gemma3-q4k-dequant-f32a.json → gemma3-q4k-dequant-f32a-nosubgroups.json} +3 -3
- package/src/config/presets/kernel-paths/gemma3-q4k-dequant-f32w-f32a-online.json +56 -0
- package/src/config/presets/kernel-paths/lfm2-q4k-dequant-f32a-nosubgroups.json +61 -0
- package/src/config/presets/kernel-paths/registry.json +43 -8
- package/src/config/presets/models/gemma2.json +3 -2
- package/src/config/presets/models/gemma3.json +2 -0
- package/src/config/presets/models/qwen3.json +4 -3
- package/src/config/presets/models/qwen3_5.json +16 -0
- package/src/config/presets/runtime/experiments/bench/gemma3-bench-q4k.json +1 -1
- package/src/config/presets/runtime/experiments/debug/gemma3-debug-q4k.json +1 -1
- package/src/config/presets/runtime/experiments/verify/gemma3-verify.json +1 -1
- package/src/config/presets/runtime/kernels/dequant-f16-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/dequant-f32-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/embeddinggemma-q4k-dequant-f32a.json +37 -0
- package/src/config/presets/runtime/kernels/fused-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f16a.json +33 -0
- package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f32a-nosubgroups.json +33 -0
- package/src/config/presets/runtime/kernels/gemma2-q4k-fused-f32a.json +33 -0
- package/src/config/presets/runtime/kernels/safe-q4k.json +6 -13
- package/src/config/presets/runtime/model/qwen3-5-layer-probe.json +52 -0
- package/src/config/presets/runtime/model/qwen3-5-linear-attn-debug.json +90 -0
- package/src/config/presets/runtime/platform/metal-apple-q4k.json +1 -1
- package/src/config/runtime.js +6 -1
- package/src/config/schema/conversion.schema.d.ts +1 -0
- package/src/config/schema/debug.schema.d.ts +5 -0
- package/src/config/schema/doppler.schema.js +16 -21
- package/src/config/schema/inference-defaults.schema.js +3 -3
- package/src/config/schema/kernel-path.schema.d.ts +5 -1
- package/src/config/schema/kernel-thresholds.schema.js +12 -4
- package/src/config/schema/manifest.schema.d.ts +3 -2
- package/src/config/schema/manifest.schema.js +17 -4
- package/src/config/schema/storage.schema.js +1 -1
- package/src/config/training-defaults.js +30 -22
- package/src/converter/conversion-plan.js +104 -11
- package/src/converter/core.d.ts +7 -0
- package/src/converter/core.js +16 -9
- package/src/converter/execution-v0-manifest.js +4 -1
- package/src/converter/index.d.ts +1 -0
- package/src/converter/index.js +1 -0
- package/src/converter/manifest-inference.js +50 -29
- package/src/converter/parsers/diffusion.js +0 -3
- package/src/converter/parsers/transformer.js +4 -0
- package/src/converter/quantization-info.js +40 -16
- package/src/converter/quantizer.js +19 -12
- package/src/converter/rope-config.js +8 -6
- package/src/converter/shard-packer.d.ts +1 -1
- package/src/converter/shard-packer.js +4 -1
- package/src/converter/tokenizer-utils.d.ts +1 -0
- package/src/converter/tokenizer-utils.js +4 -1
- package/src/debug/config.js +123 -11
- package/src/debug/reference/hf_qwen35_linear_attn_debug.py +268 -0
- package/src/debug/signals.js +7 -1
- package/src/debug/tensor.d.ts +2 -0
- package/src/debug/tensor.js +13 -2
- package/src/distribution/p2p-control-plane.js +52 -12
- package/src/distribution/p2p-observability.js +43 -7
- package/src/distribution/p2p-webrtc-browser.js +20 -0
- package/src/distribution/shard-delivery.js +83 -27
- package/src/formats/gguf/types.js +33 -16
- package/src/formats/rdrr/groups.d.ts +12 -4
- package/src/formats/rdrr/groups.js +3 -6
- package/src/formats/rdrr/parsing.d.ts +4 -0
- package/src/formats/rdrr/parsing.js +53 -3
- package/src/formats/rdrr/types.d.ts +2 -1
- package/src/gpu/command-recorder.js +86 -61
- package/src/gpu/device.d.ts +1 -0
- package/src/gpu/device.js +73 -19
- package/src/gpu/kernel-tuner/benchmarks.js +326 -316
- package/src/gpu/kernel-tuner/cache.js +71 -4
- package/src/gpu/kernel-tuner/tuner.js +22 -4
- package/src/gpu/kernels/attention.js +15 -34
- package/src/gpu/kernels/backward/adam.js +62 -58
- package/src/gpu/kernels/backward/attention_backward.js +257 -169
- package/src/gpu/kernels/backward/conv2d_backward.js +14 -1
- package/src/gpu/kernels/cast.js +191 -149
- package/src/gpu/kernels/check-stop.js +33 -44
- package/src/gpu/kernels/conv2d.js +27 -17
- package/src/gpu/kernels/cross_entropy_loss.js +21 -15
- package/src/gpu/kernels/depthwise_conv2d.js +36 -26
- package/src/gpu/kernels/dequant.js +178 -126
- package/src/gpu/kernels/energy.d.ts +3 -21
- package/src/gpu/kernels/energy.js +111 -88
- package/src/gpu/kernels/feature-check.js +1 -1
- package/src/gpu/kernels/fused_ffn.js +84 -65
- package/src/gpu/kernels/fused_matmul_residual.js +56 -33
- package/src/gpu/kernels/fused_matmul_rmsnorm.js +62 -45
- package/src/gpu/kernels/gather.js +33 -15
- package/src/gpu/kernels/gelu.js +19 -11
- package/src/gpu/kernels/grouped_pointwise_conv2d.js +33 -23
- package/src/gpu/kernels/groupnorm.js +34 -23
- package/src/gpu/kernels/index.d.ts +8 -0
- package/src/gpu/kernels/index.js +6 -0
- package/src/gpu/kernels/kv-quantize.js +5 -2
- package/src/gpu/kernels/layernorm.js +35 -19
- package/src/gpu/kernels/logit-merge.js +5 -3
- package/src/gpu/kernels/matmul-selection.js +47 -4
- package/src/gpu/kernels/matmul.d.ts +2 -0
- package/src/gpu/kernels/matmul.js +59 -40
- package/src/gpu/kernels/modulate.js +23 -15
- package/src/gpu/kernels/moe.js +221 -175
- package/src/gpu/kernels/pixel_shuffle.js +22 -14
- package/src/gpu/kernels/relu.js +18 -10
- package/src/gpu/kernels/repeat_channels.js +25 -17
- package/src/gpu/kernels/residual.js +37 -27
- package/src/gpu/kernels/rmsnorm.js +66 -43
- package/src/gpu/kernels/rope.js +3 -0
- package/src/gpu/kernels/sample.js +27 -38
- package/src/gpu/kernels/sana_linear_attention.js +18 -10
- package/src/gpu/kernels/scale.js +18 -11
- package/src/gpu/kernels/shader-cache.js +4 -2
- package/src/gpu/kernels/silu.js +120 -72
- package/src/gpu/kernels/softmax.js +44 -25
- package/src/gpu/kernels/split_qg.d.ts +50 -0
- package/src/gpu/kernels/split_qg.js +46 -0
- package/src/gpu/kernels/split_qg.wgsl +58 -0
- package/src/gpu/kernels/split_qg_f16.wgsl +62 -0
- package/src/gpu/kernels/split_qkv.js +23 -13
- package/src/gpu/kernels/transpose.js +18 -10
- package/src/gpu/kernels/transpose.wgsl +5 -3
- package/src/gpu/kernels/upsample2d.js +21 -13
- package/src/gpu/kernels/utils.js +20 -13
- package/src/gpu/partitioned-buffer-pool.js +10 -2
- package/src/gpu/perf-guards.js +2 -9
- package/src/gpu/profiler.js +27 -22
- package/src/gpu/readback-utils.d.ts +16 -0
- package/src/gpu/readback-utils.js +41 -0
- package/src/gpu/submit-tracker.js +13 -0
- package/src/gpu/uniform-cache.d.ts +1 -0
- package/src/gpu/uniform-cache.js +30 -9
- package/src/gpu/weight-buffer.d.ts +1 -1
- package/src/gpu/weight-buffer.js +1 -1
- package/src/hotswap/intent-bundle.js +6 -0
- package/src/hotswap/manifest.d.ts +10 -1
- package/src/hotswap/manifest.js +12 -2
- package/src/hotswap/runtime.js +30 -8
- package/src/index-browser.d.ts +44 -0
- package/src/index-browser.js +14 -0
- package/src/inference/browser-harness-contract-helpers.d.ts +5 -0
- package/src/inference/browser-harness-contract-helpers.js +28 -0
- package/src/inference/browser-harness-diffusion-energy-suites.d.ts +2 -0
- package/src/inference/browser-harness-diffusion-energy-suites.js +269 -0
- package/src/inference/browser-harness-model-helpers.d.ts +16 -0
- package/src/inference/browser-harness-model-helpers.js +217 -0
- package/src/inference/browser-harness-report-helpers.d.ts +7 -0
- package/src/inference/browser-harness-report-helpers.js +42 -0
- package/src/inference/browser-harness-runtime-helpers.d.ts +61 -0
- package/src/inference/browser-harness-runtime-helpers.js +415 -0
- package/src/inference/browser-harness-suite-helpers.d.ts +28 -0
- package/src/inference/browser-harness-suite-helpers.js +268 -0
- package/src/inference/browser-harness-text-helpers.d.ts +27 -0
- package/src/inference/browser-harness-text-helpers.js +788 -0
- package/src/inference/browser-harness.d.ts +8 -0
- package/src/inference/browser-harness.js +149 -1996
- package/src/inference/kv-cache/base.js +140 -94
- package/src/inference/kv-cache/tiered.js +5 -3
- package/src/inference/moe-router.js +88 -56
- package/src/inference/multi-model-network.js +5 -3
- package/src/inference/network-evolution.d.ts +11 -2
- package/src/inference/network-evolution.js +20 -21
- package/src/inference/pipelines/context.d.ts +3 -0
- package/src/inference/pipelines/context.js +142 -2
- package/src/inference/pipelines/diffusion/helpers.js +10 -2
- package/src/inference/pipelines/diffusion/pipeline.js +2 -1
- package/src/inference/pipelines/diffusion/sd3-transformer.js +10 -10
- package/src/inference/pipelines/diffusion/text-encoder-gpu.js +8 -2
- package/src/inference/pipelines/diffusion/vae.js +3 -7
- package/src/inference/pipelines/energy/pipeline.js +27 -21
- package/src/inference/pipelines/energy/quintel.d.ts +5 -0
- package/src/inference/pipelines/energy/quintel.js +11 -0
- package/src/inference/pipelines/energy-head/row-head-pipeline.js +17 -13
- package/src/inference/pipelines/structured/json-head-pipeline.js +26 -11
- package/src/inference/pipelines/text/attention/output-projection.d.ts +12 -0
- package/src/inference/pipelines/text/attention/output-projection.js +8 -0
- package/src/inference/pipelines/text/attention/projections.d.ts +10 -1
- package/src/inference/pipelines/text/attention/projections.js +192 -112
- package/src/inference/pipelines/text/attention/record.js +77 -14
- package/src/inference/pipelines/text/attention/run.js +112 -14
- package/src/inference/pipelines/text/config.js +17 -4
- package/src/inference/pipelines/text/embed.js +2 -8
- package/src/inference/pipelines/text/execution-plan.js +46 -23
- package/src/inference/pipelines/text/execution-v0-contract-helpers.d.ts +59 -0
- package/src/inference/pipelines/text/execution-v0-contract-helpers.js +937 -0
- package/src/inference/pipelines/text/execution-v0-runtime-builders.d.ts +15 -0
- package/src/inference/pipelines/text/execution-v0-runtime-builders.js +279 -0
- package/src/inference/pipelines/text/execution-v0.js +62 -1013
- package/src/inference/pipelines/text/generator-runtime.js +5 -0
- package/src/inference/pipelines/text/generator-steps.d.ts +52 -0
- package/src/inference/pipelines/text/generator-steps.js +340 -221
- package/src/inference/pipelines/text/generator.js +56 -40
- package/src/inference/pipelines/text/init.d.ts +13 -0
- package/src/inference/pipelines/text/init.js +94 -25
- package/src/inference/pipelines/text/kernel-path-auto-select.js +2 -0
- package/src/inference/pipelines/text/kernel-trace.d.ts +2 -0
- package/src/inference/pipelines/text/kernel-trace.js +6 -0
- package/src/inference/pipelines/text/layer.js +4 -9
- package/src/inference/pipelines/text/linear-attention.d.ts +15 -0
- package/src/inference/pipelines/text/linear-attention.js +113 -9
- package/src/inference/pipelines/text/logits/gpu.js +12 -7
- package/src/inference/pipelines/text/logits/index.d.ts +6 -1
- package/src/inference/pipelines/text/logits/index.js +13 -12
- package/src/inference/pipelines/text/logits/utils.d.ts +7 -0
- package/src/inference/pipelines/text/logits/utils.js +9 -0
- package/src/inference/pipelines/text/lora-apply.js +50 -32
- package/src/inference/pipelines/text/model-load.js +282 -104
- package/src/inference/pipelines/text/moe-cache.js +5 -4
- package/src/inference/pipelines/text/moe-cpu-gptoss.js +74 -69
- package/src/inference/pipelines/text/moe-cpu.js +42 -38
- package/src/inference/pipelines/text/moe-gpu.js +110 -86
- package/src/inference/pipelines/text/ops.js +90 -90
- package/src/inference/pipelines/text/probes.js +9 -9
- package/src/inference/pipelines/text/sampling.js +52 -6
- package/src/inference/pipelines/text/weights.js +17 -7
- package/src/inference/pipelines/text.js +13 -1
- package/src/inference/speculative.d.ts +2 -2
- package/src/inference/speculative.js +4 -18
- package/src/inference/test-harness.d.ts +1 -1
- package/src/inference/test-harness.js +17 -7
- package/src/inference/tokenizer.d.ts +0 -5
- package/src/inference/tokenizer.js +4 -23
- package/src/inference/tokenizers/bpe.js +9 -0
- package/src/inference/tokenizers/bundled.js +20 -0
- package/src/inference/tokenizers/sentencepiece.js +12 -0
- package/src/loader/doppler-loader.js +38 -22
- package/src/loader/dtype-utils.js +3 -44
- package/src/loader/embedding-loader.js +7 -3
- package/src/loader/experts/expert-cache.js +13 -6
- package/src/loader/experts/expert-loader.js +10 -6
- package/src/loader/final-weights-loader.js +10 -4
- package/src/loader/layer-loader.js +2 -1
- package/src/loader/loader-state.js +2 -2
- package/src/loader/memory-monitor.js +8 -0
- package/src/loader/multi-model-loader.d.ts +14 -0
- package/src/loader/multi-model-loader.js +70 -24
- package/src/loader/shard-cache.js +84 -14
- package/src/loader/shard-resolver.js +25 -3
- package/src/loader/tensors/tensor-loader.js +214 -144
- package/src/loader/tensors/tensor-reader.js +76 -19
- package/src/loader/weight-downcast.js +1 -1
- package/src/memory/buffer-pool.d.ts +9 -1
- package/src/memory/buffer-pool.js +109 -44
- package/src/memory/unified-detect.js +1 -1
- package/src/rules/inference/dtype.rules.json +5 -0
- package/src/rules/inference/kernel-path.rules.json +24 -8
- package/src/rules/kernels/split-qg.rules.json +6 -0
- package/src/rules/rule-registry.js +27 -1
- package/src/storage/backends/opfs-store.js +68 -24
- package/src/storage/downloader.js +365 -83
- package/src/storage/index.d.ts +3 -0
- package/src/storage/index.js +3 -0
- package/src/storage/preflight.d.ts +2 -2
- package/src/storage/preflight.js +24 -2
- package/src/storage/quickstart-downloader.js +11 -5
- package/src/storage/registry.js +10 -4
- package/src/storage/reports.js +1 -1
- package/src/storage/shard-manager.d.ts +15 -1
- package/src/storage/shard-manager.js +55 -6
- package/src/storage/source-artifact-store.d.ts +52 -0
- package/src/storage/source-artifact-store.js +234 -0
- package/src/tooling/command-api-constants.d.ts +9 -0
- package/src/tooling/command-api-constants.js +9 -0
- package/src/tooling/command-api-family-normalizers.d.ts +9 -0
- package/src/tooling/command-api-family-normalizers.js +343 -0
- package/src/tooling/command-api-helpers.d.ts +25 -0
- package/src/tooling/command-api-helpers.js +262 -0
- package/src/tooling/command-api.js +16 -602
- package/src/tooling/command-envelope.js +4 -1
- package/src/tooling/command-runner-shared.js +52 -18
- package/src/tooling/conversion-config-materializer.js +3 -5
- package/src/tooling/lean-execution-contract.js +150 -3
- package/src/tooling/node-browser-command-runner.js +161 -271
- package/src/tooling/node-command-runner.js +29 -3
- package/src/tooling/node-converter.js +30 -1
- package/src/tooling/node-source-runtime.d.ts +1 -1
- package/src/tooling/node-source-runtime.js +120 -3
- package/src/tooling/node-webgpu.js +24 -21
- package/src/tooling/opfs-cache.js +21 -4
- package/src/tooling/runtime-input-composition.d.ts +38 -0
- package/src/tooling/runtime-input-composition.js +86 -0
- package/src/tooling/source-runtime-bundle.d.ts +40 -5
- package/src/tooling/source-runtime-bundle.js +261 -34
- package/src/tooling/source-runtime-materializer.d.ts +6 -0
- package/src/tooling/source-runtime-materializer.js +93 -0
- package/src/training/attention-backward.js +32 -17
- package/src/training/autograd.js +80 -52
- package/src/training/checkpoint-watch.d.ts +2 -1
- package/src/training/checkpoint-watch.js +39 -6
- package/src/training/checkpoint.js +40 -11
- package/src/training/clip.js +2 -1
- package/src/training/datasets/token-batch.js +20 -8
- package/src/training/distillation/checkpoint-watch.js +1 -0
- package/src/training/distillation/student-fixture.d.ts +22 -0
- package/src/training/distillation/student-fixture.js +846 -0
- package/src/training/distillation/suite-data.d.ts +45 -0
- package/src/training/distillation/suite-data.js +189 -0
- package/src/training/lora-pipeline.js +4 -7
- package/src/training/lora.js +26 -12
- package/src/training/loss.js +5 -6
- package/src/training/objectives/cross_entropy.js +2 -5
- package/src/training/objectives/distill_kd.js +4 -8
- package/src/training/objectives/distill_triplet.js +4 -8
- package/src/training/objectives/ul_stage2_base.js +4 -8
- package/src/training/operator-command.js +2 -0
- package/src/training/optimizer.js +19 -7
- package/src/training/runner.js +2 -1
- package/src/training/suite.js +18 -978
- package/src/training/tensor-factory.d.ts +9 -0
- package/src/training/tensor-factory.js +13 -0
- package/src/training/trainer.js +3 -5
- package/src/training/ul_dataset.js +3 -5
- package/src/training/workloads.js +70 -79
- package/src/types/model.d.ts +5 -0
- package/src/version.js +1 -1
- package/tools/convert-safetensors-node.js +22 -16
- package/tools/doppler-cli.js +50 -26
|
@@ -122,6 +122,20 @@ function resolveTokenText(tokenizer, tokenIds, fallbackText = '?', renderTokenTe
|
|
|
122
122
|
return fallbackText;
|
|
123
123
|
}
|
|
124
124
|
|
|
125
|
+
export function shouldRetryWithFinitenessFallback(error) {
|
|
126
|
+
if (error?.name === 'FinitenessError') {
|
|
127
|
+
return true;
|
|
128
|
+
}
|
|
129
|
+
const message = typeof error?.message === 'string'
|
|
130
|
+
? error.message
|
|
131
|
+
: (typeof error === 'string' ? error : '');
|
|
132
|
+
if (!message.startsWith('[Sampling]')) {
|
|
133
|
+
return false;
|
|
134
|
+
}
|
|
135
|
+
return message.includes('no finite candidate logits after masking the pad token')
|
|
136
|
+
|| message.includes('Softmax produced no finite candidate probabilities');
|
|
137
|
+
}
|
|
138
|
+
|
|
125
139
|
export class PipelineGenerator {
|
|
126
140
|
|
|
127
141
|
#state;
|
|
@@ -351,7 +365,7 @@ export class PipelineGenerator {
|
|
|
351
365
|
try {
|
|
352
366
|
prefillLogits = await this._prefill(inputIds, opts);
|
|
353
367
|
} catch (error) {
|
|
354
|
-
if (error
|
|
368
|
+
if (shouldRetryWithFinitenessFallback(error)) {
|
|
355
369
|
log.warn('Pipeline', `FinitenessGuard caught NaN/Inf during prefill. Retrying with F32 precision.`);
|
|
356
370
|
prefillLogits = await this._retryWithFinitenessFallback(
|
|
357
371
|
opts,
|
|
@@ -395,13 +409,34 @@ export class PipelineGenerator {
|
|
|
395
409
|
log.debug('Pipeline', `After rep penalty top-5: ${topAfterPenalty.map(t => `"${t.text}"(${(t.prob * 100).toFixed(1)}%)`).join(', ')}`);
|
|
396
410
|
}
|
|
397
411
|
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
412
|
+
let firstToken;
|
|
413
|
+
try {
|
|
414
|
+
firstToken = sample(prefillLogits, {
|
|
415
|
+
temperature: opts.temperature,
|
|
416
|
+
topP: opts.topP,
|
|
417
|
+
topK: opts.topK,
|
|
418
|
+
padTokenId,
|
|
419
|
+
seed: opts.seed,
|
|
420
|
+
});
|
|
421
|
+
} catch (error) {
|
|
422
|
+
if (!shouldRetryWithFinitenessFallback(error)) {
|
|
423
|
+
throw error;
|
|
424
|
+
}
|
|
425
|
+
log.warn('Pipeline', 'FinitenessGuard caught non-finite prefill logits at sampling. Retrying with F32 precision.');
|
|
426
|
+
prefillLogits = await this._retryWithFinitenessFallback(
|
|
427
|
+
opts,
|
|
428
|
+
'prefill-sample',
|
|
429
|
+
() => this._prefill(inputIds, opts)
|
|
430
|
+
);
|
|
431
|
+
applyRepetitionPenalty(prefillLogits, generatedIds, opts.repetitionPenalty);
|
|
432
|
+
firstToken = sample(prefillLogits, {
|
|
433
|
+
temperature: opts.temperature,
|
|
434
|
+
topP: opts.topP,
|
|
435
|
+
topK: opts.topK,
|
|
436
|
+
padTokenId,
|
|
437
|
+
seed: opts.seed,
|
|
438
|
+
});
|
|
439
|
+
}
|
|
405
440
|
|
|
406
441
|
if (opts.debug) {
|
|
407
442
|
const firstTokenText = resolveTokenText(this.#state.tokenizer, [firstToken], `[${firstToken}]`, (tokens) => this.#state.tokenizer?.decode?.(tokens, true, false));
|
|
@@ -479,7 +514,7 @@ export class PipelineGenerator {
|
|
|
479
514
|
try {
|
|
480
515
|
prefillResult = await this._prefillToHidden(inputIds, opts);
|
|
481
516
|
} catch (error) {
|
|
482
|
-
if (error
|
|
517
|
+
if (shouldRetryWithFinitenessFallback(error)) {
|
|
483
518
|
log.warn('Pipeline', `FinitenessGuard caught NaN/Inf during prefillKVOnly. Retrying with F32 precision.`);
|
|
484
519
|
prefillResult = await this._retryWithFinitenessFallback(
|
|
485
520
|
opts,
|
|
@@ -544,7 +579,7 @@ export class PipelineGenerator {
|
|
|
544
579
|
try {
|
|
545
580
|
prefillResult = await this._prefillToHidden(inputIds, opts);
|
|
546
581
|
} catch (error) {
|
|
547
|
-
if (error
|
|
582
|
+
if (shouldRetryWithFinitenessFallback(error)) {
|
|
548
583
|
log.warn('Pipeline', `FinitenessGuard caught NaN/Inf during prefillWithEmbedding. Retrying with F32 precision.`);
|
|
549
584
|
prefillResult = await this._retryWithFinitenessFallback(
|
|
550
585
|
opts,
|
|
@@ -833,7 +868,7 @@ export class PipelineGenerator {
|
|
|
833
868
|
try {
|
|
834
869
|
nextToken = await this._decodeStep(generatedIds, opts);
|
|
835
870
|
} catch (singleTokenError) {
|
|
836
|
-
if (singleTokenError
|
|
871
|
+
if (shouldRetryWithFinitenessFallback(singleTokenError)) {
|
|
837
872
|
log.warn('Pipeline', `FinitenessGuard caught NaN/Inf at batch step ${tokensGenerated}. Truncating KV cache and retrying token with F32 precision.`);
|
|
838
873
|
nextToken = await this._retryDecodeStepWithFinitenessWindow(
|
|
839
874
|
generatedIds,
|
|
@@ -858,7 +893,7 @@ export class PipelineGenerator {
|
|
|
858
893
|
try {
|
|
859
894
|
nextToken = await this._decodeStep(generatedIds, opts);
|
|
860
895
|
} catch (error) {
|
|
861
|
-
if (error
|
|
896
|
+
if (shouldRetryWithFinitenessFallback(error)) {
|
|
862
897
|
log.warn('Pipeline', `FinitenessGuard caught NaN/Inf at step ${tokensGenerated}. Truncating KV cache and retrying token with F32 precision.`);
|
|
863
898
|
nextToken = await this._retryDecodeStepWithFinitenessWindow(
|
|
864
899
|
generatedIds,
|
|
@@ -918,11 +953,9 @@ export class PipelineGenerator {
|
|
|
918
953
|
throw new Error('Embed buffer not found or not a supported buffer type');
|
|
919
954
|
}
|
|
920
955
|
const embedBuffer = isWeightBuffer(embedBufferRaw) ? embedBufferRaw.buffer : embedBufferRaw;
|
|
921
|
-
const embedDtype =
|
|
922
|
-
?
|
|
923
|
-
:
|
|
924
|
-
? embedBufferRaw.dtype
|
|
925
|
-
: null;
|
|
956
|
+
const embedDtype = isCpuWeightBuffer(embedBufferRaw)
|
|
957
|
+
? embedBufferRaw.dtype
|
|
958
|
+
: getWeightDtype(embedBufferRaw);
|
|
926
959
|
if (opts.debug) {
|
|
927
960
|
const embedSize = embedBuffer instanceof GPUBuffer ? embedBuffer.size : 'N/A';
|
|
928
961
|
log.debug('Pipeline', `Embed buffer: type=${embedBuffer?.constructor?.name}, size=${embedSize}, dtype=${embedDtype}`);
|
|
@@ -1043,18 +1076,9 @@ export class PipelineGenerator {
|
|
|
1043
1076
|
if (allowReadback(`pipeline.prefill.layer-${l}`)) {
|
|
1044
1077
|
try {
|
|
1045
1078
|
const sampleSize = config.hiddenSize * activationBytes;
|
|
1046
|
-
const staging = device.createBuffer({
|
|
1047
|
-
size: sampleSize,
|
|
1048
|
-
usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ,
|
|
1049
|
-
});
|
|
1050
|
-
const enc = device.createCommandEncoder();
|
|
1051
1079
|
const lastTokenOffset = (numTokens - 1) * config.hiddenSize * activationBytes;
|
|
1052
|
-
|
|
1053
|
-
|
|
1054
|
-
await staging.mapAsync(GPUMapMode.READ);
|
|
1055
|
-
const data = decodeReadback(staging.getMappedRange().slice(0), activationDtype);
|
|
1056
|
-
staging.unmap();
|
|
1057
|
-
staging.destroy();
|
|
1080
|
+
const readback = await readBufferSlice(currentHiddenBuffer, lastTokenOffset, sampleSize);
|
|
1081
|
+
const data = decodeReadback(readback, activationDtype);
|
|
1058
1082
|
let min = Infinity;
|
|
1059
1083
|
let max = -Infinity;
|
|
1060
1084
|
let maxAbs = 0;
|
|
@@ -1112,20 +1136,12 @@ export class PipelineGenerator {
|
|
|
1112
1136
|
if (opts.debug) {
|
|
1113
1137
|
log.debug('Pipeline', `LAYER_LOOP_DONE, currentHiddenBuffer type=${currentHiddenBuffer?.constructor?.name}`);
|
|
1114
1138
|
if (currentHiddenBuffer && allowReadback('pipeline.prefill.final-hidden')) {
|
|
1115
|
-
const device = getDevice();
|
|
1116
1139
|
const lastTokenOffset = (numTokens - 1) * config.hiddenSize * activationBytes;
|
|
1117
1140
|
const sampleSize = config.hiddenSize * activationBytes;
|
|
1118
|
-
const
|
|
1119
|
-
|
|
1120
|
-
|
|
1121
|
-
|
|
1122
|
-
const enc = device.createCommandEncoder();
|
|
1123
|
-
enc.copyBufferToBuffer(currentHiddenBuffer, lastTokenOffset, staging, 0, sampleSize);
|
|
1124
|
-
device.queue.submit([enc.finish()]);
|
|
1125
|
-
await staging.mapAsync(GPUMapMode.READ);
|
|
1126
|
-
const data = decodeReadback(staging.getMappedRange().slice(0), activationDtype);
|
|
1127
|
-
staging.unmap();
|
|
1128
|
-
staging.destroy();
|
|
1141
|
+
const data = decodeReadback(
|
|
1142
|
+
await readBufferSlice(currentHiddenBuffer, lastTokenOffset, sampleSize),
|
|
1143
|
+
activationDtype
|
|
1144
|
+
);
|
|
1129
1145
|
const nanCount = Array.from(data).filter(x => !Number.isFinite(x)).length;
|
|
1130
1146
|
const nonZero = Array.from(data).filter(x => Number.isFinite(x) && x !== 0).slice(0, 5);
|
|
1131
1147
|
log.debug('Pipeline', `FINAL_HIDDEN[pos=${numTokens - 1}]: nan=${nanCount}/${data.length}, sample=[${nonZero.map(x => x.toFixed(4)).join(', ')}]`);
|
|
@@ -190,6 +190,12 @@ export interface WeightLoadResult {
|
|
|
190
190
|
layerRouterWeights: Map<number, RouterWeights>;
|
|
191
191
|
}
|
|
192
192
|
|
|
193
|
+
export interface ResolvedQ4KConfig {
|
|
194
|
+
useFusedQ4K: boolean;
|
|
195
|
+
q4kLayout: 'row' | 'col' | null;
|
|
196
|
+
keepF32Weights: boolean;
|
|
197
|
+
}
|
|
198
|
+
|
|
193
199
|
/** Options for loadWeights */
|
|
194
200
|
export interface LoadWeightsOptions {
|
|
195
201
|
storageContext?: PipelineStorageContext;
|
|
@@ -211,6 +217,13 @@ export function loadWeights(
|
|
|
211
217
|
options?: LoadWeightsOptions
|
|
212
218
|
): Promise<WeightLoadResult>;
|
|
213
219
|
|
|
220
|
+
export function resolveQ4KConfig(
|
|
221
|
+
manifest: Manifest,
|
|
222
|
+
kernelPath?: KernelPathSchema | null,
|
|
223
|
+
kernelPathSource?: KernelPathSource,
|
|
224
|
+
keepF32Weights?: boolean
|
|
225
|
+
): ResolvedQ4KConfig;
|
|
226
|
+
|
|
214
227
|
/**
|
|
215
228
|
* Apply Gemma chat template to a prompt.
|
|
216
229
|
*/
|
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
import { parseModelConfig } from './config.js';
|
|
4
4
|
import { getDevice, getDeviceLimits, getKernelCapabilities } from '../../../gpu/device.js';
|
|
5
|
-
import { acquireBuffer } from '../../../memory/buffer-pool.js';
|
|
5
|
+
import { acquireBuffer, releaseBuffer } from '../../../memory/buffer-pool.js';
|
|
6
6
|
import { KVCache, SlidingWindowKVCache, TieredKVCache, BasisDecomposedPagedCache } from '../../kv-cache.js';
|
|
7
7
|
import { Tokenizer } from '../../tokenizer.js';
|
|
8
8
|
import { MoERouter } from '../../moe-router.js';
|
|
@@ -11,9 +11,13 @@ import { getDopplerLoader } from '../../../loader/doppler-loader.js';
|
|
|
11
11
|
import { log, setGPUDevice, trace as debugTrace } from '../../../debug/index.js';
|
|
12
12
|
import { getRuntimeConfig } from '../../../config/runtime.js';
|
|
13
13
|
import { PAGED_LAYOUT_SEQ_LEN_THRESHOLD } from '../../../config/schema/index.js';
|
|
14
|
-
import { isKernelPathFusedQ4K } from '../../../config/kernel-path-loader.js';
|
|
14
|
+
import { isKernelPathFusedQ4K, kernelPathRequiresF32MatmulWeights } from '../../../config/kernel-path-loader.js';
|
|
15
15
|
import { createWeightBuffer, getWeightDtype, isWeightBuffer } from '../../../gpu/weight-buffer.js';
|
|
16
16
|
import { selectRuleValue } from '../../../rules/rule-registry.js';
|
|
17
|
+
import {
|
|
18
|
+
createSourceStorageContext,
|
|
19
|
+
getSourceRuntimeMetadata,
|
|
20
|
+
} from '../../../tooling/source-runtime-bundle.js';
|
|
17
21
|
|
|
18
22
|
function resolveErrorMessage(error) {
|
|
19
23
|
if (error && typeof error === 'object' && typeof error.message === 'string') {
|
|
@@ -56,12 +60,61 @@ function normalizeBaseUrl(baseUrl) {
|
|
|
56
60
|
return baseUrl.replace(/\/$/, '');
|
|
57
61
|
}
|
|
58
62
|
|
|
63
|
+
async function fetchBytes(url, offset = null, length = null) {
|
|
64
|
+
const headers = {};
|
|
65
|
+
if (Number.isFinite(offset) && Number.isFinite(length) && length > 0) {
|
|
66
|
+
const start = Math.max(0, Math.floor(offset));
|
|
67
|
+
const end = start + Math.max(0, Math.floor(length)) - 1;
|
|
68
|
+
headers.Range = `bytes=${start}-${end}`;
|
|
69
|
+
}
|
|
70
|
+
const response = await fetch(url, { headers });
|
|
71
|
+
if (!response.ok) {
|
|
72
|
+
throw new Error(`Failed to fetch ${url}: ${response.status}`);
|
|
73
|
+
}
|
|
74
|
+
return new Uint8Array(await response.arrayBuffer());
|
|
75
|
+
}
|
|
76
|
+
|
|
59
77
|
function createRemoteStorageContext(baseUrl, manifest) {
|
|
60
78
|
const root = normalizeBaseUrl(baseUrl);
|
|
61
79
|
if (!root || !isRDRRManifest(manifest)) {
|
|
62
80
|
return null;
|
|
63
81
|
}
|
|
64
82
|
|
|
83
|
+
const sourceRuntime = getSourceRuntimeMetadata(manifest);
|
|
84
|
+
if (sourceRuntime) {
|
|
85
|
+
const readRange = async (relativePath, offset, length) => {
|
|
86
|
+
const filename = String(relativePath || '').replace(/^\/+/, '');
|
|
87
|
+
if (!filename) {
|
|
88
|
+
throw new Error('Direct-source artifact path is required.');
|
|
89
|
+
}
|
|
90
|
+
const url = `${root}/${filename}`;
|
|
91
|
+
return fetchBytes(url, offset, length);
|
|
92
|
+
};
|
|
93
|
+
const readText = async (relativePath) => {
|
|
94
|
+
const filename = String(relativePath || '').replace(/^\/+/, '');
|
|
95
|
+
if (!filename) return null;
|
|
96
|
+
const response = await fetch(`${root}/${filename}`);
|
|
97
|
+
if (!response.ok) {
|
|
98
|
+
throw new Error(`Failed to fetch ${filename} from ${root}: ${response.status}`);
|
|
99
|
+
}
|
|
100
|
+
return response.text();
|
|
101
|
+
};
|
|
102
|
+
const readBinary = async (relativePath) => {
|
|
103
|
+
const filename = String(relativePath || '').replace(/^\/+/, '');
|
|
104
|
+
if (!filename) {
|
|
105
|
+
throw new Error('Direct-source binary asset path is required.');
|
|
106
|
+
}
|
|
107
|
+
return fetchBytes(`${root}/${filename}`);
|
|
108
|
+
};
|
|
109
|
+
return createSourceStorageContext({
|
|
110
|
+
manifest,
|
|
111
|
+
readRange,
|
|
112
|
+
readText,
|
|
113
|
+
readBinary,
|
|
114
|
+
verifyHashes: true,
|
|
115
|
+
});
|
|
116
|
+
}
|
|
117
|
+
|
|
65
118
|
return {
|
|
66
119
|
async loadShard(index) {
|
|
67
120
|
const shard = manifest.shards[index];
|
|
@@ -69,17 +122,13 @@ function createRemoteStorageContext(baseUrl, manifest) {
|
|
|
69
122
|
if (!filename) {
|
|
70
123
|
throw new Error(`Manifest shard ${index} is missing filename.`);
|
|
71
124
|
}
|
|
72
|
-
|
|
73
|
-
if (!response.ok) {
|
|
74
|
-
throw new Error(`Failed to fetch shard ${index} from ${root}: ${response.status}`);
|
|
75
|
-
}
|
|
76
|
-
return new Uint8Array(await response.arrayBuffer());
|
|
125
|
+
return fetchBytes(`${root}/${filename.replace(/^\/+/, '')}`);
|
|
77
126
|
},
|
|
78
127
|
};
|
|
79
128
|
}
|
|
80
129
|
|
|
81
130
|
|
|
82
|
-
function resolveQ4KConfig(
|
|
131
|
+
export function resolveQ4KConfig(
|
|
83
132
|
manifest,
|
|
84
133
|
kernelPath,
|
|
85
134
|
kernelPathSource = 'none',
|
|
@@ -101,18 +150,23 @@ function resolveQ4KConfig(
|
|
|
101
150
|
);
|
|
102
151
|
}
|
|
103
152
|
let useFused = kernelPath ? isKernelPathFusedQ4K(kernelPath) : hasSubgroups;
|
|
153
|
+
const kernelPathKeepsF32Weights = kernelPathRequiresF32MatmulWeights(kernelPath);
|
|
104
154
|
if (q4kLayout === 'col') {
|
|
105
155
|
useFused = false;
|
|
106
156
|
}
|
|
157
|
+
const resolvedKeepF32Weights = keepF32Weights || kernelPathKeepsF32Weights;
|
|
107
158
|
|
|
108
159
|
const pathLabel = kernelPath?.id ?? 'auto';
|
|
109
160
|
const layoutLabel = q4kLayout ?? 'none';
|
|
110
|
-
debugTrace.loader(
|
|
161
|
+
debugTrace.loader(
|
|
162
|
+
`Q4K config: fused=${useFused}, kernelPath=${pathLabel}, source=${kernelPathSource}, ` +
|
|
163
|
+
`layout=${layoutLabel}, keepF32Weights=${resolvedKeepF32Weights}, subgroups=${hasSubgroups}`
|
|
164
|
+
);
|
|
111
165
|
|
|
112
166
|
return {
|
|
113
167
|
useFusedQ4K: useFused,
|
|
114
168
|
q4kLayout,
|
|
115
|
-
keepF32Weights,
|
|
169
|
+
keepF32Weights: resolvedKeepF32Weights,
|
|
116
170
|
};
|
|
117
171
|
}
|
|
118
172
|
|
|
@@ -326,20 +380,29 @@ export async function initRoPEFrequencies(config, useGPU) {
|
|
|
326
380
|
// Upload to GPU if available
|
|
327
381
|
const device = getDevice();
|
|
328
382
|
if (device && useGPU) {
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
383
|
+
let cosBuffer = null;
|
|
384
|
+
let sinBuffer = null;
|
|
385
|
+
let localCosBuffer = null;
|
|
386
|
+
let localSinBuffer = null;
|
|
387
|
+
try {
|
|
388
|
+
cosBuffer = acquireBuffer(globalFreqs.cos.byteLength, undefined, 'rope_cos');
|
|
389
|
+
sinBuffer = acquireBuffer(globalFreqs.sin.byteLength, undefined, 'rope_sin');
|
|
390
|
+
device.queue.writeBuffer(cosBuffer, 0, globalFreqs.cos.buffer, globalFreqs.cos.byteOffset, globalFreqs.cos.byteLength);
|
|
391
|
+
device.queue.writeBuffer(sinBuffer, 0, globalFreqs.sin.buffer, globalFreqs.sin.byteOffset, globalFreqs.sin.byteLength);
|
|
392
|
+
|
|
393
|
+
if (localFreqs) {
|
|
394
|
+
localCosBuffer = acquireBuffer(localFreqs.cos.byteLength, undefined, 'rope_local_cos');
|
|
395
|
+
localSinBuffer = acquireBuffer(localFreqs.sin.byteLength, undefined, 'rope_local_sin');
|
|
396
|
+
device.queue.writeBuffer(localCosBuffer, 0, localFreqs.cos.buffer, localFreqs.cos.byteOffset, localFreqs.cos.byteLength);
|
|
397
|
+
device.queue.writeBuffer(localSinBuffer, 0, localFreqs.sin.buffer, localFreqs.sin.byteOffset, localFreqs.sin.byteLength);
|
|
398
|
+
}
|
|
399
|
+
} catch (error) {
|
|
400
|
+
for (const buffer of [cosBuffer, sinBuffer, localCosBuffer, localSinBuffer]) {
|
|
401
|
+
if (buffer) {
|
|
402
|
+
releaseBuffer(buffer);
|
|
403
|
+
}
|
|
404
|
+
}
|
|
405
|
+
throw error;
|
|
343
406
|
}
|
|
344
407
|
|
|
345
408
|
log.debug(
|
|
@@ -444,6 +507,12 @@ export function createKVCache(modelConfig, useGPU, debug = false, runtimeConfig)
|
|
|
444
507
|
cacheLayout = 'paged';
|
|
445
508
|
layoutSource = 'threshold';
|
|
446
509
|
}
|
|
510
|
+
if (forceContiguousKVCache && cacheLayout === 'paged') {
|
|
511
|
+
throw new Error(
|
|
512
|
+
'Paged KV cache layout is not supported for models with full-attention layers. ' +
|
|
513
|
+
'Set runtime.inference.kvcache.layout to "contiguous" instead.'
|
|
514
|
+
);
|
|
515
|
+
}
|
|
447
516
|
if (debug && cacheLayout !== runtimeKV.layout) {
|
|
448
517
|
log.debug('Pipeline', `KV cache layout override: ${runtimeKV.layout} -> ${cacheLayout} (${layoutSource})`);
|
|
449
518
|
}
|
|
@@ -541,7 +610,7 @@ export function createKVCache(modelConfig, useGPU, debug = false, runtimeConfig)
|
|
|
541
610
|
|
|
542
611
|
if (debug) {
|
|
543
612
|
if (forceContiguousKVCache && modelConfig.layerTypes) {
|
|
544
|
-
log.debug('Pipeline', 'Layer pattern includes full-attention layers;
|
|
613
|
+
log.debug('Pipeline', 'Layer pattern includes full-attention layers; paged layout blocked, contiguous enforced.');
|
|
545
614
|
}
|
|
546
615
|
const isSliding = kvCache instanceof SlidingWindowKVCache;
|
|
547
616
|
log.debug('Pipeline', `KV cache: type=${kvCache?.constructor?.name || 'unknown'}, kvDtype=${kvCache.kvDtype}, layout=${kvCache.layout}, maxSeqLen=${kvCache.maxSeqLen}, windowSize=${isSliding ? kvCache.windowSize : null}`);
|
|
@@ -78,6 +78,7 @@ export function resolveCapabilityKernelPathRef(configuredKernelPathRef, kernelPa
|
|
|
78
78
|
|
|
79
79
|
const normalizedPolicy = resolveKernelPathPolicy(kernelPathPolicy);
|
|
80
80
|
const hasSubgroups = capabilities?.hasSubgroups === true;
|
|
81
|
+
const hasF16 = capabilities?.hasF16 === true;
|
|
81
82
|
const normalizedSource = normalizeKernelPathSource(kernelPathSource);
|
|
82
83
|
const allowCapabilityAutoSelection = normalizedPolicy.mode === 'capability-aware'
|
|
83
84
|
&& normalizedPolicy.sourceScope.includes(normalizedSource);
|
|
@@ -85,6 +86,7 @@ export function resolveCapabilityKernelPathRef(configuredKernelPathRef, kernelPa
|
|
|
85
86
|
return selectRuleValue('inference', 'kernelPath', 'autoSelect', {
|
|
86
87
|
kernelPathRef: configuredKernelPathRef,
|
|
87
88
|
hasSubgroups,
|
|
89
|
+
hasF16,
|
|
88
90
|
allowCapabilityAutoSelection,
|
|
89
91
|
});
|
|
90
92
|
}
|
|
@@ -283,6 +283,9 @@ export async function traceStep(name, label, layer, outputBuffer, outputShape, o
|
|
|
283
283
|
if (layer >= 0 && !kernelTrace.shouldTraceLayer(layer)) return;
|
|
284
284
|
|
|
285
285
|
const output = await snapshotTensor(outputBuffer, outputShape);
|
|
286
|
+
if (!output.ok) {
|
|
287
|
+
throw new Error(`[TRACE] Failed to snapshot output for ${label}: ${output.error}`);
|
|
288
|
+
}
|
|
286
289
|
|
|
287
290
|
// Snapshot inputs if provided (expensive - only do if tracing)
|
|
288
291
|
|
|
@@ -290,6 +293,9 @@ export async function traceStep(name, label, layer, outputBuffer, outputShape, o
|
|
|
290
293
|
if (options?.inputs && options?.inputShapes) {
|
|
291
294
|
for (let i = 0; i < options.inputs.length; i++) {
|
|
292
295
|
const snap = await snapshotTensor(options.inputs[i], options.inputShapes[i]);
|
|
296
|
+
if (!snap.ok) {
|
|
297
|
+
throw new Error(`[TRACE] Failed to snapshot input ${i} for ${label}: ${snap.error}`);
|
|
298
|
+
}
|
|
293
299
|
inputs.push(snap);
|
|
294
300
|
}
|
|
295
301
|
}
|
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
import { log, trace } from '../../../debug/index.js';
|
|
4
4
|
import { getDevice } from '../../../gpu/device.js';
|
|
5
|
-
import { releaseBuffer } from '../../../memory/buffer-pool.js';
|
|
5
|
+
import { releaseBuffer, readBuffer } from '../../../memory/buffer-pool.js';
|
|
6
6
|
import { allowReadback } from '../../../gpu/perf-guards.js';
|
|
7
7
|
import { createTensor } from '../../../gpu/tensor.js';
|
|
8
8
|
import {
|
|
@@ -228,6 +228,7 @@ export async function processLayerGPU(layerIdx, inputBuffer, numTokens, isPrefil
|
|
|
228
228
|
linearRuntime: context.linearAttentionRuntime ?? null,
|
|
229
229
|
getWeightBuffer: (weight, label) => getWeightBuffer(weight, label),
|
|
230
230
|
getNormWeightBuffer: (weight, label) => getNormWeightBuffer(weight, label, weightConfig, debugFlags),
|
|
231
|
+
debugProbes: context.debugProbes,
|
|
231
232
|
recorder: recorder ?? null,
|
|
232
233
|
});
|
|
233
234
|
} else {
|
|
@@ -275,6 +276,7 @@ export async function processLayerGPU(layerIdx, inputBuffer, numTokens, isPrefil
|
|
|
275
276
|
: (ropeFreqsSin),
|
|
276
277
|
kvCache: ((kvCache)),
|
|
277
278
|
stats: context.stats,
|
|
279
|
+
debugProbes: context.debugProbes,
|
|
278
280
|
linearRuntime: context.linearAttentionRuntime ?? null,
|
|
279
281
|
};
|
|
280
282
|
|
|
@@ -314,14 +316,7 @@ export async function processLayerGPU(layerIdx, inputBuffer, numTokens, isPrefil
|
|
|
314
316
|
if (allowReadback(`layer.attn-out.${layerIdx}`)) {
|
|
315
317
|
try {
|
|
316
318
|
const sampleSize = Math.min(128, attnOutput.buffer.size);
|
|
317
|
-
const
|
|
318
|
-
const enc = device.createCommandEncoder();
|
|
319
|
-
enc.copyBufferToBuffer(attnOutput.buffer, 0, staging, 0, sampleSize);
|
|
320
|
-
device.queue.submit([enc.finish()]);
|
|
321
|
-
await staging.mapAsync(GPUMapMode.READ);
|
|
322
|
-
const data = new Float32Array(staging.getMappedRange().slice(0));
|
|
323
|
-
staging.unmap();
|
|
324
|
-
staging.destroy();
|
|
319
|
+
const data = new Float32Array(await readBuffer(attnOutput.buffer, sampleSize));
|
|
325
320
|
let maxAbs = 0;
|
|
326
321
|
for (let i = 0; i < data.length; i++) {
|
|
327
322
|
const abs = Math.abs(data[i]);
|
|
@@ -3,6 +3,7 @@ import type { Tensor } from '../../../gpu/tensor.js';
|
|
|
3
3
|
import type { WeightBuffer } from '../../../gpu/weight-buffer.js';
|
|
4
4
|
import type { CommandRecorder } from '../../../gpu/command-recorder.js';
|
|
5
5
|
import type { LinearNormMode } from '../../../config/schema/index.js';
|
|
6
|
+
import type { ProbeConfigSchema } from '../../../config/schema/index.js';
|
|
6
7
|
|
|
7
8
|
export interface LinearLayerRuntimeState {
|
|
8
9
|
layerIdx: number;
|
|
@@ -67,6 +68,7 @@ export interface RunLinearAttentionLayerOptions {
|
|
|
67
68
|
weight: GPUBuffer | Float32Array | ArrayBuffer,
|
|
68
69
|
label: string
|
|
69
70
|
) => GPUBuffer;
|
|
71
|
+
debugProbes?: ProbeConfigSchema[] | null;
|
|
70
72
|
recorder?: CommandRecorder | null;
|
|
71
73
|
}
|
|
72
74
|
|
|
@@ -74,6 +76,19 @@ export declare function hasLinearAttentionLayers(layerTypes: unknown): boolean;
|
|
|
74
76
|
|
|
75
77
|
export declare function createLinearAttentionRuntime(): LinearAttentionRuntime;
|
|
76
78
|
|
|
79
|
+
export declare function inferLinearNormMode(
|
|
80
|
+
weight: { size?: number; dtype?: string } | GPUBuffer | WeightBuffer | ArrayBufferView | ArrayBuffer | null | undefined,
|
|
81
|
+
projectionLayout: {
|
|
82
|
+
headVDim: number;
|
|
83
|
+
valueDim: number;
|
|
84
|
+
}
|
|
85
|
+
): LinearNormMode | null;
|
|
86
|
+
|
|
87
|
+
export declare function applyLinearNormWeightOffset(
|
|
88
|
+
values: Float32Array,
|
|
89
|
+
rmsNormWeightOffset: boolean
|
|
90
|
+
): Float32Array;
|
|
91
|
+
|
|
77
92
|
export declare function resetLinearAttentionRuntime(
|
|
78
93
|
runtime: LinearAttentionRuntime | null | undefined
|
|
79
94
|
): LinearAttentionRuntime;
|