@simulatte/doppler 0.1.5 → 0.1.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +126 -0
- package/README.md +25 -17
- package/package.json +20 -4
- package/src/adapters/adapter-registry.js +12 -1
- package/src/adapters/lora-loader.js +23 -6
- package/src/bridge/extension-client.d.ts +5 -0
- package/src/bridge/extension-client.js +40 -0
- package/src/bridge/index.d.ts +2 -1
- package/src/bridge/index.js +6 -4
- package/src/browser/browser-converter.js +26 -1
- package/src/browser/file-picker.js +6 -0
- package/src/browser/safetensors-parser-browser.js +84 -1
- package/src/browser/shard-io-browser.js +2 -2
- package/src/browser/tensor-source-download.js +8 -2
- package/src/browser/tensor-source-http.d.ts +1 -0
- package/src/browser/tensor-source-http.js +5 -1
- package/src/client/doppler-api.browser.js +20 -4
- package/src/client/doppler-api.js +19 -3
- package/src/client/doppler-provider/generation.js +12 -0
- package/src/client/doppler-provider/model-manager.d.ts +10 -0
- package/src/client/doppler-provider/model-manager.js +91 -19
- package/src/client/doppler-provider/source-runtime.d.ts +2 -1
- package/src/client/doppler-provider/source-runtime.js +132 -13
- package/src/client/doppler-registry.json +8 -7
- package/src/config/backward-registry-loader.js +17 -2
- package/src/config/execution-v0-contract-check.js +113 -15
- package/src/config/kernel-path-contract-check.js +57 -29
- package/src/config/kernel-path-loader.js +5 -36
- package/src/config/kernels/kernel-ref-digests.js +39 -39
- package/src/config/kernels/registry.js +14 -1
- package/src/config/kernels/registry.json +49 -7
- package/src/config/loader.d.ts +1 -1
- package/src/config/loader.js +43 -4
- package/src/config/merge-contract-check.js +59 -4
- package/src/config/merge-helpers.js +128 -7
- package/src/config/merge.d.ts +1 -0
- package/src/config/merge.js +28 -0
- package/src/config/param-validator.js +47 -2
- package/src/config/presets/kernel-paths/{gemma2-q4k-dequant-f32a.json → gemma2-q4k-dequant-f32a-nosubgroups.json} +3 -3
- package/src/config/presets/kernel-paths/gemma3-f16-fused-f32a-online-streamingprefill.json +223 -0
- package/src/config/presets/kernel-paths/{gemma3-q4k-dequant-f32a.json → gemma3-q4k-dequant-f32a-nosubgroups.json} +3 -3
- package/src/config/presets/kernel-paths/registry.json +29 -8
- package/src/config/presets/models/gemma2.json +2 -2
- package/src/config/presets/models/qwen3.json +9 -2
- package/src/config/presets/models/transformer.json +5 -0
- package/src/config/presets/runtime/experiments/bench/gemma3-bench-q4k.json +1 -1
- package/src/config/presets/runtime/experiments/debug/gemma3-debug-q4k.json +1 -1
- package/src/config/presets/runtime/experiments/verify/gemma3-verify.json +1 -1
- package/src/config/presets/runtime/kernels/dequant-f16-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/dequant-f32-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/embeddinggemma-q4k-dequant-f32a.json +37 -0
- package/src/config/presets/runtime/kernels/fused-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f16a.json +33 -0
- package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f32a-nosubgroups.json +33 -0
- package/src/config/presets/runtime/kernels/gemma2-q4k-fused-f32a.json +33 -0
- package/src/config/presets/runtime/kernels/safe-q4k.json +6 -13
- package/src/config/presets/runtime/platform/metal-apple-q4k.json +1 -1
- package/src/config/required-inference-fields-contract-check.js +6 -0
- package/src/config/runtime.js +6 -1
- package/src/config/schema/debug.schema.d.ts +5 -0
- package/src/config/schema/doppler.schema.js +16 -21
- package/src/config/schema/inference-defaults.schema.js +6 -3
- package/src/config/schema/inference.schema.d.ts +9 -0
- package/src/config/schema/kernel-path.schema.d.ts +11 -1
- package/src/config/schema/kernel-thresholds.schema.js +12 -4
- package/src/config/schema/manifest.schema.d.ts +8 -1
- package/src/config/schema/manifest.schema.js +19 -3
- package/src/config/training-defaults.js +30 -22
- package/src/converter/conversion-plan.js +94 -9
- package/src/converter/core.d.ts +7 -0
- package/src/converter/core.js +14 -9
- package/src/converter/execution-v0-manifest.js +4 -1
- package/src/converter/index.d.ts +1 -0
- package/src/converter/index.js +1 -0
- package/src/converter/manifest-inference.js +43 -12
- package/src/converter/parsers/diffusion.js +0 -3
- package/src/converter/quantization-info.js +35 -15
- package/src/converter/rope-config.js +42 -0
- package/src/converter/shard-packer.d.ts +1 -1
- package/src/converter/shard-packer.js +4 -1
- package/src/debug/config.js +123 -11
- package/src/debug/signals.js +7 -1
- package/src/debug/tensor.d.ts +2 -0
- package/src/debug/tensor.js +13 -2
- package/src/distribution/p2p-control-plane.js +52 -12
- package/src/distribution/p2p-observability.js +43 -7
- package/src/distribution/p2p-webrtc-browser.js +20 -0
- package/src/distribution/shard-delivery.js +77 -26
- package/src/formats/gguf/types.js +33 -16
- package/src/formats/rdrr/groups.d.ts +12 -4
- package/src/formats/rdrr/groups.js +3 -6
- package/src/formats/rdrr/parsing.js +39 -2
- package/src/formats/rdrr/types.d.ts +2 -1
- package/src/gpu/command-recorder.js +86 -61
- package/src/gpu/device.d.ts +1 -0
- package/src/gpu/device.js +131 -19
- package/src/gpu/kernel-tuner/benchmarks.js +326 -316
- package/src/gpu/kernel-tuner/cache.js +71 -4
- package/src/gpu/kernel-tuner/tuner.js +22 -4
- package/src/gpu/kernels/attention.js +113 -34
- package/src/gpu/kernels/backward/adam.js +62 -58
- package/src/gpu/kernels/backward/attention_backward.js +257 -169
- package/src/gpu/kernels/backward/conv2d_backward.js +14 -1
- package/src/gpu/kernels/bias_add.wgsl +8 -6
- package/src/gpu/kernels/bias_add_f16.wgsl +8 -5
- package/src/gpu/kernels/cast.js +191 -149
- package/src/gpu/kernels/check-stop.js +33 -44
- package/src/gpu/kernels/conv2d.js +27 -17
- package/src/gpu/kernels/conv2d.wgsl +7 -8
- package/src/gpu/kernels/conv2d_f16.wgsl +7 -8
- package/src/gpu/kernels/cross_entropy_loss.js +21 -15
- package/src/gpu/kernels/depthwise_conv2d.js +37 -26
- package/src/gpu/kernels/depthwise_conv2d.wgsl +6 -9
- package/src/gpu/kernels/depthwise_conv2d_f16.wgsl +6 -9
- package/src/gpu/kernels/dequant.js +178 -126
- package/src/gpu/kernels/energy.d.ts +3 -21
- package/src/gpu/kernels/energy.js +111 -88
- package/src/gpu/kernels/feature-check.js +1 -1
- package/src/gpu/kernels/fused_ffn.js +84 -65
- package/src/gpu/kernels/fused_matmul_residual.js +56 -33
- package/src/gpu/kernels/fused_matmul_rmsnorm.js +62 -45
- package/src/gpu/kernels/gather.js +33 -15
- package/src/gpu/kernels/gelu.js +19 -11
- package/src/gpu/kernels/grouped_pointwise_conv2d.js +34 -23
- package/src/gpu/kernels/grouped_pointwise_conv2d.wgsl +6 -9
- package/src/gpu/kernels/grouped_pointwise_conv2d_f16.wgsl +6 -9
- package/src/gpu/kernels/groupnorm.js +34 -23
- package/src/gpu/kernels/kv-quantize.js +5 -2
- package/src/gpu/kernels/layernorm.js +35 -19
- package/src/gpu/kernels/logit-merge.js +5 -3
- package/src/gpu/kernels/matmul.js +83 -39
- package/src/gpu/kernels/modulate.js +23 -15
- package/src/gpu/kernels/moe.js +221 -175
- package/src/gpu/kernels/pixel_shuffle.js +22 -14
- package/src/gpu/kernels/pixel_shuffle.wgsl +4 -5
- package/src/gpu/kernels/pixel_shuffle_f16.wgsl +4 -5
- package/src/gpu/kernels/relu.js +31 -10
- package/src/gpu/kernels/relu.wgsl +2 -1
- package/src/gpu/kernels/relu_f16.wgsl +2 -1
- package/src/gpu/kernels/repeat_channels.js +25 -17
- package/src/gpu/kernels/repeat_channels.wgsl +4 -5
- package/src/gpu/kernels/repeat_channels_f16.wgsl +4 -5
- package/src/gpu/kernels/residual.js +69 -23
- package/src/gpu/kernels/residual.wgsl +6 -3
- package/src/gpu/kernels/residual_f16.wgsl +2 -1
- package/src/gpu/kernels/residual_f16_vec4.wgsl +2 -1
- package/src/gpu/kernels/residual_vec4.wgsl +2 -1
- package/src/gpu/kernels/rmsnorm.js +96 -28
- package/src/gpu/kernels/rmsnorm.wgsl +14 -6
- package/src/gpu/kernels/rmsnorm_f16.wgsl +10 -2
- package/src/gpu/kernels/rope.d.ts +2 -0
- package/src/gpu/kernels/rope.js +14 -1
- package/src/gpu/kernels/rope.wgsl +56 -40
- package/src/gpu/kernels/sample.js +27 -38
- package/src/gpu/kernels/sana_linear_attention.js +19 -12
- package/src/gpu/kernels/sana_linear_attention_apply.wgsl +4 -5
- package/src/gpu/kernels/sana_linear_attention_apply_f16.wgsl +4 -5
- package/src/gpu/kernels/sana_linear_attention_summary.wgsl +4 -0
- package/src/gpu/kernels/sana_linear_attention_summary_f16.wgsl +4 -0
- package/src/gpu/kernels/scale.js +18 -11
- package/src/gpu/kernels/shader-cache.js +4 -2
- package/src/gpu/kernels/silu.d.ts +1 -0
- package/src/gpu/kernels/silu.js +148 -82
- package/src/gpu/kernels/silu.wgsl +19 -9
- package/src/gpu/kernels/silu_f16.wgsl +19 -9
- package/src/gpu/kernels/softmax.js +44 -25
- package/src/gpu/kernels/split_qkv.js +23 -13
- package/src/gpu/kernels/transpose.js +31 -10
- package/src/gpu/kernels/transpose.wgsl +6 -5
- package/src/gpu/kernels/upsample2d.js +22 -13
- package/src/gpu/kernels/upsample2d.wgsl +6 -9
- package/src/gpu/kernels/upsample2d_f16.wgsl +6 -9
- package/src/gpu/kernels/utils.js +35 -13
- package/src/gpu/partitioned-buffer-pool.js +10 -2
- package/src/gpu/perf-guards.js +2 -9
- package/src/gpu/profiler.js +27 -22
- package/src/gpu/readback-utils.d.ts +16 -0
- package/src/gpu/readback-utils.js +41 -0
- package/src/gpu/submit-tracker.js +13 -0
- package/src/gpu/uniform-cache.d.ts +1 -0
- package/src/gpu/uniform-cache.js +30 -9
- package/src/hotswap/intent-bundle.js +6 -0
- package/src/hotswap/manifest.d.ts +10 -1
- package/src/hotswap/manifest.js +12 -2
- package/src/hotswap/runtime.js +30 -8
- package/src/index-browser.d.ts +44 -0
- package/src/index-browser.js +14 -0
- package/src/inference/browser-harness-contract-helpers.d.ts +5 -0
- package/src/inference/browser-harness-contract-helpers.js +28 -0
- package/src/inference/browser-harness-diffusion-energy-suites.d.ts +2 -0
- package/src/inference/browser-harness-diffusion-energy-suites.js +269 -0
- package/src/inference/browser-harness-model-helpers.d.ts +16 -0
- package/src/inference/browser-harness-model-helpers.js +217 -0
- package/src/inference/browser-harness-report-helpers.d.ts +7 -0
- package/src/inference/browser-harness-report-helpers.js +42 -0
- package/src/inference/browser-harness-runtime-helpers.d.ts +61 -0
- package/src/inference/browser-harness-runtime-helpers.js +415 -0
- package/src/inference/browser-harness-suite-helpers.d.ts +28 -0
- package/src/inference/browser-harness-suite-helpers.js +268 -0
- package/src/inference/browser-harness-text-helpers.d.ts +27 -0
- package/src/inference/browser-harness-text-helpers.js +788 -0
- package/src/inference/browser-harness.d.ts +6 -0
- package/src/inference/browser-harness.js +130 -1950
- package/src/inference/kv-cache/base.js +140 -94
- package/src/inference/kv-cache/tiered.js +5 -3
- package/src/inference/moe-router.js +88 -56
- package/src/inference/multi-model-network.js +5 -3
- package/src/inference/network-evolution.d.ts +11 -2
- package/src/inference/network-evolution.js +20 -21
- package/src/inference/pipelines/context.d.ts +3 -0
- package/src/inference/pipelines/context.js +142 -2
- package/src/inference/pipelines/diffusion/helpers.js +7 -2
- package/src/inference/pipelines/diffusion/pipeline.js +17 -7
- package/src/inference/pipelines/diffusion/sd3-transformer.js +10 -10
- package/src/inference/pipelines/diffusion/text-encoder-gpu.d.ts +5 -0
- package/src/inference/pipelines/diffusion/text-encoder-gpu.js +27 -15
- package/src/inference/pipelines/diffusion/vae.js +3 -7
- package/src/inference/pipelines/energy/pipeline.js +27 -21
- package/src/inference/pipelines/energy/quintel.d.ts +5 -0
- package/src/inference/pipelines/energy/quintel.js +11 -0
- package/src/inference/pipelines/energy-head/row-head-pipeline.js +17 -13
- package/src/inference/pipelines/structured/json-head-pipeline.js +26 -11
- package/src/inference/pipelines/text/attention/projections.js +151 -101
- package/src/inference/pipelines/text/attention/record.js +73 -10
- package/src/inference/pipelines/text/attention/run.js +73 -10
- package/src/inference/pipelines/text/chat-format.js +25 -1
- package/src/inference/pipelines/text/config.d.ts +4 -0
- package/src/inference/pipelines/text/config.js +71 -5
- package/src/inference/pipelines/text/embed.js +2 -8
- package/src/inference/pipelines/text/execution-plan.js +64 -50
- package/src/inference/pipelines/text/execution-v0-contract-helpers.d.ts +59 -0
- package/src/inference/pipelines/text/execution-v0-contract-helpers.js +937 -0
- package/src/inference/pipelines/text/execution-v0-runtime-builders.d.ts +15 -0
- package/src/inference/pipelines/text/execution-v0-runtime-builders.js +279 -0
- package/src/inference/pipelines/text/execution-v0.js +78 -1002
- package/src/inference/pipelines/text/ffn/standard.js +3 -0
- package/src/inference/pipelines/text/generator-steps.d.ts +46 -0
- package/src/inference/pipelines/text/generator-steps.js +298 -207
- package/src/inference/pipelines/text/generator.js +6 -23
- package/src/inference/pipelines/text/init.d.ts +4 -0
- package/src/inference/pipelines/text/init.js +134 -29
- package/src/inference/pipelines/text/kernel-path-auto-select.js +2 -0
- package/src/inference/pipelines/text/kernel-trace.d.ts +2 -0
- package/src/inference/pipelines/text/kernel-trace.js +6 -0
- package/src/inference/pipelines/text/layer.js +14 -9
- package/src/inference/pipelines/text/linear-attention.d.ts +10 -0
- package/src/inference/pipelines/text/linear-attention.js +80 -6
- package/src/inference/pipelines/text/logits/gpu.js +10 -5
- package/src/inference/pipelines/text/logits/index.js +10 -11
- package/src/inference/pipelines/text/logits/utils.d.ts +7 -0
- package/src/inference/pipelines/text/logits/utils.js +9 -0
- package/src/inference/pipelines/text/lora-apply.js +50 -32
- package/src/inference/pipelines/text/model-load.js +279 -104
- package/src/inference/pipelines/text/moe-cache.js +5 -4
- package/src/inference/pipelines/text/moe-cpu-gptoss.js +74 -69
- package/src/inference/pipelines/text/moe-cpu.js +42 -38
- package/src/inference/pipelines/text/moe-gpu.js +110 -86
- package/src/inference/pipelines/text/ops.js +90 -90
- package/src/inference/pipelines/text/probes.js +9 -9
- package/src/inference/pipelines/text/weights.js +17 -7
- package/src/inference/pipelines/text.js +17 -1
- package/src/inference/speculative.d.ts +2 -2
- package/src/inference/speculative.js +4 -18
- package/src/inference/test-harness.d.ts +1 -1
- package/src/inference/test-harness.js +15 -5
- package/src/inference/tokenizer.d.ts +0 -5
- package/src/inference/tokenizer.js +4 -23
- package/src/inference/tokenizers/bpe.js +9 -0
- package/src/inference/tokenizers/bundled.js +176 -33
- package/src/inference/tokenizers/sentencepiece.js +12 -0
- package/src/loader/doppler-loader.js +38 -22
- package/src/loader/dtype-utils.js +3 -44
- package/src/loader/embedding-loader.js +7 -3
- package/src/loader/experts/expert-cache.js +13 -6
- package/src/loader/experts/expert-loader.js +10 -6
- package/src/loader/final-weights-loader.js +8 -4
- package/src/loader/layer-loader.js +2 -1
- package/src/loader/loader-state.js +2 -2
- package/src/loader/memory-monitor.js +8 -0
- package/src/loader/multi-model-loader.d.ts +14 -0
- package/src/loader/multi-model-loader.js +70 -24
- package/src/loader/shard-cache.js +81 -12
- package/src/loader/shard-resolver.js +25 -3
- package/src/loader/tensors/tensor-loader.js +209 -144
- package/src/loader/tensors/tensor-reader.js +76 -19
- package/src/loader/weight-downcast.js +1 -1
- package/src/memory/buffer-pool.d.ts +9 -1
- package/src/memory/buffer-pool.js +109 -44
- package/src/memory/unified-detect.js +1 -1
- package/src/rules/inference/kernel-path.rules.json +24 -8
- package/src/rules/rule-registry.js +25 -1
- package/src/rules/tooling/command-runtime.rules.json +18 -0
- package/src/storage/backends/opfs-store.js +68 -24
- package/src/storage/downloader.js +364 -83
- package/src/storage/index.d.ts +3 -0
- package/src/storage/index.js +3 -0
- package/src/storage/preflight.d.ts +2 -2
- package/src/storage/preflight.js +24 -2
- package/src/storage/quickstart-downloader.js +11 -5
- package/src/storage/registry.js +10 -4
- package/src/storage/reports.js +1 -1
- package/src/storage/shard-manager.d.ts +15 -1
- package/src/storage/shard-manager.js +51 -3
- package/src/storage/source-artifact-store.d.ts +52 -0
- package/src/storage/source-artifact-store.js +234 -0
- package/src/tooling/command-api-constants.d.ts +9 -0
- package/src/tooling/command-api-constants.js +9 -0
- package/src/tooling/command-api-family-normalizers.d.ts +9 -0
- package/src/tooling/command-api-family-normalizers.js +343 -0
- package/src/tooling/command-api-helpers.d.ts +25 -0
- package/src/tooling/command-api-helpers.js +262 -0
- package/src/tooling/command-api.d.ts +27 -1
- package/src/tooling/command-api.js +26 -473
- package/src/tooling/command-envelope.js +4 -1
- package/src/tooling/command-runner-shared.js +52 -18
- package/src/tooling/lean-execution-contract.js +150 -3
- package/src/tooling/node-browser-command-runner.d.ts +4 -0
- package/src/tooling/node-browser-command-runner.js +218 -273
- package/src/tooling/node-command-runner.js +44 -3
- package/src/tooling/node-converter.js +27 -1
- package/src/tooling/node-source-runtime.d.ts +1 -1
- package/src/tooling/node-source-runtime.js +84 -3
- package/src/tooling/node-webgpu.js +30 -105
- package/src/tooling/opfs-cache.js +21 -4
- package/src/tooling/runtime-input-composition.d.ts +38 -0
- package/src/tooling/runtime-input-composition.js +86 -0
- package/src/tooling/source-runtime-bundle.d.ts +40 -5
- package/src/tooling/source-runtime-bundle.js +261 -34
- package/src/tooling/source-runtime-materializer.d.ts +6 -0
- package/src/tooling/source-runtime-materializer.js +93 -0
- package/src/training/attention-backward.js +32 -17
- package/src/training/autograd.js +80 -52
- package/src/training/checkpoint-watch.d.ts +8 -0
- package/src/training/checkpoint-watch.js +139 -0
- package/src/training/checkpoint.d.ts +6 -1
- package/src/training/checkpoint.js +46 -7
- package/src/training/clip.js +2 -1
- package/src/training/datasets/token-batch.js +20 -8
- package/src/training/distillation/artifacts.d.ts +71 -0
- package/src/training/distillation/artifacts.js +132 -0
- package/src/training/distillation/checkpoint-watch.d.ts +10 -0
- package/src/training/distillation/checkpoint-watch.js +58 -0
- package/src/training/distillation/dataset.d.ts +59 -0
- package/src/training/distillation/dataset.js +337 -0
- package/src/training/distillation/eval.d.ts +34 -0
- package/src/training/distillation/eval.js +310 -0
- package/src/training/distillation/index.d.ts +29 -0
- package/src/training/distillation/index.js +29 -0
- package/src/training/distillation/runtime.d.ts +20 -0
- package/src/training/distillation/runtime.js +121 -0
- package/src/training/distillation/scoreboard.d.ts +6 -0
- package/src/training/distillation/scoreboard.js +8 -0
- package/src/training/distillation/stage-a.d.ts +45 -0
- package/src/training/distillation/stage-a.js +338 -0
- package/src/training/distillation/stage-b.d.ts +24 -0
- package/src/training/distillation/stage-b.js +20 -0
- package/src/training/distillation/student-fixture.d.ts +22 -0
- package/src/training/distillation/student-fixture.js +846 -0
- package/src/training/distillation/suite-data.d.ts +45 -0
- package/src/training/distillation/suite-data.js +189 -0
- package/src/training/index.d.ts +10 -0
- package/src/training/index.js +10 -0
- package/src/training/lora-pipeline.d.ts +40 -0
- package/src/training/lora-pipeline.js +793 -0
- package/src/training/lora.js +26 -12
- package/src/training/loss.js +5 -6
- package/src/training/objectives/cross_entropy.js +2 -5
- package/src/training/objectives/distill_kd.js +4 -8
- package/src/training/objectives/distill_triplet.js +4 -8
- package/src/training/objectives/ul_stage2_base.js +4 -8
- package/src/training/operator-artifacts.d.ts +62 -0
- package/src/training/operator-artifacts.js +140 -0
- package/src/training/operator-command.d.ts +5 -0
- package/src/training/operator-command.js +455 -0
- package/src/training/operator-eval.d.ts +48 -0
- package/src/training/operator-eval.js +230 -0
- package/src/training/operator-scoreboard.d.ts +5 -0
- package/src/training/operator-scoreboard.js +44 -0
- package/src/training/optimizer.js +19 -7
- package/src/training/runner.d.ts +52 -0
- package/src/training/runner.js +31 -5
- package/src/training/suite.d.ts +112 -0
- package/src/training/suite.js +24 -984
- package/src/training/tensor-factory.d.ts +9 -0
- package/src/training/tensor-factory.js +13 -0
- package/src/training/trainer.js +3 -5
- package/src/training/ul_dataset.js +3 -5
- package/src/training/workloads.d.ts +164 -0
- package/src/training/workloads.js +530 -0
- package/src/version.js +1 -1
- package/tools/convert-safetensors-node.js +22 -16
- package/tools/doppler-cli.js +179 -63
|
@@ -0,0 +1,230 @@
|
|
|
1
|
+
import { readFile } from 'node:fs/promises';
|
|
2
|
+
import { resolve } from 'node:path';
|
|
3
|
+
|
|
4
|
+
import { parseJsonl } from './datasets/jsonl.js';
|
|
5
|
+
|
|
6
|
+
function asTokenSequence(text) {
|
|
7
|
+
return String(text ?? '')
|
|
8
|
+
.trim()
|
|
9
|
+
.split(/\s+/)
|
|
10
|
+
.filter(Boolean);
|
|
11
|
+
}
|
|
12
|
+
|
|
13
|
+
function extractCharacterNgrams(text, n) {
|
|
14
|
+
const normalized = Array.from(String(text ?? '').trim());
|
|
15
|
+
if (normalized.length < n) {
|
|
16
|
+
return new Map();
|
|
17
|
+
}
|
|
18
|
+
const grams = new Map();
|
|
19
|
+
for (let index = 0; index <= normalized.length - n; index += 1) {
|
|
20
|
+
const gram = normalized.slice(index, index + n).join('');
|
|
21
|
+
grams.set(gram, (grams.get(gram) || 0) + 1);
|
|
22
|
+
}
|
|
23
|
+
return grams;
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
function countOverlap(source, target) {
|
|
27
|
+
let overlap = 0;
|
|
28
|
+
for (const [key, sourceCount] of source.entries()) {
|
|
29
|
+
const targetCount = target.get(key) || 0;
|
|
30
|
+
overlap += Math.min(sourceCount, targetCount);
|
|
31
|
+
}
|
|
32
|
+
return overlap;
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
function computeBleuStats(hypotheses, references, maxOrder = 4) {
|
|
36
|
+
const matchesByOrder = new Array(maxOrder).fill(0);
|
|
37
|
+
const possibleByOrder = new Array(maxOrder).fill(0);
|
|
38
|
+
let hypothesisLength = 0;
|
|
39
|
+
let referenceLength = 0;
|
|
40
|
+
|
|
41
|
+
for (let index = 0; index < hypotheses.length; index += 1) {
|
|
42
|
+
const hypothesis = asTokenSequence(hypotheses[index]);
|
|
43
|
+
const reference = asTokenSequence(references[index]);
|
|
44
|
+
hypothesisLength += hypothesis.length;
|
|
45
|
+
referenceLength += reference.length;
|
|
46
|
+
for (let order = 1; order <= maxOrder; order += 1) {
|
|
47
|
+
const hypothesisCounts = new Map();
|
|
48
|
+
const referenceCounts = new Map();
|
|
49
|
+
for (let tokenIndex = 0; tokenIndex <= hypothesis.length - order; tokenIndex += 1) {
|
|
50
|
+
const ngram = hypothesis.slice(tokenIndex, tokenIndex + order).join('\u0001');
|
|
51
|
+
hypothesisCounts.set(ngram, (hypothesisCounts.get(ngram) || 0) + 1);
|
|
52
|
+
}
|
|
53
|
+
for (let tokenIndex = 0; tokenIndex <= reference.length - order; tokenIndex += 1) {
|
|
54
|
+
const ngram = reference.slice(tokenIndex, tokenIndex + order).join('\u0001');
|
|
55
|
+
referenceCounts.set(ngram, (referenceCounts.get(ngram) || 0) + 1);
|
|
56
|
+
}
|
|
57
|
+
matchesByOrder[order - 1] += countOverlap(hypothesisCounts, referenceCounts);
|
|
58
|
+
possibleByOrder[order - 1] += Math.max(0, hypothesis.length - order + 1);
|
|
59
|
+
}
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
return {
|
|
63
|
+
matchesByOrder,
|
|
64
|
+
possibleByOrder,
|
|
65
|
+
hypothesisLength,
|
|
66
|
+
referenceLength,
|
|
67
|
+
};
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
export function computeBleuScore(hypotheses, references, options = {}) {
|
|
71
|
+
const maxOrder = Number.isInteger(options.maxOrder) && options.maxOrder > 0
|
|
72
|
+
? options.maxOrder
|
|
73
|
+
: 4;
|
|
74
|
+
if (!Array.isArray(hypotheses) || !Array.isArray(references) || hypotheses.length !== references.length) {
|
|
75
|
+
throw new Error('computeBleuScore requires equally sized hypothesis and reference arrays.');
|
|
76
|
+
}
|
|
77
|
+
if (hypotheses.length === 0) {
|
|
78
|
+
return {
|
|
79
|
+
score: 0,
|
|
80
|
+
brevityPenalty: 0,
|
|
81
|
+
precisions: new Array(maxOrder).fill(0),
|
|
82
|
+
hypothesisLength: 0,
|
|
83
|
+
referenceLength: 0,
|
|
84
|
+
};
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
const stats = computeBleuStats(hypotheses, references, maxOrder);
|
|
88
|
+
const precisions = [];
|
|
89
|
+
let precisionLogSum = 0;
|
|
90
|
+
for (let order = 0; order < maxOrder; order += 1) {
|
|
91
|
+
const matches = stats.matchesByOrder[order];
|
|
92
|
+
const possible = stats.possibleByOrder[order];
|
|
93
|
+
const precision = possible === 0
|
|
94
|
+
? 0
|
|
95
|
+
: ((matches + 1) / (possible + 1));
|
|
96
|
+
precisions.push(precision);
|
|
97
|
+
precisionLogSum += Math.log(Math.max(precision, 1e-16));
|
|
98
|
+
}
|
|
99
|
+
const brevityPenalty = stats.hypothesisLength > stats.referenceLength
|
|
100
|
+
? 1
|
|
101
|
+
: Math.exp(1 - (stats.referenceLength / Math.max(stats.hypothesisLength, 1)));
|
|
102
|
+
const score = brevityPenalty * Math.exp(precisionLogSum / maxOrder);
|
|
103
|
+
return {
|
|
104
|
+
score,
|
|
105
|
+
brevityPenalty,
|
|
106
|
+
precisions,
|
|
107
|
+
hypothesisLength: stats.hypothesisLength,
|
|
108
|
+
referenceLength: stats.referenceLength,
|
|
109
|
+
};
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
export function computeChrfScore(hypotheses, references, options = {}) {
|
|
113
|
+
const maxOrder = Number.isInteger(options.maxOrder) && options.maxOrder > 0
|
|
114
|
+
? options.maxOrder
|
|
115
|
+
: 6;
|
|
116
|
+
const beta = Number.isFinite(options.beta) && options.beta > 0 ? options.beta : 2;
|
|
117
|
+
if (!Array.isArray(hypotheses) || !Array.isArray(references) || hypotheses.length !== references.length) {
|
|
118
|
+
throw new Error('computeChrfScore requires equally sized hypothesis and reference arrays.');
|
|
119
|
+
}
|
|
120
|
+
if (hypotheses.length === 0) {
|
|
121
|
+
return {
|
|
122
|
+
score: 0,
|
|
123
|
+
precision: 0,
|
|
124
|
+
recall: 0,
|
|
125
|
+
};
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
let precisionSum = 0;
|
|
129
|
+
let recallSum = 0;
|
|
130
|
+
for (let order = 1; order <= maxOrder; order += 1) {
|
|
131
|
+
let overlap = 0;
|
|
132
|
+
let hypothesisTotal = 0;
|
|
133
|
+
let referenceTotal = 0;
|
|
134
|
+
for (let index = 0; index < hypotheses.length; index += 1) {
|
|
135
|
+
const hypothesisCounts = extractCharacterNgrams(hypotheses[index], order);
|
|
136
|
+
const referenceCounts = extractCharacterNgrams(references[index], order);
|
|
137
|
+
overlap += countOverlap(hypothesisCounts, referenceCounts);
|
|
138
|
+
for (const value of hypothesisCounts.values()) {
|
|
139
|
+
hypothesisTotal += value;
|
|
140
|
+
}
|
|
141
|
+
for (const value of referenceCounts.values()) {
|
|
142
|
+
referenceTotal += value;
|
|
143
|
+
}
|
|
144
|
+
}
|
|
145
|
+
precisionSum += hypothesisTotal > 0 ? (overlap / hypothesisTotal) : 0;
|
|
146
|
+
recallSum += referenceTotal > 0 ? (overlap / referenceTotal) : 0;
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
const precision = precisionSum / maxOrder;
|
|
150
|
+
const recall = recallSum / maxOrder;
|
|
151
|
+
const betaSquared = beta * beta;
|
|
152
|
+
const score = (precision + recall) === 0
|
|
153
|
+
? 0
|
|
154
|
+
: ((1 + betaSquared) * precision * recall) / ((betaSquared * precision) + recall);
|
|
155
|
+
return { score, precision, recall };
|
|
156
|
+
}
|
|
157
|
+
|
|
158
|
+
export function computeExactMatch(hypotheses, references) {
|
|
159
|
+
if (!Array.isArray(hypotheses) || !Array.isArray(references) || hypotheses.length !== references.length) {
|
|
160
|
+
throw new Error('computeExactMatch requires equally sized hypothesis and reference arrays.');
|
|
161
|
+
}
|
|
162
|
+
if (hypotheses.length === 0) {
|
|
163
|
+
return { score: 0, matches: 0, total: 0 };
|
|
164
|
+
}
|
|
165
|
+
let matches = 0;
|
|
166
|
+
for (let index = 0; index < hypotheses.length; index += 1) {
|
|
167
|
+
if (String(hypotheses[index] ?? '').trim() === String(references[index] ?? '').trim()) {
|
|
168
|
+
matches += 1;
|
|
169
|
+
}
|
|
170
|
+
}
|
|
171
|
+
return {
|
|
172
|
+
score: matches / hypotheses.length,
|
|
173
|
+
matches,
|
|
174
|
+
total: hypotheses.length,
|
|
175
|
+
};
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
export function computeAccuracy(labels, predictions) {
|
|
179
|
+
return computeExactMatch(predictions, labels);
|
|
180
|
+
}
|
|
181
|
+
|
|
182
|
+
export function computeEvalMetrics(evalKind, hypotheses, references, options = {}) {
|
|
183
|
+
const normalizedKind = String(evalKind || '').trim();
|
|
184
|
+
if (normalizedKind === 'translation') {
|
|
185
|
+
const bleu = computeBleuScore(hypotheses, references, options.bleu || {});
|
|
186
|
+
const chrf = computeChrfScore(hypotheses, references, options.chrf || {});
|
|
187
|
+
return {
|
|
188
|
+
bleu,
|
|
189
|
+
chrf,
|
|
190
|
+
primaryMetric: 'bleu',
|
|
191
|
+
primaryScore: bleu.score,
|
|
192
|
+
};
|
|
193
|
+
}
|
|
194
|
+
if (normalizedKind === 'text_generation') {
|
|
195
|
+
const exactMatch = computeExactMatch(hypotheses, references);
|
|
196
|
+
return {
|
|
197
|
+
exactMatch,
|
|
198
|
+
primaryMetric: 'exact_match',
|
|
199
|
+
primaryScore: exactMatch.score,
|
|
200
|
+
};
|
|
201
|
+
}
|
|
202
|
+
if (normalizedKind === 'classification') {
|
|
203
|
+
const accuracy = computeAccuracy(references, hypotheses);
|
|
204
|
+
return {
|
|
205
|
+
accuracy,
|
|
206
|
+
primaryMetric: 'accuracy',
|
|
207
|
+
primaryScore: accuracy.score,
|
|
208
|
+
};
|
|
209
|
+
}
|
|
210
|
+
if (normalizedKind === 'retrieval' || normalizedKind === 'custom') {
|
|
211
|
+
throw new Error(`Eval kind "${normalizedKind}" requires a custom evaluator and is not yet implemented.`);
|
|
212
|
+
}
|
|
213
|
+
throw new Error(`Unsupported eval kind "${normalizedKind}".`);
|
|
214
|
+
}
|
|
215
|
+
|
|
216
|
+
export async function loadEvalDataset(datasetPath) {
|
|
217
|
+
const absolutePath = resolve(String(datasetPath));
|
|
218
|
+
const raw = await readFile(absolutePath, 'utf8');
|
|
219
|
+
const rows = absolutePath.endsWith('.json')
|
|
220
|
+
? JSON.parse(raw)
|
|
221
|
+
: parseJsonl(raw);
|
|
222
|
+
if (!Array.isArray(rows)) {
|
|
223
|
+
throw new Error(`Eval dataset "${absolutePath}" must be a JSON array or JSONL file.`);
|
|
224
|
+
}
|
|
225
|
+
return {
|
|
226
|
+
absolutePath,
|
|
227
|
+
rows,
|
|
228
|
+
raw,
|
|
229
|
+
};
|
|
230
|
+
}
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
import { join } from 'node:path';
|
|
2
|
+
|
|
3
|
+
import { writeJsonArtifact, writeNdjsonRow } from './operator-artifacts.js';
|
|
4
|
+
|
|
5
|
+
function resolveComparableMetric(row, metric) {
|
|
6
|
+
if (!row || typeof row !== 'object') return null;
|
|
7
|
+
const direct = row[metric];
|
|
8
|
+
if (typeof direct === 'number' && Number.isFinite(direct)) {
|
|
9
|
+
return direct;
|
|
10
|
+
}
|
|
11
|
+
const metrics = row.metrics && typeof row.metrics === 'object' ? row.metrics : null;
|
|
12
|
+
const nested = metrics?.[metric];
|
|
13
|
+
if (typeof nested === 'number' && Number.isFinite(nested)) {
|
|
14
|
+
return nested;
|
|
15
|
+
}
|
|
16
|
+
return null;
|
|
17
|
+
}
|
|
18
|
+
|
|
19
|
+
export async function appendScoreboardRow(scoreboardDir, row, options = {}) {
|
|
20
|
+
const rowsPath = join(scoreboardDir, 'scoreboard.ndjson');
|
|
21
|
+
await writeNdjsonRow(rowsPath, row);
|
|
22
|
+
const metric = String(options.selectionMetric || row.selectionMetric || row.primaryMetric || '').trim();
|
|
23
|
+
const goal = String(options.selectionGoal || row.selectionGoal || 'max').trim();
|
|
24
|
+
const comparable = resolveComparableMetric(row, metric);
|
|
25
|
+
const summary = {
|
|
26
|
+
artifactType: 'training_scoreboard',
|
|
27
|
+
schemaVersion: 1,
|
|
28
|
+
generatedAt: new Date().toISOString(),
|
|
29
|
+
selectionMetric: metric || null,
|
|
30
|
+
selectionGoal: goal,
|
|
31
|
+
latest: row,
|
|
32
|
+
best: comparable === null
|
|
33
|
+
? row
|
|
34
|
+
: {
|
|
35
|
+
...row,
|
|
36
|
+
selectionMetricValue: comparable,
|
|
37
|
+
},
|
|
38
|
+
};
|
|
39
|
+
const summaryResult = await writeJsonArtifact(join(scoreboardDir, 'latest.json'), summary);
|
|
40
|
+
return {
|
|
41
|
+
rowsPath,
|
|
42
|
+
summaryPath: summaryResult.path,
|
|
43
|
+
};
|
|
44
|
+
}
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { acquireBuffer, BufferUsage } from '../memory/buffer-pool.js';
|
|
1
|
+
import { acquireBuffer, releaseBuffer, BufferUsage } from '../memory/buffer-pool.js';
|
|
2
2
|
import { createTensor, tensorBytes } from '../gpu/tensor.js';
|
|
3
3
|
import { runAdam } from '../gpu/kernels/backward/adam.js';
|
|
4
4
|
|
|
@@ -72,12 +72,24 @@ export class AdamOptimizer {
|
|
|
72
72
|
let entry = this.state.get(param);
|
|
73
73
|
if (!entry) {
|
|
74
74
|
const bytes = tensorBytes(param.shape, param.dtype);
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
75
|
+
let mBuf = null;
|
|
76
|
+
let vBuf = null;
|
|
77
|
+
try {
|
|
78
|
+
mBuf = acquireBuffer(bytes, BufferUsage.STORAGE, 'adam_m');
|
|
79
|
+
vBuf = acquireBuffer(bytes, BufferUsage.STORAGE, 'adam_v');
|
|
80
|
+
entry = {
|
|
81
|
+
m: createTensor(mBuf, param.dtype, [...param.shape], 'adam_m'),
|
|
82
|
+
v: createTensor(vBuf, param.dtype, [...param.shape], 'adam_v'),
|
|
83
|
+
};
|
|
84
|
+
} catch (error) {
|
|
85
|
+
if (mBuf) {
|
|
86
|
+
releaseBuffer(mBuf);
|
|
87
|
+
}
|
|
88
|
+
if (vBuf) {
|
|
89
|
+
releaseBuffer(vBuf);
|
|
90
|
+
}
|
|
91
|
+
throw error;
|
|
92
|
+
}
|
|
81
93
|
this.state.set(param, entry);
|
|
82
94
|
}
|
|
83
95
|
return entry;
|
package/src/training/runner.d.ts
CHANGED
|
@@ -90,6 +90,16 @@ export interface TrainingStepMetricsEntry {
|
|
|
90
90
|
export interface TrainingRunnerCallbacks {
|
|
91
91
|
onStep?: (entry: TrainingStepMetricsEntry) => Promise<void> | void;
|
|
92
92
|
onEpoch?: (entry: { epoch: number; steps: number; loss: number }) => Promise<void> | void;
|
|
93
|
+
onCheckpoint?: (entry: {
|
|
94
|
+
key: string;
|
|
95
|
+
defaultCheckpointKey: string | null;
|
|
96
|
+
path: string | null;
|
|
97
|
+
metadata: Record<string, unknown> | null;
|
|
98
|
+
payload: unknown;
|
|
99
|
+
step: number;
|
|
100
|
+
epoch: number;
|
|
101
|
+
batch: number;
|
|
102
|
+
}) => Promise<void> | void;
|
|
93
103
|
}
|
|
94
104
|
|
|
95
105
|
export interface TrainingRunnerOptions extends TrainingRunnerCallbacks {
|
|
@@ -106,6 +116,12 @@ export interface TrainingRunnerOptions extends TrainingRunnerCallbacks {
|
|
|
106
116
|
) => Promise<ClipMetrics>;
|
|
107
117
|
lossScaler?: DynamicLossScaler;
|
|
108
118
|
trainingObjective?: TrainingObjective;
|
|
119
|
+
resolveCheckpointKey?: (entry: {
|
|
120
|
+
defaultCheckpointKey: string | null;
|
|
121
|
+
step: number;
|
|
122
|
+
epoch: number;
|
|
123
|
+
batch: number;
|
|
124
|
+
}) => Promise<string> | string;
|
|
109
125
|
}
|
|
110
126
|
|
|
111
127
|
export interface TrainingRunOptions {
|
|
@@ -159,6 +175,9 @@ export declare class TrainingRunner {
|
|
|
159
175
|
lastArtifact: UlArtifactFinalizeResult | DistillArtifactFinalizeResult | null;
|
|
160
176
|
lastCheckpoint: {
|
|
161
177
|
key: string;
|
|
178
|
+
defaultKey?: string | null;
|
|
179
|
+
path?: string | null;
|
|
180
|
+
metadata?: Record<string, unknown> | null;
|
|
162
181
|
step: number;
|
|
163
182
|
epoch: number;
|
|
164
183
|
batch: number;
|
|
@@ -194,3 +213,36 @@ export declare function runTraining(
|
|
|
194
213
|
config: TrainingConfigSchema,
|
|
195
214
|
options?: TrainingRunOptions & TrainingRunnerOptions
|
|
196
215
|
): Promise<TrainingStepMetricsEntry[]>;
|
|
216
|
+
|
|
217
|
+
export declare function createTrainingCheckpointPayload(
|
|
218
|
+
model: {
|
|
219
|
+
loraParams?: () => Tensor[];
|
|
220
|
+
paramGroups?: () => Record<string, Tensor[]>;
|
|
221
|
+
},
|
|
222
|
+
optimizer: unknown,
|
|
223
|
+
context: {
|
|
224
|
+
step: number;
|
|
225
|
+
epoch: number;
|
|
226
|
+
batch: number;
|
|
227
|
+
config: TrainingConfigSchema;
|
|
228
|
+
}
|
|
229
|
+
): Promise<unknown>;
|
|
230
|
+
|
|
231
|
+
export declare function restoreTrainingCheckpointState(
|
|
232
|
+
model: {
|
|
233
|
+
loraParams?: () => Tensor[];
|
|
234
|
+
paramGroups?: () => Record<string, Tensor[]>;
|
|
235
|
+
},
|
|
236
|
+
optimizer: unknown,
|
|
237
|
+
checkpointRecord: unknown,
|
|
238
|
+
config: TrainingConfigSchema
|
|
239
|
+
): Promise<{
|
|
240
|
+
step: number;
|
|
241
|
+
epoch: number;
|
|
242
|
+
batch: number;
|
|
243
|
+
checkpointHash: string | null;
|
|
244
|
+
previousCheckpointHash: string | null;
|
|
245
|
+
checkpointKey: string | null;
|
|
246
|
+
resumeAudits: Array<Record<string, unknown>>;
|
|
247
|
+
resumeAuditCount: number;
|
|
248
|
+
} | null>;
|
package/src/training/runner.js
CHANGED
|
@@ -617,7 +617,6 @@ function buildExpectedCheckpointMetadata(metadata) {
|
|
|
617
617
|
'configHash',
|
|
618
618
|
'datasetHash',
|
|
619
619
|
'tokenizerHash',
|
|
620
|
-
'optimizerHash',
|
|
621
620
|
'runtimePresetId',
|
|
622
621
|
'kernelPathId',
|
|
623
622
|
]) {
|
|
@@ -713,7 +712,7 @@ function looksLikeTrainingCheckpointRecord(value) {
|
|
|
713
712
|
return Number.isInteger(progress.step) && progress.step >= 0;
|
|
714
713
|
}
|
|
715
714
|
|
|
716
|
-
async function createTrainingCheckpointPayload(model, optimizer, context) {
|
|
715
|
+
export async function createTrainingCheckpointPayload(model, optimizer, context) {
|
|
717
716
|
const freezeMap = context.config?.training?.ul?.freeze
|
|
718
717
|
?? context.config?.training?.distill?.freeze
|
|
719
718
|
?? {};
|
|
@@ -747,7 +746,7 @@ async function createTrainingCheckpointPayload(model, optimizer, context) {
|
|
|
747
746
|
};
|
|
748
747
|
}
|
|
749
748
|
|
|
750
|
-
async function restoreTrainingCheckpointState(model, optimizer, checkpointRecord, config) {
|
|
749
|
+
export async function restoreTrainingCheckpointState(model, optimizer, checkpointRecord, config) {
|
|
751
750
|
if (!looksLikeTrainingCheckpointRecord(checkpointRecord)) {
|
|
752
751
|
return null;
|
|
753
752
|
}
|
|
@@ -837,12 +836,16 @@ export class TrainingRunner {
|
|
|
837
836
|
this.lossScaler = options.lossScaler || new DynamicLossScaler(config.training.lossScaling);
|
|
838
837
|
this.onStep = options.onStep || null;
|
|
839
838
|
this.onEpoch = options.onEpoch || null;
|
|
839
|
+
this.onCheckpoint = options.onCheckpoint || null;
|
|
840
|
+
this.resolveCheckpointKey = options.resolveCheckpointKey || null;
|
|
840
841
|
this.lastArtifact = null;
|
|
841
842
|
this.lastCheckpoint = null;
|
|
842
843
|
this.resumeState = null;
|
|
843
844
|
}
|
|
844
845
|
|
|
845
846
|
async run(model, dataset, options = {}) {
|
|
847
|
+
this.lastCheckpoint = null;
|
|
848
|
+
this.lastArtifact = null;
|
|
846
849
|
const {
|
|
847
850
|
epochs = 1,
|
|
848
851
|
batchSize = 1,
|
|
@@ -911,16 +914,39 @@ export class TrainingRunner {
|
|
|
911
914
|
batch: checkpointContext.batch,
|
|
912
915
|
config: this.config,
|
|
913
916
|
});
|
|
914
|
-
|
|
917
|
+
const resolvedCheckpointKey = this.resolveCheckpointKey
|
|
918
|
+
? await this.resolveCheckpointKey({
|
|
919
|
+
defaultCheckpointKey: checkpointKey,
|
|
920
|
+
step: checkpointContext.step,
|
|
921
|
+
epoch: checkpointContext.epoch,
|
|
922
|
+
batch: checkpointContext.batch,
|
|
923
|
+
})
|
|
924
|
+
: checkpointKey;
|
|
925
|
+
const saveResult = await saveCheckpoint(resolvedCheckpointKey, payload, {
|
|
915
926
|
...checkpointMetadata,
|
|
916
927
|
optimizerHash: hashStableJson(payload?.trainingState?.optimizerSlots || {}),
|
|
917
928
|
});
|
|
918
929
|
this.lastCheckpoint = {
|
|
919
|
-
key:
|
|
930
|
+
key: resolvedCheckpointKey,
|
|
931
|
+
defaultKey: checkpointKey,
|
|
932
|
+
path: saveResult?.path || null,
|
|
933
|
+
metadata: saveResult?.metadata || null,
|
|
920
934
|
step: checkpointContext.step,
|
|
921
935
|
epoch: checkpointContext.epoch,
|
|
922
936
|
batch: checkpointContext.batch,
|
|
923
937
|
};
|
|
938
|
+
if (this.onCheckpoint) {
|
|
939
|
+
await this.onCheckpoint({
|
|
940
|
+
key: resolvedCheckpointKey,
|
|
941
|
+
defaultCheckpointKey: checkpointKey,
|
|
942
|
+
path: saveResult?.path || null,
|
|
943
|
+
metadata: saveResult?.metadata || null,
|
|
944
|
+
payload,
|
|
945
|
+
step: checkpointContext.step,
|
|
946
|
+
epoch: checkpointContext.epoch,
|
|
947
|
+
batch: checkpointContext.batch,
|
|
948
|
+
});
|
|
949
|
+
}
|
|
924
950
|
};
|
|
925
951
|
|
|
926
952
|
const artifactSession = distillContract.enabled
|
package/src/training/suite.d.ts
CHANGED
|
@@ -176,6 +176,66 @@ export interface RunTrainingSuiteOptions {
|
|
|
176
176
|
timestamp?: string | Date;
|
|
177
177
|
}
|
|
178
178
|
|
|
179
|
+
export interface DistillDataScope {
|
|
180
|
+
sourceLangs: string[] | null;
|
|
181
|
+
targetLangs: string[] | null;
|
|
182
|
+
pairAllowlist: string[] | null;
|
|
183
|
+
sourceLangSet: Set<string> | null;
|
|
184
|
+
targetLangSet: Set<string> | null;
|
|
185
|
+
pairAllowlistSet: Set<string> | null;
|
|
186
|
+
strictPairContract: boolean;
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
export interface DistillDatasetReport {
|
|
190
|
+
absolutePath: string;
|
|
191
|
+
rowCount: number;
|
|
192
|
+
sampleCount: number;
|
|
193
|
+
directionCounts: Record<string, number>;
|
|
194
|
+
dataScope: {
|
|
195
|
+
sourceLangs: string[] | null;
|
|
196
|
+
targetLangs: string[] | null;
|
|
197
|
+
pairAllowlist: string[] | null;
|
|
198
|
+
strictPairContract: boolean;
|
|
199
|
+
} | null;
|
|
200
|
+
shardCount?: number;
|
|
201
|
+
shardPaths?: string[];
|
|
202
|
+
createDataset(options?: Record<string, unknown>): {
|
|
203
|
+
batches(): AsyncGenerator<Record<string, unknown>, void, unknown>;
|
|
204
|
+
};
|
|
205
|
+
}
|
|
206
|
+
|
|
207
|
+
export interface DistillRuntimeContext {
|
|
208
|
+
stage: 'stage_a' | 'stage_b';
|
|
209
|
+
teacherPipeline: Record<string, unknown>;
|
|
210
|
+
studentPipeline: Record<string, unknown>;
|
|
211
|
+
teacherModelId: string;
|
|
212
|
+
studentModelId: string;
|
|
213
|
+
teacherModelUrl: string | null;
|
|
214
|
+
studentModelUrl: string | null;
|
|
215
|
+
topK: number;
|
|
216
|
+
temperature: number;
|
|
217
|
+
alphaKd: number;
|
|
218
|
+
alphaCe: number;
|
|
219
|
+
tripletMargin: number;
|
|
220
|
+
studentGraphMode: string;
|
|
221
|
+
targetTokenMode: string;
|
|
222
|
+
cleanup(): Promise<void>;
|
|
223
|
+
}
|
|
224
|
+
|
|
225
|
+
export interface DistillStudentFixture {
|
|
226
|
+
config: Record<string, unknown>;
|
|
227
|
+
model: {
|
|
228
|
+
forward: (input: unknown, tape: unknown) => Promise<unknown>;
|
|
229
|
+
forwardDistill?: (batch: unknown, tape: unknown, options?: Record<string, unknown>) => Promise<{ logits: unknown }>;
|
|
230
|
+
cleanupDistillStep?: () => void;
|
|
231
|
+
loraParams?: () => unknown[];
|
|
232
|
+
paramGroups?: () => Record<string, unknown[]>;
|
|
233
|
+
};
|
|
234
|
+
outputDim?: number;
|
|
235
|
+
embeddingDim?: number;
|
|
236
|
+
cleanup(): void;
|
|
237
|
+
}
|
|
238
|
+
|
|
179
239
|
export declare const trainingHarness: TrainingHarness;
|
|
180
240
|
|
|
181
241
|
export declare function runTrainingSuite(
|
|
@@ -185,3 +245,55 @@ export declare function runTrainingSuite(
|
|
|
185
245
|
export declare function runTrainingBenchSuite(
|
|
186
246
|
options?: RunTrainingSuiteOptions
|
|
187
247
|
): Promise<TrainingBenchSuiteResult>;
|
|
248
|
+
|
|
249
|
+
export declare function resolveDistillDataScope(
|
|
250
|
+
options?: RunTrainingSuiteOptions,
|
|
251
|
+
trainingConfig?: Record<string, unknown> | null
|
|
252
|
+
): DistillDataScope;
|
|
253
|
+
|
|
254
|
+
export declare function buildDistillPrompt(sample: Record<string, unknown>): string;
|
|
255
|
+
|
|
256
|
+
export declare function normalizeDistillStudentGraphMode(value: unknown): string;
|
|
257
|
+
|
|
258
|
+
export declare function loadDistillDatasetFromJsonl(
|
|
259
|
+
datasetPath: string,
|
|
260
|
+
scopeOptions?: DistillDataScope | null
|
|
261
|
+
): Promise<DistillDatasetReport | null>;
|
|
262
|
+
|
|
263
|
+
export declare function loadDistillModelHandle(
|
|
264
|
+
modelRef: string,
|
|
265
|
+
role: string,
|
|
266
|
+
loadOptions?: Record<string, unknown>
|
|
267
|
+
): Promise<{
|
|
268
|
+
modelRef: string;
|
|
269
|
+
modelUrl: string | null;
|
|
270
|
+
manifest: Record<string, unknown>;
|
|
271
|
+
pipeline: Record<string, unknown>;
|
|
272
|
+
}>;
|
|
273
|
+
|
|
274
|
+
export declare function createDistillRuntimeContext(
|
|
275
|
+
options?: RunTrainingSuiteOptions,
|
|
276
|
+
trainingConfig?: Record<string, unknown> | null
|
|
277
|
+
): Promise<DistillRuntimeContext>;
|
|
278
|
+
|
|
279
|
+
export declare function createToyModelFixture(
|
|
280
|
+
overrides?: Record<string, unknown>
|
|
281
|
+
): {
|
|
282
|
+
config: Record<string, unknown>;
|
|
283
|
+
model: {
|
|
284
|
+
forward: (input: unknown, tape: unknown) => Promise<unknown>;
|
|
285
|
+
loraParams(): unknown[];
|
|
286
|
+
paramGroups(): Record<string, unknown[]>;
|
|
287
|
+
};
|
|
288
|
+
batch: Record<string, unknown>;
|
|
289
|
+
cleanup(): void;
|
|
290
|
+
};
|
|
291
|
+
|
|
292
|
+
export declare function createDistillStudentRuntimeModelFixture(
|
|
293
|
+
overrides?: Record<string, unknown>,
|
|
294
|
+
options?: Record<string, unknown>
|
|
295
|
+
): Promise<DistillStudentFixture>;
|
|
296
|
+
|
|
297
|
+
export declare function buildDistillTrainingOverrides(
|
|
298
|
+
options?: RunTrainingSuiteOptions
|
|
299
|
+
): Record<string, unknown> | null;
|