@simulatte/doppler 0.1.5 → 0.1.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +126 -0
- package/README.md +25 -17
- package/package.json +20 -4
- package/src/adapters/adapter-registry.js +12 -1
- package/src/adapters/lora-loader.js +23 -6
- package/src/bridge/extension-client.d.ts +5 -0
- package/src/bridge/extension-client.js +40 -0
- package/src/bridge/index.d.ts +2 -1
- package/src/bridge/index.js +6 -4
- package/src/browser/browser-converter.js +26 -1
- package/src/browser/file-picker.js +6 -0
- package/src/browser/safetensors-parser-browser.js +84 -1
- package/src/browser/shard-io-browser.js +2 -2
- package/src/browser/tensor-source-download.js +8 -2
- package/src/browser/tensor-source-http.d.ts +1 -0
- package/src/browser/tensor-source-http.js +5 -1
- package/src/client/doppler-api.browser.js +20 -4
- package/src/client/doppler-api.js +19 -3
- package/src/client/doppler-provider/generation.js +12 -0
- package/src/client/doppler-provider/model-manager.d.ts +10 -0
- package/src/client/doppler-provider/model-manager.js +91 -19
- package/src/client/doppler-provider/source-runtime.d.ts +2 -1
- package/src/client/doppler-provider/source-runtime.js +132 -13
- package/src/client/doppler-registry.json +8 -7
- package/src/config/backward-registry-loader.js +17 -2
- package/src/config/execution-v0-contract-check.js +113 -15
- package/src/config/kernel-path-contract-check.js +57 -29
- package/src/config/kernel-path-loader.js +5 -36
- package/src/config/kernels/kernel-ref-digests.js +39 -39
- package/src/config/kernels/registry.js +14 -1
- package/src/config/kernels/registry.json +49 -7
- package/src/config/loader.d.ts +1 -1
- package/src/config/loader.js +43 -4
- package/src/config/merge-contract-check.js +59 -4
- package/src/config/merge-helpers.js +128 -7
- package/src/config/merge.d.ts +1 -0
- package/src/config/merge.js +28 -0
- package/src/config/param-validator.js +47 -2
- package/src/config/presets/kernel-paths/{gemma2-q4k-dequant-f32a.json → gemma2-q4k-dequant-f32a-nosubgroups.json} +3 -3
- package/src/config/presets/kernel-paths/gemma3-f16-fused-f32a-online-streamingprefill.json +223 -0
- package/src/config/presets/kernel-paths/{gemma3-q4k-dequant-f32a.json → gemma3-q4k-dequant-f32a-nosubgroups.json} +3 -3
- package/src/config/presets/kernel-paths/registry.json +29 -8
- package/src/config/presets/models/gemma2.json +2 -2
- package/src/config/presets/models/qwen3.json +9 -2
- package/src/config/presets/models/transformer.json +5 -0
- package/src/config/presets/runtime/experiments/bench/gemma3-bench-q4k.json +1 -1
- package/src/config/presets/runtime/experiments/debug/gemma3-debug-q4k.json +1 -1
- package/src/config/presets/runtime/experiments/verify/gemma3-verify.json +1 -1
- package/src/config/presets/runtime/kernels/dequant-f16-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/dequant-f32-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/embeddinggemma-q4k-dequant-f32a.json +37 -0
- package/src/config/presets/runtime/kernels/fused-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f16a.json +33 -0
- package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f32a-nosubgroups.json +33 -0
- package/src/config/presets/runtime/kernels/gemma2-q4k-fused-f32a.json +33 -0
- package/src/config/presets/runtime/kernels/safe-q4k.json +6 -13
- package/src/config/presets/runtime/platform/metal-apple-q4k.json +1 -1
- package/src/config/required-inference-fields-contract-check.js +6 -0
- package/src/config/runtime.js +6 -1
- package/src/config/schema/debug.schema.d.ts +5 -0
- package/src/config/schema/doppler.schema.js +16 -21
- package/src/config/schema/inference-defaults.schema.js +6 -3
- package/src/config/schema/inference.schema.d.ts +9 -0
- package/src/config/schema/kernel-path.schema.d.ts +11 -1
- package/src/config/schema/kernel-thresholds.schema.js +12 -4
- package/src/config/schema/manifest.schema.d.ts +8 -1
- package/src/config/schema/manifest.schema.js +19 -3
- package/src/config/training-defaults.js +30 -22
- package/src/converter/conversion-plan.js +94 -9
- package/src/converter/core.d.ts +7 -0
- package/src/converter/core.js +14 -9
- package/src/converter/execution-v0-manifest.js +4 -1
- package/src/converter/index.d.ts +1 -0
- package/src/converter/index.js +1 -0
- package/src/converter/manifest-inference.js +43 -12
- package/src/converter/parsers/diffusion.js +0 -3
- package/src/converter/quantization-info.js +35 -15
- package/src/converter/rope-config.js +42 -0
- package/src/converter/shard-packer.d.ts +1 -1
- package/src/converter/shard-packer.js +4 -1
- package/src/debug/config.js +123 -11
- package/src/debug/signals.js +7 -1
- package/src/debug/tensor.d.ts +2 -0
- package/src/debug/tensor.js +13 -2
- package/src/distribution/p2p-control-plane.js +52 -12
- package/src/distribution/p2p-observability.js +43 -7
- package/src/distribution/p2p-webrtc-browser.js +20 -0
- package/src/distribution/shard-delivery.js +77 -26
- package/src/formats/gguf/types.js +33 -16
- package/src/formats/rdrr/groups.d.ts +12 -4
- package/src/formats/rdrr/groups.js +3 -6
- package/src/formats/rdrr/parsing.js +39 -2
- package/src/formats/rdrr/types.d.ts +2 -1
- package/src/gpu/command-recorder.js +86 -61
- package/src/gpu/device.d.ts +1 -0
- package/src/gpu/device.js +131 -19
- package/src/gpu/kernel-tuner/benchmarks.js +326 -316
- package/src/gpu/kernel-tuner/cache.js +71 -4
- package/src/gpu/kernel-tuner/tuner.js +22 -4
- package/src/gpu/kernels/attention.js +113 -34
- package/src/gpu/kernels/backward/adam.js +62 -58
- package/src/gpu/kernels/backward/attention_backward.js +257 -169
- package/src/gpu/kernels/backward/conv2d_backward.js +14 -1
- package/src/gpu/kernels/bias_add.wgsl +8 -6
- package/src/gpu/kernels/bias_add_f16.wgsl +8 -5
- package/src/gpu/kernels/cast.js +191 -149
- package/src/gpu/kernels/check-stop.js +33 -44
- package/src/gpu/kernels/conv2d.js +27 -17
- package/src/gpu/kernels/conv2d.wgsl +7 -8
- package/src/gpu/kernels/conv2d_f16.wgsl +7 -8
- package/src/gpu/kernels/cross_entropy_loss.js +21 -15
- package/src/gpu/kernels/depthwise_conv2d.js +37 -26
- package/src/gpu/kernels/depthwise_conv2d.wgsl +6 -9
- package/src/gpu/kernels/depthwise_conv2d_f16.wgsl +6 -9
- package/src/gpu/kernels/dequant.js +178 -126
- package/src/gpu/kernels/energy.d.ts +3 -21
- package/src/gpu/kernels/energy.js +111 -88
- package/src/gpu/kernels/feature-check.js +1 -1
- package/src/gpu/kernels/fused_ffn.js +84 -65
- package/src/gpu/kernels/fused_matmul_residual.js +56 -33
- package/src/gpu/kernels/fused_matmul_rmsnorm.js +62 -45
- package/src/gpu/kernels/gather.js +33 -15
- package/src/gpu/kernels/gelu.js +19 -11
- package/src/gpu/kernels/grouped_pointwise_conv2d.js +34 -23
- package/src/gpu/kernels/grouped_pointwise_conv2d.wgsl +6 -9
- package/src/gpu/kernels/grouped_pointwise_conv2d_f16.wgsl +6 -9
- package/src/gpu/kernels/groupnorm.js +34 -23
- package/src/gpu/kernels/kv-quantize.js +5 -2
- package/src/gpu/kernels/layernorm.js +35 -19
- package/src/gpu/kernels/logit-merge.js +5 -3
- package/src/gpu/kernels/matmul.js +83 -39
- package/src/gpu/kernels/modulate.js +23 -15
- package/src/gpu/kernels/moe.js +221 -175
- package/src/gpu/kernels/pixel_shuffle.js +22 -14
- package/src/gpu/kernels/pixel_shuffle.wgsl +4 -5
- package/src/gpu/kernels/pixel_shuffle_f16.wgsl +4 -5
- package/src/gpu/kernels/relu.js +31 -10
- package/src/gpu/kernels/relu.wgsl +2 -1
- package/src/gpu/kernels/relu_f16.wgsl +2 -1
- package/src/gpu/kernels/repeat_channels.js +25 -17
- package/src/gpu/kernels/repeat_channels.wgsl +4 -5
- package/src/gpu/kernels/repeat_channels_f16.wgsl +4 -5
- package/src/gpu/kernels/residual.js +69 -23
- package/src/gpu/kernels/residual.wgsl +6 -3
- package/src/gpu/kernels/residual_f16.wgsl +2 -1
- package/src/gpu/kernels/residual_f16_vec4.wgsl +2 -1
- package/src/gpu/kernels/residual_vec4.wgsl +2 -1
- package/src/gpu/kernels/rmsnorm.js +96 -28
- package/src/gpu/kernels/rmsnorm.wgsl +14 -6
- package/src/gpu/kernels/rmsnorm_f16.wgsl +10 -2
- package/src/gpu/kernels/rope.d.ts +2 -0
- package/src/gpu/kernels/rope.js +14 -1
- package/src/gpu/kernels/rope.wgsl +56 -40
- package/src/gpu/kernels/sample.js +27 -38
- package/src/gpu/kernels/sana_linear_attention.js +19 -12
- package/src/gpu/kernels/sana_linear_attention_apply.wgsl +4 -5
- package/src/gpu/kernels/sana_linear_attention_apply_f16.wgsl +4 -5
- package/src/gpu/kernels/sana_linear_attention_summary.wgsl +4 -0
- package/src/gpu/kernels/sana_linear_attention_summary_f16.wgsl +4 -0
- package/src/gpu/kernels/scale.js +18 -11
- package/src/gpu/kernels/shader-cache.js +4 -2
- package/src/gpu/kernels/silu.d.ts +1 -0
- package/src/gpu/kernels/silu.js +148 -82
- package/src/gpu/kernels/silu.wgsl +19 -9
- package/src/gpu/kernels/silu_f16.wgsl +19 -9
- package/src/gpu/kernels/softmax.js +44 -25
- package/src/gpu/kernels/split_qkv.js +23 -13
- package/src/gpu/kernels/transpose.js +31 -10
- package/src/gpu/kernels/transpose.wgsl +6 -5
- package/src/gpu/kernels/upsample2d.js +22 -13
- package/src/gpu/kernels/upsample2d.wgsl +6 -9
- package/src/gpu/kernels/upsample2d_f16.wgsl +6 -9
- package/src/gpu/kernels/utils.js +35 -13
- package/src/gpu/partitioned-buffer-pool.js +10 -2
- package/src/gpu/perf-guards.js +2 -9
- package/src/gpu/profiler.js +27 -22
- package/src/gpu/readback-utils.d.ts +16 -0
- package/src/gpu/readback-utils.js +41 -0
- package/src/gpu/submit-tracker.js +13 -0
- package/src/gpu/uniform-cache.d.ts +1 -0
- package/src/gpu/uniform-cache.js +30 -9
- package/src/hotswap/intent-bundle.js +6 -0
- package/src/hotswap/manifest.d.ts +10 -1
- package/src/hotswap/manifest.js +12 -2
- package/src/hotswap/runtime.js +30 -8
- package/src/index-browser.d.ts +44 -0
- package/src/index-browser.js +14 -0
- package/src/inference/browser-harness-contract-helpers.d.ts +5 -0
- package/src/inference/browser-harness-contract-helpers.js +28 -0
- package/src/inference/browser-harness-diffusion-energy-suites.d.ts +2 -0
- package/src/inference/browser-harness-diffusion-energy-suites.js +269 -0
- package/src/inference/browser-harness-model-helpers.d.ts +16 -0
- package/src/inference/browser-harness-model-helpers.js +217 -0
- package/src/inference/browser-harness-report-helpers.d.ts +7 -0
- package/src/inference/browser-harness-report-helpers.js +42 -0
- package/src/inference/browser-harness-runtime-helpers.d.ts +61 -0
- package/src/inference/browser-harness-runtime-helpers.js +415 -0
- package/src/inference/browser-harness-suite-helpers.d.ts +28 -0
- package/src/inference/browser-harness-suite-helpers.js +268 -0
- package/src/inference/browser-harness-text-helpers.d.ts +27 -0
- package/src/inference/browser-harness-text-helpers.js +788 -0
- package/src/inference/browser-harness.d.ts +6 -0
- package/src/inference/browser-harness.js +130 -1950
- package/src/inference/kv-cache/base.js +140 -94
- package/src/inference/kv-cache/tiered.js +5 -3
- package/src/inference/moe-router.js +88 -56
- package/src/inference/multi-model-network.js +5 -3
- package/src/inference/network-evolution.d.ts +11 -2
- package/src/inference/network-evolution.js +20 -21
- package/src/inference/pipelines/context.d.ts +3 -0
- package/src/inference/pipelines/context.js +142 -2
- package/src/inference/pipelines/diffusion/helpers.js +7 -2
- package/src/inference/pipelines/diffusion/pipeline.js +17 -7
- package/src/inference/pipelines/diffusion/sd3-transformer.js +10 -10
- package/src/inference/pipelines/diffusion/text-encoder-gpu.d.ts +5 -0
- package/src/inference/pipelines/diffusion/text-encoder-gpu.js +27 -15
- package/src/inference/pipelines/diffusion/vae.js +3 -7
- package/src/inference/pipelines/energy/pipeline.js +27 -21
- package/src/inference/pipelines/energy/quintel.d.ts +5 -0
- package/src/inference/pipelines/energy/quintel.js +11 -0
- package/src/inference/pipelines/energy-head/row-head-pipeline.js +17 -13
- package/src/inference/pipelines/structured/json-head-pipeline.js +26 -11
- package/src/inference/pipelines/text/attention/projections.js +151 -101
- package/src/inference/pipelines/text/attention/record.js +73 -10
- package/src/inference/pipelines/text/attention/run.js +73 -10
- package/src/inference/pipelines/text/chat-format.js +25 -1
- package/src/inference/pipelines/text/config.d.ts +4 -0
- package/src/inference/pipelines/text/config.js +71 -5
- package/src/inference/pipelines/text/embed.js +2 -8
- package/src/inference/pipelines/text/execution-plan.js +64 -50
- package/src/inference/pipelines/text/execution-v0-contract-helpers.d.ts +59 -0
- package/src/inference/pipelines/text/execution-v0-contract-helpers.js +937 -0
- package/src/inference/pipelines/text/execution-v0-runtime-builders.d.ts +15 -0
- package/src/inference/pipelines/text/execution-v0-runtime-builders.js +279 -0
- package/src/inference/pipelines/text/execution-v0.js +78 -1002
- package/src/inference/pipelines/text/ffn/standard.js +3 -0
- package/src/inference/pipelines/text/generator-steps.d.ts +46 -0
- package/src/inference/pipelines/text/generator-steps.js +298 -207
- package/src/inference/pipelines/text/generator.js +6 -23
- package/src/inference/pipelines/text/init.d.ts +4 -0
- package/src/inference/pipelines/text/init.js +134 -29
- package/src/inference/pipelines/text/kernel-path-auto-select.js +2 -0
- package/src/inference/pipelines/text/kernel-trace.d.ts +2 -0
- package/src/inference/pipelines/text/kernel-trace.js +6 -0
- package/src/inference/pipelines/text/layer.js +14 -9
- package/src/inference/pipelines/text/linear-attention.d.ts +10 -0
- package/src/inference/pipelines/text/linear-attention.js +80 -6
- package/src/inference/pipelines/text/logits/gpu.js +10 -5
- package/src/inference/pipelines/text/logits/index.js +10 -11
- package/src/inference/pipelines/text/logits/utils.d.ts +7 -0
- package/src/inference/pipelines/text/logits/utils.js +9 -0
- package/src/inference/pipelines/text/lora-apply.js +50 -32
- package/src/inference/pipelines/text/model-load.js +279 -104
- package/src/inference/pipelines/text/moe-cache.js +5 -4
- package/src/inference/pipelines/text/moe-cpu-gptoss.js +74 -69
- package/src/inference/pipelines/text/moe-cpu.js +42 -38
- package/src/inference/pipelines/text/moe-gpu.js +110 -86
- package/src/inference/pipelines/text/ops.js +90 -90
- package/src/inference/pipelines/text/probes.js +9 -9
- package/src/inference/pipelines/text/weights.js +17 -7
- package/src/inference/pipelines/text.js +17 -1
- package/src/inference/speculative.d.ts +2 -2
- package/src/inference/speculative.js +4 -18
- package/src/inference/test-harness.d.ts +1 -1
- package/src/inference/test-harness.js +15 -5
- package/src/inference/tokenizer.d.ts +0 -5
- package/src/inference/tokenizer.js +4 -23
- package/src/inference/tokenizers/bpe.js +9 -0
- package/src/inference/tokenizers/bundled.js +176 -33
- package/src/inference/tokenizers/sentencepiece.js +12 -0
- package/src/loader/doppler-loader.js +38 -22
- package/src/loader/dtype-utils.js +3 -44
- package/src/loader/embedding-loader.js +7 -3
- package/src/loader/experts/expert-cache.js +13 -6
- package/src/loader/experts/expert-loader.js +10 -6
- package/src/loader/final-weights-loader.js +8 -4
- package/src/loader/layer-loader.js +2 -1
- package/src/loader/loader-state.js +2 -2
- package/src/loader/memory-monitor.js +8 -0
- package/src/loader/multi-model-loader.d.ts +14 -0
- package/src/loader/multi-model-loader.js +70 -24
- package/src/loader/shard-cache.js +81 -12
- package/src/loader/shard-resolver.js +25 -3
- package/src/loader/tensors/tensor-loader.js +209 -144
- package/src/loader/tensors/tensor-reader.js +76 -19
- package/src/loader/weight-downcast.js +1 -1
- package/src/memory/buffer-pool.d.ts +9 -1
- package/src/memory/buffer-pool.js +109 -44
- package/src/memory/unified-detect.js +1 -1
- package/src/rules/inference/kernel-path.rules.json +24 -8
- package/src/rules/rule-registry.js +25 -1
- package/src/rules/tooling/command-runtime.rules.json +18 -0
- package/src/storage/backends/opfs-store.js +68 -24
- package/src/storage/downloader.js +364 -83
- package/src/storage/index.d.ts +3 -0
- package/src/storage/index.js +3 -0
- package/src/storage/preflight.d.ts +2 -2
- package/src/storage/preflight.js +24 -2
- package/src/storage/quickstart-downloader.js +11 -5
- package/src/storage/registry.js +10 -4
- package/src/storage/reports.js +1 -1
- package/src/storage/shard-manager.d.ts +15 -1
- package/src/storage/shard-manager.js +51 -3
- package/src/storage/source-artifact-store.d.ts +52 -0
- package/src/storage/source-artifact-store.js +234 -0
- package/src/tooling/command-api-constants.d.ts +9 -0
- package/src/tooling/command-api-constants.js +9 -0
- package/src/tooling/command-api-family-normalizers.d.ts +9 -0
- package/src/tooling/command-api-family-normalizers.js +343 -0
- package/src/tooling/command-api-helpers.d.ts +25 -0
- package/src/tooling/command-api-helpers.js +262 -0
- package/src/tooling/command-api.d.ts +27 -1
- package/src/tooling/command-api.js +26 -473
- package/src/tooling/command-envelope.js +4 -1
- package/src/tooling/command-runner-shared.js +52 -18
- package/src/tooling/lean-execution-contract.js +150 -3
- package/src/tooling/node-browser-command-runner.d.ts +4 -0
- package/src/tooling/node-browser-command-runner.js +218 -273
- package/src/tooling/node-command-runner.js +44 -3
- package/src/tooling/node-converter.js +27 -1
- package/src/tooling/node-source-runtime.d.ts +1 -1
- package/src/tooling/node-source-runtime.js +84 -3
- package/src/tooling/node-webgpu.js +30 -105
- package/src/tooling/opfs-cache.js +21 -4
- package/src/tooling/runtime-input-composition.d.ts +38 -0
- package/src/tooling/runtime-input-composition.js +86 -0
- package/src/tooling/source-runtime-bundle.d.ts +40 -5
- package/src/tooling/source-runtime-bundle.js +261 -34
- package/src/tooling/source-runtime-materializer.d.ts +6 -0
- package/src/tooling/source-runtime-materializer.js +93 -0
- package/src/training/attention-backward.js +32 -17
- package/src/training/autograd.js +80 -52
- package/src/training/checkpoint-watch.d.ts +8 -0
- package/src/training/checkpoint-watch.js +139 -0
- package/src/training/checkpoint.d.ts +6 -1
- package/src/training/checkpoint.js +46 -7
- package/src/training/clip.js +2 -1
- package/src/training/datasets/token-batch.js +20 -8
- package/src/training/distillation/artifacts.d.ts +71 -0
- package/src/training/distillation/artifacts.js +132 -0
- package/src/training/distillation/checkpoint-watch.d.ts +10 -0
- package/src/training/distillation/checkpoint-watch.js +58 -0
- package/src/training/distillation/dataset.d.ts +59 -0
- package/src/training/distillation/dataset.js +337 -0
- package/src/training/distillation/eval.d.ts +34 -0
- package/src/training/distillation/eval.js +310 -0
- package/src/training/distillation/index.d.ts +29 -0
- package/src/training/distillation/index.js +29 -0
- package/src/training/distillation/runtime.d.ts +20 -0
- package/src/training/distillation/runtime.js +121 -0
- package/src/training/distillation/scoreboard.d.ts +6 -0
- package/src/training/distillation/scoreboard.js +8 -0
- package/src/training/distillation/stage-a.d.ts +45 -0
- package/src/training/distillation/stage-a.js +338 -0
- package/src/training/distillation/stage-b.d.ts +24 -0
- package/src/training/distillation/stage-b.js +20 -0
- package/src/training/distillation/student-fixture.d.ts +22 -0
- package/src/training/distillation/student-fixture.js +846 -0
- package/src/training/distillation/suite-data.d.ts +45 -0
- package/src/training/distillation/suite-data.js +189 -0
- package/src/training/index.d.ts +10 -0
- package/src/training/index.js +10 -0
- package/src/training/lora-pipeline.d.ts +40 -0
- package/src/training/lora-pipeline.js +793 -0
- package/src/training/lora.js +26 -12
- package/src/training/loss.js +5 -6
- package/src/training/objectives/cross_entropy.js +2 -5
- package/src/training/objectives/distill_kd.js +4 -8
- package/src/training/objectives/distill_triplet.js +4 -8
- package/src/training/objectives/ul_stage2_base.js +4 -8
- package/src/training/operator-artifacts.d.ts +62 -0
- package/src/training/operator-artifacts.js +140 -0
- package/src/training/operator-command.d.ts +5 -0
- package/src/training/operator-command.js +455 -0
- package/src/training/operator-eval.d.ts +48 -0
- package/src/training/operator-eval.js +230 -0
- package/src/training/operator-scoreboard.d.ts +5 -0
- package/src/training/operator-scoreboard.js +44 -0
- package/src/training/optimizer.js +19 -7
- package/src/training/runner.d.ts +52 -0
- package/src/training/runner.js +31 -5
- package/src/training/suite.d.ts +112 -0
- package/src/training/suite.js +24 -984
- package/src/training/tensor-factory.d.ts +9 -0
- package/src/training/tensor-factory.js +13 -0
- package/src/training/trainer.js +3 -5
- package/src/training/ul_dataset.js +3 -5
- package/src/training/workloads.d.ts +164 -0
- package/src/training/workloads.js +530 -0
- package/src/version.js +1 -1
- package/tools/convert-safetensors-node.js +22 -16
- package/tools/doppler-cli.js +179 -63
|
@@ -20,6 +20,10 @@ const HEAD_GROUP = 'head';
|
|
|
20
20
|
const FINAL_NORM_ROLE = 'norm';
|
|
21
21
|
const LM_HEAD_ROLE = 'lm_head';
|
|
22
22
|
|
|
23
|
+
function isGpuBufferInstance(value) {
|
|
24
|
+
return typeof GPUBuffer !== 'undefined' && value instanceof GPUBuffer;
|
|
25
|
+
}
|
|
26
|
+
|
|
23
27
|
function isLikelyFinalNormName(name) {
|
|
24
28
|
const lower = String(name || '').toLowerCase();
|
|
25
29
|
if (!lower) return false;
|
|
@@ -148,7 +152,7 @@ async function loadLmHead(ctx) {
|
|
|
148
152
|
);
|
|
149
153
|
}
|
|
150
154
|
|
|
151
|
-
if (tensor && (tensor
|
|
155
|
+
if (tensor && (isGpuBufferInstance(tensor) || isWeightBuffer(tensor) || tensor instanceof Float32Array)) {
|
|
152
156
|
lmHeadName = name;
|
|
153
157
|
lmHeadLoc = loc;
|
|
154
158
|
lmHead = processLmHeadTensor(ctx, tensor, name, loc, shouldStream);
|
|
@@ -189,7 +193,7 @@ function processLmHeadTensor(ctx, tensor, name, loc, shouldStream) {
|
|
|
189
193
|
}
|
|
190
194
|
|
|
191
195
|
// Raw GPUBuffer - wrap with dtype/layout metadata
|
|
192
|
-
if (tensor
|
|
196
|
+
if (isGpuBufferInstance(tensor) && loc.shape && loc.shape.length === 2) {
|
|
193
197
|
const layout = ctx.resolveWeightLayout(loc);
|
|
194
198
|
|
|
195
199
|
const dtype = selectRuleValue('loader', 'weights', 'floatLocationDtype', {
|
|
@@ -209,7 +213,7 @@ async function maybeDowncastLmHead(ctx, lmHead, lmHeadName, lmHeadLoc) {
|
|
|
209
213
|
const tiedToEmbeddings =
|
|
210
214
|
lmHead === ctx.embeddings ||
|
|
211
215
|
(isWeightBuffer(lmHead) && isWeightBuffer(ctx.embeddings) && lmHead.buffer === ctx.embeddings.buffer) ||
|
|
212
|
-
(lmHead
|
|
216
|
+
(isGpuBufferInstance(lmHead) && isWeightBuffer(ctx.embeddings) && lmHead === ctx.embeddings.buffer);
|
|
213
217
|
|
|
214
218
|
if (tiedToEmbeddings) {
|
|
215
219
|
return lmHead;
|
|
@@ -234,7 +238,7 @@ async function maybeDowncastLmHead(ctx, lmHead, lmHeadName, lmHeadLoc) {
|
|
|
234
238
|
|
|
235
239
|
// Get buffer for downcast
|
|
236
240
|
const buffer = isWeightBuffer(lmHead) ? lmHead.buffer : lmHead;
|
|
237
|
-
if (!(buffer
|
|
241
|
+
if (!isGpuBufferInstance(buffer)) {
|
|
238
242
|
return lmHead;
|
|
239
243
|
}
|
|
240
244
|
|
|
@@ -224,7 +224,8 @@ function createTryLoad(ctx, prefixes) {
|
|
|
224
224
|
for (const prefix of prefixes) {
|
|
225
225
|
for (const suffix of suffixes) {
|
|
226
226
|
const tensor = await ctx.loadTensor(`${prefix}.${suffix}`, true, true);
|
|
227
|
-
|
|
227
|
+
const isGpuBuffer = typeof GPUBuffer !== 'undefined' && tensor instanceof GPUBuffer;
|
|
228
|
+
if (tensor && (isGpuBuffer || tensor instanceof Float32Array || isWeightBuffer(tensor))) {
|
|
228
229
|
return tensor;
|
|
229
230
|
}
|
|
230
231
|
}
|
|
@@ -122,14 +122,14 @@ export class LoaderState {
|
|
|
122
122
|
|
|
123
123
|
static getGPUBuffer(weight) {
|
|
124
124
|
if (!weight) return null;
|
|
125
|
-
if (weight instanceof GPUBuffer) return weight;
|
|
125
|
+
if (typeof GPUBuffer !== 'undefined' && weight instanceof GPUBuffer) return weight;
|
|
126
126
|
if (isWeightBuffer(weight)) return weight.buffer;
|
|
127
127
|
return null;
|
|
128
128
|
}
|
|
129
129
|
|
|
130
130
|
static isGPUBacked(weight) {
|
|
131
131
|
if (!weight) return false;
|
|
132
|
-
if (weight instanceof GPUBuffer) return true;
|
|
132
|
+
if (typeof GPUBuffer !== 'undefined' && weight instanceof GPUBuffer) return true;
|
|
133
133
|
if (isWeightBuffer(weight)) return true;
|
|
134
134
|
if (isCpuWeightBuffer(weight)) return false;
|
|
135
135
|
if (weight instanceof Float32Array) return false;
|
|
@@ -105,6 +105,10 @@ export class MemoryMonitor {
|
|
|
105
105
|
|
|
106
106
|
|
|
107
107
|
start(getState) {
|
|
108
|
+
if (this.#interval) {
|
|
109
|
+
clearInterval(this.#interval);
|
|
110
|
+
this.#interval = null;
|
|
111
|
+
}
|
|
108
112
|
this.#startTime = performance.now();
|
|
109
113
|
this.#snapshots = [];
|
|
110
114
|
this.#log('start', getState());
|
|
@@ -209,6 +213,10 @@ export class MemoryTimeSeries {
|
|
|
209
213
|
|
|
210
214
|
|
|
211
215
|
start() {
|
|
216
|
+
if (this.#interval) {
|
|
217
|
+
clearInterval(this.#interval);
|
|
218
|
+
this.#interval = null;
|
|
219
|
+
}
|
|
212
220
|
this.#startTime = performance.now();
|
|
213
221
|
this.#samples = [];
|
|
214
222
|
this.#capture('start');
|
|
@@ -22,6 +22,20 @@ export declare class MultiModelLoader {
|
|
|
22
22
|
baseWeights: WeightLoadResult | null;
|
|
23
23
|
adapters: Map<string, LoRAAdapter>;
|
|
24
24
|
|
|
25
|
+
_loadBaseWeights(
|
|
26
|
+
manifest: Manifest,
|
|
27
|
+
options: { storageContext?: { loadShard?: (index: number) => Promise<ArrayBuffer | Uint8Array> } },
|
|
28
|
+
runtimeConfig: unknown
|
|
29
|
+
): Promise<WeightLoadResult>;
|
|
30
|
+
|
|
31
|
+
_resolveAdapterSource(source: AdapterSource): Promise<LoRAAdapter>;
|
|
32
|
+
|
|
33
|
+
_createPipeline(): InferencePipeline;
|
|
34
|
+
|
|
35
|
+
_getBaseLoader(): { unload(): Promise<void> };
|
|
36
|
+
|
|
37
|
+
unload(): Promise<void>;
|
|
38
|
+
|
|
25
39
|
loadBase(
|
|
26
40
|
manifest: Manifest,
|
|
27
41
|
options?: { storageContext?: { loadShard?: (index: number) => Promise<ArrayBuffer | Uint8Array> } }
|
|
@@ -17,37 +17,68 @@ export class MultiModelLoader {
|
|
|
17
17
|
|
|
18
18
|
adapters = new Map();
|
|
19
19
|
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
const runtimeConfig = getRuntimeConfig();
|
|
20
|
+
#pipelines = new Set();
|
|
21
|
+
|
|
22
|
+
async _loadBaseWeights(manifest, options, runtimeConfig) {
|
|
24
23
|
const modelOverrides = (runtimeConfig.inference.modelOverrides);
|
|
25
24
|
const config = parseModelConfig(manifest, modelOverrides);
|
|
26
|
-
|
|
27
|
-
this.baseWeights = await loadWeights(manifest, config, {
|
|
25
|
+
return loadWeights(manifest, config, {
|
|
28
26
|
storageContext: options.storageContext,
|
|
29
27
|
keepF32Weights: runtimeConfig.inference.compute.keepF32Weights === true,
|
|
30
28
|
});
|
|
31
|
-
return this.baseWeights;
|
|
32
29
|
}
|
|
33
30
|
|
|
34
|
-
|
|
35
|
-
async loadAdapter(name, source) {
|
|
36
|
-
|
|
37
|
-
let adapter;
|
|
38
|
-
|
|
31
|
+
async _resolveAdapterSource(source) {
|
|
39
32
|
if (typeof source === 'string') {
|
|
40
|
-
|
|
41
|
-
}
|
|
33
|
+
return loadLoRAFromUrl(source);
|
|
34
|
+
}
|
|
35
|
+
if (this.#isRDRRManifest(source)) {
|
|
42
36
|
const loader = getDopplerLoader();
|
|
43
37
|
await loader.init();
|
|
44
|
-
|
|
45
|
-
}
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
38
|
+
return loader.loadLoRAWeights(source);
|
|
39
|
+
}
|
|
40
|
+
if (this.#isLoRAManifest(source)) {
|
|
41
|
+
return loadLoRAFromManifest(source);
|
|
42
|
+
}
|
|
43
|
+
return source;
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
_createPipeline() {
|
|
47
|
+
return new InferencePipeline();
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
_getBaseLoader() {
|
|
51
|
+
return getDopplerLoader();
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
async unload() {
|
|
55
|
+
const pipelines = Array.from(this.#pipelines);
|
|
56
|
+
this.#pipelines.clear();
|
|
57
|
+
await Promise.all(pipelines.map(async (pipeline) => pipeline.unload()));
|
|
58
|
+
|
|
59
|
+
if (this.baseWeights) {
|
|
60
|
+
const loader = this._getBaseLoader();
|
|
61
|
+
await loader.unload();
|
|
49
62
|
}
|
|
50
63
|
|
|
64
|
+
this.baseManifest = null;
|
|
65
|
+
this.baseWeights = null;
|
|
66
|
+
this.adapters.clear();
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
async loadBase(manifest, options = {}) {
|
|
70
|
+
await this.unload();
|
|
71
|
+
|
|
72
|
+
const runtimeConfig = getRuntimeConfig();
|
|
73
|
+
const weights = await this._loadBaseWeights(manifest, options, runtimeConfig);
|
|
74
|
+
this.baseManifest = manifest;
|
|
75
|
+
this.baseWeights = weights;
|
|
76
|
+
return weights;
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
async loadAdapter(name, source) {
|
|
80
|
+
const adapter = await this._resolveAdapterSource(source);
|
|
81
|
+
|
|
51
82
|
const adapterName = name || adapter.name;
|
|
52
83
|
this.adapters.set(adapterName, adapter);
|
|
53
84
|
return adapter;
|
|
@@ -68,11 +99,26 @@ export class MultiModelLoader {
|
|
|
68
99
|
if (!this.baseManifest || !this.baseWeights) {
|
|
69
100
|
throw new Error('Base model not loaded');
|
|
70
101
|
}
|
|
71
|
-
const pipeline =
|
|
72
|
-
|
|
73
|
-
pipeline.
|
|
74
|
-
|
|
75
|
-
|
|
102
|
+
const pipeline = this._createPipeline();
|
|
103
|
+
const unloadPipeline = pipeline.unload.bind(pipeline);
|
|
104
|
+
pipeline.unload = async () => {
|
|
105
|
+
try {
|
|
106
|
+
await unloadPipeline();
|
|
107
|
+
} finally {
|
|
108
|
+
this.#pipelines.delete(pipeline);
|
|
109
|
+
}
|
|
110
|
+
};
|
|
111
|
+
|
|
112
|
+
try {
|
|
113
|
+
await pipeline.initialize(contexts);
|
|
114
|
+
pipeline.setPreloadedWeights(this.baseWeights);
|
|
115
|
+
await pipeline.loadModel(this.baseManifest);
|
|
116
|
+
this.#pipelines.add(pipeline);
|
|
117
|
+
return pipeline;
|
|
118
|
+
} catch (error) {
|
|
119
|
+
await pipeline.unload().catch(() => {});
|
|
120
|
+
throw error;
|
|
121
|
+
}
|
|
76
122
|
}
|
|
77
123
|
|
|
78
124
|
|
|
@@ -23,6 +23,7 @@ export class ShardCache {
|
|
|
23
23
|
#inFlightLoads = 0;
|
|
24
24
|
#highPriorityQueue = [];
|
|
25
25
|
#lowPriorityQueue = [];
|
|
26
|
+
#epoch = 0;
|
|
26
27
|
|
|
27
28
|
lastSource = null;
|
|
28
29
|
|
|
@@ -123,6 +124,7 @@ export class ShardCache {
|
|
|
123
124
|
const shardInfo = this.#manifest?.shards?.[shardIndex];
|
|
124
125
|
const sizeStr = shardInfo ? formatBytes(shardInfo.size) : '';
|
|
125
126
|
const priority = options.priority === 'low' ? 'low' : 'high';
|
|
127
|
+
const epoch = this.#epoch;
|
|
126
128
|
|
|
127
129
|
// 1. Check cache first
|
|
128
130
|
if (this.#cache.has(shardIndex)) {
|
|
@@ -136,24 +138,29 @@ export class ShardCache {
|
|
|
136
138
|
}
|
|
137
139
|
|
|
138
140
|
// 2. Check if fetch is already in-flight - deduplicate concurrent requests
|
|
139
|
-
|
|
141
|
+
const inFlight = this.#fetchPromises.get(shardIndex);
|
|
142
|
+
if (inFlight && inFlight.epoch === epoch) {
|
|
140
143
|
log.verbose('ShardCache', `Shard ${shardIndex}: waiting for in-flight fetch`);
|
|
141
|
-
return
|
|
144
|
+
return inFlight.promise;
|
|
142
145
|
}
|
|
143
146
|
|
|
144
147
|
// 3. Start the actual fetch and store the promise for deduplication
|
|
145
148
|
const fetchPromise = this.#scheduleLoad(
|
|
146
149
|
priority,
|
|
147
|
-
|
|
150
|
+
epoch,
|
|
151
|
+
() => this.#doLoad(shardIndex, sizeStr, epoch)
|
|
148
152
|
);
|
|
149
|
-
|
|
153
|
+
const fetchEntry = { epoch, promise: fetchPromise };
|
|
154
|
+
this.#fetchPromises.set(shardIndex, fetchEntry);
|
|
150
155
|
|
|
151
156
|
try {
|
|
152
157
|
const result = await fetchPromise;
|
|
153
158
|
return result;
|
|
154
159
|
} finally {
|
|
155
160
|
// Remove from in-flight map when done (success or error)
|
|
156
|
-
this.#fetchPromises.
|
|
161
|
+
if (this.#fetchPromises.get(shardIndex) === fetchEntry) {
|
|
162
|
+
this.#fetchPromises.delete(shardIndex);
|
|
163
|
+
}
|
|
157
164
|
}
|
|
158
165
|
}
|
|
159
166
|
|
|
@@ -195,6 +202,13 @@ export class ShardCache {
|
|
|
195
202
|
throw new Error('Custom shard loader must return ArrayBuffer or Uint8Array.');
|
|
196
203
|
}
|
|
197
204
|
|
|
205
|
+
#throwShortStreamRead(shardIndex, start, want, produced, path) {
|
|
206
|
+
throw new Error(
|
|
207
|
+
`Shard ${shardIndex} short stream read via ${path}: ` +
|
|
208
|
+
`offset=${start}, expected=${want}, got=${produced}.`
|
|
209
|
+
);
|
|
210
|
+
}
|
|
211
|
+
|
|
198
212
|
async loadRange(shardIndex, offset = 0, length = null, options = {}) {
|
|
199
213
|
const start = this.#toRangeOffset(offset);
|
|
200
214
|
const want = length == null ? null : this.#toRangeOffset(length);
|
|
@@ -276,9 +290,15 @@ export class ShardCache {
|
|
|
276
290
|
this.#setLastSource('RAM', 0, 'stream', 'cache');
|
|
277
291
|
const view = new Uint8Array(cached);
|
|
278
292
|
const end = want == null ? view.length : Math.min(view.length, start + want);
|
|
293
|
+
let produced = 0;
|
|
279
294
|
for (let cursor = start; cursor < end; cursor += chunkBytes) {
|
|
280
295
|
const sliceEnd = Math.min(end, cursor + chunkBytes);
|
|
281
|
-
|
|
296
|
+
const chunk = view.slice(cursor, sliceEnd);
|
|
297
|
+
produced += chunk.byteLength;
|
|
298
|
+
yield chunk;
|
|
299
|
+
}
|
|
300
|
+
if (want != null && produced < want) {
|
|
301
|
+
this.#throwShortStreamRead(shardIndex, start, want, produced, 'cache');
|
|
282
302
|
}
|
|
283
303
|
return;
|
|
284
304
|
}
|
|
@@ -323,6 +343,15 @@ export class ShardCache {
|
|
|
323
343
|
resumed += bytes.byteLength;
|
|
324
344
|
yield bytes;
|
|
325
345
|
}
|
|
346
|
+
if (want != null && produced + resumed < want) {
|
|
347
|
+
this.#throwShortStreamRead(
|
|
348
|
+
shardIndex,
|
|
349
|
+
start,
|
|
350
|
+
want,
|
|
351
|
+
produced + resumed,
|
|
352
|
+
'custom-range-fallback'
|
|
353
|
+
);
|
|
354
|
+
}
|
|
326
355
|
const elapsed = (performance.now() - streamStart) / 1000;
|
|
327
356
|
this.#setLastSource(
|
|
328
357
|
'custom',
|
|
@@ -358,6 +387,15 @@ export class ShardCache {
|
|
|
358
387
|
resumed += bytes.byteLength;
|
|
359
388
|
yield bytes;
|
|
360
389
|
}
|
|
390
|
+
if (produced + resumed < want) {
|
|
391
|
+
this.#throwShortStreamRead(
|
|
392
|
+
shardIndex,
|
|
393
|
+
start,
|
|
394
|
+
want,
|
|
395
|
+
produced + resumed,
|
|
396
|
+
'custom-range-fallback'
|
|
397
|
+
);
|
|
398
|
+
}
|
|
361
399
|
const elapsed = (performance.now() - streamStart) / 1000;
|
|
362
400
|
this.#setLastSource(
|
|
363
401
|
'custom',
|
|
@@ -369,6 +407,9 @@ export class ShardCache {
|
|
|
369
407
|
return;
|
|
370
408
|
}
|
|
371
409
|
|
|
410
|
+
if (want != null && produced < want) {
|
|
411
|
+
this.#throwShortStreamRead(shardIndex, start, want, produced, 'custom-stream');
|
|
412
|
+
}
|
|
372
413
|
const elapsed = (performance.now() - streamStart) / 1000;
|
|
373
414
|
this.#setLastSource('custom', elapsed, 'stream', 'custom-stream');
|
|
374
415
|
return;
|
|
@@ -403,6 +444,9 @@ export class ShardCache {
|
|
|
403
444
|
}
|
|
404
445
|
}
|
|
405
446
|
}
|
|
447
|
+
if (want != null && produced < want) {
|
|
448
|
+
this.#throwShortStreamRead(shardIndex, start, want, produced, 'custom-range');
|
|
449
|
+
}
|
|
406
450
|
this.#setLastSource(
|
|
407
451
|
'custom',
|
|
408
452
|
(performance.now() - rangeStart) / 1000,
|
|
@@ -414,8 +458,14 @@ export class ShardCache {
|
|
|
414
458
|
}
|
|
415
459
|
|
|
416
460
|
const streamStart = performance.now();
|
|
461
|
+
let produced = 0;
|
|
417
462
|
for await (const chunk of streamShardRangeFromStore(shardIndex, start, want, { chunkBytes })) {
|
|
418
|
-
|
|
463
|
+
const bytes = chunk instanceof Uint8Array ? chunk : new Uint8Array(chunk);
|
|
464
|
+
produced += bytes.byteLength;
|
|
465
|
+
yield bytes;
|
|
466
|
+
}
|
|
467
|
+
if (want != null && produced < want) {
|
|
468
|
+
this.#throwShortStreamRead(shardIndex, start, want, produced, 'backend-stream');
|
|
419
469
|
}
|
|
420
470
|
const elapsed = (performance.now() - streamStart) / 1000;
|
|
421
471
|
const backend = getStorageBackendType() ?? 'storage';
|
|
@@ -426,7 +476,7 @@ export class ShardCache {
|
|
|
426
476
|
return this.load(shardIndex, { priority: 'low' });
|
|
427
477
|
}
|
|
428
478
|
|
|
429
|
-
async #doLoad(shardIndex, sizeStr) {
|
|
479
|
+
async #doLoad(shardIndex, sizeStr, epoch) {
|
|
430
480
|
if (this.#customLoader) {
|
|
431
481
|
const startTime = performance.now();
|
|
432
482
|
let data = await this.#customLoader(shardIndex);
|
|
@@ -453,7 +503,9 @@ export class ShardCache {
|
|
|
453
503
|
// Normalize to ArrayBuffer for downstream slicing
|
|
454
504
|
const arrayBuffer = this.#toArrayBuffer(data);
|
|
455
505
|
|
|
456
|
-
this.#
|
|
506
|
+
if (epoch === this.#epoch) {
|
|
507
|
+
this.#add(shardIndex, arrayBuffer);
|
|
508
|
+
}
|
|
457
509
|
|
|
458
510
|
const elapsed = (performance.now() - startTime) / 1000;
|
|
459
511
|
this.#setLastSource('custom', elapsed, 'full', 'custom-loader');
|
|
@@ -463,7 +515,9 @@ export class ShardCache {
|
|
|
463
515
|
|
|
464
516
|
const storageStart = performance.now();
|
|
465
517
|
const data = await loadShardFromStore(shardIndex);
|
|
466
|
-
this.#
|
|
518
|
+
if (epoch === this.#epoch) {
|
|
519
|
+
this.#add(shardIndex, data);
|
|
520
|
+
}
|
|
467
521
|
const elapsed = (performance.now() - storageStart) / 1000;
|
|
468
522
|
const backend = getStorageBackendType() ?? 'storage';
|
|
469
523
|
this.#setLastSource(backend, elapsed, 'full', 'backend-full');
|
|
@@ -471,12 +525,15 @@ export class ShardCache {
|
|
|
471
525
|
return data;
|
|
472
526
|
}
|
|
473
527
|
|
|
474
|
-
async #scheduleLoad(priority, task) {
|
|
528
|
+
async #scheduleLoad(priority, epoch, task) {
|
|
475
529
|
const limit = this.#maxConcurrentLoads > 0
|
|
476
530
|
? this.#maxConcurrentLoads
|
|
477
531
|
: Number.POSITIVE_INFINITY;
|
|
478
532
|
|
|
479
533
|
if (this.#inFlightLoads < limit) {
|
|
534
|
+
if (epoch !== this.#epoch) {
|
|
535
|
+
throw new Error('Shard load invalidated by cache clear().');
|
|
536
|
+
}
|
|
480
537
|
this.#inFlightLoads++;
|
|
481
538
|
try {
|
|
482
539
|
return await task();
|
|
@@ -487,7 +544,7 @@ export class ShardCache {
|
|
|
487
544
|
}
|
|
488
545
|
|
|
489
546
|
return new Promise((resolve, reject) => {
|
|
490
|
-
const entry = { task, resolve, reject };
|
|
547
|
+
const entry = { task, resolve, reject, epoch };
|
|
491
548
|
if (priority === 'low') {
|
|
492
549
|
this.#lowPriorityQueue.push(entry);
|
|
493
550
|
} else {
|
|
@@ -504,6 +561,10 @@ export class ShardCache {
|
|
|
504
561
|
while (this.#inFlightLoads < limit) {
|
|
505
562
|
const entry = this.#highPriorityQueue.shift() ?? this.#lowPriorityQueue.shift();
|
|
506
563
|
if (!entry) return;
|
|
564
|
+
if (entry.epoch !== this.#epoch) {
|
|
565
|
+
entry.reject(new Error('Shard load invalidated by cache clear().'));
|
|
566
|
+
continue;
|
|
567
|
+
}
|
|
507
568
|
|
|
508
569
|
this.#inFlightLoads++;
|
|
509
570
|
Promise.resolve()
|
|
@@ -529,6 +590,14 @@ export class ShardCache {
|
|
|
529
590
|
clear() {
|
|
530
591
|
const count = this.#cache.size;
|
|
531
592
|
const bytes = this.totalBytes;
|
|
593
|
+
this.#epoch++;
|
|
594
|
+
const queued = [...this.#highPriorityQueue, ...this.#lowPriorityQueue];
|
|
595
|
+
this.#highPriorityQueue = [];
|
|
596
|
+
this.#lowPriorityQueue = [];
|
|
597
|
+
this.#fetchPromises.clear();
|
|
598
|
+
for (const entry of queued) {
|
|
599
|
+
entry.reject(new Error('Shard load invalidated by cache clear().'));
|
|
600
|
+
}
|
|
532
601
|
this.#cache.clear();
|
|
533
602
|
debugTrace.loader(`Cleared shard cache: ${count} shards, ${formatBytes(bytes)} freed`);
|
|
534
603
|
}
|
|
@@ -2,6 +2,28 @@ import { loadTensorsFromStore } from '../storage/shard-manager.js';
|
|
|
2
2
|
import { parseTensorMap } from '../formats/rdrr/index.js';
|
|
3
3
|
import { log, trace as debugTrace } from '../debug/index.js';
|
|
4
4
|
|
|
5
|
+
function normalizeLocationSpans(spans, name, sourceLabel) {
|
|
6
|
+
if (spans === undefined) {
|
|
7
|
+
return undefined;
|
|
8
|
+
}
|
|
9
|
+
if (!Array.isArray(spans)) {
|
|
10
|
+
throw new Error(`Tensor "${name}" has invalid spans in ${sourceLabel}`);
|
|
11
|
+
}
|
|
12
|
+
return spans.map((span, spanIndex) => {
|
|
13
|
+
const shardIndex = typeof span?.shardIndex === 'number'
|
|
14
|
+
? span.shardIndex
|
|
15
|
+
: span?.shard;
|
|
16
|
+
if (typeof shardIndex !== 'number') {
|
|
17
|
+
throw new Error(`Tensor "${name}" span[${spanIndex}] missing shard index in ${sourceLabel}`);
|
|
18
|
+
}
|
|
19
|
+
return {
|
|
20
|
+
shardIndex,
|
|
21
|
+
offset: span.offset,
|
|
22
|
+
size: span.size,
|
|
23
|
+
};
|
|
24
|
+
});
|
|
25
|
+
}
|
|
26
|
+
|
|
5
27
|
export async function buildTensorLocations(manifest, options = {}) {
|
|
6
28
|
const locations = new Map();
|
|
7
29
|
|
|
@@ -37,14 +59,14 @@ export async function buildTensorLocations(manifest, options = {}) {
|
|
|
37
59
|
throw new Error(`Tensor "${name}" missing role in tensors.json`);
|
|
38
60
|
}
|
|
39
61
|
locations.set(name, {
|
|
40
|
-
shardIndex: info.shard,
|
|
62
|
+
shardIndex: info.shardIndex ?? info.shard,
|
|
41
63
|
offset: info.offset,
|
|
42
64
|
size: info.size,
|
|
43
65
|
shape: info.shape,
|
|
44
66
|
dtype: info.dtype,
|
|
45
67
|
role: info.role,
|
|
46
68
|
group: info.group,
|
|
47
|
-
spans: info.spans,
|
|
69
|
+
spans: normalizeLocationSpans(info.spans, name, 'tensors.json'),
|
|
48
70
|
layout: info.layout,
|
|
49
71
|
originalShape: info.originalShape,
|
|
50
72
|
});
|
|
@@ -73,7 +95,7 @@ export async function buildTensorLocations(manifest, options = {}) {
|
|
|
73
95
|
dtype: tensorInfo.dtype,
|
|
74
96
|
role: tensorInfo.role,
|
|
75
97
|
group: tensorInfo.group,
|
|
76
|
-
spans: tensorInfo.spans,
|
|
98
|
+
spans: normalizeLocationSpans(tensorInfo.spans, name, 'manifest.tensors'),
|
|
77
99
|
layout: tensorInfo.layout,
|
|
78
100
|
originalShape: tensorInfo.originalShape,
|
|
79
101
|
});
|