@simulatte/doppler 0.1.6 → 0.1.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +126 -0
- package/README.md +16 -23
- package/package.json +14 -1
- package/src/adapters/adapter-registry.js +12 -1
- package/src/adapters/lora-loader.js +23 -6
- package/src/bridge/extension-client.d.ts +5 -0
- package/src/bridge/extension-client.js +40 -0
- package/src/bridge/index.d.ts +2 -1
- package/src/bridge/index.js +6 -4
- package/src/browser/browser-converter.js +26 -1
- package/src/browser/file-picker.js +6 -0
- package/src/browser/safetensors-parser-browser.js +84 -1
- package/src/browser/shard-io-browser.js +2 -2
- package/src/browser/tensor-source-download.js +8 -2
- package/src/browser/tensor-source-http.d.ts +1 -0
- package/src/browser/tensor-source-http.js +5 -1
- package/src/client/doppler-api.browser.js +20 -4
- package/src/client/doppler-api.js +19 -3
- package/src/client/doppler-provider/generation.js +12 -0
- package/src/client/doppler-provider/model-manager.d.ts +10 -0
- package/src/client/doppler-provider/model-manager.js +91 -19
- package/src/client/doppler-provider/source-runtime.d.ts +2 -1
- package/src/client/doppler-provider/source-runtime.js +132 -13
- package/src/client/doppler-registry.json +8 -7
- package/src/config/backward-registry-loader.js +17 -2
- package/src/config/execution-v0-contract-check.js +113 -15
- package/src/config/kernel-path-contract-check.js +57 -29
- package/src/config/kernel-path-loader.js +5 -36
- package/src/config/kernels/kernel-ref-digests.js +1 -1
- package/src/config/kernels/registry.js +14 -1
- package/src/config/kernels/registry.json +7 -5
- package/src/config/loader.d.ts +1 -1
- package/src/config/loader.js +12 -2
- package/src/config/merge-contract-check.js +59 -4
- package/src/config/merge-helpers.js +128 -7
- package/src/config/merge.d.ts +1 -0
- package/src/config/merge.js +10 -0
- package/src/config/param-validator.js +47 -2
- package/src/config/presets/kernel-paths/{gemma2-q4k-dequant-f32a.json → gemma2-q4k-dequant-f32a-nosubgroups.json} +3 -3
- package/src/config/presets/kernel-paths/gemma3-f16-fused-f32a-online-streamingprefill.json +223 -0
- package/src/config/presets/kernel-paths/{gemma3-q4k-dequant-f32a.json → gemma3-q4k-dequant-f32a-nosubgroups.json} +3 -3
- package/src/config/presets/kernel-paths/registry.json +29 -8
- package/src/config/presets/models/gemma2.json +2 -2
- package/src/config/presets/runtime/experiments/bench/gemma3-bench-q4k.json +1 -1
- package/src/config/presets/runtime/experiments/debug/gemma3-debug-q4k.json +1 -1
- package/src/config/presets/runtime/experiments/verify/gemma3-verify.json +1 -1
- package/src/config/presets/runtime/kernels/dequant-f16-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/dequant-f32-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/embeddinggemma-q4k-dequant-f32a.json +37 -0
- package/src/config/presets/runtime/kernels/fused-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f16a.json +33 -0
- package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f32a-nosubgroups.json +33 -0
- package/src/config/presets/runtime/kernels/gemma2-q4k-fused-f32a.json +33 -0
- package/src/config/presets/runtime/kernels/safe-q4k.json +6 -13
- package/src/config/presets/runtime/platform/metal-apple-q4k.json +1 -1
- package/src/config/runtime.js +6 -1
- package/src/config/schema/debug.schema.d.ts +5 -0
- package/src/config/schema/doppler.schema.js +16 -21
- package/src/config/schema/inference-defaults.schema.js +3 -3
- package/src/config/schema/kernel-path.schema.d.ts +5 -1
- package/src/config/schema/kernel-thresholds.schema.js +12 -4
- package/src/config/schema/manifest.schema.d.ts +2 -1
- package/src/config/schema/manifest.schema.js +16 -3
- package/src/config/training-defaults.js +30 -22
- package/src/converter/conversion-plan.js +94 -9
- package/src/converter/core.d.ts +7 -0
- package/src/converter/core.js +14 -9
- package/src/converter/execution-v0-manifest.js +4 -1
- package/src/converter/index.d.ts +1 -0
- package/src/converter/index.js +1 -0
- package/src/converter/manifest-inference.js +43 -12
- package/src/converter/parsers/diffusion.js +0 -3
- package/src/converter/quantization-info.js +35 -15
- package/src/converter/shard-packer.d.ts +1 -1
- package/src/converter/shard-packer.js +4 -1
- package/src/debug/config.js +123 -11
- package/src/debug/signals.js +7 -1
- package/src/debug/tensor.d.ts +2 -0
- package/src/debug/tensor.js +13 -2
- package/src/distribution/p2p-control-plane.js +52 -12
- package/src/distribution/p2p-observability.js +43 -7
- package/src/distribution/p2p-webrtc-browser.js +20 -0
- package/src/distribution/shard-delivery.js +77 -26
- package/src/formats/gguf/types.js +33 -16
- package/src/formats/rdrr/groups.d.ts +12 -4
- package/src/formats/rdrr/groups.js +3 -6
- package/src/formats/rdrr/parsing.js +39 -2
- package/src/formats/rdrr/types.d.ts +2 -1
- package/src/gpu/command-recorder.js +86 -61
- package/src/gpu/device.d.ts +1 -0
- package/src/gpu/device.js +73 -19
- package/src/gpu/kernel-tuner/benchmarks.js +326 -316
- package/src/gpu/kernel-tuner/cache.js +71 -4
- package/src/gpu/kernel-tuner/tuner.js +22 -4
- package/src/gpu/kernels/attention.js +15 -34
- package/src/gpu/kernels/backward/adam.js +62 -58
- package/src/gpu/kernels/backward/attention_backward.js +257 -169
- package/src/gpu/kernels/backward/conv2d_backward.js +14 -1
- package/src/gpu/kernels/cast.js +191 -149
- package/src/gpu/kernels/check-stop.js +33 -44
- package/src/gpu/kernels/conv2d.js +27 -17
- package/src/gpu/kernels/cross_entropy_loss.js +21 -15
- package/src/gpu/kernels/depthwise_conv2d.js +36 -26
- package/src/gpu/kernels/dequant.js +178 -126
- package/src/gpu/kernels/energy.d.ts +3 -21
- package/src/gpu/kernels/energy.js +111 -88
- package/src/gpu/kernels/feature-check.js +1 -1
- package/src/gpu/kernels/fused_ffn.js +84 -65
- package/src/gpu/kernels/fused_matmul_residual.js +56 -33
- package/src/gpu/kernels/fused_matmul_rmsnorm.js +62 -45
- package/src/gpu/kernels/gather.js +33 -15
- package/src/gpu/kernels/gelu.js +19 -11
- package/src/gpu/kernels/grouped_pointwise_conv2d.js +33 -23
- package/src/gpu/kernels/groupnorm.js +34 -23
- package/src/gpu/kernels/kv-quantize.js +5 -2
- package/src/gpu/kernels/layernorm.js +35 -19
- package/src/gpu/kernels/logit-merge.js +5 -3
- package/src/gpu/kernels/matmul.js +58 -39
- package/src/gpu/kernels/modulate.js +23 -15
- package/src/gpu/kernels/moe.js +221 -175
- package/src/gpu/kernels/pixel_shuffle.js +22 -14
- package/src/gpu/kernels/relu.js +18 -10
- package/src/gpu/kernels/repeat_channels.js +25 -17
- package/src/gpu/kernels/residual.js +37 -27
- package/src/gpu/kernels/rmsnorm.js +57 -41
- package/src/gpu/kernels/rope.js +3 -0
- package/src/gpu/kernels/sample.js +27 -38
- package/src/gpu/kernels/sana_linear_attention.js +18 -10
- package/src/gpu/kernels/scale.js +18 -11
- package/src/gpu/kernels/shader-cache.js +4 -2
- package/src/gpu/kernels/silu.js +120 -72
- package/src/gpu/kernels/softmax.js +44 -25
- package/src/gpu/kernels/split_qkv.js +23 -13
- package/src/gpu/kernels/transpose.js +18 -10
- package/src/gpu/kernels/transpose.wgsl +5 -3
- package/src/gpu/kernels/upsample2d.js +21 -13
- package/src/gpu/kernels/utils.js +20 -13
- package/src/gpu/partitioned-buffer-pool.js +10 -2
- package/src/gpu/perf-guards.js +2 -9
- package/src/gpu/profiler.js +27 -22
- package/src/gpu/readback-utils.d.ts +16 -0
- package/src/gpu/readback-utils.js +41 -0
- package/src/gpu/submit-tracker.js +13 -0
- package/src/gpu/uniform-cache.d.ts +1 -0
- package/src/gpu/uniform-cache.js +30 -9
- package/src/hotswap/intent-bundle.js +6 -0
- package/src/hotswap/manifest.d.ts +10 -1
- package/src/hotswap/manifest.js +12 -2
- package/src/hotswap/runtime.js +30 -8
- package/src/index-browser.d.ts +44 -0
- package/src/index-browser.js +14 -0
- package/src/inference/browser-harness-contract-helpers.d.ts +5 -0
- package/src/inference/browser-harness-contract-helpers.js +28 -0
- package/src/inference/browser-harness-diffusion-energy-suites.d.ts +2 -0
- package/src/inference/browser-harness-diffusion-energy-suites.js +269 -0
- package/src/inference/browser-harness-model-helpers.d.ts +16 -0
- package/src/inference/browser-harness-model-helpers.js +217 -0
- package/src/inference/browser-harness-report-helpers.d.ts +7 -0
- package/src/inference/browser-harness-report-helpers.js +42 -0
- package/src/inference/browser-harness-runtime-helpers.d.ts +61 -0
- package/src/inference/browser-harness-runtime-helpers.js +415 -0
- package/src/inference/browser-harness-suite-helpers.d.ts +28 -0
- package/src/inference/browser-harness-suite-helpers.js +268 -0
- package/src/inference/browser-harness-text-helpers.d.ts +27 -0
- package/src/inference/browser-harness-text-helpers.js +788 -0
- package/src/inference/browser-harness.d.ts +6 -0
- package/src/inference/browser-harness.js +130 -1996
- package/src/inference/kv-cache/base.js +140 -94
- package/src/inference/kv-cache/tiered.js +5 -3
- package/src/inference/moe-router.js +88 -56
- package/src/inference/multi-model-network.js +5 -3
- package/src/inference/network-evolution.d.ts +11 -2
- package/src/inference/network-evolution.js +20 -21
- package/src/inference/pipelines/context.d.ts +3 -0
- package/src/inference/pipelines/context.js +142 -2
- package/src/inference/pipelines/diffusion/helpers.js +7 -2
- package/src/inference/pipelines/diffusion/pipeline.js +2 -1
- package/src/inference/pipelines/diffusion/sd3-transformer.js +10 -10
- package/src/inference/pipelines/diffusion/vae.js +3 -7
- package/src/inference/pipelines/energy/pipeline.js +27 -21
- package/src/inference/pipelines/energy/quintel.d.ts +5 -0
- package/src/inference/pipelines/energy/quintel.js +11 -0
- package/src/inference/pipelines/energy-head/row-head-pipeline.js +17 -13
- package/src/inference/pipelines/structured/json-head-pipeline.js +26 -11
- package/src/inference/pipelines/text/attention/projections.js +151 -101
- package/src/inference/pipelines/text/attention/record.js +62 -8
- package/src/inference/pipelines/text/attention/run.js +62 -8
- package/src/inference/pipelines/text/config.js +3 -4
- package/src/inference/pipelines/text/embed.js +2 -8
- package/src/inference/pipelines/text/execution-plan.js +41 -19
- package/src/inference/pipelines/text/execution-v0-contract-helpers.d.ts +59 -0
- package/src/inference/pipelines/text/execution-v0-contract-helpers.js +937 -0
- package/src/inference/pipelines/text/execution-v0-runtime-builders.d.ts +15 -0
- package/src/inference/pipelines/text/execution-v0-runtime-builders.js +279 -0
- package/src/inference/pipelines/text/execution-v0.js +62 -1013
- package/src/inference/pipelines/text/generator-steps.d.ts +46 -0
- package/src/inference/pipelines/text/generator-steps.js +298 -207
- package/src/inference/pipelines/text/generator.js +6 -23
- package/src/inference/pipelines/text/init.js +78 -20
- package/src/inference/pipelines/text/kernel-path-auto-select.js +2 -0
- package/src/inference/pipelines/text/kernel-trace.d.ts +2 -0
- package/src/inference/pipelines/text/kernel-trace.js +6 -0
- package/src/inference/pipelines/text/layer.js +3 -9
- package/src/inference/pipelines/text/linear-attention.d.ts +10 -0
- package/src/inference/pipelines/text/linear-attention.js +80 -6
- package/src/inference/pipelines/text/logits/gpu.js +10 -5
- package/src/inference/pipelines/text/logits/index.js +10 -11
- package/src/inference/pipelines/text/logits/utils.d.ts +7 -0
- package/src/inference/pipelines/text/logits/utils.js +9 -0
- package/src/inference/pipelines/text/lora-apply.js +50 -32
- package/src/inference/pipelines/text/model-load.js +279 -104
- package/src/inference/pipelines/text/moe-cache.js +5 -4
- package/src/inference/pipelines/text/moe-cpu-gptoss.js +74 -69
- package/src/inference/pipelines/text/moe-cpu.js +42 -38
- package/src/inference/pipelines/text/moe-gpu.js +110 -86
- package/src/inference/pipelines/text/ops.js +90 -90
- package/src/inference/pipelines/text/probes.js +9 -9
- package/src/inference/pipelines/text/weights.js +17 -7
- package/src/inference/pipelines/text.js +13 -1
- package/src/inference/speculative.d.ts +2 -2
- package/src/inference/speculative.js +4 -18
- package/src/inference/test-harness.d.ts +1 -1
- package/src/inference/test-harness.js +15 -5
- package/src/inference/tokenizer.d.ts +0 -5
- package/src/inference/tokenizer.js +4 -23
- package/src/inference/tokenizers/bpe.js +9 -0
- package/src/inference/tokenizers/bundled.js +20 -0
- package/src/inference/tokenizers/sentencepiece.js +12 -0
- package/src/loader/doppler-loader.js +38 -22
- package/src/loader/dtype-utils.js +3 -44
- package/src/loader/embedding-loader.js +7 -3
- package/src/loader/experts/expert-cache.js +13 -6
- package/src/loader/experts/expert-loader.js +10 -6
- package/src/loader/final-weights-loader.js +8 -4
- package/src/loader/layer-loader.js +2 -1
- package/src/loader/loader-state.js +2 -2
- package/src/loader/memory-monitor.js +8 -0
- package/src/loader/multi-model-loader.d.ts +14 -0
- package/src/loader/multi-model-loader.js +70 -24
- package/src/loader/shard-cache.js +81 -12
- package/src/loader/shard-resolver.js +25 -3
- package/src/loader/tensors/tensor-loader.js +209 -144
- package/src/loader/tensors/tensor-reader.js +76 -19
- package/src/loader/weight-downcast.js +1 -1
- package/src/memory/buffer-pool.d.ts +9 -1
- package/src/memory/buffer-pool.js +109 -44
- package/src/memory/unified-detect.js +1 -1
- package/src/rules/inference/kernel-path.rules.json +24 -8
- package/src/rules/rule-registry.js +25 -1
- package/src/storage/backends/opfs-store.js +68 -24
- package/src/storage/downloader.js +364 -83
- package/src/storage/index.d.ts +3 -0
- package/src/storage/index.js +3 -0
- package/src/storage/preflight.d.ts +2 -2
- package/src/storage/preflight.js +24 -2
- package/src/storage/quickstart-downloader.js +11 -5
- package/src/storage/registry.js +10 -4
- package/src/storage/reports.js +1 -1
- package/src/storage/shard-manager.d.ts +15 -1
- package/src/storage/shard-manager.js +51 -3
- package/src/storage/source-artifact-store.d.ts +52 -0
- package/src/storage/source-artifact-store.js +234 -0
- package/src/tooling/command-api-constants.d.ts +9 -0
- package/src/tooling/command-api-constants.js +9 -0
- package/src/tooling/command-api-family-normalizers.d.ts +9 -0
- package/src/tooling/command-api-family-normalizers.js +343 -0
- package/src/tooling/command-api-helpers.d.ts +25 -0
- package/src/tooling/command-api-helpers.js +262 -0
- package/src/tooling/command-api.js +16 -602
- package/src/tooling/command-envelope.js +4 -1
- package/src/tooling/command-runner-shared.js +52 -18
- package/src/tooling/lean-execution-contract.js +150 -3
- package/src/tooling/node-browser-command-runner.js +161 -271
- package/src/tooling/node-command-runner.js +29 -3
- package/src/tooling/node-converter.js +27 -1
- package/src/tooling/node-source-runtime.d.ts +1 -1
- package/src/tooling/node-source-runtime.js +84 -3
- package/src/tooling/node-webgpu.js +24 -21
- package/src/tooling/opfs-cache.js +21 -4
- package/src/tooling/runtime-input-composition.d.ts +38 -0
- package/src/tooling/runtime-input-composition.js +86 -0
- package/src/tooling/source-runtime-bundle.d.ts +40 -5
- package/src/tooling/source-runtime-bundle.js +261 -34
- package/src/tooling/source-runtime-materializer.d.ts +6 -0
- package/src/tooling/source-runtime-materializer.js +93 -0
- package/src/training/attention-backward.js +32 -17
- package/src/training/autograd.js +80 -52
- package/src/training/checkpoint-watch.d.ts +2 -1
- package/src/training/checkpoint-watch.js +39 -6
- package/src/training/checkpoint.js +40 -11
- package/src/training/clip.js +2 -1
- package/src/training/datasets/token-batch.js +20 -8
- package/src/training/distillation/checkpoint-watch.js +1 -0
- package/src/training/distillation/student-fixture.d.ts +22 -0
- package/src/training/distillation/student-fixture.js +846 -0
- package/src/training/distillation/suite-data.d.ts +45 -0
- package/src/training/distillation/suite-data.js +189 -0
- package/src/training/lora-pipeline.js +4 -7
- package/src/training/lora.js +26 -12
- package/src/training/loss.js +5 -6
- package/src/training/objectives/cross_entropy.js +2 -5
- package/src/training/objectives/distill_kd.js +4 -8
- package/src/training/objectives/distill_triplet.js +4 -8
- package/src/training/objectives/ul_stage2_base.js +4 -8
- package/src/training/operator-command.js +2 -0
- package/src/training/optimizer.js +19 -7
- package/src/training/runner.js +2 -1
- package/src/training/suite.js +18 -978
- package/src/training/tensor-factory.d.ts +9 -0
- package/src/training/tensor-factory.js +13 -0
- package/src/training/trainer.js +3 -5
- package/src/training/ul_dataset.js +3 -5
- package/src/training/workloads.js +70 -79
- package/src/version.js +1 -1
- package/tools/convert-safetensors-node.js +22 -16
- package/tools/doppler-cli.js +44 -25
|
@@ -9,7 +9,7 @@ import {
|
|
|
9
9
|
verifyIntegrity,
|
|
10
10
|
loadManifestFromStore,
|
|
11
11
|
} from '../storage/shard-manager.js';
|
|
12
|
-
import { parseManifest } from '../formats/rdrr/index.js';
|
|
12
|
+
import { clearManifest, parseManifest, setManifest as setCurrentManifest } from '../formats/rdrr/index.js';
|
|
13
13
|
import { initDevice, getDevice, getKernelCapabilities } from '../gpu/device.js';
|
|
14
14
|
import { acquireBuffer, releaseBuffer, forceBufferPoolReclaim } from '../memory/buffer-pool.js';
|
|
15
15
|
import { getExpertCache } from './experts/expert-cache.js';
|
|
@@ -50,6 +50,10 @@ function hasExpertGroups(manifest) {
|
|
|
50
50
|
return Object.keys(manifest.groups).some((groupId) => groupId.includes('.expert.'));
|
|
51
51
|
}
|
|
52
52
|
|
|
53
|
+
function isGpuBufferInstance(value) {
|
|
54
|
+
return typeof GPUBuffer !== 'undefined' && value instanceof GPUBuffer;
|
|
55
|
+
}
|
|
56
|
+
|
|
53
57
|
// Re-export types for backward compatibility
|
|
54
58
|
export {
|
|
55
59
|
// Types are in .d.ts file
|
|
@@ -252,6 +256,7 @@ export class DopplerLoader {
|
|
|
252
256
|
|
|
253
257
|
setManifest(manifest) {
|
|
254
258
|
this.manifest = manifest;
|
|
259
|
+
setCurrentManifest(manifest);
|
|
255
260
|
const moeConfig = manifest.moeConfig;
|
|
256
261
|
this.isMoE = moeConfig != null && (moeConfig.numExperts ?? 0) > 1;
|
|
257
262
|
if (!this.isMoE && hasExpertGroups(manifest)) {
|
|
@@ -259,6 +264,7 @@ export class DopplerLoader {
|
|
|
259
264
|
`Manifest "${manifest.modelId ?? 'unknown'}" missing moeConfig for MoE model. Re-convert with moeConfig.`
|
|
260
265
|
);
|
|
261
266
|
}
|
|
267
|
+
this.shardCache.setManifest(this.manifest);
|
|
262
268
|
this.shardCache.configureForModel(this.manifest, this.shardCache.hasCustomLoader);
|
|
263
269
|
debugTrace.loader('Manifest set externally');
|
|
264
270
|
}
|
|
@@ -679,7 +685,7 @@ export class DopplerLoader {
|
|
|
679
685
|
const device = getDevice();
|
|
680
686
|
if (!device) {
|
|
681
687
|
log.warn('Loader', 'GPU device not available; falling back to CPU');
|
|
682
|
-
if (shardData
|
|
688
|
+
if (isGpuBufferInstance(shardData)) {
|
|
683
689
|
releaseBuffer(shardData);
|
|
684
690
|
shardData = await this.#assembleShardData(location, name);
|
|
685
691
|
}
|
|
@@ -708,7 +714,7 @@ export class DopplerLoader {
|
|
|
708
714
|
return result.data;
|
|
709
715
|
}
|
|
710
716
|
|
|
711
|
-
if (shardData
|
|
717
|
+
if (isGpuBufferInstance(shardData)) {
|
|
712
718
|
// Shouldn't happen (streaming is only used for toGPU), but keep this leak-proof.
|
|
713
719
|
releaseBuffer(shardData);
|
|
714
720
|
shardData = await this.#assembleShardData(location, name);
|
|
@@ -751,31 +757,40 @@ export class DopplerLoader {
|
|
|
751
757
|
// queue.writeBuffer requires 4-byte aligned sizes; we pad the buffer.
|
|
752
758
|
const alignedSize = Math.ceil(location.size / 4) * 4;
|
|
753
759
|
const raw = acquireBuffer(alignedSize, undefined, `raw_${name}`);
|
|
760
|
+
let complete = false;
|
|
754
761
|
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
|
|
762
|
+
try {
|
|
763
|
+
let dstOffset = 0;
|
|
764
|
+
const uploadChunk = (bytes) => {
|
|
765
|
+
device.queue.writeBuffer(raw, dstOffset, bytes, bytes.byteOffset, bytes.byteLength);
|
|
766
|
+
dstOffset += bytes.byteLength;
|
|
767
|
+
};
|
|
768
|
+
const streamRange = (idx, offset, length) => this.shardCache.streamRange(idx, offset, length, { chunkBytes });
|
|
761
769
|
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
770
|
+
if (location.spans) {
|
|
771
|
+
for (const span of location.spans) {
|
|
772
|
+
for await (const chunk of streamRange(span.shardIndex, span.offset, span.size)) {
|
|
773
|
+
uploadChunk(chunk);
|
|
774
|
+
}
|
|
775
|
+
}
|
|
776
|
+
} else {
|
|
777
|
+
for await (const chunk of streamRange(location.shardIndex, location.offset, location.size)) {
|
|
765
778
|
uploadChunk(chunk);
|
|
766
779
|
}
|
|
767
780
|
}
|
|
768
|
-
} else {
|
|
769
|
-
for await (const chunk of streamRange(location.shardIndex, location.offset, location.size)) {
|
|
770
|
-
uploadChunk(chunk);
|
|
771
|
-
}
|
|
772
|
-
}
|
|
773
781
|
|
|
774
|
-
|
|
775
|
-
|
|
782
|
+
if (dstOffset !== location.size) {
|
|
783
|
+
throw new Error(
|
|
784
|
+
`Stream upload short read for "${name}": got=${dstOffset}, expected=${location.size}.`
|
|
785
|
+
);
|
|
786
|
+
}
|
|
787
|
+
complete = true;
|
|
788
|
+
return raw;
|
|
789
|
+
} finally {
|
|
790
|
+
if (!complete) {
|
|
791
|
+
releaseBuffer(raw);
|
|
792
|
+
}
|
|
776
793
|
}
|
|
777
|
-
|
|
778
|
-
return raw;
|
|
779
794
|
}
|
|
780
795
|
|
|
781
796
|
|
|
@@ -950,7 +965,7 @@ export class DopplerLoader {
|
|
|
950
965
|
if (!value) return;
|
|
951
966
|
const gpuBuffer = isWeightBuffer(value)
|
|
952
967
|
? value.buffer
|
|
953
|
-
: (value
|
|
968
|
+
: (isGpuBufferInstance(value) ? value : null);
|
|
954
969
|
if (!gpuBuffer) return;
|
|
955
970
|
try {
|
|
956
971
|
releaseBuffer(gpuBuffer);
|
|
@@ -990,6 +1005,7 @@ export class DopplerLoader {
|
|
|
990
1005
|
this.lmHead = null;
|
|
991
1006
|
this.finalNorm = null;
|
|
992
1007
|
this.manifest = null;
|
|
1008
|
+
clearManifest();
|
|
993
1009
|
this.modelId = null;
|
|
994
1010
|
this.loadedShards.clear();
|
|
995
1011
|
this.isLoaded = false;
|
|
@@ -1,7 +1,4 @@
|
|
|
1
1
|
|
|
2
|
-
|
|
3
|
-
import { getDevice } from '../gpu/device.js';
|
|
4
|
-
import { isTraceEnabled, log, trace as debugTrace } from '../debug/index.js';
|
|
5
2
|
import { selectRuleValue } from '../rules/rule-registry.js';
|
|
6
3
|
import { tagBufferDtype } from '../gpu/weight-buffer.js';
|
|
7
4
|
|
|
@@ -26,46 +23,8 @@ export function f16ToF32(h) {
|
|
|
26
23
|
|
|
27
24
|
|
|
28
25
|
export async function convertBF16ToF32GPU(srcBuffer, numElements, name) {
|
|
29
|
-
|
|
30
|
-
const castModule = await import('../gpu/kernels/cast.js');
|
|
31
|
-
debugTrace.loader(`[BF16->F32] castModule keys:`, Object.keys(castModule));
|
|
32
|
-
const { runBF16ToF32 } = castModule;
|
|
33
|
-
debugTrace.loader(`[BF16->F32] runBF16ToF32 type: ${typeof runBF16ToF32}`);
|
|
26
|
+
const { runBF16ToF32 } = await import('../gpu/kernels/cast.js');
|
|
34
27
|
const resultTensor = await runBF16ToF32(srcBuffer, [numElements], name);
|
|
35
|
-
debugTrace.loader(`[BF16->F32] runBF16ToF32 returned, result.size=${resultTensor.buffer?.size}`);
|
|
36
|
-
|
|
37
|
-
// Debug: Verify conversion produced non-zero values
|
|
38
|
-
const shouldCheckEmbed = isTraceEnabled('loader') &&
|
|
39
|
-
name.includes('embed') &&
|
|
40
|
-
name.includes('embed_tokens');
|
|
41
|
-
if (shouldCheckEmbed) {
|
|
42
|
-
try {
|
|
43
|
-
debugTrace.loader(`[BF16->F32] Checking embed buffer for non-zeros...`);
|
|
44
|
-
const device = getDevice();
|
|
45
|
-
const sampleSize = Math.min(1024, resultTensor.buffer.size);
|
|
46
|
-
debugTrace.loader(`[BF16->F32] Creating staging buffer size=${sampleSize}`);
|
|
47
|
-
const stagingBuffer = device.createBuffer({
|
|
48
|
-
size: sampleSize,
|
|
49
|
-
usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ,
|
|
50
|
-
});
|
|
51
|
-
debugTrace.loader(`[BF16->F32] Copying to staging buffer...`);
|
|
52
|
-
const encoder = device.createCommandEncoder();
|
|
53
|
-
encoder.copyBufferToBuffer(resultTensor.buffer, 0, stagingBuffer, 0, sampleSize);
|
|
54
|
-
device.queue.submit([encoder.finish()]);
|
|
55
|
-
debugTrace.loader(`[BF16->F32] Mapping staging buffer...`);
|
|
56
|
-
await stagingBuffer.mapAsync(GPUMapMode.READ);
|
|
57
|
-
debugTrace.loader(`[BF16->F32] Reading data...`);
|
|
58
|
-
const data = new Float32Array(stagingBuffer.getMappedRange().slice(0));
|
|
59
|
-
stagingBuffer.unmap();
|
|
60
|
-
stagingBuffer.destroy();
|
|
61
|
-
const nonZero = Array.from(data).filter(x => x !== 0);
|
|
62
|
-
const nanCount = data.filter(x => !Number.isFinite(x)).length;
|
|
63
|
-
debugTrace.loader(`[BF16->F32] nonZero=${nonZero.length}/${data.length}, nan=${nanCount}, sample=[${nonZero.slice(0, 5).map(x => x.toFixed(4)).join(', ')}]`);
|
|
64
|
-
} catch (err) {
|
|
65
|
-
log.error('Loader', 'BF16->F32 embed buffer check error:', (err).message);
|
|
66
|
-
}
|
|
67
|
-
}
|
|
68
|
-
|
|
69
28
|
return resultTensor.buffer;
|
|
70
29
|
}
|
|
71
30
|
|
|
@@ -84,11 +43,11 @@ function normalizeBufferDtype(locationDtype, outputDtype) {
|
|
|
84
43
|
if (explicit) {
|
|
85
44
|
return explicit;
|
|
86
45
|
}
|
|
87
|
-
const location = typeof locationDtype === 'string' ? locationDtype.
|
|
46
|
+
const location = typeof locationDtype === 'string' ? locationDtype.toUpperCase() : null;
|
|
88
47
|
if (!location) {
|
|
89
48
|
return null;
|
|
90
49
|
}
|
|
91
|
-
return selectRuleValue('loader', 'weights', 'floatLocationDtype', { locationDtype:
|
|
50
|
+
return selectRuleValue('loader', 'weights', 'floatLocationDtype', { locationDtype: location });
|
|
92
51
|
}
|
|
93
52
|
|
|
94
53
|
export function applyBufferLayout(buffer, location, outputDtype = null) {
|
|
@@ -23,6 +23,10 @@ import { releaseBuffer } from '../memory/buffer-pool.js';
|
|
|
23
23
|
const EMBEDDING_ROLE = 'embedding';
|
|
24
24
|
const EMBEDDING_GROUP = 'embed';
|
|
25
25
|
|
|
26
|
+
function isGpuBufferInstance(value) {
|
|
27
|
+
return typeof GPUBuffer !== 'undefined' && value instanceof GPUBuffer;
|
|
28
|
+
}
|
|
29
|
+
|
|
26
30
|
// ============================================================================
|
|
27
31
|
// Main Function
|
|
28
32
|
// ============================================================================
|
|
@@ -59,7 +63,7 @@ export async function loadEmbeddings(ctx) {
|
|
|
59
63
|
}
|
|
60
64
|
|
|
61
65
|
// Handle valid tensor types
|
|
62
|
-
if (tensor
|
|
66
|
+
if (isGpuBufferInstance(tensor) || isWeightBuffer(tensor) || tensor instanceof Float32Array) {
|
|
63
67
|
const result = await processEmbeddingTensor(ctx, tensor, name, loc, shouldStream);
|
|
64
68
|
if (result) {
|
|
65
69
|
return result;
|
|
@@ -107,7 +111,7 @@ async function processEmbeddingTensor(ctx, tensor, name, loc, shouldStream) {
|
|
|
107
111
|
}
|
|
108
112
|
|
|
109
113
|
// Raw GPUBuffer - wrap with dtype/layout metadata
|
|
110
|
-
if (promoted
|
|
114
|
+
if (isGpuBufferInstance(promoted) && loc?.shape && loc.shape.length === 2) {
|
|
111
115
|
const layout = ctx.resolveWeightLayout(loc);
|
|
112
116
|
|
|
113
117
|
const dtype = selectRuleValue('loader', 'weights', 'floatLocationDtype', {
|
|
@@ -140,7 +144,7 @@ async function maybePromoteEmbeddingsToF32(ctx, current, name, loc) {
|
|
|
140
144
|
return wrapped;
|
|
141
145
|
}
|
|
142
146
|
|
|
143
|
-
if (!(current
|
|
147
|
+
if (!isGpuBufferInstance(current)) return current;
|
|
144
148
|
|
|
145
149
|
const sourceDtype = selectRuleValue('loader', 'weights', 'floatLocationDtype', {
|
|
146
150
|
locationDtype: loc?.dtype,
|
|
@@ -3,6 +3,11 @@
|
|
|
3
3
|
import { releaseBuffer } from '../../memory/buffer-pool.js';
|
|
4
4
|
import { log, trace } from '../../debug/index.js';
|
|
5
5
|
import { getRuntimeConfig } from '../../config/runtime.js';
|
|
6
|
+
import { isWeightBuffer } from '../../gpu/weight-buffer.js';
|
|
7
|
+
|
|
8
|
+
function isGpuBufferInstance(value) {
|
|
9
|
+
return typeof GPUBuffer !== 'undefined' && value instanceof GPUBuffer;
|
|
10
|
+
}
|
|
6
11
|
|
|
7
12
|
|
|
8
13
|
|
|
@@ -256,12 +261,14 @@ export class ExpertCache {
|
|
|
256
261
|
];
|
|
257
262
|
|
|
258
263
|
for (const buf of buffers) {
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
264
|
+
const gpuBuffer = isWeightBuffer(buf)
|
|
265
|
+
? buf.buffer
|
|
266
|
+
: (isGpuBufferInstance(buf) ? buf : null);
|
|
267
|
+
if (!gpuBuffer) continue;
|
|
268
|
+
try {
|
|
269
|
+
releaseBuffer(gpuBuffer);
|
|
270
|
+
} catch (e) {
|
|
271
|
+
// Buffer may already be released
|
|
265
272
|
}
|
|
266
273
|
}
|
|
267
274
|
}
|
|
@@ -18,7 +18,7 @@ import { releaseBuffer } from '../../memory/buffer-pool.js';
|
|
|
18
18
|
|
|
19
19
|
export async function preloadShardsForExpert(ctx, layerIdx, expertIdx, options) {
|
|
20
20
|
// Get required shards from manifest mapping
|
|
21
|
-
const shardIndices = getShardsForExpert(layerIdx, expertIdx);
|
|
21
|
+
const shardIndices = getShardsForExpert(layerIdx, expertIdx, ctx.manifest);
|
|
22
22
|
if (shardIndices.length === 0) {
|
|
23
23
|
// No mapping available, fall back to loading all shards on demand
|
|
24
24
|
return;
|
|
@@ -69,6 +69,10 @@ export function predictNextLayerExperts(currentExperts) {
|
|
|
69
69
|
return currentExperts;
|
|
70
70
|
}
|
|
71
71
|
|
|
72
|
+
function isGpuBufferInstance(value) {
|
|
73
|
+
return typeof GPUBuffer !== 'undefined' && value instanceof GPUBuffer;
|
|
74
|
+
}
|
|
75
|
+
|
|
72
76
|
// ============================================================================
|
|
73
77
|
// Expert Loading
|
|
74
78
|
// ============================================================================
|
|
@@ -95,7 +99,7 @@ export async function loadExpert(ctx, layerIdx, expertIdx) {
|
|
|
95
99
|
await preloadShardsForExpert(ctx, layerIdx, expertIdx);
|
|
96
100
|
|
|
97
101
|
// Get tensor names from manifest if available (for logging/debugging)
|
|
98
|
-
const tensorNames = getTensorsForExpert(layerIdx, expertIdx);
|
|
102
|
+
const tensorNames = getTensorsForExpert(layerIdx, expertIdx, ctx.manifest);
|
|
99
103
|
if (tensorNames.length > 0) {
|
|
100
104
|
debugTrace.loader(`Expert ${layerIdx}_${expertIdx} tensors: ${tensorNames.length}`);
|
|
101
105
|
}
|
|
@@ -260,7 +264,7 @@ function getGpuBuffer(value) {
|
|
|
260
264
|
if (isWeightBuffer(value)) {
|
|
261
265
|
return value.buffer;
|
|
262
266
|
}
|
|
263
|
-
if (value
|
|
267
|
+
if (isGpuBufferInstance(value)) {
|
|
264
268
|
return value;
|
|
265
269
|
}
|
|
266
270
|
return null;
|
|
@@ -342,7 +346,7 @@ async function downcastExpertWeights(ctx, weights) {
|
|
|
342
346
|
if (!buf) continue;
|
|
343
347
|
|
|
344
348
|
// Only downcast GPUBuffer or WeightBuffer (not Float32Array)
|
|
345
|
-
if (!(buf
|
|
349
|
+
if (!isGpuBufferInstance(buf) && !isWeightBuffer(buf)) {
|
|
346
350
|
continue;
|
|
347
351
|
}
|
|
348
352
|
|
|
@@ -369,13 +373,13 @@ function calculateExpertSize(weights) {
|
|
|
369
373
|
const buf = weights[k];
|
|
370
374
|
if (isWeightBuffer(buf)) {
|
|
371
375
|
sizeBytes += buf.buffer.size;
|
|
372
|
-
} else if (buf
|
|
376
|
+
} else if (isGpuBufferInstance(buf)) {
|
|
373
377
|
sizeBytes += buf.size;
|
|
374
378
|
}
|
|
375
379
|
}
|
|
376
380
|
|
|
377
381
|
// Use manifest-provided expert size if available, otherwise use calculated
|
|
378
|
-
const manifestBytes = getExpertBytes();
|
|
382
|
+
const manifestBytes = getExpertBytes(ctx.manifest);
|
|
379
383
|
if (manifestBytes > 0) {
|
|
380
384
|
sizeBytes = manifestBytes;
|
|
381
385
|
}
|
|
@@ -20,6 +20,10 @@ const HEAD_GROUP = 'head';
|
|
|
20
20
|
const FINAL_NORM_ROLE = 'norm';
|
|
21
21
|
const LM_HEAD_ROLE = 'lm_head';
|
|
22
22
|
|
|
23
|
+
function isGpuBufferInstance(value) {
|
|
24
|
+
return typeof GPUBuffer !== 'undefined' && value instanceof GPUBuffer;
|
|
25
|
+
}
|
|
26
|
+
|
|
23
27
|
function isLikelyFinalNormName(name) {
|
|
24
28
|
const lower = String(name || '').toLowerCase();
|
|
25
29
|
if (!lower) return false;
|
|
@@ -148,7 +152,7 @@ async function loadLmHead(ctx) {
|
|
|
148
152
|
);
|
|
149
153
|
}
|
|
150
154
|
|
|
151
|
-
if (tensor && (tensor
|
|
155
|
+
if (tensor && (isGpuBufferInstance(tensor) || isWeightBuffer(tensor) || tensor instanceof Float32Array)) {
|
|
152
156
|
lmHeadName = name;
|
|
153
157
|
lmHeadLoc = loc;
|
|
154
158
|
lmHead = processLmHeadTensor(ctx, tensor, name, loc, shouldStream);
|
|
@@ -189,7 +193,7 @@ function processLmHeadTensor(ctx, tensor, name, loc, shouldStream) {
|
|
|
189
193
|
}
|
|
190
194
|
|
|
191
195
|
// Raw GPUBuffer - wrap with dtype/layout metadata
|
|
192
|
-
if (tensor
|
|
196
|
+
if (isGpuBufferInstance(tensor) && loc.shape && loc.shape.length === 2) {
|
|
193
197
|
const layout = ctx.resolveWeightLayout(loc);
|
|
194
198
|
|
|
195
199
|
const dtype = selectRuleValue('loader', 'weights', 'floatLocationDtype', {
|
|
@@ -209,7 +213,7 @@ async function maybeDowncastLmHead(ctx, lmHead, lmHeadName, lmHeadLoc) {
|
|
|
209
213
|
const tiedToEmbeddings =
|
|
210
214
|
lmHead === ctx.embeddings ||
|
|
211
215
|
(isWeightBuffer(lmHead) && isWeightBuffer(ctx.embeddings) && lmHead.buffer === ctx.embeddings.buffer) ||
|
|
212
|
-
(lmHead
|
|
216
|
+
(isGpuBufferInstance(lmHead) && isWeightBuffer(ctx.embeddings) && lmHead === ctx.embeddings.buffer);
|
|
213
217
|
|
|
214
218
|
if (tiedToEmbeddings) {
|
|
215
219
|
return lmHead;
|
|
@@ -234,7 +238,7 @@ async function maybeDowncastLmHead(ctx, lmHead, lmHeadName, lmHeadLoc) {
|
|
|
234
238
|
|
|
235
239
|
// Get buffer for downcast
|
|
236
240
|
const buffer = isWeightBuffer(lmHead) ? lmHead.buffer : lmHead;
|
|
237
|
-
if (!(buffer
|
|
241
|
+
if (!isGpuBufferInstance(buffer)) {
|
|
238
242
|
return lmHead;
|
|
239
243
|
}
|
|
240
244
|
|
|
@@ -224,7 +224,8 @@ function createTryLoad(ctx, prefixes) {
|
|
|
224
224
|
for (const prefix of prefixes) {
|
|
225
225
|
for (const suffix of suffixes) {
|
|
226
226
|
const tensor = await ctx.loadTensor(`${prefix}.${suffix}`, true, true);
|
|
227
|
-
|
|
227
|
+
const isGpuBuffer = typeof GPUBuffer !== 'undefined' && tensor instanceof GPUBuffer;
|
|
228
|
+
if (tensor && (isGpuBuffer || tensor instanceof Float32Array || isWeightBuffer(tensor))) {
|
|
228
229
|
return tensor;
|
|
229
230
|
}
|
|
230
231
|
}
|
|
@@ -122,14 +122,14 @@ export class LoaderState {
|
|
|
122
122
|
|
|
123
123
|
static getGPUBuffer(weight) {
|
|
124
124
|
if (!weight) return null;
|
|
125
|
-
if (weight instanceof GPUBuffer) return weight;
|
|
125
|
+
if (typeof GPUBuffer !== 'undefined' && weight instanceof GPUBuffer) return weight;
|
|
126
126
|
if (isWeightBuffer(weight)) return weight.buffer;
|
|
127
127
|
return null;
|
|
128
128
|
}
|
|
129
129
|
|
|
130
130
|
static isGPUBacked(weight) {
|
|
131
131
|
if (!weight) return false;
|
|
132
|
-
if (weight instanceof GPUBuffer) return true;
|
|
132
|
+
if (typeof GPUBuffer !== 'undefined' && weight instanceof GPUBuffer) return true;
|
|
133
133
|
if (isWeightBuffer(weight)) return true;
|
|
134
134
|
if (isCpuWeightBuffer(weight)) return false;
|
|
135
135
|
if (weight instanceof Float32Array) return false;
|
|
@@ -105,6 +105,10 @@ export class MemoryMonitor {
|
|
|
105
105
|
|
|
106
106
|
|
|
107
107
|
start(getState) {
|
|
108
|
+
if (this.#interval) {
|
|
109
|
+
clearInterval(this.#interval);
|
|
110
|
+
this.#interval = null;
|
|
111
|
+
}
|
|
108
112
|
this.#startTime = performance.now();
|
|
109
113
|
this.#snapshots = [];
|
|
110
114
|
this.#log('start', getState());
|
|
@@ -209,6 +213,10 @@ export class MemoryTimeSeries {
|
|
|
209
213
|
|
|
210
214
|
|
|
211
215
|
start() {
|
|
216
|
+
if (this.#interval) {
|
|
217
|
+
clearInterval(this.#interval);
|
|
218
|
+
this.#interval = null;
|
|
219
|
+
}
|
|
212
220
|
this.#startTime = performance.now();
|
|
213
221
|
this.#samples = [];
|
|
214
222
|
this.#capture('start');
|
|
@@ -22,6 +22,20 @@ export declare class MultiModelLoader {
|
|
|
22
22
|
baseWeights: WeightLoadResult | null;
|
|
23
23
|
adapters: Map<string, LoRAAdapter>;
|
|
24
24
|
|
|
25
|
+
_loadBaseWeights(
|
|
26
|
+
manifest: Manifest,
|
|
27
|
+
options: { storageContext?: { loadShard?: (index: number) => Promise<ArrayBuffer | Uint8Array> } },
|
|
28
|
+
runtimeConfig: unknown
|
|
29
|
+
): Promise<WeightLoadResult>;
|
|
30
|
+
|
|
31
|
+
_resolveAdapterSource(source: AdapterSource): Promise<LoRAAdapter>;
|
|
32
|
+
|
|
33
|
+
_createPipeline(): InferencePipeline;
|
|
34
|
+
|
|
35
|
+
_getBaseLoader(): { unload(): Promise<void> };
|
|
36
|
+
|
|
37
|
+
unload(): Promise<void>;
|
|
38
|
+
|
|
25
39
|
loadBase(
|
|
26
40
|
manifest: Manifest,
|
|
27
41
|
options?: { storageContext?: { loadShard?: (index: number) => Promise<ArrayBuffer | Uint8Array> } }
|
|
@@ -17,37 +17,68 @@ export class MultiModelLoader {
|
|
|
17
17
|
|
|
18
18
|
adapters = new Map();
|
|
19
19
|
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
const runtimeConfig = getRuntimeConfig();
|
|
20
|
+
#pipelines = new Set();
|
|
21
|
+
|
|
22
|
+
async _loadBaseWeights(manifest, options, runtimeConfig) {
|
|
24
23
|
const modelOverrides = (runtimeConfig.inference.modelOverrides);
|
|
25
24
|
const config = parseModelConfig(manifest, modelOverrides);
|
|
26
|
-
|
|
27
|
-
this.baseWeights = await loadWeights(manifest, config, {
|
|
25
|
+
return loadWeights(manifest, config, {
|
|
28
26
|
storageContext: options.storageContext,
|
|
29
27
|
keepF32Weights: runtimeConfig.inference.compute.keepF32Weights === true,
|
|
30
28
|
});
|
|
31
|
-
return this.baseWeights;
|
|
32
29
|
}
|
|
33
30
|
|
|
34
|
-
|
|
35
|
-
async loadAdapter(name, source) {
|
|
36
|
-
|
|
37
|
-
let adapter;
|
|
38
|
-
|
|
31
|
+
async _resolveAdapterSource(source) {
|
|
39
32
|
if (typeof source === 'string') {
|
|
40
|
-
|
|
41
|
-
}
|
|
33
|
+
return loadLoRAFromUrl(source);
|
|
34
|
+
}
|
|
35
|
+
if (this.#isRDRRManifest(source)) {
|
|
42
36
|
const loader = getDopplerLoader();
|
|
43
37
|
await loader.init();
|
|
44
|
-
|
|
45
|
-
}
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
38
|
+
return loader.loadLoRAWeights(source);
|
|
39
|
+
}
|
|
40
|
+
if (this.#isLoRAManifest(source)) {
|
|
41
|
+
return loadLoRAFromManifest(source);
|
|
42
|
+
}
|
|
43
|
+
return source;
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
_createPipeline() {
|
|
47
|
+
return new InferencePipeline();
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
_getBaseLoader() {
|
|
51
|
+
return getDopplerLoader();
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
async unload() {
|
|
55
|
+
const pipelines = Array.from(this.#pipelines);
|
|
56
|
+
this.#pipelines.clear();
|
|
57
|
+
await Promise.all(pipelines.map(async (pipeline) => pipeline.unload()));
|
|
58
|
+
|
|
59
|
+
if (this.baseWeights) {
|
|
60
|
+
const loader = this._getBaseLoader();
|
|
61
|
+
await loader.unload();
|
|
49
62
|
}
|
|
50
63
|
|
|
64
|
+
this.baseManifest = null;
|
|
65
|
+
this.baseWeights = null;
|
|
66
|
+
this.adapters.clear();
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
async loadBase(manifest, options = {}) {
|
|
70
|
+
await this.unload();
|
|
71
|
+
|
|
72
|
+
const runtimeConfig = getRuntimeConfig();
|
|
73
|
+
const weights = await this._loadBaseWeights(manifest, options, runtimeConfig);
|
|
74
|
+
this.baseManifest = manifest;
|
|
75
|
+
this.baseWeights = weights;
|
|
76
|
+
return weights;
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
async loadAdapter(name, source) {
|
|
80
|
+
const adapter = await this._resolveAdapterSource(source);
|
|
81
|
+
|
|
51
82
|
const adapterName = name || adapter.name;
|
|
52
83
|
this.adapters.set(adapterName, adapter);
|
|
53
84
|
return adapter;
|
|
@@ -68,11 +99,26 @@ export class MultiModelLoader {
|
|
|
68
99
|
if (!this.baseManifest || !this.baseWeights) {
|
|
69
100
|
throw new Error('Base model not loaded');
|
|
70
101
|
}
|
|
71
|
-
const pipeline =
|
|
72
|
-
|
|
73
|
-
pipeline.
|
|
74
|
-
|
|
75
|
-
|
|
102
|
+
const pipeline = this._createPipeline();
|
|
103
|
+
const unloadPipeline = pipeline.unload.bind(pipeline);
|
|
104
|
+
pipeline.unload = async () => {
|
|
105
|
+
try {
|
|
106
|
+
await unloadPipeline();
|
|
107
|
+
} finally {
|
|
108
|
+
this.#pipelines.delete(pipeline);
|
|
109
|
+
}
|
|
110
|
+
};
|
|
111
|
+
|
|
112
|
+
try {
|
|
113
|
+
await pipeline.initialize(contexts);
|
|
114
|
+
pipeline.setPreloadedWeights(this.baseWeights);
|
|
115
|
+
await pipeline.loadModel(this.baseManifest);
|
|
116
|
+
this.#pipelines.add(pipeline);
|
|
117
|
+
return pipeline;
|
|
118
|
+
} catch (error) {
|
|
119
|
+
await pipeline.unload().catch(() => {});
|
|
120
|
+
throw error;
|
|
121
|
+
}
|
|
76
122
|
}
|
|
77
123
|
|
|
78
124
|
|