@simulatte/doppler 0.1.6 → 0.1.8
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +145 -0
- package/README.md +16 -23
- package/package.json +30 -32
- package/src/adapters/adapter-registry.js +12 -1
- package/src/adapters/lora-loader.js +23 -6
- package/src/bridge/extension-client.d.ts +5 -0
- package/src/bridge/extension-client.js +40 -0
- package/src/bridge/index.d.ts +2 -1
- package/src/bridge/index.js +6 -4
- package/src/browser/browser-converter.js +31 -1
- package/src/browser/file-picker.js +6 -0
- package/src/browser/safetensors-parser-browser.js +84 -1
- package/src/browser/shard-io-browser.js +2 -2
- package/src/browser/tensor-source-download.js +8 -2
- package/src/browser/tensor-source-http.d.ts +1 -0
- package/src/browser/tensor-source-http.js +5 -1
- package/src/client/doppler-api.browser.js +20 -4
- package/src/client/doppler-api.js +19 -3
- package/src/client/doppler-provider/generation.js +12 -0
- package/src/client/doppler-provider/model-manager.d.ts +10 -0
- package/src/client/doppler-provider/model-manager.js +91 -19
- package/src/client/doppler-provider/source-runtime.d.ts +2 -1
- package/src/client/doppler-provider/source-runtime.js +132 -13
- package/src/client/doppler-registry.json +5 -20
- package/src/config/backward-registry-loader.js +17 -2
- package/src/config/execution-v0-contract-check.js +113 -15
- package/src/config/kernel-path-contract-check.js +57 -29
- package/src/config/kernel-path-loader.d.ts +5 -0
- package/src/config/kernel-path-loader.js +18 -36
- package/src/config/kernels/kernel-ref-digests.js +1 -1
- package/src/config/kernels/registry.js +14 -1
- package/src/config/kernels/registry.json +81 -5
- package/src/config/loader.d.ts +1 -1
- package/src/config/loader.js +15 -2
- package/src/config/merge-contract-check.js +66 -4
- package/src/config/merge-helpers.js +128 -7
- package/src/config/merge.d.ts +1 -0
- package/src/config/merge.js +10 -0
- package/src/config/param-validator.js +47 -2
- package/src/config/presets/kernel-paths/{gemma2-q4k-dequant-f32a.json → gemma2-q4k-dequant-f32a-nosubgroups.json} +3 -3
- package/src/config/presets/kernel-paths/gemma3-f16-fused-f32a-online-streamingprefill.json +223 -0
- package/src/config/presets/kernel-paths/{gemma3-q4k-dequant-f32a.json → gemma3-q4k-dequant-f32a-nosubgroups.json} +3 -3
- package/src/config/presets/kernel-paths/gemma3-q4k-dequant-f32w-f32a-online.json +56 -0
- package/src/config/presets/kernel-paths/lfm2-q4k-dequant-f32a-nosubgroups.json +61 -0
- package/src/config/presets/kernel-paths/registry.json +43 -8
- package/src/config/presets/models/gemma2.json +3 -2
- package/src/config/presets/models/gemma3.json +2 -0
- package/src/config/presets/models/qwen3.json +4 -3
- package/src/config/presets/models/qwen3_5.json +16 -0
- package/src/config/presets/runtime/experiments/bench/gemma3-bench-q4k.json +1 -1
- package/src/config/presets/runtime/experiments/debug/gemma3-debug-q4k.json +1 -1
- package/src/config/presets/runtime/experiments/verify/gemma3-verify.json +1 -1
- package/src/config/presets/runtime/kernels/dequant-f16-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/dequant-f32-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/embeddinggemma-q4k-dequant-f32a.json +37 -0
- package/src/config/presets/runtime/kernels/fused-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f16a.json +33 -0
- package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f32a-nosubgroups.json +33 -0
- package/src/config/presets/runtime/kernels/gemma2-q4k-fused-f32a.json +33 -0
- package/src/config/presets/runtime/kernels/safe-q4k.json +6 -13
- package/src/config/presets/runtime/model/qwen3-5-layer-probe.json +52 -0
- package/src/config/presets/runtime/model/qwen3-5-linear-attn-debug.json +90 -0
- package/src/config/presets/runtime/platform/metal-apple-q4k.json +1 -1
- package/src/config/runtime.js +6 -1
- package/src/config/schema/conversion.schema.d.ts +1 -0
- package/src/config/schema/debug.schema.d.ts +5 -0
- package/src/config/schema/doppler.schema.js +16 -21
- package/src/config/schema/inference-defaults.schema.js +3 -3
- package/src/config/schema/kernel-path.schema.d.ts +5 -1
- package/src/config/schema/kernel-thresholds.schema.js +12 -4
- package/src/config/schema/manifest.schema.d.ts +3 -2
- package/src/config/schema/manifest.schema.js +17 -4
- package/src/config/schema/storage.schema.js +1 -1
- package/src/config/training-defaults.js +30 -22
- package/src/converter/conversion-plan.js +104 -11
- package/src/converter/core.d.ts +7 -0
- package/src/converter/core.js +16 -9
- package/src/converter/execution-v0-manifest.js +4 -1
- package/src/converter/index.d.ts +1 -0
- package/src/converter/index.js +1 -0
- package/src/converter/manifest-inference.js +50 -29
- package/src/converter/parsers/diffusion.js +0 -3
- package/src/converter/parsers/transformer.js +4 -0
- package/src/converter/quantization-info.js +40 -16
- package/src/converter/quantizer.js +19 -12
- package/src/converter/rope-config.js +8 -6
- package/src/converter/shard-packer.d.ts +1 -1
- package/src/converter/shard-packer.js +4 -1
- package/src/converter/tokenizer-utils.d.ts +1 -0
- package/src/converter/tokenizer-utils.js +4 -1
- package/src/debug/config.js +123 -11
- package/src/debug/reference/hf_qwen35_linear_attn_debug.py +268 -0
- package/src/debug/signals.js +7 -1
- package/src/debug/tensor.d.ts +2 -0
- package/src/debug/tensor.js +13 -2
- package/src/distribution/p2p-control-plane.js +52 -12
- package/src/distribution/p2p-observability.js +43 -7
- package/src/distribution/p2p-webrtc-browser.js +20 -0
- package/src/distribution/shard-delivery.js +83 -27
- package/src/formats/gguf/types.js +33 -16
- package/src/formats/rdrr/groups.d.ts +12 -4
- package/src/formats/rdrr/groups.js +3 -6
- package/src/formats/rdrr/parsing.d.ts +4 -0
- package/src/formats/rdrr/parsing.js +53 -3
- package/src/formats/rdrr/types.d.ts +2 -1
- package/src/gpu/command-recorder.js +86 -61
- package/src/gpu/device.d.ts +1 -0
- package/src/gpu/device.js +73 -19
- package/src/gpu/kernel-tuner/benchmarks.js +326 -316
- package/src/gpu/kernel-tuner/cache.js +71 -4
- package/src/gpu/kernel-tuner/tuner.js +22 -4
- package/src/gpu/kernels/attention.js +15 -34
- package/src/gpu/kernels/backward/adam.js +62 -58
- package/src/gpu/kernels/backward/attention_backward.js +257 -169
- package/src/gpu/kernels/backward/conv2d_backward.js +14 -1
- package/src/gpu/kernels/cast.js +191 -149
- package/src/gpu/kernels/check-stop.js +33 -44
- package/src/gpu/kernels/conv2d.js +27 -17
- package/src/gpu/kernels/cross_entropy_loss.js +21 -15
- package/src/gpu/kernels/depthwise_conv2d.js +36 -26
- package/src/gpu/kernels/dequant.js +178 -126
- package/src/gpu/kernels/energy.d.ts +3 -21
- package/src/gpu/kernels/energy.js +111 -88
- package/src/gpu/kernels/feature-check.js +1 -1
- package/src/gpu/kernels/fused_ffn.js +84 -65
- package/src/gpu/kernels/fused_matmul_residual.js +56 -33
- package/src/gpu/kernels/fused_matmul_rmsnorm.js +62 -45
- package/src/gpu/kernels/gather.js +33 -15
- package/src/gpu/kernels/gelu.js +19 -11
- package/src/gpu/kernels/grouped_pointwise_conv2d.js +33 -23
- package/src/gpu/kernels/groupnorm.js +34 -23
- package/src/gpu/kernels/index.d.ts +8 -0
- package/src/gpu/kernels/index.js +6 -0
- package/src/gpu/kernels/kv-quantize.js +5 -2
- package/src/gpu/kernels/layernorm.js +35 -19
- package/src/gpu/kernels/logit-merge.js +5 -3
- package/src/gpu/kernels/matmul-selection.js +47 -4
- package/src/gpu/kernels/matmul.d.ts +2 -0
- package/src/gpu/kernels/matmul.js +59 -40
- package/src/gpu/kernels/modulate.js +23 -15
- package/src/gpu/kernels/moe.js +221 -175
- package/src/gpu/kernels/pixel_shuffle.js +22 -14
- package/src/gpu/kernels/relu.js +18 -10
- package/src/gpu/kernels/repeat_channels.js +25 -17
- package/src/gpu/kernels/residual.js +37 -27
- package/src/gpu/kernels/rmsnorm.js +66 -43
- package/src/gpu/kernels/rope.js +3 -0
- package/src/gpu/kernels/sample.js +27 -38
- package/src/gpu/kernels/sana_linear_attention.js +18 -10
- package/src/gpu/kernels/scale.js +18 -11
- package/src/gpu/kernels/shader-cache.js +4 -2
- package/src/gpu/kernels/silu.js +120 -72
- package/src/gpu/kernels/softmax.js +44 -25
- package/src/gpu/kernels/split_qg.d.ts +50 -0
- package/src/gpu/kernels/split_qg.js +46 -0
- package/src/gpu/kernels/split_qg.wgsl +58 -0
- package/src/gpu/kernels/split_qg_f16.wgsl +62 -0
- package/src/gpu/kernels/split_qkv.js +23 -13
- package/src/gpu/kernels/transpose.js +18 -10
- package/src/gpu/kernels/transpose.wgsl +5 -3
- package/src/gpu/kernels/upsample2d.js +21 -13
- package/src/gpu/kernels/utils.js +20 -13
- package/src/gpu/partitioned-buffer-pool.js +10 -2
- package/src/gpu/perf-guards.js +2 -9
- package/src/gpu/profiler.js +27 -22
- package/src/gpu/readback-utils.d.ts +16 -0
- package/src/gpu/readback-utils.js +41 -0
- package/src/gpu/submit-tracker.js +13 -0
- package/src/gpu/uniform-cache.d.ts +1 -0
- package/src/gpu/uniform-cache.js +30 -9
- package/src/gpu/weight-buffer.d.ts +1 -1
- package/src/gpu/weight-buffer.js +1 -1
- package/src/hotswap/intent-bundle.js +6 -0
- package/src/hotswap/manifest.d.ts +10 -1
- package/src/hotswap/manifest.js +12 -2
- package/src/hotswap/runtime.js +30 -8
- package/src/index-browser.d.ts +44 -0
- package/src/index-browser.js +14 -0
- package/src/inference/browser-harness-contract-helpers.d.ts +5 -0
- package/src/inference/browser-harness-contract-helpers.js +28 -0
- package/src/inference/browser-harness-diffusion-energy-suites.d.ts +2 -0
- package/src/inference/browser-harness-diffusion-energy-suites.js +269 -0
- package/src/inference/browser-harness-model-helpers.d.ts +16 -0
- package/src/inference/browser-harness-model-helpers.js +217 -0
- package/src/inference/browser-harness-report-helpers.d.ts +7 -0
- package/src/inference/browser-harness-report-helpers.js +42 -0
- package/src/inference/browser-harness-runtime-helpers.d.ts +61 -0
- package/src/inference/browser-harness-runtime-helpers.js +415 -0
- package/src/inference/browser-harness-suite-helpers.d.ts +28 -0
- package/src/inference/browser-harness-suite-helpers.js +268 -0
- package/src/inference/browser-harness-text-helpers.d.ts +27 -0
- package/src/inference/browser-harness-text-helpers.js +788 -0
- package/src/inference/browser-harness.d.ts +8 -0
- package/src/inference/browser-harness.js +149 -1996
- package/src/inference/kv-cache/base.js +140 -94
- package/src/inference/kv-cache/tiered.js +5 -3
- package/src/inference/moe-router.js +88 -56
- package/src/inference/multi-model-network.js +5 -3
- package/src/inference/network-evolution.d.ts +11 -2
- package/src/inference/network-evolution.js +20 -21
- package/src/inference/pipelines/context.d.ts +3 -0
- package/src/inference/pipelines/context.js +142 -2
- package/src/inference/pipelines/diffusion/helpers.js +10 -2
- package/src/inference/pipelines/diffusion/pipeline.js +2 -1
- package/src/inference/pipelines/diffusion/sd3-transformer.js +10 -10
- package/src/inference/pipelines/diffusion/text-encoder-gpu.js +8 -2
- package/src/inference/pipelines/diffusion/vae.js +3 -7
- package/src/inference/pipelines/energy/pipeline.js +27 -21
- package/src/inference/pipelines/energy/quintel.d.ts +5 -0
- package/src/inference/pipelines/energy/quintel.js +11 -0
- package/src/inference/pipelines/energy-head/row-head-pipeline.js +17 -13
- package/src/inference/pipelines/structured/json-head-pipeline.js +26 -11
- package/src/inference/pipelines/text/attention/output-projection.d.ts +12 -0
- package/src/inference/pipelines/text/attention/output-projection.js +8 -0
- package/src/inference/pipelines/text/attention/projections.d.ts +10 -1
- package/src/inference/pipelines/text/attention/projections.js +192 -112
- package/src/inference/pipelines/text/attention/record.js +77 -14
- package/src/inference/pipelines/text/attention/run.js +112 -14
- package/src/inference/pipelines/text/config.js +17 -4
- package/src/inference/pipelines/text/embed.js +2 -8
- package/src/inference/pipelines/text/execution-plan.js +46 -23
- package/src/inference/pipelines/text/execution-v0-contract-helpers.d.ts +59 -0
- package/src/inference/pipelines/text/execution-v0-contract-helpers.js +937 -0
- package/src/inference/pipelines/text/execution-v0-runtime-builders.d.ts +15 -0
- package/src/inference/pipelines/text/execution-v0-runtime-builders.js +279 -0
- package/src/inference/pipelines/text/execution-v0.js +62 -1013
- package/src/inference/pipelines/text/generator-runtime.js +5 -0
- package/src/inference/pipelines/text/generator-steps.d.ts +52 -0
- package/src/inference/pipelines/text/generator-steps.js +340 -221
- package/src/inference/pipelines/text/generator.js +56 -40
- package/src/inference/pipelines/text/init.d.ts +13 -0
- package/src/inference/pipelines/text/init.js +94 -25
- package/src/inference/pipelines/text/kernel-path-auto-select.js +2 -0
- package/src/inference/pipelines/text/kernel-trace.d.ts +2 -0
- package/src/inference/pipelines/text/kernel-trace.js +6 -0
- package/src/inference/pipelines/text/layer.js +4 -9
- package/src/inference/pipelines/text/linear-attention.d.ts +15 -0
- package/src/inference/pipelines/text/linear-attention.js +113 -9
- package/src/inference/pipelines/text/logits/gpu.js +12 -7
- package/src/inference/pipelines/text/logits/index.d.ts +6 -1
- package/src/inference/pipelines/text/logits/index.js +13 -12
- package/src/inference/pipelines/text/logits/utils.d.ts +7 -0
- package/src/inference/pipelines/text/logits/utils.js +9 -0
- package/src/inference/pipelines/text/lora-apply.js +50 -32
- package/src/inference/pipelines/text/model-load.js +282 -104
- package/src/inference/pipelines/text/moe-cache.js +5 -4
- package/src/inference/pipelines/text/moe-cpu-gptoss.js +74 -69
- package/src/inference/pipelines/text/moe-cpu.js +42 -38
- package/src/inference/pipelines/text/moe-gpu.js +110 -86
- package/src/inference/pipelines/text/ops.js +90 -90
- package/src/inference/pipelines/text/probes.js +9 -9
- package/src/inference/pipelines/text/sampling.js +52 -6
- package/src/inference/pipelines/text/weights.js +17 -7
- package/src/inference/pipelines/text.js +13 -1
- package/src/inference/speculative.d.ts +2 -2
- package/src/inference/speculative.js +4 -18
- package/src/inference/test-harness.d.ts +1 -1
- package/src/inference/test-harness.js +17 -7
- package/src/inference/tokenizer.d.ts +0 -5
- package/src/inference/tokenizer.js +4 -23
- package/src/inference/tokenizers/bpe.js +9 -0
- package/src/inference/tokenizers/bundled.js +20 -0
- package/src/inference/tokenizers/sentencepiece.js +12 -0
- package/src/loader/doppler-loader.js +38 -22
- package/src/loader/dtype-utils.js +3 -44
- package/src/loader/embedding-loader.js +7 -3
- package/src/loader/experts/expert-cache.js +13 -6
- package/src/loader/experts/expert-loader.js +10 -6
- package/src/loader/final-weights-loader.js +10 -4
- package/src/loader/layer-loader.js +2 -1
- package/src/loader/loader-state.js +2 -2
- package/src/loader/memory-monitor.js +8 -0
- package/src/loader/multi-model-loader.d.ts +14 -0
- package/src/loader/multi-model-loader.js +70 -24
- package/src/loader/shard-cache.js +84 -14
- package/src/loader/shard-resolver.js +25 -3
- package/src/loader/tensors/tensor-loader.js +214 -144
- package/src/loader/tensors/tensor-reader.js +76 -19
- package/src/loader/weight-downcast.js +1 -1
- package/src/memory/buffer-pool.d.ts +9 -1
- package/src/memory/buffer-pool.js +109 -44
- package/src/memory/unified-detect.js +1 -1
- package/src/rules/inference/dtype.rules.json +5 -0
- package/src/rules/inference/kernel-path.rules.json +24 -8
- package/src/rules/kernels/split-qg.rules.json +6 -0
- package/src/rules/rule-registry.js +27 -1
- package/src/storage/backends/opfs-store.js +68 -24
- package/src/storage/downloader.js +365 -83
- package/src/storage/index.d.ts +3 -0
- package/src/storage/index.js +3 -0
- package/src/storage/preflight.d.ts +2 -2
- package/src/storage/preflight.js +24 -2
- package/src/storage/quickstart-downloader.js +11 -5
- package/src/storage/registry.js +10 -4
- package/src/storage/reports.js +1 -1
- package/src/storage/shard-manager.d.ts +15 -1
- package/src/storage/shard-manager.js +55 -6
- package/src/storage/source-artifact-store.d.ts +52 -0
- package/src/storage/source-artifact-store.js +234 -0
- package/src/tooling/command-api-constants.d.ts +9 -0
- package/src/tooling/command-api-constants.js +9 -0
- package/src/tooling/command-api-family-normalizers.d.ts +9 -0
- package/src/tooling/command-api-family-normalizers.js +343 -0
- package/src/tooling/command-api-helpers.d.ts +25 -0
- package/src/tooling/command-api-helpers.js +262 -0
- package/src/tooling/command-api.js +16 -602
- package/src/tooling/command-envelope.js +4 -1
- package/src/tooling/command-runner-shared.js +52 -18
- package/src/tooling/conversion-config-materializer.js +3 -5
- package/src/tooling/lean-execution-contract.js +150 -3
- package/src/tooling/node-browser-command-runner.js +161 -271
- package/src/tooling/node-command-runner.js +29 -3
- package/src/tooling/node-converter.js +30 -1
- package/src/tooling/node-source-runtime.d.ts +1 -1
- package/src/tooling/node-source-runtime.js +120 -3
- package/src/tooling/node-webgpu.js +24 -21
- package/src/tooling/opfs-cache.js +21 -4
- package/src/tooling/runtime-input-composition.d.ts +38 -0
- package/src/tooling/runtime-input-composition.js +86 -0
- package/src/tooling/source-runtime-bundle.d.ts +40 -5
- package/src/tooling/source-runtime-bundle.js +261 -34
- package/src/tooling/source-runtime-materializer.d.ts +6 -0
- package/src/tooling/source-runtime-materializer.js +93 -0
- package/src/training/attention-backward.js +32 -17
- package/src/training/autograd.js +80 -52
- package/src/training/checkpoint-watch.d.ts +2 -1
- package/src/training/checkpoint-watch.js +39 -6
- package/src/training/checkpoint.js +40 -11
- package/src/training/clip.js +2 -1
- package/src/training/datasets/token-batch.js +20 -8
- package/src/training/distillation/checkpoint-watch.js +1 -0
- package/src/training/distillation/student-fixture.d.ts +22 -0
- package/src/training/distillation/student-fixture.js +846 -0
- package/src/training/distillation/suite-data.d.ts +45 -0
- package/src/training/distillation/suite-data.js +189 -0
- package/src/training/lora-pipeline.js +4 -7
- package/src/training/lora.js +26 -12
- package/src/training/loss.js +5 -6
- package/src/training/objectives/cross_entropy.js +2 -5
- package/src/training/objectives/distill_kd.js +4 -8
- package/src/training/objectives/distill_triplet.js +4 -8
- package/src/training/objectives/ul_stage2_base.js +4 -8
- package/src/training/operator-command.js +2 -0
- package/src/training/optimizer.js +19 -7
- package/src/training/runner.js +2 -1
- package/src/training/suite.js +18 -978
- package/src/training/tensor-factory.d.ts +9 -0
- package/src/training/tensor-factory.js +13 -0
- package/src/training/trainer.js +3 -5
- package/src/training/ul_dataset.js +3 -5
- package/src/training/workloads.js +70 -79
- package/src/types/model.d.ts +5 -0
- package/src/version.js +1 -1
- package/tools/convert-safetensors-node.js +22 -16
- package/tools/doppler-cli.js +50 -26
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
export interface DistillDataScope {
|
|
2
|
+
sourceLangs: string[] | null;
|
|
3
|
+
targetLangs: string[] | null;
|
|
4
|
+
pairAllowlist: string[] | null;
|
|
5
|
+
sourceLangSet: Set<string> | null;
|
|
6
|
+
targetLangSet: Set<string> | null;
|
|
7
|
+
pairAllowlistSet: Set<string> | null;
|
|
8
|
+
strictPairContract: boolean;
|
|
9
|
+
}
|
|
10
|
+
|
|
11
|
+
export interface DistillSample {
|
|
12
|
+
index?: number;
|
|
13
|
+
direction?: string | null;
|
|
14
|
+
sourceLang?: string | null;
|
|
15
|
+
targetLang?: string | null;
|
|
16
|
+
source?: string | null;
|
|
17
|
+
targetPos?: string | null;
|
|
18
|
+
targetNeg?: string | null;
|
|
19
|
+
}
|
|
20
|
+
|
|
21
|
+
export declare function normalizeOptionalString(value: unknown): string | null;
|
|
22
|
+
|
|
23
|
+
export declare function normalizeDistillDatasetPath(value: unknown): string | null;
|
|
24
|
+
|
|
25
|
+
export declare function resolveDistillDataScope(
|
|
26
|
+
options?: Record<string, unknown>,
|
|
27
|
+
trainingConfig?: Record<string, unknown> | null
|
|
28
|
+
): DistillDataScope;
|
|
29
|
+
|
|
30
|
+
export declare function encodeDistillRow(
|
|
31
|
+
record: Record<string, unknown> | null | undefined,
|
|
32
|
+
index: number,
|
|
33
|
+
scope?: DistillDataScope | null
|
|
34
|
+
): DistillSample | null;
|
|
35
|
+
|
|
36
|
+
export declare function summarizeDirectionCounts(
|
|
37
|
+
samples: Array<Record<string, unknown> | null | undefined>
|
|
38
|
+
): Record<string, number>;
|
|
39
|
+
|
|
40
|
+
export declare function buildDistillPrompt(sample: Record<string, unknown> | null | undefined): string;
|
|
41
|
+
|
|
42
|
+
export declare function buildDistillCandidatePrompt(
|
|
43
|
+
sample: Record<string, unknown> | null | undefined,
|
|
44
|
+
candidate: unknown
|
|
45
|
+
): string;
|
|
@@ -0,0 +1,189 @@
|
|
|
1
|
+
export function normalizeOptionalString(value) {
|
|
2
|
+
if (value === undefined || value === null) return null;
|
|
3
|
+
const trimmed = String(value).trim();
|
|
4
|
+
return trimmed || null;
|
|
5
|
+
}
|
|
6
|
+
|
|
7
|
+
export function normalizeDistillDatasetPath(value) {
|
|
8
|
+
return normalizeOptionalString(value);
|
|
9
|
+
}
|
|
10
|
+
|
|
11
|
+
function normalizeLangCode(value) {
|
|
12
|
+
const normalized = normalizeOptionalString(value);
|
|
13
|
+
if (!normalized) return null;
|
|
14
|
+
const compact = normalized.toLowerCase().replace(/_/g, '-');
|
|
15
|
+
if (compact.startsWith('en')) return 'en';
|
|
16
|
+
if (compact.startsWith('es')) return 'es';
|
|
17
|
+
return compact;
|
|
18
|
+
}
|
|
19
|
+
|
|
20
|
+
function normalizePairDirection(value) {
|
|
21
|
+
const pair = normalizeOptionalString(value);
|
|
22
|
+
if (!pair) return null;
|
|
23
|
+
const normalized = pair.toLowerCase().replace(/_/g, '-').replace(/\s+/g, '');
|
|
24
|
+
const parts = normalized.includes('->')
|
|
25
|
+
? normalized.split('->').filter(Boolean)
|
|
26
|
+
: normalized.split('-').filter(Boolean);
|
|
27
|
+
if (parts.length !== 2) return null;
|
|
28
|
+
return `${normalizeLangCode(parts[0]) || parts[0]}->${normalizeLangCode(parts[1]) || parts[1]}`;
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
function normalizeOptionalStringArray(value) {
|
|
32
|
+
if (value === undefined || value === null) return null;
|
|
33
|
+
const list = Array.isArray(value)
|
|
34
|
+
? value
|
|
35
|
+
: (typeof value === 'string' ? value.split(',') : null);
|
|
36
|
+
if (!Array.isArray(list)) return null;
|
|
37
|
+
const normalized = list
|
|
38
|
+
.map((entry) => normalizeOptionalString(entry))
|
|
39
|
+
.filter(Boolean);
|
|
40
|
+
return normalized.length > 0 ? normalized : null;
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
function normalizeDistillLanguageAllowlist(value) {
|
|
44
|
+
const list = normalizeOptionalStringArray(value);
|
|
45
|
+
if (!list) return null;
|
|
46
|
+
const normalized = list
|
|
47
|
+
.map((entry) => normalizeLangCode(entry))
|
|
48
|
+
.filter(Boolean);
|
|
49
|
+
if (normalized.length === 0) return null;
|
|
50
|
+
return [...new Set(normalized)];
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
function normalizeDistillPairAllowlist(value) {
|
|
54
|
+
const list = normalizeOptionalStringArray(value);
|
|
55
|
+
if (!list) return null;
|
|
56
|
+
const normalized = list
|
|
57
|
+
.map((entry) => normalizePairDirection(entry))
|
|
58
|
+
.filter(Boolean);
|
|
59
|
+
if (normalized.length === 0) return null;
|
|
60
|
+
return [...new Set(normalized)];
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
export function resolveDistillDataScope(options = {}, trainingConfig = null) {
|
|
64
|
+
const distillConfig = trainingConfig?.distill || {};
|
|
65
|
+
const sourceLangs = normalizeDistillLanguageAllowlist(
|
|
66
|
+
options.distillSourceLangs ?? distillConfig.sourceLangs ?? null
|
|
67
|
+
);
|
|
68
|
+
const targetLangs = normalizeDistillLanguageAllowlist(
|
|
69
|
+
options.distillTargetLangs ?? distillConfig.targetLangs ?? null
|
|
70
|
+
);
|
|
71
|
+
const pairAllowlist = normalizeDistillPairAllowlist(
|
|
72
|
+
options.distillPairAllowlist ?? distillConfig.pairAllowlist ?? null
|
|
73
|
+
);
|
|
74
|
+
const strictPairContract = (
|
|
75
|
+
options.strictPairContract === true
|
|
76
|
+
|| distillConfig.strictPairContract === true
|
|
77
|
+
);
|
|
78
|
+
return {
|
|
79
|
+
sourceLangs,
|
|
80
|
+
targetLangs,
|
|
81
|
+
pairAllowlist,
|
|
82
|
+
sourceLangSet: sourceLangs ? new Set(sourceLangs) : null,
|
|
83
|
+
targetLangSet: targetLangs ? new Set(targetLangs) : null,
|
|
84
|
+
pairAllowlistSet: pairAllowlist ? new Set(pairAllowlist) : null,
|
|
85
|
+
strictPairContract,
|
|
86
|
+
};
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
function resolveDistillDirection(record) {
|
|
90
|
+
const pairDirection = normalizePairDirection(record?.pair);
|
|
91
|
+
if (pairDirection) return pairDirection;
|
|
92
|
+
const srcLang = normalizeLangCode(record?.src_lang);
|
|
93
|
+
const tgtLang = normalizeLangCode(record?.tgt_lang || record?.lang);
|
|
94
|
+
if (srcLang && tgtLang) {
|
|
95
|
+
return `${srcLang}->${tgtLang}`;
|
|
96
|
+
}
|
|
97
|
+
return null;
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
function resolveStringCandidate(record, keys) {
|
|
101
|
+
for (const key of keys) {
|
|
102
|
+
const value = normalizeOptionalString(record?.[key]);
|
|
103
|
+
if (value) return value;
|
|
104
|
+
}
|
|
105
|
+
return null;
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
export function encodeDistillRow(record, index, scope = null) {
|
|
109
|
+
if (!record || typeof record !== 'object') return null;
|
|
110
|
+
const source = resolveStringCandidate(record, ['source', 'query']);
|
|
111
|
+
const targetPos = resolveStringCandidate(record, ['target_pos', 'target', 'pos']);
|
|
112
|
+
const targetNeg = resolveStringCandidate(record, ['target_neg', 'neg']);
|
|
113
|
+
if (!source || !targetPos) return null;
|
|
114
|
+
const sourceLangRaw = normalizeLangCode(record?.src_lang);
|
|
115
|
+
const targetLangRaw = normalizeLangCode(record?.tgt_lang || record?.lang);
|
|
116
|
+
const pairDirection = normalizePairDirection(record?.pair);
|
|
117
|
+
const sourceTargetDirection = (
|
|
118
|
+
sourceLangRaw && targetLangRaw
|
|
119
|
+
? `${sourceLangRaw}->${targetLangRaw}`
|
|
120
|
+
: null
|
|
121
|
+
);
|
|
122
|
+
if (scope?.strictPairContract === true) {
|
|
123
|
+
if (!sourceLangRaw || !targetLangRaw) {
|
|
124
|
+
throw new Error('strictPairContract requires src_lang and tgt_lang/lang on each row.');
|
|
125
|
+
}
|
|
126
|
+
if (!pairDirection) {
|
|
127
|
+
throw new Error('strictPairContract requires pair on each row.');
|
|
128
|
+
}
|
|
129
|
+
if (pairDirection !== sourceTargetDirection) {
|
|
130
|
+
throw new Error(`pair "${record?.pair}" does not match src/tgt "${sourceLangRaw}-${targetLangRaw}".`);
|
|
131
|
+
}
|
|
132
|
+
}
|
|
133
|
+
const direction = pairDirection || sourceTargetDirection || resolveDistillDirection(record) || 'unknown';
|
|
134
|
+
const [directionSourceLang, directionTargetLang] = String(direction).split('->');
|
|
135
|
+
const sourceLang = sourceLangRaw || normalizeLangCode(directionSourceLang);
|
|
136
|
+
const targetLang = targetLangRaw || normalizeLangCode(directionTargetLang);
|
|
137
|
+
if (scope?.sourceLangSet && (!sourceLang || !scope.sourceLangSet.has(sourceLang))) {
|
|
138
|
+
return null;
|
|
139
|
+
}
|
|
140
|
+
if (scope?.targetLangSet && (!targetLang || !scope.targetLangSet.has(targetLang))) {
|
|
141
|
+
return null;
|
|
142
|
+
}
|
|
143
|
+
if (scope?.pairAllowlistSet && !scope.pairAllowlistSet.has(direction)) {
|
|
144
|
+
return null;
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
return {
|
|
148
|
+
index,
|
|
149
|
+
direction,
|
|
150
|
+
sourceLang: sourceLang || null,
|
|
151
|
+
targetLang: targetLang || null,
|
|
152
|
+
source,
|
|
153
|
+
targetPos,
|
|
154
|
+
targetNeg: targetNeg || null,
|
|
155
|
+
};
|
|
156
|
+
}
|
|
157
|
+
|
|
158
|
+
export function summarizeDirectionCounts(samples) {
|
|
159
|
+
const counts = {};
|
|
160
|
+
for (const sample of samples) {
|
|
161
|
+
const key = sample?.direction || 'unknown';
|
|
162
|
+
counts[key] = (counts[key] || 0) + 1;
|
|
163
|
+
}
|
|
164
|
+
return counts;
|
|
165
|
+
}
|
|
166
|
+
|
|
167
|
+
function resolveLanguageName(langCode) {
|
|
168
|
+
const normalized = normalizeLangCode(langCode);
|
|
169
|
+
if (normalized === 'en') return 'English';
|
|
170
|
+
if (normalized === 'es') return 'Spanish';
|
|
171
|
+
return normalized || 'target';
|
|
172
|
+
}
|
|
173
|
+
|
|
174
|
+
export function buildDistillPrompt(sample) {
|
|
175
|
+
const direction = String(sample?.direction || '').trim();
|
|
176
|
+
const [srcCodeRaw, tgtCodeRaw] = direction.split('->');
|
|
177
|
+
const srcCode = normalizeLangCode(srcCodeRaw) || srcCodeRaw || 'source';
|
|
178
|
+
const tgtCode = normalizeLangCode(tgtCodeRaw) || tgtCodeRaw || 'target';
|
|
179
|
+
const srcName = resolveLanguageName(srcCode);
|
|
180
|
+
const tgtName = resolveLanguageName(tgtCode);
|
|
181
|
+
const source = String(sample?.source || '').trim();
|
|
182
|
+
return `Translate from ${srcName} to ${tgtName}:\n${source}\nTranslation:`;
|
|
183
|
+
}
|
|
184
|
+
|
|
185
|
+
export function buildDistillCandidatePrompt(sample, candidate) {
|
|
186
|
+
const base = buildDistillPrompt(sample);
|
|
187
|
+
const text = String(candidate || '').trim();
|
|
188
|
+
return text ? `${base} ${text}` : base;
|
|
189
|
+
}
|
|
@@ -3,7 +3,6 @@ import { join, resolve } from 'node:path';
|
|
|
3
3
|
|
|
4
4
|
import { loadBackwardRegistry } from '../config/backward-registry-loader.js';
|
|
5
5
|
import { acquireBuffer, readBuffer, releaseBuffer, uploadData } from '../memory/buffer-pool.js';
|
|
6
|
-
import { createTensor } from '../gpu/tensor.js';
|
|
7
6
|
import { runMatmul } from '../gpu/kernels/index.js';
|
|
8
7
|
import { runResidualAdd } from '../gpu/kernels/residual.js';
|
|
9
8
|
import { parseJsonl } from './datasets/jsonl.js';
|
|
@@ -27,6 +26,7 @@ import {
|
|
|
27
26
|
} from './operator-artifacts.js';
|
|
28
27
|
import { watchFinalizedCheckpoints } from './checkpoint-watch.js';
|
|
29
28
|
import { loadLoRAFromManifest } from '../adapters/lora-loader.js';
|
|
29
|
+
import { createUploadedTensor } from './tensor-factory.js';
|
|
30
30
|
|
|
31
31
|
function stableSortObject(value) {
|
|
32
32
|
if (Array.isArray(value)) {
|
|
@@ -48,16 +48,12 @@ function stableJson(value) {
|
|
|
48
48
|
|
|
49
49
|
function makeTensorFromFloat32(values, shape, label) {
|
|
50
50
|
const data = values instanceof Float32Array ? values : new Float32Array(values);
|
|
51
|
-
|
|
52
|
-
uploadData(buffer, data);
|
|
53
|
-
return createTensor(buffer, 'f32', [...shape], label);
|
|
51
|
+
return createUploadedTensor(data, 'f32', shape, label);
|
|
54
52
|
}
|
|
55
53
|
|
|
56
54
|
function makeTensorFromUint32(values, shape, label) {
|
|
57
55
|
const data = values instanceof Uint32Array ? values : new Uint32Array(values);
|
|
58
|
-
|
|
59
|
-
uploadData(buffer, data);
|
|
60
|
-
return createTensor(buffer, 'u32', [...shape], label);
|
|
56
|
+
return createUploadedTensor(data, 'u32', shape, label);
|
|
61
57
|
}
|
|
62
58
|
|
|
63
59
|
function releaseTensor(tensor) {
|
|
@@ -709,6 +705,7 @@ export async function watchLoraCheckpoints(options) {
|
|
|
709
705
|
manifestPath: join(options.runRoot, 'scoreboard', 'watch-manifest.json'),
|
|
710
706
|
pollIntervalMs: options.pollIntervalMs || 2000,
|
|
711
707
|
stopWhenIdle: options.stopWhenIdle === true,
|
|
708
|
+
signal: options.signal ?? null,
|
|
712
709
|
onCheckpoint: async (markerPath) => {
|
|
713
710
|
const raw = await readFile(markerPath, 'utf8');
|
|
714
711
|
const marker = JSON.parse(raw);
|
package/src/training/lora.js
CHANGED
|
@@ -12,18 +12,32 @@ export class LoraAdapter {
|
|
|
12
12
|
const aBytes = tensorBytes([inDim, rank], dtype);
|
|
13
13
|
const bBytes = tensorBytes([rank, outDim], dtype);
|
|
14
14
|
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
'
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
15
|
+
let aBuffer = null;
|
|
16
|
+
let bBuffer = null;
|
|
17
|
+
try {
|
|
18
|
+
aBuffer = acquireBuffer(aBytes, BufferUsage.STORAGE, 'lora_A');
|
|
19
|
+
bBuffer = acquireBuffer(bBytes, BufferUsage.STORAGE, 'lora_B');
|
|
20
|
+
this.A = createTensor(
|
|
21
|
+
aBuffer,
|
|
22
|
+
dtype,
|
|
23
|
+
[inDim, rank],
|
|
24
|
+
'lora_A'
|
|
25
|
+
);
|
|
26
|
+
this.B = createTensor(
|
|
27
|
+
bBuffer,
|
|
28
|
+
dtype,
|
|
29
|
+
[rank, outDim],
|
|
30
|
+
'lora_B'
|
|
31
|
+
);
|
|
32
|
+
} catch (error) {
|
|
33
|
+
if (aBuffer) {
|
|
34
|
+
releaseBuffer(aBuffer);
|
|
35
|
+
}
|
|
36
|
+
if (bBuffer) {
|
|
37
|
+
releaseBuffer(bBuffer);
|
|
38
|
+
}
|
|
39
|
+
throw error;
|
|
40
|
+
}
|
|
27
41
|
this.alpha = alpha;
|
|
28
42
|
this.rank = rank;
|
|
29
43
|
}
|
package/src/training/loss.js
CHANGED
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
|
|
2
2
|
import { runSoftmax, runCrossEntropyLoss, castF16ToF32 } from '../gpu/kernels/index.js';
|
|
3
|
-
import { releaseBuffer } from '../memory/buffer-pool.js';
|
|
4
3
|
import { OpType } from './autograd.js';
|
|
5
4
|
|
|
6
5
|
export async function crossEntropyLoss(logits, targets, config, tape) {
|
|
@@ -25,13 +24,13 @@ export async function crossEntropyLoss(logits, targets, config, tape) {
|
|
|
25
24
|
OpType.SOFTMAX,
|
|
26
25
|
(input) => runSoftmax(input, -1, { batchSize: numTokens, size: vocabSize }),
|
|
27
26
|
[logitsF32],
|
|
28
|
-
{
|
|
27
|
+
{
|
|
28
|
+
rows: numTokens,
|
|
29
|
+
cols: vocabSize,
|
|
30
|
+
retainBuffers: logitsF32 !== logits ? [logitsF32.buffer] : [],
|
|
31
|
+
}
|
|
29
32
|
);
|
|
30
33
|
|
|
31
|
-
if (logitsF32 !== logits) {
|
|
32
|
-
releaseBuffer(logitsF32.buffer);
|
|
33
|
-
}
|
|
34
|
-
|
|
35
34
|
return tape.record(
|
|
36
35
|
OpType.CROSS_ENTROPY,
|
|
37
36
|
(input, target) => runCrossEntropyLoss(input, target, { numTokens, vocabSize }),
|
|
@@ -1,15 +1,12 @@
|
|
|
1
1
|
import { crossEntropyLoss as defaultCrossEntropyLoss } from '../loss.js';
|
|
2
|
-
import { acquireBuffer, uploadData } from '../../memory/buffer-pool.js';
|
|
3
|
-
import { createTensor } from '../../gpu/tensor.js';
|
|
4
2
|
import { createTrainingObjective } from './base.js';
|
|
3
|
+
import { createUploadedTensor } from '../tensor-factory.js';
|
|
5
4
|
|
|
6
5
|
function createLossGradient(loss, lossScale) {
|
|
7
6
|
const lossElements = loss.shape.reduce((acc, value) => acc * value, 1);
|
|
8
7
|
const gradData = new Float32Array(lossElements);
|
|
9
8
|
gradData.fill(lossScale);
|
|
10
|
-
|
|
11
|
-
uploadData(gradBuf, gradData);
|
|
12
|
-
return createTensor(gradBuf, 'f32', [...loss.shape], 'loss_grad_output');
|
|
9
|
+
return createUploadedTensor(gradData, 'f32', loss.shape, 'loss_grad_output');
|
|
13
10
|
}
|
|
14
11
|
|
|
15
12
|
export function createCrossEntropyObjective(options = {}) {
|
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
import { crossEntropyLoss as defaultCrossEntropyLoss } from '../loss.js';
|
|
2
2
|
import { createTrainingObjective } from './base.js';
|
|
3
|
-
import {
|
|
4
|
-
import { createTensor } from '../../gpu/tensor.js';
|
|
3
|
+
import { readBuffer } from '../../memory/buffer-pool.js';
|
|
5
4
|
import { f16ToF32Array, f32ToF16Array } from '../../inference/kv-cache/types.js';
|
|
5
|
+
import { createUploadedTensor } from '../tensor-factory.js';
|
|
6
6
|
|
|
7
7
|
const EPS = 1e-8;
|
|
8
8
|
|
|
@@ -31,9 +31,7 @@ function createLossGradient(loss, lossScale) {
|
|
|
31
31
|
const lossElements = loss.shape.reduce((acc, value) => acc * value, 1);
|
|
32
32
|
const gradData = new Float32Array(lossElements);
|
|
33
33
|
gradData.fill(lossScale);
|
|
34
|
-
|
|
35
|
-
uploadData(gradBuf, gradData);
|
|
36
|
-
return createTensor(gradBuf, 'f32', [...loss.shape], 'distill_kd_loss_grad_output');
|
|
34
|
+
return createUploadedTensor(gradData, 'f32', loss.shape, 'distill_kd_loss_grad_output');
|
|
37
35
|
}
|
|
38
36
|
|
|
39
37
|
function createGradientTensor(values, shape, dtype, label) {
|
|
@@ -42,9 +40,7 @@ function createGradientTensor(values, shape, dtype, label) {
|
|
|
42
40
|
const payload = tensorDtype === 'f16'
|
|
43
41
|
? f32ToF16Array(floatValues)
|
|
44
42
|
: floatValues;
|
|
45
|
-
|
|
46
|
-
uploadData(gradBuf, payload);
|
|
47
|
-
return createTensor(gradBuf, tensorDtype, [...shape], label);
|
|
43
|
+
return createUploadedTensor(payload, tensorDtype, shape, label);
|
|
48
44
|
}
|
|
49
45
|
|
|
50
46
|
async function readLogitsRows(logitsTensor) {
|
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
import { crossEntropyLoss as defaultCrossEntropyLoss } from '../loss.js';
|
|
2
2
|
import { createTrainingObjective } from './base.js';
|
|
3
|
-
import {
|
|
4
|
-
import { createTensor } from '../../gpu/tensor.js';
|
|
3
|
+
import { readBuffer } from '../../memory/buffer-pool.js';
|
|
5
4
|
import { f16ToF32Array, f32ToF16Array } from '../../inference/kv-cache/types.js';
|
|
5
|
+
import { createUploadedTensor } from '../tensor-factory.js';
|
|
6
6
|
|
|
7
7
|
function toFinite(value, fallback) {
|
|
8
8
|
const parsed = Number(value);
|
|
@@ -29,9 +29,7 @@ function createLossGradient(loss, lossScale) {
|
|
|
29
29
|
const lossElements = loss.shape.reduce((acc, value) => acc * value, 1);
|
|
30
30
|
const gradData = new Float32Array(lossElements);
|
|
31
31
|
gradData.fill(lossScale);
|
|
32
|
-
|
|
33
|
-
uploadData(gradBuf, gradData);
|
|
34
|
-
return createTensor(gradBuf, 'f32', [...loss.shape], 'distill_triplet_loss_grad_output');
|
|
32
|
+
return createUploadedTensor(gradData, 'f32', loss.shape, 'distill_triplet_loss_grad_output');
|
|
35
33
|
}
|
|
36
34
|
|
|
37
35
|
function createGradientTensor(values, shape, dtype, label) {
|
|
@@ -40,9 +38,7 @@ function createGradientTensor(values, shape, dtype, label) {
|
|
|
40
38
|
const payload = tensorDtype === 'f16'
|
|
41
39
|
? f32ToF16Array(floatValues)
|
|
42
40
|
: floatValues;
|
|
43
|
-
|
|
44
|
-
uploadData(gradBuf, payload);
|
|
45
|
-
return createTensor(gradBuf, tensorDtype, [...shape], label);
|
|
41
|
+
return createUploadedTensor(payload, tensorDtype, shape, label);
|
|
46
42
|
}
|
|
47
43
|
|
|
48
44
|
async function readLogitsRows(logitsTensor) {
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import { crossEntropyLoss as defaultCrossEntropyLoss } from '../loss.js';
|
|
2
2
|
import { createTrainingObjective } from './base.js';
|
|
3
|
-
import {
|
|
4
|
-
import {
|
|
3
|
+
import { releaseBuffer } from '../../memory/buffer-pool.js';
|
|
4
|
+
import { createUploadedTensor } from '../tensor-factory.js';
|
|
5
5
|
|
|
6
6
|
function sigmoid(value) {
|
|
7
7
|
return 1 / (1 + Math.exp(-value));
|
|
@@ -9,17 +9,13 @@ function sigmoid(value) {
|
|
|
9
9
|
|
|
10
10
|
function createF32Tensor(values, shape, label) {
|
|
11
11
|
const data = values instanceof Float32Array ? values : new Float32Array(values);
|
|
12
|
-
|
|
13
|
-
uploadData(buffer, data);
|
|
14
|
-
return createTensor(buffer, 'f32', [...shape], label);
|
|
12
|
+
return createUploadedTensor(data, 'f32', shape, label);
|
|
15
13
|
}
|
|
16
14
|
|
|
17
15
|
function createU32TokenTensor(values, shape, label) {
|
|
18
16
|
const data = values instanceof Uint32Array ? values : new Uint32Array(values);
|
|
19
|
-
const buffer = acquireBuffer(data.byteLength, undefined, label);
|
|
20
|
-
uploadData(buffer, data);
|
|
21
17
|
// Token targets are consumed as raw u32 bytes by loss kernels.
|
|
22
|
-
return
|
|
18
|
+
return createUploadedTensor(data, 'f32', shape, label);
|
|
23
19
|
}
|
|
24
20
|
|
|
25
21
|
function releaseTensor(tensor) {
|
|
@@ -316,6 +316,7 @@ async function runDistillCommand(request) {
|
|
|
316
316
|
layout: runArtifacts.layout,
|
|
317
317
|
pollIntervalMs: request.pollIntervalMs || null,
|
|
318
318
|
stopWhenIdle: request.stopWhenIdle === true,
|
|
319
|
+
signal: request.signal ?? null,
|
|
319
320
|
})),
|
|
320
321
|
};
|
|
321
322
|
}
|
|
@@ -378,6 +379,7 @@ async function runLoraCommand(request) {
|
|
|
378
379
|
runRoot: resolve(String(request.runRoot)),
|
|
379
380
|
pollIntervalMs: request.pollIntervalMs || null,
|
|
380
381
|
stopWhenIdle: request.stopWhenIdle === true,
|
|
382
|
+
signal: request.signal ?? null,
|
|
381
383
|
})),
|
|
382
384
|
};
|
|
383
385
|
}
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { acquireBuffer, BufferUsage } from '../memory/buffer-pool.js';
|
|
1
|
+
import { acquireBuffer, releaseBuffer, BufferUsage } from '../memory/buffer-pool.js';
|
|
2
2
|
import { createTensor, tensorBytes } from '../gpu/tensor.js';
|
|
3
3
|
import { runAdam } from '../gpu/kernels/backward/adam.js';
|
|
4
4
|
|
|
@@ -72,12 +72,24 @@ export class AdamOptimizer {
|
|
|
72
72
|
let entry = this.state.get(param);
|
|
73
73
|
if (!entry) {
|
|
74
74
|
const bytes = tensorBytes(param.shape, param.dtype);
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
75
|
+
let mBuf = null;
|
|
76
|
+
let vBuf = null;
|
|
77
|
+
try {
|
|
78
|
+
mBuf = acquireBuffer(bytes, BufferUsage.STORAGE, 'adam_m');
|
|
79
|
+
vBuf = acquireBuffer(bytes, BufferUsage.STORAGE, 'adam_v');
|
|
80
|
+
entry = {
|
|
81
|
+
m: createTensor(mBuf, param.dtype, [...param.shape], 'adam_m'),
|
|
82
|
+
v: createTensor(vBuf, param.dtype, [...param.shape], 'adam_v'),
|
|
83
|
+
};
|
|
84
|
+
} catch (error) {
|
|
85
|
+
if (mBuf) {
|
|
86
|
+
releaseBuffer(mBuf);
|
|
87
|
+
}
|
|
88
|
+
if (vBuf) {
|
|
89
|
+
releaseBuffer(vBuf);
|
|
90
|
+
}
|
|
91
|
+
throw error;
|
|
92
|
+
}
|
|
81
93
|
this.state.set(param, entry);
|
|
82
94
|
}
|
|
83
95
|
return entry;
|
package/src/training/runner.js
CHANGED
|
@@ -617,7 +617,6 @@ function buildExpectedCheckpointMetadata(metadata) {
|
|
|
617
617
|
'configHash',
|
|
618
618
|
'datasetHash',
|
|
619
619
|
'tokenizerHash',
|
|
620
|
-
'optimizerHash',
|
|
621
620
|
'runtimePresetId',
|
|
622
621
|
'kernelPathId',
|
|
623
622
|
]) {
|
|
@@ -845,6 +844,8 @@ export class TrainingRunner {
|
|
|
845
844
|
}
|
|
846
845
|
|
|
847
846
|
async run(model, dataset, options = {}) {
|
|
847
|
+
this.lastCheckpoint = null;
|
|
848
|
+
this.lastArtifact = null;
|
|
848
849
|
const {
|
|
849
850
|
epochs = 1,
|
|
850
851
|
batchSize = 1,
|