@simulatte/doppler 0.1.5 → 0.1.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +126 -0
- package/README.md +25 -17
- package/package.json +20 -4
- package/src/adapters/adapter-registry.js +12 -1
- package/src/adapters/lora-loader.js +23 -6
- package/src/bridge/extension-client.d.ts +5 -0
- package/src/bridge/extension-client.js +40 -0
- package/src/bridge/index.d.ts +2 -1
- package/src/bridge/index.js +6 -4
- package/src/browser/browser-converter.js +26 -1
- package/src/browser/file-picker.js +6 -0
- package/src/browser/safetensors-parser-browser.js +84 -1
- package/src/browser/shard-io-browser.js +2 -2
- package/src/browser/tensor-source-download.js +8 -2
- package/src/browser/tensor-source-http.d.ts +1 -0
- package/src/browser/tensor-source-http.js +5 -1
- package/src/client/doppler-api.browser.js +20 -4
- package/src/client/doppler-api.js +19 -3
- package/src/client/doppler-provider/generation.js +12 -0
- package/src/client/doppler-provider/model-manager.d.ts +10 -0
- package/src/client/doppler-provider/model-manager.js +91 -19
- package/src/client/doppler-provider/source-runtime.d.ts +2 -1
- package/src/client/doppler-provider/source-runtime.js +132 -13
- package/src/client/doppler-registry.json +8 -7
- package/src/config/backward-registry-loader.js +17 -2
- package/src/config/execution-v0-contract-check.js +113 -15
- package/src/config/kernel-path-contract-check.js +57 -29
- package/src/config/kernel-path-loader.js +5 -36
- package/src/config/kernels/kernel-ref-digests.js +39 -39
- package/src/config/kernels/registry.js +14 -1
- package/src/config/kernels/registry.json +49 -7
- package/src/config/loader.d.ts +1 -1
- package/src/config/loader.js +43 -4
- package/src/config/merge-contract-check.js +59 -4
- package/src/config/merge-helpers.js +128 -7
- package/src/config/merge.d.ts +1 -0
- package/src/config/merge.js +28 -0
- package/src/config/param-validator.js +47 -2
- package/src/config/presets/kernel-paths/{gemma2-q4k-dequant-f32a.json → gemma2-q4k-dequant-f32a-nosubgroups.json} +3 -3
- package/src/config/presets/kernel-paths/gemma3-f16-fused-f32a-online-streamingprefill.json +223 -0
- package/src/config/presets/kernel-paths/{gemma3-q4k-dequant-f32a.json → gemma3-q4k-dequant-f32a-nosubgroups.json} +3 -3
- package/src/config/presets/kernel-paths/registry.json +29 -8
- package/src/config/presets/models/gemma2.json +2 -2
- package/src/config/presets/models/qwen3.json +9 -2
- package/src/config/presets/models/transformer.json +5 -0
- package/src/config/presets/runtime/experiments/bench/gemma3-bench-q4k.json +1 -1
- package/src/config/presets/runtime/experiments/debug/gemma3-debug-q4k.json +1 -1
- package/src/config/presets/runtime/experiments/verify/gemma3-verify.json +1 -1
- package/src/config/presets/runtime/kernels/dequant-f16-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/dequant-f32-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/embeddinggemma-q4k-dequant-f32a.json +37 -0
- package/src/config/presets/runtime/kernels/fused-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f16a.json +33 -0
- package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f32a-nosubgroups.json +33 -0
- package/src/config/presets/runtime/kernels/gemma2-q4k-fused-f32a.json +33 -0
- package/src/config/presets/runtime/kernels/safe-q4k.json +6 -13
- package/src/config/presets/runtime/platform/metal-apple-q4k.json +1 -1
- package/src/config/required-inference-fields-contract-check.js +6 -0
- package/src/config/runtime.js +6 -1
- package/src/config/schema/debug.schema.d.ts +5 -0
- package/src/config/schema/doppler.schema.js +16 -21
- package/src/config/schema/inference-defaults.schema.js +6 -3
- package/src/config/schema/inference.schema.d.ts +9 -0
- package/src/config/schema/kernel-path.schema.d.ts +11 -1
- package/src/config/schema/kernel-thresholds.schema.js +12 -4
- package/src/config/schema/manifest.schema.d.ts +8 -1
- package/src/config/schema/manifest.schema.js +19 -3
- package/src/config/training-defaults.js +30 -22
- package/src/converter/conversion-plan.js +94 -9
- package/src/converter/core.d.ts +7 -0
- package/src/converter/core.js +14 -9
- package/src/converter/execution-v0-manifest.js +4 -1
- package/src/converter/index.d.ts +1 -0
- package/src/converter/index.js +1 -0
- package/src/converter/manifest-inference.js +43 -12
- package/src/converter/parsers/diffusion.js +0 -3
- package/src/converter/quantization-info.js +35 -15
- package/src/converter/rope-config.js +42 -0
- package/src/converter/shard-packer.d.ts +1 -1
- package/src/converter/shard-packer.js +4 -1
- package/src/debug/config.js +123 -11
- package/src/debug/signals.js +7 -1
- package/src/debug/tensor.d.ts +2 -0
- package/src/debug/tensor.js +13 -2
- package/src/distribution/p2p-control-plane.js +52 -12
- package/src/distribution/p2p-observability.js +43 -7
- package/src/distribution/p2p-webrtc-browser.js +20 -0
- package/src/distribution/shard-delivery.js +77 -26
- package/src/formats/gguf/types.js +33 -16
- package/src/formats/rdrr/groups.d.ts +12 -4
- package/src/formats/rdrr/groups.js +3 -6
- package/src/formats/rdrr/parsing.js +39 -2
- package/src/formats/rdrr/types.d.ts +2 -1
- package/src/gpu/command-recorder.js +86 -61
- package/src/gpu/device.d.ts +1 -0
- package/src/gpu/device.js +131 -19
- package/src/gpu/kernel-tuner/benchmarks.js +326 -316
- package/src/gpu/kernel-tuner/cache.js +71 -4
- package/src/gpu/kernel-tuner/tuner.js +22 -4
- package/src/gpu/kernels/attention.js +113 -34
- package/src/gpu/kernels/backward/adam.js +62 -58
- package/src/gpu/kernels/backward/attention_backward.js +257 -169
- package/src/gpu/kernels/backward/conv2d_backward.js +14 -1
- package/src/gpu/kernels/bias_add.wgsl +8 -6
- package/src/gpu/kernels/bias_add_f16.wgsl +8 -5
- package/src/gpu/kernels/cast.js +191 -149
- package/src/gpu/kernels/check-stop.js +33 -44
- package/src/gpu/kernels/conv2d.js +27 -17
- package/src/gpu/kernels/conv2d.wgsl +7 -8
- package/src/gpu/kernels/conv2d_f16.wgsl +7 -8
- package/src/gpu/kernels/cross_entropy_loss.js +21 -15
- package/src/gpu/kernels/depthwise_conv2d.js +37 -26
- package/src/gpu/kernels/depthwise_conv2d.wgsl +6 -9
- package/src/gpu/kernels/depthwise_conv2d_f16.wgsl +6 -9
- package/src/gpu/kernels/dequant.js +178 -126
- package/src/gpu/kernels/energy.d.ts +3 -21
- package/src/gpu/kernels/energy.js +111 -88
- package/src/gpu/kernels/feature-check.js +1 -1
- package/src/gpu/kernels/fused_ffn.js +84 -65
- package/src/gpu/kernels/fused_matmul_residual.js +56 -33
- package/src/gpu/kernels/fused_matmul_rmsnorm.js +62 -45
- package/src/gpu/kernels/gather.js +33 -15
- package/src/gpu/kernels/gelu.js +19 -11
- package/src/gpu/kernels/grouped_pointwise_conv2d.js +34 -23
- package/src/gpu/kernels/grouped_pointwise_conv2d.wgsl +6 -9
- package/src/gpu/kernels/grouped_pointwise_conv2d_f16.wgsl +6 -9
- package/src/gpu/kernels/groupnorm.js +34 -23
- package/src/gpu/kernels/kv-quantize.js +5 -2
- package/src/gpu/kernels/layernorm.js +35 -19
- package/src/gpu/kernels/logit-merge.js +5 -3
- package/src/gpu/kernels/matmul.js +83 -39
- package/src/gpu/kernels/modulate.js +23 -15
- package/src/gpu/kernels/moe.js +221 -175
- package/src/gpu/kernels/pixel_shuffle.js +22 -14
- package/src/gpu/kernels/pixel_shuffle.wgsl +4 -5
- package/src/gpu/kernels/pixel_shuffle_f16.wgsl +4 -5
- package/src/gpu/kernels/relu.js +31 -10
- package/src/gpu/kernels/relu.wgsl +2 -1
- package/src/gpu/kernels/relu_f16.wgsl +2 -1
- package/src/gpu/kernels/repeat_channels.js +25 -17
- package/src/gpu/kernels/repeat_channels.wgsl +4 -5
- package/src/gpu/kernels/repeat_channels_f16.wgsl +4 -5
- package/src/gpu/kernels/residual.js +69 -23
- package/src/gpu/kernels/residual.wgsl +6 -3
- package/src/gpu/kernels/residual_f16.wgsl +2 -1
- package/src/gpu/kernels/residual_f16_vec4.wgsl +2 -1
- package/src/gpu/kernels/residual_vec4.wgsl +2 -1
- package/src/gpu/kernels/rmsnorm.js +96 -28
- package/src/gpu/kernels/rmsnorm.wgsl +14 -6
- package/src/gpu/kernels/rmsnorm_f16.wgsl +10 -2
- package/src/gpu/kernels/rope.d.ts +2 -0
- package/src/gpu/kernels/rope.js +14 -1
- package/src/gpu/kernels/rope.wgsl +56 -40
- package/src/gpu/kernels/sample.js +27 -38
- package/src/gpu/kernels/sana_linear_attention.js +19 -12
- package/src/gpu/kernels/sana_linear_attention_apply.wgsl +4 -5
- package/src/gpu/kernels/sana_linear_attention_apply_f16.wgsl +4 -5
- package/src/gpu/kernels/sana_linear_attention_summary.wgsl +4 -0
- package/src/gpu/kernels/sana_linear_attention_summary_f16.wgsl +4 -0
- package/src/gpu/kernels/scale.js +18 -11
- package/src/gpu/kernels/shader-cache.js +4 -2
- package/src/gpu/kernels/silu.d.ts +1 -0
- package/src/gpu/kernels/silu.js +148 -82
- package/src/gpu/kernels/silu.wgsl +19 -9
- package/src/gpu/kernels/silu_f16.wgsl +19 -9
- package/src/gpu/kernels/softmax.js +44 -25
- package/src/gpu/kernels/split_qkv.js +23 -13
- package/src/gpu/kernels/transpose.js +31 -10
- package/src/gpu/kernels/transpose.wgsl +6 -5
- package/src/gpu/kernels/upsample2d.js +22 -13
- package/src/gpu/kernels/upsample2d.wgsl +6 -9
- package/src/gpu/kernels/upsample2d_f16.wgsl +6 -9
- package/src/gpu/kernels/utils.js +35 -13
- package/src/gpu/partitioned-buffer-pool.js +10 -2
- package/src/gpu/perf-guards.js +2 -9
- package/src/gpu/profiler.js +27 -22
- package/src/gpu/readback-utils.d.ts +16 -0
- package/src/gpu/readback-utils.js +41 -0
- package/src/gpu/submit-tracker.js +13 -0
- package/src/gpu/uniform-cache.d.ts +1 -0
- package/src/gpu/uniform-cache.js +30 -9
- package/src/hotswap/intent-bundle.js +6 -0
- package/src/hotswap/manifest.d.ts +10 -1
- package/src/hotswap/manifest.js +12 -2
- package/src/hotswap/runtime.js +30 -8
- package/src/index-browser.d.ts +44 -0
- package/src/index-browser.js +14 -0
- package/src/inference/browser-harness-contract-helpers.d.ts +5 -0
- package/src/inference/browser-harness-contract-helpers.js +28 -0
- package/src/inference/browser-harness-diffusion-energy-suites.d.ts +2 -0
- package/src/inference/browser-harness-diffusion-energy-suites.js +269 -0
- package/src/inference/browser-harness-model-helpers.d.ts +16 -0
- package/src/inference/browser-harness-model-helpers.js +217 -0
- package/src/inference/browser-harness-report-helpers.d.ts +7 -0
- package/src/inference/browser-harness-report-helpers.js +42 -0
- package/src/inference/browser-harness-runtime-helpers.d.ts +61 -0
- package/src/inference/browser-harness-runtime-helpers.js +415 -0
- package/src/inference/browser-harness-suite-helpers.d.ts +28 -0
- package/src/inference/browser-harness-suite-helpers.js +268 -0
- package/src/inference/browser-harness-text-helpers.d.ts +27 -0
- package/src/inference/browser-harness-text-helpers.js +788 -0
- package/src/inference/browser-harness.d.ts +6 -0
- package/src/inference/browser-harness.js +130 -1950
- package/src/inference/kv-cache/base.js +140 -94
- package/src/inference/kv-cache/tiered.js +5 -3
- package/src/inference/moe-router.js +88 -56
- package/src/inference/multi-model-network.js +5 -3
- package/src/inference/network-evolution.d.ts +11 -2
- package/src/inference/network-evolution.js +20 -21
- package/src/inference/pipelines/context.d.ts +3 -0
- package/src/inference/pipelines/context.js +142 -2
- package/src/inference/pipelines/diffusion/helpers.js +7 -2
- package/src/inference/pipelines/diffusion/pipeline.js +17 -7
- package/src/inference/pipelines/diffusion/sd3-transformer.js +10 -10
- package/src/inference/pipelines/diffusion/text-encoder-gpu.d.ts +5 -0
- package/src/inference/pipelines/diffusion/text-encoder-gpu.js +27 -15
- package/src/inference/pipelines/diffusion/vae.js +3 -7
- package/src/inference/pipelines/energy/pipeline.js +27 -21
- package/src/inference/pipelines/energy/quintel.d.ts +5 -0
- package/src/inference/pipelines/energy/quintel.js +11 -0
- package/src/inference/pipelines/energy-head/row-head-pipeline.js +17 -13
- package/src/inference/pipelines/structured/json-head-pipeline.js +26 -11
- package/src/inference/pipelines/text/attention/projections.js +151 -101
- package/src/inference/pipelines/text/attention/record.js +73 -10
- package/src/inference/pipelines/text/attention/run.js +73 -10
- package/src/inference/pipelines/text/chat-format.js +25 -1
- package/src/inference/pipelines/text/config.d.ts +4 -0
- package/src/inference/pipelines/text/config.js +71 -5
- package/src/inference/pipelines/text/embed.js +2 -8
- package/src/inference/pipelines/text/execution-plan.js +64 -50
- package/src/inference/pipelines/text/execution-v0-contract-helpers.d.ts +59 -0
- package/src/inference/pipelines/text/execution-v0-contract-helpers.js +937 -0
- package/src/inference/pipelines/text/execution-v0-runtime-builders.d.ts +15 -0
- package/src/inference/pipelines/text/execution-v0-runtime-builders.js +279 -0
- package/src/inference/pipelines/text/execution-v0.js +78 -1002
- package/src/inference/pipelines/text/ffn/standard.js +3 -0
- package/src/inference/pipelines/text/generator-steps.d.ts +46 -0
- package/src/inference/pipelines/text/generator-steps.js +298 -207
- package/src/inference/pipelines/text/generator.js +6 -23
- package/src/inference/pipelines/text/init.d.ts +4 -0
- package/src/inference/pipelines/text/init.js +134 -29
- package/src/inference/pipelines/text/kernel-path-auto-select.js +2 -0
- package/src/inference/pipelines/text/kernel-trace.d.ts +2 -0
- package/src/inference/pipelines/text/kernel-trace.js +6 -0
- package/src/inference/pipelines/text/layer.js +14 -9
- package/src/inference/pipelines/text/linear-attention.d.ts +10 -0
- package/src/inference/pipelines/text/linear-attention.js +80 -6
- package/src/inference/pipelines/text/logits/gpu.js +10 -5
- package/src/inference/pipelines/text/logits/index.js +10 -11
- package/src/inference/pipelines/text/logits/utils.d.ts +7 -0
- package/src/inference/pipelines/text/logits/utils.js +9 -0
- package/src/inference/pipelines/text/lora-apply.js +50 -32
- package/src/inference/pipelines/text/model-load.js +279 -104
- package/src/inference/pipelines/text/moe-cache.js +5 -4
- package/src/inference/pipelines/text/moe-cpu-gptoss.js +74 -69
- package/src/inference/pipelines/text/moe-cpu.js +42 -38
- package/src/inference/pipelines/text/moe-gpu.js +110 -86
- package/src/inference/pipelines/text/ops.js +90 -90
- package/src/inference/pipelines/text/probes.js +9 -9
- package/src/inference/pipelines/text/weights.js +17 -7
- package/src/inference/pipelines/text.js +17 -1
- package/src/inference/speculative.d.ts +2 -2
- package/src/inference/speculative.js +4 -18
- package/src/inference/test-harness.d.ts +1 -1
- package/src/inference/test-harness.js +15 -5
- package/src/inference/tokenizer.d.ts +0 -5
- package/src/inference/tokenizer.js +4 -23
- package/src/inference/tokenizers/bpe.js +9 -0
- package/src/inference/tokenizers/bundled.js +176 -33
- package/src/inference/tokenizers/sentencepiece.js +12 -0
- package/src/loader/doppler-loader.js +38 -22
- package/src/loader/dtype-utils.js +3 -44
- package/src/loader/embedding-loader.js +7 -3
- package/src/loader/experts/expert-cache.js +13 -6
- package/src/loader/experts/expert-loader.js +10 -6
- package/src/loader/final-weights-loader.js +8 -4
- package/src/loader/layer-loader.js +2 -1
- package/src/loader/loader-state.js +2 -2
- package/src/loader/memory-monitor.js +8 -0
- package/src/loader/multi-model-loader.d.ts +14 -0
- package/src/loader/multi-model-loader.js +70 -24
- package/src/loader/shard-cache.js +81 -12
- package/src/loader/shard-resolver.js +25 -3
- package/src/loader/tensors/tensor-loader.js +209 -144
- package/src/loader/tensors/tensor-reader.js +76 -19
- package/src/loader/weight-downcast.js +1 -1
- package/src/memory/buffer-pool.d.ts +9 -1
- package/src/memory/buffer-pool.js +109 -44
- package/src/memory/unified-detect.js +1 -1
- package/src/rules/inference/kernel-path.rules.json +24 -8
- package/src/rules/rule-registry.js +25 -1
- package/src/rules/tooling/command-runtime.rules.json +18 -0
- package/src/storage/backends/opfs-store.js +68 -24
- package/src/storage/downloader.js +364 -83
- package/src/storage/index.d.ts +3 -0
- package/src/storage/index.js +3 -0
- package/src/storage/preflight.d.ts +2 -2
- package/src/storage/preflight.js +24 -2
- package/src/storage/quickstart-downloader.js +11 -5
- package/src/storage/registry.js +10 -4
- package/src/storage/reports.js +1 -1
- package/src/storage/shard-manager.d.ts +15 -1
- package/src/storage/shard-manager.js +51 -3
- package/src/storage/source-artifact-store.d.ts +52 -0
- package/src/storage/source-artifact-store.js +234 -0
- package/src/tooling/command-api-constants.d.ts +9 -0
- package/src/tooling/command-api-constants.js +9 -0
- package/src/tooling/command-api-family-normalizers.d.ts +9 -0
- package/src/tooling/command-api-family-normalizers.js +343 -0
- package/src/tooling/command-api-helpers.d.ts +25 -0
- package/src/tooling/command-api-helpers.js +262 -0
- package/src/tooling/command-api.d.ts +27 -1
- package/src/tooling/command-api.js +26 -473
- package/src/tooling/command-envelope.js +4 -1
- package/src/tooling/command-runner-shared.js +52 -18
- package/src/tooling/lean-execution-contract.js +150 -3
- package/src/tooling/node-browser-command-runner.d.ts +4 -0
- package/src/tooling/node-browser-command-runner.js +218 -273
- package/src/tooling/node-command-runner.js +44 -3
- package/src/tooling/node-converter.js +27 -1
- package/src/tooling/node-source-runtime.d.ts +1 -1
- package/src/tooling/node-source-runtime.js +84 -3
- package/src/tooling/node-webgpu.js +30 -105
- package/src/tooling/opfs-cache.js +21 -4
- package/src/tooling/runtime-input-composition.d.ts +38 -0
- package/src/tooling/runtime-input-composition.js +86 -0
- package/src/tooling/source-runtime-bundle.d.ts +40 -5
- package/src/tooling/source-runtime-bundle.js +261 -34
- package/src/tooling/source-runtime-materializer.d.ts +6 -0
- package/src/tooling/source-runtime-materializer.js +93 -0
- package/src/training/attention-backward.js +32 -17
- package/src/training/autograd.js +80 -52
- package/src/training/checkpoint-watch.d.ts +8 -0
- package/src/training/checkpoint-watch.js +139 -0
- package/src/training/checkpoint.d.ts +6 -1
- package/src/training/checkpoint.js +46 -7
- package/src/training/clip.js +2 -1
- package/src/training/datasets/token-batch.js +20 -8
- package/src/training/distillation/artifacts.d.ts +71 -0
- package/src/training/distillation/artifacts.js +132 -0
- package/src/training/distillation/checkpoint-watch.d.ts +10 -0
- package/src/training/distillation/checkpoint-watch.js +58 -0
- package/src/training/distillation/dataset.d.ts +59 -0
- package/src/training/distillation/dataset.js +337 -0
- package/src/training/distillation/eval.d.ts +34 -0
- package/src/training/distillation/eval.js +310 -0
- package/src/training/distillation/index.d.ts +29 -0
- package/src/training/distillation/index.js +29 -0
- package/src/training/distillation/runtime.d.ts +20 -0
- package/src/training/distillation/runtime.js +121 -0
- package/src/training/distillation/scoreboard.d.ts +6 -0
- package/src/training/distillation/scoreboard.js +8 -0
- package/src/training/distillation/stage-a.d.ts +45 -0
- package/src/training/distillation/stage-a.js +338 -0
- package/src/training/distillation/stage-b.d.ts +24 -0
- package/src/training/distillation/stage-b.js +20 -0
- package/src/training/distillation/student-fixture.d.ts +22 -0
- package/src/training/distillation/student-fixture.js +846 -0
- package/src/training/distillation/suite-data.d.ts +45 -0
- package/src/training/distillation/suite-data.js +189 -0
- package/src/training/index.d.ts +10 -0
- package/src/training/index.js +10 -0
- package/src/training/lora-pipeline.d.ts +40 -0
- package/src/training/lora-pipeline.js +793 -0
- package/src/training/lora.js +26 -12
- package/src/training/loss.js +5 -6
- package/src/training/objectives/cross_entropy.js +2 -5
- package/src/training/objectives/distill_kd.js +4 -8
- package/src/training/objectives/distill_triplet.js +4 -8
- package/src/training/objectives/ul_stage2_base.js +4 -8
- package/src/training/operator-artifacts.d.ts +62 -0
- package/src/training/operator-artifacts.js +140 -0
- package/src/training/operator-command.d.ts +5 -0
- package/src/training/operator-command.js +455 -0
- package/src/training/operator-eval.d.ts +48 -0
- package/src/training/operator-eval.js +230 -0
- package/src/training/operator-scoreboard.d.ts +5 -0
- package/src/training/operator-scoreboard.js +44 -0
- package/src/training/optimizer.js +19 -7
- package/src/training/runner.d.ts +52 -0
- package/src/training/runner.js +31 -5
- package/src/training/suite.d.ts +112 -0
- package/src/training/suite.js +24 -984
- package/src/training/tensor-factory.d.ts +9 -0
- package/src/training/tensor-factory.js +13 -0
- package/src/training/trainer.js +3 -5
- package/src/training/ul_dataset.js +3 -5
- package/src/training/workloads.d.ts +164 -0
- package/src/training/workloads.js +530 -0
- package/src/version.js +1 -1
- package/tools/convert-safetensors-node.js +22 -16
- package/tools/doppler-cli.js +179 -63
package/src/training/suite.js
CHANGED
|
@@ -3,18 +3,11 @@ import { setPlatformsBaseUrl } from '../config/platforms/loader.js';
|
|
|
3
3
|
import { setRegistryUrl } from '../config/kernels/registry.js';
|
|
4
4
|
import { createTrainingConfig } from '../config/training-defaults.js';
|
|
5
5
|
import {
|
|
6
|
-
runAttention,
|
|
7
|
-
castF16ToF32,
|
|
8
|
-
runGather,
|
|
9
6
|
runMatmul,
|
|
10
7
|
runResidualAdd,
|
|
11
|
-
runRMSNorm,
|
|
12
|
-
runRoPE,
|
|
13
|
-
runSiLURowSplit,
|
|
14
8
|
} from '../gpu/kernels/index.js';
|
|
15
9
|
import { createTensor } from '../gpu/tensor.js';
|
|
16
10
|
import { acquireBuffer, uploadData, releaseBuffer } from '../memory/buffer-pool.js';
|
|
17
|
-
import { getBufferDtype, getWeightDtype, isCpuWeightBuffer, isWeightBuffer } from '../gpu/weight-buffer.js';
|
|
18
11
|
import { OpType } from './autograd.js';
|
|
19
12
|
import { AdamOptimizer } from './optimizer.js';
|
|
20
13
|
import { TrainingRunner } from './runner.js';
|
|
@@ -25,6 +18,16 @@ import { exportLoRAAdapter } from './export.js';
|
|
|
25
18
|
import { sha256Hex } from '../utils/sha256.js';
|
|
26
19
|
import { computeSampleStats } from '../debug/stats.js';
|
|
27
20
|
import { parseJsonl } from './datasets/jsonl.js';
|
|
21
|
+
import {
|
|
22
|
+
buildDistillCandidatePrompt,
|
|
23
|
+
buildDistillPrompt,
|
|
24
|
+
encodeDistillRow,
|
|
25
|
+
normalizeDistillDatasetPath,
|
|
26
|
+
normalizeOptionalString,
|
|
27
|
+
resolveDistillDataScope,
|
|
28
|
+
summarizeDirectionCounts,
|
|
29
|
+
} from './distillation/suite-data.js';
|
|
30
|
+
import { createDistillStudentRuntimeModelFixture } from './distillation/student-fixture.js';
|
|
28
31
|
import { initializeInference } from '../inference/test-harness.js';
|
|
29
32
|
import { createPipeline } from '../inference/pipelines/text.js';
|
|
30
33
|
import { parseManifest } from '../formats/rdrr/index.js';
|
|
@@ -128,195 +131,7 @@ function isNodeRuntime() {
|
|
|
128
131
|
return typeof process !== 'undefined' && !!process.versions?.node;
|
|
129
132
|
}
|
|
130
133
|
|
|
131
|
-
|
|
132
|
-
if (value === undefined || value === null) return null;
|
|
133
|
-
const trimmed = String(value).trim();
|
|
134
|
-
return trimmed || null;
|
|
135
|
-
}
|
|
136
|
-
|
|
137
|
-
function normalizeDistillDatasetPath(value) {
|
|
138
|
-
return normalizeOptionalString(value);
|
|
139
|
-
}
|
|
140
|
-
|
|
141
|
-
function normalizeLangCode(value) {
|
|
142
|
-
const normalized = normalizeOptionalString(value);
|
|
143
|
-
if (!normalized) return null;
|
|
144
|
-
const compact = normalized.toLowerCase().replace(/_/g, '-');
|
|
145
|
-
if (compact.startsWith('en')) return 'en';
|
|
146
|
-
if (compact.startsWith('es')) return 'es';
|
|
147
|
-
return compact;
|
|
148
|
-
}
|
|
149
|
-
|
|
150
|
-
function normalizePairDirection(value) {
|
|
151
|
-
const pair = normalizeOptionalString(value);
|
|
152
|
-
if (!pair) return null;
|
|
153
|
-
const normalized = pair.toLowerCase().replace(/_/g, '-').replace(/\s+/g, '');
|
|
154
|
-
const parts = normalized.includes('->')
|
|
155
|
-
? normalized.split('->').filter(Boolean)
|
|
156
|
-
: normalized.split('-').filter(Boolean);
|
|
157
|
-
if (parts.length !== 2) return null;
|
|
158
|
-
return `${normalizeLangCode(parts[0]) || parts[0]}->${normalizeLangCode(parts[1]) || parts[1]}`;
|
|
159
|
-
}
|
|
160
|
-
|
|
161
|
-
function normalizeOptionalStringArray(value) {
|
|
162
|
-
if (value === undefined || value === null) return null;
|
|
163
|
-
const list = Array.isArray(value)
|
|
164
|
-
? value
|
|
165
|
-
: (typeof value === 'string' ? value.split(',') : null);
|
|
166
|
-
if (!Array.isArray(list)) return null;
|
|
167
|
-
const normalized = list
|
|
168
|
-
.map((entry) => normalizeOptionalString(entry))
|
|
169
|
-
.filter(Boolean);
|
|
170
|
-
return normalized.length > 0 ? normalized : null;
|
|
171
|
-
}
|
|
172
|
-
|
|
173
|
-
function normalizeDistillLanguageAllowlist(value) {
|
|
174
|
-
const list = normalizeOptionalStringArray(value);
|
|
175
|
-
if (!list) return null;
|
|
176
|
-
const normalized = list
|
|
177
|
-
.map((entry) => normalizeLangCode(entry))
|
|
178
|
-
.filter(Boolean);
|
|
179
|
-
if (normalized.length === 0) return null;
|
|
180
|
-
return [...new Set(normalized)];
|
|
181
|
-
}
|
|
182
|
-
|
|
183
|
-
function normalizeDistillPairAllowlist(value) {
|
|
184
|
-
const list = normalizeOptionalStringArray(value);
|
|
185
|
-
if (!list) return null;
|
|
186
|
-
const normalized = list
|
|
187
|
-
.map((entry) => normalizePairDirection(entry))
|
|
188
|
-
.filter(Boolean);
|
|
189
|
-
if (normalized.length === 0) return null;
|
|
190
|
-
return [...new Set(normalized)];
|
|
191
|
-
}
|
|
192
|
-
|
|
193
|
-
function resolveDistillDataScope(options = {}, trainingConfig = null) {
|
|
194
|
-
const distillConfig = trainingConfig?.distill || {};
|
|
195
|
-
const sourceLangs = normalizeDistillLanguageAllowlist(
|
|
196
|
-
options.distillSourceLangs ?? distillConfig.sourceLangs ?? null
|
|
197
|
-
);
|
|
198
|
-
const targetLangs = normalizeDistillLanguageAllowlist(
|
|
199
|
-
options.distillTargetLangs ?? distillConfig.targetLangs ?? null
|
|
200
|
-
);
|
|
201
|
-
const pairAllowlist = normalizeDistillPairAllowlist(
|
|
202
|
-
options.distillPairAllowlist ?? distillConfig.pairAllowlist ?? null
|
|
203
|
-
);
|
|
204
|
-
const strictPairContract = (
|
|
205
|
-
options.strictPairContract === true
|
|
206
|
-
|| distillConfig.strictPairContract === true
|
|
207
|
-
);
|
|
208
|
-
return {
|
|
209
|
-
sourceLangs,
|
|
210
|
-
targetLangs,
|
|
211
|
-
pairAllowlist,
|
|
212
|
-
sourceLangSet: sourceLangs ? new Set(sourceLangs) : null,
|
|
213
|
-
targetLangSet: targetLangs ? new Set(targetLangs) : null,
|
|
214
|
-
pairAllowlistSet: pairAllowlist ? new Set(pairAllowlist) : null,
|
|
215
|
-
strictPairContract,
|
|
216
|
-
};
|
|
217
|
-
}
|
|
218
|
-
|
|
219
|
-
function resolveDistillDirection(record) {
|
|
220
|
-
const pairDirection = normalizePairDirection(record?.pair);
|
|
221
|
-
if (pairDirection) return pairDirection;
|
|
222
|
-
const srcLang = normalizeLangCode(record?.src_lang);
|
|
223
|
-
const tgtLang = normalizeLangCode(record?.tgt_lang || record?.lang);
|
|
224
|
-
if (srcLang && tgtLang) {
|
|
225
|
-
return `${srcLang}->${tgtLang}`;
|
|
226
|
-
}
|
|
227
|
-
return null;
|
|
228
|
-
}
|
|
229
|
-
|
|
230
|
-
function resolveStringCandidate(record, keys) {
|
|
231
|
-
for (const key of keys) {
|
|
232
|
-
const value = normalizeOptionalString(record?.[key]);
|
|
233
|
-
if (value) return value;
|
|
234
|
-
}
|
|
235
|
-
return null;
|
|
236
|
-
}
|
|
237
|
-
|
|
238
|
-
function encodeDistillRow(record, index, scope = null) {
|
|
239
|
-
if (!record || typeof record !== 'object') return null;
|
|
240
|
-
const source = resolveStringCandidate(record, ['source', 'query']);
|
|
241
|
-
const targetPos = resolveStringCandidate(record, ['target_pos', 'target', 'pos']);
|
|
242
|
-
const targetNeg = resolveStringCandidate(record, ['target_neg', 'neg']);
|
|
243
|
-
if (!source || !targetPos) return null;
|
|
244
|
-
const sourceLangRaw = normalizeLangCode(record?.src_lang);
|
|
245
|
-
const targetLangRaw = normalizeLangCode(record?.tgt_lang || record?.lang);
|
|
246
|
-
const pairDirection = normalizePairDirection(record?.pair);
|
|
247
|
-
const sourceTargetDirection = (
|
|
248
|
-
sourceLangRaw && targetLangRaw
|
|
249
|
-
? `${sourceLangRaw}->${targetLangRaw}`
|
|
250
|
-
: null
|
|
251
|
-
);
|
|
252
|
-
if (scope?.strictPairContract === true) {
|
|
253
|
-
if (!sourceLangRaw || !targetLangRaw) {
|
|
254
|
-
throw new Error('strictPairContract requires src_lang and tgt_lang/lang on each row.');
|
|
255
|
-
}
|
|
256
|
-
if (!pairDirection) {
|
|
257
|
-
throw new Error('strictPairContract requires pair on each row.');
|
|
258
|
-
}
|
|
259
|
-
if (pairDirection !== sourceTargetDirection) {
|
|
260
|
-
throw new Error(`pair "${record?.pair}" does not match src/tgt "${sourceLangRaw}-${targetLangRaw}".`);
|
|
261
|
-
}
|
|
262
|
-
}
|
|
263
|
-
const direction = pairDirection || sourceTargetDirection || resolveDistillDirection(record) || 'unknown';
|
|
264
|
-
const [directionSourceLang, directionTargetLang] = String(direction).split('->');
|
|
265
|
-
const sourceLang = sourceLangRaw || normalizeLangCode(directionSourceLang);
|
|
266
|
-
const targetLang = targetLangRaw || normalizeLangCode(directionTargetLang);
|
|
267
|
-
if (scope?.sourceLangSet && (!sourceLang || !scope.sourceLangSet.has(sourceLang))) {
|
|
268
|
-
return null;
|
|
269
|
-
}
|
|
270
|
-
if (scope?.targetLangSet && (!targetLang || !scope.targetLangSet.has(targetLang))) {
|
|
271
|
-
return null;
|
|
272
|
-
}
|
|
273
|
-
if (scope?.pairAllowlistSet && !scope.pairAllowlistSet.has(direction)) {
|
|
274
|
-
return null;
|
|
275
|
-
}
|
|
276
|
-
|
|
277
|
-
return {
|
|
278
|
-
index,
|
|
279
|
-
direction,
|
|
280
|
-
sourceLang: sourceLang || null,
|
|
281
|
-
targetLang: targetLang || null,
|
|
282
|
-
source,
|
|
283
|
-
targetPos,
|
|
284
|
-
targetNeg: targetNeg || null,
|
|
285
|
-
};
|
|
286
|
-
}
|
|
287
|
-
|
|
288
|
-
function summarizeDirectionCounts(samples) {
|
|
289
|
-
const counts = {};
|
|
290
|
-
for (const sample of samples) {
|
|
291
|
-
const key = sample?.direction || 'unknown';
|
|
292
|
-
counts[key] = (counts[key] || 0) + 1;
|
|
293
|
-
}
|
|
294
|
-
return counts;
|
|
295
|
-
}
|
|
296
|
-
|
|
297
|
-
function resolveLanguageName(langCode) {
|
|
298
|
-
const normalized = normalizeLangCode(langCode);
|
|
299
|
-
if (normalized === 'en') return 'English';
|
|
300
|
-
if (normalized === 'es') return 'Spanish';
|
|
301
|
-
return normalized || 'target';
|
|
302
|
-
}
|
|
303
|
-
|
|
304
|
-
function buildDistillPrompt(sample) {
|
|
305
|
-
const direction = String(sample?.direction || '').trim();
|
|
306
|
-
const [srcCodeRaw, tgtCodeRaw] = direction.split('->');
|
|
307
|
-
const srcCode = normalizeLangCode(srcCodeRaw) || srcCodeRaw || 'source';
|
|
308
|
-
const tgtCode = normalizeLangCode(tgtCodeRaw) || tgtCodeRaw || 'target';
|
|
309
|
-
const srcName = resolveLanguageName(srcCode);
|
|
310
|
-
const tgtName = resolveLanguageName(tgtCode);
|
|
311
|
-
const source = String(sample?.source || '').trim();
|
|
312
|
-
return `Translate from ${srcName} to ${tgtName}:\n${source}\nTranslation:`;
|
|
313
|
-
}
|
|
314
|
-
|
|
315
|
-
function buildDistillCandidatePrompt(sample, candidate) {
|
|
316
|
-
const base = buildDistillPrompt(sample);
|
|
317
|
-
const text = String(candidate || '').trim();
|
|
318
|
-
return text ? `${base} ${text}` : base;
|
|
319
|
-
}
|
|
134
|
+
export { buildDistillPrompt, resolveDistillDataScope };
|
|
320
135
|
|
|
321
136
|
function toFiniteNumber(value, fallback) {
|
|
322
137
|
const parsed = Number(value);
|
|
@@ -328,7 +143,7 @@ function clampDistillTopK(value) {
|
|
|
328
143
|
return Math.max(2, Math.min(256, parsed));
|
|
329
144
|
}
|
|
330
145
|
|
|
331
|
-
function normalizeDistillStudentGraphMode(value) {
|
|
146
|
+
export function normalizeDistillStudentGraphMode(value) {
|
|
332
147
|
const normalized = normalizeOptionalString(value);
|
|
333
148
|
if (!normalized) return DISTILL_STUDENT_GRAPH_FULL;
|
|
334
149
|
const compact = normalized.toLowerCase().replace(/[-\s]/g, '_');
|
|
@@ -605,7 +420,7 @@ function createDistillTensorDataset(samples, options = {}) {
|
|
|
605
420
|
};
|
|
606
421
|
}
|
|
607
422
|
|
|
608
|
-
async function loadDistillDatasetFromJsonl(datasetPath, scopeOptions = null) {
|
|
423
|
+
export async function loadDistillDatasetFromJsonl(datasetPath, scopeOptions = null) {
|
|
609
424
|
const normalizedPath = normalizeDistillDatasetPath(datasetPath);
|
|
610
425
|
if (!normalizedPath) return null;
|
|
611
426
|
if (!isNodeRuntime()) {
|
|
@@ -820,7 +635,7 @@ async function initializeInferenceFromStore(modelId) {
|
|
|
820
635
|
return { pipeline, manifest };
|
|
821
636
|
}
|
|
822
637
|
|
|
823
|
-
async function loadDistillModelHandle(modelRef, role, loadOptions = {}) {
|
|
638
|
+
export async function loadDistillModelHandle(modelRef, role, loadOptions = {}) {
|
|
824
639
|
const normalizedRef = normalizeOptionalString(modelRef);
|
|
825
640
|
if (!normalizedRef) {
|
|
826
641
|
throw new Error(`Distill ${role} model reference is required.`);
|
|
@@ -876,7 +691,7 @@ function resolveDistillModelRefs(options = {}, trainingConfig = null) {
|
|
|
876
691
|
};
|
|
877
692
|
}
|
|
878
693
|
|
|
879
|
-
async function createDistillRuntimeContext(options = {}, trainingConfig = null) {
|
|
694
|
+
export async function createDistillRuntimeContext(options = {}, trainingConfig = null) {
|
|
880
695
|
const { teacherModelRef, studentModelRef } = resolveDistillModelRefs(options, trainingConfig);
|
|
881
696
|
if (!teacherModelRef || !studentModelRef) {
|
|
882
697
|
throw new Error('Distill stage requires teacherModelId and studentModelId.');
|
|
@@ -967,7 +782,7 @@ async function ensureTrainingGpuRuntime() {
|
|
|
967
782
|
await initDevice();
|
|
968
783
|
}
|
|
969
784
|
|
|
970
|
-
function createToyModelFixture(overrides = {}) {
|
|
785
|
+
export function createToyModelFixture(overrides = {}) {
|
|
971
786
|
const config = createTrainingConfig({
|
|
972
787
|
...overrides,
|
|
973
788
|
training: {
|
|
@@ -1040,770 +855,7 @@ function createToyModelFixture(overrides = {}) {
|
|
|
1040
855
|
};
|
|
1041
856
|
}
|
|
1042
857
|
|
|
1043
|
-
|
|
1044
|
-
const dtype = isWeightBuffer(value)
|
|
1045
|
-
? value.dtype
|
|
1046
|
-
: (value?.dtype || getWeightDtype(value) || null);
|
|
1047
|
-
const normalized = String(dtype || '').toLowerCase();
|
|
1048
|
-
return normalized === 'f16' ? 'f16' : (normalized === 'f32' ? 'f32' : fallback);
|
|
1049
|
-
}
|
|
1050
|
-
|
|
1051
|
-
async function ensureTrainableTensor(value, shape, label, ownedTrainables = null) {
|
|
1052
|
-
if (!value) {
|
|
1053
|
-
throw new Error(`Distill full-graph student missing required weight "${label}".`);
|
|
1054
|
-
}
|
|
1055
|
-
const registerOwned = (tensor) => {
|
|
1056
|
-
if (ownedTrainables instanceof Set && tensor?.buffer instanceof GPUBuffer) {
|
|
1057
|
-
ownedTrainables.add(tensor);
|
|
1058
|
-
}
|
|
1059
|
-
return tensor;
|
|
1060
|
-
};
|
|
1061
|
-
if (isWeightBuffer(value)) {
|
|
1062
|
-
if (value.dtype === 'f32') {
|
|
1063
|
-
return value;
|
|
1064
|
-
}
|
|
1065
|
-
if (value.dtype === 'f16') {
|
|
1066
|
-
const sourceShape = Array.isArray(value.shape) && value.shape.length > 0 ? value.shape : [...shape];
|
|
1067
|
-
const source = createTensor(value.buffer, 'f16', sourceShape, `${label}_source_f16`);
|
|
1068
|
-
const promoted = await castF16ToF32(source);
|
|
1069
|
-
return registerOwned(createTensor(promoted.buffer, 'f32', sourceShape, `${label}_trainable_f32`));
|
|
1070
|
-
}
|
|
1071
|
-
throw new Error(`Distill full-graph student weight "${label}" uses unsupported dtype "${value.dtype}".`);
|
|
1072
|
-
}
|
|
1073
|
-
if (value instanceof GPUBuffer) {
|
|
1074
|
-
const sourceShape = [...shape];
|
|
1075
|
-
const rawDtype = String(getBufferDtype(value) || 'f32').toLowerCase();
|
|
1076
|
-
const dtype = rawDtype === 'f16' ? 'f16' : 'f32';
|
|
1077
|
-
const tensor = createTensor(value, dtype, sourceShape, label);
|
|
1078
|
-
if (dtype === 'f16') {
|
|
1079
|
-
const promoted = await castF16ToF32(tensor);
|
|
1080
|
-
return registerOwned(createTensor(promoted.buffer, 'f32', sourceShape, `${label}_trainable_f32`));
|
|
1081
|
-
}
|
|
1082
|
-
return tensor;
|
|
1083
|
-
}
|
|
1084
|
-
if (isCpuWeightBuffer(value)) {
|
|
1085
|
-
const sourceShape = Array.isArray(value.shape) && value.shape.length > 0 ? value.shape : [...shape];
|
|
1086
|
-
const dtype = resolveTensorDtype(value, 'f32');
|
|
1087
|
-
if (dtype === 'f32') {
|
|
1088
|
-
const tensor = makeTensorFromFloat32(value.data, sourceShape, `${label}_cpu_f32`);
|
|
1089
|
-
return registerOwned(tensor);
|
|
1090
|
-
}
|
|
1091
|
-
if (dtype === 'f16') {
|
|
1092
|
-
let raw = null;
|
|
1093
|
-
if (value.data instanceof Uint16Array) {
|
|
1094
|
-
raw = value.data;
|
|
1095
|
-
} else if (ArrayBuffer.isView(value.data)) {
|
|
1096
|
-
raw = new Uint16Array(
|
|
1097
|
-
value.data.buffer,
|
|
1098
|
-
value.data.byteOffset,
|
|
1099
|
-
Math.floor(value.data.byteLength / 2)
|
|
1100
|
-
);
|
|
1101
|
-
} else if (value.data instanceof ArrayBuffer) {
|
|
1102
|
-
raw = new Uint16Array(value.data);
|
|
1103
|
-
}
|
|
1104
|
-
if (!raw) {
|
|
1105
|
-
throw new Error(`Distill full-graph student weight "${label}" has non-typed f16 CPU data.`);
|
|
1106
|
-
}
|
|
1107
|
-
const source = makeTensorFromF16Bits(raw, sourceShape, `${label}_cpu_f16`);
|
|
1108
|
-
const promoted = await castF16ToF32(source);
|
|
1109
|
-
releaseTensor(source);
|
|
1110
|
-
return registerOwned(createTensor(promoted.buffer, 'f32', sourceShape, `${label}_trainable_f32`));
|
|
1111
|
-
}
|
|
1112
|
-
throw new Error(`Distill full-graph student weight "${label}" has unsupported CPU dtype "${dtype}".`);
|
|
1113
|
-
}
|
|
1114
|
-
if (value.buffer instanceof GPUBuffer) {
|
|
1115
|
-
const resolvedShape = Array.isArray(value.shape) && value.shape.length > 0 ? value.shape : [...shape];
|
|
1116
|
-
const tensor = createTensor(
|
|
1117
|
-
value.buffer,
|
|
1118
|
-
resolveTensorDtype(value, 'f32'),
|
|
1119
|
-
resolvedShape,
|
|
1120
|
-
label
|
|
1121
|
-
);
|
|
1122
|
-
if (tensor.dtype === 'f16') {
|
|
1123
|
-
const promoted = await castF16ToF32(tensor);
|
|
1124
|
-
return registerOwned(createTensor(promoted.buffer, 'f32', resolvedShape, `${label}_trainable_f32`));
|
|
1125
|
-
}
|
|
1126
|
-
return tensor;
|
|
1127
|
-
}
|
|
1128
|
-
throw new Error(`Distill full-graph student weight "${label}" is not GPU-resident.`);
|
|
1129
|
-
}
|
|
1130
|
-
|
|
1131
|
-
async function ensureNormTensor(value, hiddenSize, label, ownedTrainables = null) {
|
|
1132
|
-
return ensureTrainableTensor(value, [hiddenSize], label, ownedTrainables);
|
|
1133
|
-
}
|
|
1134
|
-
|
|
1135
|
-
function hasTensorPayload(value) {
|
|
1136
|
-
if (!value) return false;
|
|
1137
|
-
if (value instanceof GPUBuffer) return true;
|
|
1138
|
-
if (isWeightBuffer(value) || isCpuWeightBuffer(value)) return true;
|
|
1139
|
-
if (value?.buffer instanceof GPUBuffer) return true;
|
|
1140
|
-
if (ArrayBuffer.isView(value) || Array.isArray(value)) return true;
|
|
1141
|
-
return false;
|
|
1142
|
-
}
|
|
1143
|
-
|
|
1144
|
-
async function fuseGateUpTensors(gateTensor, upTensor, intermediateSize, hiddenSize, label, ownedTrainables = null) {
|
|
1145
|
-
const device = getDevice();
|
|
1146
|
-
if (!device) {
|
|
1147
|
-
throw new Error('Distill full-graph student requires active GPU device.');
|
|
1148
|
-
}
|
|
1149
|
-
if (gateTensor?.dtype !== 'f32' || upTensor?.dtype !== 'f32') {
|
|
1150
|
-
throw new Error(`Distill fused gate_up expects f32 tensors for "${label}".`);
|
|
1151
|
-
}
|
|
1152
|
-
const expectedRows = intermediateSize;
|
|
1153
|
-
const expectedCols = hiddenSize;
|
|
1154
|
-
const gateRows = Number.isFinite(gateTensor?.shape?.[0]) ? gateTensor.shape[0] : 0;
|
|
1155
|
-
const gateCols = Number.isFinite(gateTensor?.shape?.[1]) ? gateTensor.shape[1] : 0;
|
|
1156
|
-
const upRows = Number.isFinite(upTensor?.shape?.[0]) ? upTensor.shape[0] : 0;
|
|
1157
|
-
const upCols = Number.isFinite(upTensor?.shape?.[1]) ? upTensor.shape[1] : 0;
|
|
1158
|
-
if (gateRows !== expectedRows || gateCols !== expectedCols || upRows !== expectedRows || upCols !== expectedCols) {
|
|
1159
|
-
throw new Error(
|
|
1160
|
-
`Distill gate/up shape mismatch for "${label}": gate=[${gateRows},${gateCols}] up=[${upRows},${upCols}] ` +
|
|
1161
|
-
`expected=[${expectedRows},${expectedCols}]`
|
|
1162
|
-
);
|
|
1163
|
-
}
|
|
1164
|
-
const rowBytes = expectedCols * 4;
|
|
1165
|
-
const blockBytes = expectedRows * rowBytes;
|
|
1166
|
-
const fusedBuffer = acquireBuffer(blockBytes * 2, undefined, `${label}_fused`);
|
|
1167
|
-
const encoder = device.createCommandEncoder();
|
|
1168
|
-
encoder.copyBufferToBuffer(gateTensor.buffer, 0, fusedBuffer, 0, blockBytes);
|
|
1169
|
-
encoder.copyBufferToBuffer(upTensor.buffer, 0, fusedBuffer, blockBytes, blockBytes);
|
|
1170
|
-
device.queue.submit([encoder.finish()]);
|
|
1171
|
-
const fused = createTensor(fusedBuffer, 'f32', [expectedRows * 2, expectedCols], `${label}_fused`);
|
|
1172
|
-
if (ownedTrainables instanceof Set) {
|
|
1173
|
-
ownedTrainables.add(fused);
|
|
1174
|
-
}
|
|
1175
|
-
return fused;
|
|
1176
|
-
}
|
|
1177
|
-
|
|
1178
|
-
function resolvePhasePrompts(batch, phase) {
|
|
1179
|
-
const distill = batch?.distill || {};
|
|
1180
|
-
const prompts = phase === 'positive'
|
|
1181
|
-
? distill.tripletPositivePrompts
|
|
1182
|
-
: (phase === 'negative' ? distill.tripletNegativePrompts : distill.prompts);
|
|
1183
|
-
if (!Array.isArray(prompts) || prompts.length === 0) {
|
|
1184
|
-
throw new Error(`Distill student fixture requires distill prompts for phase "${phase}".`);
|
|
1185
|
-
}
|
|
1186
|
-
return prompts;
|
|
1187
|
-
}
|
|
1188
|
-
|
|
1189
|
-
function createRowSliceTensor(inputTensor, rows, cols, rowIndex, label) {
|
|
1190
|
-
const device = getDevice();
|
|
1191
|
-
if (!device) {
|
|
1192
|
-
throw new Error('Distill full-graph student requires active GPU device.');
|
|
1193
|
-
}
|
|
1194
|
-
const dtype = inputTensor?.dtype === 'f16' ? 'f16' : 'f32';
|
|
1195
|
-
const bytesPerElement = dtype === 'f16' ? 2 : 4;
|
|
1196
|
-
const rowBytes = cols * bytesPerElement;
|
|
1197
|
-
const clampedRow = Math.max(0, Math.min(rows - 1, rowIndex));
|
|
1198
|
-
const outputBuffer = acquireBuffer(rowBytes, undefined, label);
|
|
1199
|
-
const encoder = device.createCommandEncoder();
|
|
1200
|
-
encoder.copyBufferToBuffer(
|
|
1201
|
-
inputTensor.buffer,
|
|
1202
|
-
clampedRow * rowBytes,
|
|
1203
|
-
outputBuffer,
|
|
1204
|
-
0,
|
|
1205
|
-
rowBytes
|
|
1206
|
-
);
|
|
1207
|
-
device.queue.submit([encoder.finish()]);
|
|
1208
|
-
return createTensor(outputBuffer, dtype, [1, cols], label);
|
|
1209
|
-
}
|
|
1210
|
-
|
|
1211
|
-
function createDistillStudentProjectionModelFixture(overrides = {}, options = {}) {
|
|
1212
|
-
const distillRuntime = options.distillRuntime && typeof options.distillRuntime === 'object'
|
|
1213
|
-
? options.distillRuntime
|
|
1214
|
-
: null;
|
|
1215
|
-
if (!distillRuntime?.studentPipeline) {
|
|
1216
|
-
throw new Error('Distill student fixture requires distillRuntime.studentPipeline.');
|
|
1217
|
-
}
|
|
1218
|
-
const outputDim = clampDistillTopK(
|
|
1219
|
-
options.outputDim
|
|
1220
|
-
?? options.inputDim
|
|
1221
|
-
?? DISTILL_ADAPTER_TOP_K
|
|
1222
|
-
);
|
|
1223
|
-
const inferredEmbeddingDim = Math.floor(
|
|
1224
|
-
Number(distillRuntime.studentPipeline?.modelConfig?.hiddenSize)
|
|
1225
|
-
);
|
|
1226
|
-
const embeddingDim = Number.isInteger(options.embeddingDim) && options.embeddingDim > 0
|
|
1227
|
-
? options.embeddingDim
|
|
1228
|
-
: (Number.isFinite(inferredEmbeddingDim) && inferredEmbeddingDim > 0
|
|
1229
|
-
? inferredEmbeddingDim
|
|
1230
|
-
: outputDim);
|
|
1231
|
-
const config = createTrainingConfig({
|
|
1232
|
-
...overrides,
|
|
1233
|
-
training: {
|
|
1234
|
-
enabled: true,
|
|
1235
|
-
lossScaling: { enabled: false },
|
|
1236
|
-
gradient: { maxNorm: 0 },
|
|
1237
|
-
...(overrides.training || {}),
|
|
1238
|
-
},
|
|
1239
|
-
});
|
|
1240
|
-
|
|
1241
|
-
const projectionWeights = new Float32Array(embeddingDim * outputDim);
|
|
1242
|
-
const projectionWeight = makeTensorFromFloat32(
|
|
1243
|
-
projectionWeights,
|
|
1244
|
-
[embeddingDim, outputDim],
|
|
1245
|
-
'distill_student_head_weight'
|
|
1246
|
-
);
|
|
1247
|
-
const temporaryInputs = new Set();
|
|
1248
|
-
|
|
1249
|
-
async function projectEmbeddingInput(inputTensor, tape) {
|
|
1250
|
-
const rows = Number.isFinite(inputTensor?.shape?.[0]) ? inputTensor.shape[0] : 1;
|
|
1251
|
-
return tape.record(
|
|
1252
|
-
OpType.MATMUL,
|
|
1253
|
-
(a, b) => runMatmul(a, b, rows, outputDim, embeddingDim, { transposeB: false }),
|
|
1254
|
-
[inputTensor, projectionWeight],
|
|
1255
|
-
{ M: rows, N: outputDim, K: embeddingDim, transposeB: false }
|
|
1256
|
-
);
|
|
1257
|
-
}
|
|
1258
|
-
|
|
1259
|
-
async function buildStudentEmbeddingInput(batch, phase = 'anchor') {
|
|
1260
|
-
const distill = batch?.distill || {};
|
|
1261
|
-
const prompts = phase === 'positive'
|
|
1262
|
-
? distill.tripletPositivePrompts
|
|
1263
|
-
: (phase === 'negative' ? distill.tripletNegativePrompts : distill.prompts);
|
|
1264
|
-
if (!Array.isArray(prompts) || prompts.length === 0) {
|
|
1265
|
-
throw new Error(`Distill student fixture requires distill prompts for phase "${phase}".`);
|
|
1266
|
-
}
|
|
1267
|
-
|
|
1268
|
-
const rows = prompts.length;
|
|
1269
|
-
const features = new Float32Array(rows * embeddingDim);
|
|
1270
|
-
for (let row = 0; row < rows; row += 1) {
|
|
1271
|
-
const prompt = String(prompts[row] || '').trim();
|
|
1272
|
-
const studentResult = await distillRuntime.studentPipeline.prefillWithEmbedding(prompt, {
|
|
1273
|
-
useChatTemplate: false,
|
|
1274
|
-
embeddingMode: 'last',
|
|
1275
|
-
});
|
|
1276
|
-
try {
|
|
1277
|
-
const studentEmbedding = toFloat32Array(studentResult?.embedding, 'student embedding');
|
|
1278
|
-
const rowOffset = row * embeddingDim;
|
|
1279
|
-
const copyCount = Math.min(embeddingDim, studentEmbedding.length);
|
|
1280
|
-
features.set(studentEmbedding.subarray(0, copyCount), rowOffset);
|
|
1281
|
-
} finally {
|
|
1282
|
-
disposePrefillSnapshot(studentResult);
|
|
1283
|
-
distillRuntime.studentPipeline.reset();
|
|
1284
|
-
}
|
|
1285
|
-
}
|
|
1286
|
-
const inputTensor = makeTensorFromFloat32(
|
|
1287
|
-
features,
|
|
1288
|
-
[rows, embeddingDim],
|
|
1289
|
-
`distill_student_${phase}_embedding`
|
|
1290
|
-
);
|
|
1291
|
-
temporaryInputs.add(inputTensor);
|
|
1292
|
-
return inputTensor;
|
|
1293
|
-
}
|
|
1294
|
-
|
|
1295
|
-
const model = {
|
|
1296
|
-
async forward(inputTensor, tape) {
|
|
1297
|
-
return projectEmbeddingInput(inputTensor, tape);
|
|
1298
|
-
},
|
|
1299
|
-
async forwardDistill(batch, tape, forwardOptions = {}) {
|
|
1300
|
-
const requestedPhase = String(forwardOptions?.phase || 'anchor').trim();
|
|
1301
|
-
const phase = requestedPhase === 'positive'
|
|
1302
|
-
? 'positive'
|
|
1303
|
-
: (requestedPhase === 'negative' ? 'negative' : 'anchor');
|
|
1304
|
-
const inputTensor = await buildStudentEmbeddingInput(batch, phase);
|
|
1305
|
-
const logits = await projectEmbeddingInput(inputTensor, tape);
|
|
1306
|
-
return { logits };
|
|
1307
|
-
},
|
|
1308
|
-
cleanupDistillStep() {
|
|
1309
|
-
for (const tensor of temporaryInputs) {
|
|
1310
|
-
releaseTensor(tensor);
|
|
1311
|
-
}
|
|
1312
|
-
temporaryInputs.clear();
|
|
1313
|
-
},
|
|
1314
|
-
loraParams() {
|
|
1315
|
-
return [projectionWeight];
|
|
1316
|
-
},
|
|
1317
|
-
paramGroups() {
|
|
1318
|
-
return {
|
|
1319
|
-
encoder: [],
|
|
1320
|
-
prior: [],
|
|
1321
|
-
decoder: [],
|
|
1322
|
-
base: [projectionWeight],
|
|
1323
|
-
lora: [projectionWeight],
|
|
1324
|
-
};
|
|
1325
|
-
},
|
|
1326
|
-
};
|
|
1327
|
-
|
|
1328
|
-
return {
|
|
1329
|
-
config,
|
|
1330
|
-
model,
|
|
1331
|
-
outputDim,
|
|
1332
|
-
embeddingDim,
|
|
1333
|
-
cleanup() {
|
|
1334
|
-
model.cleanupDistillStep();
|
|
1335
|
-
releaseTensor(projectionWeight);
|
|
1336
|
-
},
|
|
1337
|
-
};
|
|
1338
|
-
}
|
|
1339
|
-
|
|
1340
|
-
async function createDistillStudentTransformerModelFixture(overrides = {}, options = {}) {
|
|
1341
|
-
const distillRuntime = options.distillRuntime && typeof options.distillRuntime === 'object'
|
|
1342
|
-
? options.distillRuntime
|
|
1343
|
-
: null;
|
|
1344
|
-
const studentPipeline = distillRuntime?.studentPipeline || null;
|
|
1345
|
-
if (!studentPipeline?.modelConfig || !(studentPipeline.weights instanceof Map)) {
|
|
1346
|
-
throw new Error('Distill full-graph student fixture requires loaded student pipeline weights.');
|
|
1347
|
-
}
|
|
1348
|
-
const modelConfig = studentPipeline.modelConfig;
|
|
1349
|
-
const hiddenSize = Math.max(1, Math.floor(Number(modelConfig.hiddenSize) || 0));
|
|
1350
|
-
const intermediateSize = Math.max(1, Math.floor(Number(modelConfig.intermediateSize) || 0));
|
|
1351
|
-
const numLayers = Math.max(1, Math.floor(Number(modelConfig.numLayers) || 0));
|
|
1352
|
-
const numHeads = Math.max(1, Math.floor(Number(modelConfig.numHeads) || 0));
|
|
1353
|
-
const numKVHeads = Math.max(1, Math.floor(Number(modelConfig.numKVHeads || numHeads) || 0));
|
|
1354
|
-
const headDim = Math.max(1, Math.floor(Number(modelConfig.headDim) || 0));
|
|
1355
|
-
const vocabSize = Math.max(1, Math.floor(Number(modelConfig.vocabSize) || 0));
|
|
1356
|
-
const rmsNormEps = Number.isFinite(modelConfig.rmsNormEps) ? modelConfig.rmsNormEps : 1e-6;
|
|
1357
|
-
const hiddenActivation = String(modelConfig.hiddenActivation || 'silu').toLowerCase();
|
|
1358
|
-
const swigluLimit = Number.isFinite(modelConfig.swigluLimit) ? modelConfig.swigluLimit : 0;
|
|
1359
|
-
const useEmbeddingTranspose = modelConfig.embeddingTranspose === true;
|
|
1360
|
-
const tieWordEmbeddings = modelConfig.useTiedEmbeddings === true;
|
|
1361
|
-
|
|
1362
|
-
const config = createTrainingConfig({
|
|
1363
|
-
...overrides,
|
|
1364
|
-
training: {
|
|
1365
|
-
enabled: true,
|
|
1366
|
-
lossScaling: { enabled: false },
|
|
1367
|
-
gradient: { maxNorm: 0 },
|
|
1368
|
-
...(overrides.training || {}),
|
|
1369
|
-
},
|
|
1370
|
-
});
|
|
1371
|
-
|
|
1372
|
-
const ownedTrainables = new Set();
|
|
1373
|
-
const embeddingWeight = await ensureTrainableTensor(
|
|
1374
|
-
studentPipeline.weights.get('embed'),
|
|
1375
|
-
[vocabSize, hiddenSize],
|
|
1376
|
-
'embed',
|
|
1377
|
-
ownedTrainables
|
|
1378
|
-
);
|
|
1379
|
-
const lmHeadWeight = tieWordEmbeddings
|
|
1380
|
-
? embeddingWeight
|
|
1381
|
-
: await ensureTrainableTensor(
|
|
1382
|
-
studentPipeline.weights.get('lm_head'),
|
|
1383
|
-
[vocabSize, hiddenSize],
|
|
1384
|
-
'lm_head',
|
|
1385
|
-
ownedTrainables
|
|
1386
|
-
);
|
|
1387
|
-
const finalNormWeight = await ensureNormTensor(
|
|
1388
|
-
studentPipeline.weights.get('final_norm'),
|
|
1389
|
-
hiddenSize,
|
|
1390
|
-
'final_norm',
|
|
1391
|
-
ownedTrainables
|
|
1392
|
-
);
|
|
1393
|
-
|
|
1394
|
-
const ropeDim = Math.max(1, Math.floor(headDim / 2));
|
|
1395
|
-
const ropeRows = Math.max(1, Math.floor(Number(modelConfig.maxSeqLen) || 1));
|
|
1396
|
-
const ropeCos = await ensureTrainableTensor(
|
|
1397
|
-
createTensor(studentPipeline.ropeFreqsCos, 'f32', [ropeRows, ropeDim], 'rope_cos'),
|
|
1398
|
-
[ropeRows, ropeDim],
|
|
1399
|
-
'rope_cos',
|
|
1400
|
-
ownedTrainables
|
|
1401
|
-
);
|
|
1402
|
-
const ropeSin = await ensureTrainableTensor(
|
|
1403
|
-
createTensor(studentPipeline.ropeFreqsSin, 'f32', [ropeRows, ropeDim], 'rope_sin'),
|
|
1404
|
-
[ropeRows, ropeDim],
|
|
1405
|
-
'rope_sin',
|
|
1406
|
-
ownedTrainables
|
|
1407
|
-
);
|
|
1408
|
-
|
|
1409
|
-
const layerParams = [];
|
|
1410
|
-
const layers = [];
|
|
1411
|
-
for (let layerIdx = 0; layerIdx < numLayers; layerIdx += 1) {
|
|
1412
|
-
const layerWeights = studentPipeline.weights.get(`layer_${layerIdx}`);
|
|
1413
|
-
if (!layerWeights) {
|
|
1414
|
-
throw new Error(`Distill full-graph student missing layer_${layerIdx} weights.`);
|
|
1415
|
-
}
|
|
1416
|
-
const gateUpWeight = layerWeights.gateUp || layerWeights.ffnGateUp || null;
|
|
1417
|
-
let layerGateUp = null;
|
|
1418
|
-
if (hasTensorPayload(gateUpWeight)) {
|
|
1419
|
-
layerGateUp = await ensureTrainableTensor(
|
|
1420
|
-
gateUpWeight,
|
|
1421
|
-
[intermediateSize * 2, hiddenSize],
|
|
1422
|
-
`layer_${layerIdx}.ffn_gate_up`,
|
|
1423
|
-
ownedTrainables
|
|
1424
|
-
);
|
|
1425
|
-
} else {
|
|
1426
|
-
const gateWeight = layerWeights.gate || layerWeights.ffnGate || null;
|
|
1427
|
-
const upWeight = layerWeights.up || layerWeights.ffnUp || null;
|
|
1428
|
-
if (!hasTensorPayload(gateWeight) || !hasTensorPayload(upWeight)) {
|
|
1429
|
-
throw new Error(
|
|
1430
|
-
`Distill full-graph student missing gate/up projections on layer ${layerIdx}.`
|
|
1431
|
-
);
|
|
1432
|
-
}
|
|
1433
|
-
const gateTensor = await ensureTrainableTensor(
|
|
1434
|
-
gateWeight,
|
|
1435
|
-
[intermediateSize, hiddenSize],
|
|
1436
|
-
`layer_${layerIdx}.ffn_gate`,
|
|
1437
|
-
ownedTrainables
|
|
1438
|
-
);
|
|
1439
|
-
const upTensor = await ensureTrainableTensor(
|
|
1440
|
-
upWeight,
|
|
1441
|
-
[intermediateSize, hiddenSize],
|
|
1442
|
-
`layer_${layerIdx}.ffn_up`,
|
|
1443
|
-
ownedTrainables
|
|
1444
|
-
);
|
|
1445
|
-
layerGateUp = await fuseGateUpTensors(
|
|
1446
|
-
gateTensor,
|
|
1447
|
-
upTensor,
|
|
1448
|
-
intermediateSize,
|
|
1449
|
-
hiddenSize,
|
|
1450
|
-
`layer_${layerIdx}.ffn_gate_up`,
|
|
1451
|
-
ownedTrainables
|
|
1452
|
-
);
|
|
1453
|
-
}
|
|
1454
|
-
const layer = {
|
|
1455
|
-
inputNorm: await ensureNormTensor(
|
|
1456
|
-
layerWeights.inputNorm,
|
|
1457
|
-
hiddenSize,
|
|
1458
|
-
`layer_${layerIdx}.input_norm`,
|
|
1459
|
-
ownedTrainables
|
|
1460
|
-
),
|
|
1461
|
-
qProj: await ensureTrainableTensor(
|
|
1462
|
-
layerWeights.qProj,
|
|
1463
|
-
[numHeads * headDim, hiddenSize],
|
|
1464
|
-
`layer_${layerIdx}.q_proj`,
|
|
1465
|
-
ownedTrainables
|
|
1466
|
-
),
|
|
1467
|
-
kProj: await ensureTrainableTensor(
|
|
1468
|
-
layerWeights.kProj,
|
|
1469
|
-
[numKVHeads * headDim, hiddenSize],
|
|
1470
|
-
`layer_${layerIdx}.k_proj`,
|
|
1471
|
-
ownedTrainables
|
|
1472
|
-
),
|
|
1473
|
-
vProj: await ensureTrainableTensor(
|
|
1474
|
-
layerWeights.vProj,
|
|
1475
|
-
[numKVHeads * headDim, hiddenSize],
|
|
1476
|
-
`layer_${layerIdx}.v_proj`,
|
|
1477
|
-
ownedTrainables
|
|
1478
|
-
),
|
|
1479
|
-
oProj: await ensureTrainableTensor(
|
|
1480
|
-
layerWeights.oProj,
|
|
1481
|
-
[hiddenSize, hiddenSize],
|
|
1482
|
-
`layer_${layerIdx}.o_proj`,
|
|
1483
|
-
ownedTrainables
|
|
1484
|
-
),
|
|
1485
|
-
postAttentionNorm: layerWeights.postAttentionNorm
|
|
1486
|
-
? await ensureNormTensor(
|
|
1487
|
-
layerWeights.postAttentionNorm,
|
|
1488
|
-
hiddenSize,
|
|
1489
|
-
`layer_${layerIdx}.post_attention_norm`,
|
|
1490
|
-
ownedTrainables
|
|
1491
|
-
)
|
|
1492
|
-
: null,
|
|
1493
|
-
gateUp: layerGateUp,
|
|
1494
|
-
down: await ensureTrainableTensor(
|
|
1495
|
-
layerWeights.down || layerWeights.ffnDown,
|
|
1496
|
-
[hiddenSize, intermediateSize],
|
|
1497
|
-
`layer_${layerIdx}.ffn_down`,
|
|
1498
|
-
ownedTrainables
|
|
1499
|
-
),
|
|
1500
|
-
};
|
|
1501
|
-
layers.push(layer);
|
|
1502
|
-
layerParams.push(layer.inputNorm, layer.qProj, layer.kProj, layer.vProj, layer.oProj, layer.gateUp, layer.down);
|
|
1503
|
-
if (layer.postAttentionNorm) {
|
|
1504
|
-
layerParams.push(layer.postAttentionNorm);
|
|
1505
|
-
}
|
|
1506
|
-
}
|
|
1507
|
-
|
|
1508
|
-
const encoderParams = [embeddingWeight, ...layerParams];
|
|
1509
|
-
const decoderParams = [finalNormWeight, lmHeadWeight];
|
|
1510
|
-
const baseParams = [...encoderParams, ...decoderParams];
|
|
1511
|
-
const temporaryInputs = new Set();
|
|
1512
|
-
|
|
1513
|
-
async function buildPromptTokens(prompt) {
|
|
1514
|
-
const normalized = String(prompt || '').trim();
|
|
1515
|
-
if (!normalized) {
|
|
1516
|
-
throw new Error('Distill full-graph student prompt is empty.');
|
|
1517
|
-
}
|
|
1518
|
-
const tokenIds = studentPipeline.tokenizer.encode(normalized);
|
|
1519
|
-
if (!Array.isArray(tokenIds) || tokenIds.length === 0) {
|
|
1520
|
-
throw new Error('Distill full-graph student tokenizer produced no tokens.');
|
|
1521
|
-
}
|
|
1522
|
-
const tokenTensor = makeTensorFromUint32(
|
|
1523
|
-
tokenIds,
|
|
1524
|
-
[tokenIds.length],
|
|
1525
|
-
'distill_student_prompt_tokens'
|
|
1526
|
-
);
|
|
1527
|
-
temporaryInputs.add(tokenTensor);
|
|
1528
|
-
return { tokenTensor, seqLen: tokenIds.length };
|
|
1529
|
-
}
|
|
1530
|
-
|
|
1531
|
-
async function runTransformerPrompt(prompt, tape) {
|
|
1532
|
-
const { tokenTensor, seqLen } = await buildPromptTokens(prompt);
|
|
1533
|
-
let hidden = await tape.record(
|
|
1534
|
-
OpType.EMBED,
|
|
1535
|
-
(indices, embeddings) => runGather(
|
|
1536
|
-
indices,
|
|
1537
|
-
embeddings,
|
|
1538
|
-
seqLen,
|
|
1539
|
-
hiddenSize,
|
|
1540
|
-
vocabSize,
|
|
1541
|
-
{
|
|
1542
|
-
embeddingDtype: resolveTensorDtype(embeddingWeight, 'f32'),
|
|
1543
|
-
outputDtype: 'f32',
|
|
1544
|
-
transpose: useEmbeddingTranspose,
|
|
1545
|
-
}
|
|
1546
|
-
),
|
|
1547
|
-
[tokenTensor, embeddingWeight],
|
|
1548
|
-
{
|
|
1549
|
-
numTokens: seqLen,
|
|
1550
|
-
hiddenSize,
|
|
1551
|
-
vocabSize,
|
|
1552
|
-
transpose: useEmbeddingTranspose,
|
|
1553
|
-
indexOffset: 0,
|
|
1554
|
-
}
|
|
1555
|
-
);
|
|
1556
|
-
|
|
1557
|
-
for (let layerIdx = 0; layerIdx < layers.length; layerIdx += 1) {
|
|
1558
|
-
const layer = layers[layerIdx];
|
|
1559
|
-
const normed = await tape.record(
|
|
1560
|
-
OpType.RMSNORM,
|
|
1561
|
-
(x, gamma) => runRMSNorm(x, gamma, rmsNormEps, {
|
|
1562
|
-
batchSize: seqLen,
|
|
1563
|
-
hiddenSize,
|
|
1564
|
-
rmsNormWeightOffset: modelConfig.rmsNormWeightOffset === true,
|
|
1565
|
-
}),
|
|
1566
|
-
[hidden, layer.inputNorm],
|
|
1567
|
-
{ numTokens: seqLen, hiddenSize, eps: rmsNormEps }
|
|
1568
|
-
);
|
|
1569
|
-
|
|
1570
|
-
const q2d = await tape.record(
|
|
1571
|
-
OpType.MATMUL,
|
|
1572
|
-
(x, w) => runMatmul(x, w, seqLen, numHeads * headDim, hiddenSize, {
|
|
1573
|
-
transposeB: 'auto',
|
|
1574
|
-
outputDtype: 'f32',
|
|
1575
|
-
}),
|
|
1576
|
-
[normed, layer.qProj],
|
|
1577
|
-
{ M: seqLen, N: numHeads * headDim, K: hiddenSize, transposeB: 'auto' }
|
|
1578
|
-
);
|
|
1579
|
-
const k2d = await tape.record(
|
|
1580
|
-
OpType.MATMUL,
|
|
1581
|
-
(x, w) => runMatmul(x, w, seqLen, numKVHeads * headDim, hiddenSize, {
|
|
1582
|
-
transposeB: 'auto',
|
|
1583
|
-
outputDtype: 'f32',
|
|
1584
|
-
}),
|
|
1585
|
-
[normed, layer.kProj],
|
|
1586
|
-
{ M: seqLen, N: numKVHeads * headDim, K: hiddenSize, transposeB: 'auto' }
|
|
1587
|
-
);
|
|
1588
|
-
const v2d = await tape.record(
|
|
1589
|
-
OpType.MATMUL,
|
|
1590
|
-
(x, w) => runMatmul(x, w, seqLen, numKVHeads * headDim, hiddenSize, {
|
|
1591
|
-
transposeB: 'auto',
|
|
1592
|
-
outputDtype: 'f32',
|
|
1593
|
-
}),
|
|
1594
|
-
[normed, layer.vProj],
|
|
1595
|
-
{ M: seqLen, N: numKVHeads * headDim, K: hiddenSize, transposeB: 'auto' }
|
|
1596
|
-
);
|
|
1597
|
-
|
|
1598
|
-
const q3d = createTensor(q2d.buffer, q2d.dtype, [seqLen, numHeads, headDim], `layer_${layerIdx}_q`);
|
|
1599
|
-
const k3d = createTensor(k2d.buffer, k2d.dtype, [seqLen, numKVHeads, headDim], `layer_${layerIdx}_k`);
|
|
1600
|
-
const v3d = createTensor(v2d.buffer, v2d.dtype, [seqLen, numKVHeads, headDim], `layer_${layerIdx}_v`);
|
|
1601
|
-
|
|
1602
|
-
const qRope = await tape.record(
|
|
1603
|
-
OpType.ROPE,
|
|
1604
|
-
(q, cos, sin) => runRoPE(q, cos, sin, seqLen, { numHeads, headDim, startPos: 0 }),
|
|
1605
|
-
[q3d, ropeCos, ropeSin],
|
|
1606
|
-
{ seqLen, numHeads, headDim, startPos: 0 }
|
|
1607
|
-
);
|
|
1608
|
-
const kRope = await tape.record(
|
|
1609
|
-
OpType.ROPE,
|
|
1610
|
-
(k, cos, sin) => runRoPE(k, cos, sin, seqLen, { numHeads: numKVHeads, headDim, startPos: 0 }),
|
|
1611
|
-
[k3d, ropeCos, ropeSin],
|
|
1612
|
-
{ seqLen, numHeads: numKVHeads, headDim, startPos: 0 }
|
|
1613
|
-
);
|
|
1614
|
-
|
|
1615
|
-
const attention = await tape.record(
|
|
1616
|
-
OpType.ATTENTION,
|
|
1617
|
-
(q, k, v) => runAttention(q, k, v, null, numHeads, headDim, {
|
|
1618
|
-
seqLen,
|
|
1619
|
-
kvLen: seqLen,
|
|
1620
|
-
numKVHeads,
|
|
1621
|
-
causal: true,
|
|
1622
|
-
startPos: 0,
|
|
1623
|
-
scale: 1 / Math.sqrt(headDim),
|
|
1624
|
-
}),
|
|
1625
|
-
[qRope, kRope, v3d],
|
|
1626
|
-
{ seqLen, numHeads, headDim, scale: 1 / Math.sqrt(headDim), causal: true, recomputeForward: true }
|
|
1627
|
-
);
|
|
1628
|
-
const attention2d = createTensor(
|
|
1629
|
-
attention.buffer,
|
|
1630
|
-
attention.dtype,
|
|
1631
|
-
[seqLen, hiddenSize],
|
|
1632
|
-
`layer_${layerIdx}_attn_2d`
|
|
1633
|
-
);
|
|
1634
|
-
|
|
1635
|
-
const attentionOutput = await tape.record(
|
|
1636
|
-
OpType.MATMUL,
|
|
1637
|
-
(x, w) => runMatmul(x, w, seqLen, hiddenSize, hiddenSize, {
|
|
1638
|
-
transposeB: 'auto',
|
|
1639
|
-
outputDtype: 'f32',
|
|
1640
|
-
}),
|
|
1641
|
-
[attention2d, layer.oProj],
|
|
1642
|
-
{ M: seqLen, N: hiddenSize, K: hiddenSize, transposeB: 'auto' }
|
|
1643
|
-
);
|
|
1644
|
-
const postAttention = await tape.record(
|
|
1645
|
-
OpType.RESIDUAL_ADD,
|
|
1646
|
-
(a, b) => runResidualAdd(a, b, seqLen * hiddenSize),
|
|
1647
|
-
[attentionOutput, hidden],
|
|
1648
|
-
{ size: seqLen * hiddenSize }
|
|
1649
|
-
);
|
|
1650
|
-
|
|
1651
|
-
const ffnInput = layer.postAttentionNorm
|
|
1652
|
-
? await tape.record(
|
|
1653
|
-
OpType.RMSNORM,
|
|
1654
|
-
(x, gamma) => runRMSNorm(x, gamma, rmsNormEps, {
|
|
1655
|
-
batchSize: seqLen,
|
|
1656
|
-
hiddenSize,
|
|
1657
|
-
rmsNormWeightOffset: modelConfig.rmsNormWeightOffset === true,
|
|
1658
|
-
}),
|
|
1659
|
-
[postAttention, layer.postAttentionNorm],
|
|
1660
|
-
{ numTokens: seqLen, hiddenSize, eps: rmsNormEps }
|
|
1661
|
-
)
|
|
1662
|
-
: postAttention;
|
|
1663
|
-
const gateUp = await tape.record(
|
|
1664
|
-
OpType.MATMUL,
|
|
1665
|
-
(x, w) => runMatmul(x, w, seqLen, intermediateSize * 2, hiddenSize, {
|
|
1666
|
-
transposeB: 'auto',
|
|
1667
|
-
outputDtype: 'f32',
|
|
1668
|
-
}),
|
|
1669
|
-
[ffnInput, layer.gateUp],
|
|
1670
|
-
{ M: seqLen, N: intermediateSize * 2, K: hiddenSize, transposeB: 'auto' }
|
|
1671
|
-
);
|
|
1672
|
-
const activated = await tape.record(
|
|
1673
|
-
OpType.SILU_ROWSPLIT,
|
|
1674
|
-
(x) => runSiLURowSplit(x, {
|
|
1675
|
-
numTokens: seqLen,
|
|
1676
|
-
dim: intermediateSize,
|
|
1677
|
-
activation: hiddenActivation === 'gelu' ? 'gelu' : 'silu',
|
|
1678
|
-
swigluLimit: hiddenActivation === 'gelu' ? null : swigluLimit,
|
|
1679
|
-
}),
|
|
1680
|
-
[gateUp],
|
|
1681
|
-
{
|
|
1682
|
-
numTokens: seqLen,
|
|
1683
|
-
dim: intermediateSize,
|
|
1684
|
-
activation: hiddenActivation === 'gelu' ? 'gelu' : 'silu',
|
|
1685
|
-
swigluLimit: hiddenActivation === 'gelu' ? 0 : swigluLimit,
|
|
1686
|
-
}
|
|
1687
|
-
);
|
|
1688
|
-
const ffnOutput = await tape.record(
|
|
1689
|
-
OpType.MATMUL,
|
|
1690
|
-
(x, w) => runMatmul(x, w, seqLen, hiddenSize, intermediateSize, {
|
|
1691
|
-
transposeB: 'auto',
|
|
1692
|
-
outputDtype: 'f32',
|
|
1693
|
-
}),
|
|
1694
|
-
[activated, layer.down],
|
|
1695
|
-
{ M: seqLen, N: hiddenSize, K: intermediateSize, transposeB: 'auto' }
|
|
1696
|
-
);
|
|
1697
|
-
hidden = await tape.record(
|
|
1698
|
-
OpType.RESIDUAL_ADD,
|
|
1699
|
-
(a, b) => runResidualAdd(a, b, seqLen * hiddenSize),
|
|
1700
|
-
[ffnOutput, postAttention],
|
|
1701
|
-
{ size: seqLen * hiddenSize }
|
|
1702
|
-
);
|
|
1703
|
-
}
|
|
1704
|
-
|
|
1705
|
-
const finalHidden = await tape.record(
|
|
1706
|
-
OpType.RMSNORM,
|
|
1707
|
-
(x, gamma) => runRMSNorm(x, gamma, rmsNormEps, {
|
|
1708
|
-
batchSize: seqLen,
|
|
1709
|
-
hiddenSize,
|
|
1710
|
-
rmsNormWeightOffset: modelConfig.rmsNormWeightOffset === true,
|
|
1711
|
-
}),
|
|
1712
|
-
[hidden, finalNormWeight],
|
|
1713
|
-
{ numTokens: seqLen, hiddenSize, eps: rmsNormEps }
|
|
1714
|
-
);
|
|
1715
|
-
const lastHidden = await tape.record(
|
|
1716
|
-
OpType.ROW_SLICE,
|
|
1717
|
-
(x) => createRowSliceTensor(x, seqLen, hiddenSize, seqLen - 1, 'distill_last_hidden'),
|
|
1718
|
-
[finalHidden],
|
|
1719
|
-
{ rows: seqLen, cols: hiddenSize, rowIndex: seqLen - 1 }
|
|
1720
|
-
);
|
|
1721
|
-
return tape.record(
|
|
1722
|
-
OpType.MATMUL,
|
|
1723
|
-
(x, w) => runMatmul(x, w, 1, vocabSize, hiddenSize, {
|
|
1724
|
-
transposeB: 'auto',
|
|
1725
|
-
outputDtype: 'f32',
|
|
1726
|
-
}),
|
|
1727
|
-
[lastHidden, lmHeadWeight],
|
|
1728
|
-
{ M: 1, N: vocabSize, K: hiddenSize, transposeB: 'auto' }
|
|
1729
|
-
);
|
|
1730
|
-
}
|
|
1731
|
-
|
|
1732
|
-
const model = {
|
|
1733
|
-
async forward(inputTensor, tape) {
|
|
1734
|
-
return tape.record(
|
|
1735
|
-
OpType.MATMUL,
|
|
1736
|
-
(x, w) => runMatmul(x, w, 1, vocabSize, hiddenSize, {
|
|
1737
|
-
transposeB: 'auto',
|
|
1738
|
-
outputDtype: 'f32',
|
|
1739
|
-
}),
|
|
1740
|
-
[inputTensor, lmHeadWeight],
|
|
1741
|
-
{ M: 1, N: vocabSize, K: hiddenSize, transposeB: 'auto' }
|
|
1742
|
-
);
|
|
1743
|
-
},
|
|
1744
|
-
async forwardDistill(batch, tape, forwardOptions = {}) {
|
|
1745
|
-
const requestedPhase = String(forwardOptions?.phase || 'anchor').trim();
|
|
1746
|
-
const phase = requestedPhase === 'positive'
|
|
1747
|
-
? 'positive'
|
|
1748
|
-
: (requestedPhase === 'negative' ? 'negative' : 'anchor');
|
|
1749
|
-
const prompts = resolvePhasePrompts(batch, phase);
|
|
1750
|
-
if (prompts.length !== 1) {
|
|
1751
|
-
throw new Error(
|
|
1752
|
-
`Distill full-graph student currently requires batchSize=1, got ${prompts.length}.`
|
|
1753
|
-
);
|
|
1754
|
-
}
|
|
1755
|
-
const logits = await runTransformerPrompt(prompts[0], tape);
|
|
1756
|
-
return { logits };
|
|
1757
|
-
},
|
|
1758
|
-
cleanupDistillStep() {
|
|
1759
|
-
for (const tensor of temporaryInputs) {
|
|
1760
|
-
releaseTensor(tensor);
|
|
1761
|
-
}
|
|
1762
|
-
temporaryInputs.clear();
|
|
1763
|
-
},
|
|
1764
|
-
loraParams() {
|
|
1765
|
-
return decoderParams;
|
|
1766
|
-
},
|
|
1767
|
-
paramGroups() {
|
|
1768
|
-
return {
|
|
1769
|
-
encoder: encoderParams,
|
|
1770
|
-
prior: [],
|
|
1771
|
-
decoder: decoderParams,
|
|
1772
|
-
base: baseParams,
|
|
1773
|
-
lora: [],
|
|
1774
|
-
};
|
|
1775
|
-
},
|
|
1776
|
-
};
|
|
1777
|
-
|
|
1778
|
-
return {
|
|
1779
|
-
config,
|
|
1780
|
-
model,
|
|
1781
|
-
outputDim: vocabSize,
|
|
1782
|
-
embeddingDim: hiddenSize,
|
|
1783
|
-
cleanup() {
|
|
1784
|
-
model.cleanupDistillStep();
|
|
1785
|
-
for (const tensor of ownedTrainables) {
|
|
1786
|
-
releaseTensor(tensor);
|
|
1787
|
-
}
|
|
1788
|
-
ownedTrainables.clear();
|
|
1789
|
-
},
|
|
1790
|
-
};
|
|
1791
|
-
}
|
|
1792
|
-
|
|
1793
|
-
async function createDistillStudentRuntimeModelFixture(overrides = {}, options = {}) {
|
|
1794
|
-
const distillRuntime = options.distillRuntime && typeof options.distillRuntime === 'object'
|
|
1795
|
-
? options.distillRuntime
|
|
1796
|
-
: null;
|
|
1797
|
-
const graphMode = normalizeDistillStudentGraphMode(
|
|
1798
|
-
options.studentGraphMode
|
|
1799
|
-
?? distillRuntime?.studentGraphMode
|
|
1800
|
-
?? overrides?.training?.distill?.studentGraphMode
|
|
1801
|
-
);
|
|
1802
|
-
if (graphMode === DISTILL_STUDENT_GRAPH_PROJECTION) {
|
|
1803
|
-
return createDistillStudentProjectionModelFixture(overrides, options);
|
|
1804
|
-
}
|
|
1805
|
-
return createDistillStudentTransformerModelFixture(overrides, options);
|
|
1806
|
-
}
|
|
858
|
+
export { createDistillStudentRuntimeModelFixture };
|
|
1807
859
|
|
|
1808
860
|
async function runRunnerSmokeTest() {
|
|
1809
861
|
const fixture = createToyModelFixture();
|
|
@@ -2085,7 +1137,7 @@ function buildUlTrainingOverrides(options = {}) {
|
|
|
2085
1137
|
};
|
|
2086
1138
|
}
|
|
2087
1139
|
|
|
2088
|
-
function buildDistillTrainingOverrides(options = {}) {
|
|
1140
|
+
export function buildDistillTrainingOverrides(options = {}) {
|
|
2089
1141
|
const trainingConfig = normalizeTrainingConfigOverride(options.trainingConfig);
|
|
2090
1142
|
const explicitStage = normalizeTrainingStage(options.trainingStage || trainingConfig?.distill?.stage);
|
|
2091
1143
|
const distillEnabled = isDistillStage(explicitStage) || trainingConfig?.distill?.enabled === true;
|
|
@@ -2160,22 +1212,6 @@ async function computeNodeFileHash(filePath) {
|
|
|
2160
1212
|
};
|
|
2161
1213
|
}
|
|
2162
1214
|
|
|
2163
|
-
async function resolveIsolatedArtifactDir(explicitDir, prefix) {
|
|
2164
|
-
const normalized = normalizeOptionalString(explicitDir);
|
|
2165
|
-
if (normalized) {
|
|
2166
|
-
return normalized;
|
|
2167
|
-
}
|
|
2168
|
-
if (!(typeof process !== 'undefined' && process.versions?.node)) {
|
|
2169
|
-
return null;
|
|
2170
|
-
}
|
|
2171
|
-
const [{ mkdtemp }, { tmpdir }, { join }] = await Promise.all([
|
|
2172
|
-
import('node:fs/promises'),
|
|
2173
|
-
import('node:os'),
|
|
2174
|
-
import('node:path'),
|
|
2175
|
-
]);
|
|
2176
|
-
return mkdtemp(join(tmpdir(), `doppler-${prefix}-`));
|
|
2177
|
-
}
|
|
2178
|
-
|
|
2179
1215
|
async function runUlStageTest(stage, options = {}) {
|
|
2180
1216
|
const ulTraining = buildUlTrainingOverrides({
|
|
2181
1217
|
...options,
|
|
@@ -2198,7 +1234,9 @@ async function runUlStageTest(stage, options = {}) {
|
|
|
2198
1234
|
}
|
|
2199
1235
|
},
|
|
2200
1236
|
};
|
|
2201
|
-
const ulArtifactDir =
|
|
1237
|
+
const ulArtifactDir = normalizeOptionalString(options.ulArtifactDir)
|
|
1238
|
+
|| normalizeOptionalString(fixture.config.training?.ul?.artifactDir)
|
|
1239
|
+
|| 'reports/training/ul';
|
|
2202
1240
|
const metrics = await runner.run(fixture.model, dataset, {
|
|
2203
1241
|
epochs: 1,
|
|
2204
1242
|
batchSize: 1,
|
|
@@ -2374,7 +1412,9 @@ async function runDistillStageTest(stage, options = {}) {
|
|
|
2374
1412
|
distillRuntime,
|
|
2375
1413
|
});
|
|
2376
1414
|
const distillRunStartMs = performance.now();
|
|
2377
|
-
const distillArtifactDir =
|
|
1415
|
+
const distillArtifactDir = normalizeOptionalString(options.distillArtifactDir)
|
|
1416
|
+
|| normalizeOptionalString(fixture.config.training?.distill?.artifactDir)
|
|
1417
|
+
|| 'reports/training/distill';
|
|
2378
1418
|
const metrics = await runner.run(fixture.model, dataset, {
|
|
2379
1419
|
epochs: 1,
|
|
2380
1420
|
batchSize: 1,
|