@simulatte/doppler 0.1.5 → 0.1.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +126 -0
- package/README.md +25 -17
- package/package.json +20 -4
- package/src/adapters/adapter-registry.js +12 -1
- package/src/adapters/lora-loader.js +23 -6
- package/src/bridge/extension-client.d.ts +5 -0
- package/src/bridge/extension-client.js +40 -0
- package/src/bridge/index.d.ts +2 -1
- package/src/bridge/index.js +6 -4
- package/src/browser/browser-converter.js +26 -1
- package/src/browser/file-picker.js +6 -0
- package/src/browser/safetensors-parser-browser.js +84 -1
- package/src/browser/shard-io-browser.js +2 -2
- package/src/browser/tensor-source-download.js +8 -2
- package/src/browser/tensor-source-http.d.ts +1 -0
- package/src/browser/tensor-source-http.js +5 -1
- package/src/client/doppler-api.browser.js +20 -4
- package/src/client/doppler-api.js +19 -3
- package/src/client/doppler-provider/generation.js +12 -0
- package/src/client/doppler-provider/model-manager.d.ts +10 -0
- package/src/client/doppler-provider/model-manager.js +91 -19
- package/src/client/doppler-provider/source-runtime.d.ts +2 -1
- package/src/client/doppler-provider/source-runtime.js +132 -13
- package/src/client/doppler-registry.json +8 -7
- package/src/config/backward-registry-loader.js +17 -2
- package/src/config/execution-v0-contract-check.js +113 -15
- package/src/config/kernel-path-contract-check.js +57 -29
- package/src/config/kernel-path-loader.js +5 -36
- package/src/config/kernels/kernel-ref-digests.js +39 -39
- package/src/config/kernels/registry.js +14 -1
- package/src/config/kernels/registry.json +49 -7
- package/src/config/loader.d.ts +1 -1
- package/src/config/loader.js +43 -4
- package/src/config/merge-contract-check.js +59 -4
- package/src/config/merge-helpers.js +128 -7
- package/src/config/merge.d.ts +1 -0
- package/src/config/merge.js +28 -0
- package/src/config/param-validator.js +47 -2
- package/src/config/presets/kernel-paths/{gemma2-q4k-dequant-f32a.json → gemma2-q4k-dequant-f32a-nosubgroups.json} +3 -3
- package/src/config/presets/kernel-paths/gemma3-f16-fused-f32a-online-streamingprefill.json +223 -0
- package/src/config/presets/kernel-paths/{gemma3-q4k-dequant-f32a.json → gemma3-q4k-dequant-f32a-nosubgroups.json} +3 -3
- package/src/config/presets/kernel-paths/registry.json +29 -8
- package/src/config/presets/models/gemma2.json +2 -2
- package/src/config/presets/models/qwen3.json +9 -2
- package/src/config/presets/models/transformer.json +5 -0
- package/src/config/presets/runtime/experiments/bench/gemma3-bench-q4k.json +1 -1
- package/src/config/presets/runtime/experiments/debug/gemma3-debug-q4k.json +1 -1
- package/src/config/presets/runtime/experiments/verify/gemma3-verify.json +1 -1
- package/src/config/presets/runtime/kernels/dequant-f16-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/dequant-f32-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/embeddinggemma-q4k-dequant-f32a.json +37 -0
- package/src/config/presets/runtime/kernels/fused-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f16a.json +33 -0
- package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f32a-nosubgroups.json +33 -0
- package/src/config/presets/runtime/kernels/gemma2-q4k-fused-f32a.json +33 -0
- package/src/config/presets/runtime/kernels/safe-q4k.json +6 -13
- package/src/config/presets/runtime/platform/metal-apple-q4k.json +1 -1
- package/src/config/required-inference-fields-contract-check.js +6 -0
- package/src/config/runtime.js +6 -1
- package/src/config/schema/debug.schema.d.ts +5 -0
- package/src/config/schema/doppler.schema.js +16 -21
- package/src/config/schema/inference-defaults.schema.js +6 -3
- package/src/config/schema/inference.schema.d.ts +9 -0
- package/src/config/schema/kernel-path.schema.d.ts +11 -1
- package/src/config/schema/kernel-thresholds.schema.js +12 -4
- package/src/config/schema/manifest.schema.d.ts +8 -1
- package/src/config/schema/manifest.schema.js +19 -3
- package/src/config/training-defaults.js +30 -22
- package/src/converter/conversion-plan.js +94 -9
- package/src/converter/core.d.ts +7 -0
- package/src/converter/core.js +14 -9
- package/src/converter/execution-v0-manifest.js +4 -1
- package/src/converter/index.d.ts +1 -0
- package/src/converter/index.js +1 -0
- package/src/converter/manifest-inference.js +43 -12
- package/src/converter/parsers/diffusion.js +0 -3
- package/src/converter/quantization-info.js +35 -15
- package/src/converter/rope-config.js +42 -0
- package/src/converter/shard-packer.d.ts +1 -1
- package/src/converter/shard-packer.js +4 -1
- package/src/debug/config.js +123 -11
- package/src/debug/signals.js +7 -1
- package/src/debug/tensor.d.ts +2 -0
- package/src/debug/tensor.js +13 -2
- package/src/distribution/p2p-control-plane.js +52 -12
- package/src/distribution/p2p-observability.js +43 -7
- package/src/distribution/p2p-webrtc-browser.js +20 -0
- package/src/distribution/shard-delivery.js +77 -26
- package/src/formats/gguf/types.js +33 -16
- package/src/formats/rdrr/groups.d.ts +12 -4
- package/src/formats/rdrr/groups.js +3 -6
- package/src/formats/rdrr/parsing.js +39 -2
- package/src/formats/rdrr/types.d.ts +2 -1
- package/src/gpu/command-recorder.js +86 -61
- package/src/gpu/device.d.ts +1 -0
- package/src/gpu/device.js +131 -19
- package/src/gpu/kernel-tuner/benchmarks.js +326 -316
- package/src/gpu/kernel-tuner/cache.js +71 -4
- package/src/gpu/kernel-tuner/tuner.js +22 -4
- package/src/gpu/kernels/attention.js +113 -34
- package/src/gpu/kernels/backward/adam.js +62 -58
- package/src/gpu/kernels/backward/attention_backward.js +257 -169
- package/src/gpu/kernels/backward/conv2d_backward.js +14 -1
- package/src/gpu/kernels/bias_add.wgsl +8 -6
- package/src/gpu/kernels/bias_add_f16.wgsl +8 -5
- package/src/gpu/kernels/cast.js +191 -149
- package/src/gpu/kernels/check-stop.js +33 -44
- package/src/gpu/kernels/conv2d.js +27 -17
- package/src/gpu/kernels/conv2d.wgsl +7 -8
- package/src/gpu/kernels/conv2d_f16.wgsl +7 -8
- package/src/gpu/kernels/cross_entropy_loss.js +21 -15
- package/src/gpu/kernels/depthwise_conv2d.js +37 -26
- package/src/gpu/kernels/depthwise_conv2d.wgsl +6 -9
- package/src/gpu/kernels/depthwise_conv2d_f16.wgsl +6 -9
- package/src/gpu/kernels/dequant.js +178 -126
- package/src/gpu/kernels/energy.d.ts +3 -21
- package/src/gpu/kernels/energy.js +111 -88
- package/src/gpu/kernels/feature-check.js +1 -1
- package/src/gpu/kernels/fused_ffn.js +84 -65
- package/src/gpu/kernels/fused_matmul_residual.js +56 -33
- package/src/gpu/kernels/fused_matmul_rmsnorm.js +62 -45
- package/src/gpu/kernels/gather.js +33 -15
- package/src/gpu/kernels/gelu.js +19 -11
- package/src/gpu/kernels/grouped_pointwise_conv2d.js +34 -23
- package/src/gpu/kernels/grouped_pointwise_conv2d.wgsl +6 -9
- package/src/gpu/kernels/grouped_pointwise_conv2d_f16.wgsl +6 -9
- package/src/gpu/kernels/groupnorm.js +34 -23
- package/src/gpu/kernels/kv-quantize.js +5 -2
- package/src/gpu/kernels/layernorm.js +35 -19
- package/src/gpu/kernels/logit-merge.js +5 -3
- package/src/gpu/kernels/matmul.js +83 -39
- package/src/gpu/kernels/modulate.js +23 -15
- package/src/gpu/kernels/moe.js +221 -175
- package/src/gpu/kernels/pixel_shuffle.js +22 -14
- package/src/gpu/kernels/pixel_shuffle.wgsl +4 -5
- package/src/gpu/kernels/pixel_shuffle_f16.wgsl +4 -5
- package/src/gpu/kernels/relu.js +31 -10
- package/src/gpu/kernels/relu.wgsl +2 -1
- package/src/gpu/kernels/relu_f16.wgsl +2 -1
- package/src/gpu/kernels/repeat_channels.js +25 -17
- package/src/gpu/kernels/repeat_channels.wgsl +4 -5
- package/src/gpu/kernels/repeat_channels_f16.wgsl +4 -5
- package/src/gpu/kernels/residual.js +69 -23
- package/src/gpu/kernels/residual.wgsl +6 -3
- package/src/gpu/kernels/residual_f16.wgsl +2 -1
- package/src/gpu/kernels/residual_f16_vec4.wgsl +2 -1
- package/src/gpu/kernels/residual_vec4.wgsl +2 -1
- package/src/gpu/kernels/rmsnorm.js +96 -28
- package/src/gpu/kernels/rmsnorm.wgsl +14 -6
- package/src/gpu/kernels/rmsnorm_f16.wgsl +10 -2
- package/src/gpu/kernels/rope.d.ts +2 -0
- package/src/gpu/kernels/rope.js +14 -1
- package/src/gpu/kernels/rope.wgsl +56 -40
- package/src/gpu/kernels/sample.js +27 -38
- package/src/gpu/kernels/sana_linear_attention.js +19 -12
- package/src/gpu/kernels/sana_linear_attention_apply.wgsl +4 -5
- package/src/gpu/kernels/sana_linear_attention_apply_f16.wgsl +4 -5
- package/src/gpu/kernels/sana_linear_attention_summary.wgsl +4 -0
- package/src/gpu/kernels/sana_linear_attention_summary_f16.wgsl +4 -0
- package/src/gpu/kernels/scale.js +18 -11
- package/src/gpu/kernels/shader-cache.js +4 -2
- package/src/gpu/kernels/silu.d.ts +1 -0
- package/src/gpu/kernels/silu.js +148 -82
- package/src/gpu/kernels/silu.wgsl +19 -9
- package/src/gpu/kernels/silu_f16.wgsl +19 -9
- package/src/gpu/kernels/softmax.js +44 -25
- package/src/gpu/kernels/split_qkv.js +23 -13
- package/src/gpu/kernels/transpose.js +31 -10
- package/src/gpu/kernels/transpose.wgsl +6 -5
- package/src/gpu/kernels/upsample2d.js +22 -13
- package/src/gpu/kernels/upsample2d.wgsl +6 -9
- package/src/gpu/kernels/upsample2d_f16.wgsl +6 -9
- package/src/gpu/kernels/utils.js +35 -13
- package/src/gpu/partitioned-buffer-pool.js +10 -2
- package/src/gpu/perf-guards.js +2 -9
- package/src/gpu/profiler.js +27 -22
- package/src/gpu/readback-utils.d.ts +16 -0
- package/src/gpu/readback-utils.js +41 -0
- package/src/gpu/submit-tracker.js +13 -0
- package/src/gpu/uniform-cache.d.ts +1 -0
- package/src/gpu/uniform-cache.js +30 -9
- package/src/hotswap/intent-bundle.js +6 -0
- package/src/hotswap/manifest.d.ts +10 -1
- package/src/hotswap/manifest.js +12 -2
- package/src/hotswap/runtime.js +30 -8
- package/src/index-browser.d.ts +44 -0
- package/src/index-browser.js +14 -0
- package/src/inference/browser-harness-contract-helpers.d.ts +5 -0
- package/src/inference/browser-harness-contract-helpers.js +28 -0
- package/src/inference/browser-harness-diffusion-energy-suites.d.ts +2 -0
- package/src/inference/browser-harness-diffusion-energy-suites.js +269 -0
- package/src/inference/browser-harness-model-helpers.d.ts +16 -0
- package/src/inference/browser-harness-model-helpers.js +217 -0
- package/src/inference/browser-harness-report-helpers.d.ts +7 -0
- package/src/inference/browser-harness-report-helpers.js +42 -0
- package/src/inference/browser-harness-runtime-helpers.d.ts +61 -0
- package/src/inference/browser-harness-runtime-helpers.js +415 -0
- package/src/inference/browser-harness-suite-helpers.d.ts +28 -0
- package/src/inference/browser-harness-suite-helpers.js +268 -0
- package/src/inference/browser-harness-text-helpers.d.ts +27 -0
- package/src/inference/browser-harness-text-helpers.js +788 -0
- package/src/inference/browser-harness.d.ts +6 -0
- package/src/inference/browser-harness.js +130 -1950
- package/src/inference/kv-cache/base.js +140 -94
- package/src/inference/kv-cache/tiered.js +5 -3
- package/src/inference/moe-router.js +88 -56
- package/src/inference/multi-model-network.js +5 -3
- package/src/inference/network-evolution.d.ts +11 -2
- package/src/inference/network-evolution.js +20 -21
- package/src/inference/pipelines/context.d.ts +3 -0
- package/src/inference/pipelines/context.js +142 -2
- package/src/inference/pipelines/diffusion/helpers.js +7 -2
- package/src/inference/pipelines/diffusion/pipeline.js +17 -7
- package/src/inference/pipelines/diffusion/sd3-transformer.js +10 -10
- package/src/inference/pipelines/diffusion/text-encoder-gpu.d.ts +5 -0
- package/src/inference/pipelines/diffusion/text-encoder-gpu.js +27 -15
- package/src/inference/pipelines/diffusion/vae.js +3 -7
- package/src/inference/pipelines/energy/pipeline.js +27 -21
- package/src/inference/pipelines/energy/quintel.d.ts +5 -0
- package/src/inference/pipelines/energy/quintel.js +11 -0
- package/src/inference/pipelines/energy-head/row-head-pipeline.js +17 -13
- package/src/inference/pipelines/structured/json-head-pipeline.js +26 -11
- package/src/inference/pipelines/text/attention/projections.js +151 -101
- package/src/inference/pipelines/text/attention/record.js +73 -10
- package/src/inference/pipelines/text/attention/run.js +73 -10
- package/src/inference/pipelines/text/chat-format.js +25 -1
- package/src/inference/pipelines/text/config.d.ts +4 -0
- package/src/inference/pipelines/text/config.js +71 -5
- package/src/inference/pipelines/text/embed.js +2 -8
- package/src/inference/pipelines/text/execution-plan.js +64 -50
- package/src/inference/pipelines/text/execution-v0-contract-helpers.d.ts +59 -0
- package/src/inference/pipelines/text/execution-v0-contract-helpers.js +937 -0
- package/src/inference/pipelines/text/execution-v0-runtime-builders.d.ts +15 -0
- package/src/inference/pipelines/text/execution-v0-runtime-builders.js +279 -0
- package/src/inference/pipelines/text/execution-v0.js +78 -1002
- package/src/inference/pipelines/text/ffn/standard.js +3 -0
- package/src/inference/pipelines/text/generator-steps.d.ts +46 -0
- package/src/inference/pipelines/text/generator-steps.js +298 -207
- package/src/inference/pipelines/text/generator.js +6 -23
- package/src/inference/pipelines/text/init.d.ts +4 -0
- package/src/inference/pipelines/text/init.js +134 -29
- package/src/inference/pipelines/text/kernel-path-auto-select.js +2 -0
- package/src/inference/pipelines/text/kernel-trace.d.ts +2 -0
- package/src/inference/pipelines/text/kernel-trace.js +6 -0
- package/src/inference/pipelines/text/layer.js +14 -9
- package/src/inference/pipelines/text/linear-attention.d.ts +10 -0
- package/src/inference/pipelines/text/linear-attention.js +80 -6
- package/src/inference/pipelines/text/logits/gpu.js +10 -5
- package/src/inference/pipelines/text/logits/index.js +10 -11
- package/src/inference/pipelines/text/logits/utils.d.ts +7 -0
- package/src/inference/pipelines/text/logits/utils.js +9 -0
- package/src/inference/pipelines/text/lora-apply.js +50 -32
- package/src/inference/pipelines/text/model-load.js +279 -104
- package/src/inference/pipelines/text/moe-cache.js +5 -4
- package/src/inference/pipelines/text/moe-cpu-gptoss.js +74 -69
- package/src/inference/pipelines/text/moe-cpu.js +42 -38
- package/src/inference/pipelines/text/moe-gpu.js +110 -86
- package/src/inference/pipelines/text/ops.js +90 -90
- package/src/inference/pipelines/text/probes.js +9 -9
- package/src/inference/pipelines/text/weights.js +17 -7
- package/src/inference/pipelines/text.js +17 -1
- package/src/inference/speculative.d.ts +2 -2
- package/src/inference/speculative.js +4 -18
- package/src/inference/test-harness.d.ts +1 -1
- package/src/inference/test-harness.js +15 -5
- package/src/inference/tokenizer.d.ts +0 -5
- package/src/inference/tokenizer.js +4 -23
- package/src/inference/tokenizers/bpe.js +9 -0
- package/src/inference/tokenizers/bundled.js +176 -33
- package/src/inference/tokenizers/sentencepiece.js +12 -0
- package/src/loader/doppler-loader.js +38 -22
- package/src/loader/dtype-utils.js +3 -44
- package/src/loader/embedding-loader.js +7 -3
- package/src/loader/experts/expert-cache.js +13 -6
- package/src/loader/experts/expert-loader.js +10 -6
- package/src/loader/final-weights-loader.js +8 -4
- package/src/loader/layer-loader.js +2 -1
- package/src/loader/loader-state.js +2 -2
- package/src/loader/memory-monitor.js +8 -0
- package/src/loader/multi-model-loader.d.ts +14 -0
- package/src/loader/multi-model-loader.js +70 -24
- package/src/loader/shard-cache.js +81 -12
- package/src/loader/shard-resolver.js +25 -3
- package/src/loader/tensors/tensor-loader.js +209 -144
- package/src/loader/tensors/tensor-reader.js +76 -19
- package/src/loader/weight-downcast.js +1 -1
- package/src/memory/buffer-pool.d.ts +9 -1
- package/src/memory/buffer-pool.js +109 -44
- package/src/memory/unified-detect.js +1 -1
- package/src/rules/inference/kernel-path.rules.json +24 -8
- package/src/rules/rule-registry.js +25 -1
- package/src/rules/tooling/command-runtime.rules.json +18 -0
- package/src/storage/backends/opfs-store.js +68 -24
- package/src/storage/downloader.js +364 -83
- package/src/storage/index.d.ts +3 -0
- package/src/storage/index.js +3 -0
- package/src/storage/preflight.d.ts +2 -2
- package/src/storage/preflight.js +24 -2
- package/src/storage/quickstart-downloader.js +11 -5
- package/src/storage/registry.js +10 -4
- package/src/storage/reports.js +1 -1
- package/src/storage/shard-manager.d.ts +15 -1
- package/src/storage/shard-manager.js +51 -3
- package/src/storage/source-artifact-store.d.ts +52 -0
- package/src/storage/source-artifact-store.js +234 -0
- package/src/tooling/command-api-constants.d.ts +9 -0
- package/src/tooling/command-api-constants.js +9 -0
- package/src/tooling/command-api-family-normalizers.d.ts +9 -0
- package/src/tooling/command-api-family-normalizers.js +343 -0
- package/src/tooling/command-api-helpers.d.ts +25 -0
- package/src/tooling/command-api-helpers.js +262 -0
- package/src/tooling/command-api.d.ts +27 -1
- package/src/tooling/command-api.js +26 -473
- package/src/tooling/command-envelope.js +4 -1
- package/src/tooling/command-runner-shared.js +52 -18
- package/src/tooling/lean-execution-contract.js +150 -3
- package/src/tooling/node-browser-command-runner.d.ts +4 -0
- package/src/tooling/node-browser-command-runner.js +218 -273
- package/src/tooling/node-command-runner.js +44 -3
- package/src/tooling/node-converter.js +27 -1
- package/src/tooling/node-source-runtime.d.ts +1 -1
- package/src/tooling/node-source-runtime.js +84 -3
- package/src/tooling/node-webgpu.js +30 -105
- package/src/tooling/opfs-cache.js +21 -4
- package/src/tooling/runtime-input-composition.d.ts +38 -0
- package/src/tooling/runtime-input-composition.js +86 -0
- package/src/tooling/source-runtime-bundle.d.ts +40 -5
- package/src/tooling/source-runtime-bundle.js +261 -34
- package/src/tooling/source-runtime-materializer.d.ts +6 -0
- package/src/tooling/source-runtime-materializer.js +93 -0
- package/src/training/attention-backward.js +32 -17
- package/src/training/autograd.js +80 -52
- package/src/training/checkpoint-watch.d.ts +8 -0
- package/src/training/checkpoint-watch.js +139 -0
- package/src/training/checkpoint.d.ts +6 -1
- package/src/training/checkpoint.js +46 -7
- package/src/training/clip.js +2 -1
- package/src/training/datasets/token-batch.js +20 -8
- package/src/training/distillation/artifacts.d.ts +71 -0
- package/src/training/distillation/artifacts.js +132 -0
- package/src/training/distillation/checkpoint-watch.d.ts +10 -0
- package/src/training/distillation/checkpoint-watch.js +58 -0
- package/src/training/distillation/dataset.d.ts +59 -0
- package/src/training/distillation/dataset.js +337 -0
- package/src/training/distillation/eval.d.ts +34 -0
- package/src/training/distillation/eval.js +310 -0
- package/src/training/distillation/index.d.ts +29 -0
- package/src/training/distillation/index.js +29 -0
- package/src/training/distillation/runtime.d.ts +20 -0
- package/src/training/distillation/runtime.js +121 -0
- package/src/training/distillation/scoreboard.d.ts +6 -0
- package/src/training/distillation/scoreboard.js +8 -0
- package/src/training/distillation/stage-a.d.ts +45 -0
- package/src/training/distillation/stage-a.js +338 -0
- package/src/training/distillation/stage-b.d.ts +24 -0
- package/src/training/distillation/stage-b.js +20 -0
- package/src/training/distillation/student-fixture.d.ts +22 -0
- package/src/training/distillation/student-fixture.js +846 -0
- package/src/training/distillation/suite-data.d.ts +45 -0
- package/src/training/distillation/suite-data.js +189 -0
- package/src/training/index.d.ts +10 -0
- package/src/training/index.js +10 -0
- package/src/training/lora-pipeline.d.ts +40 -0
- package/src/training/lora-pipeline.js +793 -0
- package/src/training/lora.js +26 -12
- package/src/training/loss.js +5 -6
- package/src/training/objectives/cross_entropy.js +2 -5
- package/src/training/objectives/distill_kd.js +4 -8
- package/src/training/objectives/distill_triplet.js +4 -8
- package/src/training/objectives/ul_stage2_base.js +4 -8
- package/src/training/operator-artifacts.d.ts +62 -0
- package/src/training/operator-artifacts.js +140 -0
- package/src/training/operator-command.d.ts +5 -0
- package/src/training/operator-command.js +455 -0
- package/src/training/operator-eval.d.ts +48 -0
- package/src/training/operator-eval.js +230 -0
- package/src/training/operator-scoreboard.d.ts +5 -0
- package/src/training/operator-scoreboard.js +44 -0
- package/src/training/optimizer.js +19 -7
- package/src/training/runner.d.ts +52 -0
- package/src/training/runner.js +31 -5
- package/src/training/suite.d.ts +112 -0
- package/src/training/suite.js +24 -984
- package/src/training/tensor-factory.d.ts +9 -0
- package/src/training/tensor-factory.js +13 -0
- package/src/training/trainer.js +3 -5
- package/src/training/ul_dataset.js +3 -5
- package/src/training/workloads.d.ts +164 -0
- package/src/training/workloads.js +530 -0
- package/src/version.js +1 -1
- package/tools/convert-safetensors-node.js +22 -16
- package/tools/doppler-cli.js +179 -63
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
import path from 'node:path';
|
|
2
|
+
|
|
3
|
+
import {
|
|
4
|
+
DIRECT_SOURCE_PATH_ARTIFACT_RELATIVE,
|
|
5
|
+
DIRECT_SOURCE_RUNTIME_MODE,
|
|
6
|
+
DIRECT_SOURCE_RUNTIME_SCHEMA,
|
|
7
|
+
DIRECT_SOURCE_RUNTIME_SCHEMA_VERSION,
|
|
8
|
+
getSourceRuntimeMetadata,
|
|
9
|
+
} from './source-runtime-bundle.js';
|
|
10
|
+
|
|
11
|
+
function cloneJsonValue(value) {
|
|
12
|
+
if (typeof structuredClone === 'function') {
|
|
13
|
+
return structuredClone(value);
|
|
14
|
+
}
|
|
15
|
+
return JSON.parse(JSON.stringify(value));
|
|
16
|
+
}
|
|
17
|
+
|
|
18
|
+
function toRelativeArtifactPath(value, artifactDir, label) {
|
|
19
|
+
const raw = String(value || '').trim();
|
|
20
|
+
if (!raw) {
|
|
21
|
+
throw new Error(`${label} path is required.`);
|
|
22
|
+
}
|
|
23
|
+
const resolvedArtifactDir = path.resolve(artifactDir);
|
|
24
|
+
const resolvedTarget = path.resolve(raw);
|
|
25
|
+
const relativePath = path.relative(resolvedArtifactDir, resolvedTarget).replace(/\\/g, '/');
|
|
26
|
+
if (!relativePath || relativePath.startsWith('../') || relativePath === '..') {
|
|
27
|
+
throw new Error(
|
|
28
|
+
`${label} "${raw}" must live inside artifactDir "${resolvedArtifactDir}" for a persisted direct-source manifest.`
|
|
29
|
+
);
|
|
30
|
+
}
|
|
31
|
+
return relativePath;
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
export function materializeSourceRuntimeManifest(manifest, artifactDir) {
|
|
35
|
+
const sourceRuntime = getSourceRuntimeMetadata(manifest);
|
|
36
|
+
if (!sourceRuntime) {
|
|
37
|
+
throw new Error('materializeSourceRuntimeManifest requires manifest.metadata.sourceRuntime.');
|
|
38
|
+
}
|
|
39
|
+
const resolvedArtifactDir = String(artifactDir || '').trim();
|
|
40
|
+
if (!resolvedArtifactDir) {
|
|
41
|
+
throw new Error('materializeSourceRuntimeManifest requires artifactDir.');
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
const nextManifest = cloneJsonValue(manifest);
|
|
45
|
+
if (!nextManifest.metadata || typeof nextManifest.metadata !== 'object') {
|
|
46
|
+
nextManifest.metadata = {};
|
|
47
|
+
}
|
|
48
|
+
const sourceMetadata = nextManifest.metadata.sourceRuntime && typeof nextManifest.metadata.sourceRuntime === 'object'
|
|
49
|
+
? cloneJsonValue(nextManifest.metadata.sourceRuntime)
|
|
50
|
+
: {};
|
|
51
|
+
|
|
52
|
+
sourceMetadata.mode = DIRECT_SOURCE_RUNTIME_MODE;
|
|
53
|
+
sourceMetadata.schema = DIRECT_SOURCE_RUNTIME_SCHEMA;
|
|
54
|
+
sourceMetadata.schemaVersion = DIRECT_SOURCE_RUNTIME_SCHEMA_VERSION;
|
|
55
|
+
sourceMetadata.hashAlgorithm = sourceRuntime.hashAlgorithm;
|
|
56
|
+
sourceMetadata.pathSemantics = DIRECT_SOURCE_PATH_ARTIFACT_RELATIVE;
|
|
57
|
+
sourceMetadata.sourceFiles = sourceRuntime.sourceFiles.map((entry) => ({
|
|
58
|
+
index: entry.index,
|
|
59
|
+
filename: entry.filename ?? null,
|
|
60
|
+
path: toRelativeArtifactPath(
|
|
61
|
+
entry.path,
|
|
62
|
+
resolvedArtifactDir,
|
|
63
|
+
`source runtime source file ${entry.index}`
|
|
64
|
+
),
|
|
65
|
+
size: entry.size,
|
|
66
|
+
hash: entry.hash,
|
|
67
|
+
hashAlgorithm: entry.hashAlgorithm,
|
|
68
|
+
}));
|
|
69
|
+
sourceMetadata.auxiliaryFiles = sourceRuntime.auxiliaryFiles.map((entry) => ({
|
|
70
|
+
path: toRelativeArtifactPath(
|
|
71
|
+
entry.path,
|
|
72
|
+
resolvedArtifactDir,
|
|
73
|
+
`source runtime auxiliary file ${entry.kind}`
|
|
74
|
+
),
|
|
75
|
+
size: entry.size,
|
|
76
|
+
hash: entry.hash,
|
|
77
|
+
hashAlgorithm: entry.hashAlgorithm,
|
|
78
|
+
kind: entry.kind,
|
|
79
|
+
}));
|
|
80
|
+
sourceMetadata.tokenizer = {
|
|
81
|
+
jsonPath: sourceRuntime.tokenizer.jsonPath
|
|
82
|
+
? toRelativeArtifactPath(sourceRuntime.tokenizer.jsonPath, resolvedArtifactDir, 'source runtime tokenizer json')
|
|
83
|
+
: null,
|
|
84
|
+
configPath: sourceRuntime.tokenizer.configPath
|
|
85
|
+
? toRelativeArtifactPath(sourceRuntime.tokenizer.configPath, resolvedArtifactDir, 'source runtime tokenizer config')
|
|
86
|
+
: null,
|
|
87
|
+
modelPath: sourceRuntime.tokenizer.modelPath
|
|
88
|
+
? toRelativeArtifactPath(sourceRuntime.tokenizer.modelPath, resolvedArtifactDir, 'source runtime tokenizer model')
|
|
89
|
+
: null,
|
|
90
|
+
};
|
|
91
|
+
nextManifest.metadata.sourceRuntime = sourceMetadata;
|
|
92
|
+
return nextManifest;
|
|
93
|
+
}
|
|
@@ -1,6 +1,7 @@
|
|
|
1
|
-
import { acquireBuffer, uploadData, readBuffer } from '../memory/buffer-pool.js';
|
|
1
|
+
import { acquireBuffer, uploadData, readBuffer, releaseBuffer } from '../memory/buffer-pool.js';
|
|
2
2
|
import { createTensor, tensorBytes } from '../gpu/tensor.js';
|
|
3
3
|
import { f16ToF32Array } from '../inference/kv-cache/types.js';
|
|
4
|
+
import { createUploadedTensor } from './tensor-factory.js';
|
|
4
5
|
|
|
5
6
|
function toFloat32(buffer, dtype) {
|
|
6
7
|
if (dtype === 'f16') {
|
|
@@ -67,9 +68,7 @@ export async function buildAttentionSoftmaxCache(q, k, options) {
|
|
|
67
68
|
const kData = toFloat32(kBuf, k.dtype);
|
|
68
69
|
const sData = computeSoftmax(qData, kData, options);
|
|
69
70
|
const { seqLen, numHeads } = options;
|
|
70
|
-
|
|
71
|
-
uploadData(outBuf, sData);
|
|
72
|
-
return createTensor(outBuf, 'f32', [numHeads, seqLen, seqLen], 'attn_softmax_cache');
|
|
71
|
+
return createUploadedTensor(sData, 'f32', [numHeads, seqLen, seqLen], 'attn_softmax_cache');
|
|
73
72
|
}
|
|
74
73
|
|
|
75
74
|
export async function attentionBackwardCpu(
|
|
@@ -201,17 +200,33 @@ export async function attentionBackwardCpu(
|
|
|
201
200
|
}
|
|
202
201
|
}
|
|
203
202
|
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
203
|
+
let qBufOut = null;
|
|
204
|
+
let kBufOut = null;
|
|
205
|
+
let vBufOut = null;
|
|
206
|
+
try {
|
|
207
|
+
qBufOut = acquireBuffer(tensorBytes([seqLen, numHeads, headDim], 'f32'), undefined, 'attn_backward_q');
|
|
208
|
+
kBufOut = acquireBuffer(tensorBytes([seqLen, numHeads, headDim], 'f32'), undefined, 'attn_backward_k');
|
|
209
|
+
vBufOut = acquireBuffer(tensorBytes([seqLen, numHeads, headDim], 'f32'), undefined, 'attn_backward_v');
|
|
210
|
+
|
|
211
|
+
uploadData(qBufOut, dQ);
|
|
212
|
+
uploadData(kBufOut, dK);
|
|
213
|
+
uploadData(vBufOut, dV);
|
|
214
|
+
|
|
215
|
+
return {
|
|
216
|
+
gradQ: createTensor(qBufOut, 'f32', [seqLen, numHeads, headDim], 'attn_grad_q'),
|
|
217
|
+
gradK: createTensor(kBufOut, 'f32', [seqLen, numHeads, headDim], 'attn_grad_k'),
|
|
218
|
+
gradV: createTensor(vBufOut, 'f32', [seqLen, numHeads, headDim], 'attn_grad_v'),
|
|
219
|
+
};
|
|
220
|
+
} catch (error) {
|
|
221
|
+
if (qBufOut) {
|
|
222
|
+
releaseBuffer(qBufOut);
|
|
223
|
+
}
|
|
224
|
+
if (kBufOut) {
|
|
225
|
+
releaseBuffer(kBufOut);
|
|
226
|
+
}
|
|
227
|
+
if (vBufOut) {
|
|
228
|
+
releaseBuffer(vBufOut);
|
|
229
|
+
}
|
|
230
|
+
throw error;
|
|
231
|
+
}
|
|
217
232
|
}
|
package/src/training/autograd.js
CHANGED
|
@@ -6,6 +6,7 @@ import { acquireBuffer, readBuffer, releaseBuffer, uploadData } from '../memory/
|
|
|
6
6
|
import { createTensor } from '../gpu/tensor.js';
|
|
7
7
|
import { attentionBackwardCpu } from './attention-backward.js';
|
|
8
8
|
import { f16ToF32Array, f32ToF16Array } from '../inference/kv-cache/types.js';
|
|
9
|
+
import { createUploadedTensor } from './tensor-factory.js';
|
|
9
10
|
|
|
10
11
|
export const OpType = {
|
|
11
12
|
EMBED: 'embed',
|
|
@@ -35,6 +36,7 @@ export class AutogradTape {
|
|
|
35
36
|
constructor(registry) {
|
|
36
37
|
this.registry = registry;
|
|
37
38
|
this.records = [];
|
|
39
|
+
this.retainedBuffers = new Set();
|
|
38
40
|
}
|
|
39
41
|
|
|
40
42
|
watch(tensor) {
|
|
@@ -43,6 +45,13 @@ export class AutogradTape {
|
|
|
43
45
|
|
|
44
46
|
async record(op, fn, inputs, options = {}) {
|
|
45
47
|
const output = await fn(...inputs);
|
|
48
|
+
if (Array.isArray(options.retainBuffers)) {
|
|
49
|
+
for (const buffer of options.retainBuffers) {
|
|
50
|
+
if (buffer) {
|
|
51
|
+
this.retainedBuffers.add(buffer);
|
|
52
|
+
}
|
|
53
|
+
}
|
|
54
|
+
}
|
|
46
55
|
this.records.push({ op, inputs, output, options });
|
|
47
56
|
return output;
|
|
48
57
|
}
|
|
@@ -50,31 +59,40 @@ export class AutogradTape {
|
|
|
50
59
|
async backward(gradOutput) {
|
|
51
60
|
const grads = new Map();
|
|
52
61
|
const seeds = this.normalizeBackwardSeeds(gradOutput);
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
for (let i = this.records.length - 1; i >= 0; i -= 1) {
|
|
58
|
-
const record = this.records[i];
|
|
59
|
-
const entry = this.registry.ops[record.op];
|
|
60
|
-
if (!entry) {
|
|
61
|
-
continue;
|
|
62
|
+
try {
|
|
63
|
+
for (const seed of seeds) {
|
|
64
|
+
await this.accumulateGrad(grads, seed.tensor, seed.grad);
|
|
62
65
|
}
|
|
63
66
|
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
67
|
+
for (let i = this.records.length - 1; i >= 0; i -= 1) {
|
|
68
|
+
const record = this.records[i];
|
|
69
|
+
const entry = this.registry.ops[record.op];
|
|
70
|
+
if (!entry) {
|
|
71
|
+
continue;
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
const gradOut = grads.get(record.output);
|
|
75
|
+
if (!gradOut) {
|
|
76
|
+
continue;
|
|
77
|
+
}
|
|
68
78
|
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
79
|
+
const gradsOut = await this.runBackward(entry.backward, record, gradOut);
|
|
80
|
+
for (const { input, grad } of gradsOut) {
|
|
81
|
+
if (input && grad) {
|
|
82
|
+
await this.accumulateGrad(grads, input, grad);
|
|
83
|
+
}
|
|
73
84
|
}
|
|
74
85
|
}
|
|
75
|
-
}
|
|
76
86
|
|
|
77
|
-
|
|
87
|
+
return grads;
|
|
88
|
+
} finally {
|
|
89
|
+
for (const buffer of this.retainedBuffers) {
|
|
90
|
+
try {
|
|
91
|
+
releaseBuffer(buffer);
|
|
92
|
+
} catch {}
|
|
93
|
+
}
|
|
94
|
+
this.retainedBuffers.clear();
|
|
95
|
+
}
|
|
78
96
|
}
|
|
79
97
|
|
|
80
98
|
isTensorLike(value) {
|
|
@@ -245,9 +263,7 @@ export class AutogradTape {
|
|
|
245
263
|
expanded.set(gradRow.subarray(0, copyCount), rowOffset);
|
|
246
264
|
const dtype = gradOut.dtype === 'f16' ? 'f16' : 'f32';
|
|
247
265
|
const payload = dtype === 'f16' ? f32ToF16Array(expanded) : expanded;
|
|
248
|
-
|
|
249
|
-
uploadData(outBuffer, payload);
|
|
250
|
-
return createTensor(outBuffer, dtype, [rows, cols], 'row_slice_backward_output');
|
|
266
|
+
return createUploadedTensor(payload, dtype, [rows, cols], 'row_slice_backward_output');
|
|
251
267
|
}
|
|
252
268
|
|
|
253
269
|
resolveSiluRowsplitGate(gateValue, activation) {
|
|
@@ -305,9 +321,7 @@ export class AutogradTape {
|
|
|
305
321
|
|
|
306
322
|
const dtype = gradOut.dtype === 'f16' ? 'f16' : 'f32';
|
|
307
323
|
const payload = dtype === 'f16' ? f32ToF16Array(output) : output;
|
|
308
|
-
|
|
309
|
-
uploadData(outBuffer, payload);
|
|
310
|
-
return createTensor(outBuffer, dtype, [numTokens, dim * 2], 'silu_rowsplit_backward_output');
|
|
324
|
+
return createUploadedTensor(payload, dtype, [numTokens, dim * 2], 'silu_rowsplit_backward_output');
|
|
311
325
|
}
|
|
312
326
|
|
|
313
327
|
async accumulateLargeGradF32(existing, grad, size, shape) {
|
|
@@ -317,35 +331,49 @@ export class AutogradTape {
|
|
|
317
331
|
}
|
|
318
332
|
const bytesPerElement = 4;
|
|
319
333
|
const outputBuffer = acquireBuffer(size * bytesPerElement, undefined, 'grad_accum_large_output');
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
334
|
+
try {
|
|
335
|
+
for (let offset = 0; offset < size; offset += MAX_RESIDUAL_ELEMENTS_PER_DISPATCH) {
|
|
336
|
+
const chunkElements = Math.min(MAX_RESIDUAL_ELEMENTS_PER_DISPATCH, size - offset);
|
|
337
|
+
const chunkBytes = chunkElements * bytesPerElement;
|
|
338
|
+
const chunkOffsetBytes = offset * bytesPerElement;
|
|
339
|
+
|
|
340
|
+
let aChunkBuffer = null;
|
|
341
|
+
let bChunkBuffer = null;
|
|
342
|
+
let summedChunkBuffer = null;
|
|
343
|
+
try {
|
|
344
|
+
aChunkBuffer = acquireBuffer(chunkBytes, undefined, 'grad_accum_large_a_chunk');
|
|
345
|
+
bChunkBuffer = acquireBuffer(chunkBytes, undefined, 'grad_accum_large_b_chunk');
|
|
346
|
+
const copyIn = device.createCommandEncoder();
|
|
347
|
+
copyIn.copyBufferToBuffer(existing.buffer, chunkOffsetBytes, aChunkBuffer, 0, chunkBytes);
|
|
348
|
+
copyIn.copyBufferToBuffer(grad.buffer, chunkOffsetBytes, bChunkBuffer, 0, chunkBytes);
|
|
349
|
+
device.queue.submit([copyIn.finish()]);
|
|
350
|
+
|
|
351
|
+
const aChunk = createTensor(aChunkBuffer, 'f32', [chunkElements], 'grad_accum_large_a_tensor');
|
|
352
|
+
const bChunk = createTensor(bChunkBuffer, 'f32', [chunkElements], 'grad_accum_large_b_tensor');
|
|
353
|
+
const summedChunk = await runResidualAdd(aChunk, bChunk, chunkElements);
|
|
354
|
+
summedChunkBuffer = summedChunk?.buffer ?? null;
|
|
355
|
+
|
|
356
|
+
const copyOut = device.createCommandEncoder();
|
|
357
|
+
copyOut.copyBufferToBuffer(summedChunk.buffer, 0, outputBuffer, chunkOffsetBytes, chunkBytes);
|
|
358
|
+
device.queue.submit([copyOut.finish()]);
|
|
359
|
+
} finally {
|
|
360
|
+
if (aChunkBuffer) {
|
|
361
|
+
releaseBuffer(aChunkBuffer);
|
|
362
|
+
}
|
|
363
|
+
if (bChunkBuffer) {
|
|
364
|
+
releaseBuffer(bChunkBuffer);
|
|
365
|
+
}
|
|
366
|
+
if (summedChunkBuffer && summedChunkBuffer !== outputBuffer) {
|
|
367
|
+
releaseBuffer(summedChunkBuffer);
|
|
368
|
+
}
|
|
369
|
+
}
|
|
345
370
|
}
|
|
346
|
-
}
|
|
347
371
|
|
|
348
|
-
|
|
372
|
+
return createTensor(outputBuffer, 'f32', [...shape], 'grad_accum_large_output');
|
|
373
|
+
} catch (error) {
|
|
374
|
+
releaseBuffer(outputBuffer);
|
|
375
|
+
throw error;
|
|
376
|
+
}
|
|
349
377
|
}
|
|
350
378
|
|
|
351
379
|
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
export declare function watchFinalizedCheckpoints(options: {
|
|
2
|
+
checkpointsDir: string;
|
|
3
|
+
manifestPath: string;
|
|
4
|
+
pollIntervalMs?: number | null;
|
|
5
|
+
stopWhenIdle?: boolean;
|
|
6
|
+
signal?: AbortSignal | null;
|
|
7
|
+
onCheckpoint: (markerPath: string) => Promise<void> | void;
|
|
8
|
+
}): Promise<{ ok: true; processedCount: number; manifestPath: string; aborted?: boolean }>;
|
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
import { readdir, readFile } from 'node:fs/promises';
|
|
2
|
+
import { join, resolve } from 'node:path';
|
|
3
|
+
|
|
4
|
+
import { writeJsonArtifact } from './operator-artifacts.js';
|
|
5
|
+
|
|
6
|
+
async function listCheckpointMarkers(checkpointsDir) {
|
|
7
|
+
const absoluteDir = resolve(String(checkpointsDir));
|
|
8
|
+
const entries = await readdir(absoluteDir, { withFileTypes: true });
|
|
9
|
+
const markers = [];
|
|
10
|
+
for (const entry of entries) {
|
|
11
|
+
if (!entry.isDirectory()) {
|
|
12
|
+
continue;
|
|
13
|
+
}
|
|
14
|
+
const entryPath = join(absoluteDir, entry.name);
|
|
15
|
+
const markerPath = join(entryPath, 'checkpoint.complete.json');
|
|
16
|
+
try {
|
|
17
|
+
await readFile(markerPath, 'utf8');
|
|
18
|
+
markers.push(markerPath);
|
|
19
|
+
continue;
|
|
20
|
+
} catch (error) {
|
|
21
|
+
if (error?.code !== 'ENOENT') {
|
|
22
|
+
throw error;
|
|
23
|
+
}
|
|
24
|
+
}
|
|
25
|
+
markers.push(...await listCheckpointMarkers(entryPath));
|
|
26
|
+
}
|
|
27
|
+
return markers.sort((left, right) => left.localeCompare(right));
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
async function ensureDirectoryExists(directoryPath) {
|
|
31
|
+
try {
|
|
32
|
+
const entries = await readdir(directoryPath, { withFileTypes: true });
|
|
33
|
+
return Array.isArray(entries);
|
|
34
|
+
} catch (error) {
|
|
35
|
+
if (error?.code === 'ENOENT') {
|
|
36
|
+
return false;
|
|
37
|
+
}
|
|
38
|
+
throw error;
|
|
39
|
+
}
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
async function readProcessedManifest(manifestPath) {
|
|
43
|
+
try {
|
|
44
|
+
const raw = await readFile(manifestPath, 'utf8');
|
|
45
|
+
const parsed = JSON.parse(raw);
|
|
46
|
+
const processed = Array.isArray(parsed?.processedCheckpointMarkers)
|
|
47
|
+
? parsed.processedCheckpointMarkers.filter((entry) => typeof entry === 'string')
|
|
48
|
+
: [];
|
|
49
|
+
return new Set(processed);
|
|
50
|
+
} catch (error) {
|
|
51
|
+
if (error?.code === 'ENOENT') {
|
|
52
|
+
return new Set();
|
|
53
|
+
}
|
|
54
|
+
throw error;
|
|
55
|
+
}
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
function createWatchResult(processed, manifestPath, aborted = false) {
|
|
59
|
+
return {
|
|
60
|
+
ok: true,
|
|
61
|
+
processedCount: processed.size,
|
|
62
|
+
manifestPath,
|
|
63
|
+
aborted,
|
|
64
|
+
};
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
async function waitForPollInterval(pollIntervalMs, signal) {
|
|
68
|
+
if (!signal) {
|
|
69
|
+
await new Promise((resolvePromise) => setTimeout(resolvePromise, pollIntervalMs));
|
|
70
|
+
return true;
|
|
71
|
+
}
|
|
72
|
+
if (signal.aborted) {
|
|
73
|
+
return false;
|
|
74
|
+
}
|
|
75
|
+
return new Promise((resolvePromise) => {
|
|
76
|
+
const onAbort = () => {
|
|
77
|
+
clearTimeout(timer);
|
|
78
|
+
resolvePromise(false);
|
|
79
|
+
};
|
|
80
|
+
const timer = setTimeout(() => {
|
|
81
|
+
signal.removeEventListener('abort', onAbort);
|
|
82
|
+
resolvePromise(true);
|
|
83
|
+
}, pollIntervalMs);
|
|
84
|
+
signal.addEventListener('abort', onAbort, { once: true });
|
|
85
|
+
});
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
export async function watchFinalizedCheckpoints(options) {
|
|
89
|
+
const checkpointsDir = resolve(String(options.checkpointsDir));
|
|
90
|
+
const manifestPath = resolve(String(options.manifestPath));
|
|
91
|
+
const pollIntervalMs = Number.isFinite(options.pollIntervalMs)
|
|
92
|
+
? Math.max(100, Math.floor(options.pollIntervalMs))
|
|
93
|
+
: 2000;
|
|
94
|
+
const stopWhenIdle = options.stopWhenIdle === true;
|
|
95
|
+
const onCheckpoint = typeof options.onCheckpoint === 'function'
|
|
96
|
+
? options.onCheckpoint
|
|
97
|
+
: null;
|
|
98
|
+
const signal = options.signal ?? null;
|
|
99
|
+
if (!onCheckpoint) {
|
|
100
|
+
throw new Error('watchFinalizedCheckpoints requires onCheckpoint(markerPath).');
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
const processed = await readProcessedManifest(manifestPath);
|
|
104
|
+
let idlePolls = 0;
|
|
105
|
+
for (;;) {
|
|
106
|
+
if (signal?.aborted) {
|
|
107
|
+
return createWatchResult(processed, manifestPath, true);
|
|
108
|
+
}
|
|
109
|
+
const checkpointsExist = await ensureDirectoryExists(checkpointsDir);
|
|
110
|
+
const markers = checkpointsExist
|
|
111
|
+
? await listCheckpointMarkers(checkpointsDir)
|
|
112
|
+
: [];
|
|
113
|
+
let sawNewMarker = false;
|
|
114
|
+
for (const markerPath of markers) {
|
|
115
|
+
if (processed.has(markerPath)) continue;
|
|
116
|
+
sawNewMarker = true;
|
|
117
|
+
await onCheckpoint(markerPath);
|
|
118
|
+
processed.add(markerPath);
|
|
119
|
+
await writeJsonArtifact(manifestPath, {
|
|
120
|
+
artifactType: 'training_checkpoint_watch_manifest',
|
|
121
|
+
schemaVersion: 1,
|
|
122
|
+
generatedAt: new Date().toISOString(),
|
|
123
|
+
processedCheckpointMarkers: [...processed].sort((left, right) => left.localeCompare(right)),
|
|
124
|
+
});
|
|
125
|
+
}
|
|
126
|
+
if (!sawNewMarker) {
|
|
127
|
+
idlePolls += 1;
|
|
128
|
+
if (stopWhenIdle && idlePolls > 0) {
|
|
129
|
+
return createWatchResult(processed, manifestPath);
|
|
130
|
+
}
|
|
131
|
+
} else {
|
|
132
|
+
idlePolls = 0;
|
|
133
|
+
}
|
|
134
|
+
const shouldContinue = await waitForPollInterval(pollIntervalMs, signal);
|
|
135
|
+
if (!shouldContinue) {
|
|
136
|
+
return createWatchResult(processed, manifestPath, true);
|
|
137
|
+
}
|
|
138
|
+
}
|
|
139
|
+
}
|
|
@@ -23,7 +23,12 @@ export declare function saveCheckpoint(
|
|
|
23
23
|
key: string,
|
|
24
24
|
data: unknown,
|
|
25
25
|
options?: CheckpointStoreOptions
|
|
26
|
-
): Promise<
|
|
26
|
+
): Promise<{
|
|
27
|
+
key: string;
|
|
28
|
+
path: string | null;
|
|
29
|
+
metadata: Record<string, unknown>;
|
|
30
|
+
data: unknown;
|
|
31
|
+
}>;
|
|
27
32
|
|
|
28
33
|
export declare function loadCheckpoint(
|
|
29
34
|
key: string,
|
|
@@ -31,6 +31,13 @@ function openCheckpointDB(options = {}) {
|
|
|
31
31
|
});
|
|
32
32
|
}
|
|
33
33
|
|
|
34
|
+
function closeCheckpointDB(db) {
|
|
35
|
+
if (!db || typeof db.close !== 'function') {
|
|
36
|
+
return;
|
|
37
|
+
}
|
|
38
|
+
db.close();
|
|
39
|
+
}
|
|
40
|
+
|
|
34
41
|
async function resolveNodeCheckpointPath(key, options = {}) {
|
|
35
42
|
const [{ resolve, join, dirname }, { mkdir }] = await Promise.all([
|
|
36
43
|
import('node:path'),
|
|
@@ -140,9 +147,15 @@ export async function saveCheckpoint(key, payload, options = {}) {
|
|
|
140
147
|
const useNodeStore = isNodeRuntime() && typeof indexedDB === 'undefined';
|
|
141
148
|
const nodePath = useNodeStore ? await resolveNodeCheckpointPath(key, options) : null;
|
|
142
149
|
const browserStore = useNodeStore ? null : await openCheckpointDB(options);
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
150
|
+
let previousData;
|
|
151
|
+
try {
|
|
152
|
+
previousData = useNodeStore
|
|
153
|
+
? await readNodeCheckpointRecord(nodePath)
|
|
154
|
+
: await readCheckpointRecord(browserStore.db, browserStore.storeName, key);
|
|
155
|
+
} catch (error) {
|
|
156
|
+
closeCheckpointDB(browserStore?.db);
|
|
157
|
+
throw error;
|
|
158
|
+
}
|
|
146
159
|
const previousMetadata = previousData?.metadata || {};
|
|
147
160
|
const previousLineage = previousMetadata.lineage || {};
|
|
148
161
|
const previousCheckpointHash = options.priorCheckpointHash
|
|
@@ -184,13 +197,35 @@ export async function saveCheckpoint(key, payload, options = {}) {
|
|
|
184
197
|
|
|
185
198
|
if (useNodeStore) {
|
|
186
199
|
await writeNodeCheckpointRecord(nodePath, data);
|
|
187
|
-
return
|
|
200
|
+
return {
|
|
201
|
+
key,
|
|
202
|
+
path: nodePath,
|
|
203
|
+
metadata: data.metadata,
|
|
204
|
+
data,
|
|
205
|
+
};
|
|
188
206
|
}
|
|
189
207
|
|
|
190
208
|
return new Promise((resolve, reject) => {
|
|
191
209
|
const tx = browserStore.db.transaction(browserStore.storeName, 'readwrite');
|
|
192
|
-
tx.oncomplete = () =>
|
|
193
|
-
|
|
210
|
+
tx.oncomplete = () => {
|
|
211
|
+
closeCheckpointDB(browserStore.db);
|
|
212
|
+
resolve({
|
|
213
|
+
key,
|
|
214
|
+
path: null,
|
|
215
|
+
metadata: data.metadata,
|
|
216
|
+
data,
|
|
217
|
+
});
|
|
218
|
+
};
|
|
219
|
+
tx.onerror = () => {
|
|
220
|
+
const error = tx.error;
|
|
221
|
+
closeCheckpointDB(browserStore.db);
|
|
222
|
+
reject(error);
|
|
223
|
+
};
|
|
224
|
+
tx.onabort = () => {
|
|
225
|
+
const error = tx.error ?? new Error('Checkpoint transaction aborted');
|
|
226
|
+
closeCheckpointDB(browserStore.db);
|
|
227
|
+
reject(error);
|
|
228
|
+
};
|
|
194
229
|
const store = tx.objectStore(browserStore.storeName);
|
|
195
230
|
store.put(data, key);
|
|
196
231
|
});
|
|
@@ -203,7 +238,11 @@ export async function loadCheckpoint(key, options = {}) {
|
|
|
203
238
|
? await readNodeCheckpointRecord(nodePath)
|
|
204
239
|
: await (async () => {
|
|
205
240
|
const { db, storeName } = await openCheckpointDB(options);
|
|
206
|
-
|
|
241
|
+
try {
|
|
242
|
+
return await readCheckpointRecord(db, storeName, key);
|
|
243
|
+
} finally {
|
|
244
|
+
closeCheckpointDB(db);
|
|
245
|
+
}
|
|
207
246
|
})();
|
|
208
247
|
|
|
209
248
|
if (!data || !data.metadata || !options.expectedMetadata) {
|
package/src/training/clip.js
CHANGED
|
@@ -12,7 +12,8 @@ async function readGradData(grad) {
|
|
|
12
12
|
}
|
|
13
13
|
|
|
14
14
|
export async function clipGradients(grads, config) {
|
|
15
|
-
const maxNorm = config?.training?.
|
|
15
|
+
const maxNorm = config?.training?.gradientClipping?.maxNorm
|
|
16
|
+
?? config?.training?.gradient?.maxNorm;
|
|
16
17
|
let sumSq = 0;
|
|
17
18
|
let totalParamCount = 0;
|
|
18
19
|
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
|
|
2
|
-
import { acquireBuffer, uploadData } from '../../memory/buffer-pool.js';
|
|
2
|
+
import { acquireBuffer, uploadData, releaseBuffer } from '../../memory/buffer-pool.js';
|
|
3
3
|
import { createTensor } from '../../gpu/tensor.js';
|
|
4
4
|
|
|
5
5
|
function flattenTokenBatch(samples, key) {
|
|
@@ -27,14 +27,26 @@ export function buildTokenBatch(samples) {
|
|
|
27
27
|
}
|
|
28
28
|
|
|
29
29
|
export function createTokenBatchTensors(batch) {
|
|
30
|
-
|
|
31
|
-
|
|
30
|
+
let inputBuf = null;
|
|
31
|
+
let targetBuf = null;
|
|
32
|
+
try {
|
|
33
|
+
inputBuf = acquireBuffer(batch.inputFlat.byteLength, undefined, 'train_input_tokens');
|
|
34
|
+
uploadData(inputBuf, batch.inputFlat);
|
|
32
35
|
|
|
33
|
-
|
|
34
|
-
|
|
36
|
+
targetBuf = acquireBuffer(batch.targetFlat.byteLength, undefined, 'train_target_tokens');
|
|
37
|
+
uploadData(targetBuf, batch.targetFlat);
|
|
35
38
|
|
|
36
|
-
|
|
37
|
-
|
|
39
|
+
const input = createTensor(inputBuf, 'f32', [batch.inputFlat.length], 'train_input_tokens');
|
|
40
|
+
const targets = createTensor(targetBuf, 'f32', [batch.targetFlat.length], 'train_target_tokens');
|
|
38
41
|
|
|
39
|
-
|
|
42
|
+
return { input, targets, offsets: batch.offsets };
|
|
43
|
+
} catch (error) {
|
|
44
|
+
if (inputBuf) {
|
|
45
|
+
releaseBuffer(inputBuf);
|
|
46
|
+
}
|
|
47
|
+
if (targetBuf) {
|
|
48
|
+
releaseBuffer(targetBuf);
|
|
49
|
+
}
|
|
50
|
+
throw error;
|
|
51
|
+
}
|
|
40
52
|
}
|