@simulatte/doppler 0.1.6 → 0.1.8
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +145 -0
- package/README.md +16 -23
- package/package.json +30 -32
- package/src/adapters/adapter-registry.js +12 -1
- package/src/adapters/lora-loader.js +23 -6
- package/src/bridge/extension-client.d.ts +5 -0
- package/src/bridge/extension-client.js +40 -0
- package/src/bridge/index.d.ts +2 -1
- package/src/bridge/index.js +6 -4
- package/src/browser/browser-converter.js +31 -1
- package/src/browser/file-picker.js +6 -0
- package/src/browser/safetensors-parser-browser.js +84 -1
- package/src/browser/shard-io-browser.js +2 -2
- package/src/browser/tensor-source-download.js +8 -2
- package/src/browser/tensor-source-http.d.ts +1 -0
- package/src/browser/tensor-source-http.js +5 -1
- package/src/client/doppler-api.browser.js +20 -4
- package/src/client/doppler-api.js +19 -3
- package/src/client/doppler-provider/generation.js +12 -0
- package/src/client/doppler-provider/model-manager.d.ts +10 -0
- package/src/client/doppler-provider/model-manager.js +91 -19
- package/src/client/doppler-provider/source-runtime.d.ts +2 -1
- package/src/client/doppler-provider/source-runtime.js +132 -13
- package/src/client/doppler-registry.json +5 -20
- package/src/config/backward-registry-loader.js +17 -2
- package/src/config/execution-v0-contract-check.js +113 -15
- package/src/config/kernel-path-contract-check.js +57 -29
- package/src/config/kernel-path-loader.d.ts +5 -0
- package/src/config/kernel-path-loader.js +18 -36
- package/src/config/kernels/kernel-ref-digests.js +1 -1
- package/src/config/kernels/registry.js +14 -1
- package/src/config/kernels/registry.json +81 -5
- package/src/config/loader.d.ts +1 -1
- package/src/config/loader.js +15 -2
- package/src/config/merge-contract-check.js +66 -4
- package/src/config/merge-helpers.js +128 -7
- package/src/config/merge.d.ts +1 -0
- package/src/config/merge.js +10 -0
- package/src/config/param-validator.js +47 -2
- package/src/config/presets/kernel-paths/{gemma2-q4k-dequant-f32a.json → gemma2-q4k-dequant-f32a-nosubgroups.json} +3 -3
- package/src/config/presets/kernel-paths/gemma3-f16-fused-f32a-online-streamingprefill.json +223 -0
- package/src/config/presets/kernel-paths/{gemma3-q4k-dequant-f32a.json → gemma3-q4k-dequant-f32a-nosubgroups.json} +3 -3
- package/src/config/presets/kernel-paths/gemma3-q4k-dequant-f32w-f32a-online.json +56 -0
- package/src/config/presets/kernel-paths/lfm2-q4k-dequant-f32a-nosubgroups.json +61 -0
- package/src/config/presets/kernel-paths/registry.json +43 -8
- package/src/config/presets/models/gemma2.json +3 -2
- package/src/config/presets/models/gemma3.json +2 -0
- package/src/config/presets/models/qwen3.json +4 -3
- package/src/config/presets/models/qwen3_5.json +16 -0
- package/src/config/presets/runtime/experiments/bench/gemma3-bench-q4k.json +1 -1
- package/src/config/presets/runtime/experiments/debug/gemma3-debug-q4k.json +1 -1
- package/src/config/presets/runtime/experiments/verify/gemma3-verify.json +1 -1
- package/src/config/presets/runtime/kernels/dequant-f16-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/dequant-f32-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/embeddinggemma-q4k-dequant-f32a.json +37 -0
- package/src/config/presets/runtime/kernels/fused-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f16a.json +33 -0
- package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f32a-nosubgroups.json +33 -0
- package/src/config/presets/runtime/kernels/gemma2-q4k-fused-f32a.json +33 -0
- package/src/config/presets/runtime/kernels/safe-q4k.json +6 -13
- package/src/config/presets/runtime/model/qwen3-5-layer-probe.json +52 -0
- package/src/config/presets/runtime/model/qwen3-5-linear-attn-debug.json +90 -0
- package/src/config/presets/runtime/platform/metal-apple-q4k.json +1 -1
- package/src/config/runtime.js +6 -1
- package/src/config/schema/conversion.schema.d.ts +1 -0
- package/src/config/schema/debug.schema.d.ts +5 -0
- package/src/config/schema/doppler.schema.js +16 -21
- package/src/config/schema/inference-defaults.schema.js +3 -3
- package/src/config/schema/kernel-path.schema.d.ts +5 -1
- package/src/config/schema/kernel-thresholds.schema.js +12 -4
- package/src/config/schema/manifest.schema.d.ts +3 -2
- package/src/config/schema/manifest.schema.js +17 -4
- package/src/config/schema/storage.schema.js +1 -1
- package/src/config/training-defaults.js +30 -22
- package/src/converter/conversion-plan.js +104 -11
- package/src/converter/core.d.ts +7 -0
- package/src/converter/core.js +16 -9
- package/src/converter/execution-v0-manifest.js +4 -1
- package/src/converter/index.d.ts +1 -0
- package/src/converter/index.js +1 -0
- package/src/converter/manifest-inference.js +50 -29
- package/src/converter/parsers/diffusion.js +0 -3
- package/src/converter/parsers/transformer.js +4 -0
- package/src/converter/quantization-info.js +40 -16
- package/src/converter/quantizer.js +19 -12
- package/src/converter/rope-config.js +8 -6
- package/src/converter/shard-packer.d.ts +1 -1
- package/src/converter/shard-packer.js +4 -1
- package/src/converter/tokenizer-utils.d.ts +1 -0
- package/src/converter/tokenizer-utils.js +4 -1
- package/src/debug/config.js +123 -11
- package/src/debug/reference/hf_qwen35_linear_attn_debug.py +268 -0
- package/src/debug/signals.js +7 -1
- package/src/debug/tensor.d.ts +2 -0
- package/src/debug/tensor.js +13 -2
- package/src/distribution/p2p-control-plane.js +52 -12
- package/src/distribution/p2p-observability.js +43 -7
- package/src/distribution/p2p-webrtc-browser.js +20 -0
- package/src/distribution/shard-delivery.js +83 -27
- package/src/formats/gguf/types.js +33 -16
- package/src/formats/rdrr/groups.d.ts +12 -4
- package/src/formats/rdrr/groups.js +3 -6
- package/src/formats/rdrr/parsing.d.ts +4 -0
- package/src/formats/rdrr/parsing.js +53 -3
- package/src/formats/rdrr/types.d.ts +2 -1
- package/src/gpu/command-recorder.js +86 -61
- package/src/gpu/device.d.ts +1 -0
- package/src/gpu/device.js +73 -19
- package/src/gpu/kernel-tuner/benchmarks.js +326 -316
- package/src/gpu/kernel-tuner/cache.js +71 -4
- package/src/gpu/kernel-tuner/tuner.js +22 -4
- package/src/gpu/kernels/attention.js +15 -34
- package/src/gpu/kernels/backward/adam.js +62 -58
- package/src/gpu/kernels/backward/attention_backward.js +257 -169
- package/src/gpu/kernels/backward/conv2d_backward.js +14 -1
- package/src/gpu/kernels/cast.js +191 -149
- package/src/gpu/kernels/check-stop.js +33 -44
- package/src/gpu/kernels/conv2d.js +27 -17
- package/src/gpu/kernels/cross_entropy_loss.js +21 -15
- package/src/gpu/kernels/depthwise_conv2d.js +36 -26
- package/src/gpu/kernels/dequant.js +178 -126
- package/src/gpu/kernels/energy.d.ts +3 -21
- package/src/gpu/kernels/energy.js +111 -88
- package/src/gpu/kernels/feature-check.js +1 -1
- package/src/gpu/kernels/fused_ffn.js +84 -65
- package/src/gpu/kernels/fused_matmul_residual.js +56 -33
- package/src/gpu/kernels/fused_matmul_rmsnorm.js +62 -45
- package/src/gpu/kernels/gather.js +33 -15
- package/src/gpu/kernels/gelu.js +19 -11
- package/src/gpu/kernels/grouped_pointwise_conv2d.js +33 -23
- package/src/gpu/kernels/groupnorm.js +34 -23
- package/src/gpu/kernels/index.d.ts +8 -0
- package/src/gpu/kernels/index.js +6 -0
- package/src/gpu/kernels/kv-quantize.js +5 -2
- package/src/gpu/kernels/layernorm.js +35 -19
- package/src/gpu/kernels/logit-merge.js +5 -3
- package/src/gpu/kernels/matmul-selection.js +47 -4
- package/src/gpu/kernels/matmul.d.ts +2 -0
- package/src/gpu/kernels/matmul.js +59 -40
- package/src/gpu/kernels/modulate.js +23 -15
- package/src/gpu/kernels/moe.js +221 -175
- package/src/gpu/kernels/pixel_shuffle.js +22 -14
- package/src/gpu/kernels/relu.js +18 -10
- package/src/gpu/kernels/repeat_channels.js +25 -17
- package/src/gpu/kernels/residual.js +37 -27
- package/src/gpu/kernels/rmsnorm.js +66 -43
- package/src/gpu/kernels/rope.js +3 -0
- package/src/gpu/kernels/sample.js +27 -38
- package/src/gpu/kernels/sana_linear_attention.js +18 -10
- package/src/gpu/kernels/scale.js +18 -11
- package/src/gpu/kernels/shader-cache.js +4 -2
- package/src/gpu/kernels/silu.js +120 -72
- package/src/gpu/kernels/softmax.js +44 -25
- package/src/gpu/kernels/split_qg.d.ts +50 -0
- package/src/gpu/kernels/split_qg.js +46 -0
- package/src/gpu/kernels/split_qg.wgsl +58 -0
- package/src/gpu/kernels/split_qg_f16.wgsl +62 -0
- package/src/gpu/kernels/split_qkv.js +23 -13
- package/src/gpu/kernels/transpose.js +18 -10
- package/src/gpu/kernels/transpose.wgsl +5 -3
- package/src/gpu/kernels/upsample2d.js +21 -13
- package/src/gpu/kernels/utils.js +20 -13
- package/src/gpu/partitioned-buffer-pool.js +10 -2
- package/src/gpu/perf-guards.js +2 -9
- package/src/gpu/profiler.js +27 -22
- package/src/gpu/readback-utils.d.ts +16 -0
- package/src/gpu/readback-utils.js +41 -0
- package/src/gpu/submit-tracker.js +13 -0
- package/src/gpu/uniform-cache.d.ts +1 -0
- package/src/gpu/uniform-cache.js +30 -9
- package/src/gpu/weight-buffer.d.ts +1 -1
- package/src/gpu/weight-buffer.js +1 -1
- package/src/hotswap/intent-bundle.js +6 -0
- package/src/hotswap/manifest.d.ts +10 -1
- package/src/hotswap/manifest.js +12 -2
- package/src/hotswap/runtime.js +30 -8
- package/src/index-browser.d.ts +44 -0
- package/src/index-browser.js +14 -0
- package/src/inference/browser-harness-contract-helpers.d.ts +5 -0
- package/src/inference/browser-harness-contract-helpers.js +28 -0
- package/src/inference/browser-harness-diffusion-energy-suites.d.ts +2 -0
- package/src/inference/browser-harness-diffusion-energy-suites.js +269 -0
- package/src/inference/browser-harness-model-helpers.d.ts +16 -0
- package/src/inference/browser-harness-model-helpers.js +217 -0
- package/src/inference/browser-harness-report-helpers.d.ts +7 -0
- package/src/inference/browser-harness-report-helpers.js +42 -0
- package/src/inference/browser-harness-runtime-helpers.d.ts +61 -0
- package/src/inference/browser-harness-runtime-helpers.js +415 -0
- package/src/inference/browser-harness-suite-helpers.d.ts +28 -0
- package/src/inference/browser-harness-suite-helpers.js +268 -0
- package/src/inference/browser-harness-text-helpers.d.ts +27 -0
- package/src/inference/browser-harness-text-helpers.js +788 -0
- package/src/inference/browser-harness.d.ts +8 -0
- package/src/inference/browser-harness.js +149 -1996
- package/src/inference/kv-cache/base.js +140 -94
- package/src/inference/kv-cache/tiered.js +5 -3
- package/src/inference/moe-router.js +88 -56
- package/src/inference/multi-model-network.js +5 -3
- package/src/inference/network-evolution.d.ts +11 -2
- package/src/inference/network-evolution.js +20 -21
- package/src/inference/pipelines/context.d.ts +3 -0
- package/src/inference/pipelines/context.js +142 -2
- package/src/inference/pipelines/diffusion/helpers.js +10 -2
- package/src/inference/pipelines/diffusion/pipeline.js +2 -1
- package/src/inference/pipelines/diffusion/sd3-transformer.js +10 -10
- package/src/inference/pipelines/diffusion/text-encoder-gpu.js +8 -2
- package/src/inference/pipelines/diffusion/vae.js +3 -7
- package/src/inference/pipelines/energy/pipeline.js +27 -21
- package/src/inference/pipelines/energy/quintel.d.ts +5 -0
- package/src/inference/pipelines/energy/quintel.js +11 -0
- package/src/inference/pipelines/energy-head/row-head-pipeline.js +17 -13
- package/src/inference/pipelines/structured/json-head-pipeline.js +26 -11
- package/src/inference/pipelines/text/attention/output-projection.d.ts +12 -0
- package/src/inference/pipelines/text/attention/output-projection.js +8 -0
- package/src/inference/pipelines/text/attention/projections.d.ts +10 -1
- package/src/inference/pipelines/text/attention/projections.js +192 -112
- package/src/inference/pipelines/text/attention/record.js +77 -14
- package/src/inference/pipelines/text/attention/run.js +112 -14
- package/src/inference/pipelines/text/config.js +17 -4
- package/src/inference/pipelines/text/embed.js +2 -8
- package/src/inference/pipelines/text/execution-plan.js +46 -23
- package/src/inference/pipelines/text/execution-v0-contract-helpers.d.ts +59 -0
- package/src/inference/pipelines/text/execution-v0-contract-helpers.js +937 -0
- package/src/inference/pipelines/text/execution-v0-runtime-builders.d.ts +15 -0
- package/src/inference/pipelines/text/execution-v0-runtime-builders.js +279 -0
- package/src/inference/pipelines/text/execution-v0.js +62 -1013
- package/src/inference/pipelines/text/generator-runtime.js +5 -0
- package/src/inference/pipelines/text/generator-steps.d.ts +52 -0
- package/src/inference/pipelines/text/generator-steps.js +340 -221
- package/src/inference/pipelines/text/generator.js +56 -40
- package/src/inference/pipelines/text/init.d.ts +13 -0
- package/src/inference/pipelines/text/init.js +94 -25
- package/src/inference/pipelines/text/kernel-path-auto-select.js +2 -0
- package/src/inference/pipelines/text/kernel-trace.d.ts +2 -0
- package/src/inference/pipelines/text/kernel-trace.js +6 -0
- package/src/inference/pipelines/text/layer.js +4 -9
- package/src/inference/pipelines/text/linear-attention.d.ts +15 -0
- package/src/inference/pipelines/text/linear-attention.js +113 -9
- package/src/inference/pipelines/text/logits/gpu.js +12 -7
- package/src/inference/pipelines/text/logits/index.d.ts +6 -1
- package/src/inference/pipelines/text/logits/index.js +13 -12
- package/src/inference/pipelines/text/logits/utils.d.ts +7 -0
- package/src/inference/pipelines/text/logits/utils.js +9 -0
- package/src/inference/pipelines/text/lora-apply.js +50 -32
- package/src/inference/pipelines/text/model-load.js +282 -104
- package/src/inference/pipelines/text/moe-cache.js +5 -4
- package/src/inference/pipelines/text/moe-cpu-gptoss.js +74 -69
- package/src/inference/pipelines/text/moe-cpu.js +42 -38
- package/src/inference/pipelines/text/moe-gpu.js +110 -86
- package/src/inference/pipelines/text/ops.js +90 -90
- package/src/inference/pipelines/text/probes.js +9 -9
- package/src/inference/pipelines/text/sampling.js +52 -6
- package/src/inference/pipelines/text/weights.js +17 -7
- package/src/inference/pipelines/text.js +13 -1
- package/src/inference/speculative.d.ts +2 -2
- package/src/inference/speculative.js +4 -18
- package/src/inference/test-harness.d.ts +1 -1
- package/src/inference/test-harness.js +17 -7
- package/src/inference/tokenizer.d.ts +0 -5
- package/src/inference/tokenizer.js +4 -23
- package/src/inference/tokenizers/bpe.js +9 -0
- package/src/inference/tokenizers/bundled.js +20 -0
- package/src/inference/tokenizers/sentencepiece.js +12 -0
- package/src/loader/doppler-loader.js +38 -22
- package/src/loader/dtype-utils.js +3 -44
- package/src/loader/embedding-loader.js +7 -3
- package/src/loader/experts/expert-cache.js +13 -6
- package/src/loader/experts/expert-loader.js +10 -6
- package/src/loader/final-weights-loader.js +10 -4
- package/src/loader/layer-loader.js +2 -1
- package/src/loader/loader-state.js +2 -2
- package/src/loader/memory-monitor.js +8 -0
- package/src/loader/multi-model-loader.d.ts +14 -0
- package/src/loader/multi-model-loader.js +70 -24
- package/src/loader/shard-cache.js +84 -14
- package/src/loader/shard-resolver.js +25 -3
- package/src/loader/tensors/tensor-loader.js +214 -144
- package/src/loader/tensors/tensor-reader.js +76 -19
- package/src/loader/weight-downcast.js +1 -1
- package/src/memory/buffer-pool.d.ts +9 -1
- package/src/memory/buffer-pool.js +109 -44
- package/src/memory/unified-detect.js +1 -1
- package/src/rules/inference/dtype.rules.json +5 -0
- package/src/rules/inference/kernel-path.rules.json +24 -8
- package/src/rules/kernels/split-qg.rules.json +6 -0
- package/src/rules/rule-registry.js +27 -1
- package/src/storage/backends/opfs-store.js +68 -24
- package/src/storage/downloader.js +365 -83
- package/src/storage/index.d.ts +3 -0
- package/src/storage/index.js +3 -0
- package/src/storage/preflight.d.ts +2 -2
- package/src/storage/preflight.js +24 -2
- package/src/storage/quickstart-downloader.js +11 -5
- package/src/storage/registry.js +10 -4
- package/src/storage/reports.js +1 -1
- package/src/storage/shard-manager.d.ts +15 -1
- package/src/storage/shard-manager.js +55 -6
- package/src/storage/source-artifact-store.d.ts +52 -0
- package/src/storage/source-artifact-store.js +234 -0
- package/src/tooling/command-api-constants.d.ts +9 -0
- package/src/tooling/command-api-constants.js +9 -0
- package/src/tooling/command-api-family-normalizers.d.ts +9 -0
- package/src/tooling/command-api-family-normalizers.js +343 -0
- package/src/tooling/command-api-helpers.d.ts +25 -0
- package/src/tooling/command-api-helpers.js +262 -0
- package/src/tooling/command-api.js +16 -602
- package/src/tooling/command-envelope.js +4 -1
- package/src/tooling/command-runner-shared.js +52 -18
- package/src/tooling/conversion-config-materializer.js +3 -5
- package/src/tooling/lean-execution-contract.js +150 -3
- package/src/tooling/node-browser-command-runner.js +161 -271
- package/src/tooling/node-command-runner.js +29 -3
- package/src/tooling/node-converter.js +30 -1
- package/src/tooling/node-source-runtime.d.ts +1 -1
- package/src/tooling/node-source-runtime.js +120 -3
- package/src/tooling/node-webgpu.js +24 -21
- package/src/tooling/opfs-cache.js +21 -4
- package/src/tooling/runtime-input-composition.d.ts +38 -0
- package/src/tooling/runtime-input-composition.js +86 -0
- package/src/tooling/source-runtime-bundle.d.ts +40 -5
- package/src/tooling/source-runtime-bundle.js +261 -34
- package/src/tooling/source-runtime-materializer.d.ts +6 -0
- package/src/tooling/source-runtime-materializer.js +93 -0
- package/src/training/attention-backward.js +32 -17
- package/src/training/autograd.js +80 -52
- package/src/training/checkpoint-watch.d.ts +2 -1
- package/src/training/checkpoint-watch.js +39 -6
- package/src/training/checkpoint.js +40 -11
- package/src/training/clip.js +2 -1
- package/src/training/datasets/token-batch.js +20 -8
- package/src/training/distillation/checkpoint-watch.js +1 -0
- package/src/training/distillation/student-fixture.d.ts +22 -0
- package/src/training/distillation/student-fixture.js +846 -0
- package/src/training/distillation/suite-data.d.ts +45 -0
- package/src/training/distillation/suite-data.js +189 -0
- package/src/training/lora-pipeline.js +4 -7
- package/src/training/lora.js +26 -12
- package/src/training/loss.js +5 -6
- package/src/training/objectives/cross_entropy.js +2 -5
- package/src/training/objectives/distill_kd.js +4 -8
- package/src/training/objectives/distill_triplet.js +4 -8
- package/src/training/objectives/ul_stage2_base.js +4 -8
- package/src/training/operator-command.js +2 -0
- package/src/training/optimizer.js +19 -7
- package/src/training/runner.js +2 -1
- package/src/training/suite.js +18 -978
- package/src/training/tensor-factory.d.ts +9 -0
- package/src/training/tensor-factory.js +13 -0
- package/src/training/trainer.js +3 -5
- package/src/training/ul_dataset.js +3 -5
- package/src/training/workloads.js +70 -79
- package/src/types/model.d.ts +5 -0
- package/src/version.js +1 -1
- package/tools/convert-safetensors-node.js +22 -16
- package/tools/doppler-cli.js +50 -26
|
@@ -4,6 +4,9 @@ import { readBuffer, releaseBuffer, uploadData, acquireBuffer } from '../../../m
|
|
|
4
4
|
import { log } from '../../../debug/index.js';
|
|
5
5
|
import { decodeReadback } from './debug-utils/index.js';
|
|
6
6
|
import { runLinearAttentionCoreGPU } from '../../../gpu/kernels/linear-attention-core.js';
|
|
7
|
+
import { runProbes } from './probes.js';
|
|
8
|
+
import { QK_K, Q4K_BLOCK_BYTES } from '../../../config/schema/index.js';
|
|
9
|
+
import { dequantizeQ4KM } from '../../../converter/quantizer.js';
|
|
7
10
|
|
|
8
11
|
const LINEAR_RUNTIME_SCHEMA_VERSION = 1;
|
|
9
12
|
const QK_L2NORM_EPS = 1e-6;
|
|
@@ -33,6 +36,15 @@ function bytesFromDtype(dtype) {
|
|
|
33
36
|
return 4;
|
|
34
37
|
}
|
|
35
38
|
|
|
39
|
+
export function applyLinearNormWeightOffset(values, rmsNormWeightOffset) {
|
|
40
|
+
if (!(values instanceof Float32Array)) {
|
|
41
|
+
throw new Error('applyLinearNormWeightOffset requires Float32Array input.');
|
|
42
|
+
}
|
|
43
|
+
// Qwen linear-attention output norm uses direct weights even when surrounding
|
|
44
|
+
// transformer RMSNorm sites use the Gemma-style (1 + weight) formula.
|
|
45
|
+
return values;
|
|
46
|
+
}
|
|
47
|
+
|
|
36
48
|
function cloneLayerRuntimeState(layerState) {
|
|
37
49
|
return {
|
|
38
50
|
layerIdx: layerState.layerIdx,
|
|
@@ -173,9 +185,22 @@ function inferLinearNormModeFromWeight(weight, projectionLayout) {
|
|
|
173
185
|
if (weight instanceof ArrayBuffer) {
|
|
174
186
|
return classify(Math.trunc(weight.byteLength / Float32Array.BYTES_PER_ELEMENT));
|
|
175
187
|
}
|
|
188
|
+
const explicitDtype = typeof weight?.dtype === 'string' ? weight.dtype.toLowerCase() : null;
|
|
189
|
+
const trackedDtype = isGpuBuffer(weight) ? String(getBufferDtype(weight) ?? '').toLowerCase() : '';
|
|
190
|
+
const bytesPerElement = bytesFromDtype(explicitDtype || trackedDtype || null);
|
|
191
|
+
const sizedElements = Number.isFinite(weight?.size)
|
|
192
|
+
? Math.trunc(Number(weight.size) / bytesPerElement)
|
|
193
|
+
: null;
|
|
194
|
+
if (sizedElements && Number(weight.size) % bytesPerElement === 0) {
|
|
195
|
+
return classify(sizedElements);
|
|
196
|
+
}
|
|
176
197
|
return null;
|
|
177
198
|
}
|
|
178
199
|
|
|
200
|
+
export function inferLinearNormMode(weight, projectionLayout) {
|
|
201
|
+
return inferLinearNormModeFromWeight(weight, projectionLayout);
|
|
202
|
+
}
|
|
203
|
+
|
|
179
204
|
function resolveLinearNormMode(configNormMode, normWeight, projectionLayout, layerIdx) {
|
|
180
205
|
const configuredMode = normalizeLinearNormMode(configNormMode);
|
|
181
206
|
const inferredMode = inferLinearNormModeFromWeight(normWeight, projectionLayout);
|
|
@@ -185,7 +210,15 @@ function resolveLinearNormMode(configNormMode, normWeight, projectionLayout, lay
|
|
|
185
210
|
`but norm.weight shape implies "${inferredMode}".`
|
|
186
211
|
);
|
|
187
212
|
}
|
|
188
|
-
|
|
213
|
+
if (configuredMode) {
|
|
214
|
+
return configuredMode;
|
|
215
|
+
}
|
|
216
|
+
if (inferredMode) {
|
|
217
|
+
return inferredMode;
|
|
218
|
+
}
|
|
219
|
+
throw new Error(
|
|
220
|
+
`linear_attention layer ${layerIdx} requires explicit linearNormMode or a norm.weight shape that resolves it.`
|
|
221
|
+
);
|
|
189
222
|
}
|
|
190
223
|
|
|
191
224
|
async function readWeightAsF32(weight, expectedElements, label) {
|
|
@@ -261,9 +294,27 @@ async function readWeightAsF32(weight, expectedElements, label) {
|
|
|
261
294
|
if (!elementCount && isWeightBuffer(weight) && Array.isArray(weight.shape) && weight.shape.length > 0) {
|
|
262
295
|
elementCount = weight.shape.reduce((total, dim) => total * Math.max(1, Math.trunc(Number(dim) || 0)), 1);
|
|
263
296
|
}
|
|
297
|
+
const isQ4K = sourceDtype === 'q4k' || sourceDtype === 'q4_k_m' || sourceDtype === 'q4_k';
|
|
264
298
|
if (!elementCount) {
|
|
265
|
-
|
|
266
|
-
|
|
299
|
+
if (isQ4K) {
|
|
300
|
+
elementCount = Math.trunc(sourceBuffer.size / Q4K_BLOCK_BYTES) * QK_K;
|
|
301
|
+
} else {
|
|
302
|
+
const inferredBytes = sourceDtype === 'f16' || sourceDtype === 'bf16' ? 2 : 4;
|
|
303
|
+
elementCount = Math.trunc(sourceBuffer.size / inferredBytes);
|
|
304
|
+
}
|
|
305
|
+
}
|
|
306
|
+
|
|
307
|
+
if (isQ4K) {
|
|
308
|
+
const numBlocks = Math.ceil(elementCount / QK_K);
|
|
309
|
+
const q4kBytes = numBlocks * Q4K_BLOCK_BYTES;
|
|
310
|
+
const raw = await readBuffer(sourceBuffer, q4kBytes);
|
|
311
|
+
const decoded = dequantizeQ4KM(new Uint8Array(raw), numBlocks, [elementCount]);
|
|
312
|
+
if (expectedElements != null && decoded.length !== expectedElements) {
|
|
313
|
+
throw new Error(
|
|
314
|
+
`Weight "${label}" Q4K decoded length ${decoded.length}, expected ${expectedElements}.`
|
|
315
|
+
);
|
|
316
|
+
}
|
|
317
|
+
return decoded;
|
|
267
318
|
}
|
|
268
319
|
|
|
269
320
|
if (!sourceDtype) {
|
|
@@ -395,10 +446,17 @@ async function createLayerRuntimeState(
|
|
|
395
446
|
|
|
396
447
|
let convKernelSize = toPositiveInt(config.linearConvKernelDim) ?? null;
|
|
397
448
|
if (isWeightBuffer(convKernel) && Array.isArray(convKernel.shape) && convKernel.shape.length >= 3) {
|
|
398
|
-
|
|
449
|
+
const shapeKernelSize = toPositiveInt(convKernel.shape[2]) ?? null;
|
|
450
|
+
if (convKernelSize != null && shapeKernelSize != null && convKernelSize !== shapeKernelSize) {
|
|
451
|
+
throw new Error(
|
|
452
|
+
`linear_attention layer ${layerIdx} declares linearConvKernelDim=${convKernelSize}, ` +
|
|
453
|
+
`but conv1d weight shape implies ${shapeKernelSize}.`
|
|
454
|
+
);
|
|
455
|
+
}
|
|
456
|
+
convKernelSize = shapeKernelSize ?? convKernelSize;
|
|
399
457
|
}
|
|
400
458
|
if (!convKernelSize) {
|
|
401
|
-
|
|
459
|
+
throw new Error(`linear_attention layer ${layerIdx} requires linearConvKernelDim.`);
|
|
402
460
|
}
|
|
403
461
|
|
|
404
462
|
const convWeight = await readWeightAsF32(
|
|
@@ -425,6 +483,7 @@ async function createLayerRuntimeState(
|
|
|
425
483
|
expectedNormElements,
|
|
426
484
|
`L${layerIdx}.linear_attn.norm.weight`
|
|
427
485
|
);
|
|
486
|
+
const runtimeNorm = applyLinearNormWeightOffset(norm, config.rmsNormWeightOffset === true);
|
|
428
487
|
|
|
429
488
|
const aNegExp = new Float32Array(aLog.length);
|
|
430
489
|
for (let i = 0; i < aLog.length; i++) {
|
|
@@ -435,6 +494,11 @@ async function createLayerRuntimeState(
|
|
|
435
494
|
const recurrentState = new Float32Array(
|
|
436
495
|
projectionLayout.numVHeads * projectionLayout.headKDim * projectionLayout.headVDim
|
|
437
496
|
);
|
|
497
|
+
const rmsNormEps = Number(config.rmsNormEps);
|
|
498
|
+
if (!Number.isFinite(rmsNormEps) || rmsNormEps <= 0) {
|
|
499
|
+
throw new Error(`linear_attention layer ${layerIdx} requires a positive rmsNormEps.`);
|
|
500
|
+
}
|
|
501
|
+
|
|
438
502
|
const layerState = {
|
|
439
503
|
layerIdx,
|
|
440
504
|
seqLen: currentSeqLen,
|
|
@@ -452,11 +516,11 @@ async function createLayerRuntimeState(
|
|
|
452
516
|
vSize: projectionLayout.vSize,
|
|
453
517
|
qRep: projectionLayout.qRep,
|
|
454
518
|
normMode,
|
|
455
|
-
rmsNormEps
|
|
519
|
+
rmsNormEps,
|
|
456
520
|
convWeight,
|
|
457
521
|
dtBias,
|
|
458
522
|
aNegExp,
|
|
459
|
-
normWeight:
|
|
523
|
+
normWeight: runtimeNorm,
|
|
460
524
|
convState,
|
|
461
525
|
recurrentState,
|
|
462
526
|
convWeightGPU: null,
|
|
@@ -681,13 +745,13 @@ export async function runLinearAttentionLayer(inputTensor, layerWeights, options
|
|
|
681
745
|
const normWeightBuffer = getNormWeightBuffer(layerWeights.inputNorm, `L${layerIdx}.linear_input_norm`);
|
|
682
746
|
try {
|
|
683
747
|
if (recorder) {
|
|
684
|
-
normedTensor = await recordRMSNorm(recorder, inputTensor, normWeightBuffer,
|
|
748
|
+
normedTensor = await recordRMSNorm(recorder, inputTensor, normWeightBuffer, layerState.rmsNormEps, {
|
|
685
749
|
batchSize: numTokens,
|
|
686
750
|
hiddenSize,
|
|
687
751
|
rmsNormWeightOffset: config.rmsNormWeightOffset,
|
|
688
752
|
});
|
|
689
753
|
} else {
|
|
690
|
-
normedTensor = await runRMSNorm(inputTensor, normWeightBuffer,
|
|
754
|
+
normedTensor = await runRMSNorm(inputTensor, normWeightBuffer, layerState.rmsNormEps, {
|
|
691
755
|
batchSize: numTokens,
|
|
692
756
|
hiddenSize,
|
|
693
757
|
rmsNormWeightOffset: config.rmsNormWeightOffset,
|
|
@@ -755,6 +819,38 @@ export async function runLinearAttentionLayer(inputTensor, layerWeights, options
|
|
|
755
819
|
});
|
|
756
820
|
|
|
757
821
|
try {
|
|
822
|
+
await runProbes('linear_qkv_proj', qkvTensor.buffer, {
|
|
823
|
+
layerIdx,
|
|
824
|
+
numTokens,
|
|
825
|
+
hiddenSize: projectionLayout.convDim,
|
|
826
|
+
probes: options.debugProbes,
|
|
827
|
+
recorder,
|
|
828
|
+
dtype: qkvTensor.dtype,
|
|
829
|
+
});
|
|
830
|
+
await runProbes('linear_z_proj', zTensor.buffer, {
|
|
831
|
+
layerIdx,
|
|
832
|
+
numTokens,
|
|
833
|
+
hiddenSize: projectionLayout.valueDim,
|
|
834
|
+
probes: options.debugProbes,
|
|
835
|
+
recorder,
|
|
836
|
+
dtype: zTensor.dtype,
|
|
837
|
+
});
|
|
838
|
+
await runProbes('linear_a_proj', aTensor.buffer, {
|
|
839
|
+
layerIdx,
|
|
840
|
+
numTokens,
|
|
841
|
+
hiddenSize: projectionLayout.numVHeads,
|
|
842
|
+
probes: options.debugProbes,
|
|
843
|
+
recorder,
|
|
844
|
+
dtype: aTensor.dtype,
|
|
845
|
+
});
|
|
846
|
+
await runProbes('linear_b_proj', bTensor.buffer, {
|
|
847
|
+
layerIdx,
|
|
848
|
+
numTokens,
|
|
849
|
+
hiddenSize: projectionLayout.numVHeads,
|
|
850
|
+
probes: options.debugProbes,
|
|
851
|
+
recorder,
|
|
852
|
+
dtype: bTensor.dtype,
|
|
853
|
+
});
|
|
758
854
|
const coreTensor = await runLinearAttentionCoreGPU(
|
|
759
855
|
qkvTensor,
|
|
760
856
|
zTensor,
|
|
@@ -768,6 +864,14 @@ export async function runLinearAttentionLayer(inputTensor, layerWeights, options
|
|
|
768
864
|
recorder,
|
|
769
865
|
}
|
|
770
866
|
);
|
|
867
|
+
await runProbes('linear_core_out', coreTensor.buffer, {
|
|
868
|
+
layerIdx,
|
|
869
|
+
numTokens,
|
|
870
|
+
hiddenSize: projectionLayout.valueDim,
|
|
871
|
+
probes: options.debugProbes,
|
|
872
|
+
recorder,
|
|
873
|
+
dtype: coreTensor.dtype,
|
|
874
|
+
});
|
|
771
875
|
layerState.seqLen = currentSeqLen + numTokens;
|
|
772
876
|
const outProjWeight = getWeightBuffer(layerWeights.oProj, `L${layerIdx}.linear_out_proj`);
|
|
773
877
|
try {
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
|
|
2
2
|
|
|
3
3
|
import { getDevice, getKernelCapabilities } from '../../../../gpu/device.js';
|
|
4
|
-
import { acquireBuffer, releaseBuffer
|
|
4
|
+
import { acquireBuffer, releaseBuffer } from '../../../../memory/buffer-pool.js';
|
|
5
5
|
import { runMatmul, runRMSNorm } from '../../../../gpu/kernel-selector.js';
|
|
6
6
|
import { recordMatmul } from '../../../../gpu/kernels/matmul.js';
|
|
7
7
|
import { recordRMSNorm } from '../../../../gpu/kernels/rmsnorm.js';
|
|
@@ -13,6 +13,7 @@ import { getRuntimeConfig } from '../../../../config/runtime.js';
|
|
|
13
13
|
import { selectRuleValue } from '../../../../rules/rule-registry.js';
|
|
14
14
|
import { runProbes } from '../probes.js';
|
|
15
15
|
import { f16BufferToF32 } from './cpu.js';
|
|
16
|
+
import { readBufferWithCleanup } from './utils.js';
|
|
16
17
|
|
|
17
18
|
function shouldForceStableF32Logits(config, inputDtype) {
|
|
18
19
|
// Small Gemma-family checkpoints can overflow in pure F16 logits path after RMSNorm offset.
|
|
@@ -187,14 +188,18 @@ export async function computeChunkedLogitsGPU(
|
|
|
187
188
|
}
|
|
188
189
|
|
|
189
190
|
const logitsBytes = selectRuleValue('shared', 'dtype', 'bytesFromDtype', { dtype: logitsTensor.dtype });
|
|
190
|
-
const chunkLogitsData = await
|
|
191
|
+
const chunkLogitsData = await readBufferWithCleanup(
|
|
192
|
+
logitsTensor.buffer,
|
|
193
|
+
numTokens * rowCount * logitsBytes,
|
|
194
|
+
() => {
|
|
195
|
+
releaseBuffer(logitsTensor.buffer);
|
|
196
|
+
releaseBuffer(weightBuffer.buffer);
|
|
197
|
+
}
|
|
198
|
+
);
|
|
191
199
|
const chunkLogits = logitsTensor.dtype === 'f16'
|
|
192
200
|
? f16BufferToF32(chunkLogitsData)
|
|
193
201
|
: new Float32Array(chunkLogitsData);
|
|
194
202
|
writeChunkLogits(logits, chunkLogits, numTokens, vocabSize, rowOffset, rowCount);
|
|
195
|
-
|
|
196
|
-
releaseBuffer(logitsTensor.buffer);
|
|
197
|
-
releaseBuffer(weightBuffer.buffer);
|
|
198
203
|
}
|
|
199
204
|
|
|
200
205
|
return logits;
|
|
@@ -299,7 +304,7 @@ export async function computeLogitsGPU(
|
|
|
299
304
|
|
|
300
305
|
const logitsTensor = await runMatmul(normedTensor, lmHeadBuffer, numTokens, matmulVocabSize, hiddenSize, {
|
|
301
306
|
transposeB: 'auto',
|
|
302
|
-
role:
|
|
307
|
+
role: 'lm_head',
|
|
303
308
|
kernelPath: config.kernelPath ?? null,
|
|
304
309
|
});
|
|
305
310
|
|
|
@@ -386,7 +391,7 @@ export async function recordLogitsGPU(
|
|
|
386
391
|
// Record matmul (no submit)
|
|
387
392
|
const logitsTensor = await recordMatmul(recorder, normedTensor, lmHeadBuffer, numTokens, matmulVocabSize, hiddenSize, {
|
|
388
393
|
transposeB: 'auto',
|
|
389
|
-
role:
|
|
394
|
+
role: 'lm_head',
|
|
390
395
|
kernelPath: config.kernelPath ?? null,
|
|
391
396
|
});
|
|
392
397
|
|
|
@@ -25,6 +25,10 @@ export { computeLogitsGPU, recordLogitsGPU, computeChunkedLogitsGPU, resolveCpuW
|
|
|
25
25
|
// Re-export utilities
|
|
26
26
|
export { extractLastPositionLogits, finalizeLogits } from './utils.js';
|
|
27
27
|
|
|
28
|
+
export interface ComputeLogitsOptions {
|
|
29
|
+
lastPositionOnly?: boolean;
|
|
30
|
+
}
|
|
31
|
+
|
|
28
32
|
/**
|
|
29
33
|
* Compute logits from hidden states.
|
|
30
34
|
*
|
|
@@ -53,5 +57,6 @@ export function computeLogits(
|
|
|
53
57
|
debugFlags?: LogitsDebugFlags,
|
|
54
58
|
getNormWeightBuffer?: (weight: GPUBuffer | Float32Array | ArrayBuffer, label: string) => GPUBuffer,
|
|
55
59
|
debugCheckBuffer?: (buffer: GPUBuffer, label: string, numTokens: number, expectedDim?: number) => Promise<void>,
|
|
56
|
-
debugProbes?: ProbeConfigSchema[] | null
|
|
60
|
+
debugProbes?: ProbeConfigSchema[] | null,
|
|
61
|
+
options?: ComputeLogitsOptions
|
|
57
62
|
): Promise<Float32Array>;
|
|
@@ -7,7 +7,7 @@ export { rmsNormCPU, matmulCPU, applySoftcapping, f16ToF32, f16BufferToF32 } fro
|
|
|
7
7
|
export { computeLogitsGPU, recordLogitsGPU, computeChunkedLogitsGPU, resolveCpuWeightDims, resolveLmHeadChunkRows, extractLmHeadChunk, writeChunkLogits } from './gpu.js';
|
|
8
8
|
|
|
9
9
|
// Re-export utilities
|
|
10
|
-
export { extractLastPositionLogits, finalizeLogits } from './utils.js';
|
|
10
|
+
export { extractLastPositionLogits, finalizeLogits, readBufferWithCleanup } from './utils.js';
|
|
11
11
|
|
|
12
12
|
// Imports for computeLogits orchestrator
|
|
13
13
|
import { getDevice } from '../../../../gpu/device.js';
|
|
@@ -20,7 +20,7 @@ import { log, trace, isTraceEnabled } from '../../../../debug/index.js';
|
|
|
20
20
|
import { runProbes } from '../probes.js';
|
|
21
21
|
import { rmsNormCPU, matmulCPU, f16BufferToF32 } from './cpu.js';
|
|
22
22
|
import { resolveCpuWeightDims, computeChunkedLogitsGPU } from './gpu.js';
|
|
23
|
-
import { finalizeLogits } from './utils.js';
|
|
23
|
+
import { finalizeLogits, readBufferWithCleanup } from './utils.js';
|
|
24
24
|
import { getRuntimeConfig } from '../../../../config/runtime.js';
|
|
25
25
|
import { selectRuleValue } from '../../../../rules/rule-registry.js';
|
|
26
26
|
|
|
@@ -253,6 +253,7 @@ export async function computeLogits(
|
|
|
253
253
|
|
|
254
254
|
const lastPositionOnly = options?.lastPositionOnly === true && numTokens > 1;
|
|
255
255
|
const matmulRows = lastPositionOnly ? 1 : numTokens;
|
|
256
|
+
const matmulPhaseOverride = lastPositionOnly ? 'prefill' : null;
|
|
256
257
|
let matmulInputTensor = normedTensor;
|
|
257
258
|
let matmulInputOwned = false;
|
|
258
259
|
if (lastPositionOnly) {
|
|
@@ -270,7 +271,8 @@ export async function computeLogits(
|
|
|
270
271
|
// HuggingFace models store lm_head as [vocabSize, hiddenSize], so transposeB=true
|
|
271
272
|
const logitsTensor = await runMatmul(matmulInputTensor, lmHeadBuffer, matmulRows, matmulVocabSize, hiddenSize, {
|
|
272
273
|
transposeB: 'auto',
|
|
273
|
-
role:
|
|
274
|
+
role: 'lm_head',
|
|
275
|
+
phaseOverride: matmulPhaseOverride,
|
|
274
276
|
kernelPath: config.kernelPath ?? null,
|
|
275
277
|
});
|
|
276
278
|
await runProbes('logits', logitsTensor.buffer, {
|
|
@@ -288,15 +290,14 @@ export async function computeLogits(
|
|
|
288
290
|
// 4. Read back logits
|
|
289
291
|
const logitsBytes = selectRuleValue('shared', 'dtype', 'bytesFromDtype', { dtype: logitsTensor.dtype });
|
|
290
292
|
const logitsReadSize = matmulRows * matmulVocabSize * logitsBytes;
|
|
291
|
-
const logitsData = await
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
if (lmHeadBufferOwned) releaseBuffer(lmHeadGPU);
|
|
293
|
+
const logitsData = await readBufferWithCleanup(logitsTensor.buffer, logitsReadSize, () => {
|
|
294
|
+
if (inputBufferOwned) releaseBuffer(inputBuffer);
|
|
295
|
+
releaseBuffer(normedTensor.buffer);
|
|
296
|
+
if (matmulInputOwned) releaseBuffer(matmulInputTensor.buffer);
|
|
297
|
+
releaseBuffer(logitsTensor.buffer);
|
|
298
|
+
if (!getNormWeightBuffer && !(finalNorm instanceof GPUBuffer)) releaseBuffer(normWeightBuffer);
|
|
299
|
+
if (lmHeadBufferOwned) releaseBuffer(lmHeadGPU);
|
|
300
|
+
});
|
|
300
301
|
|
|
301
302
|
const rawLogits = logitsTensor.dtype === 'f16'
|
|
302
303
|
? f16BufferToF32(logitsData)
|
|
@@ -25,6 +25,13 @@ export function extractLastPositionLogits(
|
|
|
25
25
|
vocabSize: number
|
|
26
26
|
): Float32Array;
|
|
27
27
|
|
|
28
|
+
export function readBufferWithCleanup(
|
|
29
|
+
buffer: GPUBuffer,
|
|
30
|
+
byteLength: number,
|
|
31
|
+
cleanup?: (() => void) | null,
|
|
32
|
+
reader?: ((buffer: GPUBuffer, byteLength: number) => Promise<ArrayBuffer>) | null
|
|
33
|
+
): Promise<ArrayBuffer>;
|
|
34
|
+
|
|
28
35
|
/**
|
|
29
36
|
* Finalize logits by applying padding and softcapping.
|
|
30
37
|
*
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
|
|
2
2
|
|
|
3
|
+
import { readBuffer } from '../../../../memory/buffer-pool.js';
|
|
3
4
|
import { runProbes } from '../probes.js';
|
|
4
5
|
import { applySoftcapping } from './cpu.js';
|
|
5
6
|
|
|
@@ -19,6 +20,14 @@ export function extractLastPositionLogits(
|
|
|
19
20
|
return lastPosLogits;
|
|
20
21
|
}
|
|
21
22
|
|
|
23
|
+
export async function readBufferWithCleanup(buffer, byteLength, cleanup, reader = readBuffer) {
|
|
24
|
+
try {
|
|
25
|
+
return await reader(buffer, byteLength);
|
|
26
|
+
} finally {
|
|
27
|
+
cleanup?.();
|
|
28
|
+
}
|
|
29
|
+
}
|
|
30
|
+
|
|
22
31
|
|
|
23
32
|
export async function finalizeLogits(
|
|
24
33
|
rawLogits,
|
|
@@ -17,42 +17,60 @@ export async function applyLoRA(input, baseOutput, lora, dims, getWeightBuffer,
|
|
|
17
17
|
|
|
18
18
|
const aBuf = getWeightBuffer(lora.a, 'lora_a');
|
|
19
19
|
const bBuf = getWeightBuffer(lora.b, 'lora_b');
|
|
20
|
-
const ownsA = !(lora.a instanceof GPUBuffer) && !isWeightBuffer(lora.a);
|
|
21
|
-
const ownsB = !(lora.b instanceof GPUBuffer) && !isWeightBuffer(lora.b);
|
|
22
|
-
|
|
23
|
-
const
|
|
24
|
-
|
|
25
|
-
|
|
20
|
+
const ownsA = !(typeof GPUBuffer !== 'undefined' && lora.a instanceof GPUBuffer) && !isWeightBuffer(lora.a);
|
|
21
|
+
const ownsB = !(typeof GPUBuffer !== 'undefined' && lora.b instanceof GPUBuffer) && !isWeightBuffer(lora.b);
|
|
22
|
+
// Extract underlying GPUBuffer for WeightBuffers
|
|
23
|
+
const aBufGPU = isWeightBuffer(aBuf) ? aBuf.buffer : aBuf;
|
|
24
|
+
const bBufGPU = isWeightBuffer(bBuf) ? bBuf.buffer : bBuf;
|
|
25
|
+
let loraIntermediate = null;
|
|
26
|
+
let loraOutput = null;
|
|
27
|
+
let scaled = null;
|
|
28
|
+
try {
|
|
29
|
+
loraIntermediate = recorder
|
|
30
|
+
? await recordMatmul(recorder, input, aBuf, M, rank, K, { transposeB: 'auto', role: 'lora_a', kernelPath })
|
|
31
|
+
: await runMatmul(input, aBuf, M, rank, K, { transposeB: 'auto', role: 'lora_a', kernelPath });
|
|
26
32
|
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
33
|
+
loraOutput = recorder
|
|
34
|
+
? await recordMatmul(recorder, loraIntermediate, bBuf, M, N, rank, { transposeB: 'auto', role: 'lora_b', kernelPath })
|
|
35
|
+
: await runMatmul(loraIntermediate, bBuf, M, N, rank, { transposeB: 'auto', role: 'lora_b', kernelPath });
|
|
30
36
|
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
37
|
+
scaled = recorder
|
|
38
|
+
? await recordScale(recorder, loraOutput, lora.scale, { outputBuffer: null })
|
|
39
|
+
: await runScale(loraOutput, lora.scale, { outputBuffer: null });
|
|
34
40
|
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
41
|
+
const combined = recorder
|
|
42
|
+
? await recordResidualAdd(recorder, baseOutput, scaled, M * N)
|
|
43
|
+
: await runResidualAdd(baseOutput, scaled, M * N);
|
|
38
44
|
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
45
|
+
if (recorder) {
|
|
46
|
+
recorder.trackTemporaryBuffer(loraIntermediate.buffer);
|
|
47
|
+
recorder.trackTemporaryBuffer(loraOutput.buffer);
|
|
48
|
+
recorder.trackTemporaryBuffer(scaled.buffer);
|
|
49
|
+
if (ownsA) recorder.trackTemporaryBuffer(aBufGPU);
|
|
50
|
+
if (ownsB) recorder.trackTemporaryBuffer(bBufGPU);
|
|
51
|
+
} else {
|
|
52
|
+
releaseBuffer(loraIntermediate.buffer);
|
|
53
|
+
releaseBuffer(loraOutput.buffer);
|
|
54
|
+
releaseBuffer(scaled.buffer);
|
|
55
|
+
if (ownsA) releaseBuffer(aBufGPU);
|
|
56
|
+
if (ownsB) releaseBuffer(bBufGPU);
|
|
57
|
+
}
|
|
42
58
|
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
recorder
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
59
|
+
return combined;
|
|
60
|
+
} catch (error) {
|
|
61
|
+
if (recorder) {
|
|
62
|
+
if (loraIntermediate) recorder.trackTemporaryBuffer(loraIntermediate.buffer);
|
|
63
|
+
if (loraOutput) recorder.trackTemporaryBuffer(loraOutput.buffer);
|
|
64
|
+
if (scaled) recorder.trackTemporaryBuffer(scaled.buffer);
|
|
65
|
+
if (ownsA) recorder.trackTemporaryBuffer(aBufGPU);
|
|
66
|
+
if (ownsB) recorder.trackTemporaryBuffer(bBufGPU);
|
|
67
|
+
} else {
|
|
68
|
+
if (loraIntermediate) releaseBuffer(loraIntermediate.buffer);
|
|
69
|
+
if (loraOutput) releaseBuffer(loraOutput.buffer);
|
|
70
|
+
if (scaled) releaseBuffer(scaled.buffer);
|
|
71
|
+
if (ownsA) releaseBuffer(aBufGPU);
|
|
72
|
+
if (ownsB) releaseBuffer(bBufGPU);
|
|
73
|
+
}
|
|
74
|
+
throw error;
|
|
55
75
|
}
|
|
56
|
-
|
|
57
|
-
return combined;
|
|
58
76
|
}
|