@simulatte/doppler 0.1.6 → 0.1.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +126 -0
- package/README.md +16 -23
- package/package.json +14 -1
- package/src/adapters/adapter-registry.js +12 -1
- package/src/adapters/lora-loader.js +23 -6
- package/src/bridge/extension-client.d.ts +5 -0
- package/src/bridge/extension-client.js +40 -0
- package/src/bridge/index.d.ts +2 -1
- package/src/bridge/index.js +6 -4
- package/src/browser/browser-converter.js +26 -1
- package/src/browser/file-picker.js +6 -0
- package/src/browser/safetensors-parser-browser.js +84 -1
- package/src/browser/shard-io-browser.js +2 -2
- package/src/browser/tensor-source-download.js +8 -2
- package/src/browser/tensor-source-http.d.ts +1 -0
- package/src/browser/tensor-source-http.js +5 -1
- package/src/client/doppler-api.browser.js +20 -4
- package/src/client/doppler-api.js +19 -3
- package/src/client/doppler-provider/generation.js +12 -0
- package/src/client/doppler-provider/model-manager.d.ts +10 -0
- package/src/client/doppler-provider/model-manager.js +91 -19
- package/src/client/doppler-provider/source-runtime.d.ts +2 -1
- package/src/client/doppler-provider/source-runtime.js +132 -13
- package/src/client/doppler-registry.json +8 -7
- package/src/config/backward-registry-loader.js +17 -2
- package/src/config/execution-v0-contract-check.js +113 -15
- package/src/config/kernel-path-contract-check.js +57 -29
- package/src/config/kernel-path-loader.js +5 -36
- package/src/config/kernels/kernel-ref-digests.js +1 -1
- package/src/config/kernels/registry.js +14 -1
- package/src/config/kernels/registry.json +7 -5
- package/src/config/loader.d.ts +1 -1
- package/src/config/loader.js +12 -2
- package/src/config/merge-contract-check.js +59 -4
- package/src/config/merge-helpers.js +128 -7
- package/src/config/merge.d.ts +1 -0
- package/src/config/merge.js +10 -0
- package/src/config/param-validator.js +47 -2
- package/src/config/presets/kernel-paths/{gemma2-q4k-dequant-f32a.json → gemma2-q4k-dequant-f32a-nosubgroups.json} +3 -3
- package/src/config/presets/kernel-paths/gemma3-f16-fused-f32a-online-streamingprefill.json +223 -0
- package/src/config/presets/kernel-paths/{gemma3-q4k-dequant-f32a.json → gemma3-q4k-dequant-f32a-nosubgroups.json} +3 -3
- package/src/config/presets/kernel-paths/registry.json +29 -8
- package/src/config/presets/models/gemma2.json +2 -2
- package/src/config/presets/runtime/experiments/bench/gemma3-bench-q4k.json +1 -1
- package/src/config/presets/runtime/experiments/debug/gemma3-debug-q4k.json +1 -1
- package/src/config/presets/runtime/experiments/verify/gemma3-verify.json +1 -1
- package/src/config/presets/runtime/kernels/dequant-f16-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/dequant-f32-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/embeddinggemma-q4k-dequant-f32a.json +37 -0
- package/src/config/presets/runtime/kernels/fused-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f16a.json +33 -0
- package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f32a-nosubgroups.json +33 -0
- package/src/config/presets/runtime/kernels/gemma2-q4k-fused-f32a.json +33 -0
- package/src/config/presets/runtime/kernels/safe-q4k.json +6 -13
- package/src/config/presets/runtime/platform/metal-apple-q4k.json +1 -1
- package/src/config/runtime.js +6 -1
- package/src/config/schema/debug.schema.d.ts +5 -0
- package/src/config/schema/doppler.schema.js +16 -21
- package/src/config/schema/inference-defaults.schema.js +3 -3
- package/src/config/schema/kernel-path.schema.d.ts +5 -1
- package/src/config/schema/kernel-thresholds.schema.js +12 -4
- package/src/config/schema/manifest.schema.d.ts +2 -1
- package/src/config/schema/manifest.schema.js +16 -3
- package/src/config/training-defaults.js +30 -22
- package/src/converter/conversion-plan.js +94 -9
- package/src/converter/core.d.ts +7 -0
- package/src/converter/core.js +14 -9
- package/src/converter/execution-v0-manifest.js +4 -1
- package/src/converter/index.d.ts +1 -0
- package/src/converter/index.js +1 -0
- package/src/converter/manifest-inference.js +43 -12
- package/src/converter/parsers/diffusion.js +0 -3
- package/src/converter/quantization-info.js +35 -15
- package/src/converter/shard-packer.d.ts +1 -1
- package/src/converter/shard-packer.js +4 -1
- package/src/debug/config.js +123 -11
- package/src/debug/signals.js +7 -1
- package/src/debug/tensor.d.ts +2 -0
- package/src/debug/tensor.js +13 -2
- package/src/distribution/p2p-control-plane.js +52 -12
- package/src/distribution/p2p-observability.js +43 -7
- package/src/distribution/p2p-webrtc-browser.js +20 -0
- package/src/distribution/shard-delivery.js +77 -26
- package/src/formats/gguf/types.js +33 -16
- package/src/formats/rdrr/groups.d.ts +12 -4
- package/src/formats/rdrr/groups.js +3 -6
- package/src/formats/rdrr/parsing.js +39 -2
- package/src/formats/rdrr/types.d.ts +2 -1
- package/src/gpu/command-recorder.js +86 -61
- package/src/gpu/device.d.ts +1 -0
- package/src/gpu/device.js +73 -19
- package/src/gpu/kernel-tuner/benchmarks.js +326 -316
- package/src/gpu/kernel-tuner/cache.js +71 -4
- package/src/gpu/kernel-tuner/tuner.js +22 -4
- package/src/gpu/kernels/attention.js +15 -34
- package/src/gpu/kernels/backward/adam.js +62 -58
- package/src/gpu/kernels/backward/attention_backward.js +257 -169
- package/src/gpu/kernels/backward/conv2d_backward.js +14 -1
- package/src/gpu/kernels/cast.js +191 -149
- package/src/gpu/kernels/check-stop.js +33 -44
- package/src/gpu/kernels/conv2d.js +27 -17
- package/src/gpu/kernels/cross_entropy_loss.js +21 -15
- package/src/gpu/kernels/depthwise_conv2d.js +36 -26
- package/src/gpu/kernels/dequant.js +178 -126
- package/src/gpu/kernels/energy.d.ts +3 -21
- package/src/gpu/kernels/energy.js +111 -88
- package/src/gpu/kernels/feature-check.js +1 -1
- package/src/gpu/kernels/fused_ffn.js +84 -65
- package/src/gpu/kernels/fused_matmul_residual.js +56 -33
- package/src/gpu/kernels/fused_matmul_rmsnorm.js +62 -45
- package/src/gpu/kernels/gather.js +33 -15
- package/src/gpu/kernels/gelu.js +19 -11
- package/src/gpu/kernels/grouped_pointwise_conv2d.js +33 -23
- package/src/gpu/kernels/groupnorm.js +34 -23
- package/src/gpu/kernels/kv-quantize.js +5 -2
- package/src/gpu/kernels/layernorm.js +35 -19
- package/src/gpu/kernels/logit-merge.js +5 -3
- package/src/gpu/kernels/matmul.js +58 -39
- package/src/gpu/kernels/modulate.js +23 -15
- package/src/gpu/kernels/moe.js +221 -175
- package/src/gpu/kernels/pixel_shuffle.js +22 -14
- package/src/gpu/kernels/relu.js +18 -10
- package/src/gpu/kernels/repeat_channels.js +25 -17
- package/src/gpu/kernels/residual.js +37 -27
- package/src/gpu/kernels/rmsnorm.js +57 -41
- package/src/gpu/kernels/rope.js +3 -0
- package/src/gpu/kernels/sample.js +27 -38
- package/src/gpu/kernels/sana_linear_attention.js +18 -10
- package/src/gpu/kernels/scale.js +18 -11
- package/src/gpu/kernels/shader-cache.js +4 -2
- package/src/gpu/kernels/silu.js +120 -72
- package/src/gpu/kernels/softmax.js +44 -25
- package/src/gpu/kernels/split_qkv.js +23 -13
- package/src/gpu/kernels/transpose.js +18 -10
- package/src/gpu/kernels/transpose.wgsl +5 -3
- package/src/gpu/kernels/upsample2d.js +21 -13
- package/src/gpu/kernels/utils.js +20 -13
- package/src/gpu/partitioned-buffer-pool.js +10 -2
- package/src/gpu/perf-guards.js +2 -9
- package/src/gpu/profiler.js +27 -22
- package/src/gpu/readback-utils.d.ts +16 -0
- package/src/gpu/readback-utils.js +41 -0
- package/src/gpu/submit-tracker.js +13 -0
- package/src/gpu/uniform-cache.d.ts +1 -0
- package/src/gpu/uniform-cache.js +30 -9
- package/src/hotswap/intent-bundle.js +6 -0
- package/src/hotswap/manifest.d.ts +10 -1
- package/src/hotswap/manifest.js +12 -2
- package/src/hotswap/runtime.js +30 -8
- package/src/index-browser.d.ts +44 -0
- package/src/index-browser.js +14 -0
- package/src/inference/browser-harness-contract-helpers.d.ts +5 -0
- package/src/inference/browser-harness-contract-helpers.js +28 -0
- package/src/inference/browser-harness-diffusion-energy-suites.d.ts +2 -0
- package/src/inference/browser-harness-diffusion-energy-suites.js +269 -0
- package/src/inference/browser-harness-model-helpers.d.ts +16 -0
- package/src/inference/browser-harness-model-helpers.js +217 -0
- package/src/inference/browser-harness-report-helpers.d.ts +7 -0
- package/src/inference/browser-harness-report-helpers.js +42 -0
- package/src/inference/browser-harness-runtime-helpers.d.ts +61 -0
- package/src/inference/browser-harness-runtime-helpers.js +415 -0
- package/src/inference/browser-harness-suite-helpers.d.ts +28 -0
- package/src/inference/browser-harness-suite-helpers.js +268 -0
- package/src/inference/browser-harness-text-helpers.d.ts +27 -0
- package/src/inference/browser-harness-text-helpers.js +788 -0
- package/src/inference/browser-harness.d.ts +6 -0
- package/src/inference/browser-harness.js +130 -1996
- package/src/inference/kv-cache/base.js +140 -94
- package/src/inference/kv-cache/tiered.js +5 -3
- package/src/inference/moe-router.js +88 -56
- package/src/inference/multi-model-network.js +5 -3
- package/src/inference/network-evolution.d.ts +11 -2
- package/src/inference/network-evolution.js +20 -21
- package/src/inference/pipelines/context.d.ts +3 -0
- package/src/inference/pipelines/context.js +142 -2
- package/src/inference/pipelines/diffusion/helpers.js +7 -2
- package/src/inference/pipelines/diffusion/pipeline.js +2 -1
- package/src/inference/pipelines/diffusion/sd3-transformer.js +10 -10
- package/src/inference/pipelines/diffusion/vae.js +3 -7
- package/src/inference/pipelines/energy/pipeline.js +27 -21
- package/src/inference/pipelines/energy/quintel.d.ts +5 -0
- package/src/inference/pipelines/energy/quintel.js +11 -0
- package/src/inference/pipelines/energy-head/row-head-pipeline.js +17 -13
- package/src/inference/pipelines/structured/json-head-pipeline.js +26 -11
- package/src/inference/pipelines/text/attention/projections.js +151 -101
- package/src/inference/pipelines/text/attention/record.js +62 -8
- package/src/inference/pipelines/text/attention/run.js +62 -8
- package/src/inference/pipelines/text/config.js +3 -4
- package/src/inference/pipelines/text/embed.js +2 -8
- package/src/inference/pipelines/text/execution-plan.js +41 -19
- package/src/inference/pipelines/text/execution-v0-contract-helpers.d.ts +59 -0
- package/src/inference/pipelines/text/execution-v0-contract-helpers.js +937 -0
- package/src/inference/pipelines/text/execution-v0-runtime-builders.d.ts +15 -0
- package/src/inference/pipelines/text/execution-v0-runtime-builders.js +279 -0
- package/src/inference/pipelines/text/execution-v0.js +62 -1013
- package/src/inference/pipelines/text/generator-steps.d.ts +46 -0
- package/src/inference/pipelines/text/generator-steps.js +298 -207
- package/src/inference/pipelines/text/generator.js +6 -23
- package/src/inference/pipelines/text/init.js +78 -20
- package/src/inference/pipelines/text/kernel-path-auto-select.js +2 -0
- package/src/inference/pipelines/text/kernel-trace.d.ts +2 -0
- package/src/inference/pipelines/text/kernel-trace.js +6 -0
- package/src/inference/pipelines/text/layer.js +3 -9
- package/src/inference/pipelines/text/linear-attention.d.ts +10 -0
- package/src/inference/pipelines/text/linear-attention.js +80 -6
- package/src/inference/pipelines/text/logits/gpu.js +10 -5
- package/src/inference/pipelines/text/logits/index.js +10 -11
- package/src/inference/pipelines/text/logits/utils.d.ts +7 -0
- package/src/inference/pipelines/text/logits/utils.js +9 -0
- package/src/inference/pipelines/text/lora-apply.js +50 -32
- package/src/inference/pipelines/text/model-load.js +279 -104
- package/src/inference/pipelines/text/moe-cache.js +5 -4
- package/src/inference/pipelines/text/moe-cpu-gptoss.js +74 -69
- package/src/inference/pipelines/text/moe-cpu.js +42 -38
- package/src/inference/pipelines/text/moe-gpu.js +110 -86
- package/src/inference/pipelines/text/ops.js +90 -90
- package/src/inference/pipelines/text/probes.js +9 -9
- package/src/inference/pipelines/text/weights.js +17 -7
- package/src/inference/pipelines/text.js +13 -1
- package/src/inference/speculative.d.ts +2 -2
- package/src/inference/speculative.js +4 -18
- package/src/inference/test-harness.d.ts +1 -1
- package/src/inference/test-harness.js +15 -5
- package/src/inference/tokenizer.d.ts +0 -5
- package/src/inference/tokenizer.js +4 -23
- package/src/inference/tokenizers/bpe.js +9 -0
- package/src/inference/tokenizers/bundled.js +20 -0
- package/src/inference/tokenizers/sentencepiece.js +12 -0
- package/src/loader/doppler-loader.js +38 -22
- package/src/loader/dtype-utils.js +3 -44
- package/src/loader/embedding-loader.js +7 -3
- package/src/loader/experts/expert-cache.js +13 -6
- package/src/loader/experts/expert-loader.js +10 -6
- package/src/loader/final-weights-loader.js +8 -4
- package/src/loader/layer-loader.js +2 -1
- package/src/loader/loader-state.js +2 -2
- package/src/loader/memory-monitor.js +8 -0
- package/src/loader/multi-model-loader.d.ts +14 -0
- package/src/loader/multi-model-loader.js +70 -24
- package/src/loader/shard-cache.js +81 -12
- package/src/loader/shard-resolver.js +25 -3
- package/src/loader/tensors/tensor-loader.js +209 -144
- package/src/loader/tensors/tensor-reader.js +76 -19
- package/src/loader/weight-downcast.js +1 -1
- package/src/memory/buffer-pool.d.ts +9 -1
- package/src/memory/buffer-pool.js +109 -44
- package/src/memory/unified-detect.js +1 -1
- package/src/rules/inference/kernel-path.rules.json +24 -8
- package/src/rules/rule-registry.js +25 -1
- package/src/storage/backends/opfs-store.js +68 -24
- package/src/storage/downloader.js +364 -83
- package/src/storage/index.d.ts +3 -0
- package/src/storage/index.js +3 -0
- package/src/storage/preflight.d.ts +2 -2
- package/src/storage/preflight.js +24 -2
- package/src/storage/quickstart-downloader.js +11 -5
- package/src/storage/registry.js +10 -4
- package/src/storage/reports.js +1 -1
- package/src/storage/shard-manager.d.ts +15 -1
- package/src/storage/shard-manager.js +51 -3
- package/src/storage/source-artifact-store.d.ts +52 -0
- package/src/storage/source-artifact-store.js +234 -0
- package/src/tooling/command-api-constants.d.ts +9 -0
- package/src/tooling/command-api-constants.js +9 -0
- package/src/tooling/command-api-family-normalizers.d.ts +9 -0
- package/src/tooling/command-api-family-normalizers.js +343 -0
- package/src/tooling/command-api-helpers.d.ts +25 -0
- package/src/tooling/command-api-helpers.js +262 -0
- package/src/tooling/command-api.js +16 -602
- package/src/tooling/command-envelope.js +4 -1
- package/src/tooling/command-runner-shared.js +52 -18
- package/src/tooling/lean-execution-contract.js +150 -3
- package/src/tooling/node-browser-command-runner.js +161 -271
- package/src/tooling/node-command-runner.js +29 -3
- package/src/tooling/node-converter.js +27 -1
- package/src/tooling/node-source-runtime.d.ts +1 -1
- package/src/tooling/node-source-runtime.js +84 -3
- package/src/tooling/node-webgpu.js +24 -21
- package/src/tooling/opfs-cache.js +21 -4
- package/src/tooling/runtime-input-composition.d.ts +38 -0
- package/src/tooling/runtime-input-composition.js +86 -0
- package/src/tooling/source-runtime-bundle.d.ts +40 -5
- package/src/tooling/source-runtime-bundle.js +261 -34
- package/src/tooling/source-runtime-materializer.d.ts +6 -0
- package/src/tooling/source-runtime-materializer.js +93 -0
- package/src/training/attention-backward.js +32 -17
- package/src/training/autograd.js +80 -52
- package/src/training/checkpoint-watch.d.ts +2 -1
- package/src/training/checkpoint-watch.js +39 -6
- package/src/training/checkpoint.js +40 -11
- package/src/training/clip.js +2 -1
- package/src/training/datasets/token-batch.js +20 -8
- package/src/training/distillation/checkpoint-watch.js +1 -0
- package/src/training/distillation/student-fixture.d.ts +22 -0
- package/src/training/distillation/student-fixture.js +846 -0
- package/src/training/distillation/suite-data.d.ts +45 -0
- package/src/training/distillation/suite-data.js +189 -0
- package/src/training/lora-pipeline.js +4 -7
- package/src/training/lora.js +26 -12
- package/src/training/loss.js +5 -6
- package/src/training/objectives/cross_entropy.js +2 -5
- package/src/training/objectives/distill_kd.js +4 -8
- package/src/training/objectives/distill_triplet.js +4 -8
- package/src/training/objectives/ul_stage2_base.js +4 -8
- package/src/training/operator-command.js +2 -0
- package/src/training/optimizer.js +19 -7
- package/src/training/runner.js +2 -1
- package/src/training/suite.js +18 -978
- package/src/training/tensor-factory.d.ts +9 -0
- package/src/training/tensor-factory.js +13 -0
- package/src/training/trainer.js +3 -5
- package/src/training/ul_dataset.js +3 -5
- package/src/training/workloads.js +70 -79
- package/src/version.js +1 -1
- package/tools/convert-safetensors-node.js +22 -16
- package/tools/doppler-cli.js +44 -25
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import {
|
|
1
|
+
import { releaseBuffer } from '../../../../memory/buffer-pool.js';
|
|
2
2
|
import { isWeightBuffer, getLayout, getWeightDtype } from '../../../../gpu/weight-buffer.js';
|
|
3
3
|
import {
|
|
4
4
|
runMatmul,
|
|
@@ -36,7 +36,7 @@ function getRmsNormRunner(recorder) {
|
|
|
36
36
|
}
|
|
37
37
|
|
|
38
38
|
function releaseOwnedWeightBuffer(layerWeight, resolvedWeightBuffer, releaseTemporary) {
|
|
39
|
-
if (layerWeight instanceof GPUBuffer || isWeightBuffer(layerWeight)) {
|
|
39
|
+
if ((typeof GPUBuffer !== 'undefined' && layerWeight instanceof GPUBuffer) || isWeightBuffer(layerWeight)) {
|
|
40
40
|
return;
|
|
41
41
|
}
|
|
42
42
|
if (!resolvedWeightBuffer) {
|
|
@@ -66,10 +66,16 @@ async function projectSingleQkvTensor({
|
|
|
66
66
|
}) {
|
|
67
67
|
const runMatmulForMode = getMatmulRunner(recorder);
|
|
68
68
|
const layerWeight = layerWeights?.[weightKey];
|
|
69
|
-
|
|
69
|
+
if (!layerWeight) {
|
|
70
|
+
throw new Error(`Attention projection requires ${weightKey}.`);
|
|
71
|
+
}
|
|
72
|
+
if (!getWeightBuffer) {
|
|
73
|
+
throw new Error(`Attention projection requires getWeightBuffer for ${role}.`);
|
|
74
|
+
}
|
|
70
75
|
|
|
71
|
-
|
|
72
|
-
|
|
76
|
+
let projected;
|
|
77
|
+
const projBuffer = getWeightBuffer(layerWeight, role);
|
|
78
|
+
try {
|
|
73
79
|
projected = await runMatmulForMode(normed, projBuffer, numTokens, outputSize, hiddenSize, {
|
|
74
80
|
transposeB: 'auto',
|
|
75
81
|
role,
|
|
@@ -77,26 +83,31 @@ async function projectSingleQkvTensor({
|
|
|
77
83
|
kernelPath,
|
|
78
84
|
outputDtype: matmulOutputDtype,
|
|
79
85
|
});
|
|
86
|
+
} finally {
|
|
80
87
|
releaseOwnedWeightBuffer(layerWeight, projBuffer, releaseTemporary);
|
|
81
|
-
} else {
|
|
82
|
-
const fallback = acquireBuffer(numTokens * outputSize * 4, undefined, outputLabel);
|
|
83
|
-
projected = createTensor(fallback, normed.dtype, [numTokens, outputSize], outputLabel);
|
|
84
88
|
}
|
|
85
89
|
|
|
86
90
|
const loraModule = getLoRAModule(lora, layerIdx, loraKey);
|
|
87
91
|
if (loraModule && getWeightBuffer) {
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
92
|
+
try {
|
|
93
|
+
const combined = await applyLoRA(
|
|
94
|
+
normed,
|
|
95
|
+
projected,
|
|
96
|
+
loraModule,
|
|
97
|
+
{ M: numTokens, N: outputSize, K: hiddenSize },
|
|
98
|
+
getWeightBuffer,
|
|
99
|
+
recorder ?? undefined,
|
|
100
|
+
{ kernelPath }
|
|
101
|
+
);
|
|
102
|
+
if (combined.buffer !== projected.buffer) {
|
|
103
|
+
releaseTemporary(projected.buffer);
|
|
104
|
+
projected = combined;
|
|
105
|
+
}
|
|
106
|
+
} catch (error) {
|
|
107
|
+
if (projected?.buffer) {
|
|
108
|
+
releaseTemporary(projected.buffer);
|
|
109
|
+
}
|
|
110
|
+
throw error;
|
|
100
111
|
}
|
|
101
112
|
}
|
|
102
113
|
|
|
@@ -212,24 +223,42 @@ async function projectQueryWithOptionalGate({
|
|
|
212
223
|
bOffset: gateOffset,
|
|
213
224
|
outputDtype: matmulOutputDtype,
|
|
214
225
|
});
|
|
226
|
+
} catch (error) {
|
|
227
|
+
if (qTensor) {
|
|
228
|
+
releaseTemporary(qTensor.buffer);
|
|
229
|
+
}
|
|
230
|
+
if (qGateTensor) {
|
|
231
|
+
releaseTemporary(qGateTensor.buffer);
|
|
232
|
+
}
|
|
233
|
+
throw error;
|
|
215
234
|
} finally {
|
|
216
235
|
releaseOwnedWeightBuffer(qWeight, qWeightBuffer, releaseTemporary);
|
|
217
236
|
}
|
|
218
237
|
|
|
219
238
|
const loraModule = getLoRAModule(lora, layerIdx, 'q_proj');
|
|
220
239
|
if (loraModule && getWeightBuffer) {
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
240
|
+
try {
|
|
241
|
+
const combined = await applyLoRA(
|
|
242
|
+
normed,
|
|
243
|
+
qTensor,
|
|
244
|
+
loraModule,
|
|
245
|
+
{ M: numTokens, N: qSize, K: hiddenSize },
|
|
246
|
+
getWeightBuffer,
|
|
247
|
+
recorder ?? undefined,
|
|
248
|
+
{ kernelPath }
|
|
249
|
+
);
|
|
250
|
+
if (combined.buffer !== qTensor.buffer) {
|
|
251
|
+
releaseTemporary(qTensor.buffer);
|
|
252
|
+
qTensor = combined;
|
|
253
|
+
}
|
|
254
|
+
} catch (error) {
|
|
255
|
+
if (qTensor?.buffer) {
|
|
256
|
+
releaseTemporary(qTensor.buffer);
|
|
257
|
+
}
|
|
258
|
+
if (qGateTensor?.buffer) {
|
|
259
|
+
releaseTemporary(qGateTensor.buffer);
|
|
260
|
+
}
|
|
261
|
+
throw error;
|
|
233
262
|
}
|
|
234
263
|
}
|
|
235
264
|
|
|
@@ -289,82 +318,103 @@ export async function projectAttentionQKV({
|
|
|
289
318
|
if (useFusedQKV && layerWeights.qkvProj && layerWeights.qkvSizes) {
|
|
290
319
|
const [qSizeFused, kSizeFused, vSizeFused] = layerWeights.qkvSizes;
|
|
291
320
|
const qkvSizeTotal = qSizeFused + kSizeFused + vSizeFused;
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
321
|
+
let qkvTensor = null;
|
|
322
|
+
try {
|
|
323
|
+
qkvTensor = await runMatmulForMode(normed, layerWeights.qkvProj, numTokens, qkvSizeTotal, hiddenSize, {
|
|
324
|
+
transposeB: 'auto',
|
|
325
|
+
role: 'qkv_proj',
|
|
326
|
+
layerIdx,
|
|
327
|
+
kernelPath,
|
|
328
|
+
outputDtype: matmulOutputDtype,
|
|
329
|
+
});
|
|
330
|
+
const split = await runSplitForMode(qkvTensor, {
|
|
331
|
+
numTokens,
|
|
332
|
+
qSize: qSizeFused,
|
|
333
|
+
kSize: kSizeFused,
|
|
334
|
+
vSize: vSizeFused,
|
|
335
|
+
});
|
|
336
|
+
releaseTemporary(qkvTensor.buffer);
|
|
337
|
+
if (onFusedQKV) {
|
|
338
|
+
onFusedQKV({ qSize: qSizeFused, kSize: kSizeFused, vSize: vSizeFused, totalSize: qkvSizeTotal });
|
|
339
|
+
}
|
|
340
|
+
return { qTensor: split.Q, qGateTensor: null, kTensor: split.K, vTensor: split.V, usedFusedQKV: true };
|
|
341
|
+
} catch (error) {
|
|
342
|
+
if (qkvTensor) {
|
|
343
|
+
releaseTemporary(qkvTensor.buffer);
|
|
344
|
+
}
|
|
345
|
+
throw error;
|
|
346
|
+
}
|
|
347
|
+
}
|
|
348
|
+
|
|
349
|
+
let qTensor = null;
|
|
350
|
+
let qGateTensor = null;
|
|
351
|
+
let kTensor = null;
|
|
352
|
+
let vTensor = null;
|
|
353
|
+
try {
|
|
354
|
+
({ qTensor, qGateTensor } = await projectQueryWithOptionalGate({
|
|
355
|
+
recorder,
|
|
356
|
+
normed,
|
|
357
|
+
layerWeights,
|
|
358
|
+
numTokens,
|
|
359
|
+
numHeads,
|
|
360
|
+
headDim,
|
|
361
|
+
hiddenSize,
|
|
295
362
|
layerIdx,
|
|
296
363
|
kernelPath,
|
|
297
|
-
|
|
364
|
+
matmulOutputDtype,
|
|
365
|
+
getWeightBuffer,
|
|
366
|
+
lora,
|
|
367
|
+
releaseTemporary,
|
|
368
|
+
attentionOutputGate,
|
|
369
|
+
}));
|
|
370
|
+
|
|
371
|
+
kTensor = await projectSingleQkvTensor({
|
|
372
|
+
recorder,
|
|
373
|
+
normed,
|
|
374
|
+
layerWeights,
|
|
375
|
+
weightKey: 'kProj',
|
|
376
|
+
role: 'k_proj',
|
|
377
|
+
outputSize: numKVHeads * headDim,
|
|
378
|
+
outputLabel: 'K',
|
|
379
|
+
loraKey: 'k_proj',
|
|
380
|
+
numTokens,
|
|
381
|
+
hiddenSize,
|
|
382
|
+
layerIdx,
|
|
383
|
+
kernelPath,
|
|
384
|
+
matmulOutputDtype,
|
|
385
|
+
getWeightBuffer,
|
|
386
|
+
lora,
|
|
387
|
+
releaseTemporary,
|
|
298
388
|
});
|
|
299
|
-
|
|
389
|
+
|
|
390
|
+
vTensor = await projectSingleQkvTensor({
|
|
391
|
+
recorder,
|
|
392
|
+
normed,
|
|
393
|
+
layerWeights,
|
|
394
|
+
weightKey: 'vProj',
|
|
395
|
+
role: 'v_proj',
|
|
396
|
+
outputSize: numKVHeads * headDim,
|
|
397
|
+
outputLabel: 'V',
|
|
398
|
+
loraKey: 'v_proj',
|
|
300
399
|
numTokens,
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
400
|
+
hiddenSize,
|
|
401
|
+
layerIdx,
|
|
402
|
+
kernelPath,
|
|
403
|
+
matmulOutputDtype,
|
|
404
|
+
getWeightBuffer,
|
|
405
|
+
lora,
|
|
406
|
+
releaseTemporary,
|
|
304
407
|
});
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
408
|
+
|
|
409
|
+
return { qTensor, qGateTensor, kTensor, vTensor, usedFusedQKV: false };
|
|
410
|
+
} catch (error) {
|
|
411
|
+
for (const tensor of [qTensor, qGateTensor, kTensor, vTensor]) {
|
|
412
|
+
if (tensor?.buffer) {
|
|
413
|
+
releaseTemporary(tensor.buffer);
|
|
414
|
+
}
|
|
308
415
|
}
|
|
309
|
-
|
|
416
|
+
throw error;
|
|
310
417
|
}
|
|
311
|
-
|
|
312
|
-
const { qTensor, qGateTensor } = await projectQueryWithOptionalGate({
|
|
313
|
-
recorder,
|
|
314
|
-
normed,
|
|
315
|
-
layerWeights,
|
|
316
|
-
numTokens,
|
|
317
|
-
numHeads,
|
|
318
|
-
headDim,
|
|
319
|
-
hiddenSize,
|
|
320
|
-
layerIdx,
|
|
321
|
-
kernelPath,
|
|
322
|
-
matmulOutputDtype,
|
|
323
|
-
getWeightBuffer,
|
|
324
|
-
lora,
|
|
325
|
-
releaseTemporary,
|
|
326
|
-
attentionOutputGate,
|
|
327
|
-
});
|
|
328
|
-
|
|
329
|
-
const kTensor = await projectSingleQkvTensor({
|
|
330
|
-
recorder,
|
|
331
|
-
normed,
|
|
332
|
-
layerWeights,
|
|
333
|
-
weightKey: 'kProj',
|
|
334
|
-
role: 'k_proj',
|
|
335
|
-
outputSize: numKVHeads * headDim,
|
|
336
|
-
outputLabel: 'K',
|
|
337
|
-
loraKey: 'k_proj',
|
|
338
|
-
numTokens,
|
|
339
|
-
hiddenSize,
|
|
340
|
-
layerIdx,
|
|
341
|
-
kernelPath,
|
|
342
|
-
matmulOutputDtype,
|
|
343
|
-
getWeightBuffer,
|
|
344
|
-
lora,
|
|
345
|
-
releaseTemporary,
|
|
346
|
-
});
|
|
347
|
-
|
|
348
|
-
const vTensor = await projectSingleQkvTensor({
|
|
349
|
-
recorder,
|
|
350
|
-
normed,
|
|
351
|
-
layerWeights,
|
|
352
|
-
weightKey: 'vProj',
|
|
353
|
-
role: 'v_proj',
|
|
354
|
-
outputSize: numKVHeads * headDim,
|
|
355
|
-
outputLabel: 'V',
|
|
356
|
-
loraKey: 'v_proj',
|
|
357
|
-
numTokens,
|
|
358
|
-
hiddenSize,
|
|
359
|
-
layerIdx,
|
|
360
|
-
kernelPath,
|
|
361
|
-
matmulOutputDtype,
|
|
362
|
-
getWeightBuffer,
|
|
363
|
-
lora,
|
|
364
|
-
releaseTemporary,
|
|
365
|
-
});
|
|
366
|
-
|
|
367
|
-
return { qTensor, qGateTensor, kTensor, vTensor, usedFusedQKV: false };
|
|
368
418
|
}
|
|
369
419
|
|
|
370
420
|
export async function applyAttentionQKNorm({
|
|
@@ -90,9 +90,20 @@ export async function recordLayerAttentionGPU(
|
|
|
90
90
|
const allowF16Attention = wantsF16Output && kvCacheDtype === 'f16';
|
|
91
91
|
let attentionInput = input;
|
|
92
92
|
let attentionInputTemp = false;
|
|
93
|
+
let normed = attentionInput;
|
|
94
|
+
let qTensor = null;
|
|
95
|
+
let qGateTensor = null;
|
|
96
|
+
let kTensor = null;
|
|
97
|
+
let vTensor = null;
|
|
98
|
+
let attnOutput = null;
|
|
99
|
+
let attnForProjection = null;
|
|
100
|
+
let output = null;
|
|
101
|
+
let finalOutput = null;
|
|
102
|
+
let oProjInputTemp = null;
|
|
93
103
|
if (wantsF16Output && !allowF16Attention) {
|
|
94
104
|
attentionInput = await recordCastF16ToF32(recorder, input);
|
|
95
105
|
attentionInputTemp = true;
|
|
106
|
+
normed = attentionInput;
|
|
96
107
|
}
|
|
97
108
|
|
|
98
109
|
if (!layerWeights) {
|
|
@@ -108,7 +119,7 @@ export async function recordLayerAttentionGPU(
|
|
|
108
119
|
|
|
109
120
|
// 1. Input norm
|
|
110
121
|
|
|
111
|
-
|
|
122
|
+
try {
|
|
112
123
|
if (!skipInputNorm && layerWeights.inputNorm && getNormWeightBuffer) {
|
|
113
124
|
const normWeightBuf = getNormWeightBuffer(layerWeights.inputNorm, 'input_norm');
|
|
114
125
|
normed = await recordRMSNorm(recorder, attentionInput, normWeightBuf, rmsNormEps, {
|
|
@@ -132,7 +143,8 @@ export async function recordLayerAttentionGPU(
|
|
|
132
143
|
|
|
133
144
|
// 2. Q/K/V projections
|
|
134
145
|
const matmulOutputDtype = resolveAttentionProjectionOutputDtype(desiredOutputDtype);
|
|
135
|
-
let
|
|
146
|
+
let usedFusedQKV = false;
|
|
147
|
+
({ qTensor, qGateTensor, kTensor, vTensor, usedFusedQKV } = await projectAttentionQKV({
|
|
136
148
|
recorder,
|
|
137
149
|
normed,
|
|
138
150
|
layerWeights,
|
|
@@ -153,7 +165,7 @@ export async function recordLayerAttentionGPU(
|
|
|
153
165
|
trace.attn(layerIdx, `Using fused QKV path: ${qSizeFused}+${kSizeFused}+${vSizeFused}=${totalSize}`);
|
|
154
166
|
}
|
|
155
167
|
: null,
|
|
156
|
-
});
|
|
168
|
+
}));
|
|
157
169
|
|
|
158
170
|
// Optional per-head Q/K normalization.
|
|
159
171
|
// Some models use RMSNorm with (1+weight) offset formula, controlled by rmsNormWeightOffset.
|
|
@@ -502,9 +514,9 @@ export async function recordLayerAttentionGPU(
|
|
|
502
514
|
throw new Error(`Unsupported attention kernel variant "${attentionKernelVariant}" at layer ${layerIdx}`);
|
|
503
515
|
}
|
|
504
516
|
|
|
505
|
-
|
|
517
|
+
attnOutput = await runAttentionKernel();
|
|
506
518
|
|
|
507
|
-
|
|
519
|
+
attnForProjection = attnOutput;
|
|
508
520
|
if (qGateTensor) {
|
|
509
521
|
attnForProjection = await recordSiLU(recorder, attnOutput, {
|
|
510
522
|
size: numTokens * numHeads * headDim,
|
|
@@ -518,10 +530,10 @@ export async function recordLayerAttentionGPU(
|
|
|
518
530
|
|
|
519
531
|
// 6. Output projection (with optional fused residual for decode)
|
|
520
532
|
|
|
521
|
-
|
|
533
|
+
output = null;
|
|
522
534
|
let residualFused = false;
|
|
523
535
|
let oProjInput = attnForProjection;
|
|
524
|
-
|
|
536
|
+
oProjInputTemp = null;
|
|
525
537
|
if (layerWeights.oProj && getWeightBuffer) {
|
|
526
538
|
const oProjBuf = getWeightBuffer(layerWeights.oProj, 'o_proj');
|
|
527
539
|
const loraO = getLoRAModule(lora, layerIdx, 'o_proj');
|
|
@@ -589,7 +601,7 @@ export async function recordLayerAttentionGPU(
|
|
|
589
601
|
}
|
|
590
602
|
}
|
|
591
603
|
|
|
592
|
-
|
|
604
|
+
finalOutput = output;
|
|
593
605
|
|
|
594
606
|
const buffersToTrack = [];
|
|
595
607
|
if (output.buffer !== attnForProjection.buffer) {
|
|
@@ -619,4 +631,46 @@ export async function recordLayerAttentionGPU(
|
|
|
619
631
|
}
|
|
620
632
|
|
|
621
633
|
return { output: finalOutput, residualFused };
|
|
634
|
+
} catch (error) {
|
|
635
|
+
const tracked = new Set();
|
|
636
|
+
const trackOnce = (buffer) => {
|
|
637
|
+
if (!buffer || tracked.has(buffer)) return;
|
|
638
|
+
tracked.add(buffer);
|
|
639
|
+
recorder.trackTemporaryBuffer(buffer);
|
|
640
|
+
};
|
|
641
|
+
if (finalOutput?.buffer && finalOutput.buffer !== output?.buffer) {
|
|
642
|
+
trackOnce(finalOutput.buffer);
|
|
643
|
+
}
|
|
644
|
+
if (output?.buffer && output.buffer !== attnForProjection?.buffer) {
|
|
645
|
+
trackOnce(output.buffer);
|
|
646
|
+
}
|
|
647
|
+
if (oProjInputTemp?.buffer) {
|
|
648
|
+
trackOnce(oProjInputTemp.buffer);
|
|
649
|
+
}
|
|
650
|
+
if (attnForProjection?.buffer && attnForProjection.buffer !== attnOutput?.buffer) {
|
|
651
|
+
trackOnce(attnForProjection.buffer);
|
|
652
|
+
}
|
|
653
|
+
if (attnOutput?.buffer) {
|
|
654
|
+
trackOnce(attnOutput.buffer);
|
|
655
|
+
}
|
|
656
|
+
if (qGateTensor?.buffer) {
|
|
657
|
+
trackOnce(qGateTensor.buffer);
|
|
658
|
+
}
|
|
659
|
+
if (qTensor?.buffer) {
|
|
660
|
+
trackOnce(qTensor.buffer);
|
|
661
|
+
}
|
|
662
|
+
if (kTensor?.buffer) {
|
|
663
|
+
trackOnce(kTensor.buffer);
|
|
664
|
+
}
|
|
665
|
+
if (vTensor?.buffer) {
|
|
666
|
+
trackOnce(vTensor.buffer);
|
|
667
|
+
}
|
|
668
|
+
if (normed?.buffer && normed.buffer !== attentionInput?.buffer) {
|
|
669
|
+
trackOnce(normed.buffer);
|
|
670
|
+
}
|
|
671
|
+
if (attentionInputTemp && attentionInput?.buffer) {
|
|
672
|
+
trackOnce(attentionInput.buffer);
|
|
673
|
+
}
|
|
674
|
+
throw error;
|
|
675
|
+
}
|
|
622
676
|
}
|
|
@@ -97,9 +97,20 @@ export async function runLayerAttentionGPU(
|
|
|
97
97
|
const allowF16Attention = wantsF16Output && kvCacheDtype === 'f16';
|
|
98
98
|
let attentionInput = input;
|
|
99
99
|
let attentionInputTemp = false;
|
|
100
|
+
let normed = attentionInput;
|
|
101
|
+
let qTensor = null;
|
|
102
|
+
let qGateTensor = null;
|
|
103
|
+
let kTensor = null;
|
|
104
|
+
let vTensor = null;
|
|
105
|
+
let attnOutput = null;
|
|
106
|
+
let attnForProjection = null;
|
|
107
|
+
let output = null;
|
|
108
|
+
let finalOutput = null;
|
|
109
|
+
let oProjInputTemp = null;
|
|
100
110
|
if (wantsF16Output && !allowF16Attention) {
|
|
101
111
|
attentionInput = await castF16ToF32(input);
|
|
102
112
|
attentionInputTemp = true;
|
|
113
|
+
normed = attentionInput;
|
|
103
114
|
}
|
|
104
115
|
|
|
105
116
|
// Debug: attention input for configured layers
|
|
@@ -123,7 +134,7 @@ export async function runLayerAttentionGPU(
|
|
|
123
134
|
|
|
124
135
|
// 1. Input norm
|
|
125
136
|
|
|
126
|
-
|
|
137
|
+
try {
|
|
127
138
|
if (!skipInputNorm && layerWeights.inputNorm && getNormWeightBuffer) {
|
|
128
139
|
const normWeightBuf = getNormWeightBuffer(layerWeights.inputNorm, 'input_norm');
|
|
129
140
|
|
|
@@ -183,7 +194,8 @@ export async function runLayerAttentionGPU(
|
|
|
183
194
|
|
|
184
195
|
// 2. Q/K/V projections
|
|
185
196
|
const matmulOutputDtype = resolveAttentionProjectionOutputDtype(desiredOutputDtype);
|
|
186
|
-
let
|
|
197
|
+
let usedFusedQKV = false;
|
|
198
|
+
({ qTensor, qGateTensor, kTensor, vTensor, usedFusedQKV } = await projectAttentionQKV({
|
|
187
199
|
recorder: null,
|
|
188
200
|
normed,
|
|
189
201
|
layerWeights,
|
|
@@ -204,7 +216,7 @@ export async function runLayerAttentionGPU(
|
|
|
204
216
|
trace.attn(layerIdx, `Using fused QKV path: ${qSizeFused}+${kSizeFused}+${vSizeFused}=${totalSize}`);
|
|
205
217
|
}
|
|
206
218
|
: null,
|
|
207
|
-
});
|
|
219
|
+
}));
|
|
208
220
|
|
|
209
221
|
// Trace Q/K/V projections
|
|
210
222
|
if (kernelTrace.enabled) {
|
|
@@ -669,7 +681,7 @@ export async function runLayerAttentionGPU(
|
|
|
669
681
|
throw new Error(`Unsupported attention kernel variant "${attentionKernelVariant}" at layer ${layerIdx}`);
|
|
670
682
|
}
|
|
671
683
|
|
|
672
|
-
|
|
684
|
+
attnOutput = await runAttentionKernel();
|
|
673
685
|
|
|
674
686
|
// Trace attention output
|
|
675
687
|
if (kernelTrace.enabled) {
|
|
@@ -692,7 +704,7 @@ export async function runLayerAttentionGPU(
|
|
|
692
704
|
await debugCheckBuffer(attnOutput.buffer, `L${layerIdx} attention output (before o_proj, GPU)`, numTokens, numHeads * headDim);
|
|
693
705
|
}
|
|
694
706
|
|
|
695
|
-
|
|
707
|
+
attnForProjection = attnOutput;
|
|
696
708
|
if (qGateTensor) {
|
|
697
709
|
attnForProjection = await runSiLU(attnOutput, {
|
|
698
710
|
size: numTokens * numHeads * headDim,
|
|
@@ -706,10 +718,10 @@ export async function runLayerAttentionGPU(
|
|
|
706
718
|
|
|
707
719
|
// 6. Output projection (with optional fused residual for decode)
|
|
708
720
|
|
|
709
|
-
|
|
721
|
+
output = null;
|
|
710
722
|
let residualFused = false;
|
|
711
723
|
let oProjInput = attnForProjection;
|
|
712
|
-
|
|
724
|
+
oProjInputTemp = null;
|
|
713
725
|
if (layerWeights.oProj && getWeightBuffer) {
|
|
714
726
|
const oProjBuf = getWeightBuffer(layerWeights.oProj, 'o_proj');
|
|
715
727
|
const loraO = getLoRAModule(lora, layerIdx, 'o_proj');
|
|
@@ -807,7 +819,7 @@ export async function runLayerAttentionGPU(
|
|
|
807
819
|
await debugCheckBuffer(output.buffer, `L${layerIdx} attention output (after o_proj, GPU)`, numTokens, hiddenSize);
|
|
808
820
|
}
|
|
809
821
|
|
|
810
|
-
|
|
822
|
+
finalOutput = output;
|
|
811
823
|
|
|
812
824
|
const buffersToRelease = [];
|
|
813
825
|
if (output.buffer !== attnForProjection.buffer) {
|
|
@@ -832,4 +844,46 @@ export async function runLayerAttentionGPU(
|
|
|
832
844
|
}
|
|
833
845
|
|
|
834
846
|
return { output: finalOutput, residualFused };
|
|
847
|
+
} catch (error) {
|
|
848
|
+
const released = new Set();
|
|
849
|
+
const releaseOnce = (buffer) => {
|
|
850
|
+
if (!buffer || released.has(buffer)) return;
|
|
851
|
+
released.add(buffer);
|
|
852
|
+
releaseBuffer(buffer);
|
|
853
|
+
};
|
|
854
|
+
if (finalOutput?.buffer && finalOutput.buffer !== output?.buffer) {
|
|
855
|
+
releaseOnce(finalOutput.buffer);
|
|
856
|
+
}
|
|
857
|
+
if (output?.buffer && output.buffer !== attnForProjection?.buffer) {
|
|
858
|
+
releaseOnce(output.buffer);
|
|
859
|
+
}
|
|
860
|
+
if (oProjInputTemp?.buffer) {
|
|
861
|
+
releaseOnce(oProjInputTemp.buffer);
|
|
862
|
+
}
|
|
863
|
+
if (attnForProjection?.buffer && attnForProjection.buffer !== attnOutput?.buffer) {
|
|
864
|
+
releaseOnce(attnForProjection.buffer);
|
|
865
|
+
}
|
|
866
|
+
if (attnOutput?.buffer) {
|
|
867
|
+
releaseOnce(attnOutput.buffer);
|
|
868
|
+
}
|
|
869
|
+
if (qGateTensor?.buffer) {
|
|
870
|
+
releaseOnce(qGateTensor.buffer);
|
|
871
|
+
}
|
|
872
|
+
if (qTensor?.buffer) {
|
|
873
|
+
releaseOnce(qTensor.buffer);
|
|
874
|
+
}
|
|
875
|
+
if (kTensor?.buffer) {
|
|
876
|
+
releaseOnce(kTensor.buffer);
|
|
877
|
+
}
|
|
878
|
+
if (vTensor?.buffer) {
|
|
879
|
+
releaseOnce(vTensor.buffer);
|
|
880
|
+
}
|
|
881
|
+
if (normed?.buffer && normed.buffer !== attentionInput?.buffer) {
|
|
882
|
+
releaseOnce(normed.buffer);
|
|
883
|
+
}
|
|
884
|
+
if (attentionInputTemp && attentionInput?.buffer) {
|
|
885
|
+
releaseOnce(attentionInput.buffer);
|
|
886
|
+
}
|
|
887
|
+
throw error;
|
|
888
|
+
}
|
|
835
889
|
}
|
|
@@ -134,11 +134,10 @@ function resolveIntermediateSizeForRuntime(manifest, inf, arch, modelId) {
|
|
|
134
134
|
if (inferred == null || inferred === fromArch) {
|
|
135
135
|
return fromArch;
|
|
136
136
|
}
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
137
|
+
throw new Error(
|
|
138
|
+
`Manifest "${modelId}" has intermediateSize=${fromArch}, but FFN tensors imply ${inferred}. ` +
|
|
139
|
+
'Re-convert the model so manifest architecture matches the weights.'
|
|
140
140
|
);
|
|
141
|
-
return inferred;
|
|
142
141
|
}
|
|
143
142
|
|
|
144
143
|
// =============================================================================
|
|
@@ -319,14 +319,8 @@ export async function embed(tokenIds, embedBuffer, config) {
|
|
|
319
319
|
const firstTokenId = tokenIdArray[0];
|
|
320
320
|
const bytesPerElement = useF16 ? 2 : 4;
|
|
321
321
|
const sampleSize = Math.min(32 * bytesPerElement, hiddenSize * bytesPerElement);
|
|
322
|
-
const
|
|
323
|
-
const
|
|
324
|
-
enc.copyBufferToBuffer(gatherOutput.buffer, 0, staging, 0, sampleSize);
|
|
325
|
-
device.queue.submit([enc.finish()]);
|
|
326
|
-
await staging.mapAsync(GPUMapMode.READ);
|
|
327
|
-
const data = decodeReadback(staging.getMappedRange().slice(0), gatherOptions.outputDtype);
|
|
328
|
-
staging.unmap();
|
|
329
|
-
staging.destroy();
|
|
322
|
+
const readback = await readBuffer(gatherOutput.buffer, sampleSize);
|
|
323
|
+
const data = decodeReadback(readback, gatherOptions.outputDtype);
|
|
330
324
|
|
|
331
325
|
// Compute statistics
|
|
332
326
|
let sum = 0, sumSq = 0;
|
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
import { log } from '../../../debug/index.js';
|
|
2
1
|
import { resolveKernelPath } from '../../../config/kernel-path-loader.js';
|
|
3
2
|
import { selectRuleValue } from '../../../rules/rule-registry.js';
|
|
4
3
|
import {
|
|
@@ -9,19 +8,36 @@ import {
|
|
|
9
8
|
export const PRIMARY_EXECUTION_PLAN_ID = 'primary';
|
|
10
9
|
export const FINITENESS_FALLBACK_EXECUTION_PLAN_ID = 'finiteness_fallback';
|
|
11
10
|
|
|
12
|
-
function
|
|
13
|
-
if (
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
11
|
+
function assertOptionalBoolean(value, label) {
|
|
12
|
+
if (value === undefined) {
|
|
13
|
+
return undefined;
|
|
14
|
+
}
|
|
15
|
+
if (typeof value !== 'boolean') {
|
|
16
|
+
throw new Error(`[ExecutionPlan] ${label} must be boolean when provided; got ${JSON.stringify(value)}.`);
|
|
17
|
+
}
|
|
18
|
+
return value;
|
|
18
19
|
}
|
|
19
20
|
|
|
20
|
-
function
|
|
21
|
-
if (value ===
|
|
22
|
-
return
|
|
21
|
+
function assertOptionalPositiveInt(value, label) {
|
|
22
|
+
if (value === undefined) {
|
|
23
|
+
return undefined;
|
|
24
|
+
}
|
|
25
|
+
if (!Number.isInteger(value) || value < 1) {
|
|
26
|
+
throw new Error(`[ExecutionPlan] ${label} must be a positive integer when provided; got ${JSON.stringify(value)}.`);
|
|
27
|
+
}
|
|
28
|
+
return value;
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
function assertOptionalStopCheckMode(value) {
|
|
32
|
+
if (value === undefined) {
|
|
33
|
+
return undefined;
|
|
34
|
+
}
|
|
35
|
+
if (value !== 'batch' && value !== 'per-token') {
|
|
36
|
+
throw new Error(
|
|
37
|
+
`[ExecutionPlan] stopCheckMode must be "batch" or "per-token" when provided; got ${JSON.stringify(value)}.`
|
|
38
|
+
);
|
|
23
39
|
}
|
|
24
|
-
return
|
|
40
|
+
return value;
|
|
25
41
|
}
|
|
26
42
|
|
|
27
43
|
function resolveFallbackActivationDtype(primaryActivationDtype) {
|
|
@@ -244,11 +260,17 @@ export function activateFallbackExecutionPlan(container) {
|
|
|
244
260
|
|
|
245
261
|
function resolveExecutionOverrides(options = {}) {
|
|
246
262
|
return {
|
|
247
|
-
disableCommandBatching:
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
263
|
+
disableCommandBatching: assertOptionalBoolean(
|
|
264
|
+
options.disableCommandBatching,
|
|
265
|
+
'disableCommandBatching'
|
|
266
|
+
),
|
|
267
|
+
disableMultiTokenDecode: assertOptionalBoolean(
|
|
268
|
+
options.disableMultiTokenDecode,
|
|
269
|
+
'disableMultiTokenDecode'
|
|
270
|
+
),
|
|
271
|
+
batchSize: assertOptionalPositiveInt(options.batchSize, 'batchSize'),
|
|
272
|
+
stopCheckMode: assertOptionalStopCheckMode(options.stopCheckMode),
|
|
273
|
+
maxTokens: assertOptionalPositiveInt(options.maxTokens, 'maxTokens'),
|
|
252
274
|
};
|
|
253
275
|
}
|
|
254
276
|
|
|
@@ -268,9 +290,9 @@ export function resolveExecutionSessionPlan(container, options = {}) {
|
|
|
268
290
|
deferredRoundingWindowTokens: activePlan.deferredRoundingWindowTokens,
|
|
269
291
|
disableCommandBatching: overrides.disableCommandBatching ?? activePlan.defaultDisableCommandBatching,
|
|
270
292
|
disableMultiTokenDecode: overrides.disableMultiTokenDecode ?? activePlan.defaultDisableMultiTokenDecode,
|
|
271
|
-
batchSize:
|
|
272
|
-
stopCheckMode:
|
|
273
|
-
maxTokens:
|
|
293
|
+
batchSize: overrides.batchSize ?? activePlan.defaultBatchSize,
|
|
294
|
+
stopCheckMode: overrides.stopCheckMode ?? activePlan.defaultStopCheckMode,
|
|
295
|
+
maxTokens: overrides.maxTokens ?? activePlan.defaultMaxTokens,
|
|
274
296
|
readbackInterval: activePlan.readbackInterval,
|
|
275
297
|
ringTokens: activePlan.ringTokens,
|
|
276
298
|
ringStop: activePlan.ringStop,
|