@simulatte/doppler 0.1.6 → 0.1.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +126 -0
- package/README.md +16 -23
- package/package.json +14 -1
- package/src/adapters/adapter-registry.js +12 -1
- package/src/adapters/lora-loader.js +23 -6
- package/src/bridge/extension-client.d.ts +5 -0
- package/src/bridge/extension-client.js +40 -0
- package/src/bridge/index.d.ts +2 -1
- package/src/bridge/index.js +6 -4
- package/src/browser/browser-converter.js +26 -1
- package/src/browser/file-picker.js +6 -0
- package/src/browser/safetensors-parser-browser.js +84 -1
- package/src/browser/shard-io-browser.js +2 -2
- package/src/browser/tensor-source-download.js +8 -2
- package/src/browser/tensor-source-http.d.ts +1 -0
- package/src/browser/tensor-source-http.js +5 -1
- package/src/client/doppler-api.browser.js +20 -4
- package/src/client/doppler-api.js +19 -3
- package/src/client/doppler-provider/generation.js +12 -0
- package/src/client/doppler-provider/model-manager.d.ts +10 -0
- package/src/client/doppler-provider/model-manager.js +91 -19
- package/src/client/doppler-provider/source-runtime.d.ts +2 -1
- package/src/client/doppler-provider/source-runtime.js +132 -13
- package/src/client/doppler-registry.json +8 -7
- package/src/config/backward-registry-loader.js +17 -2
- package/src/config/execution-v0-contract-check.js +113 -15
- package/src/config/kernel-path-contract-check.js +57 -29
- package/src/config/kernel-path-loader.js +5 -36
- package/src/config/kernels/kernel-ref-digests.js +1 -1
- package/src/config/kernels/registry.js +14 -1
- package/src/config/kernels/registry.json +7 -5
- package/src/config/loader.d.ts +1 -1
- package/src/config/loader.js +12 -2
- package/src/config/merge-contract-check.js +59 -4
- package/src/config/merge-helpers.js +128 -7
- package/src/config/merge.d.ts +1 -0
- package/src/config/merge.js +10 -0
- package/src/config/param-validator.js +47 -2
- package/src/config/presets/kernel-paths/{gemma2-q4k-dequant-f32a.json → gemma2-q4k-dequant-f32a-nosubgroups.json} +3 -3
- package/src/config/presets/kernel-paths/gemma3-f16-fused-f32a-online-streamingprefill.json +223 -0
- package/src/config/presets/kernel-paths/{gemma3-q4k-dequant-f32a.json → gemma3-q4k-dequant-f32a-nosubgroups.json} +3 -3
- package/src/config/presets/kernel-paths/registry.json +29 -8
- package/src/config/presets/models/gemma2.json +2 -2
- package/src/config/presets/runtime/experiments/bench/gemma3-bench-q4k.json +1 -1
- package/src/config/presets/runtime/experiments/debug/gemma3-debug-q4k.json +1 -1
- package/src/config/presets/runtime/experiments/verify/gemma3-verify.json +1 -1
- package/src/config/presets/runtime/kernels/dequant-f16-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/dequant-f32-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/embeddinggemma-q4k-dequant-f32a.json +37 -0
- package/src/config/presets/runtime/kernels/fused-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f16a.json +33 -0
- package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f32a-nosubgroups.json +33 -0
- package/src/config/presets/runtime/kernels/gemma2-q4k-fused-f32a.json +33 -0
- package/src/config/presets/runtime/kernels/safe-q4k.json +6 -13
- package/src/config/presets/runtime/platform/metal-apple-q4k.json +1 -1
- package/src/config/runtime.js +6 -1
- package/src/config/schema/debug.schema.d.ts +5 -0
- package/src/config/schema/doppler.schema.js +16 -21
- package/src/config/schema/inference-defaults.schema.js +3 -3
- package/src/config/schema/kernel-path.schema.d.ts +5 -1
- package/src/config/schema/kernel-thresholds.schema.js +12 -4
- package/src/config/schema/manifest.schema.d.ts +2 -1
- package/src/config/schema/manifest.schema.js +16 -3
- package/src/config/training-defaults.js +30 -22
- package/src/converter/conversion-plan.js +94 -9
- package/src/converter/core.d.ts +7 -0
- package/src/converter/core.js +14 -9
- package/src/converter/execution-v0-manifest.js +4 -1
- package/src/converter/index.d.ts +1 -0
- package/src/converter/index.js +1 -0
- package/src/converter/manifest-inference.js +43 -12
- package/src/converter/parsers/diffusion.js +0 -3
- package/src/converter/quantization-info.js +35 -15
- package/src/converter/shard-packer.d.ts +1 -1
- package/src/converter/shard-packer.js +4 -1
- package/src/debug/config.js +123 -11
- package/src/debug/signals.js +7 -1
- package/src/debug/tensor.d.ts +2 -0
- package/src/debug/tensor.js +13 -2
- package/src/distribution/p2p-control-plane.js +52 -12
- package/src/distribution/p2p-observability.js +43 -7
- package/src/distribution/p2p-webrtc-browser.js +20 -0
- package/src/distribution/shard-delivery.js +77 -26
- package/src/formats/gguf/types.js +33 -16
- package/src/formats/rdrr/groups.d.ts +12 -4
- package/src/formats/rdrr/groups.js +3 -6
- package/src/formats/rdrr/parsing.js +39 -2
- package/src/formats/rdrr/types.d.ts +2 -1
- package/src/gpu/command-recorder.js +86 -61
- package/src/gpu/device.d.ts +1 -0
- package/src/gpu/device.js +73 -19
- package/src/gpu/kernel-tuner/benchmarks.js +326 -316
- package/src/gpu/kernel-tuner/cache.js +71 -4
- package/src/gpu/kernel-tuner/tuner.js +22 -4
- package/src/gpu/kernels/attention.js +15 -34
- package/src/gpu/kernels/backward/adam.js +62 -58
- package/src/gpu/kernels/backward/attention_backward.js +257 -169
- package/src/gpu/kernels/backward/conv2d_backward.js +14 -1
- package/src/gpu/kernels/cast.js +191 -149
- package/src/gpu/kernels/check-stop.js +33 -44
- package/src/gpu/kernels/conv2d.js +27 -17
- package/src/gpu/kernels/cross_entropy_loss.js +21 -15
- package/src/gpu/kernels/depthwise_conv2d.js +36 -26
- package/src/gpu/kernels/dequant.js +178 -126
- package/src/gpu/kernels/energy.d.ts +3 -21
- package/src/gpu/kernels/energy.js +111 -88
- package/src/gpu/kernels/feature-check.js +1 -1
- package/src/gpu/kernels/fused_ffn.js +84 -65
- package/src/gpu/kernels/fused_matmul_residual.js +56 -33
- package/src/gpu/kernels/fused_matmul_rmsnorm.js +62 -45
- package/src/gpu/kernels/gather.js +33 -15
- package/src/gpu/kernels/gelu.js +19 -11
- package/src/gpu/kernels/grouped_pointwise_conv2d.js +33 -23
- package/src/gpu/kernels/groupnorm.js +34 -23
- package/src/gpu/kernels/kv-quantize.js +5 -2
- package/src/gpu/kernels/layernorm.js +35 -19
- package/src/gpu/kernels/logit-merge.js +5 -3
- package/src/gpu/kernels/matmul.js +58 -39
- package/src/gpu/kernels/modulate.js +23 -15
- package/src/gpu/kernels/moe.js +221 -175
- package/src/gpu/kernels/pixel_shuffle.js +22 -14
- package/src/gpu/kernels/relu.js +18 -10
- package/src/gpu/kernels/repeat_channels.js +25 -17
- package/src/gpu/kernels/residual.js +37 -27
- package/src/gpu/kernels/rmsnorm.js +57 -41
- package/src/gpu/kernels/rope.js +3 -0
- package/src/gpu/kernels/sample.js +27 -38
- package/src/gpu/kernels/sana_linear_attention.js +18 -10
- package/src/gpu/kernels/scale.js +18 -11
- package/src/gpu/kernels/shader-cache.js +4 -2
- package/src/gpu/kernels/silu.js +120 -72
- package/src/gpu/kernels/softmax.js +44 -25
- package/src/gpu/kernels/split_qkv.js +23 -13
- package/src/gpu/kernels/transpose.js +18 -10
- package/src/gpu/kernels/transpose.wgsl +5 -3
- package/src/gpu/kernels/upsample2d.js +21 -13
- package/src/gpu/kernels/utils.js +20 -13
- package/src/gpu/partitioned-buffer-pool.js +10 -2
- package/src/gpu/perf-guards.js +2 -9
- package/src/gpu/profiler.js +27 -22
- package/src/gpu/readback-utils.d.ts +16 -0
- package/src/gpu/readback-utils.js +41 -0
- package/src/gpu/submit-tracker.js +13 -0
- package/src/gpu/uniform-cache.d.ts +1 -0
- package/src/gpu/uniform-cache.js +30 -9
- package/src/hotswap/intent-bundle.js +6 -0
- package/src/hotswap/manifest.d.ts +10 -1
- package/src/hotswap/manifest.js +12 -2
- package/src/hotswap/runtime.js +30 -8
- package/src/index-browser.d.ts +44 -0
- package/src/index-browser.js +14 -0
- package/src/inference/browser-harness-contract-helpers.d.ts +5 -0
- package/src/inference/browser-harness-contract-helpers.js +28 -0
- package/src/inference/browser-harness-diffusion-energy-suites.d.ts +2 -0
- package/src/inference/browser-harness-diffusion-energy-suites.js +269 -0
- package/src/inference/browser-harness-model-helpers.d.ts +16 -0
- package/src/inference/browser-harness-model-helpers.js +217 -0
- package/src/inference/browser-harness-report-helpers.d.ts +7 -0
- package/src/inference/browser-harness-report-helpers.js +42 -0
- package/src/inference/browser-harness-runtime-helpers.d.ts +61 -0
- package/src/inference/browser-harness-runtime-helpers.js +415 -0
- package/src/inference/browser-harness-suite-helpers.d.ts +28 -0
- package/src/inference/browser-harness-suite-helpers.js +268 -0
- package/src/inference/browser-harness-text-helpers.d.ts +27 -0
- package/src/inference/browser-harness-text-helpers.js +788 -0
- package/src/inference/browser-harness.d.ts +6 -0
- package/src/inference/browser-harness.js +130 -1996
- package/src/inference/kv-cache/base.js +140 -94
- package/src/inference/kv-cache/tiered.js +5 -3
- package/src/inference/moe-router.js +88 -56
- package/src/inference/multi-model-network.js +5 -3
- package/src/inference/network-evolution.d.ts +11 -2
- package/src/inference/network-evolution.js +20 -21
- package/src/inference/pipelines/context.d.ts +3 -0
- package/src/inference/pipelines/context.js +142 -2
- package/src/inference/pipelines/diffusion/helpers.js +7 -2
- package/src/inference/pipelines/diffusion/pipeline.js +2 -1
- package/src/inference/pipelines/diffusion/sd3-transformer.js +10 -10
- package/src/inference/pipelines/diffusion/vae.js +3 -7
- package/src/inference/pipelines/energy/pipeline.js +27 -21
- package/src/inference/pipelines/energy/quintel.d.ts +5 -0
- package/src/inference/pipelines/energy/quintel.js +11 -0
- package/src/inference/pipelines/energy-head/row-head-pipeline.js +17 -13
- package/src/inference/pipelines/structured/json-head-pipeline.js +26 -11
- package/src/inference/pipelines/text/attention/projections.js +151 -101
- package/src/inference/pipelines/text/attention/record.js +62 -8
- package/src/inference/pipelines/text/attention/run.js +62 -8
- package/src/inference/pipelines/text/config.js +3 -4
- package/src/inference/pipelines/text/embed.js +2 -8
- package/src/inference/pipelines/text/execution-plan.js +41 -19
- package/src/inference/pipelines/text/execution-v0-contract-helpers.d.ts +59 -0
- package/src/inference/pipelines/text/execution-v0-contract-helpers.js +937 -0
- package/src/inference/pipelines/text/execution-v0-runtime-builders.d.ts +15 -0
- package/src/inference/pipelines/text/execution-v0-runtime-builders.js +279 -0
- package/src/inference/pipelines/text/execution-v0.js +62 -1013
- package/src/inference/pipelines/text/generator-steps.d.ts +46 -0
- package/src/inference/pipelines/text/generator-steps.js +298 -207
- package/src/inference/pipelines/text/generator.js +6 -23
- package/src/inference/pipelines/text/init.js +78 -20
- package/src/inference/pipelines/text/kernel-path-auto-select.js +2 -0
- package/src/inference/pipelines/text/kernel-trace.d.ts +2 -0
- package/src/inference/pipelines/text/kernel-trace.js +6 -0
- package/src/inference/pipelines/text/layer.js +3 -9
- package/src/inference/pipelines/text/linear-attention.d.ts +10 -0
- package/src/inference/pipelines/text/linear-attention.js +80 -6
- package/src/inference/pipelines/text/logits/gpu.js +10 -5
- package/src/inference/pipelines/text/logits/index.js +10 -11
- package/src/inference/pipelines/text/logits/utils.d.ts +7 -0
- package/src/inference/pipelines/text/logits/utils.js +9 -0
- package/src/inference/pipelines/text/lora-apply.js +50 -32
- package/src/inference/pipelines/text/model-load.js +279 -104
- package/src/inference/pipelines/text/moe-cache.js +5 -4
- package/src/inference/pipelines/text/moe-cpu-gptoss.js +74 -69
- package/src/inference/pipelines/text/moe-cpu.js +42 -38
- package/src/inference/pipelines/text/moe-gpu.js +110 -86
- package/src/inference/pipelines/text/ops.js +90 -90
- package/src/inference/pipelines/text/probes.js +9 -9
- package/src/inference/pipelines/text/weights.js +17 -7
- package/src/inference/pipelines/text.js +13 -1
- package/src/inference/speculative.d.ts +2 -2
- package/src/inference/speculative.js +4 -18
- package/src/inference/test-harness.d.ts +1 -1
- package/src/inference/test-harness.js +15 -5
- package/src/inference/tokenizer.d.ts +0 -5
- package/src/inference/tokenizer.js +4 -23
- package/src/inference/tokenizers/bpe.js +9 -0
- package/src/inference/tokenizers/bundled.js +20 -0
- package/src/inference/tokenizers/sentencepiece.js +12 -0
- package/src/loader/doppler-loader.js +38 -22
- package/src/loader/dtype-utils.js +3 -44
- package/src/loader/embedding-loader.js +7 -3
- package/src/loader/experts/expert-cache.js +13 -6
- package/src/loader/experts/expert-loader.js +10 -6
- package/src/loader/final-weights-loader.js +8 -4
- package/src/loader/layer-loader.js +2 -1
- package/src/loader/loader-state.js +2 -2
- package/src/loader/memory-monitor.js +8 -0
- package/src/loader/multi-model-loader.d.ts +14 -0
- package/src/loader/multi-model-loader.js +70 -24
- package/src/loader/shard-cache.js +81 -12
- package/src/loader/shard-resolver.js +25 -3
- package/src/loader/tensors/tensor-loader.js +209 -144
- package/src/loader/tensors/tensor-reader.js +76 -19
- package/src/loader/weight-downcast.js +1 -1
- package/src/memory/buffer-pool.d.ts +9 -1
- package/src/memory/buffer-pool.js +109 -44
- package/src/memory/unified-detect.js +1 -1
- package/src/rules/inference/kernel-path.rules.json +24 -8
- package/src/rules/rule-registry.js +25 -1
- package/src/storage/backends/opfs-store.js +68 -24
- package/src/storage/downloader.js +364 -83
- package/src/storage/index.d.ts +3 -0
- package/src/storage/index.js +3 -0
- package/src/storage/preflight.d.ts +2 -2
- package/src/storage/preflight.js +24 -2
- package/src/storage/quickstart-downloader.js +11 -5
- package/src/storage/registry.js +10 -4
- package/src/storage/reports.js +1 -1
- package/src/storage/shard-manager.d.ts +15 -1
- package/src/storage/shard-manager.js +51 -3
- package/src/storage/source-artifact-store.d.ts +52 -0
- package/src/storage/source-artifact-store.js +234 -0
- package/src/tooling/command-api-constants.d.ts +9 -0
- package/src/tooling/command-api-constants.js +9 -0
- package/src/tooling/command-api-family-normalizers.d.ts +9 -0
- package/src/tooling/command-api-family-normalizers.js +343 -0
- package/src/tooling/command-api-helpers.d.ts +25 -0
- package/src/tooling/command-api-helpers.js +262 -0
- package/src/tooling/command-api.js +16 -602
- package/src/tooling/command-envelope.js +4 -1
- package/src/tooling/command-runner-shared.js +52 -18
- package/src/tooling/lean-execution-contract.js +150 -3
- package/src/tooling/node-browser-command-runner.js +161 -271
- package/src/tooling/node-command-runner.js +29 -3
- package/src/tooling/node-converter.js +27 -1
- package/src/tooling/node-source-runtime.d.ts +1 -1
- package/src/tooling/node-source-runtime.js +84 -3
- package/src/tooling/node-webgpu.js +24 -21
- package/src/tooling/opfs-cache.js +21 -4
- package/src/tooling/runtime-input-composition.d.ts +38 -0
- package/src/tooling/runtime-input-composition.js +86 -0
- package/src/tooling/source-runtime-bundle.d.ts +40 -5
- package/src/tooling/source-runtime-bundle.js +261 -34
- package/src/tooling/source-runtime-materializer.d.ts +6 -0
- package/src/tooling/source-runtime-materializer.js +93 -0
- package/src/training/attention-backward.js +32 -17
- package/src/training/autograd.js +80 -52
- package/src/training/checkpoint-watch.d.ts +2 -1
- package/src/training/checkpoint-watch.js +39 -6
- package/src/training/checkpoint.js +40 -11
- package/src/training/clip.js +2 -1
- package/src/training/datasets/token-batch.js +20 -8
- package/src/training/distillation/checkpoint-watch.js +1 -0
- package/src/training/distillation/student-fixture.d.ts +22 -0
- package/src/training/distillation/student-fixture.js +846 -0
- package/src/training/distillation/suite-data.d.ts +45 -0
- package/src/training/distillation/suite-data.js +189 -0
- package/src/training/lora-pipeline.js +4 -7
- package/src/training/lora.js +26 -12
- package/src/training/loss.js +5 -6
- package/src/training/objectives/cross_entropy.js +2 -5
- package/src/training/objectives/distill_kd.js +4 -8
- package/src/training/objectives/distill_triplet.js +4 -8
- package/src/training/objectives/ul_stage2_base.js +4 -8
- package/src/training/operator-command.js +2 -0
- package/src/training/optimizer.js +19 -7
- package/src/training/runner.js +2 -1
- package/src/training/suite.js +18 -978
- package/src/training/tensor-factory.d.ts +9 -0
- package/src/training/tensor-factory.js +13 -0
- package/src/training/trainer.js +3 -5
- package/src/training/ul_dataset.js +3 -5
- package/src/training/workloads.js +70 -79
- package/src/version.js +1 -1
- package/tools/convert-safetensors-node.js +22 -16
- package/tools/doppler-cli.js +44 -25
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
import { parseModelConfig } from './config.js';
|
|
4
4
|
import { getDevice, getDeviceLimits, getKernelCapabilities } from '../../../gpu/device.js';
|
|
5
|
-
import { acquireBuffer } from '../../../memory/buffer-pool.js';
|
|
5
|
+
import { acquireBuffer, releaseBuffer } from '../../../memory/buffer-pool.js';
|
|
6
6
|
import { KVCache, SlidingWindowKVCache, TieredKVCache, BasisDecomposedPagedCache } from '../../kv-cache.js';
|
|
7
7
|
import { Tokenizer } from '../../tokenizer.js';
|
|
8
8
|
import { MoERouter } from '../../moe-router.js';
|
|
@@ -14,6 +14,10 @@ import { PAGED_LAYOUT_SEQ_LEN_THRESHOLD } from '../../../config/schema/index.js'
|
|
|
14
14
|
import { isKernelPathFusedQ4K } from '../../../config/kernel-path-loader.js';
|
|
15
15
|
import { createWeightBuffer, getWeightDtype, isWeightBuffer } from '../../../gpu/weight-buffer.js';
|
|
16
16
|
import { selectRuleValue } from '../../../rules/rule-registry.js';
|
|
17
|
+
import {
|
|
18
|
+
createSourceStorageContext,
|
|
19
|
+
getSourceRuntimeMetadata,
|
|
20
|
+
} from '../../../tooling/source-runtime-bundle.js';
|
|
17
21
|
|
|
18
22
|
function resolveErrorMessage(error) {
|
|
19
23
|
if (error && typeof error === 'object' && typeof error.message === 'string') {
|
|
@@ -56,12 +60,61 @@ function normalizeBaseUrl(baseUrl) {
|
|
|
56
60
|
return baseUrl.replace(/\/$/, '');
|
|
57
61
|
}
|
|
58
62
|
|
|
63
|
+
async function fetchBytes(url, offset = null, length = null) {
|
|
64
|
+
const headers = {};
|
|
65
|
+
if (Number.isFinite(offset) && Number.isFinite(length) && length > 0) {
|
|
66
|
+
const start = Math.max(0, Math.floor(offset));
|
|
67
|
+
const end = start + Math.max(0, Math.floor(length)) - 1;
|
|
68
|
+
headers.Range = `bytes=${start}-${end}`;
|
|
69
|
+
}
|
|
70
|
+
const response = await fetch(url, { headers });
|
|
71
|
+
if (!response.ok) {
|
|
72
|
+
throw new Error(`Failed to fetch ${url}: ${response.status}`);
|
|
73
|
+
}
|
|
74
|
+
return new Uint8Array(await response.arrayBuffer());
|
|
75
|
+
}
|
|
76
|
+
|
|
59
77
|
function createRemoteStorageContext(baseUrl, manifest) {
|
|
60
78
|
const root = normalizeBaseUrl(baseUrl);
|
|
61
79
|
if (!root || !isRDRRManifest(manifest)) {
|
|
62
80
|
return null;
|
|
63
81
|
}
|
|
64
82
|
|
|
83
|
+
const sourceRuntime = getSourceRuntimeMetadata(manifest);
|
|
84
|
+
if (sourceRuntime) {
|
|
85
|
+
const readRange = async (relativePath, offset, length) => {
|
|
86
|
+
const filename = String(relativePath || '').replace(/^\/+/, '');
|
|
87
|
+
if (!filename) {
|
|
88
|
+
throw new Error('Direct-source artifact path is required.');
|
|
89
|
+
}
|
|
90
|
+
const url = `${root}/${filename}`;
|
|
91
|
+
return fetchBytes(url, offset, length);
|
|
92
|
+
};
|
|
93
|
+
const readText = async (relativePath) => {
|
|
94
|
+
const filename = String(relativePath || '').replace(/^\/+/, '');
|
|
95
|
+
if (!filename) return null;
|
|
96
|
+
const response = await fetch(`${root}/${filename}`);
|
|
97
|
+
if (!response.ok) {
|
|
98
|
+
throw new Error(`Failed to fetch ${filename} from ${root}: ${response.status}`);
|
|
99
|
+
}
|
|
100
|
+
return response.text();
|
|
101
|
+
};
|
|
102
|
+
const readBinary = async (relativePath) => {
|
|
103
|
+
const filename = String(relativePath || '').replace(/^\/+/, '');
|
|
104
|
+
if (!filename) {
|
|
105
|
+
throw new Error('Direct-source binary asset path is required.');
|
|
106
|
+
}
|
|
107
|
+
return fetchBytes(`${root}/${filename}`);
|
|
108
|
+
};
|
|
109
|
+
return createSourceStorageContext({
|
|
110
|
+
manifest,
|
|
111
|
+
readRange,
|
|
112
|
+
readText,
|
|
113
|
+
readBinary,
|
|
114
|
+
verifyHashes: true,
|
|
115
|
+
});
|
|
116
|
+
}
|
|
117
|
+
|
|
65
118
|
return {
|
|
66
119
|
async loadShard(index) {
|
|
67
120
|
const shard = manifest.shards[index];
|
|
@@ -69,11 +122,7 @@ function createRemoteStorageContext(baseUrl, manifest) {
|
|
|
69
122
|
if (!filename) {
|
|
70
123
|
throw new Error(`Manifest shard ${index} is missing filename.`);
|
|
71
124
|
}
|
|
72
|
-
|
|
73
|
-
if (!response.ok) {
|
|
74
|
-
throw new Error(`Failed to fetch shard ${index} from ${root}: ${response.status}`);
|
|
75
|
-
}
|
|
76
|
-
return new Uint8Array(await response.arrayBuffer());
|
|
125
|
+
return fetchBytes(`${root}/${filename.replace(/^\/+/, '')}`);
|
|
77
126
|
},
|
|
78
127
|
};
|
|
79
128
|
}
|
|
@@ -326,20 +375,29 @@ export async function initRoPEFrequencies(config, useGPU) {
|
|
|
326
375
|
// Upload to GPU if available
|
|
327
376
|
const device = getDevice();
|
|
328
377
|
if (device && useGPU) {
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
378
|
+
let cosBuffer = null;
|
|
379
|
+
let sinBuffer = null;
|
|
380
|
+
let localCosBuffer = null;
|
|
381
|
+
let localSinBuffer = null;
|
|
382
|
+
try {
|
|
383
|
+
cosBuffer = acquireBuffer(globalFreqs.cos.byteLength, undefined, 'rope_cos');
|
|
384
|
+
sinBuffer = acquireBuffer(globalFreqs.sin.byteLength, undefined, 'rope_sin');
|
|
385
|
+
device.queue.writeBuffer(cosBuffer, 0, globalFreqs.cos.buffer, globalFreqs.cos.byteOffset, globalFreqs.cos.byteLength);
|
|
386
|
+
device.queue.writeBuffer(sinBuffer, 0, globalFreqs.sin.buffer, globalFreqs.sin.byteOffset, globalFreqs.sin.byteLength);
|
|
387
|
+
|
|
388
|
+
if (localFreqs) {
|
|
389
|
+
localCosBuffer = acquireBuffer(localFreqs.cos.byteLength, undefined, 'rope_local_cos');
|
|
390
|
+
localSinBuffer = acquireBuffer(localFreqs.sin.byteLength, undefined, 'rope_local_sin');
|
|
391
|
+
device.queue.writeBuffer(localCosBuffer, 0, localFreqs.cos.buffer, localFreqs.cos.byteOffset, localFreqs.cos.byteLength);
|
|
392
|
+
device.queue.writeBuffer(localSinBuffer, 0, localFreqs.sin.buffer, localFreqs.sin.byteOffset, localFreqs.sin.byteLength);
|
|
393
|
+
}
|
|
394
|
+
} catch (error) {
|
|
395
|
+
for (const buffer of [cosBuffer, sinBuffer, localCosBuffer, localSinBuffer]) {
|
|
396
|
+
if (buffer) {
|
|
397
|
+
releaseBuffer(buffer);
|
|
398
|
+
}
|
|
399
|
+
}
|
|
400
|
+
throw error;
|
|
343
401
|
}
|
|
344
402
|
|
|
345
403
|
log.debug(
|
|
@@ -78,6 +78,7 @@ export function resolveCapabilityKernelPathRef(configuredKernelPathRef, kernelPa
|
|
|
78
78
|
|
|
79
79
|
const normalizedPolicy = resolveKernelPathPolicy(kernelPathPolicy);
|
|
80
80
|
const hasSubgroups = capabilities?.hasSubgroups === true;
|
|
81
|
+
const hasF16 = capabilities?.hasF16 === true;
|
|
81
82
|
const normalizedSource = normalizeKernelPathSource(kernelPathSource);
|
|
82
83
|
const allowCapabilityAutoSelection = normalizedPolicy.mode === 'capability-aware'
|
|
83
84
|
&& normalizedPolicy.sourceScope.includes(normalizedSource);
|
|
@@ -85,6 +86,7 @@ export function resolveCapabilityKernelPathRef(configuredKernelPathRef, kernelPa
|
|
|
85
86
|
return selectRuleValue('inference', 'kernelPath', 'autoSelect', {
|
|
86
87
|
kernelPathRef: configuredKernelPathRef,
|
|
87
88
|
hasSubgroups,
|
|
89
|
+
hasF16,
|
|
88
90
|
allowCapabilityAutoSelection,
|
|
89
91
|
});
|
|
90
92
|
}
|
|
@@ -283,6 +283,9 @@ export async function traceStep(name, label, layer, outputBuffer, outputShape, o
|
|
|
283
283
|
if (layer >= 0 && !kernelTrace.shouldTraceLayer(layer)) return;
|
|
284
284
|
|
|
285
285
|
const output = await snapshotTensor(outputBuffer, outputShape);
|
|
286
|
+
if (!output.ok) {
|
|
287
|
+
throw new Error(`[TRACE] Failed to snapshot output for ${label}: ${output.error}`);
|
|
288
|
+
}
|
|
286
289
|
|
|
287
290
|
// Snapshot inputs if provided (expensive - only do if tracing)
|
|
288
291
|
|
|
@@ -290,6 +293,9 @@ export async function traceStep(name, label, layer, outputBuffer, outputShape, o
|
|
|
290
293
|
if (options?.inputs && options?.inputShapes) {
|
|
291
294
|
for (let i = 0; i < options.inputs.length; i++) {
|
|
292
295
|
const snap = await snapshotTensor(options.inputs[i], options.inputShapes[i]);
|
|
296
|
+
if (!snap.ok) {
|
|
297
|
+
throw new Error(`[TRACE] Failed to snapshot input ${i} for ${label}: ${snap.error}`);
|
|
298
|
+
}
|
|
293
299
|
inputs.push(snap);
|
|
294
300
|
}
|
|
295
301
|
}
|
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
import { log, trace } from '../../../debug/index.js';
|
|
4
4
|
import { getDevice } from '../../../gpu/device.js';
|
|
5
|
-
import { releaseBuffer } from '../../../memory/buffer-pool.js';
|
|
5
|
+
import { releaseBuffer, readBuffer } from '../../../memory/buffer-pool.js';
|
|
6
6
|
import { allowReadback } from '../../../gpu/perf-guards.js';
|
|
7
7
|
import { createTensor } from '../../../gpu/tensor.js';
|
|
8
8
|
import {
|
|
@@ -228,6 +228,7 @@ export async function processLayerGPU(layerIdx, inputBuffer, numTokens, isPrefil
|
|
|
228
228
|
linearRuntime: context.linearAttentionRuntime ?? null,
|
|
229
229
|
getWeightBuffer: (weight, label) => getWeightBuffer(weight, label),
|
|
230
230
|
getNormWeightBuffer: (weight, label) => getNormWeightBuffer(weight, label, weightConfig, debugFlags),
|
|
231
|
+
debugProbes: context.debugProbes,
|
|
231
232
|
recorder: recorder ?? null,
|
|
232
233
|
});
|
|
233
234
|
} else {
|
|
@@ -314,14 +315,7 @@ export async function processLayerGPU(layerIdx, inputBuffer, numTokens, isPrefil
|
|
|
314
315
|
if (allowReadback(`layer.attn-out.${layerIdx}`)) {
|
|
315
316
|
try {
|
|
316
317
|
const sampleSize = Math.min(128, attnOutput.buffer.size);
|
|
317
|
-
const
|
|
318
|
-
const enc = device.createCommandEncoder();
|
|
319
|
-
enc.copyBufferToBuffer(attnOutput.buffer, 0, staging, 0, sampleSize);
|
|
320
|
-
device.queue.submit([enc.finish()]);
|
|
321
|
-
await staging.mapAsync(GPUMapMode.READ);
|
|
322
|
-
const data = new Float32Array(staging.getMappedRange().slice(0));
|
|
323
|
-
staging.unmap();
|
|
324
|
-
staging.destroy();
|
|
318
|
+
const data = new Float32Array(await readBuffer(attnOutput.buffer, sampleSize));
|
|
325
319
|
let maxAbs = 0;
|
|
326
320
|
for (let i = 0; i < data.length; i++) {
|
|
327
321
|
const abs = Math.abs(data[i]);
|
|
@@ -3,6 +3,7 @@ import type { Tensor } from '../../../gpu/tensor.js';
|
|
|
3
3
|
import type { WeightBuffer } from '../../../gpu/weight-buffer.js';
|
|
4
4
|
import type { CommandRecorder } from '../../../gpu/command-recorder.js';
|
|
5
5
|
import type { LinearNormMode } from '../../../config/schema/index.js';
|
|
6
|
+
import type { ProbeConfigSchema } from '../../../config/schema/index.js';
|
|
6
7
|
|
|
7
8
|
export interface LinearLayerRuntimeState {
|
|
8
9
|
layerIdx: number;
|
|
@@ -67,6 +68,7 @@ export interface RunLinearAttentionLayerOptions {
|
|
|
67
68
|
weight: GPUBuffer | Float32Array | ArrayBuffer,
|
|
68
69
|
label: string
|
|
69
70
|
) => GPUBuffer;
|
|
71
|
+
debugProbes?: ProbeConfigSchema[] | null;
|
|
70
72
|
recorder?: CommandRecorder | null;
|
|
71
73
|
}
|
|
72
74
|
|
|
@@ -74,6 +76,14 @@ export declare function hasLinearAttentionLayers(layerTypes: unknown): boolean;
|
|
|
74
76
|
|
|
75
77
|
export declare function createLinearAttentionRuntime(): LinearAttentionRuntime;
|
|
76
78
|
|
|
79
|
+
export declare function inferLinearNormMode(
|
|
80
|
+
weight: { size?: number; dtype?: string } | GPUBuffer | WeightBuffer | ArrayBufferView | ArrayBuffer | null | undefined,
|
|
81
|
+
projectionLayout: {
|
|
82
|
+
headVDim: number;
|
|
83
|
+
valueDim: number;
|
|
84
|
+
}
|
|
85
|
+
): LinearNormMode | null;
|
|
86
|
+
|
|
77
87
|
export declare function resetLinearAttentionRuntime(
|
|
78
88
|
runtime: LinearAttentionRuntime | null | undefined
|
|
79
89
|
): LinearAttentionRuntime;
|
|
@@ -4,6 +4,7 @@ import { readBuffer, releaseBuffer, uploadData, acquireBuffer } from '../../../m
|
|
|
4
4
|
import { log } from '../../../debug/index.js';
|
|
5
5
|
import { decodeReadback } from './debug-utils/index.js';
|
|
6
6
|
import { runLinearAttentionCoreGPU } from '../../../gpu/kernels/linear-attention-core.js';
|
|
7
|
+
import { runProbes } from './probes.js';
|
|
7
8
|
|
|
8
9
|
const LINEAR_RUNTIME_SCHEMA_VERSION = 1;
|
|
9
10
|
const QK_L2NORM_EPS = 1e-6;
|
|
@@ -173,9 +174,22 @@ function inferLinearNormModeFromWeight(weight, projectionLayout) {
|
|
|
173
174
|
if (weight instanceof ArrayBuffer) {
|
|
174
175
|
return classify(Math.trunc(weight.byteLength / Float32Array.BYTES_PER_ELEMENT));
|
|
175
176
|
}
|
|
177
|
+
const explicitDtype = typeof weight?.dtype === 'string' ? weight.dtype.toLowerCase() : null;
|
|
178
|
+
const trackedDtype = isGpuBuffer(weight) ? String(getBufferDtype(weight) ?? '').toLowerCase() : '';
|
|
179
|
+
const bytesPerElement = bytesFromDtype(explicitDtype || trackedDtype || null);
|
|
180
|
+
const sizedElements = Number.isFinite(weight?.size)
|
|
181
|
+
? Math.trunc(Number(weight.size) / bytesPerElement)
|
|
182
|
+
: null;
|
|
183
|
+
if (sizedElements && Number(weight.size) % bytesPerElement === 0) {
|
|
184
|
+
return classify(sizedElements);
|
|
185
|
+
}
|
|
176
186
|
return null;
|
|
177
187
|
}
|
|
178
188
|
|
|
189
|
+
export function inferLinearNormMode(weight, projectionLayout) {
|
|
190
|
+
return inferLinearNormModeFromWeight(weight, projectionLayout);
|
|
191
|
+
}
|
|
192
|
+
|
|
179
193
|
function resolveLinearNormMode(configNormMode, normWeight, projectionLayout, layerIdx) {
|
|
180
194
|
const configuredMode = normalizeLinearNormMode(configNormMode);
|
|
181
195
|
const inferredMode = inferLinearNormModeFromWeight(normWeight, projectionLayout);
|
|
@@ -185,7 +199,15 @@ function resolveLinearNormMode(configNormMode, normWeight, projectionLayout, lay
|
|
|
185
199
|
`but norm.weight shape implies "${inferredMode}".`
|
|
186
200
|
);
|
|
187
201
|
}
|
|
188
|
-
|
|
202
|
+
if (configuredMode) {
|
|
203
|
+
return configuredMode;
|
|
204
|
+
}
|
|
205
|
+
if (inferredMode) {
|
|
206
|
+
return inferredMode;
|
|
207
|
+
}
|
|
208
|
+
throw new Error(
|
|
209
|
+
`linear_attention layer ${layerIdx} requires explicit linearNormMode or a norm.weight shape that resolves it.`
|
|
210
|
+
);
|
|
189
211
|
}
|
|
190
212
|
|
|
191
213
|
async function readWeightAsF32(weight, expectedElements, label) {
|
|
@@ -395,10 +417,17 @@ async function createLayerRuntimeState(
|
|
|
395
417
|
|
|
396
418
|
let convKernelSize = toPositiveInt(config.linearConvKernelDim) ?? null;
|
|
397
419
|
if (isWeightBuffer(convKernel) && Array.isArray(convKernel.shape) && convKernel.shape.length >= 3) {
|
|
398
|
-
|
|
420
|
+
const shapeKernelSize = toPositiveInt(convKernel.shape[2]) ?? null;
|
|
421
|
+
if (convKernelSize != null && shapeKernelSize != null && convKernelSize !== shapeKernelSize) {
|
|
422
|
+
throw new Error(
|
|
423
|
+
`linear_attention layer ${layerIdx} declares linearConvKernelDim=${convKernelSize}, ` +
|
|
424
|
+
`but conv1d weight shape implies ${shapeKernelSize}.`
|
|
425
|
+
);
|
|
426
|
+
}
|
|
427
|
+
convKernelSize = shapeKernelSize ?? convKernelSize;
|
|
399
428
|
}
|
|
400
429
|
if (!convKernelSize) {
|
|
401
|
-
|
|
430
|
+
throw new Error(`linear_attention layer ${layerIdx} requires linearConvKernelDim.`);
|
|
402
431
|
}
|
|
403
432
|
|
|
404
433
|
const convWeight = await readWeightAsF32(
|
|
@@ -435,6 +464,11 @@ async function createLayerRuntimeState(
|
|
|
435
464
|
const recurrentState = new Float32Array(
|
|
436
465
|
projectionLayout.numVHeads * projectionLayout.headKDim * projectionLayout.headVDim
|
|
437
466
|
);
|
|
467
|
+
const rmsNormEps = Number(config.rmsNormEps);
|
|
468
|
+
if (!Number.isFinite(rmsNormEps) || rmsNormEps <= 0) {
|
|
469
|
+
throw new Error(`linear_attention layer ${layerIdx} requires a positive rmsNormEps.`);
|
|
470
|
+
}
|
|
471
|
+
|
|
438
472
|
const layerState = {
|
|
439
473
|
layerIdx,
|
|
440
474
|
seqLen: currentSeqLen,
|
|
@@ -452,7 +486,7 @@ async function createLayerRuntimeState(
|
|
|
452
486
|
vSize: projectionLayout.vSize,
|
|
453
487
|
qRep: projectionLayout.qRep,
|
|
454
488
|
normMode,
|
|
455
|
-
rmsNormEps
|
|
489
|
+
rmsNormEps,
|
|
456
490
|
convWeight,
|
|
457
491
|
dtBias,
|
|
458
492
|
aNegExp,
|
|
@@ -681,13 +715,13 @@ export async function runLinearAttentionLayer(inputTensor, layerWeights, options
|
|
|
681
715
|
const normWeightBuffer = getNormWeightBuffer(layerWeights.inputNorm, `L${layerIdx}.linear_input_norm`);
|
|
682
716
|
try {
|
|
683
717
|
if (recorder) {
|
|
684
|
-
normedTensor = await recordRMSNorm(recorder, inputTensor, normWeightBuffer,
|
|
718
|
+
normedTensor = await recordRMSNorm(recorder, inputTensor, normWeightBuffer, layerState.rmsNormEps, {
|
|
685
719
|
batchSize: numTokens,
|
|
686
720
|
hiddenSize,
|
|
687
721
|
rmsNormWeightOffset: config.rmsNormWeightOffset,
|
|
688
722
|
});
|
|
689
723
|
} else {
|
|
690
|
-
normedTensor = await runRMSNorm(inputTensor, normWeightBuffer,
|
|
724
|
+
normedTensor = await runRMSNorm(inputTensor, normWeightBuffer, layerState.rmsNormEps, {
|
|
691
725
|
batchSize: numTokens,
|
|
692
726
|
hiddenSize,
|
|
693
727
|
rmsNormWeightOffset: config.rmsNormWeightOffset,
|
|
@@ -755,6 +789,38 @@ export async function runLinearAttentionLayer(inputTensor, layerWeights, options
|
|
|
755
789
|
});
|
|
756
790
|
|
|
757
791
|
try {
|
|
792
|
+
await runProbes('linear_qkv_proj', qkvTensor.buffer, {
|
|
793
|
+
layerIdx,
|
|
794
|
+
numTokens,
|
|
795
|
+
hiddenSize: projectionLayout.convDim,
|
|
796
|
+
probes: options.debugProbes,
|
|
797
|
+
recorder,
|
|
798
|
+
dtype: qkvTensor.dtype,
|
|
799
|
+
});
|
|
800
|
+
await runProbes('linear_z_proj', zTensor.buffer, {
|
|
801
|
+
layerIdx,
|
|
802
|
+
numTokens,
|
|
803
|
+
hiddenSize: projectionLayout.valueDim,
|
|
804
|
+
probes: options.debugProbes,
|
|
805
|
+
recorder,
|
|
806
|
+
dtype: zTensor.dtype,
|
|
807
|
+
});
|
|
808
|
+
await runProbes('linear_a_proj', aTensor.buffer, {
|
|
809
|
+
layerIdx,
|
|
810
|
+
numTokens,
|
|
811
|
+
hiddenSize: projectionLayout.numVHeads,
|
|
812
|
+
probes: options.debugProbes,
|
|
813
|
+
recorder,
|
|
814
|
+
dtype: aTensor.dtype,
|
|
815
|
+
});
|
|
816
|
+
await runProbes('linear_b_proj', bTensor.buffer, {
|
|
817
|
+
layerIdx,
|
|
818
|
+
numTokens,
|
|
819
|
+
hiddenSize: projectionLayout.numVHeads,
|
|
820
|
+
probes: options.debugProbes,
|
|
821
|
+
recorder,
|
|
822
|
+
dtype: bTensor.dtype,
|
|
823
|
+
});
|
|
758
824
|
const coreTensor = await runLinearAttentionCoreGPU(
|
|
759
825
|
qkvTensor,
|
|
760
826
|
zTensor,
|
|
@@ -768,6 +834,14 @@ export async function runLinearAttentionLayer(inputTensor, layerWeights, options
|
|
|
768
834
|
recorder,
|
|
769
835
|
}
|
|
770
836
|
);
|
|
837
|
+
await runProbes('linear_core_out', coreTensor.buffer, {
|
|
838
|
+
layerIdx,
|
|
839
|
+
numTokens,
|
|
840
|
+
hiddenSize: projectionLayout.valueDim,
|
|
841
|
+
probes: options.debugProbes,
|
|
842
|
+
recorder,
|
|
843
|
+
dtype: coreTensor.dtype,
|
|
844
|
+
});
|
|
771
845
|
layerState.seqLen = currentSeqLen + numTokens;
|
|
772
846
|
const outProjWeight = getWeightBuffer(layerWeights.oProj, `L${layerIdx}.linear_out_proj`);
|
|
773
847
|
try {
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
|
|
2
2
|
|
|
3
3
|
import { getDevice, getKernelCapabilities } from '../../../../gpu/device.js';
|
|
4
|
-
import { acquireBuffer, releaseBuffer
|
|
4
|
+
import { acquireBuffer, releaseBuffer } from '../../../../memory/buffer-pool.js';
|
|
5
5
|
import { runMatmul, runRMSNorm } from '../../../../gpu/kernel-selector.js';
|
|
6
6
|
import { recordMatmul } from '../../../../gpu/kernels/matmul.js';
|
|
7
7
|
import { recordRMSNorm } from '../../../../gpu/kernels/rmsnorm.js';
|
|
@@ -13,6 +13,7 @@ import { getRuntimeConfig } from '../../../../config/runtime.js';
|
|
|
13
13
|
import { selectRuleValue } from '../../../../rules/rule-registry.js';
|
|
14
14
|
import { runProbes } from '../probes.js';
|
|
15
15
|
import { f16BufferToF32 } from './cpu.js';
|
|
16
|
+
import { readBufferWithCleanup } from './utils.js';
|
|
16
17
|
|
|
17
18
|
function shouldForceStableF32Logits(config, inputDtype) {
|
|
18
19
|
// Small Gemma-family checkpoints can overflow in pure F16 logits path after RMSNorm offset.
|
|
@@ -187,14 +188,18 @@ export async function computeChunkedLogitsGPU(
|
|
|
187
188
|
}
|
|
188
189
|
|
|
189
190
|
const logitsBytes = selectRuleValue('shared', 'dtype', 'bytesFromDtype', { dtype: logitsTensor.dtype });
|
|
190
|
-
const chunkLogitsData = await
|
|
191
|
+
const chunkLogitsData = await readBufferWithCleanup(
|
|
192
|
+
logitsTensor.buffer,
|
|
193
|
+
numTokens * rowCount * logitsBytes,
|
|
194
|
+
() => {
|
|
195
|
+
releaseBuffer(logitsTensor.buffer);
|
|
196
|
+
releaseBuffer(weightBuffer.buffer);
|
|
197
|
+
}
|
|
198
|
+
);
|
|
191
199
|
const chunkLogits = logitsTensor.dtype === 'f16'
|
|
192
200
|
? f16BufferToF32(chunkLogitsData)
|
|
193
201
|
: new Float32Array(chunkLogitsData);
|
|
194
202
|
writeChunkLogits(logits, chunkLogits, numTokens, vocabSize, rowOffset, rowCount);
|
|
195
|
-
|
|
196
|
-
releaseBuffer(logitsTensor.buffer);
|
|
197
|
-
releaseBuffer(weightBuffer.buffer);
|
|
198
203
|
}
|
|
199
204
|
|
|
200
205
|
return logits;
|
|
@@ -7,7 +7,7 @@ export { rmsNormCPU, matmulCPU, applySoftcapping, f16ToF32, f16BufferToF32 } fro
|
|
|
7
7
|
export { computeLogitsGPU, recordLogitsGPU, computeChunkedLogitsGPU, resolveCpuWeightDims, resolveLmHeadChunkRows, extractLmHeadChunk, writeChunkLogits } from './gpu.js';
|
|
8
8
|
|
|
9
9
|
// Re-export utilities
|
|
10
|
-
export { extractLastPositionLogits, finalizeLogits } from './utils.js';
|
|
10
|
+
export { extractLastPositionLogits, finalizeLogits, readBufferWithCleanup } from './utils.js';
|
|
11
11
|
|
|
12
12
|
// Imports for computeLogits orchestrator
|
|
13
13
|
import { getDevice } from '../../../../gpu/device.js';
|
|
@@ -20,7 +20,7 @@ import { log, trace, isTraceEnabled } from '../../../../debug/index.js';
|
|
|
20
20
|
import { runProbes } from '../probes.js';
|
|
21
21
|
import { rmsNormCPU, matmulCPU, f16BufferToF32 } from './cpu.js';
|
|
22
22
|
import { resolveCpuWeightDims, computeChunkedLogitsGPU } from './gpu.js';
|
|
23
|
-
import { finalizeLogits } from './utils.js';
|
|
23
|
+
import { finalizeLogits, readBufferWithCleanup } from './utils.js';
|
|
24
24
|
import { getRuntimeConfig } from '../../../../config/runtime.js';
|
|
25
25
|
import { selectRuleValue } from '../../../../rules/rule-registry.js';
|
|
26
26
|
|
|
@@ -288,15 +288,14 @@ export async function computeLogits(
|
|
|
288
288
|
// 4. Read back logits
|
|
289
289
|
const logitsBytes = selectRuleValue('shared', 'dtype', 'bytesFromDtype', { dtype: logitsTensor.dtype });
|
|
290
290
|
const logitsReadSize = matmulRows * matmulVocabSize * logitsBytes;
|
|
291
|
-
const logitsData = await
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
if (lmHeadBufferOwned) releaseBuffer(lmHeadGPU);
|
|
291
|
+
const logitsData = await readBufferWithCleanup(logitsTensor.buffer, logitsReadSize, () => {
|
|
292
|
+
if (inputBufferOwned) releaseBuffer(inputBuffer);
|
|
293
|
+
releaseBuffer(normedTensor.buffer);
|
|
294
|
+
if (matmulInputOwned) releaseBuffer(matmulInputTensor.buffer);
|
|
295
|
+
releaseBuffer(logitsTensor.buffer);
|
|
296
|
+
if (!getNormWeightBuffer && !(finalNorm instanceof GPUBuffer)) releaseBuffer(normWeightBuffer);
|
|
297
|
+
if (lmHeadBufferOwned) releaseBuffer(lmHeadGPU);
|
|
298
|
+
});
|
|
300
299
|
|
|
301
300
|
const rawLogits = logitsTensor.dtype === 'f16'
|
|
302
301
|
? f16BufferToF32(logitsData)
|
|
@@ -25,6 +25,13 @@ export function extractLastPositionLogits(
|
|
|
25
25
|
vocabSize: number
|
|
26
26
|
): Float32Array;
|
|
27
27
|
|
|
28
|
+
export function readBufferWithCleanup(
|
|
29
|
+
buffer: GPUBuffer,
|
|
30
|
+
byteLength: number,
|
|
31
|
+
cleanup?: (() => void) | null,
|
|
32
|
+
reader?: ((buffer: GPUBuffer, byteLength: number) => Promise<ArrayBuffer>) | null
|
|
33
|
+
): Promise<ArrayBuffer>;
|
|
34
|
+
|
|
28
35
|
/**
|
|
29
36
|
* Finalize logits by applying padding and softcapping.
|
|
30
37
|
*
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
|
|
2
2
|
|
|
3
|
+
import { readBuffer } from '../../../../memory/buffer-pool.js';
|
|
3
4
|
import { runProbes } from '../probes.js';
|
|
4
5
|
import { applySoftcapping } from './cpu.js';
|
|
5
6
|
|
|
@@ -19,6 +20,14 @@ export function extractLastPositionLogits(
|
|
|
19
20
|
return lastPosLogits;
|
|
20
21
|
}
|
|
21
22
|
|
|
23
|
+
export async function readBufferWithCleanup(buffer, byteLength, cleanup, reader = readBuffer) {
|
|
24
|
+
try {
|
|
25
|
+
return await reader(buffer, byteLength);
|
|
26
|
+
} finally {
|
|
27
|
+
cleanup?.();
|
|
28
|
+
}
|
|
29
|
+
}
|
|
30
|
+
|
|
22
31
|
|
|
23
32
|
export async function finalizeLogits(
|
|
24
33
|
rawLogits,
|
|
@@ -17,42 +17,60 @@ export async function applyLoRA(input, baseOutput, lora, dims, getWeightBuffer,
|
|
|
17
17
|
|
|
18
18
|
const aBuf = getWeightBuffer(lora.a, 'lora_a');
|
|
19
19
|
const bBuf = getWeightBuffer(lora.b, 'lora_b');
|
|
20
|
-
const ownsA = !(lora.a instanceof GPUBuffer) && !isWeightBuffer(lora.a);
|
|
21
|
-
const ownsB = !(lora.b instanceof GPUBuffer) && !isWeightBuffer(lora.b);
|
|
22
|
-
|
|
23
|
-
const
|
|
24
|
-
|
|
25
|
-
|
|
20
|
+
const ownsA = !(typeof GPUBuffer !== 'undefined' && lora.a instanceof GPUBuffer) && !isWeightBuffer(lora.a);
|
|
21
|
+
const ownsB = !(typeof GPUBuffer !== 'undefined' && lora.b instanceof GPUBuffer) && !isWeightBuffer(lora.b);
|
|
22
|
+
// Extract underlying GPUBuffer for WeightBuffers
|
|
23
|
+
const aBufGPU = isWeightBuffer(aBuf) ? aBuf.buffer : aBuf;
|
|
24
|
+
const bBufGPU = isWeightBuffer(bBuf) ? bBuf.buffer : bBuf;
|
|
25
|
+
let loraIntermediate = null;
|
|
26
|
+
let loraOutput = null;
|
|
27
|
+
let scaled = null;
|
|
28
|
+
try {
|
|
29
|
+
loraIntermediate = recorder
|
|
30
|
+
? await recordMatmul(recorder, input, aBuf, M, rank, K, { transposeB: 'auto', role: 'lora_a', kernelPath })
|
|
31
|
+
: await runMatmul(input, aBuf, M, rank, K, { transposeB: 'auto', role: 'lora_a', kernelPath });
|
|
26
32
|
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
33
|
+
loraOutput = recorder
|
|
34
|
+
? await recordMatmul(recorder, loraIntermediate, bBuf, M, N, rank, { transposeB: 'auto', role: 'lora_b', kernelPath })
|
|
35
|
+
: await runMatmul(loraIntermediate, bBuf, M, N, rank, { transposeB: 'auto', role: 'lora_b', kernelPath });
|
|
30
36
|
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
37
|
+
scaled = recorder
|
|
38
|
+
? await recordScale(recorder, loraOutput, lora.scale, { outputBuffer: null })
|
|
39
|
+
: await runScale(loraOutput, lora.scale, { outputBuffer: null });
|
|
34
40
|
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
41
|
+
const combined = recorder
|
|
42
|
+
? await recordResidualAdd(recorder, baseOutput, scaled, M * N)
|
|
43
|
+
: await runResidualAdd(baseOutput, scaled, M * N);
|
|
38
44
|
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
45
|
+
if (recorder) {
|
|
46
|
+
recorder.trackTemporaryBuffer(loraIntermediate.buffer);
|
|
47
|
+
recorder.trackTemporaryBuffer(loraOutput.buffer);
|
|
48
|
+
recorder.trackTemporaryBuffer(scaled.buffer);
|
|
49
|
+
if (ownsA) recorder.trackTemporaryBuffer(aBufGPU);
|
|
50
|
+
if (ownsB) recorder.trackTemporaryBuffer(bBufGPU);
|
|
51
|
+
} else {
|
|
52
|
+
releaseBuffer(loraIntermediate.buffer);
|
|
53
|
+
releaseBuffer(loraOutput.buffer);
|
|
54
|
+
releaseBuffer(scaled.buffer);
|
|
55
|
+
if (ownsA) releaseBuffer(aBufGPU);
|
|
56
|
+
if (ownsB) releaseBuffer(bBufGPU);
|
|
57
|
+
}
|
|
42
58
|
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
recorder
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
59
|
+
return combined;
|
|
60
|
+
} catch (error) {
|
|
61
|
+
if (recorder) {
|
|
62
|
+
if (loraIntermediate) recorder.trackTemporaryBuffer(loraIntermediate.buffer);
|
|
63
|
+
if (loraOutput) recorder.trackTemporaryBuffer(loraOutput.buffer);
|
|
64
|
+
if (scaled) recorder.trackTemporaryBuffer(scaled.buffer);
|
|
65
|
+
if (ownsA) recorder.trackTemporaryBuffer(aBufGPU);
|
|
66
|
+
if (ownsB) recorder.trackTemporaryBuffer(bBufGPU);
|
|
67
|
+
} else {
|
|
68
|
+
if (loraIntermediate) releaseBuffer(loraIntermediate.buffer);
|
|
69
|
+
if (loraOutput) releaseBuffer(loraOutput.buffer);
|
|
70
|
+
if (scaled) releaseBuffer(scaled.buffer);
|
|
71
|
+
if (ownsA) releaseBuffer(aBufGPU);
|
|
72
|
+
if (ownsB) releaseBuffer(bBufGPU);
|
|
73
|
+
}
|
|
74
|
+
throw error;
|
|
55
75
|
}
|
|
56
|
-
|
|
57
|
-
return combined;
|
|
58
76
|
}
|