@simulatte/doppler 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/BRANDING.md +14 -0
- package/LICENSE +201 -0
- package/NOTICE +5 -0
- package/README.md +85 -0
- package/SECURITY.md +19 -0
- package/package.json +144 -0
- package/src/adapters/adapter-manager.d.ts +200 -0
- package/src/adapters/adapter-manager.js +509 -0
- package/src/adapters/adapter-manifest.d.ts +290 -0
- package/src/adapters/adapter-manifest.js +320 -0
- package/src/adapters/adapter-registry.d.ts +192 -0
- package/src/adapters/adapter-registry.js +466 -0
- package/src/adapters/index.d.ts +89 -0
- package/src/adapters/index.js +42 -0
- package/src/adapters/lora-loader.d.ts +105 -0
- package/src/adapters/lora-loader.js +397 -0
- package/src/bootstrap.d.ts +1 -0
- package/src/bootstrap.js +30 -0
- package/src/bridge/extension/background.d.ts +14 -0
- package/src/bridge/extension/background.js +168 -0
- package/src/bridge/extension/manifest.json +34 -0
- package/src/bridge/extension-client.d.ts +109 -0
- package/src/bridge/extension-client.js +369 -0
- package/src/bridge/index.d.ts +68 -0
- package/src/bridge/index.js +51 -0
- package/src/bridge/protocol.d.ts +96 -0
- package/src/bridge/protocol.js +130 -0
- package/src/browser/browser-converter.d.ts +71 -0
- package/src/browser/browser-converter.js +947 -0
- package/src/browser/file-picker.d.ts +63 -0
- package/src/browser/file-picker.js +275 -0
- package/src/browser/gguf-importer.d.ts +136 -0
- package/src/browser/gguf-importer.js +532 -0
- package/src/browser/gguf-parser-browser.d.ts +14 -0
- package/src/browser/gguf-parser-browser.js +17 -0
- package/src/browser/quantization.d.ts +69 -0
- package/src/browser/quantization.js +328 -0
- package/src/browser/safetensors-parser-browser.d.ts +193 -0
- package/src/browser/safetensors-parser-browser.js +264 -0
- package/src/browser/shard-io-browser.d.ts +57 -0
- package/src/browser/shard-io-browser.js +89 -0
- package/src/browser/tensor-source-download.d.ts +27 -0
- package/src/browser/tensor-source-download.js +239 -0
- package/src/browser/tensor-source-file.d.ts +26 -0
- package/src/browser/tensor-source-file.js +53 -0
- package/src/browser/tensor-source-http.d.ts +28 -0
- package/src/browser/tensor-source-http.js +126 -0
- package/src/client/doppler-provider/generation.d.ts +25 -0
- package/src/client/doppler-provider/generation.js +114 -0
- package/src/client/doppler-provider/index.d.ts +2 -0
- package/src/client/doppler-provider/index.js +3 -0
- package/src/client/doppler-provider/model-manager.d.ts +61 -0
- package/src/client/doppler-provider/model-manager.js +667 -0
- package/src/client/doppler-provider/provider.d.ts +5 -0
- package/src/client/doppler-provider/provider.js +102 -0
- package/src/client/doppler-provider/source-runtime.d.ts +22 -0
- package/src/client/doppler-provider/source-runtime.js +522 -0
- package/src/client/doppler-provider/types.d.ts +127 -0
- package/src/client/doppler-provider/types.js +17 -0
- package/src/client/doppler-provider.d.ts +46 -0
- package/src/client/doppler-provider.js +36 -0
- package/src/config/README.md +69 -0
- package/src/config/backward-registry-loader.d.ts +3 -0
- package/src/config/backward-registry-loader.js +8 -0
- package/src/config/index.d.ts +63 -0
- package/src/config/index.js +31 -0
- package/src/config/kernel-path-loader.d.ts +149 -0
- package/src/config/kernel-path-loader.js +534 -0
- package/src/config/kernels/backward-registry.json +99 -0
- package/src/config/kernels/kernel-ref-digests.d.ts +1 -0
- package/src/config/kernels/kernel-ref-digests.js +214 -0
- package/src/config/kernels/kernel-ref.d.ts +17 -0
- package/src/config/kernels/kernel-ref.js +75 -0
- package/src/config/kernels/moe/gpt-oss.paths.json +49 -0
- package/src/config/kernels/registry.d.ts +86 -0
- package/src/config/kernels/registry.js +103 -0
- package/src/config/kernels/registry.json +6771 -0
- package/src/config/loader.d.ts +57 -0
- package/src/config/loader.js +513 -0
- package/src/config/merge.d.ts +142 -0
- package/src/config/merge.js +389 -0
- package/src/config/param-categories.d.ts +17 -0
- package/src/config/param-categories.js +72 -0
- package/src/config/param-validator.d.ts +26 -0
- package/src/config/param-validator.js +235 -0
- package/src/config/platforms/amd-rdna3.json +16 -0
- package/src/config/platforms/apple-m1.json +16 -0
- package/src/config/platforms/apple-m2.json +16 -0
- package/src/config/platforms/apple-m3.json +16 -0
- package/src/config/platforms/generic.json +14 -0
- package/src/config/platforms/loader.d.ts +65 -0
- package/src/config/platforms/loader.js +153 -0
- package/src/config/platforms/nvidia-rtx30.json +16 -0
- package/src/config/platforms/nvidia-rtx40.json +16 -0
- package/src/config/presets/kernel-paths/embeddinggemma-f16-f32a.json +60 -0
- package/src/config/presets/kernel-paths/embeddinggemma-f32-f32a.json +60 -0
- package/src/config/presets/kernel-paths/embeddinggemma-q4k-dequant-f32a.json +60 -0
- package/src/config/presets/kernel-paths/gemma2-f16-f16a.json +61 -0
- package/src/config/presets/kernel-paths/gemma2-f16-f32a.json +60 -0
- package/src/config/presets/kernel-paths/gemma2-q4k-dequant-f16a.json +61 -0
- package/src/config/presets/kernel-paths/gemma2-q4k-dequant-f32a.json +60 -0
- package/src/config/presets/kernel-paths/gemma2-q4k-fused-f32a.json +57 -0
- package/src/config/presets/kernel-paths/gemma3-f16-fused-f16a-online.json +200 -0
- package/src/config/presets/kernel-paths/gemma3-f16-fused-f32a-online.json +223 -0
- package/src/config/presets/kernel-paths/gemma3-q4k-dequant-f16a-online.json +60 -0
- package/src/config/presets/kernel-paths/gemma3-q4k-dequant-f32a-online.json +61 -0
- package/src/config/presets/kernel-paths/gemma3-q4k-dequant-f32a.json +61 -0
- package/src/config/presets/kernel-paths/lfm2-q4k-dequant-f32a-online.json +61 -0
- package/src/config/presets/kernel-paths/registry.json +103 -0
- package/src/config/presets/models/deepseek.json +20 -0
- package/src/config/presets/models/diffusion.json +10 -0
- package/src/config/presets/models/embeddinggemma.json +74 -0
- package/src/config/presets/models/functiongemma.json +31 -0
- package/src/config/presets/models/gemma2.json +59 -0
- package/src/config/presets/models/gemma3.json +75 -0
- package/src/config/presets/models/gpt-oss.json +68 -0
- package/src/config/presets/models/kimi-k2.json +25 -0
- package/src/config/presets/models/lfm2.json +83 -0
- package/src/config/presets/models/llama3.json +40 -0
- package/src/config/presets/models/mamba.json +34 -0
- package/src/config/presets/models/mixtral.json +37 -0
- package/src/config/presets/models/modernbert.json +32 -0
- package/src/config/presets/models/qwen3.json +41 -0
- package/src/config/presets/models/transformer.json +73 -0
- package/src/config/presets/models/translategemma.json +30 -0
- package/src/config/presets/platforms/nvidia-gb200-8gpu.json +45 -0
- package/src/config/presets/platforms/nvidia-gb200-nvl72.json +45 -0
- package/src/config/presets/platforms/nvidia-gh200-nvl2.json +44 -0
- package/src/config/presets/platforms/nvidia-gh200.json +44 -0
- package/src/config/presets/runtime/compute/f16-activations.json +30 -0
- package/src/config/presets/runtime/compute/f16-batched.json +32 -0
- package/src/config/presets/runtime/default.json +101 -0
- package/src/config/presets/runtime/diagnostics/debug-logits.json +53 -0
- package/src/config/presets/runtime/experiments/bench/gemma3-bench-q4k.json +53 -0
- package/src/config/presets/runtime/experiments/debug/gemma3-debug-q4k.json +210 -0
- package/src/config/presets/runtime/experiments/verify/gemma3-verify.json +39 -0
- package/src/config/presets/runtime/kernels/dequant-f16-q4k.json +20 -0
- package/src/config/presets/runtime/kernels/dequant-f32-q4k.json +20 -0
- package/src/config/presets/runtime/kernels/fused-q4k.json +20 -0
- package/src/config/presets/runtime/kernels/safe-q4k.json +20 -0
- package/src/config/presets/runtime/model/gemma2-debug.json +77 -0
- package/src/config/presets/runtime/model/gemma2-pipeline-debug.json +66 -0
- package/src/config/presets/runtime/model/gemma2-pipeline.json +75 -0
- package/src/config/presets/runtime/model/gemma3-layer-probe.json +85 -0
- package/src/config/presets/runtime/modes/bench.json +37 -0
- package/src/config/presets/runtime/modes/debug.json +39 -0
- package/src/config/presets/runtime/modes/default.json +10 -0
- package/src/config/presets/runtime/modes/embedding-bench.json +28 -0
- package/src/config/presets/runtime/modes/embedding.json +54 -0
- package/src/config/presets/runtime/modes/low-memory.json +40 -0
- package/src/config/presets/runtime/modes/production.json +48 -0
- package/src/config/presets/runtime/modes/simulation.json +30 -0
- package/src/config/presets/runtime/modes/trace-layers.json +126 -0
- package/src/config/presets/runtime/platform/metal-apple-q4k.json +11 -0
- package/src/config/runtime-merge.d.ts +5 -0
- package/src/config/runtime-merge.js +21 -0
- package/src/config/runtime.d.ts +28 -0
- package/src/config/runtime.js +56 -0
- package/src/config/schema/adapter.schema.d.ts +53 -0
- package/src/config/schema/adapter.schema.js +60 -0
- package/src/config/schema/backward-registry.schema.d.ts +14 -0
- package/src/config/schema/backward-registry.schema.js +46 -0
- package/src/config/schema/benchmark.schema.d.ts +54 -0
- package/src/config/schema/benchmark.schema.js +74 -0
- package/src/config/schema/bridge.schema.d.ts +25 -0
- package/src/config/schema/bridge.schema.js +22 -0
- package/src/config/schema/buffer-pool.schema.d.ts +92 -0
- package/src/config/schema/buffer-pool.schema.js +50 -0
- package/src/config/schema/conversion.schema.d.ts +183 -0
- package/src/config/schema/conversion.schema.js +13 -0
- package/src/config/schema/converter.schema.d.ts +123 -0
- package/src/config/schema/converter.schema.js +136 -0
- package/src/config/schema/debug.schema.d.ts +245 -0
- package/src/config/schema/debug.schema.js +106 -0
- package/src/config/schema/diffusion.schema.d.ts +88 -0
- package/src/config/schema/diffusion.schema.js +62 -0
- package/src/config/schema/distill-training.schema.d.ts +48 -0
- package/src/config/schema/distill-training.schema.js +139 -0
- package/src/config/schema/distribution.schema.d.ts +155 -0
- package/src/config/schema/distribution.schema.js +81 -0
- package/src/config/schema/doppler.schema.d.ts +75 -0
- package/src/config/schema/doppler.schema.js +352 -0
- package/src/config/schema/ecosystem.schema.d.ts +255 -0
- package/src/config/schema/ecosystem.schema.js +534 -0
- package/src/config/schema/emulation.schema.d.ts +351 -0
- package/src/config/schema/emulation.schema.js +299 -0
- package/src/config/schema/energy.schema.d.ts +102 -0
- package/src/config/schema/energy.schema.js +72 -0
- package/src/config/schema/execution-v0.schema.d.ts +187 -0
- package/src/config/schema/execution-v0.schema.js +55 -0
- package/src/config/schema/gpu-cache.schema.d.ts +26 -0
- package/src/config/schema/gpu-cache.schema.js +8 -0
- package/src/config/schema/harness.schema.d.ts +32 -0
- package/src/config/schema/harness.schema.js +20 -0
- package/src/config/schema/hotswap.schema.d.ts +55 -0
- package/src/config/schema/hotswap.schema.js +18 -0
- package/src/config/schema/index.d.ts +863 -0
- package/src/config/schema/index.js +471 -0
- package/src/config/schema/inference-defaults.schema.d.ts +276 -0
- package/src/config/schema/inference-defaults.schema.js +185 -0
- package/src/config/schema/inference.schema.d.ts +289 -0
- package/src/config/schema/inference.schema.js +39 -0
- package/src/config/schema/intent-bundle.schema.d.ts +28 -0
- package/src/config/schema/intent-bundle.schema.js +12 -0
- package/src/config/schema/kernel-path.schema.d.ts +173 -0
- package/src/config/schema/kernel-path.schema.js +9 -0
- package/src/config/schema/kernel-registry.schema.d.ts +199 -0
- package/src/config/schema/kernel-registry.schema.js +46 -0
- package/src/config/schema/kernel-thresholds.schema.d.ts +302 -0
- package/src/config/schema/kernel-thresholds.schema.js +187 -0
- package/src/config/schema/kernel-warmup.schema.d.ts +19 -0
- package/src/config/schema/kernel-warmup.schema.js +5 -0
- package/src/config/schema/kvcache.schema.d.ts +131 -0
- package/src/config/schema/kvcache.schema.js +31 -0
- package/src/config/schema/loading.schema.d.ts +153 -0
- package/src/config/schema/loading.schema.js +84 -0
- package/src/config/schema/lora.schema.d.ts +12 -0
- package/src/config/schema/lora.schema.js +12 -0
- package/src/config/schema/manifest.schema.d.ts +500 -0
- package/src/config/schema/manifest.schema.js +130 -0
- package/src/config/schema/memory-limits.schema.d.ts +107 -0
- package/src/config/schema/memory-limits.schema.js +57 -0
- package/src/config/schema/moe.schema.d.ts +78 -0
- package/src/config/schema/moe.schema.js +31 -0
- package/src/config/schema/platform.schema.d.ts +121 -0
- package/src/config/schema/platform.schema.js +1 -0
- package/src/config/schema/preset.schema.d.ts +124 -0
- package/src/config/schema/preset.schema.js +1 -0
- package/src/config/schema/quantization-defaults.schema.d.ts +34 -0
- package/src/config/schema/quantization-defaults.schema.js +5 -0
- package/src/config/schema/quantization.schema.d.ts +10 -0
- package/src/config/schema/quantization.schema.js +33 -0
- package/src/config/schema/shared-runtime.schema.d.ts +75 -0
- package/src/config/schema/shared-runtime.schema.js +45 -0
- package/src/config/schema/speculative.schema.d.ts +21 -0
- package/src/config/schema/speculative.schema.js +11 -0
- package/src/config/schema/storage.schema.d.ts +123 -0
- package/src/config/schema/storage.schema.js +66 -0
- package/src/config/schema/tooling.schema.d.ts +29 -0
- package/src/config/schema/tooling.schema.js +12 -0
- package/src/config/schema/training-metrics.schema.d.ts +89 -0
- package/src/config/schema/training-metrics.schema.js +374 -0
- package/src/config/schema/training.schema.d.ts +88 -0
- package/src/config/schema/training.schema.js +106 -0
- package/src/config/schema/tuner.schema.d.ts +39 -0
- package/src/config/schema/tuner.schema.js +13 -0
- package/src/config/schema/ul-training.schema.d.ts +61 -0
- package/src/config/schema/ul-training.schema.js +140 -0
- package/src/config/schema/units.schema.d.ts +27 -0
- package/src/config/schema/units.schema.js +26 -0
- package/src/config/training-defaults.d.ts +24 -0
- package/src/config/training-defaults.js +91 -0
- package/src/converter/conversion-plan.d.ts +64 -0
- package/src/converter/conversion-plan.js +472 -0
- package/src/converter/core.d.ts +247 -0
- package/src/converter/core.js +1329 -0
- package/src/converter/execution-v0-manifest.d.ts +15 -0
- package/src/converter/execution-v0-manifest.js +146 -0
- package/src/converter/index.d.ts +98 -0
- package/src/converter/index.js +59 -0
- package/src/converter/manifest-inference.d.ts +20 -0
- package/src/converter/manifest-inference.js +492 -0
- package/src/converter/parsers/diffusion.d.ts +50 -0
- package/src/converter/parsers/diffusion.js +270 -0
- package/src/converter/parsers/gguf.d.ts +22 -0
- package/src/converter/parsers/gguf.js +46 -0
- package/src/converter/parsers/index.d.ts +21 -0
- package/src/converter/parsers/index.js +12 -0
- package/src/converter/parsers/transformer.d.ts +16 -0
- package/src/converter/parsers/transformer.js +25 -0
- package/src/converter/quantization-info.d.ts +37 -0
- package/src/converter/quantization-info.js +398 -0
- package/src/converter/quantizer.d.ts +96 -0
- package/src/converter/quantizer.js +422 -0
- package/src/converter/rope-config.d.ts +15 -0
- package/src/converter/rope-config.js +218 -0
- package/src/converter/shard-packer.d.ts +138 -0
- package/src/converter/shard-packer.js +422 -0
- package/src/converter/tokenizer-utils.d.ts +11 -0
- package/src/converter/tokenizer-utils.js +87 -0
- package/src/debug/config.d.ts +78 -0
- package/src/debug/config.js +235 -0
- package/src/debug/history.d.ts +65 -0
- package/src/debug/history.js +71 -0
- package/src/debug/index.d.ts +268 -0
- package/src/debug/index.js +192 -0
- package/src/debug/log.d.ts +46 -0
- package/src/debug/log.js +132 -0
- package/src/debug/perf.d.ts +33 -0
- package/src/debug/perf.js +51 -0
- package/src/debug/reference/README.md +114 -0
- package/src/debug/reference/hf_attn_debug.py +114 -0
- package/src/debug/reference/hf_embed_check.py +89 -0
- package/src/debug/reference/hf_layer_out.py +100 -0
- package/src/debug/reference/hf_rope_check.py +116 -0
- package/src/debug/reference/hf_weights.py +75 -0
- package/src/debug/signals.d.ts +63 -0
- package/src/debug/signals.js +33 -0
- package/src/debug/stats.d.ts +47 -0
- package/src/debug/stats.js +160 -0
- package/src/debug/tensor.d.ts +123 -0
- package/src/debug/tensor.js +257 -0
- package/src/debug/trace.d.ts +17 -0
- package/src/debug/trace.js +167 -0
- package/src/diffusion/image-regression.d.ts +31 -0
- package/src/diffusion/image-regression.js +107 -0
- package/src/diffusion/index.d.ts +8 -0
- package/src/diffusion/index.js +8 -0
- package/src/distribution/p2p-control-plane.d.ts +52 -0
- package/src/distribution/p2p-control-plane.js +232 -0
- package/src/distribution/p2p-observability.d.ts +116 -0
- package/src/distribution/p2p-observability.js +267 -0
- package/src/distribution/p2p-transport-contract.d.ts +57 -0
- package/src/distribution/p2p-transport-contract.js +310 -0
- package/src/distribution/p2p-webrtc-browser.d.ts +37 -0
- package/src/distribution/p2p-webrtc-browser.js +434 -0
- package/src/distribution/shard-delivery.d.ts +251 -0
- package/src/distribution/shard-delivery.js +2096 -0
- package/src/energy/index.d.ts +2 -0
- package/src/energy/index.js +2 -0
- package/src/errors/doppler-error.d.ts +21 -0
- package/src/errors/doppler-error.js +25 -0
- package/src/errors/index.d.ts +1 -0
- package/src/errors/index.js +1 -0
- package/src/formats/gguf/index.d.ts +8 -0
- package/src/formats/gguf/index.js +4 -0
- package/src/formats/gguf/types.d.ts +137 -0
- package/src/formats/gguf/types.js +443 -0
- package/src/formats/index.d.ts +51 -0
- package/src/formats/index.js +13 -0
- package/src/formats/rdrr/classification.d.ts +39 -0
- package/src/formats/rdrr/classification.js +275 -0
- package/src/formats/rdrr/groups.d.ts +27 -0
- package/src/formats/rdrr/groups.js +76 -0
- package/src/formats/rdrr/index.d.ts +25 -0
- package/src/formats/rdrr/index.js +19 -0
- package/src/formats/rdrr/manifest.d.ts +32 -0
- package/src/formats/rdrr/manifest.js +108 -0
- package/src/formats/rdrr/parsing.d.ts +23 -0
- package/src/formats/rdrr/parsing.js +101 -0
- package/src/formats/rdrr/tensor-config-validator.d.ts +42 -0
- package/src/formats/rdrr/tensor-config-validator.js +156 -0
- package/src/formats/rdrr/types.d.ts +200 -0
- package/src/formats/rdrr/types.js +16 -0
- package/src/formats/rdrr/validation.d.ts +9 -0
- package/src/formats/rdrr/validation.js +200 -0
- package/src/formats/safetensors/index.d.ts +8 -0
- package/src/formats/safetensors/index.js +4 -0
- package/src/formats/safetensors/types.d.ts +67 -0
- package/src/formats/safetensors/types.js +102 -0
- package/src/formats/tokenizer/index.d.ts +5 -0
- package/src/formats/tokenizer/index.js +3 -0
- package/src/formats/tokenizer/types.d.ts +9 -0
- package/src/formats/tokenizer/types.js +22 -0
- package/src/generation/index.d.ts +18 -0
- package/src/generation/index.js +12 -0
- package/src/gpu/command-recorder.d.ts +175 -0
- package/src/gpu/command-recorder.js +473 -0
- package/src/gpu/device.d.ts +141 -0
- package/src/gpu/device.js +350 -0
- package/src/gpu/kernel-runtime.d.ts +20 -0
- package/src/gpu/kernel-runtime.js +37 -0
- package/src/gpu/kernel-selection-cache.d.ts +13 -0
- package/src/gpu/kernel-selection-cache.js +13 -0
- package/src/gpu/kernel-selection-log.d.ts +12 -0
- package/src/gpu/kernel-selection-log.js +28 -0
- package/src/gpu/kernel-selector.d.ts +11 -0
- package/src/gpu/kernel-selector.js +10 -0
- package/src/gpu/kernel-tuner/benchmarks.d.ts +144 -0
- package/src/gpu/kernel-tuner/benchmarks.js +892 -0
- package/src/gpu/kernel-tuner/cache.d.ts +55 -0
- package/src/gpu/kernel-tuner/cache.js +66 -0
- package/src/gpu/kernel-tuner/index.d.ts +59 -0
- package/src/gpu/kernel-tuner/index.js +38 -0
- package/src/gpu/kernel-tuner/tuner.d.ts +82 -0
- package/src/gpu/kernel-tuner/tuner.js +229 -0
- package/src/gpu/kernel-tuner/types.d.ts +101 -0
- package/src/gpu/kernel-tuner/types.js +4 -0
- package/src/gpu/kernel-tuner.d.ts +33 -0
- package/src/gpu/kernel-tuner.js +12 -0
- package/src/gpu/kernels/README.md +127 -0
- package/src/gpu/kernels/attention.d.ts +236 -0
- package/src/gpu/kernels/attention.js +1359 -0
- package/src/gpu/kernels/attention.wgsl +249 -0
- package/src/gpu/kernels/attention_bdpa_decode_f16.wgsl +246 -0
- package/src/gpu/kernels/attention_decode.wgsl +233 -0
- package/src/gpu/kernels/attention_decode_chunked_f16.wgsl +183 -0
- package/src/gpu/kernels/attention_decode_chunked_f16kv.wgsl +208 -0
- package/src/gpu/kernels/attention_decode_f16.wgsl +202 -0
- package/src/gpu/kernels/attention_decode_f16kv.wgsl +224 -0
- package/src/gpu/kernels/attention_decode_online_f16.wgsl +223 -0
- package/src/gpu/kernels/attention_decode_online_f16kv.wgsl +225 -0
- package/src/gpu/kernels/attention_decode_optimized.wgsl +445 -0
- package/src/gpu/kernels/attention_decode_paged_f16.wgsl +172 -0
- package/src/gpu/kernels/attention_decode_paged_f16kv.wgsl +174 -0
- package/src/gpu/kernels/attention_decode_subgroup.wgsl +233 -0
- package/src/gpu/kernels/attention_decode_tiered_f16.wgsl +218 -0
- package/src/gpu/kernels/attention_decode_tiered_f16kv.wgsl +220 -0
- package/src/gpu/kernels/attention_decode_tiered_int4_f16kv.wgsl +242 -0
- package/src/gpu/kernels/attention_decode_tiered_int8_f16kv.wgsl +242 -0
- package/src/gpu/kernels/attention_f16.wgsl +214 -0
- package/src/gpu/kernels/attention_f16kv.wgsl +242 -0
- package/src/gpu/kernels/attention_small.wgsl +260 -0
- package/src/gpu/kernels/attention_small_f16.wgsl +240 -0
- package/src/gpu/kernels/attention_small_f16kv.wgsl +266 -0
- package/src/gpu/kernels/attention_streaming.wgsl +149 -0
- package/src/gpu/kernels/attention_streaming_f16.wgsl +147 -0
- package/src/gpu/kernels/attention_streaming_f16kv.wgsl +151 -0
- package/src/gpu/kernels/backward/adam.d.ts +28 -0
- package/src/gpu/kernels/backward/adam.js +199 -0
- package/src/gpu/kernels/backward/adam.wgsl +50 -0
- package/src/gpu/kernels/backward/attention_backward.d.ts +22 -0
- package/src/gpu/kernels/backward/attention_backward.js +276 -0
- package/src/gpu/kernels/backward/attention_backward.wgsl +49 -0
- package/src/gpu/kernels/backward/bias_add_backward.d.ts +17 -0
- package/src/gpu/kernels/backward/bias_add_backward.js +24 -0
- package/src/gpu/kernels/backward/bias_add_backward.wgsl +33 -0
- package/src/gpu/kernels/backward/conv2d_backward.d.ts +31 -0
- package/src/gpu/kernels/backward/conv2d_backward.js +135 -0
- package/src/gpu/kernels/backward/conv2d_backward_input.wgsl +83 -0
- package/src/gpu/kernels/backward/conv2d_backward_weight.wgsl +70 -0
- package/src/gpu/kernels/backward/cross_entropy_backward.d.ts +23 -0
- package/src/gpu/kernels/backward/cross_entropy_backward.js +29 -0
- package/src/gpu/kernels/backward/cross_entropy_backward.wgsl +39 -0
- package/src/gpu/kernels/backward/embed_backward.d.ts +29 -0
- package/src/gpu/kernels/backward/embed_backward.js +118 -0
- package/src/gpu/kernels/backward/embed_backward.wgsl +73 -0
- package/src/gpu/kernels/backward/gelu_backward.d.ts +16 -0
- package/src/gpu/kernels/backward/gelu_backward.js +39 -0
- package/src/gpu/kernels/backward/gelu_backward.wgsl +38 -0
- package/src/gpu/kernels/backward/groupnorm_backward.d.ts +24 -0
- package/src/gpu/kernels/backward/groupnorm_backward.js +29 -0
- package/src/gpu/kernels/backward/groupnorm_backward.wgsl +143 -0
- package/src/gpu/kernels/backward/index.d.ts +17 -0
- package/src/gpu/kernels/backward/index.js +23 -0
- package/src/gpu/kernels/backward/layernorm_backward.d.ts +22 -0
- package/src/gpu/kernels/backward/layernorm_backward.js +135 -0
- package/src/gpu/kernels/backward/layernorm_backward.wgsl +194 -0
- package/src/gpu/kernels/backward/matmul_backward.d.ts +32 -0
- package/src/gpu/kernels/backward/matmul_backward.js +124 -0
- package/src/gpu/kernels/backward/matmul_backward.wgsl +90 -0
- package/src/gpu/kernels/backward/matmul_transpose_a.wgsl +84 -0
- package/src/gpu/kernels/backward/pixel_shuffle_backward.d.ts +22 -0
- package/src/gpu/kernels/backward/pixel_shuffle_backward.js +30 -0
- package/src/gpu/kernels/backward/pixel_shuffle_backward.wgsl +54 -0
- package/src/gpu/kernels/backward/rmsnorm_backward.d.ts +24 -0
- package/src/gpu/kernels/backward/rmsnorm_backward.js +101 -0
- package/src/gpu/kernels/backward/rmsnorm_backward.wgsl +78 -0
- package/src/gpu/kernels/backward/rope_backward.d.ts +25 -0
- package/src/gpu/kernels/backward/rope_backward.js +109 -0
- package/src/gpu/kernels/backward/rope_backward.wgsl +59 -0
- package/src/gpu/kernels/backward/scale_backward.d.ts +16 -0
- package/src/gpu/kernels/backward/scale_backward.js +84 -0
- package/src/gpu/kernels/backward/scale_backward.wgsl +27 -0
- package/src/gpu/kernels/backward/silu_backward.d.ts +16 -0
- package/src/gpu/kernels/backward/silu_backward.js +39 -0
- package/src/gpu/kernels/backward/silu_backward.wgsl +31 -0
- package/src/gpu/kernels/backward/softmax_backward.d.ts +16 -0
- package/src/gpu/kernels/backward/softmax_backward.js +43 -0
- package/src/gpu/kernels/backward/softmax_backward.wgsl +44 -0
- package/src/gpu/kernels/backward/upsample2d_backward.d.ts +21 -0
- package/src/gpu/kernels/backward/upsample2d_backward.js +30 -0
- package/src/gpu/kernels/backward/upsample2d_backward.wgsl +59 -0
- package/src/gpu/kernels/backward/utils.d.ts +45 -0
- package/src/gpu/kernels/backward/utils.js +371 -0
- package/src/gpu/kernels/bf16_to_f16.wgsl +54 -0
- package/src/gpu/kernels/bf16_to_f32.wgsl +70 -0
- package/src/gpu/kernels/bias_add.wgsl +40 -0
- package/src/gpu/kernels/bias_add_f16.wgsl +44 -0
- package/src/gpu/kernels/cast.d.ts +67 -0
- package/src/gpu/kernels/cast.js +422 -0
- package/src/gpu/kernels/cast_f16_to_f32.wgsl +31 -0
- package/src/gpu/kernels/cast_f32_to_f16.wgsl +36 -0
- package/src/gpu/kernels/check-finiteness.d.ts +15 -0
- package/src/gpu/kernels/check-finiteness.js +149 -0
- package/src/gpu/kernels/check-stop.d.ts +31 -0
- package/src/gpu/kernels/check-stop.js +181 -0
- package/src/gpu/kernels/clamp.d.ts +22 -0
- package/src/gpu/kernels/clamp.js +42 -0
- package/src/gpu/kernels/clamp.wgsl +24 -0
- package/src/gpu/kernels/constants.d.ts +168 -0
- package/src/gpu/kernels/constants.js +129 -0
- package/src/gpu/kernels/conv2d.d.ts +34 -0
- package/src/gpu/kernels/conv2d.js +81 -0
- package/src/gpu/kernels/conv2d.wgsl +71 -0
- package/src/gpu/kernels/conv2d_f16.wgsl +73 -0
- package/src/gpu/kernels/cross_entropy_loss.d.ts +21 -0
- package/src/gpu/kernels/cross_entropy_loss.js +54 -0
- package/src/gpu/kernels/cross_entropy_loss.wgsl +39 -0
- package/src/gpu/kernels/dequant.d.ts +108 -0
- package/src/gpu/kernels/dequant.js +524 -0
- package/src/gpu/kernels/dequant_f16_out.wgsl +151 -0
- package/src/gpu/kernels/dequant_f16_out_vec4.wgsl +149 -0
- package/src/gpu/kernels/dequant_f16_rowwise.wgsl +139 -0
- package/src/gpu/kernels/dequant_f32_rowwise.wgsl +133 -0
- package/src/gpu/kernels/dequant_mxfp4.wgsl +120 -0
- package/src/gpu/kernels/dequant_mxfp4_expert.wgsl +129 -0
- package/src/gpu/kernels/dequant_mxfp4_expert_f16.wgsl +105 -0
- package/src/gpu/kernels/dequant_mxfp4_vec4.wgsl +116 -0
- package/src/gpu/kernels/dequant_q6k.wgsl +140 -0
- package/src/gpu/kernels/dequant_q8_0.wgsl +98 -0
- package/src/gpu/kernels/dequant_shared.wgsl +202 -0
- package/src/gpu/kernels/dequant_shared_vec4.wgsl +153 -0
- package/src/gpu/kernels/dequant_subgroup.wgsl +202 -0
- package/src/gpu/kernels/dispatch.d.ts +157 -0
- package/src/gpu/kernels/dispatch.js +235 -0
- package/src/gpu/kernels/energy.d.ts +131 -0
- package/src/gpu/kernels/energy.js +425 -0
- package/src/gpu/kernels/energy_eval.wgsl +26 -0
- package/src/gpu/kernels/energy_eval_f16.wgsl +30 -0
- package/src/gpu/kernels/energy_quintel_grad.wgsl +92 -0
- package/src/gpu/kernels/energy_quintel_grad_f16.wgsl +96 -0
- package/src/gpu/kernels/energy_quintel_reduce.wgsl +112 -0
- package/src/gpu/kernels/energy_quintel_reduce_f16.wgsl +116 -0
- package/src/gpu/kernels/energy_quintel_update.wgsl +92 -0
- package/src/gpu/kernels/energy_quintel_update_f16.wgsl +96 -0
- package/src/gpu/kernels/energy_update.wgsl +25 -0
- package/src/gpu/kernels/energy_update_f16.wgsl +30 -0
- package/src/gpu/kernels/feature-check.d.ts +42 -0
- package/src/gpu/kernels/feature-check.js +70 -0
- package/src/gpu/kernels/fused_ffn.d.ts +65 -0
- package/src/gpu/kernels/fused_ffn.js +318 -0
- package/src/gpu/kernels/fused_ffn.wgsl +420 -0
- package/src/gpu/kernels/fused_ffn_f16.wgsl +213 -0
- package/src/gpu/kernels/fused_ffn_q4k.wgsl +375 -0
- package/src/gpu/kernels/fused_matmul_q4.wgsl +404 -0
- package/src/gpu/kernels/fused_matmul_q4_batched.wgsl +194 -0
- package/src/gpu/kernels/fused_matmul_q4_batched_f16.wgsl +170 -0
- package/src/gpu/kernels/fused_matmul_q4_batched_f16a.wgsl +154 -0
- package/src/gpu/kernels/fused_matmul_q4_f16a.wgsl +219 -0
- package/src/gpu/kernels/fused_matmul_q4_multicol_f16.wgsl +216 -0
- package/src/gpu/kernels/fused_matmul_q4_multicol_f16a.wgsl +204 -0
- package/src/gpu/kernels/fused_matmul_residual.d.ts +46 -0
- package/src/gpu/kernels/fused_matmul_residual.js +152 -0
- package/src/gpu/kernels/fused_matmul_rmsnorm.d.ts +64 -0
- package/src/gpu/kernels/fused_matmul_rmsnorm.js +273 -0
- package/src/gpu/kernels/fused_matmul_rmsnorm.wgsl +324 -0
- package/src/gpu/kernels/fused_matmul_rmsnorm_f16.wgsl +303 -0
- package/src/gpu/kernels/fused_swiglu.wgsl +63 -0
- package/src/gpu/kernels/fused_swiglu_f16.wgsl +57 -0
- package/src/gpu/kernels/gather.d.ts +64 -0
- package/src/gpu/kernels/gather.js +119 -0
- package/src/gpu/kernels/gather.wgsl +61 -0
- package/src/gpu/kernels/gather_f16.wgsl +65 -0
- package/src/gpu/kernels/gather_f16_f16_out.wgsl +55 -0
- package/src/gpu/kernels/gather_f16_out.wgsl +55 -0
- package/src/gpu/kernels/gather_f16_vec4.wgsl +76 -0
- package/src/gpu/kernels/gather_f16_vec4_f16_out.wgsl +68 -0
- package/src/gpu/kernels/gather_vec4.wgsl +74 -0
- package/src/gpu/kernels/gather_vec4_f16_out.wgsl +68 -0
- package/src/gpu/kernels/gelu.d.ts +33 -0
- package/src/gpu/kernels/gelu.js +47 -0
- package/src/gpu/kernels/gelu.wgsl +64 -0
- package/src/gpu/kernels/gelu_f16.wgsl +66 -0
- package/src/gpu/kernels/gptoss_mxfp4_expert_fused.wgsl +127 -0
- package/src/gpu/kernels/gptoss_router_topk.wgsl +119 -0
- package/src/gpu/kernels/groupnorm.d.ts +31 -0
- package/src/gpu/kernels/groupnorm.js +91 -0
- package/src/gpu/kernels/groupnorm_apply.wgsl +41 -0
- package/src/gpu/kernels/groupnorm_apply_f16.wgsl +46 -0
- package/src/gpu/kernels/groupnorm_stats.wgsl +76 -0
- package/src/gpu/kernels/groupnorm_stats_f16.wgsl +79 -0
- package/src/gpu/kernels/index.d.ts +336 -0
- package/src/gpu/kernels/index.js +284 -0
- package/src/gpu/kernels/kernel-base.d.ts +33 -0
- package/src/gpu/kernels/kernel-base.js +46 -0
- package/src/gpu/kernels/kernel-configs.d.ts +65 -0
- package/src/gpu/kernels/kernel-configs.js +50 -0
- package/src/gpu/kernels/kernel-tuning.d.ts +42 -0
- package/src/gpu/kernels/kernel-tuning.js +149 -0
- package/src/gpu/kernels/kv-quantize.d.ts +37 -0
- package/src/gpu/kernels/kv-quantize.js +138 -0
- package/src/gpu/kernels/kv_quantize_int4.wgsl +119 -0
- package/src/gpu/kernels/kv_quantize_int8.wgsl +119 -0
- package/src/gpu/kernels/layernorm.d.ts +37 -0
- package/src/gpu/kernels/layernorm.js +80 -0
- package/src/gpu/kernels/layernorm.wgsl +121 -0
- package/src/gpu/kernels/layernorm_f16.wgsl +103 -0
- package/src/gpu/kernels/linear-attention-core.d.ts +39 -0
- package/src/gpu/kernels/linear-attention-core.js +535 -0
- package/src/gpu/kernels/logit-merge.d.ts +110 -0
- package/src/gpu/kernels/logit-merge.js +392 -0
- package/src/gpu/kernels/matmul-dispatch.d.ts +38 -0
- package/src/gpu/kernels/matmul-dispatch.js +155 -0
- package/src/gpu/kernels/matmul-selection.d.ts +87 -0
- package/src/gpu/kernels/matmul-selection.js +474 -0
- package/src/gpu/kernels/matmul.d.ts +109 -0
- package/src/gpu/kernels/matmul.js +271 -0
- package/src/gpu/kernels/matmul_f16.wgsl +170 -0
- package/src/gpu/kernels/matmul_f16_tiled.wgsl +165 -0
- package/src/gpu/kernels/matmul_f16w_f32a.wgsl +89 -0
- package/src/gpu/kernels/matmul_f16w_f32a_tiled.wgsl +154 -0
- package/src/gpu/kernels/matmul_f32.wgsl +100 -0
- package/src/gpu/kernels/matmul_gemv.wgsl +80 -0
- package/src/gpu/kernels/matmul_gemv_f16a.wgsl +81 -0
- package/src/gpu/kernels/matmul_gemv_residual.wgsl +119 -0
- package/src/gpu/kernels/matmul_gemv_residual_f16.wgsl +78 -0
- package/src/gpu/kernels/matmul_gemv_subgroup.wgsl +345 -0
- package/src/gpu/kernels/matmul_gemv_subgroup_f16a.wgsl +514 -0
- package/src/gpu/kernels/modulate.d.ts +29 -0
- package/src/gpu/kernels/modulate.js +49 -0
- package/src/gpu/kernels/modulate.wgsl +40 -0
- package/src/gpu/kernels/modulate_f16.wgsl +43 -0
- package/src/gpu/kernels/moe.d.ts +164 -0
- package/src/gpu/kernels/moe.js +496 -0
- package/src/gpu/kernels/moe_gather.wgsl +170 -0
- package/src/gpu/kernels/moe_gather_f16.wgsl +82 -0
- package/src/gpu/kernels/moe_gather_vec4.wgsl +74 -0
- package/src/gpu/kernels/moe_offsets.wgsl +48 -0
- package/src/gpu/kernels/pipeline-cache.d.ts +88 -0
- package/src/gpu/kernels/pipeline-cache.js +305 -0
- package/src/gpu/kernels/pixel_shuffle.d.ts +27 -0
- package/src/gpu/kernels/pixel_shuffle.js +49 -0
- package/src/gpu/kernels/pixel_shuffle.wgsl +44 -0
- package/src/gpu/kernels/pixel_shuffle_f16.wgsl +47 -0
- package/src/gpu/kernels/residual.d.ts +74 -0
- package/src/gpu/kernels/residual.js +127 -0
- package/src/gpu/kernels/residual.wgsl +53 -0
- package/src/gpu/kernels/residual_f16.wgsl +35 -0
- package/src/gpu/kernels/residual_f16_vec4.wgsl +47 -0
- package/src/gpu/kernels/residual_vec4.wgsl +46 -0
- package/src/gpu/kernels/rmsnorm.d.ts +53 -0
- package/src/gpu/kernels/rmsnorm.js +140 -0
- package/src/gpu/kernels/rmsnorm.wgsl +417 -0
- package/src/gpu/kernels/rmsnorm_f16.wgsl +164 -0
- package/src/gpu/kernels/rope.d.ts +48 -0
- package/src/gpu/kernels/rope.js +53 -0
- package/src/gpu/kernels/rope.wgsl +328 -0
- package/src/gpu/kernels/rope_f16.wgsl +271 -0
- package/src/gpu/kernels/rule-matcher.d.ts +30 -0
- package/src/gpu/kernels/rule-matcher.js +42 -0
- package/src/gpu/kernels/rule-registry.d.ts +7 -0
- package/src/gpu/kernels/rule-registry.js +41 -0
- package/src/gpu/kernels/sample.d.ts +75 -0
- package/src/gpu/kernels/sample.js +578 -0
- package/src/gpu/kernels/sample.wgsl +377 -0
- package/src/gpu/kernels/sample_f16.wgsl +331 -0
- package/src/gpu/kernels/scale.d.ts +35 -0
- package/src/gpu/kernels/scale.js +37 -0
- package/src/gpu/kernels/scale.wgsl +38 -0
- package/src/gpu/kernels/scatter_add.wgsl +88 -0
- package/src/gpu/kernels/scatter_add_dynamic.wgsl +59 -0
- package/src/gpu/kernels/scatter_add_dynamic_f16.wgsl +52 -0
- package/src/gpu/kernels/scatter_add_dynamic_f16_weights.wgsl +50 -0
- package/src/gpu/kernels/scatter_add_vec4.wgsl +70 -0
- package/src/gpu/kernels/shader-cache.d.ts +56 -0
- package/src/gpu/kernels/shader-cache.js +206 -0
- package/src/gpu/kernels/silu.d.ts +75 -0
- package/src/gpu/kernels/silu.js +340 -0
- package/src/gpu/kernels/silu.wgsl +99 -0
- package/src/gpu/kernels/silu_f16.wgsl +98 -0
- package/src/gpu/kernels/softmax.d.ts +57 -0
- package/src/gpu/kernels/softmax.js +106 -0
- package/src/gpu/kernels/softmax.wgsl +388 -0
- package/src/gpu/kernels/softmax_subgroup.wgsl +175 -0
- package/src/gpu/kernels/split_qkv.d.ts +51 -0
- package/src/gpu/kernels/split_qkv.js +41 -0
- package/src/gpu/kernels/split_qkv.wgsl +71 -0
- package/src/gpu/kernels/split_qkv_f16.wgsl +75 -0
- package/src/gpu/kernels/topk.wgsl +243 -0
- package/src/gpu/kernels/topk_f16.wgsl +108 -0
- package/src/gpu/kernels/topk_f16_weights.wgsl +101 -0
- package/src/gpu/kernels/transpose.d.ts +21 -0
- package/src/gpu/kernels/transpose.js +30 -0
- package/src/gpu/kernels/transpose.wgsl +32 -0
- package/src/gpu/kernels/types.d.ts +21 -0
- package/src/gpu/kernels/types.js +4 -0
- package/src/gpu/kernels/uniform-utils.d.ts +48 -0
- package/src/gpu/kernels/uniform-utils.js +94 -0
- package/src/gpu/kernels/upsample2d.d.ts +25 -0
- package/src/gpu/kernels/upsample2d.js +58 -0
- package/src/gpu/kernels/upsample2d.wgsl +37 -0
- package/src/gpu/kernels/upsample2d_f16.wgsl +41 -0
- package/src/gpu/kernels/utils.d.ts +106 -0
- package/src/gpu/kernels/utils.js +224 -0
- package/src/gpu/multi-model-recorder.d.ts +21 -0
- package/src/gpu/multi-model-recorder.js +31 -0
- package/src/gpu/partitioned-buffer-pool.d.ts +28 -0
- package/src/gpu/partitioned-buffer-pool.js +49 -0
- package/src/gpu/perf-guards.d.ts +25 -0
- package/src/gpu/perf-guards.js +140 -0
- package/src/gpu/profiler.d.ts +114 -0
- package/src/gpu/profiler.js +391 -0
- package/src/gpu/submit-tracker.d.ts +111 -0
- package/src/gpu/submit-tracker.js +229 -0
- package/src/gpu/tensor.d.ts +69 -0
- package/src/gpu/tensor.js +75 -0
- package/src/gpu/uniform-cache.d.ts +108 -0
- package/src/gpu/uniform-cache.js +242 -0
- package/src/gpu/weight-buffer.d.ts +115 -0
- package/src/gpu/weight-buffer.js +118 -0
- package/src/hotswap/intent-bundle.d.ts +37 -0
- package/src/hotswap/intent-bundle.js +123 -0
- package/src/hotswap/manifest.d.ts +33 -0
- package/src/hotswap/manifest.js +114 -0
- package/src/hotswap/runtime.d.ts +31 -0
- package/src/hotswap/runtime.js +128 -0
- package/src/index-browser.d.ts +47 -0
- package/src/index-browser.js +53 -0
- package/src/index-internal.d.ts +2 -0
- package/src/index-internal.js +2 -0
- package/src/index.d.ts +102 -0
- package/src/index.js +75 -0
- package/src/inference/README.md +593 -0
- package/src/inference/browser-harness.d.ts +234 -0
- package/src/inference/browser-harness.js +2665 -0
- package/src/inference/decode-buffers.d.ts +108 -0
- package/src/inference/decode-buffers.js +181 -0
- package/src/inference/decode-ring.d.ts +52 -0
- package/src/inference/decode-ring.js +273 -0
- package/src/inference/expert-router.d.ts +27 -0
- package/src/inference/expert-router.js +55 -0
- package/src/inference/functiongemma.d.ts +15 -0
- package/src/inference/functiongemma.js +1 -0
- package/src/inference/kv-cache/base.d.ts +150 -0
- package/src/inference/kv-cache/base.js +1037 -0
- package/src/inference/kv-cache/basis-decomposed-paged.d.ts +50 -0
- package/src/inference/kv-cache/basis-decomposed-paged.js +276 -0
- package/src/inference/kv-cache/index.d.ts +35 -0
- package/src/inference/kv-cache/index.js +20 -0
- package/src/inference/kv-cache/sliding-window.d.ts +72 -0
- package/src/inference/kv-cache/sliding-window.js +243 -0
- package/src/inference/kv-cache/tiered.d.ts +89 -0
- package/src/inference/kv-cache/tiered.js +574 -0
- package/src/inference/kv-cache/types.d.ts +188 -0
- package/src/inference/kv-cache/types.js +80 -0
- package/src/inference/kv-cache.d.ts +36 -0
- package/src/inference/kv-cache.js +18 -0
- package/src/inference/moe-router.d.ts +212 -0
- package/src/inference/moe-router.js +553 -0
- package/src/inference/multi-model-network.d.ts +139 -0
- package/src/inference/multi-model-network.js +769 -0
- package/src/inference/multi-pipeline-pool.d.ts +62 -0
- package/src/inference/multi-pipeline-pool.js +161 -0
- package/src/inference/network-evolution.d.ts +46 -0
- package/src/inference/network-evolution.js +80 -0
- package/src/inference/pipelines/context.d.ts +18 -0
- package/src/inference/pipelines/context.js +44 -0
- package/src/inference/pipelines/diffusion/helpers.d.ts +29 -0
- package/src/inference/pipelines/diffusion/helpers.js +112 -0
- package/src/inference/pipelines/diffusion/index.d.ts +3 -0
- package/src/inference/pipelines/diffusion/index.js +3 -0
- package/src/inference/pipelines/diffusion/init.d.ts +24 -0
- package/src/inference/pipelines/diffusion/init.js +124 -0
- package/src/inference/pipelines/diffusion/pipeline.d.ts +38 -0
- package/src/inference/pipelines/diffusion/pipeline.js +632 -0
- package/src/inference/pipelines/diffusion/scheduler.d.ts +19 -0
- package/src/inference/pipelines/diffusion/scheduler.js +65 -0
- package/src/inference/pipelines/diffusion/sd3-transformer.d.ts +20 -0
- package/src/inference/pipelines/diffusion/sd3-transformer.js +1194 -0
- package/src/inference/pipelines/diffusion/sd3-weights.d.ts +21 -0
- package/src/inference/pipelines/diffusion/sd3-weights.js +287 -0
- package/src/inference/pipelines/diffusion/text-encoder-gpu.d.ts +80 -0
- package/src/inference/pipelines/diffusion/text-encoder-gpu.js +935 -0
- package/src/inference/pipelines/diffusion/text-encoder.d.ts +29 -0
- package/src/inference/pipelines/diffusion/text-encoder.js +178 -0
- package/src/inference/pipelines/diffusion/types.d.ts +112 -0
- package/src/inference/pipelines/diffusion/types.js +1 -0
- package/src/inference/pipelines/diffusion/vae.d.ts +20 -0
- package/src/inference/pipelines/diffusion/vae.js +675 -0
- package/src/inference/pipelines/diffusion/weights.d.ts +40 -0
- package/src/inference/pipelines/diffusion/weights.js +150 -0
- package/src/inference/pipelines/dream/energy-head-pipeline.d.ts +29 -0
- package/src/inference/pipelines/dream/energy-head-pipeline.js +6 -0
- package/src/inference/pipelines/dream/pipeline.d.ts +17 -0
- package/src/inference/pipelines/dream/pipeline.js +8 -0
- package/src/inference/pipelines/energy/index.d.ts +1 -0
- package/src/inference/pipelines/energy/index.js +1 -0
- package/src/inference/pipelines/energy/pipeline.d.ts +27 -0
- package/src/inference/pipelines/energy/pipeline.js +680 -0
- package/src/inference/pipelines/energy/quintel.d.ts +87 -0
- package/src/inference/pipelines/energy/quintel.js +207 -0
- package/src/inference/pipelines/energy/types.d.ts +63 -0
- package/src/inference/pipelines/energy/types.js +1 -0
- package/src/inference/pipelines/energy-head/index.d.ts +6 -0
- package/src/inference/pipelines/energy-head/index.js +6 -0
- package/src/inference/pipelines/energy-head/row-head-pipeline.d.ts +103 -0
- package/src/inference/pipelines/energy-head/row-head-pipeline.js +487 -0
- package/src/inference/pipelines/factory.d.ts +10 -0
- package/src/inference/pipelines/factory.js +6 -0
- package/src/inference/pipelines/index.d.ts +22 -0
- package/src/inference/pipelines/index.js +19 -0
- package/src/inference/pipelines/registry.d.ts +15 -0
- package/src/inference/pipelines/registry.js +23 -0
- package/src/inference/pipelines/rng.d.ts +2 -0
- package/src/inference/pipelines/rng.js +17 -0
- package/src/inference/pipelines/structured/index.d.ts +8 -0
- package/src/inference/pipelines/structured/index.js +8 -0
- package/src/inference/pipelines/structured/json-head-pipeline.d.ts +58 -0
- package/src/inference/pipelines/structured/json-head-pipeline.js +181 -0
- package/src/inference/pipelines/text/attention/index.d.ts +24 -0
- package/src/inference/pipelines/text/attention/index.js +17 -0
- package/src/inference/pipelines/text/attention/projections.d.ts +101 -0
- package/src/inference/pipelines/text/attention/projections.js +435 -0
- package/src/inference/pipelines/text/attention/record.d.ts +36 -0
- package/src/inference/pipelines/text/attention/record.js +613 -0
- package/src/inference/pipelines/text/attention/run.d.ts +38 -0
- package/src/inference/pipelines/text/attention/run.js +826 -0
- package/src/inference/pipelines/text/attention/types.d.ts +98 -0
- package/src/inference/pipelines/text/attention/types.js +67 -0
- package/src/inference/pipelines/text/attention.d.ts +23 -0
- package/src/inference/pipelines/text/attention.js +12 -0
- package/src/inference/pipelines/text/bdpa-steamroller.d.ts +22 -0
- package/src/inference/pipelines/text/bdpa-steamroller.js +158 -0
- package/src/inference/pipelines/text/buffer-types.d.ts +7 -0
- package/src/inference/pipelines/text/buffer-types.js +4 -0
- package/src/inference/pipelines/text/chat-format.d.ts +46 -0
- package/src/inference/pipelines/text/chat-format.js +366 -0
- package/src/inference/pipelines/text/config.d.ts +235 -0
- package/src/inference/pipelines/text/config.js +623 -0
- package/src/inference/pipelines/text/debug-utils/config.d.ts +144 -0
- package/src/inference/pipelines/text/debug-utils/config.js +156 -0
- package/src/inference/pipelines/text/debug-utils/index.d.ts +53 -0
- package/src/inference/pipelines/text/debug-utils/index.js +44 -0
- package/src/inference/pipelines/text/debug-utils/logging.d.ts +106 -0
- package/src/inference/pipelines/text/debug-utils/logging.js +152 -0
- package/src/inference/pipelines/text/debug-utils/tensor.d.ts +119 -0
- package/src/inference/pipelines/text/debug-utils/tensor.js +268 -0
- package/src/inference/pipelines/text/debug-utils/utils.d.ts +77 -0
- package/src/inference/pipelines/text/debug-utils/utils.js +139 -0
- package/src/inference/pipelines/text/debug-utils.d.ts +42 -0
- package/src/inference/pipelines/text/debug-utils.js +34 -0
- package/src/inference/pipelines/text/embed.d.ts +67 -0
- package/src/inference/pipelines/text/embed.js +461 -0
- package/src/inference/pipelines/text/execution-plan.d.ts +116 -0
- package/src/inference/pipelines/text/execution-plan.js +314 -0
- package/src/inference/pipelines/text/execution-v0.d.ts +66 -0
- package/src/inference/pipelines/text/execution-v0.js +1139 -0
- package/src/inference/pipelines/text/ffn/dense.d.ts +40 -0
- package/src/inference/pipelines/text/ffn/dense.js +759 -0
- package/src/inference/pipelines/text/ffn/index.d.ts +23 -0
- package/src/inference/pipelines/text/ffn/index.js +16 -0
- package/src/inference/pipelines/text/ffn/moe.d.ts +21 -0
- package/src/inference/pipelines/text/ffn/moe.js +49 -0
- package/src/inference/pipelines/text/ffn/sandwich.d.ts +25 -0
- package/src/inference/pipelines/text/ffn/sandwich.js +196 -0
- package/src/inference/pipelines/text/ffn/standard.d.ts +23 -0
- package/src/inference/pipelines/text/ffn/standard.js +84 -0
- package/src/inference/pipelines/text/ffn/types.d.ts +30 -0
- package/src/inference/pipelines/text/ffn/types.js +25 -0
- package/src/inference/pipelines/text/ffn.d.ts +31 -0
- package/src/inference/pipelines/text/ffn.js +18 -0
- package/src/inference/pipelines/text/finiteness-guard-status.d.ts +11 -0
- package/src/inference/pipelines/text/finiteness-guard-status.js +21 -0
- package/src/inference/pipelines/text/finiteness-policy.d.ts +35 -0
- package/src/inference/pipelines/text/finiteness-policy.js +45 -0
- package/src/inference/pipelines/text/generator-helpers.d.ts +34 -0
- package/src/inference/pipelines/text/generator-helpers.js +175 -0
- package/src/inference/pipelines/text/generator-runtime.d.ts +93 -0
- package/src/inference/pipelines/text/generator-runtime.js +373 -0
- package/src/inference/pipelines/text/generator-steps.d.ts +75 -0
- package/src/inference/pipelines/text/generator-steps.js +1078 -0
- package/src/inference/pipelines/text/generator.d.ts +41 -0
- package/src/inference/pipelines/text/generator.js +1345 -0
- package/src/inference/pipelines/text/index.d.ts +5 -0
- package/src/inference/pipelines/text/index.js +6 -0
- package/src/inference/pipelines/text/init.d.ts +295 -0
- package/src/inference/pipelines/text/init.js +965 -0
- package/src/inference/pipelines/text/kernel-path-auto-select.d.ts +12 -0
- package/src/inference/pipelines/text/kernel-path-auto-select.js +90 -0
- package/src/inference/pipelines/text/kernel-trace.d.ts +150 -0
- package/src/inference/pipelines/text/kernel-trace.js +324 -0
- package/src/inference/pipelines/text/layer-plan.d.ts +65 -0
- package/src/inference/pipelines/text/layer-plan.js +249 -0
- package/src/inference/pipelines/text/layer.d.ts +56 -0
- package/src/inference/pipelines/text/layer.js +916 -0
- package/src/inference/pipelines/text/linear-attention.d.ts +94 -0
- package/src/inference/pipelines/text/linear-attention.js +803 -0
- package/src/inference/pipelines/text/logits/cpu.d.ts +81 -0
- package/src/inference/pipelines/text/logits/cpu.js +91 -0
- package/src/inference/pipelines/text/logits/gpu.d.ts +113 -0
- package/src/inference/pipelines/text/logits/gpu.js +406 -0
- package/src/inference/pipelines/text/logits/index.d.ts +57 -0
- package/src/inference/pipelines/text/logits/index.js +305 -0
- package/src/inference/pipelines/text/logits/types.d.ts +46 -0
- package/src/inference/pipelines/text/logits/types.js +4 -0
- package/src/inference/pipelines/text/logits/utils.d.ts +49 -0
- package/src/inference/pipelines/text/logits/utils.js +59 -0
- package/src/inference/pipelines/text/logits.d.ts +27 -0
- package/src/inference/pipelines/text/logits.js +16 -0
- package/src/inference/pipelines/text/lora-apply.d.ts +28 -0
- package/src/inference/pipelines/text/lora-apply.js +58 -0
- package/src/inference/pipelines/text/lora-types.d.ts +39 -0
- package/src/inference/pipelines/text/lora-types.js +18 -0
- package/src/inference/pipelines/text/lora.d.ts +18 -0
- package/src/inference/pipelines/text/lora.js +12 -0
- package/src/inference/pipelines/text/model-load.d.ts +58 -0
- package/src/inference/pipelines/text/model-load.js +561 -0
- package/src/inference/pipelines/text/moe-cache.d.ts +32 -0
- package/src/inference/pipelines/text/moe-cache.js +107 -0
- package/src/inference/pipelines/text/moe-cpu-gptoss.d.ts +9 -0
- package/src/inference/pipelines/text/moe-cpu-gptoss.js +110 -0
- package/src/inference/pipelines/text/moe-cpu.d.ts +13 -0
- package/src/inference/pipelines/text/moe-cpu.js +116 -0
- package/src/inference/pipelines/text/moe-gpu.d.ts +13 -0
- package/src/inference/pipelines/text/moe-gpu.js +611 -0
- package/src/inference/pipelines/text/moe-helpers.d.ts +12 -0
- package/src/inference/pipelines/text/moe-helpers.js +21 -0
- package/src/inference/pipelines/text/moe-impl.d.ts +117 -0
- package/src/inference/pipelines/text/moe-impl.js +9 -0
- package/src/inference/pipelines/text/moe-shape-validator.d.ts +31 -0
- package/src/inference/pipelines/text/moe-shape-validator.js +78 -0
- package/src/inference/pipelines/text/ops.d.ts +167 -0
- package/src/inference/pipelines/text/ops.js +367 -0
- package/src/inference/pipelines/text/probes.d.ts +31 -0
- package/src/inference/pipelines/text/probes.js +170 -0
- package/src/inference/pipelines/text/sampling.d.ts +54 -0
- package/src/inference/pipelines/text/sampling.js +203 -0
- package/src/inference/pipelines/text/state.d.ts +112 -0
- package/src/inference/pipelines/text/state.js +152 -0
- package/src/inference/pipelines/text/types.d.ts +627 -0
- package/src/inference/pipelines/text/types.js +4 -0
- package/src/inference/pipelines/text/weights.d.ts +110 -0
- package/src/inference/pipelines/text/weights.js +163 -0
- package/src/inference/pipelines/text.d.ts +157 -0
- package/src/inference/pipelines/text.js +586 -0
- package/src/inference/speculative.d.ts +239 -0
- package/src/inference/speculative.js +416 -0
- package/src/inference/test-harness.d.ts +178 -0
- package/src/inference/test-harness.js +349 -0
- package/src/inference/tokenizer.d.ts +77 -0
- package/src/inference/tokenizer.js +258 -0
- package/src/inference/tokenizers/base.d.ts +39 -0
- package/src/inference/tokenizers/base.js +69 -0
- package/src/inference/tokenizers/bpe.d.ts +27 -0
- package/src/inference/tokenizers/bpe.js +171 -0
- package/src/inference/tokenizers/bundled.d.ts +63 -0
- package/src/inference/tokenizers/bundled.js +866 -0
- package/src/inference/tokenizers/sentencepiece.d.ts +28 -0
- package/src/inference/tokenizers/sentencepiece.js +389 -0
- package/src/inference/tokenizers/types.d.ts +166 -0
- package/src/inference/tokenizers/types.js +7 -0
- package/src/loader/doppler-loader.d.ts +134 -0
- package/src/loader/doppler-loader.js +1036 -0
- package/src/loader/dtype-utils.d.ts +40 -0
- package/src/loader/dtype-utils.js +102 -0
- package/src/loader/embedding-loader.d.ts +56 -0
- package/src/loader/embedding-loader.js +207 -0
- package/src/loader/experts/expert-cache.d.ts +156 -0
- package/src/loader/experts/expert-cache.js +375 -0
- package/src/loader/experts/expert-loader.d.ts +108 -0
- package/src/loader/experts/expert-loader.js +384 -0
- package/src/loader/final-weights-loader.d.ts +68 -0
- package/src/loader/final-weights-loader.js +262 -0
- package/src/loader/index.d.ts +150 -0
- package/src/loader/index.js +124 -0
- package/src/loader/layer-loader.d.ts +63 -0
- package/src/loader/layer-loader.js +417 -0
- package/src/loader/loader-state.d.ts +51 -0
- package/src/loader/loader-state.js +142 -0
- package/src/loader/loader-types.d.ts +236 -0
- package/src/loader/loader-types.js +4 -0
- package/src/loader/manifest-config.d.ts +97 -0
- package/src/loader/manifest-config.js +132 -0
- package/src/loader/memory-monitor.d.ts +112 -0
- package/src/loader/memory-monitor.js +276 -0
- package/src/loader/multi-model-loader.d.ts +37 -0
- package/src/loader/multi-model-loader.js +87 -0
- package/src/loader/quantization-constants.d.ts +23 -0
- package/src/loader/quantization-constants.js +14 -0
- package/src/loader/shard-cache.d.ts +60 -0
- package/src/loader/shard-cache.js +568 -0
- package/src/loader/shard-resolver.d.ts +12 -0
- package/src/loader/shard-resolver.js +83 -0
- package/src/loader/tensors/tensor-loader.d.ts +154 -0
- package/src/loader/tensors/tensor-loader.js +427 -0
- package/src/loader/tensors/tensor-reader.d.ts +22 -0
- package/src/loader/tensors/tensor-reader.js +56 -0
- package/src/loader/tensors/tensor-role.d.ts +7 -0
- package/src/loader/tensors/tensor-role.js +12 -0
- package/src/loader/weight-downcast.d.ts +62 -0
- package/src/loader/weight-downcast.js +213 -0
- package/src/loader/weights.d.ts +22 -0
- package/src/loader/weights.js +4 -0
- package/src/memory/address-table.d.ts +104 -0
- package/src/memory/address-table.js +114 -0
- package/src/memory/buffer-pool.d.ts +196 -0
- package/src/memory/buffer-pool.js +756 -0
- package/src/memory/capability.d.ts +49 -0
- package/src/memory/capability.js +95 -0
- package/src/memory/heap-manager.d.ts +104 -0
- package/src/memory/heap-manager.js +264 -0
- package/src/memory/unified-detect.d.ts +59 -0
- package/src/memory/unified-detect.js +192 -0
- package/src/rules/converter/execution.rules.json +20 -0
- package/src/rules/converter/tensor-roles.rules.json +13 -0
- package/src/rules/converter/tokenizer.rules.json +7 -0
- package/src/rules/inference/attention.rules.json +54 -0
- package/src/rules/inference/config.rules.json +58 -0
- package/src/rules/inference/dtype.rules.json +94 -0
- package/src/rules/inference/execution.rules.json +45 -0
- package/src/rules/inference/ffn.rules.json +35 -0
- package/src/rules/inference/kernel-path.rules.json +76 -0
- package/src/rules/inference/layer-pattern.rules.json +16 -0
- package/src/rules/inference/layer.rules.json +7 -0
- package/src/rules/inference/moe.rules.json +48 -0
- package/src/rules/kernels/attention.rules.json +61 -0
- package/src/rules/kernels/conv2d.rules.json +6 -0
- package/src/rules/kernels/dequant.rules.json +58 -0
- package/src/rules/kernels/energy.rules.json +22 -0
- package/src/rules/kernels/fused-ffn.rules.json +13 -0
- package/src/rules/kernels/fused-matmul-residual.rules.json +6 -0
- package/src/rules/kernels/fused-matmul-rmsnorm.rules.json +8 -0
- package/src/rules/kernels/gather.rules.json +12 -0
- package/src/rules/kernels/gelu.rules.json +11 -0
- package/src/rules/kernels/groupnorm.rules.json +10 -0
- package/src/rules/kernels/kernel-validator.d.ts +24 -0
- package/src/rules/kernels/kernel-validator.js +160 -0
- package/src/rules/kernels/kv_quantize.rules.json +7 -0
- package/src/rules/kernels/layernorm.rules.json +6 -0
- package/src/rules/kernels/matmul.rules.json +60 -0
- package/src/rules/kernels/modulate.rules.json +6 -0
- package/src/rules/kernels/moe.rules.gptoss.json +105 -0
- package/src/rules/kernels/moe.rules.json +11 -0
- package/src/rules/kernels/pixel_shuffle.rules.json +6 -0
- package/src/rules/kernels/residual.rules.json +12 -0
- package/src/rules/kernels/rmsnorm.rules.json +11 -0
- package/src/rules/kernels/rope.rules.json +6 -0
- package/src/rules/kernels/sample.rules.json +6 -0
- package/src/rules/kernels/scale.rules.json +6 -0
- package/src/rules/kernels/silu.rules.json +21 -0
- package/src/rules/kernels/softmax.rules.json +23 -0
- package/src/rules/kernels/split-qkv.rules.json +6 -0
- package/src/rules/kernels/upsample2d.rules.json +6 -0
- package/src/rules/loader/tensor-loader.rules.json +15 -0
- package/src/rules/loader/weights.rules.json +41 -0
- package/src/rules/rule-registry.d.ts +48 -0
- package/src/rules/rule-registry.js +177 -0
- package/src/rules/tooling/command-runtime.rules.json +38 -0
- package/src/storage/backends/idb-store.d.ts +52 -0
- package/src/storage/backends/idb-store.js +590 -0
- package/src/storage/backends/memory-store.d.ts +36 -0
- package/src/storage/backends/memory-store.js +242 -0
- package/src/storage/backends/opfs-store.d.ts +41 -0
- package/src/storage/backends/opfs-store.js +429 -0
- package/src/storage/blake3.d.ts +17 -0
- package/src/storage/blake3.js +269 -0
- package/src/storage/download-types.d.ts +157 -0
- package/src/storage/download-types.js +48 -0
- package/src/storage/downloader.d.ts +103 -0
- package/src/storage/downloader.js +839 -0
- package/src/storage/emulated-vram.d.ts +264 -0
- package/src/storage/emulated-vram.js +576 -0
- package/src/storage/export.d.ts +20 -0
- package/src/storage/export.js +159 -0
- package/src/storage/index.d.ts +253 -0
- package/src/storage/index.js +185 -0
- package/src/storage/inventory.d.ts +26 -0
- package/src/storage/inventory.js +218 -0
- package/src/storage/preflight.d.ts +144 -0
- package/src/storage/preflight.js +294 -0
- package/src/storage/quickstart-downloader.d.ts +154 -0
- package/src/storage/quickstart-downloader.js +265 -0
- package/src/storage/quota.d.ts +150 -0
- package/src/storage/quota.js +304 -0
- package/src/storage/registry.d.ts +28 -0
- package/src/storage/registry.js +125 -0
- package/src/storage/reports.d.ts +20 -0
- package/src/storage/reports.js +94 -0
- package/src/storage/shard-manager.d.ts +137 -0
- package/src/storage/shard-manager.js +801 -0
- package/src/sw.d.ts +1 -0
- package/src/sw.js +187 -0
- package/src/tooling/browser-command-runner.d.ts +28 -0
- package/src/tooling/browser-command-runner.js +82 -0
- package/src/tooling/command-api.d.ts +147 -0
- package/src/tooling/command-api.js +523 -0
- package/src/tooling/command-envelope.d.ts +81 -0
- package/src/tooling/command-envelope.js +195 -0
- package/src/tooling/command-runner-shared.d.ts +73 -0
- package/src/tooling/command-runner-shared.js +146 -0
- package/src/tooling/command-runner.html +45 -0
- package/src/tooling/node-browser-command-runner.d.ts +30 -0
- package/src/tooling/node-browser-command-runner.js +868 -0
- package/src/tooling/node-command-runner.d.ts +36 -0
- package/src/tooling/node-command-runner.js +127 -0
- package/src/tooling/node-convert-worker-pool.d.ts +16 -0
- package/src/tooling/node-convert-worker-pool.js +186 -0
- package/src/tooling/node-convert-worker.d.ts +1 -0
- package/src/tooling/node-convert-worker.js +60 -0
- package/src/tooling/node-convert.d.ts +44 -0
- package/src/tooling/node-converter.d.ts +1 -0
- package/src/tooling/node-converter.js +1227 -0
- package/src/tooling/node-file-fetch.d.ts +1 -0
- package/src/tooling/node-file-fetch.js +38 -0
- package/src/tooling/node-source-runtime.d.ts +19 -0
- package/src/tooling/node-source-runtime.js +469 -0
- package/src/tooling/node-webgpu.d.ts +6 -0
- package/src/tooling/node-webgpu.js +321 -0
- package/src/tooling/opfs-cache.d.ts +11 -0
- package/src/tooling/opfs-cache.js +174 -0
- package/src/tooling/source-runtime-bundle.d.ts +102 -0
- package/src/tooling/source-runtime-bundle.js +484 -0
- package/src/tooling-exports.browser.d.ts +7 -0
- package/src/tooling-exports.browser.js +2 -0
- package/src/tooling-exports.d.ts +22 -0
- package/src/tooling-exports.js +7 -0
- package/src/tooling-exports.shared.d.ts +105 -0
- package/src/tooling-exports.shared.js +92 -0
- package/src/training/README.md +153 -0
- package/src/training/artifacts.d.ts +160 -0
- package/src/training/artifacts.js +896 -0
- package/src/training/attention-backward.d.ts +30 -0
- package/src/training/attention-backward.js +217 -0
- package/src/training/attention-forward.d.ts +22 -0
- package/src/training/attention-forward.js +82 -0
- package/src/training/autograd.d.ts +51 -0
- package/src/training/autograd.js +380 -0
- package/src/training/checkpoint.d.ts +31 -0
- package/src/training/checkpoint.js +238 -0
- package/src/training/clip.d.ts +9 -0
- package/src/training/clip.js +54 -0
- package/src/training/dataloader.d.ts +8 -0
- package/src/training/dataloader.js +44 -0
- package/src/training/datasets/index.d.ts +12 -0
- package/src/training/datasets/index.js +6 -0
- package/src/training/datasets/jsonl.d.ts +11 -0
- package/src/training/datasets/jsonl.js +50 -0
- package/src/training/datasets/reploid.d.ts +3 -0
- package/src/training/datasets/reploid.js +36 -0
- package/src/training/datasets/text-pairs.d.ts +21 -0
- package/src/training/datasets/text-pairs.js +42 -0
- package/src/training/datasets/token-batch.d.ts +21 -0
- package/src/training/datasets/token-batch.js +40 -0
- package/src/training/datasets/translation-pairs.d.ts +34 -0
- package/src/training/datasets/translation-pairs.js +49 -0
- package/src/training/export.d.ts +32 -0
- package/src/training/export.js +112 -0
- package/src/training/index.d.ts +52 -0
- package/src/training/index.js +41 -0
- package/src/training/lora.d.ts +19 -0
- package/src/training/lora.js +57 -0
- package/src/training/loss-scaling.d.ts +21 -0
- package/src/training/loss-scaling.js +80 -0
- package/src/training/loss.d.ts +10 -0
- package/src/training/loss.js +41 -0
- package/src/training/objectives/base.d.ts +58 -0
- package/src/training/objectives/base.js +38 -0
- package/src/training/objectives/cross_entropy.d.ts +18 -0
- package/src/training/objectives/cross_entropy.js +37 -0
- package/src/training/objectives/distill_kd.d.ts +16 -0
- package/src/training/objectives/distill_kd.js +369 -0
- package/src/training/objectives/distill_triplet.d.ts +16 -0
- package/src/training/objectives/distill_triplet.js +412 -0
- package/src/training/objectives/index.d.ts +12 -0
- package/src/training/objectives/index.js +6 -0
- package/src/training/objectives/ul_stage1_joint.d.ts +16 -0
- package/src/training/objectives/ul_stage1_joint.js +188 -0
- package/src/training/objectives/ul_stage2_base.d.ts +16 -0
- package/src/training/objectives/ul_stage2_base.js +222 -0
- package/src/training/optimizer.d.ts +22 -0
- package/src/training/optimizer.js +115 -0
- package/src/training/runner.d.ts +196 -0
- package/src/training/runner.js +1194 -0
- package/src/training/suite.d.ts +187 -0
- package/src/training/suite.js +3156 -0
- package/src/training/trainer.d.ts +89 -0
- package/src/training/trainer.js +301 -0
- package/src/training/ul_dataset.d.ts +47 -0
- package/src/training/ul_dataset.js +153 -0
- package/src/training/ul_schedule.d.ts +6 -0
- package/src/training/ul_schedule.js +29 -0
- package/src/types/chrome.d.ts +36 -0
- package/src/types/chrome.js +1 -0
- package/src/types/gpu.d.ts +185 -0
- package/src/types/gpu.js +5 -0
- package/src/types/index.d.ts +3 -0
- package/src/types/index.js +3 -0
- package/src/types/inference.d.ts +197 -0
- package/src/types/inference.js +5 -0
- package/src/types/model.d.ts +125 -0
- package/src/types/model.js +5 -0
- package/src/utils/index.d.ts +7 -0
- package/src/utils/index.js +7 -0
- package/src/utils/load-json.d.ts +5 -0
- package/src/utils/load-json.js +23 -0
- package/src/utils/plain-object.d.ts +1 -0
- package/src/utils/plain-object.js +3 -0
- package/src/utils/sha256.d.ts +4 -0
- package/src/utils/sha256.js +135 -0
- package/tools/convert-safetensors-node.js +180 -0
- package/tools/doppler-cli.js +1170 -0
|
@@ -0,0 +1,3156 @@
|
|
|
1
|
+
import { initDevice, getKernelCapabilities, getDevice } from '../gpu/device.js';
|
|
2
|
+
import { setPlatformsBaseUrl } from '../config/platforms/loader.js';
|
|
3
|
+
import { setRegistryUrl } from '../config/kernels/registry.js';
|
|
4
|
+
import { createTrainingConfig } from '../config/training-defaults.js';
|
|
5
|
+
import {
|
|
6
|
+
runAttention,
|
|
7
|
+
castF16ToF32,
|
|
8
|
+
runGather,
|
|
9
|
+
runMatmul,
|
|
10
|
+
runResidualAdd,
|
|
11
|
+
runRMSNorm,
|
|
12
|
+
runRoPE,
|
|
13
|
+
runSiLURowSplit,
|
|
14
|
+
} from '../gpu/kernels/index.js';
|
|
15
|
+
import { createTensor } from '../gpu/tensor.js';
|
|
16
|
+
import { acquireBuffer, uploadData, releaseBuffer } from '../memory/buffer-pool.js';
|
|
17
|
+
import { getBufferDtype, getWeightDtype, isCpuWeightBuffer, isWeightBuffer } from '../gpu/weight-buffer.js';
|
|
18
|
+
import { OpType } from './autograd.js';
|
|
19
|
+
import { AdamOptimizer } from './optimizer.js';
|
|
20
|
+
import { TrainingRunner } from './runner.js';
|
|
21
|
+
import { trainStep } from './trainer.js';
|
|
22
|
+
import { crossEntropyLoss } from './loss.js';
|
|
23
|
+
import { clipGradients } from './clip.js';
|
|
24
|
+
import { exportLoRAAdapter } from './export.js';
|
|
25
|
+
import { sha256Hex } from '../utils/sha256.js';
|
|
26
|
+
import { computeSampleStats } from '../debug/stats.js';
|
|
27
|
+
import { parseJsonl } from './datasets/jsonl.js';
|
|
28
|
+
import { initializeInference } from '../inference/test-harness.js';
|
|
29
|
+
import { createPipeline } from '../inference/pipelines/text.js';
|
|
30
|
+
import { parseManifest } from '../formats/rdrr/index.js';
|
|
31
|
+
import { openModelStore, loadManifestFromStore } from '../storage/shard-manager.js';
|
|
32
|
+
|
|
33
|
+
const LEGACY_BROWSER_TESTS = Object.freeze([
|
|
34
|
+
'loss-forward',
|
|
35
|
+
'softmax-backward',
|
|
36
|
+
'cross-entropy-backward',
|
|
37
|
+
'rmsnorm-backward',
|
|
38
|
+
'layernorm-backward',
|
|
39
|
+
'conv2d-backward',
|
|
40
|
+
'matmul-backward',
|
|
41
|
+
'embed-backward',
|
|
42
|
+
'ebm-state-optimize',
|
|
43
|
+
'ebm-recorded-bench',
|
|
44
|
+
'parity-fixture',
|
|
45
|
+
'training-leak-perf',
|
|
46
|
+
'autograd-branching',
|
|
47
|
+
]);
|
|
48
|
+
const TRAINING_COMMAND_SCHEMA_VERSION = 1;
|
|
49
|
+
const DISTILL_ADAPTER_TOP_K = 64;
|
|
50
|
+
const DISTILL_LOGIT_FALLBACK = -80;
|
|
51
|
+
const DISTILL_STUDENT_GRAPH_PROJECTION = 'projection_head';
|
|
52
|
+
const DISTILL_STUDENT_GRAPH_FULL = 'transformer_full';
|
|
53
|
+
|
|
54
|
+
function buildSuiteSummary(suiteName, results, startTimeMs) {
|
|
55
|
+
let passed = 0;
|
|
56
|
+
let failed = 0;
|
|
57
|
+
let skipped = 0;
|
|
58
|
+
for (const result of results) {
|
|
59
|
+
if (result.skipped) {
|
|
60
|
+
skipped++;
|
|
61
|
+
} else if (result.passed) {
|
|
62
|
+
passed++;
|
|
63
|
+
} else {
|
|
64
|
+
failed++;
|
|
65
|
+
}
|
|
66
|
+
}
|
|
67
|
+
return {
|
|
68
|
+
suite: suiteName,
|
|
69
|
+
passed,
|
|
70
|
+
failed,
|
|
71
|
+
skipped,
|
|
72
|
+
duration: Math.max(0, performance.now() - startTimeMs),
|
|
73
|
+
results,
|
|
74
|
+
};
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
function normalizeTrainingTestNames(names) {
|
|
78
|
+
if (!Array.isArray(names)) return null;
|
|
79
|
+
const normalized = names
|
|
80
|
+
.map((name) => String(name || '').trim())
|
|
81
|
+
.filter(Boolean);
|
|
82
|
+
return normalized.length > 0 ? normalized : null;
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
function assertTrainingSchemaVersion(value) {
|
|
86
|
+
if (value === undefined || value === null) {
|
|
87
|
+
return TRAINING_COMMAND_SCHEMA_VERSION;
|
|
88
|
+
}
|
|
89
|
+
const parsed = Number(value);
|
|
90
|
+
if (!Number.isInteger(parsed) || parsed !== TRAINING_COMMAND_SCHEMA_VERSION) {
|
|
91
|
+
throw new Error(`trainingSchemaVersion must be ${TRAINING_COMMAND_SCHEMA_VERSION}.`);
|
|
92
|
+
}
|
|
93
|
+
return parsed;
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
function makeTensorFromFloat32(values, shape, label) {
|
|
97
|
+
const data = values instanceof Float32Array ? values : new Float32Array(values);
|
|
98
|
+
const buffer = acquireBuffer(data.byteLength, undefined, label || 'train_tensor');
|
|
99
|
+
uploadData(buffer, data);
|
|
100
|
+
return createTensor(buffer, 'f32', shape, label || 'train_tensor');
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
function makeTensorFromF16Bits(values, shape, label) {
|
|
104
|
+
const data = values instanceof Uint16Array ? values : new Uint16Array(values);
|
|
105
|
+
const buffer = acquireBuffer(data.byteLength, undefined, label || 'train_tensor_f16');
|
|
106
|
+
uploadData(buffer, data);
|
|
107
|
+
return createTensor(buffer, 'f16', shape, label || 'train_tensor_f16');
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
function makeTensorFromUint32(values, shape, label) {
|
|
111
|
+
const data = values instanceof Uint32Array ? values : new Uint32Array(values);
|
|
112
|
+
const buffer = acquireBuffer(data.byteLength, undefined, label || 'train_tokens');
|
|
113
|
+
uploadData(buffer, data);
|
|
114
|
+
// Token tensors are wrapped as f32 by contract; kernels read the underlying u32 bytes.
|
|
115
|
+
return createTensor(buffer, 'f32', shape, label || 'train_tokens');
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
function releaseTensor(tensor) {
|
|
119
|
+
if (!tensor?.buffer) return;
|
|
120
|
+
releaseBuffer(tensor.buffer);
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
function isFiniteNumber(value) {
|
|
124
|
+
return typeof value === 'number' && Number.isFinite(value);
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
function isNodeRuntime() {
|
|
128
|
+
return typeof process !== 'undefined' && !!process.versions?.node;
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
function normalizeOptionalString(value) {
|
|
132
|
+
if (value === undefined || value === null) return null;
|
|
133
|
+
const trimmed = String(value).trim();
|
|
134
|
+
return trimmed || null;
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
function normalizeDistillDatasetPath(value) {
|
|
138
|
+
return normalizeOptionalString(value);
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
function normalizeLangCode(value) {
|
|
142
|
+
const normalized = normalizeOptionalString(value);
|
|
143
|
+
if (!normalized) return null;
|
|
144
|
+
const compact = normalized.toLowerCase().replace(/_/g, '-');
|
|
145
|
+
if (compact.startsWith('en')) return 'en';
|
|
146
|
+
if (compact.startsWith('es')) return 'es';
|
|
147
|
+
return compact;
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
function normalizePairDirection(value) {
|
|
151
|
+
const pair = normalizeOptionalString(value);
|
|
152
|
+
if (!pair) return null;
|
|
153
|
+
const normalized = pair.toLowerCase().replace(/_/g, '-').replace(/\s+/g, '');
|
|
154
|
+
const parts = normalized.includes('->')
|
|
155
|
+
? normalized.split('->').filter(Boolean)
|
|
156
|
+
: normalized.split('-').filter(Boolean);
|
|
157
|
+
if (parts.length !== 2) return null;
|
|
158
|
+
return `${normalizeLangCode(parts[0]) || parts[0]}->${normalizeLangCode(parts[1]) || parts[1]}`;
|
|
159
|
+
}
|
|
160
|
+
|
|
161
|
+
function normalizeOptionalStringArray(value) {
|
|
162
|
+
if (value === undefined || value === null) return null;
|
|
163
|
+
const list = Array.isArray(value)
|
|
164
|
+
? value
|
|
165
|
+
: (typeof value === 'string' ? value.split(',') : null);
|
|
166
|
+
if (!Array.isArray(list)) return null;
|
|
167
|
+
const normalized = list
|
|
168
|
+
.map((entry) => normalizeOptionalString(entry))
|
|
169
|
+
.filter(Boolean);
|
|
170
|
+
return normalized.length > 0 ? normalized : null;
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
function normalizeDistillLanguageAllowlist(value) {
|
|
174
|
+
const list = normalizeOptionalStringArray(value);
|
|
175
|
+
if (!list) return null;
|
|
176
|
+
const normalized = list
|
|
177
|
+
.map((entry) => normalizeLangCode(entry))
|
|
178
|
+
.filter(Boolean);
|
|
179
|
+
if (normalized.length === 0) return null;
|
|
180
|
+
return [...new Set(normalized)];
|
|
181
|
+
}
|
|
182
|
+
|
|
183
|
+
function normalizeDistillPairAllowlist(value) {
|
|
184
|
+
const list = normalizeOptionalStringArray(value);
|
|
185
|
+
if (!list) return null;
|
|
186
|
+
const normalized = list
|
|
187
|
+
.map((entry) => normalizePairDirection(entry))
|
|
188
|
+
.filter(Boolean);
|
|
189
|
+
if (normalized.length === 0) return null;
|
|
190
|
+
return [...new Set(normalized)];
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
function resolveDistillDataScope(options = {}, trainingConfig = null) {
|
|
194
|
+
const distillConfig = trainingConfig?.distill || {};
|
|
195
|
+
const sourceLangs = normalizeDistillLanguageAllowlist(
|
|
196
|
+
options.distillSourceLangs ?? distillConfig.sourceLangs ?? null
|
|
197
|
+
);
|
|
198
|
+
const targetLangs = normalizeDistillLanguageAllowlist(
|
|
199
|
+
options.distillTargetLangs ?? distillConfig.targetLangs ?? null
|
|
200
|
+
);
|
|
201
|
+
const pairAllowlist = normalizeDistillPairAllowlist(
|
|
202
|
+
options.distillPairAllowlist ?? distillConfig.pairAllowlist ?? null
|
|
203
|
+
);
|
|
204
|
+
const strictPairContract = (
|
|
205
|
+
options.strictPairContract === true
|
|
206
|
+
|| distillConfig.strictPairContract === true
|
|
207
|
+
);
|
|
208
|
+
return {
|
|
209
|
+
sourceLangs,
|
|
210
|
+
targetLangs,
|
|
211
|
+
pairAllowlist,
|
|
212
|
+
sourceLangSet: sourceLangs ? new Set(sourceLangs) : null,
|
|
213
|
+
targetLangSet: targetLangs ? new Set(targetLangs) : null,
|
|
214
|
+
pairAllowlistSet: pairAllowlist ? new Set(pairAllowlist) : null,
|
|
215
|
+
strictPairContract,
|
|
216
|
+
};
|
|
217
|
+
}
|
|
218
|
+
|
|
219
|
+
function resolveDistillDirection(record) {
|
|
220
|
+
const pairDirection = normalizePairDirection(record?.pair);
|
|
221
|
+
if (pairDirection) return pairDirection;
|
|
222
|
+
const srcLang = normalizeLangCode(record?.src_lang);
|
|
223
|
+
const tgtLang = normalizeLangCode(record?.tgt_lang || record?.lang);
|
|
224
|
+
if (srcLang && tgtLang) {
|
|
225
|
+
return `${srcLang}->${tgtLang}`;
|
|
226
|
+
}
|
|
227
|
+
return null;
|
|
228
|
+
}
|
|
229
|
+
|
|
230
|
+
function resolveStringCandidate(record, keys) {
|
|
231
|
+
for (const key of keys) {
|
|
232
|
+
const value = normalizeOptionalString(record?.[key]);
|
|
233
|
+
if (value) return value;
|
|
234
|
+
}
|
|
235
|
+
return null;
|
|
236
|
+
}
|
|
237
|
+
|
|
238
|
+
function encodeDistillRow(record, index, scope = null) {
|
|
239
|
+
if (!record || typeof record !== 'object') return null;
|
|
240
|
+
const source = resolveStringCandidate(record, ['source', 'query']);
|
|
241
|
+
const targetPos = resolveStringCandidate(record, ['target_pos', 'target', 'pos']);
|
|
242
|
+
const targetNeg = resolveStringCandidate(record, ['target_neg', 'neg']);
|
|
243
|
+
if (!source || !targetPos) return null;
|
|
244
|
+
const sourceLangRaw = normalizeLangCode(record?.src_lang);
|
|
245
|
+
const targetLangRaw = normalizeLangCode(record?.tgt_lang || record?.lang);
|
|
246
|
+
const pairDirection = normalizePairDirection(record?.pair);
|
|
247
|
+
const sourceTargetDirection = (
|
|
248
|
+
sourceLangRaw && targetLangRaw
|
|
249
|
+
? `${sourceLangRaw}->${targetLangRaw}`
|
|
250
|
+
: null
|
|
251
|
+
);
|
|
252
|
+
if (scope?.strictPairContract === true) {
|
|
253
|
+
if (!sourceLangRaw || !targetLangRaw) {
|
|
254
|
+
throw new Error('strictPairContract requires src_lang and tgt_lang/lang on each row.');
|
|
255
|
+
}
|
|
256
|
+
if (!pairDirection) {
|
|
257
|
+
throw new Error('strictPairContract requires pair on each row.');
|
|
258
|
+
}
|
|
259
|
+
if (pairDirection !== sourceTargetDirection) {
|
|
260
|
+
throw new Error(`pair "${record?.pair}" does not match src/tgt "${sourceLangRaw}-${targetLangRaw}".`);
|
|
261
|
+
}
|
|
262
|
+
}
|
|
263
|
+
const direction = pairDirection || sourceTargetDirection || resolveDistillDirection(record) || 'unknown';
|
|
264
|
+
const [directionSourceLang, directionTargetLang] = String(direction).split('->');
|
|
265
|
+
const sourceLang = sourceLangRaw || normalizeLangCode(directionSourceLang);
|
|
266
|
+
const targetLang = targetLangRaw || normalizeLangCode(directionTargetLang);
|
|
267
|
+
if (scope?.sourceLangSet && (!sourceLang || !scope.sourceLangSet.has(sourceLang))) {
|
|
268
|
+
return null;
|
|
269
|
+
}
|
|
270
|
+
if (scope?.targetLangSet && (!targetLang || !scope.targetLangSet.has(targetLang))) {
|
|
271
|
+
return null;
|
|
272
|
+
}
|
|
273
|
+
if (scope?.pairAllowlistSet && !scope.pairAllowlistSet.has(direction)) {
|
|
274
|
+
return null;
|
|
275
|
+
}
|
|
276
|
+
|
|
277
|
+
return {
|
|
278
|
+
index,
|
|
279
|
+
direction,
|
|
280
|
+
sourceLang: sourceLang || null,
|
|
281
|
+
targetLang: targetLang || null,
|
|
282
|
+
source,
|
|
283
|
+
targetPos,
|
|
284
|
+
targetNeg: targetNeg || null,
|
|
285
|
+
};
|
|
286
|
+
}
|
|
287
|
+
|
|
288
|
+
function summarizeDirectionCounts(samples) {
|
|
289
|
+
const counts = {};
|
|
290
|
+
for (const sample of samples) {
|
|
291
|
+
const key = sample?.direction || 'unknown';
|
|
292
|
+
counts[key] = (counts[key] || 0) + 1;
|
|
293
|
+
}
|
|
294
|
+
return counts;
|
|
295
|
+
}
|
|
296
|
+
|
|
297
|
+
function resolveLanguageName(langCode) {
|
|
298
|
+
const normalized = normalizeLangCode(langCode);
|
|
299
|
+
if (normalized === 'en') return 'English';
|
|
300
|
+
if (normalized === 'es') return 'Spanish';
|
|
301
|
+
return normalized || 'target';
|
|
302
|
+
}
|
|
303
|
+
|
|
304
|
+
function buildDistillPrompt(sample) {
|
|
305
|
+
const direction = String(sample?.direction || '').trim();
|
|
306
|
+
const [srcCodeRaw, tgtCodeRaw] = direction.split('->');
|
|
307
|
+
const srcCode = normalizeLangCode(srcCodeRaw) || srcCodeRaw || 'source';
|
|
308
|
+
const tgtCode = normalizeLangCode(tgtCodeRaw) || tgtCodeRaw || 'target';
|
|
309
|
+
const srcName = resolveLanguageName(srcCode);
|
|
310
|
+
const tgtName = resolveLanguageName(tgtCode);
|
|
311
|
+
const source = String(sample?.source || '').trim();
|
|
312
|
+
return `Translate from ${srcName} to ${tgtName}:\n${source}\nTranslation:`;
|
|
313
|
+
}
|
|
314
|
+
|
|
315
|
+
function buildDistillCandidatePrompt(sample, candidate) {
|
|
316
|
+
const base = buildDistillPrompt(sample);
|
|
317
|
+
const text = String(candidate || '').trim();
|
|
318
|
+
return text ? `${base} ${text}` : base;
|
|
319
|
+
}
|
|
320
|
+
|
|
321
|
+
function toFiniteNumber(value, fallback) {
|
|
322
|
+
const parsed = Number(value);
|
|
323
|
+
return Number.isFinite(parsed) ? parsed : fallback;
|
|
324
|
+
}
|
|
325
|
+
|
|
326
|
+
function clampDistillTopK(value) {
|
|
327
|
+
const parsed = Math.floor(toFiniteNumber(value, DISTILL_ADAPTER_TOP_K));
|
|
328
|
+
return Math.max(2, Math.min(256, parsed));
|
|
329
|
+
}
|
|
330
|
+
|
|
331
|
+
function normalizeDistillStudentGraphMode(value) {
|
|
332
|
+
const normalized = normalizeOptionalString(value);
|
|
333
|
+
if (!normalized) return DISTILL_STUDENT_GRAPH_FULL;
|
|
334
|
+
const compact = normalized.toLowerCase().replace(/[-\s]/g, '_');
|
|
335
|
+
if (compact === 'projection_head' || compact === 'projection') {
|
|
336
|
+
return DISTILL_STUDENT_GRAPH_PROJECTION;
|
|
337
|
+
}
|
|
338
|
+
return DISTILL_STUDENT_GRAPH_FULL;
|
|
339
|
+
}
|
|
340
|
+
|
|
341
|
+
function toFloat32Array(values, label = 'values') {
|
|
342
|
+
if (values instanceof Float32Array) return values;
|
|
343
|
+
if (ArrayBuffer.isView(values)) {
|
|
344
|
+
return new Float32Array(values.buffer.slice(values.byteOffset, values.byteOffset + values.byteLength));
|
|
345
|
+
}
|
|
346
|
+
if (Array.isArray(values)) {
|
|
347
|
+
return new Float32Array(values);
|
|
348
|
+
}
|
|
349
|
+
throw new Error(`Distill ${label} must be an array-like float buffer.`);
|
|
350
|
+
}
|
|
351
|
+
|
|
352
|
+
function selectTopKIndices(logits, topK) {
|
|
353
|
+
const k = Math.max(1, Math.floor(topK));
|
|
354
|
+
const indices = new Int32Array(k);
|
|
355
|
+
const values = new Float32Array(k);
|
|
356
|
+
indices.fill(-1);
|
|
357
|
+
values.fill(-Infinity);
|
|
358
|
+
|
|
359
|
+
for (let i = 0; i < logits.length; i += 1) {
|
|
360
|
+
const value = Number.isFinite(logits[i]) ? logits[i] : DISTILL_LOGIT_FALLBACK;
|
|
361
|
+
if (value <= values[k - 1]) continue;
|
|
362
|
+
let insert = k - 1;
|
|
363
|
+
while (insert > 0 && value > values[insert - 1]) {
|
|
364
|
+
values[insert] = values[insert - 1];
|
|
365
|
+
indices[insert] = indices[insert - 1];
|
|
366
|
+
insert -= 1;
|
|
367
|
+
}
|
|
368
|
+
values[insert] = value;
|
|
369
|
+
indices[insert] = i;
|
|
370
|
+
}
|
|
371
|
+
|
|
372
|
+
for (let i = 0; i < k; i += 1) {
|
|
373
|
+
if (indices[i] >= 0) continue;
|
|
374
|
+
indices[i] = i < logits.length ? i : -1;
|
|
375
|
+
}
|
|
376
|
+
return indices;
|
|
377
|
+
}
|
|
378
|
+
|
|
379
|
+
function gatherLogitsByIndices(logits, indices, fallback = DISTILL_LOGIT_FALLBACK) {
|
|
380
|
+
const gathered = new Float32Array(indices.length);
|
|
381
|
+
for (let i = 0; i < indices.length; i += 1) {
|
|
382
|
+
const tokenIndex = indices[i];
|
|
383
|
+
if (tokenIndex >= 0 && tokenIndex < logits.length) {
|
|
384
|
+
const value = logits[tokenIndex];
|
|
385
|
+
gathered[i] = Number.isFinite(value) ? value : fallback;
|
|
386
|
+
continue;
|
|
387
|
+
}
|
|
388
|
+
gathered[i] = fallback;
|
|
389
|
+
}
|
|
390
|
+
return gathered;
|
|
391
|
+
}
|
|
392
|
+
|
|
393
|
+
function argmax(values) {
|
|
394
|
+
let bestIndex = 0;
|
|
395
|
+
let bestValue = Number.NEGATIVE_INFINITY;
|
|
396
|
+
for (let i = 0; i < values.length; i += 1) {
|
|
397
|
+
const value = Number.isFinite(values[i]) ? values[i] : Number.NEGATIVE_INFINITY;
|
|
398
|
+
if (value > bestValue) {
|
|
399
|
+
bestValue = value;
|
|
400
|
+
bestIndex = i;
|
|
401
|
+
}
|
|
402
|
+
}
|
|
403
|
+
return bestIndex;
|
|
404
|
+
}
|
|
405
|
+
|
|
406
|
+
function softmax(values, temperature = 1) {
|
|
407
|
+
const t = Math.max(1e-4, toFiniteNumber(temperature, 1));
|
|
408
|
+
let max = Number.NEGATIVE_INFINITY;
|
|
409
|
+
for (let i = 0; i < values.length; i += 1) {
|
|
410
|
+
const candidate = values[i] / t;
|
|
411
|
+
if (candidate > max) max = candidate;
|
|
412
|
+
}
|
|
413
|
+
const exps = new Float32Array(values.length);
|
|
414
|
+
let sum = 0;
|
|
415
|
+
for (let i = 0; i < values.length; i += 1) {
|
|
416
|
+
const value = Math.exp((values[i] / t) - max);
|
|
417
|
+
exps[i] = value;
|
|
418
|
+
sum += value;
|
|
419
|
+
}
|
|
420
|
+
if (!Number.isFinite(sum) || sum <= 0) {
|
|
421
|
+
const uniform = 1 / Math.max(1, values.length);
|
|
422
|
+
exps.fill(uniform);
|
|
423
|
+
return exps;
|
|
424
|
+
}
|
|
425
|
+
for (let i = 0; i < exps.length; i += 1) {
|
|
426
|
+
exps[i] /= sum;
|
|
427
|
+
}
|
|
428
|
+
return exps;
|
|
429
|
+
}
|
|
430
|
+
|
|
431
|
+
function disposePrefillSnapshot(result) {
|
|
432
|
+
const cache = result?.cache;
|
|
433
|
+
if (cache && typeof cache.clear === 'function') {
|
|
434
|
+
cache.clear();
|
|
435
|
+
}
|
|
436
|
+
}
|
|
437
|
+
|
|
438
|
+
function buildShuffledIndices(length, seed = 1337) {
|
|
439
|
+
const indices = Array.from({ length }, (_, idx) => idx);
|
|
440
|
+
let state = (Number(seed) >>> 0) || 0x6d2b79f5;
|
|
441
|
+
for (let i = indices.length - 1; i > 0; i -= 1) {
|
|
442
|
+
state = ((state * 1664525) + 1013904223) >>> 0;
|
|
443
|
+
const j = state % (i + 1);
|
|
444
|
+
const tmp = indices[i];
|
|
445
|
+
indices[i] = indices[j];
|
|
446
|
+
indices[j] = tmp;
|
|
447
|
+
}
|
|
448
|
+
return indices;
|
|
449
|
+
}
|
|
450
|
+
|
|
451
|
+
function normalizeDistillStage(value) {
|
|
452
|
+
const stage = String(value || '').trim();
|
|
453
|
+
return stage === 'stage_b' ? 'stage_b' : 'stage_a';
|
|
454
|
+
}
|
|
455
|
+
|
|
456
|
+
async function computeTeacherPromptDistillFeatures(sample, prompt, runtime) {
|
|
457
|
+
const teacherResult = await runtime.teacherPipeline.prefillWithLogits(prompt, {
|
|
458
|
+
useChatTemplate: false,
|
|
459
|
+
});
|
|
460
|
+
try {
|
|
461
|
+
const teacherLogits = toFloat32Array(teacherResult?.logits, 'teacher logits');
|
|
462
|
+
const topTokenIndices = selectTopKIndices(teacherLogits, runtime.topK);
|
|
463
|
+
const teacherTopLogits = gatherLogitsByIndices(teacherLogits, topTokenIndices, DISTILL_LOGIT_FALLBACK);
|
|
464
|
+
const teacherTopProbs = softmax(teacherTopLogits, runtime.temperature);
|
|
465
|
+
const targetClass = argmax(teacherTopLogits);
|
|
466
|
+
return {
|
|
467
|
+
source: sample.source,
|
|
468
|
+
direction: sample.direction,
|
|
469
|
+
targetClass,
|
|
470
|
+
topTokenIndices: Array.from(topTokenIndices),
|
|
471
|
+
teacherTopLogits,
|
|
472
|
+
teacherTopProbs,
|
|
473
|
+
};
|
|
474
|
+
} finally {
|
|
475
|
+
disposePrefillSnapshot(teacherResult);
|
|
476
|
+
runtime.teacherPipeline.reset();
|
|
477
|
+
}
|
|
478
|
+
}
|
|
479
|
+
|
|
480
|
+
function createDistillTensorDataset(samples, options = {}) {
|
|
481
|
+
if (!Array.isArray(samples) || samples.length === 0) {
|
|
482
|
+
throw new Error('Distill dataset has no usable rows.');
|
|
483
|
+
}
|
|
484
|
+
const distillRuntime = options.distillRuntime && typeof options.distillRuntime === 'object'
|
|
485
|
+
? options.distillRuntime
|
|
486
|
+
: null;
|
|
487
|
+
if (!distillRuntime?.teacherPipeline) {
|
|
488
|
+
throw new Error('Distill dataset requires teacherPipeline.');
|
|
489
|
+
}
|
|
490
|
+
const batchSize = Math.max(1, Math.floor(Number(options.batchSize) || 1));
|
|
491
|
+
const shuffle = options.shuffle === true;
|
|
492
|
+
const seed = Number.isInteger(options.seed) ? options.seed : 1337;
|
|
493
|
+
const stage = normalizeDistillStage(distillRuntime.stage);
|
|
494
|
+
const topK = clampDistillTopK(distillRuntime.topK);
|
|
495
|
+
|
|
496
|
+
return {
|
|
497
|
+
async *batches() {
|
|
498
|
+
const order = shuffle
|
|
499
|
+
? buildShuffledIndices(samples.length, seed)
|
|
500
|
+
: Array.from({ length: samples.length }, (_, idx) => idx);
|
|
501
|
+
let inputTensor = null;
|
|
502
|
+
let targetTensor = null;
|
|
503
|
+
let tensorBatchSize = 0;
|
|
504
|
+
try {
|
|
505
|
+
for (let offset = 0; offset < order.length; offset += batchSize) {
|
|
506
|
+
const batchIndices = order.slice(offset, offset + batchSize);
|
|
507
|
+
const features = new Float32Array(batchIndices.length * topK);
|
|
508
|
+
const targets = new Uint32Array(batchIndices.length);
|
|
509
|
+
const teacherTopProbs = [];
|
|
510
|
+
const teacherTopTokenIndices = [];
|
|
511
|
+
const teacherTopLogits = [];
|
|
512
|
+
const teacherTargetIndices = [];
|
|
513
|
+
const teacherTargetTokenIds = [];
|
|
514
|
+
const prompts = [];
|
|
515
|
+
const tripletPositivePrompts = [];
|
|
516
|
+
const tripletNegativePrompts = [];
|
|
517
|
+
const tripletMask = [];
|
|
518
|
+
const directionCounts = {};
|
|
519
|
+
|
|
520
|
+
for (let i = 0; i < batchIndices.length; i += 1) {
|
|
521
|
+
const sample = samples[batchIndices[i]];
|
|
522
|
+
const prompt = buildDistillPrompt(sample);
|
|
523
|
+
const baseDistill = await computeTeacherPromptDistillFeatures(sample, prompt, {
|
|
524
|
+
...distillRuntime,
|
|
525
|
+
topK,
|
|
526
|
+
});
|
|
527
|
+
|
|
528
|
+
const baseOffset = i * topK;
|
|
529
|
+
features.set(baseDistill.teacherTopLogits, baseOffset);
|
|
530
|
+
const targetClass = baseDistill.targetClass;
|
|
531
|
+
const targetToken = Number.isInteger(baseDistill.topTokenIndices?.[targetClass])
|
|
532
|
+
? baseDistill.topTokenIndices[targetClass]
|
|
533
|
+
: targetClass;
|
|
534
|
+
const targetTokenMode = distillRuntime.targetTokenMode === 'teacher_top_token';
|
|
535
|
+
targets[i] = targetTokenMode ? targetToken : targetClass;
|
|
536
|
+
teacherTargetIndices.push(targetClass);
|
|
537
|
+
teacherTargetTokenIds.push(targetToken);
|
|
538
|
+
teacherTopProbs.push(baseDistill.teacherTopProbs);
|
|
539
|
+
teacherTopTokenIndices.push(baseDistill.topTokenIndices);
|
|
540
|
+
teacherTopLogits.push(baseDistill.teacherTopLogits);
|
|
541
|
+
prompts.push(prompt);
|
|
542
|
+
|
|
543
|
+
if (stage === 'stage_b') {
|
|
544
|
+
const posPrompt = buildDistillCandidatePrompt(sample, sample.targetPos);
|
|
545
|
+
const negPrompt = sample.targetNeg
|
|
546
|
+
? buildDistillCandidatePrompt(sample, sample.targetNeg)
|
|
547
|
+
: null;
|
|
548
|
+
tripletPositivePrompts.push(posPrompt);
|
|
549
|
+
tripletNegativePrompts.push(negPrompt || posPrompt);
|
|
550
|
+
tripletMask.push(Boolean(negPrompt));
|
|
551
|
+
}
|
|
552
|
+
|
|
553
|
+
directionCounts[sample.direction] = (directionCounts[sample.direction] || 0) + 1;
|
|
554
|
+
}
|
|
555
|
+
|
|
556
|
+
if (!inputTensor || !targetTensor || tensorBatchSize !== batchIndices.length) {
|
|
557
|
+
releaseTensor(inputTensor);
|
|
558
|
+
releaseTensor(targetTensor);
|
|
559
|
+
inputTensor = makeTensorFromFloat32(
|
|
560
|
+
features,
|
|
561
|
+
[batchIndices.length, topK],
|
|
562
|
+
'distill_jsonl_input'
|
|
563
|
+
);
|
|
564
|
+
targetTensor = makeTensorFromUint32(
|
|
565
|
+
targets,
|
|
566
|
+
[batchIndices.length],
|
|
567
|
+
'distill_jsonl_targets'
|
|
568
|
+
);
|
|
569
|
+
tensorBatchSize = batchIndices.length;
|
|
570
|
+
} else {
|
|
571
|
+
uploadData(inputTensor.buffer, features);
|
|
572
|
+
uploadData(targetTensor.buffer, targets);
|
|
573
|
+
}
|
|
574
|
+
yield {
|
|
575
|
+
input: inputTensor,
|
|
576
|
+
targets: targetTensor,
|
|
577
|
+
distill: {
|
|
578
|
+
prompts,
|
|
579
|
+
tripletPositivePrompts,
|
|
580
|
+
tripletNegativePrompts,
|
|
581
|
+
tripletMask,
|
|
582
|
+
teacherTopProbs,
|
|
583
|
+
teacherTopTokenIndices,
|
|
584
|
+
teacherTopLogits,
|
|
585
|
+
teacherTargetIndices,
|
|
586
|
+
teacherTargetTokenIds,
|
|
587
|
+
targetTokenMode: distillRuntime.targetTokenMode || 'topk_class',
|
|
588
|
+
batchSampleCount: batchIndices.length,
|
|
589
|
+
directionCounts,
|
|
590
|
+
distillStage: stage,
|
|
591
|
+
temperature: toFiniteNumber(distillRuntime.temperature, 1),
|
|
592
|
+
alphaKd: toFiniteNumber(distillRuntime.alphaKd, 1),
|
|
593
|
+
alphaCe: toFiniteNumber(distillRuntime.alphaCe, 0),
|
|
594
|
+
tripletMargin: Math.max(0, toFiniteNumber(distillRuntime.tripletMargin, 0.2)),
|
|
595
|
+
teacherModelId: distillRuntime.teacherModelId || null,
|
|
596
|
+
studentModelId: distillRuntime.studentModelId || null,
|
|
597
|
+
},
|
|
598
|
+
};
|
|
599
|
+
}
|
|
600
|
+
} finally {
|
|
601
|
+
releaseTensor(inputTensor);
|
|
602
|
+
releaseTensor(targetTensor);
|
|
603
|
+
}
|
|
604
|
+
},
|
|
605
|
+
};
|
|
606
|
+
}
|
|
607
|
+
|
|
608
|
+
async function loadDistillDatasetFromJsonl(datasetPath, scopeOptions = null) {
|
|
609
|
+
const normalizedPath = normalizeDistillDatasetPath(datasetPath);
|
|
610
|
+
if (!normalizedPath) return null;
|
|
611
|
+
if (!isNodeRuntime()) {
|
|
612
|
+
throw new Error('distillDatasetPath currently requires Node runtime.');
|
|
613
|
+
}
|
|
614
|
+
const normalizedScope = (
|
|
615
|
+
scopeOptions && typeof scopeOptions === 'object'
|
|
616
|
+
? scopeOptions
|
|
617
|
+
: resolveDistillDataScope()
|
|
618
|
+
);
|
|
619
|
+
|
|
620
|
+
const [{ readFile }, { resolve, dirname, isAbsolute, join, sep }] = await Promise.all([
|
|
621
|
+
import('node:fs/promises'),
|
|
622
|
+
import('node:path'),
|
|
623
|
+
]);
|
|
624
|
+
|
|
625
|
+
const isShardManifest = (candidate) => {
|
|
626
|
+
if (!candidate || typeof candidate !== 'object' || Array.isArray(candidate)) return false;
|
|
627
|
+
if (!Array.isArray(candidate.shards) || candidate.shards.length === 0) return false;
|
|
628
|
+
return candidate.shards.every((entry) => {
|
|
629
|
+
if (typeof entry === 'string' && entry.trim()) return true;
|
|
630
|
+
if (entry && typeof entry === 'object' && typeof entry.path === 'string' && entry.path.trim()) return true;
|
|
631
|
+
return false;
|
|
632
|
+
});
|
|
633
|
+
};
|
|
634
|
+
const resolveShardPath = (entry, manifestDir) => {
|
|
635
|
+
const rawPath = typeof entry === 'string' ? entry : entry.path;
|
|
636
|
+
const normalized = String(rawPath || '').trim();
|
|
637
|
+
if (!normalized) return null;
|
|
638
|
+
if (isAbsolute(normalized)) return normalized;
|
|
639
|
+
if (normalized.startsWith(`.${sep}`) || normalized.startsWith(`..${sep}`)) {
|
|
640
|
+
return resolve(manifestDir, normalized);
|
|
641
|
+
}
|
|
642
|
+
const projectsPrefix = `projects${sep}`;
|
|
643
|
+
if (normalized.startsWith(projectsPrefix)) {
|
|
644
|
+
const marker = `${sep}projects${sep}`;
|
|
645
|
+
const markerIndex = manifestDir.lastIndexOf(marker);
|
|
646
|
+
if (markerIndex >= 0) {
|
|
647
|
+
const workspaceRoot = manifestDir.slice(0, markerIndex);
|
|
648
|
+
return join(workspaceRoot, normalized);
|
|
649
|
+
}
|
|
650
|
+
}
|
|
651
|
+
return join(manifestDir, normalized);
|
|
652
|
+
};
|
|
653
|
+
const loadEncodedRows = (rawRows, contextLabel) => {
|
|
654
|
+
const encodedRows = [];
|
|
655
|
+
for (let i = 0; i < rawRows.length; i += 1) {
|
|
656
|
+
let encoded = null;
|
|
657
|
+
try {
|
|
658
|
+
encoded = encodeDistillRow(rawRows[i], i, normalizedScope);
|
|
659
|
+
} catch (error) {
|
|
660
|
+
const message = error?.message ? String(error.message) : String(error);
|
|
661
|
+
throw new Error(`${contextLabel}: row ${i + 1}: ${message}`);
|
|
662
|
+
}
|
|
663
|
+
if (encoded) encodedRows.push(encoded);
|
|
664
|
+
}
|
|
665
|
+
return encodedRows;
|
|
666
|
+
};
|
|
667
|
+
|
|
668
|
+
const absolutePath = resolve(normalizedPath);
|
|
669
|
+
let raw;
|
|
670
|
+
try {
|
|
671
|
+
raw = await readFile(absolutePath, 'utf8');
|
|
672
|
+
} catch (error) {
|
|
673
|
+
const message = error?.message ? String(error.message) : String(error);
|
|
674
|
+
throw new Error(`Failed to read distillDatasetPath "${absolutePath}": ${message}`);
|
|
675
|
+
}
|
|
676
|
+
|
|
677
|
+
let parsedJson = null;
|
|
678
|
+
try {
|
|
679
|
+
parsedJson = JSON.parse(raw);
|
|
680
|
+
} catch {
|
|
681
|
+
parsedJson = null;
|
|
682
|
+
}
|
|
683
|
+
if (isShardManifest(parsedJson)) {
|
|
684
|
+
const manifestDir = dirname(absolutePath);
|
|
685
|
+
const shardPaths = parsedJson.shards
|
|
686
|
+
.map((entry) => resolveShardPath(entry, manifestDir))
|
|
687
|
+
.filter(Boolean);
|
|
688
|
+
if (shardPaths.length === 0) {
|
|
689
|
+
throw new Error(`Distill shard manifest "${absolutePath}" has no valid shard paths.`);
|
|
690
|
+
}
|
|
691
|
+
let rowCount = 0;
|
|
692
|
+
let sampleCount = 0;
|
|
693
|
+
const directionCounts = {};
|
|
694
|
+
for (const shardPath of shardPaths) {
|
|
695
|
+
const shardRaw = await readFile(shardPath, 'utf8');
|
|
696
|
+
const shardRows = parseJsonl(shardRaw);
|
|
697
|
+
const encodedRows = loadEncodedRows(shardRows, `distill shard "${shardPath}"`);
|
|
698
|
+
rowCount += shardRows.length;
|
|
699
|
+
sampleCount += encodedRows.length;
|
|
700
|
+
const shardDirections = summarizeDirectionCounts(encodedRows);
|
|
701
|
+
for (const [direction, count] of Object.entries(shardDirections)) {
|
|
702
|
+
directionCounts[direction] = (directionCounts[direction] || 0) + count;
|
|
703
|
+
}
|
|
704
|
+
}
|
|
705
|
+
if (sampleCount <= 0) {
|
|
706
|
+
throw new Error(`Distill shard manifest "${absolutePath}" has no usable rows across shards.`);
|
|
707
|
+
}
|
|
708
|
+
return {
|
|
709
|
+
absolutePath,
|
|
710
|
+
rowCount,
|
|
711
|
+
sampleCount,
|
|
712
|
+
directionCounts,
|
|
713
|
+
dataScope: {
|
|
714
|
+
sourceLangs: normalizedScope.sourceLangs || null,
|
|
715
|
+
targetLangs: normalizedScope.targetLangs || null,
|
|
716
|
+
pairAllowlist: normalizedScope.pairAllowlist || null,
|
|
717
|
+
strictPairContract: normalizedScope.strictPairContract === true,
|
|
718
|
+
},
|
|
719
|
+
shardCount: shardPaths.length,
|
|
720
|
+
shardPaths,
|
|
721
|
+
createDataset(runOptions = {}) {
|
|
722
|
+
const shardSeedBase = Number.isInteger(runOptions.seed) ? runOptions.seed : 1337;
|
|
723
|
+
return {
|
|
724
|
+
async *batches() {
|
|
725
|
+
for (let shardIndex = 0; shardIndex < shardPaths.length; shardIndex += 1) {
|
|
726
|
+
const shardPath = shardPaths[shardIndex];
|
|
727
|
+
const shardRaw = await readFile(shardPath, 'utf8');
|
|
728
|
+
const shardRows = parseJsonl(shardRaw);
|
|
729
|
+
const encodedRows = loadEncodedRows(shardRows, `distill shard "${shardPath}"`);
|
|
730
|
+
if (encodedRows.length === 0) continue;
|
|
731
|
+
const shardDataset = createDistillTensorDataset(encodedRows, {
|
|
732
|
+
...runOptions,
|
|
733
|
+
seed: shardSeedBase + shardIndex,
|
|
734
|
+
});
|
|
735
|
+
for await (const batch of shardDataset.batches()) {
|
|
736
|
+
if (batch?.distill && typeof batch.distill === 'object') {
|
|
737
|
+
batch.distill.datasetShardIndex = shardIndex + 1;
|
|
738
|
+
batch.distill.datasetShardCount = shardPaths.length;
|
|
739
|
+
batch.distill.datasetShardPath = shardPath;
|
|
740
|
+
}
|
|
741
|
+
yield batch;
|
|
742
|
+
}
|
|
743
|
+
}
|
|
744
|
+
},
|
|
745
|
+
};
|
|
746
|
+
},
|
|
747
|
+
};
|
|
748
|
+
}
|
|
749
|
+
|
|
750
|
+
const rows = parseJsonl(raw);
|
|
751
|
+
const encodedRows = loadEncodedRows(rows, `distill dataset "${absolutePath}"`);
|
|
752
|
+
if (encodedRows.length === 0) {
|
|
753
|
+
throw new Error(`Distill dataset "${absolutePath}" has no usable rows.`);
|
|
754
|
+
}
|
|
755
|
+
|
|
756
|
+
return {
|
|
757
|
+
absolutePath,
|
|
758
|
+
rowCount: rows.length,
|
|
759
|
+
sampleCount: encodedRows.length,
|
|
760
|
+
directionCounts: summarizeDirectionCounts(encodedRows),
|
|
761
|
+
dataScope: {
|
|
762
|
+
sourceLangs: normalizedScope.sourceLangs || null,
|
|
763
|
+
targetLangs: normalizedScope.targetLangs || null,
|
|
764
|
+
pairAllowlist: normalizedScope.pairAllowlist || null,
|
|
765
|
+
strictPairContract: normalizedScope.strictPairContract === true,
|
|
766
|
+
},
|
|
767
|
+
createDataset(runOptions = {}) {
|
|
768
|
+
return createDistillTensorDataset(encodedRows, runOptions);
|
|
769
|
+
},
|
|
770
|
+
};
|
|
771
|
+
}
|
|
772
|
+
|
|
773
|
+
function looksLikeUrl(value) {
|
|
774
|
+
return /^[a-zA-Z][a-zA-Z0-9+.-]*:/.test(String(value || '').trim());
|
|
775
|
+
}
|
|
776
|
+
|
|
777
|
+
function looksLikeFilesystemPath(value) {
|
|
778
|
+
const normalized = String(value || '').trim();
|
|
779
|
+
return normalized.startsWith('/') || normalized.startsWith('./') || normalized.startsWith('../');
|
|
780
|
+
}
|
|
781
|
+
|
|
782
|
+
async function resolveNodeModelUrlFromRef(modelRef) {
|
|
783
|
+
if (!isNodeRuntime()) return null;
|
|
784
|
+
const [{ access, constants }, { resolve, join }, { pathToFileURL }] = await Promise.all([
|
|
785
|
+
import('node:fs/promises'),
|
|
786
|
+
import('node:path'),
|
|
787
|
+
import('node:url'),
|
|
788
|
+
]);
|
|
789
|
+
|
|
790
|
+
const normalized = String(modelRef || '').trim();
|
|
791
|
+
if (!normalized) return null;
|
|
792
|
+
const candidates = [
|
|
793
|
+
normalized,
|
|
794
|
+
join('models', 'local', normalized),
|
|
795
|
+
join('models', 'curated', normalized),
|
|
796
|
+
];
|
|
797
|
+
for (const candidate of candidates) {
|
|
798
|
+
const absolutePath = resolve(candidate);
|
|
799
|
+
const manifestPath = join(absolutePath, 'manifest.json');
|
|
800
|
+
try {
|
|
801
|
+
await access(manifestPath, constants.R_OK);
|
|
802
|
+
return pathToFileURL(absolutePath).href;
|
|
803
|
+
} catch {
|
|
804
|
+
// Try next candidate.
|
|
805
|
+
}
|
|
806
|
+
}
|
|
807
|
+
return null;
|
|
808
|
+
}
|
|
809
|
+
|
|
810
|
+
async function initializeInferenceFromStore(modelId) {
|
|
811
|
+
await openModelStore(modelId);
|
|
812
|
+
const manifestText = await loadManifestFromStore();
|
|
813
|
+
if (!manifestText) {
|
|
814
|
+
throw new Error(`Manifest not found in store for model "${modelId}".`);
|
|
815
|
+
}
|
|
816
|
+
const manifest = parseManifest(manifestText);
|
|
817
|
+
const pipeline = await createPipeline(manifest, {
|
|
818
|
+
gpu: { device: getDevice() },
|
|
819
|
+
});
|
|
820
|
+
return { pipeline, manifest };
|
|
821
|
+
}
|
|
822
|
+
|
|
823
|
+
async function loadDistillModelHandle(modelRef, role, loadOptions = {}) {
|
|
824
|
+
const normalizedRef = normalizeOptionalString(modelRef);
|
|
825
|
+
if (!normalizedRef) {
|
|
826
|
+
throw new Error(`Distill ${role} model reference is required.`);
|
|
827
|
+
}
|
|
828
|
+
|
|
829
|
+
const loadFromUrl = async (url) => {
|
|
830
|
+
const initialized = await initializeInference(url, {
|
|
831
|
+
log: () => {},
|
|
832
|
+
onProgress: () => {},
|
|
833
|
+
runtime: loadOptions.runtime || undefined,
|
|
834
|
+
});
|
|
835
|
+
return {
|
|
836
|
+
modelRef: normalizedRef,
|
|
837
|
+
modelUrl: url,
|
|
838
|
+
manifest: initialized.manifest,
|
|
839
|
+
pipeline: initialized.pipeline,
|
|
840
|
+
};
|
|
841
|
+
};
|
|
842
|
+
|
|
843
|
+
if (looksLikeUrl(normalizedRef)) {
|
|
844
|
+
return loadFromUrl(normalizedRef);
|
|
845
|
+
}
|
|
846
|
+
|
|
847
|
+
if (isNodeRuntime()) {
|
|
848
|
+
const localUrl = await resolveNodeModelUrlFromRef(normalizedRef);
|
|
849
|
+
if (localUrl) {
|
|
850
|
+
return loadFromUrl(localUrl);
|
|
851
|
+
}
|
|
852
|
+
}
|
|
853
|
+
|
|
854
|
+
if (looksLikeFilesystemPath(normalizedRef) && isNodeRuntime()) {
|
|
855
|
+
const [{ resolve }, { pathToFileURL }] = await Promise.all([
|
|
856
|
+
import('node:path'),
|
|
857
|
+
import('node:url'),
|
|
858
|
+
]);
|
|
859
|
+
return loadFromUrl(pathToFileURL(resolve(normalizedRef)).href);
|
|
860
|
+
}
|
|
861
|
+
|
|
862
|
+
const { pipeline, manifest } = await initializeInferenceFromStore(normalizedRef);
|
|
863
|
+
return {
|
|
864
|
+
modelRef: normalizedRef,
|
|
865
|
+
modelUrl: null,
|
|
866
|
+
manifest,
|
|
867
|
+
pipeline,
|
|
868
|
+
};
|
|
869
|
+
}
|
|
870
|
+
|
|
871
|
+
function resolveDistillModelRefs(options = {}, trainingConfig = null) {
|
|
872
|
+
const distillConfig = trainingConfig?.distill || {};
|
|
873
|
+
return {
|
|
874
|
+
teacherModelRef: normalizeOptionalString(options.teacherModelId ?? distillConfig.teacherModelId),
|
|
875
|
+
studentModelRef: normalizeOptionalString(options.studentModelId ?? distillConfig.studentModelId),
|
|
876
|
+
};
|
|
877
|
+
}
|
|
878
|
+
|
|
879
|
+
async function createDistillRuntimeContext(options = {}, trainingConfig = null) {
|
|
880
|
+
const { teacherModelRef, studentModelRef } = resolveDistillModelRefs(options, trainingConfig);
|
|
881
|
+
if (!teacherModelRef || !studentModelRef) {
|
|
882
|
+
throw new Error('Distill stage requires teacherModelId and studentModelId.');
|
|
883
|
+
}
|
|
884
|
+
|
|
885
|
+
const distillConfig = trainingConfig?.distill || {};
|
|
886
|
+
const studentGraphMode = normalizeDistillStudentGraphMode(
|
|
887
|
+
options.studentGraphMode
|
|
888
|
+
?? distillConfig.studentGraphMode
|
|
889
|
+
);
|
|
890
|
+
const teacher = await loadDistillModelHandle(teacherModelRef, 'teacher');
|
|
891
|
+
const studentRuntime = studentGraphMode === DISTILL_STUDENT_GRAPH_FULL
|
|
892
|
+
? {
|
|
893
|
+
runtimeConfig: {
|
|
894
|
+
shared: {
|
|
895
|
+
debug: {
|
|
896
|
+
logLevel: {
|
|
897
|
+
defaultLogLevel: 'debug',
|
|
898
|
+
},
|
|
899
|
+
},
|
|
900
|
+
},
|
|
901
|
+
inference: {
|
|
902
|
+
compute: {
|
|
903
|
+
activationDtype: 'f32',
|
|
904
|
+
keepF32Weights: true,
|
|
905
|
+
},
|
|
906
|
+
},
|
|
907
|
+
},
|
|
908
|
+
}
|
|
909
|
+
: null;
|
|
910
|
+
let student = null;
|
|
911
|
+
try {
|
|
912
|
+
student = await loadDistillModelHandle(studentModelRef, 'student', {
|
|
913
|
+
runtime: studentRuntime,
|
|
914
|
+
});
|
|
915
|
+
} catch (error) {
|
|
916
|
+
if (teacher?.pipeline && typeof teacher.pipeline.unload === 'function') {
|
|
917
|
+
await teacher.pipeline.unload();
|
|
918
|
+
}
|
|
919
|
+
throw error;
|
|
920
|
+
}
|
|
921
|
+
|
|
922
|
+
const runtime = {
|
|
923
|
+
stage: normalizeDistillStage(options.trainingStage || distillConfig.stage),
|
|
924
|
+
teacherPipeline: teacher.pipeline,
|
|
925
|
+
studentPipeline: student.pipeline,
|
|
926
|
+
teacherModelId: teacher.manifest?.modelId || teacherModelRef,
|
|
927
|
+
studentModelId: student.manifest?.modelId || studentModelRef,
|
|
928
|
+
teacherModelUrl: teacher.modelUrl || null,
|
|
929
|
+
studentModelUrl: student.modelUrl || null,
|
|
930
|
+
topK: clampDistillTopK(distillConfig.topK ?? DISTILL_ADAPTER_TOP_K),
|
|
931
|
+
temperature: Math.max(1e-4, toFiniteNumber(distillConfig.temperature, 1)),
|
|
932
|
+
alphaKd: toFiniteNumber(distillConfig.alphaKd, 1),
|
|
933
|
+
alphaCe: toFiniteNumber(distillConfig.alphaCe, 0),
|
|
934
|
+
tripletMargin: Math.max(0, toFiniteNumber(distillConfig.tripletMargin, 0.2)),
|
|
935
|
+
studentGraphMode,
|
|
936
|
+
targetTokenMode: studentGraphMode === DISTILL_STUDENT_GRAPH_FULL
|
|
937
|
+
? 'teacher_top_token'
|
|
938
|
+
: 'topk_class',
|
|
939
|
+
async cleanup() {
|
|
940
|
+
if (teacher?.pipeline && typeof teacher.pipeline.unload === 'function') {
|
|
941
|
+
await teacher.pipeline.unload();
|
|
942
|
+
}
|
|
943
|
+
if (student?.pipeline && typeof student.pipeline.unload === 'function') {
|
|
944
|
+
await student.pipeline.unload();
|
|
945
|
+
}
|
|
946
|
+
},
|
|
947
|
+
};
|
|
948
|
+
return runtime;
|
|
949
|
+
}
|
|
950
|
+
|
|
951
|
+
function resolveDistillDatasetPath(options = {}, trainingConfig = null) {
|
|
952
|
+
return normalizeDistillDatasetPath(
|
|
953
|
+
options.distillDatasetPath ?? trainingConfig?.distill?.datasetPath ?? null
|
|
954
|
+
);
|
|
955
|
+
}
|
|
956
|
+
|
|
957
|
+
function resolveRuntimeUrl(pathname) {
|
|
958
|
+
if (typeof globalThis.location !== 'undefined' && globalThis.location?.href) {
|
|
959
|
+
return pathname;
|
|
960
|
+
}
|
|
961
|
+
return new URL(pathname, import.meta.url).toString();
|
|
962
|
+
}
|
|
963
|
+
|
|
964
|
+
async function ensureTrainingGpuRuntime() {
|
|
965
|
+
setPlatformsBaseUrl(resolveRuntimeUrl('../config/platforms/'));
|
|
966
|
+
setRegistryUrl(resolveRuntimeUrl('../config/kernels/registry.json'));
|
|
967
|
+
await initDevice();
|
|
968
|
+
}
|
|
969
|
+
|
|
970
|
+
function createToyModelFixture(overrides = {}) {
|
|
971
|
+
const config = createTrainingConfig({
|
|
972
|
+
...overrides,
|
|
973
|
+
training: {
|
|
974
|
+
enabled: true,
|
|
975
|
+
lossScaling: { enabled: false },
|
|
976
|
+
gradient: { maxNorm: 0 },
|
|
977
|
+
...(overrides.training || {}),
|
|
978
|
+
},
|
|
979
|
+
});
|
|
980
|
+
|
|
981
|
+
const encoderWeight = makeTensorFromFloat32(
|
|
982
|
+
[0.1, -0.2, 0.3, 0.4, 0.05, -0.1],
|
|
983
|
+
[3, 2],
|
|
984
|
+
'training_suite_encoder_weight'
|
|
985
|
+
);
|
|
986
|
+
const priorWeight = makeTensorFromFloat32(
|
|
987
|
+
[0.02, -0.01, 0.03, -0.05, 0.04, -0.02],
|
|
988
|
+
[3, 2],
|
|
989
|
+
'training_suite_prior_weight'
|
|
990
|
+
);
|
|
991
|
+
const decoderWeight = makeTensorFromFloat32(
|
|
992
|
+
[0.03, 0.02, -0.01, 0.06, -0.04, 0.02],
|
|
993
|
+
[3, 2],
|
|
994
|
+
'training_suite_decoder_weight'
|
|
995
|
+
);
|
|
996
|
+
const baseWeight = makeTensorFromFloat32(
|
|
997
|
+
[0.08, -0.12, 0.16, 0.22, -0.03, 0.09],
|
|
998
|
+
[3, 2],
|
|
999
|
+
'training_suite_base_weight'
|
|
1000
|
+
);
|
|
1001
|
+
const input = makeTensorFromFloat32([0.5, 0.1, -0.3, 0.2, 0.4, -0.1], [2, 3], 'training_suite_input');
|
|
1002
|
+
const targets = makeTensorFromUint32([1, 0], [2], 'training_suite_targets');
|
|
1003
|
+
const batch = { input, targets };
|
|
1004
|
+
|
|
1005
|
+
const model = {
|
|
1006
|
+
async forward(inputTensor, tape) {
|
|
1007
|
+
return tape.record(
|
|
1008
|
+
OpType.MATMUL,
|
|
1009
|
+
(a, b) => runMatmul(a, b, 2, 2, 3, { transposeB: false }),
|
|
1010
|
+
[inputTensor, baseWeight],
|
|
1011
|
+
{ M: 2, N: 2, K: 3, transposeB: false }
|
|
1012
|
+
);
|
|
1013
|
+
},
|
|
1014
|
+
loraParams() {
|
|
1015
|
+
return [baseWeight];
|
|
1016
|
+
},
|
|
1017
|
+
paramGroups() {
|
|
1018
|
+
return {
|
|
1019
|
+
encoder: [encoderWeight],
|
|
1020
|
+
prior: [priorWeight],
|
|
1021
|
+
decoder: [decoderWeight],
|
|
1022
|
+
base: [baseWeight],
|
|
1023
|
+
lora: [baseWeight],
|
|
1024
|
+
};
|
|
1025
|
+
},
|
|
1026
|
+
};
|
|
1027
|
+
|
|
1028
|
+
return {
|
|
1029
|
+
config,
|
|
1030
|
+
model,
|
|
1031
|
+
batch,
|
|
1032
|
+
cleanup() {
|
|
1033
|
+
releaseTensor(encoderWeight);
|
|
1034
|
+
releaseTensor(priorWeight);
|
|
1035
|
+
releaseTensor(decoderWeight);
|
|
1036
|
+
releaseTensor(baseWeight);
|
|
1037
|
+
releaseTensor(input);
|
|
1038
|
+
releaseTensor(targets);
|
|
1039
|
+
},
|
|
1040
|
+
};
|
|
1041
|
+
}
|
|
1042
|
+
|
|
1043
|
+
function resolveTensorDtype(value, fallback = 'f32') {
|
|
1044
|
+
const dtype = isWeightBuffer(value)
|
|
1045
|
+
? value.dtype
|
|
1046
|
+
: (value?.dtype || getWeightDtype(value) || null);
|
|
1047
|
+
const normalized = String(dtype || '').toLowerCase();
|
|
1048
|
+
return normalized === 'f16' ? 'f16' : (normalized === 'f32' ? 'f32' : fallback);
|
|
1049
|
+
}
|
|
1050
|
+
|
|
1051
|
+
async function ensureTrainableTensor(value, shape, label, ownedTrainables = null) {
|
|
1052
|
+
if (!value) {
|
|
1053
|
+
throw new Error(`Distill full-graph student missing required weight "${label}".`);
|
|
1054
|
+
}
|
|
1055
|
+
const registerOwned = (tensor) => {
|
|
1056
|
+
if (ownedTrainables instanceof Set && tensor?.buffer instanceof GPUBuffer) {
|
|
1057
|
+
ownedTrainables.add(tensor);
|
|
1058
|
+
}
|
|
1059
|
+
return tensor;
|
|
1060
|
+
};
|
|
1061
|
+
if (isWeightBuffer(value)) {
|
|
1062
|
+
if (value.dtype === 'f32') {
|
|
1063
|
+
return value;
|
|
1064
|
+
}
|
|
1065
|
+
if (value.dtype === 'f16') {
|
|
1066
|
+
const sourceShape = Array.isArray(value.shape) && value.shape.length > 0 ? value.shape : [...shape];
|
|
1067
|
+
const source = createTensor(value.buffer, 'f16', sourceShape, `${label}_source_f16`);
|
|
1068
|
+
const promoted = await castF16ToF32(source);
|
|
1069
|
+
return registerOwned(createTensor(promoted.buffer, 'f32', sourceShape, `${label}_trainable_f32`));
|
|
1070
|
+
}
|
|
1071
|
+
throw new Error(`Distill full-graph student weight "${label}" uses unsupported dtype "${value.dtype}".`);
|
|
1072
|
+
}
|
|
1073
|
+
if (value instanceof GPUBuffer) {
|
|
1074
|
+
const sourceShape = [...shape];
|
|
1075
|
+
const rawDtype = String(getBufferDtype(value) || 'f32').toLowerCase();
|
|
1076
|
+
const dtype = rawDtype === 'f16' ? 'f16' : 'f32';
|
|
1077
|
+
const tensor = createTensor(value, dtype, sourceShape, label);
|
|
1078
|
+
if (dtype === 'f16') {
|
|
1079
|
+
const promoted = await castF16ToF32(tensor);
|
|
1080
|
+
return registerOwned(createTensor(promoted.buffer, 'f32', sourceShape, `${label}_trainable_f32`));
|
|
1081
|
+
}
|
|
1082
|
+
return tensor;
|
|
1083
|
+
}
|
|
1084
|
+
if (isCpuWeightBuffer(value)) {
|
|
1085
|
+
const sourceShape = Array.isArray(value.shape) && value.shape.length > 0 ? value.shape : [...shape];
|
|
1086
|
+
const dtype = resolveTensorDtype(value, 'f32');
|
|
1087
|
+
if (dtype === 'f32') {
|
|
1088
|
+
const tensor = makeTensorFromFloat32(value.data, sourceShape, `${label}_cpu_f32`);
|
|
1089
|
+
return registerOwned(tensor);
|
|
1090
|
+
}
|
|
1091
|
+
if (dtype === 'f16') {
|
|
1092
|
+
let raw = null;
|
|
1093
|
+
if (value.data instanceof Uint16Array) {
|
|
1094
|
+
raw = value.data;
|
|
1095
|
+
} else if (ArrayBuffer.isView(value.data)) {
|
|
1096
|
+
raw = new Uint16Array(
|
|
1097
|
+
value.data.buffer,
|
|
1098
|
+
value.data.byteOffset,
|
|
1099
|
+
Math.floor(value.data.byteLength / 2)
|
|
1100
|
+
);
|
|
1101
|
+
} else if (value.data instanceof ArrayBuffer) {
|
|
1102
|
+
raw = new Uint16Array(value.data);
|
|
1103
|
+
}
|
|
1104
|
+
if (!raw) {
|
|
1105
|
+
throw new Error(`Distill full-graph student weight "${label}" has non-typed f16 CPU data.`);
|
|
1106
|
+
}
|
|
1107
|
+
const source = makeTensorFromF16Bits(raw, sourceShape, `${label}_cpu_f16`);
|
|
1108
|
+
const promoted = await castF16ToF32(source);
|
|
1109
|
+
releaseTensor(source);
|
|
1110
|
+
return registerOwned(createTensor(promoted.buffer, 'f32', sourceShape, `${label}_trainable_f32`));
|
|
1111
|
+
}
|
|
1112
|
+
throw new Error(`Distill full-graph student weight "${label}" has unsupported CPU dtype "${dtype}".`);
|
|
1113
|
+
}
|
|
1114
|
+
if (value.buffer instanceof GPUBuffer) {
|
|
1115
|
+
const resolvedShape = Array.isArray(value.shape) && value.shape.length > 0 ? value.shape : [...shape];
|
|
1116
|
+
const tensor = createTensor(
|
|
1117
|
+
value.buffer,
|
|
1118
|
+
resolveTensorDtype(value, 'f32'),
|
|
1119
|
+
resolvedShape,
|
|
1120
|
+
label
|
|
1121
|
+
);
|
|
1122
|
+
if (tensor.dtype === 'f16') {
|
|
1123
|
+
const promoted = await castF16ToF32(tensor);
|
|
1124
|
+
return registerOwned(createTensor(promoted.buffer, 'f32', resolvedShape, `${label}_trainable_f32`));
|
|
1125
|
+
}
|
|
1126
|
+
return tensor;
|
|
1127
|
+
}
|
|
1128
|
+
throw new Error(`Distill full-graph student weight "${label}" is not GPU-resident.`);
|
|
1129
|
+
}
|
|
1130
|
+
|
|
1131
|
+
async function ensureNormTensor(value, hiddenSize, label, ownedTrainables = null) {
|
|
1132
|
+
return ensureTrainableTensor(value, [hiddenSize], label, ownedTrainables);
|
|
1133
|
+
}
|
|
1134
|
+
|
|
1135
|
+
function hasTensorPayload(value) {
|
|
1136
|
+
if (!value) return false;
|
|
1137
|
+
if (value instanceof GPUBuffer) return true;
|
|
1138
|
+
if (isWeightBuffer(value) || isCpuWeightBuffer(value)) return true;
|
|
1139
|
+
if (value?.buffer instanceof GPUBuffer) return true;
|
|
1140
|
+
if (ArrayBuffer.isView(value) || Array.isArray(value)) return true;
|
|
1141
|
+
return false;
|
|
1142
|
+
}
|
|
1143
|
+
|
|
1144
|
+
async function fuseGateUpTensors(gateTensor, upTensor, intermediateSize, hiddenSize, label, ownedTrainables = null) {
|
|
1145
|
+
const device = getDevice();
|
|
1146
|
+
if (!device) {
|
|
1147
|
+
throw new Error('Distill full-graph student requires active GPU device.');
|
|
1148
|
+
}
|
|
1149
|
+
if (gateTensor?.dtype !== 'f32' || upTensor?.dtype !== 'f32') {
|
|
1150
|
+
throw new Error(`Distill fused gate_up expects f32 tensors for "${label}".`);
|
|
1151
|
+
}
|
|
1152
|
+
const expectedRows = intermediateSize;
|
|
1153
|
+
const expectedCols = hiddenSize;
|
|
1154
|
+
const gateRows = Number.isFinite(gateTensor?.shape?.[0]) ? gateTensor.shape[0] : 0;
|
|
1155
|
+
const gateCols = Number.isFinite(gateTensor?.shape?.[1]) ? gateTensor.shape[1] : 0;
|
|
1156
|
+
const upRows = Number.isFinite(upTensor?.shape?.[0]) ? upTensor.shape[0] : 0;
|
|
1157
|
+
const upCols = Number.isFinite(upTensor?.shape?.[1]) ? upTensor.shape[1] : 0;
|
|
1158
|
+
if (gateRows !== expectedRows || gateCols !== expectedCols || upRows !== expectedRows || upCols !== expectedCols) {
|
|
1159
|
+
throw new Error(
|
|
1160
|
+
`Distill gate/up shape mismatch for "${label}": gate=[${gateRows},${gateCols}] up=[${upRows},${upCols}] ` +
|
|
1161
|
+
`expected=[${expectedRows},${expectedCols}]`
|
|
1162
|
+
);
|
|
1163
|
+
}
|
|
1164
|
+
const rowBytes = expectedCols * 4;
|
|
1165
|
+
const blockBytes = expectedRows * rowBytes;
|
|
1166
|
+
const fusedBuffer = acquireBuffer(blockBytes * 2, undefined, `${label}_fused`);
|
|
1167
|
+
const encoder = device.createCommandEncoder();
|
|
1168
|
+
encoder.copyBufferToBuffer(gateTensor.buffer, 0, fusedBuffer, 0, blockBytes);
|
|
1169
|
+
encoder.copyBufferToBuffer(upTensor.buffer, 0, fusedBuffer, blockBytes, blockBytes);
|
|
1170
|
+
device.queue.submit([encoder.finish()]);
|
|
1171
|
+
const fused = createTensor(fusedBuffer, 'f32', [expectedRows * 2, expectedCols], `${label}_fused`);
|
|
1172
|
+
if (ownedTrainables instanceof Set) {
|
|
1173
|
+
ownedTrainables.add(fused);
|
|
1174
|
+
}
|
|
1175
|
+
return fused;
|
|
1176
|
+
}
|
|
1177
|
+
|
|
1178
|
+
function resolvePhasePrompts(batch, phase) {
|
|
1179
|
+
const distill = batch?.distill || {};
|
|
1180
|
+
const prompts = phase === 'positive'
|
|
1181
|
+
? distill.tripletPositivePrompts
|
|
1182
|
+
: (phase === 'negative' ? distill.tripletNegativePrompts : distill.prompts);
|
|
1183
|
+
if (!Array.isArray(prompts) || prompts.length === 0) {
|
|
1184
|
+
throw new Error(`Distill student fixture requires distill prompts for phase "${phase}".`);
|
|
1185
|
+
}
|
|
1186
|
+
return prompts;
|
|
1187
|
+
}
|
|
1188
|
+
|
|
1189
|
+
function createRowSliceTensor(inputTensor, rows, cols, rowIndex, label) {
|
|
1190
|
+
const device = getDevice();
|
|
1191
|
+
if (!device) {
|
|
1192
|
+
throw new Error('Distill full-graph student requires active GPU device.');
|
|
1193
|
+
}
|
|
1194
|
+
const dtype = inputTensor?.dtype === 'f16' ? 'f16' : 'f32';
|
|
1195
|
+
const bytesPerElement = dtype === 'f16' ? 2 : 4;
|
|
1196
|
+
const rowBytes = cols * bytesPerElement;
|
|
1197
|
+
const clampedRow = Math.max(0, Math.min(rows - 1, rowIndex));
|
|
1198
|
+
const outputBuffer = acquireBuffer(rowBytes, undefined, label);
|
|
1199
|
+
const encoder = device.createCommandEncoder();
|
|
1200
|
+
encoder.copyBufferToBuffer(
|
|
1201
|
+
inputTensor.buffer,
|
|
1202
|
+
clampedRow * rowBytes,
|
|
1203
|
+
outputBuffer,
|
|
1204
|
+
0,
|
|
1205
|
+
rowBytes
|
|
1206
|
+
);
|
|
1207
|
+
device.queue.submit([encoder.finish()]);
|
|
1208
|
+
return createTensor(outputBuffer, dtype, [1, cols], label);
|
|
1209
|
+
}
|
|
1210
|
+
|
|
1211
|
+
function createDistillStudentProjectionModelFixture(overrides = {}, options = {}) {
|
|
1212
|
+
const distillRuntime = options.distillRuntime && typeof options.distillRuntime === 'object'
|
|
1213
|
+
? options.distillRuntime
|
|
1214
|
+
: null;
|
|
1215
|
+
if (!distillRuntime?.studentPipeline) {
|
|
1216
|
+
throw new Error('Distill student fixture requires distillRuntime.studentPipeline.');
|
|
1217
|
+
}
|
|
1218
|
+
const outputDim = clampDistillTopK(
|
|
1219
|
+
options.outputDim
|
|
1220
|
+
?? options.inputDim
|
|
1221
|
+
?? DISTILL_ADAPTER_TOP_K
|
|
1222
|
+
);
|
|
1223
|
+
const inferredEmbeddingDim = Math.floor(
|
|
1224
|
+
Number(distillRuntime.studentPipeline?.modelConfig?.hiddenSize)
|
|
1225
|
+
);
|
|
1226
|
+
const embeddingDim = Number.isInteger(options.embeddingDim) && options.embeddingDim > 0
|
|
1227
|
+
? options.embeddingDim
|
|
1228
|
+
: (Number.isFinite(inferredEmbeddingDim) && inferredEmbeddingDim > 0
|
|
1229
|
+
? inferredEmbeddingDim
|
|
1230
|
+
: outputDim);
|
|
1231
|
+
const config = createTrainingConfig({
|
|
1232
|
+
...overrides,
|
|
1233
|
+
training: {
|
|
1234
|
+
enabled: true,
|
|
1235
|
+
lossScaling: { enabled: false },
|
|
1236
|
+
gradient: { maxNorm: 0 },
|
|
1237
|
+
...(overrides.training || {}),
|
|
1238
|
+
},
|
|
1239
|
+
});
|
|
1240
|
+
|
|
1241
|
+
const projectionWeights = new Float32Array(embeddingDim * outputDim);
|
|
1242
|
+
const projectionWeight = makeTensorFromFloat32(
|
|
1243
|
+
projectionWeights,
|
|
1244
|
+
[embeddingDim, outputDim],
|
|
1245
|
+
'distill_student_head_weight'
|
|
1246
|
+
);
|
|
1247
|
+
const temporaryInputs = new Set();
|
|
1248
|
+
|
|
1249
|
+
async function projectEmbeddingInput(inputTensor, tape) {
|
|
1250
|
+
const rows = Number.isFinite(inputTensor?.shape?.[0]) ? inputTensor.shape[0] : 1;
|
|
1251
|
+
return tape.record(
|
|
1252
|
+
OpType.MATMUL,
|
|
1253
|
+
(a, b) => runMatmul(a, b, rows, outputDim, embeddingDim, { transposeB: false }),
|
|
1254
|
+
[inputTensor, projectionWeight],
|
|
1255
|
+
{ M: rows, N: outputDim, K: embeddingDim, transposeB: false }
|
|
1256
|
+
);
|
|
1257
|
+
}
|
|
1258
|
+
|
|
1259
|
+
async function buildStudentEmbeddingInput(batch, phase = 'anchor') {
|
|
1260
|
+
const distill = batch?.distill || {};
|
|
1261
|
+
const prompts = phase === 'positive'
|
|
1262
|
+
? distill.tripletPositivePrompts
|
|
1263
|
+
: (phase === 'negative' ? distill.tripletNegativePrompts : distill.prompts);
|
|
1264
|
+
if (!Array.isArray(prompts) || prompts.length === 0) {
|
|
1265
|
+
throw new Error(`Distill student fixture requires distill prompts for phase "${phase}".`);
|
|
1266
|
+
}
|
|
1267
|
+
|
|
1268
|
+
const rows = prompts.length;
|
|
1269
|
+
const features = new Float32Array(rows * embeddingDim);
|
|
1270
|
+
for (let row = 0; row < rows; row += 1) {
|
|
1271
|
+
const prompt = String(prompts[row] || '').trim();
|
|
1272
|
+
const studentResult = await distillRuntime.studentPipeline.prefillWithEmbedding(prompt, {
|
|
1273
|
+
useChatTemplate: false,
|
|
1274
|
+
embeddingMode: 'last',
|
|
1275
|
+
});
|
|
1276
|
+
try {
|
|
1277
|
+
const studentEmbedding = toFloat32Array(studentResult?.embedding, 'student embedding');
|
|
1278
|
+
const rowOffset = row * embeddingDim;
|
|
1279
|
+
const copyCount = Math.min(embeddingDim, studentEmbedding.length);
|
|
1280
|
+
features.set(studentEmbedding.subarray(0, copyCount), rowOffset);
|
|
1281
|
+
} finally {
|
|
1282
|
+
disposePrefillSnapshot(studentResult);
|
|
1283
|
+
distillRuntime.studentPipeline.reset();
|
|
1284
|
+
}
|
|
1285
|
+
}
|
|
1286
|
+
const inputTensor = makeTensorFromFloat32(
|
|
1287
|
+
features,
|
|
1288
|
+
[rows, embeddingDim],
|
|
1289
|
+
`distill_student_${phase}_embedding`
|
|
1290
|
+
);
|
|
1291
|
+
temporaryInputs.add(inputTensor);
|
|
1292
|
+
return inputTensor;
|
|
1293
|
+
}
|
|
1294
|
+
|
|
1295
|
+
const model = {
|
|
1296
|
+
async forward(inputTensor, tape) {
|
|
1297
|
+
return projectEmbeddingInput(inputTensor, tape);
|
|
1298
|
+
},
|
|
1299
|
+
async forwardDistill(batch, tape, forwardOptions = {}) {
|
|
1300
|
+
const requestedPhase = String(forwardOptions?.phase || 'anchor').trim();
|
|
1301
|
+
const phase = requestedPhase === 'positive'
|
|
1302
|
+
? 'positive'
|
|
1303
|
+
: (requestedPhase === 'negative' ? 'negative' : 'anchor');
|
|
1304
|
+
const inputTensor = await buildStudentEmbeddingInput(batch, phase);
|
|
1305
|
+
const logits = await projectEmbeddingInput(inputTensor, tape);
|
|
1306
|
+
return { logits };
|
|
1307
|
+
},
|
|
1308
|
+
cleanupDistillStep() {
|
|
1309
|
+
for (const tensor of temporaryInputs) {
|
|
1310
|
+
releaseTensor(tensor);
|
|
1311
|
+
}
|
|
1312
|
+
temporaryInputs.clear();
|
|
1313
|
+
},
|
|
1314
|
+
loraParams() {
|
|
1315
|
+
return [projectionWeight];
|
|
1316
|
+
},
|
|
1317
|
+
paramGroups() {
|
|
1318
|
+
return {
|
|
1319
|
+
encoder: [],
|
|
1320
|
+
prior: [],
|
|
1321
|
+
decoder: [],
|
|
1322
|
+
base: [projectionWeight],
|
|
1323
|
+
lora: [projectionWeight],
|
|
1324
|
+
};
|
|
1325
|
+
},
|
|
1326
|
+
};
|
|
1327
|
+
|
|
1328
|
+
return {
|
|
1329
|
+
config,
|
|
1330
|
+
model,
|
|
1331
|
+
outputDim,
|
|
1332
|
+
embeddingDim,
|
|
1333
|
+
cleanup() {
|
|
1334
|
+
model.cleanupDistillStep();
|
|
1335
|
+
releaseTensor(projectionWeight);
|
|
1336
|
+
},
|
|
1337
|
+
};
|
|
1338
|
+
}
|
|
1339
|
+
|
|
1340
|
+
async function createDistillStudentTransformerModelFixture(overrides = {}, options = {}) {
|
|
1341
|
+
const distillRuntime = options.distillRuntime && typeof options.distillRuntime === 'object'
|
|
1342
|
+
? options.distillRuntime
|
|
1343
|
+
: null;
|
|
1344
|
+
const studentPipeline = distillRuntime?.studentPipeline || null;
|
|
1345
|
+
if (!studentPipeline?.modelConfig || !(studentPipeline.weights instanceof Map)) {
|
|
1346
|
+
throw new Error('Distill full-graph student fixture requires loaded student pipeline weights.');
|
|
1347
|
+
}
|
|
1348
|
+
const modelConfig = studentPipeline.modelConfig;
|
|
1349
|
+
const hiddenSize = Math.max(1, Math.floor(Number(modelConfig.hiddenSize) || 0));
|
|
1350
|
+
const intermediateSize = Math.max(1, Math.floor(Number(modelConfig.intermediateSize) || 0));
|
|
1351
|
+
const numLayers = Math.max(1, Math.floor(Number(modelConfig.numLayers) || 0));
|
|
1352
|
+
const numHeads = Math.max(1, Math.floor(Number(modelConfig.numHeads) || 0));
|
|
1353
|
+
const numKVHeads = Math.max(1, Math.floor(Number(modelConfig.numKVHeads || numHeads) || 0));
|
|
1354
|
+
const headDim = Math.max(1, Math.floor(Number(modelConfig.headDim) || 0));
|
|
1355
|
+
const vocabSize = Math.max(1, Math.floor(Number(modelConfig.vocabSize) || 0));
|
|
1356
|
+
const rmsNormEps = Number.isFinite(modelConfig.rmsNormEps) ? modelConfig.rmsNormEps : 1e-6;
|
|
1357
|
+
const hiddenActivation = String(modelConfig.hiddenActivation || 'silu').toLowerCase();
|
|
1358
|
+
const swigluLimit = Number.isFinite(modelConfig.swigluLimit) ? modelConfig.swigluLimit : 0;
|
|
1359
|
+
const useEmbeddingTranspose = modelConfig.embeddingTranspose === true;
|
|
1360
|
+
const tieWordEmbeddings = modelConfig.useTiedEmbeddings === true;
|
|
1361
|
+
|
|
1362
|
+
const config = createTrainingConfig({
|
|
1363
|
+
...overrides,
|
|
1364
|
+
training: {
|
|
1365
|
+
enabled: true,
|
|
1366
|
+
lossScaling: { enabled: false },
|
|
1367
|
+
gradient: { maxNorm: 0 },
|
|
1368
|
+
...(overrides.training || {}),
|
|
1369
|
+
},
|
|
1370
|
+
});
|
|
1371
|
+
|
|
1372
|
+
const ownedTrainables = new Set();
|
|
1373
|
+
const embeddingWeight = await ensureTrainableTensor(
|
|
1374
|
+
studentPipeline.weights.get('embed'),
|
|
1375
|
+
[vocabSize, hiddenSize],
|
|
1376
|
+
'embed',
|
|
1377
|
+
ownedTrainables
|
|
1378
|
+
);
|
|
1379
|
+
const lmHeadWeight = tieWordEmbeddings
|
|
1380
|
+
? embeddingWeight
|
|
1381
|
+
: await ensureTrainableTensor(
|
|
1382
|
+
studentPipeline.weights.get('lm_head'),
|
|
1383
|
+
[vocabSize, hiddenSize],
|
|
1384
|
+
'lm_head',
|
|
1385
|
+
ownedTrainables
|
|
1386
|
+
);
|
|
1387
|
+
const finalNormWeight = await ensureNormTensor(
|
|
1388
|
+
studentPipeline.weights.get('final_norm'),
|
|
1389
|
+
hiddenSize,
|
|
1390
|
+
'final_norm',
|
|
1391
|
+
ownedTrainables
|
|
1392
|
+
);
|
|
1393
|
+
|
|
1394
|
+
const ropeDim = Math.max(1, Math.floor(headDim / 2));
|
|
1395
|
+
const ropeRows = Math.max(1, Math.floor(Number(modelConfig.maxSeqLen) || 1));
|
|
1396
|
+
const ropeCos = await ensureTrainableTensor(
|
|
1397
|
+
createTensor(studentPipeline.ropeFreqsCos, 'f32', [ropeRows, ropeDim], 'rope_cos'),
|
|
1398
|
+
[ropeRows, ropeDim],
|
|
1399
|
+
'rope_cos',
|
|
1400
|
+
ownedTrainables
|
|
1401
|
+
);
|
|
1402
|
+
const ropeSin = await ensureTrainableTensor(
|
|
1403
|
+
createTensor(studentPipeline.ropeFreqsSin, 'f32', [ropeRows, ropeDim], 'rope_sin'),
|
|
1404
|
+
[ropeRows, ropeDim],
|
|
1405
|
+
'rope_sin',
|
|
1406
|
+
ownedTrainables
|
|
1407
|
+
);
|
|
1408
|
+
|
|
1409
|
+
const layerParams = [];
|
|
1410
|
+
const layers = [];
|
|
1411
|
+
for (let layerIdx = 0; layerIdx < numLayers; layerIdx += 1) {
|
|
1412
|
+
const layerWeights = studentPipeline.weights.get(`layer_${layerIdx}`);
|
|
1413
|
+
if (!layerWeights) {
|
|
1414
|
+
throw new Error(`Distill full-graph student missing layer_${layerIdx} weights.`);
|
|
1415
|
+
}
|
|
1416
|
+
const gateUpWeight = layerWeights.gateUp || layerWeights.ffnGateUp || null;
|
|
1417
|
+
let layerGateUp = null;
|
|
1418
|
+
if (hasTensorPayload(gateUpWeight)) {
|
|
1419
|
+
layerGateUp = await ensureTrainableTensor(
|
|
1420
|
+
gateUpWeight,
|
|
1421
|
+
[intermediateSize * 2, hiddenSize],
|
|
1422
|
+
`layer_${layerIdx}.ffn_gate_up`,
|
|
1423
|
+
ownedTrainables
|
|
1424
|
+
);
|
|
1425
|
+
} else {
|
|
1426
|
+
const gateWeight = layerWeights.gate || layerWeights.ffnGate || null;
|
|
1427
|
+
const upWeight = layerWeights.up || layerWeights.ffnUp || null;
|
|
1428
|
+
if (!hasTensorPayload(gateWeight) || !hasTensorPayload(upWeight)) {
|
|
1429
|
+
throw new Error(
|
|
1430
|
+
`Distill full-graph student missing gate/up projections on layer ${layerIdx}.`
|
|
1431
|
+
);
|
|
1432
|
+
}
|
|
1433
|
+
const gateTensor = await ensureTrainableTensor(
|
|
1434
|
+
gateWeight,
|
|
1435
|
+
[intermediateSize, hiddenSize],
|
|
1436
|
+
`layer_${layerIdx}.ffn_gate`,
|
|
1437
|
+
ownedTrainables
|
|
1438
|
+
);
|
|
1439
|
+
const upTensor = await ensureTrainableTensor(
|
|
1440
|
+
upWeight,
|
|
1441
|
+
[intermediateSize, hiddenSize],
|
|
1442
|
+
`layer_${layerIdx}.ffn_up`,
|
|
1443
|
+
ownedTrainables
|
|
1444
|
+
);
|
|
1445
|
+
layerGateUp = await fuseGateUpTensors(
|
|
1446
|
+
gateTensor,
|
|
1447
|
+
upTensor,
|
|
1448
|
+
intermediateSize,
|
|
1449
|
+
hiddenSize,
|
|
1450
|
+
`layer_${layerIdx}.ffn_gate_up`,
|
|
1451
|
+
ownedTrainables
|
|
1452
|
+
);
|
|
1453
|
+
}
|
|
1454
|
+
const layer = {
|
|
1455
|
+
inputNorm: await ensureNormTensor(
|
|
1456
|
+
layerWeights.inputNorm,
|
|
1457
|
+
hiddenSize,
|
|
1458
|
+
`layer_${layerIdx}.input_norm`,
|
|
1459
|
+
ownedTrainables
|
|
1460
|
+
),
|
|
1461
|
+
qProj: await ensureTrainableTensor(
|
|
1462
|
+
layerWeights.qProj,
|
|
1463
|
+
[numHeads * headDim, hiddenSize],
|
|
1464
|
+
`layer_${layerIdx}.q_proj`,
|
|
1465
|
+
ownedTrainables
|
|
1466
|
+
),
|
|
1467
|
+
kProj: await ensureTrainableTensor(
|
|
1468
|
+
layerWeights.kProj,
|
|
1469
|
+
[numKVHeads * headDim, hiddenSize],
|
|
1470
|
+
`layer_${layerIdx}.k_proj`,
|
|
1471
|
+
ownedTrainables
|
|
1472
|
+
),
|
|
1473
|
+
vProj: await ensureTrainableTensor(
|
|
1474
|
+
layerWeights.vProj,
|
|
1475
|
+
[numKVHeads * headDim, hiddenSize],
|
|
1476
|
+
`layer_${layerIdx}.v_proj`,
|
|
1477
|
+
ownedTrainables
|
|
1478
|
+
),
|
|
1479
|
+
oProj: await ensureTrainableTensor(
|
|
1480
|
+
layerWeights.oProj,
|
|
1481
|
+
[hiddenSize, hiddenSize],
|
|
1482
|
+
`layer_${layerIdx}.o_proj`,
|
|
1483
|
+
ownedTrainables
|
|
1484
|
+
),
|
|
1485
|
+
postAttentionNorm: layerWeights.postAttentionNorm
|
|
1486
|
+
? await ensureNormTensor(
|
|
1487
|
+
layerWeights.postAttentionNorm,
|
|
1488
|
+
hiddenSize,
|
|
1489
|
+
`layer_${layerIdx}.post_attention_norm`,
|
|
1490
|
+
ownedTrainables
|
|
1491
|
+
)
|
|
1492
|
+
: null,
|
|
1493
|
+
gateUp: layerGateUp,
|
|
1494
|
+
down: await ensureTrainableTensor(
|
|
1495
|
+
layerWeights.down || layerWeights.ffnDown,
|
|
1496
|
+
[hiddenSize, intermediateSize],
|
|
1497
|
+
`layer_${layerIdx}.ffn_down`,
|
|
1498
|
+
ownedTrainables
|
|
1499
|
+
),
|
|
1500
|
+
};
|
|
1501
|
+
layers.push(layer);
|
|
1502
|
+
layerParams.push(layer.inputNorm, layer.qProj, layer.kProj, layer.vProj, layer.oProj, layer.gateUp, layer.down);
|
|
1503
|
+
if (layer.postAttentionNorm) {
|
|
1504
|
+
layerParams.push(layer.postAttentionNorm);
|
|
1505
|
+
}
|
|
1506
|
+
}
|
|
1507
|
+
|
|
1508
|
+
const encoderParams = [embeddingWeight, ...layerParams];
|
|
1509
|
+
const decoderParams = [finalNormWeight, lmHeadWeight];
|
|
1510
|
+
const baseParams = [...encoderParams, ...decoderParams];
|
|
1511
|
+
const temporaryInputs = new Set();
|
|
1512
|
+
|
|
1513
|
+
async function buildPromptTokens(prompt) {
|
|
1514
|
+
const normalized = String(prompt || '').trim();
|
|
1515
|
+
if (!normalized) {
|
|
1516
|
+
throw new Error('Distill full-graph student prompt is empty.');
|
|
1517
|
+
}
|
|
1518
|
+
const tokenIds = studentPipeline.tokenizer.encode(normalized);
|
|
1519
|
+
if (!Array.isArray(tokenIds) || tokenIds.length === 0) {
|
|
1520
|
+
throw new Error('Distill full-graph student tokenizer produced no tokens.');
|
|
1521
|
+
}
|
|
1522
|
+
const tokenTensor = makeTensorFromUint32(
|
|
1523
|
+
tokenIds,
|
|
1524
|
+
[tokenIds.length],
|
|
1525
|
+
'distill_student_prompt_tokens'
|
|
1526
|
+
);
|
|
1527
|
+
temporaryInputs.add(tokenTensor);
|
|
1528
|
+
return { tokenTensor, seqLen: tokenIds.length };
|
|
1529
|
+
}
|
|
1530
|
+
|
|
1531
|
+
async function runTransformerPrompt(prompt, tape) {
|
|
1532
|
+
const { tokenTensor, seqLen } = await buildPromptTokens(prompt);
|
|
1533
|
+
let hidden = await tape.record(
|
|
1534
|
+
OpType.EMBED,
|
|
1535
|
+
(indices, embeddings) => runGather(
|
|
1536
|
+
indices,
|
|
1537
|
+
embeddings,
|
|
1538
|
+
seqLen,
|
|
1539
|
+
hiddenSize,
|
|
1540
|
+
vocabSize,
|
|
1541
|
+
{
|
|
1542
|
+
embeddingDtype: resolveTensorDtype(embeddingWeight, 'f32'),
|
|
1543
|
+
outputDtype: 'f32',
|
|
1544
|
+
transpose: useEmbeddingTranspose,
|
|
1545
|
+
}
|
|
1546
|
+
),
|
|
1547
|
+
[tokenTensor, embeddingWeight],
|
|
1548
|
+
{
|
|
1549
|
+
numTokens: seqLen,
|
|
1550
|
+
hiddenSize,
|
|
1551
|
+
vocabSize,
|
|
1552
|
+
transpose: useEmbeddingTranspose,
|
|
1553
|
+
indexOffset: 0,
|
|
1554
|
+
}
|
|
1555
|
+
);
|
|
1556
|
+
|
|
1557
|
+
for (let layerIdx = 0; layerIdx < layers.length; layerIdx += 1) {
|
|
1558
|
+
const layer = layers[layerIdx];
|
|
1559
|
+
const normed = await tape.record(
|
|
1560
|
+
OpType.RMSNORM,
|
|
1561
|
+
(x, gamma) => runRMSNorm(x, gamma, rmsNormEps, {
|
|
1562
|
+
batchSize: seqLen,
|
|
1563
|
+
hiddenSize,
|
|
1564
|
+
rmsNormWeightOffset: modelConfig.rmsNormWeightOffset === true,
|
|
1565
|
+
}),
|
|
1566
|
+
[hidden, layer.inputNorm],
|
|
1567
|
+
{ numTokens: seqLen, hiddenSize, eps: rmsNormEps }
|
|
1568
|
+
);
|
|
1569
|
+
|
|
1570
|
+
const q2d = await tape.record(
|
|
1571
|
+
OpType.MATMUL,
|
|
1572
|
+
(x, w) => runMatmul(x, w, seqLen, numHeads * headDim, hiddenSize, {
|
|
1573
|
+
transposeB: 'auto',
|
|
1574
|
+
outputDtype: 'f32',
|
|
1575
|
+
}),
|
|
1576
|
+
[normed, layer.qProj],
|
|
1577
|
+
{ M: seqLen, N: numHeads * headDim, K: hiddenSize, transposeB: 'auto' }
|
|
1578
|
+
);
|
|
1579
|
+
const k2d = await tape.record(
|
|
1580
|
+
OpType.MATMUL,
|
|
1581
|
+
(x, w) => runMatmul(x, w, seqLen, numKVHeads * headDim, hiddenSize, {
|
|
1582
|
+
transposeB: 'auto',
|
|
1583
|
+
outputDtype: 'f32',
|
|
1584
|
+
}),
|
|
1585
|
+
[normed, layer.kProj],
|
|
1586
|
+
{ M: seqLen, N: numKVHeads * headDim, K: hiddenSize, transposeB: 'auto' }
|
|
1587
|
+
);
|
|
1588
|
+
const v2d = await tape.record(
|
|
1589
|
+
OpType.MATMUL,
|
|
1590
|
+
(x, w) => runMatmul(x, w, seqLen, numKVHeads * headDim, hiddenSize, {
|
|
1591
|
+
transposeB: 'auto',
|
|
1592
|
+
outputDtype: 'f32',
|
|
1593
|
+
}),
|
|
1594
|
+
[normed, layer.vProj],
|
|
1595
|
+
{ M: seqLen, N: numKVHeads * headDim, K: hiddenSize, transposeB: 'auto' }
|
|
1596
|
+
);
|
|
1597
|
+
|
|
1598
|
+
const q3d = createTensor(q2d.buffer, q2d.dtype, [seqLen, numHeads, headDim], `layer_${layerIdx}_q`);
|
|
1599
|
+
const k3d = createTensor(k2d.buffer, k2d.dtype, [seqLen, numKVHeads, headDim], `layer_${layerIdx}_k`);
|
|
1600
|
+
const v3d = createTensor(v2d.buffer, v2d.dtype, [seqLen, numKVHeads, headDim], `layer_${layerIdx}_v`);
|
|
1601
|
+
|
|
1602
|
+
const qRope = await tape.record(
|
|
1603
|
+
OpType.ROPE,
|
|
1604
|
+
(q, cos, sin) => runRoPE(q, cos, sin, seqLen, { numHeads, headDim, startPos: 0 }),
|
|
1605
|
+
[q3d, ropeCos, ropeSin],
|
|
1606
|
+
{ seqLen, numHeads, headDim, startPos: 0 }
|
|
1607
|
+
);
|
|
1608
|
+
const kRope = await tape.record(
|
|
1609
|
+
OpType.ROPE,
|
|
1610
|
+
(k, cos, sin) => runRoPE(k, cos, sin, seqLen, { numHeads: numKVHeads, headDim, startPos: 0 }),
|
|
1611
|
+
[k3d, ropeCos, ropeSin],
|
|
1612
|
+
{ seqLen, numHeads: numKVHeads, headDim, startPos: 0 }
|
|
1613
|
+
);
|
|
1614
|
+
|
|
1615
|
+
const attention = await tape.record(
|
|
1616
|
+
OpType.ATTENTION,
|
|
1617
|
+
(q, k, v) => runAttention(q, k, v, null, numHeads, headDim, {
|
|
1618
|
+
seqLen,
|
|
1619
|
+
kvLen: seqLen,
|
|
1620
|
+
numKVHeads,
|
|
1621
|
+
causal: true,
|
|
1622
|
+
startPos: 0,
|
|
1623
|
+
scale: 1 / Math.sqrt(headDim),
|
|
1624
|
+
}),
|
|
1625
|
+
[qRope, kRope, v3d],
|
|
1626
|
+
{ seqLen, numHeads, headDim, scale: 1 / Math.sqrt(headDim), causal: true, recomputeForward: true }
|
|
1627
|
+
);
|
|
1628
|
+
const attention2d = createTensor(
|
|
1629
|
+
attention.buffer,
|
|
1630
|
+
attention.dtype,
|
|
1631
|
+
[seqLen, hiddenSize],
|
|
1632
|
+
`layer_${layerIdx}_attn_2d`
|
|
1633
|
+
);
|
|
1634
|
+
|
|
1635
|
+
const attentionOutput = await tape.record(
|
|
1636
|
+
OpType.MATMUL,
|
|
1637
|
+
(x, w) => runMatmul(x, w, seqLen, hiddenSize, hiddenSize, {
|
|
1638
|
+
transposeB: 'auto',
|
|
1639
|
+
outputDtype: 'f32',
|
|
1640
|
+
}),
|
|
1641
|
+
[attention2d, layer.oProj],
|
|
1642
|
+
{ M: seqLen, N: hiddenSize, K: hiddenSize, transposeB: 'auto' }
|
|
1643
|
+
);
|
|
1644
|
+
const postAttention = await tape.record(
|
|
1645
|
+
OpType.RESIDUAL_ADD,
|
|
1646
|
+
(a, b) => runResidualAdd(a, b, seqLen * hiddenSize),
|
|
1647
|
+
[attentionOutput, hidden],
|
|
1648
|
+
{ size: seqLen * hiddenSize }
|
|
1649
|
+
);
|
|
1650
|
+
|
|
1651
|
+
const ffnInput = layer.postAttentionNorm
|
|
1652
|
+
? await tape.record(
|
|
1653
|
+
OpType.RMSNORM,
|
|
1654
|
+
(x, gamma) => runRMSNorm(x, gamma, rmsNormEps, {
|
|
1655
|
+
batchSize: seqLen,
|
|
1656
|
+
hiddenSize,
|
|
1657
|
+
rmsNormWeightOffset: modelConfig.rmsNormWeightOffset === true,
|
|
1658
|
+
}),
|
|
1659
|
+
[postAttention, layer.postAttentionNorm],
|
|
1660
|
+
{ numTokens: seqLen, hiddenSize, eps: rmsNormEps }
|
|
1661
|
+
)
|
|
1662
|
+
: postAttention;
|
|
1663
|
+
const gateUp = await tape.record(
|
|
1664
|
+
OpType.MATMUL,
|
|
1665
|
+
(x, w) => runMatmul(x, w, seqLen, intermediateSize * 2, hiddenSize, {
|
|
1666
|
+
transposeB: 'auto',
|
|
1667
|
+
outputDtype: 'f32',
|
|
1668
|
+
}),
|
|
1669
|
+
[ffnInput, layer.gateUp],
|
|
1670
|
+
{ M: seqLen, N: intermediateSize * 2, K: hiddenSize, transposeB: 'auto' }
|
|
1671
|
+
);
|
|
1672
|
+
const activated = await tape.record(
|
|
1673
|
+
OpType.SILU_ROWSPLIT,
|
|
1674
|
+
(x) => runSiLURowSplit(x, {
|
|
1675
|
+
numTokens: seqLen,
|
|
1676
|
+
dim: intermediateSize,
|
|
1677
|
+
activation: hiddenActivation === 'gelu' ? 'gelu' : 'silu',
|
|
1678
|
+
swigluLimit: hiddenActivation === 'gelu' ? null : swigluLimit,
|
|
1679
|
+
}),
|
|
1680
|
+
[gateUp],
|
|
1681
|
+
{
|
|
1682
|
+
numTokens: seqLen,
|
|
1683
|
+
dim: intermediateSize,
|
|
1684
|
+
activation: hiddenActivation === 'gelu' ? 'gelu' : 'silu',
|
|
1685
|
+
swigluLimit: hiddenActivation === 'gelu' ? 0 : swigluLimit,
|
|
1686
|
+
}
|
|
1687
|
+
);
|
|
1688
|
+
const ffnOutput = await tape.record(
|
|
1689
|
+
OpType.MATMUL,
|
|
1690
|
+
(x, w) => runMatmul(x, w, seqLen, hiddenSize, intermediateSize, {
|
|
1691
|
+
transposeB: 'auto',
|
|
1692
|
+
outputDtype: 'f32',
|
|
1693
|
+
}),
|
|
1694
|
+
[activated, layer.down],
|
|
1695
|
+
{ M: seqLen, N: hiddenSize, K: intermediateSize, transposeB: 'auto' }
|
|
1696
|
+
);
|
|
1697
|
+
hidden = await tape.record(
|
|
1698
|
+
OpType.RESIDUAL_ADD,
|
|
1699
|
+
(a, b) => runResidualAdd(a, b, seqLen * hiddenSize),
|
|
1700
|
+
[ffnOutput, postAttention],
|
|
1701
|
+
{ size: seqLen * hiddenSize }
|
|
1702
|
+
);
|
|
1703
|
+
}
|
|
1704
|
+
|
|
1705
|
+
const finalHidden = await tape.record(
|
|
1706
|
+
OpType.RMSNORM,
|
|
1707
|
+
(x, gamma) => runRMSNorm(x, gamma, rmsNormEps, {
|
|
1708
|
+
batchSize: seqLen,
|
|
1709
|
+
hiddenSize,
|
|
1710
|
+
rmsNormWeightOffset: modelConfig.rmsNormWeightOffset === true,
|
|
1711
|
+
}),
|
|
1712
|
+
[hidden, finalNormWeight],
|
|
1713
|
+
{ numTokens: seqLen, hiddenSize, eps: rmsNormEps }
|
|
1714
|
+
);
|
|
1715
|
+
const lastHidden = await tape.record(
|
|
1716
|
+
OpType.ROW_SLICE,
|
|
1717
|
+
(x) => createRowSliceTensor(x, seqLen, hiddenSize, seqLen - 1, 'distill_last_hidden'),
|
|
1718
|
+
[finalHidden],
|
|
1719
|
+
{ rows: seqLen, cols: hiddenSize, rowIndex: seqLen - 1 }
|
|
1720
|
+
);
|
|
1721
|
+
return tape.record(
|
|
1722
|
+
OpType.MATMUL,
|
|
1723
|
+
(x, w) => runMatmul(x, w, 1, vocabSize, hiddenSize, {
|
|
1724
|
+
transposeB: 'auto',
|
|
1725
|
+
outputDtype: 'f32',
|
|
1726
|
+
}),
|
|
1727
|
+
[lastHidden, lmHeadWeight],
|
|
1728
|
+
{ M: 1, N: vocabSize, K: hiddenSize, transposeB: 'auto' }
|
|
1729
|
+
);
|
|
1730
|
+
}
|
|
1731
|
+
|
|
1732
|
+
const model = {
|
|
1733
|
+
async forward(inputTensor, tape) {
|
|
1734
|
+
return tape.record(
|
|
1735
|
+
OpType.MATMUL,
|
|
1736
|
+
(x, w) => runMatmul(x, w, 1, vocabSize, hiddenSize, {
|
|
1737
|
+
transposeB: 'auto',
|
|
1738
|
+
outputDtype: 'f32',
|
|
1739
|
+
}),
|
|
1740
|
+
[inputTensor, lmHeadWeight],
|
|
1741
|
+
{ M: 1, N: vocabSize, K: hiddenSize, transposeB: 'auto' }
|
|
1742
|
+
);
|
|
1743
|
+
},
|
|
1744
|
+
async forwardDistill(batch, tape, forwardOptions = {}) {
|
|
1745
|
+
const requestedPhase = String(forwardOptions?.phase || 'anchor').trim();
|
|
1746
|
+
const phase = requestedPhase === 'positive'
|
|
1747
|
+
? 'positive'
|
|
1748
|
+
: (requestedPhase === 'negative' ? 'negative' : 'anchor');
|
|
1749
|
+
const prompts = resolvePhasePrompts(batch, phase);
|
|
1750
|
+
if (prompts.length !== 1) {
|
|
1751
|
+
throw new Error(
|
|
1752
|
+
`Distill full-graph student currently requires batchSize=1, got ${prompts.length}.`
|
|
1753
|
+
);
|
|
1754
|
+
}
|
|
1755
|
+
const logits = await runTransformerPrompt(prompts[0], tape);
|
|
1756
|
+
return { logits };
|
|
1757
|
+
},
|
|
1758
|
+
cleanupDistillStep() {
|
|
1759
|
+
for (const tensor of temporaryInputs) {
|
|
1760
|
+
releaseTensor(tensor);
|
|
1761
|
+
}
|
|
1762
|
+
temporaryInputs.clear();
|
|
1763
|
+
},
|
|
1764
|
+
loraParams() {
|
|
1765
|
+
return decoderParams;
|
|
1766
|
+
},
|
|
1767
|
+
paramGroups() {
|
|
1768
|
+
return {
|
|
1769
|
+
encoder: encoderParams,
|
|
1770
|
+
prior: [],
|
|
1771
|
+
decoder: decoderParams,
|
|
1772
|
+
base: baseParams,
|
|
1773
|
+
lora: [],
|
|
1774
|
+
};
|
|
1775
|
+
},
|
|
1776
|
+
};
|
|
1777
|
+
|
|
1778
|
+
return {
|
|
1779
|
+
config,
|
|
1780
|
+
model,
|
|
1781
|
+
outputDim: vocabSize,
|
|
1782
|
+
embeddingDim: hiddenSize,
|
|
1783
|
+
cleanup() {
|
|
1784
|
+
model.cleanupDistillStep();
|
|
1785
|
+
for (const tensor of ownedTrainables) {
|
|
1786
|
+
releaseTensor(tensor);
|
|
1787
|
+
}
|
|
1788
|
+
ownedTrainables.clear();
|
|
1789
|
+
},
|
|
1790
|
+
};
|
|
1791
|
+
}
|
|
1792
|
+
|
|
1793
|
+
async function createDistillStudentRuntimeModelFixture(overrides = {}, options = {}) {
|
|
1794
|
+
const distillRuntime = options.distillRuntime && typeof options.distillRuntime === 'object'
|
|
1795
|
+
? options.distillRuntime
|
|
1796
|
+
: null;
|
|
1797
|
+
const graphMode = normalizeDistillStudentGraphMode(
|
|
1798
|
+
options.studentGraphMode
|
|
1799
|
+
?? distillRuntime?.studentGraphMode
|
|
1800
|
+
?? overrides?.training?.distill?.studentGraphMode
|
|
1801
|
+
);
|
|
1802
|
+
if (graphMode === DISTILL_STUDENT_GRAPH_PROJECTION) {
|
|
1803
|
+
return createDistillStudentProjectionModelFixture(overrides, options);
|
|
1804
|
+
}
|
|
1805
|
+
return createDistillStudentTransformerModelFixture(overrides, options);
|
|
1806
|
+
}
|
|
1807
|
+
|
|
1808
|
+
async function runRunnerSmokeTest() {
|
|
1809
|
+
const fixture = createToyModelFixture();
|
|
1810
|
+
try {
|
|
1811
|
+
const runner = new TrainingRunner(fixture.config, {
|
|
1812
|
+
optimizer: new AdamOptimizer(fixture.config),
|
|
1813
|
+
crossEntropyLoss,
|
|
1814
|
+
clipGradients,
|
|
1815
|
+
});
|
|
1816
|
+
const dataset = {
|
|
1817
|
+
async *batches() {
|
|
1818
|
+
for (let i = 0; i < 3; i += 1) {
|
|
1819
|
+
yield fixture.batch;
|
|
1820
|
+
}
|
|
1821
|
+
},
|
|
1822
|
+
};
|
|
1823
|
+
|
|
1824
|
+
const metrics = await runner.run(fixture.model, dataset, {
|
|
1825
|
+
epochs: 1,
|
|
1826
|
+
batchSize: 1,
|
|
1827
|
+
shuffle: false,
|
|
1828
|
+
maxSteps: 3,
|
|
1829
|
+
});
|
|
1830
|
+
if (!Array.isArray(metrics) || metrics.length === 0) {
|
|
1831
|
+
return { passed: false, error: 'Training runner produced no metrics.' };
|
|
1832
|
+
}
|
|
1833
|
+
for (const entry of metrics) {
|
|
1834
|
+
if (!isFiniteNumber(entry.total_loss) || !isFiniteNumber(entry.step_time_ms)) {
|
|
1835
|
+
return { passed: false, error: 'Training runner emitted non-finite metrics.' };
|
|
1836
|
+
}
|
|
1837
|
+
}
|
|
1838
|
+
|
|
1839
|
+
return { passed: true };
|
|
1840
|
+
} finally {
|
|
1841
|
+
fixture.cleanup();
|
|
1842
|
+
}
|
|
1843
|
+
}
|
|
1844
|
+
|
|
1845
|
+
async function runTrainStepMetricsTest() {
|
|
1846
|
+
const fixture = createToyModelFixture();
|
|
1847
|
+
try {
|
|
1848
|
+
const result = await trainStep(fixture.model, fixture.batch, fixture.config, {
|
|
1849
|
+
crossEntropyLoss,
|
|
1850
|
+
clipGradients,
|
|
1851
|
+
optimizer: new AdamOptimizer(fixture.config),
|
|
1852
|
+
});
|
|
1853
|
+
|
|
1854
|
+
if (!isFiniteNumber(result.forward_ms) || !isFiniteNumber(result.backward_ms)) {
|
|
1855
|
+
return { passed: false, error: 'trainStep did not report finite phase timings.' };
|
|
1856
|
+
}
|
|
1857
|
+
if (!result.clipMetrics || !isFiniteNumber(result.clipMetrics.gradient_norm_unclipped)) {
|
|
1858
|
+
return { passed: false, error: 'trainStep did not report clipping metrics.' };
|
|
1859
|
+
}
|
|
1860
|
+
if (!result.optimizerMetrics || !isFiniteNumber(result.optimizerMetrics.optimizer_ms)) {
|
|
1861
|
+
return { passed: false, error: 'trainStep did not report optimizer metrics.' };
|
|
1862
|
+
}
|
|
1863
|
+
|
|
1864
|
+
return { passed: true };
|
|
1865
|
+
} finally {
|
|
1866
|
+
fixture.cleanup();
|
|
1867
|
+
}
|
|
1868
|
+
}
|
|
1869
|
+
|
|
1870
|
+
const UL_STAGE_SET = Object.freeze(['stage1_joint', 'stage2_base']);
|
|
1871
|
+
const DISTILL_STAGE_SET = Object.freeze(['stage_a', 'stage_b']);
|
|
1872
|
+
const TRAINING_STAGE_SET = Object.freeze([...UL_STAGE_SET, ...DISTILL_STAGE_SET]);
|
|
1873
|
+
|
|
1874
|
+
function normalizeTrainingStage(stage) {
|
|
1875
|
+
const normalized = String(stage || '').trim();
|
|
1876
|
+
if (!normalized) return null;
|
|
1877
|
+
if (!TRAINING_STAGE_SET.includes(normalized)) {
|
|
1878
|
+
throw new Error(`Unknown training stage "${normalized}". Expected one of: ${TRAINING_STAGE_SET.join(', ')}.`);
|
|
1879
|
+
}
|
|
1880
|
+
return normalized;
|
|
1881
|
+
}
|
|
1882
|
+
|
|
1883
|
+
function isUlStage(stage) {
|
|
1884
|
+
return UL_STAGE_SET.includes(String(stage || ''));
|
|
1885
|
+
}
|
|
1886
|
+
|
|
1887
|
+
function isDistillStage(stage) {
|
|
1888
|
+
return DISTILL_STAGE_SET.includes(String(stage || ''));
|
|
1889
|
+
}
|
|
1890
|
+
|
|
1891
|
+
function normalizeTrainingConfigOverride(value) {
|
|
1892
|
+
if (!value) return null;
|
|
1893
|
+
if (typeof value !== 'object' || Array.isArray(value)) {
|
|
1894
|
+
throw new Error('trainingConfig must be an object when provided.');
|
|
1895
|
+
}
|
|
1896
|
+
return value;
|
|
1897
|
+
}
|
|
1898
|
+
|
|
1899
|
+
function normalizeAdapterActivationConfig(options = {}) {
|
|
1900
|
+
const runtimeConfig = normalizeTrainingConfigOverride(options.trainingConfig);
|
|
1901
|
+
const direct = options.adapterActivation;
|
|
1902
|
+
const nested = runtimeConfig?.adapterActivation;
|
|
1903
|
+
const config = direct && typeof direct === 'object' ? direct : (nested && typeof nested === 'object' ? nested : null);
|
|
1904
|
+
if (!config) {
|
|
1905
|
+
return {
|
|
1906
|
+
enabled: false,
|
|
1907
|
+
autoActivate: false,
|
|
1908
|
+
adapterPayload: null,
|
|
1909
|
+
exportConfig: null,
|
|
1910
|
+
};
|
|
1911
|
+
}
|
|
1912
|
+
const exportConfig = config.export && typeof config.export === 'object'
|
|
1913
|
+
? config.export
|
|
1914
|
+
: null;
|
|
1915
|
+
const adapterPayload = (() => {
|
|
1916
|
+
if (config.adapterManifest && typeof config.adapterManifest === 'object') {
|
|
1917
|
+
return { adapterManifest: config.adapterManifest };
|
|
1918
|
+
}
|
|
1919
|
+
if (typeof config.adapterManifestJson === 'string' && config.adapterManifestJson.trim()) {
|
|
1920
|
+
return { adapterManifestJson: config.adapterManifestJson };
|
|
1921
|
+
}
|
|
1922
|
+
if (typeof config.adapterManifestUrl === 'string' && config.adapterManifestUrl.trim()) {
|
|
1923
|
+
return { adapterManifestUrl: config.adapterManifestUrl };
|
|
1924
|
+
}
|
|
1925
|
+
if (typeof config.adapterManifestPath === 'string' && config.adapterManifestPath.trim()) {
|
|
1926
|
+
return { adapterManifestPath: config.adapterManifestPath };
|
|
1927
|
+
}
|
|
1928
|
+
if (config.adapter != null) {
|
|
1929
|
+
return { adapter: config.adapter };
|
|
1930
|
+
}
|
|
1931
|
+
return null;
|
|
1932
|
+
})();
|
|
1933
|
+
return {
|
|
1934
|
+
enabled: config.enabled !== false,
|
|
1935
|
+
autoActivate: config.autoActivate === true,
|
|
1936
|
+
adapterPayload,
|
|
1937
|
+
exportConfig,
|
|
1938
|
+
};
|
|
1939
|
+
}
|
|
1940
|
+
|
|
1941
|
+
function normalizeLoRAExportConfig(value) {
|
|
1942
|
+
if (!value || typeof value !== 'object' || Array.isArray(value)) {
|
|
1943
|
+
return null;
|
|
1944
|
+
}
|
|
1945
|
+
const tensors = Array.isArray(value.tensors) ? value.tensors : [];
|
|
1946
|
+
if (tensors.length === 0) {
|
|
1947
|
+
return null;
|
|
1948
|
+
}
|
|
1949
|
+
const normalizedTensors = tensors.map((entry, index) => {
|
|
1950
|
+
const name = normalizeOptionalString(entry?.name);
|
|
1951
|
+
const paramIndex = Number.isFinite(entry?.paramIndex)
|
|
1952
|
+
? Math.floor(entry.paramIndex)
|
|
1953
|
+
: -1;
|
|
1954
|
+
if (!name) {
|
|
1955
|
+
throw new Error(`adapterActivation.export.tensors[${index}].name is required.`);
|
|
1956
|
+
}
|
|
1957
|
+
if (!Number.isInteger(paramIndex) || paramIndex < 0) {
|
|
1958
|
+
throw new Error(`adapterActivation.export.tensors[${index}].paramIndex must be a non-negative integer.`);
|
|
1959
|
+
}
|
|
1960
|
+
return { name, paramIndex };
|
|
1961
|
+
});
|
|
1962
|
+
const targetModules = Array.isArray(value.targetModules)
|
|
1963
|
+
? value.targetModules.map((moduleName) => String(moduleName || '').trim()).filter(Boolean)
|
|
1964
|
+
: [];
|
|
1965
|
+
if (targetModules.length === 0) {
|
|
1966
|
+
throw new Error('adapterActivation.export.targetModules must contain at least one module.');
|
|
1967
|
+
}
|
|
1968
|
+
const id = normalizeOptionalString(value.id);
|
|
1969
|
+
const name = normalizeOptionalString(value.name);
|
|
1970
|
+
const baseModel = normalizeOptionalString(value.baseModel);
|
|
1971
|
+
const rank = Number(value.rank);
|
|
1972
|
+
const alpha = Number(value.alpha);
|
|
1973
|
+
if (!id || !name || !baseModel) {
|
|
1974
|
+
throw new Error('adapterActivation.export requires id, name, and baseModel.');
|
|
1975
|
+
}
|
|
1976
|
+
if (!Number.isFinite(rank) || rank <= 0 || !Number.isInteger(rank)) {
|
|
1977
|
+
throw new Error('adapterActivation.export.rank must be a positive integer.');
|
|
1978
|
+
}
|
|
1979
|
+
if (!Number.isFinite(alpha) || alpha <= 0) {
|
|
1980
|
+
throw new Error('adapterActivation.export.alpha must be a positive number.');
|
|
1981
|
+
}
|
|
1982
|
+
return {
|
|
1983
|
+
id,
|
|
1984
|
+
name,
|
|
1985
|
+
baseModel,
|
|
1986
|
+
rank,
|
|
1987
|
+
alpha,
|
|
1988
|
+
targetModules,
|
|
1989
|
+
tensors: normalizedTensors,
|
|
1990
|
+
format: value.format === 'array' ? 'array' : 'base64',
|
|
1991
|
+
pretty: value.pretty === true,
|
|
1992
|
+
};
|
|
1993
|
+
}
|
|
1994
|
+
|
|
1995
|
+
async function exportLoRAAdapterFromModel(model, exportConfig, runIndex = null) {
|
|
1996
|
+
const normalizedConfig = normalizeLoRAExportConfig(exportConfig);
|
|
1997
|
+
if (!normalizedConfig) return null;
|
|
1998
|
+
if (!model || typeof model.loraParams !== 'function') {
|
|
1999
|
+
throw new Error('adapterActivation.export requires model.loraParams() support.');
|
|
2000
|
+
}
|
|
2001
|
+
const params = model.loraParams();
|
|
2002
|
+
if (!Array.isArray(params) || params.length === 0) {
|
|
2003
|
+
throw new Error('adapterActivation.export requires non-empty model.loraParams().');
|
|
2004
|
+
}
|
|
2005
|
+
const tensors = normalizedConfig.tensors.map((entry) => {
|
|
2006
|
+
const tensor = params[entry.paramIndex];
|
|
2007
|
+
if (!tensor) {
|
|
2008
|
+
throw new Error(`adapterActivation.export tensor paramIndex ${entry.paramIndex} is out of range.`);
|
|
2009
|
+
}
|
|
2010
|
+
return {
|
|
2011
|
+
name: entry.name,
|
|
2012
|
+
tensor,
|
|
2013
|
+
};
|
|
2014
|
+
});
|
|
2015
|
+
const exported = await exportLoRAAdapter({
|
|
2016
|
+
id: normalizedConfig.id,
|
|
2017
|
+
name: normalizedConfig.name,
|
|
2018
|
+
baseModel: normalizedConfig.baseModel,
|
|
2019
|
+
rank: normalizedConfig.rank,
|
|
2020
|
+
alpha: normalizedConfig.alpha,
|
|
2021
|
+
targetModules: normalizedConfig.targetModules,
|
|
2022
|
+
tensors,
|
|
2023
|
+
format: normalizedConfig.format,
|
|
2024
|
+
pretty: normalizedConfig.pretty,
|
|
2025
|
+
});
|
|
2026
|
+
return {
|
|
2027
|
+
runIndex,
|
|
2028
|
+
manifest: exported.manifest,
|
|
2029
|
+
json: exported.json,
|
|
2030
|
+
hash: sha256Hex(exported.json),
|
|
2031
|
+
};
|
|
2032
|
+
}
|
|
2033
|
+
|
|
2034
|
+
async function tryActivateAdapterPayload(payload) {
|
|
2035
|
+
if (!payload) {
|
|
2036
|
+
return {
|
|
2037
|
+
activated: false,
|
|
2038
|
+
adapterName: null,
|
|
2039
|
+
source: null,
|
|
2040
|
+
reason: 'no_adapter_payload',
|
|
2041
|
+
};
|
|
2042
|
+
}
|
|
2043
|
+
const { activateLoRAFromTrainingOutput } = await import('../client/doppler-provider/model-manager.js');
|
|
2044
|
+
try {
|
|
2045
|
+
return await activateLoRAFromTrainingOutput(payload);
|
|
2046
|
+
} catch (error) {
|
|
2047
|
+
return {
|
|
2048
|
+
activated: false,
|
|
2049
|
+
adapterName: null,
|
|
2050
|
+
source: null,
|
|
2051
|
+
reason: String(error?.message || error),
|
|
2052
|
+
};
|
|
2053
|
+
}
|
|
2054
|
+
}
|
|
2055
|
+
|
|
2056
|
+
function buildUlTrainingOverrides(options = {}) {
|
|
2057
|
+
const trainingConfig = normalizeTrainingConfigOverride(options.trainingConfig);
|
|
2058
|
+
const explicitStage = normalizeTrainingStage(options.trainingStage || trainingConfig?.ul?.stage);
|
|
2059
|
+
const ulEnabled = isUlStage(explicitStage) || trainingConfig?.ul?.enabled === true;
|
|
2060
|
+
if (!ulEnabled) {
|
|
2061
|
+
return trainingConfig || null;
|
|
2062
|
+
}
|
|
2063
|
+
const stage = isUlStage(explicitStage) ? explicitStage : 'stage1_joint';
|
|
2064
|
+
const ulOverride = {
|
|
2065
|
+
...(trainingConfig?.ul || {}),
|
|
2066
|
+
enabled: true,
|
|
2067
|
+
stage,
|
|
2068
|
+
stage1Artifact: options.stage1Artifact ?? trainingConfig?.ul?.stage1Artifact ?? null,
|
|
2069
|
+
stage1ArtifactHash: options.stage1ArtifactHash ?? trainingConfig?.ul?.stage1ArtifactHash ?? null,
|
|
2070
|
+
artifactDir: options.ulArtifactDir ?? trainingConfig?.ul?.artifactDir ?? 'reports/training/ul',
|
|
2071
|
+
};
|
|
2072
|
+
if (stage === 'stage2_base') {
|
|
2073
|
+
ulOverride.freeze = {
|
|
2074
|
+
encoder: true,
|
|
2075
|
+
prior: true,
|
|
2076
|
+
decoder: true,
|
|
2077
|
+
base: false,
|
|
2078
|
+
lora: false,
|
|
2079
|
+
...(trainingConfig?.ul?.freeze || {}),
|
|
2080
|
+
};
|
|
2081
|
+
}
|
|
2082
|
+
return {
|
|
2083
|
+
...(trainingConfig || {}),
|
|
2084
|
+
ul: ulOverride,
|
|
2085
|
+
};
|
|
2086
|
+
}
|
|
2087
|
+
|
|
2088
|
+
function buildDistillTrainingOverrides(options = {}) {
|
|
2089
|
+
const trainingConfig = normalizeTrainingConfigOverride(options.trainingConfig);
|
|
2090
|
+
const explicitStage = normalizeTrainingStage(options.trainingStage || trainingConfig?.distill?.stage);
|
|
2091
|
+
const distillEnabled = isDistillStage(explicitStage) || trainingConfig?.distill?.enabled === true;
|
|
2092
|
+
if (!distillEnabled) {
|
|
2093
|
+
return trainingConfig || null;
|
|
2094
|
+
}
|
|
2095
|
+
const stage = isDistillStage(explicitStage) ? explicitStage : 'stage_a';
|
|
2096
|
+
const distillOverride = {
|
|
2097
|
+
...(trainingConfig?.distill || {}),
|
|
2098
|
+
enabled: true,
|
|
2099
|
+
stage,
|
|
2100
|
+
teacherModelId: options.teacherModelId ?? trainingConfig?.distill?.teacherModelId ?? null,
|
|
2101
|
+
studentModelId: options.studentModelId ?? trainingConfig?.distill?.studentModelId ?? null,
|
|
2102
|
+
datasetId: options.distillDatasetId ?? trainingConfig?.distill?.datasetId ?? null,
|
|
2103
|
+
datasetPath: options.distillDatasetPath ?? trainingConfig?.distill?.datasetPath ?? null,
|
|
2104
|
+
languagePair: options.distillLanguagePair ?? trainingConfig?.distill?.languagePair ?? null,
|
|
2105
|
+
sourceLangs: (
|
|
2106
|
+
options.distillSourceLangs
|
|
2107
|
+
?? trainingConfig?.distill?.sourceLangs
|
|
2108
|
+
?? null
|
|
2109
|
+
),
|
|
2110
|
+
targetLangs: (
|
|
2111
|
+
options.distillTargetLangs
|
|
2112
|
+
?? trainingConfig?.distill?.targetLangs
|
|
2113
|
+
?? null
|
|
2114
|
+
),
|
|
2115
|
+
pairAllowlist: (
|
|
2116
|
+
options.distillPairAllowlist
|
|
2117
|
+
?? trainingConfig?.distill?.pairAllowlist
|
|
2118
|
+
?? null
|
|
2119
|
+
),
|
|
2120
|
+
strictPairContract: (
|
|
2121
|
+
options.strictPairContract === true
|
|
2122
|
+
|| trainingConfig?.distill?.strictPairContract === true
|
|
2123
|
+
),
|
|
2124
|
+
shardIndex: options.distillShardIndex ?? trainingConfig?.distill?.shardIndex ?? null,
|
|
2125
|
+
shardCount: options.distillShardCount ?? trainingConfig?.distill?.shardCount ?? null,
|
|
2126
|
+
resumeFrom: options.resumeFrom ?? trainingConfig?.distill?.resumeFrom ?? null,
|
|
2127
|
+
stageAArtifact: options.stageAArtifact ?? trainingConfig?.distill?.stageAArtifact ?? null,
|
|
2128
|
+
stageAArtifactHash: options.stageAArtifactHash ?? trainingConfig?.distill?.stageAArtifactHash ?? null,
|
|
2129
|
+
artifactDir: options.distillArtifactDir ?? trainingConfig?.distill?.artifactDir ?? 'reports/training/distill',
|
|
2130
|
+
};
|
|
2131
|
+
if (stage === 'stage_b') {
|
|
2132
|
+
distillOverride.freeze = {
|
|
2133
|
+
encoder: true,
|
|
2134
|
+
prior: true,
|
|
2135
|
+
decoder: true,
|
|
2136
|
+
base: false,
|
|
2137
|
+
lora: false,
|
|
2138
|
+
...(trainingConfig?.distill?.freeze || {}),
|
|
2139
|
+
};
|
|
2140
|
+
}
|
|
2141
|
+
return {
|
|
2142
|
+
...(trainingConfig || {}),
|
|
2143
|
+
distill: distillOverride,
|
|
2144
|
+
};
|
|
2145
|
+
}
|
|
2146
|
+
|
|
2147
|
+
async function computeNodeFileHash(filePath) {
|
|
2148
|
+
if (!(typeof process !== 'undefined' && process.versions?.node)) {
|
|
2149
|
+
return null;
|
|
2150
|
+
}
|
|
2151
|
+
const [{ readFile }, { resolve }] = await Promise.all([
|
|
2152
|
+
import('node:fs/promises'),
|
|
2153
|
+
import('node:path'),
|
|
2154
|
+
]);
|
|
2155
|
+
const absolutePath = resolve(String(filePath));
|
|
2156
|
+
const raw = await readFile(absolutePath, 'utf8');
|
|
2157
|
+
return {
|
|
2158
|
+
absolutePath,
|
|
2159
|
+
hash: sha256Hex(raw),
|
|
2160
|
+
};
|
|
2161
|
+
}
|
|
2162
|
+
|
|
2163
|
+
async function resolveIsolatedArtifactDir(explicitDir, prefix) {
|
|
2164
|
+
const normalized = normalizeOptionalString(explicitDir);
|
|
2165
|
+
if (normalized) {
|
|
2166
|
+
return normalized;
|
|
2167
|
+
}
|
|
2168
|
+
if (!(typeof process !== 'undefined' && process.versions?.node)) {
|
|
2169
|
+
return null;
|
|
2170
|
+
}
|
|
2171
|
+
const [{ mkdtemp }, { tmpdir }, { join }] = await Promise.all([
|
|
2172
|
+
import('node:fs/promises'),
|
|
2173
|
+
import('node:os'),
|
|
2174
|
+
import('node:path'),
|
|
2175
|
+
]);
|
|
2176
|
+
return mkdtemp(join(tmpdir(), `doppler-${prefix}-`));
|
|
2177
|
+
}
|
|
2178
|
+
|
|
2179
|
+
async function runUlStageTest(stage, options = {}) {
|
|
2180
|
+
const ulTraining = buildUlTrainingOverrides({
|
|
2181
|
+
...options,
|
|
2182
|
+
trainingStage: stage,
|
|
2183
|
+
});
|
|
2184
|
+
const fixture = createToyModelFixture({
|
|
2185
|
+
training: ulTraining || undefined,
|
|
2186
|
+
});
|
|
2187
|
+
|
|
2188
|
+
try {
|
|
2189
|
+
const runner = new TrainingRunner(fixture.config, {
|
|
2190
|
+
optimizer: new AdamOptimizer(fixture.config),
|
|
2191
|
+
crossEntropyLoss,
|
|
2192
|
+
clipGradients,
|
|
2193
|
+
});
|
|
2194
|
+
const dataset = {
|
|
2195
|
+
async *batches() {
|
|
2196
|
+
for (let i = 0; i < 2; i += 1) {
|
|
2197
|
+
yield fixture.batch;
|
|
2198
|
+
}
|
|
2199
|
+
},
|
|
2200
|
+
};
|
|
2201
|
+
const ulArtifactDir = await resolveIsolatedArtifactDir(options.ulArtifactDir, 'ul');
|
|
2202
|
+
const metrics = await runner.run(fixture.model, dataset, {
|
|
2203
|
+
epochs: 1,
|
|
2204
|
+
batchSize: 1,
|
|
2205
|
+
shuffle: false,
|
|
2206
|
+
maxSteps: 2,
|
|
2207
|
+
modelId: options.modelId || 'training',
|
|
2208
|
+
modelUrl: options.modelUrl || null,
|
|
2209
|
+
runtimePreset: options.runtimePreset || null,
|
|
2210
|
+
trainingStage: stage,
|
|
2211
|
+
command: options.command || null,
|
|
2212
|
+
surface: options.surface || null,
|
|
2213
|
+
forceResume: options.forceResume === true,
|
|
2214
|
+
forceResumeReason: options.forceResumeReason || null,
|
|
2215
|
+
forceResumeSource: options.forceResumeSource || null,
|
|
2216
|
+
checkpointOperator: options.checkpointOperator || null,
|
|
2217
|
+
checkpointEvery: options.checkpointEvery ?? null,
|
|
2218
|
+
gpuAdapterInfo: getKernelCapabilities(),
|
|
2219
|
+
timestamp: options.timestamp || null,
|
|
2220
|
+
ulArtifactDir,
|
|
2221
|
+
});
|
|
2222
|
+
if (!Array.isArray(metrics) || metrics.length === 0) {
|
|
2223
|
+
return { passed: false, error: `UL ${stage} produced no metrics.` };
|
|
2224
|
+
}
|
|
2225
|
+
const requiredFields = [
|
|
2226
|
+
'loss_prior',
|
|
2227
|
+
'loss_decoder',
|
|
2228
|
+
'loss_recon',
|
|
2229
|
+
'lambda',
|
|
2230
|
+
'latent_bitrate_proxy',
|
|
2231
|
+
'loss_total',
|
|
2232
|
+
'coeff_ce',
|
|
2233
|
+
'coeff_prior',
|
|
2234
|
+
'coeff_decoder',
|
|
2235
|
+
'coeff_recon',
|
|
2236
|
+
];
|
|
2237
|
+
if (stage === 'stage1_joint') {
|
|
2238
|
+
requiredFields.push(
|
|
2239
|
+
'schedule_step_index',
|
|
2240
|
+
'latent_clean_mean',
|
|
2241
|
+
'latent_clean_std',
|
|
2242
|
+
'latent_noise_mean',
|
|
2243
|
+
'latent_noise_std',
|
|
2244
|
+
'latent_noisy_mean',
|
|
2245
|
+
'latent_noisy_std',
|
|
2246
|
+
'latent_shape',
|
|
2247
|
+
'latent_clean_values',
|
|
2248
|
+
'latent_noise_values',
|
|
2249
|
+
'latent_noisy_values'
|
|
2250
|
+
);
|
|
2251
|
+
}
|
|
2252
|
+
if (stage === 'stage2_base') {
|
|
2253
|
+
requiredFields.push('stage1_latent_count');
|
|
2254
|
+
}
|
|
2255
|
+
for (const field of requiredFields) {
|
|
2256
|
+
if (!(field in metrics[0])) {
|
|
2257
|
+
return { passed: false, error: `UL ${stage} missing metric field "${field}".` };
|
|
2258
|
+
}
|
|
2259
|
+
}
|
|
2260
|
+
const artifact = runner.lastArtifact;
|
|
2261
|
+
if (!artifact || !artifact.manifestPath) {
|
|
2262
|
+
return { passed: false, error: `UL ${stage} did not produce artifacts.` };
|
|
2263
|
+
}
|
|
2264
|
+
return {
|
|
2265
|
+
passed: true,
|
|
2266
|
+
artifact: {
|
|
2267
|
+
...artifact,
|
|
2268
|
+
resumeAudits: Array.isArray(runner.resumeState?.resumeAudits)
|
|
2269
|
+
? runner.resumeState.resumeAudits
|
|
2270
|
+
: [],
|
|
2271
|
+
},
|
|
2272
|
+
metrics: {
|
|
2273
|
+
stage,
|
|
2274
|
+
steps: metrics.length,
|
|
2275
|
+
manifestPath: artifact.manifestPath,
|
|
2276
|
+
manifestHash: artifact.manifestHash,
|
|
2277
|
+
manifestContentHash: artifact.manifestContentHash,
|
|
2278
|
+
manifestFileHash: artifact.manifestFileHash ?? null,
|
|
2279
|
+
ulResolvedConfig: {
|
|
2280
|
+
enabled: fixture.config.training?.ul?.enabled === true,
|
|
2281
|
+
stage: fixture.config.training?.ul?.stage ?? null,
|
|
2282
|
+
lambda0: fixture.config.training?.ul?.lambda0 ?? null,
|
|
2283
|
+
seed: fixture.config.training?.ul?.seed ?? null,
|
|
2284
|
+
noiseSchedule: fixture.config.training?.ul?.noiseSchedule ?? null,
|
|
2285
|
+
priorAlignment: fixture.config.training?.ul?.priorAlignment ?? null,
|
|
2286
|
+
decoderSigmoidWeight: fixture.config.training?.ul?.decoderSigmoidWeight ?? null,
|
|
2287
|
+
freeze: fixture.config.training?.ul?.freeze ?? null,
|
|
2288
|
+
},
|
|
2289
|
+
resumeAuditCount: Number.isInteger(runner.resumeState?.resumeAuditCount)
|
|
2290
|
+
? runner.resumeState.resumeAuditCount
|
|
2291
|
+
: 0,
|
|
2292
|
+
},
|
|
2293
|
+
};
|
|
2294
|
+
} finally {
|
|
2295
|
+
fixture.cleanup();
|
|
2296
|
+
}
|
|
2297
|
+
}
|
|
2298
|
+
|
|
2299
|
+
async function runUlStage1Test(options = {}) {
|
|
2300
|
+
return runUlStageTest('stage1_joint', options);
|
|
2301
|
+
}
|
|
2302
|
+
|
|
2303
|
+
async function runUlStage2Test(options = {}) {
|
|
2304
|
+
const explicitStage1Artifact = String(options.stage1Artifact || '').trim();
|
|
2305
|
+
let stage1Artifact = explicitStage1Artifact || null;
|
|
2306
|
+
let stage1ArtifactHash = String(options.stage1ArtifactHash || '').trim() || null;
|
|
2307
|
+
|
|
2308
|
+
if (!stage1Artifact) {
|
|
2309
|
+
const stage1 = await runUlStage1Test({
|
|
2310
|
+
...options,
|
|
2311
|
+
trainingStage: 'stage1_joint',
|
|
2312
|
+
});
|
|
2313
|
+
if (!stage1?.passed || !stage1?.artifact?.manifestPath) {
|
|
2314
|
+
return { passed: false, error: 'UL stage2 preflight failed to generate stage1 artifact.' };
|
|
2315
|
+
}
|
|
2316
|
+
stage1Artifact = stage1.artifact.manifestPath;
|
|
2317
|
+
stage1ArtifactHash = stage1.artifact.manifestHash;
|
|
2318
|
+
const nodeHash = await computeNodeFileHash(stage1Artifact);
|
|
2319
|
+
if (nodeHash?.hash) {
|
|
2320
|
+
stage1ArtifactHash = nodeHash.hash;
|
|
2321
|
+
stage1Artifact = nodeHash.absolutePath;
|
|
2322
|
+
}
|
|
2323
|
+
}
|
|
2324
|
+
|
|
2325
|
+
return runUlStageTest('stage2_base', {
|
|
2326
|
+
...options,
|
|
2327
|
+
stage1Artifact,
|
|
2328
|
+
stage1ArtifactHash,
|
|
2329
|
+
});
|
|
2330
|
+
}
|
|
2331
|
+
|
|
2332
|
+
async function runDistillStageTest(stage, options = {}) {
|
|
2333
|
+
const distillTraining = buildDistillTrainingOverrides({
|
|
2334
|
+
...options,
|
|
2335
|
+
trainingStage: stage,
|
|
2336
|
+
});
|
|
2337
|
+
const distillOutputDim = clampDistillTopK(distillTraining?.distill?.topK ?? DISTILL_ADAPTER_TOP_K);
|
|
2338
|
+
const resolvedTrainingConfig = createTrainingConfig({
|
|
2339
|
+
training: distillTraining || undefined,
|
|
2340
|
+
}).training;
|
|
2341
|
+
let fixture = null;
|
|
2342
|
+
let distillRuntime = null;
|
|
2343
|
+
|
|
2344
|
+
try {
|
|
2345
|
+
const distillDatasetPath = resolveDistillDatasetPath(options, resolvedTrainingConfig);
|
|
2346
|
+
if (!distillDatasetPath) {
|
|
2347
|
+
throw new Error('Distill stage requires --distill-dataset-path (training.distill.datasetPath).');
|
|
2348
|
+
}
|
|
2349
|
+
const distillDataScope = resolveDistillDataScope(options, resolvedTrainingConfig);
|
|
2350
|
+
const distillDatasetReport = await loadDistillDatasetFromJsonl(distillDatasetPath, distillDataScope);
|
|
2351
|
+
distillRuntime = await createDistillRuntimeContext({
|
|
2352
|
+
...options,
|
|
2353
|
+
trainingStage: stage,
|
|
2354
|
+
}, resolvedTrainingConfig);
|
|
2355
|
+
fixture = await createDistillStudentRuntimeModelFixture({
|
|
2356
|
+
training: distillTraining || undefined,
|
|
2357
|
+
}, {
|
|
2358
|
+
outputDim: distillOutputDim,
|
|
2359
|
+
distillRuntime,
|
|
2360
|
+
});
|
|
2361
|
+
|
|
2362
|
+
const runner = new TrainingRunner(fixture.config, {
|
|
2363
|
+
optimizer: new AdamOptimizer(fixture.config),
|
|
2364
|
+
crossEntropyLoss,
|
|
2365
|
+
clipGradients,
|
|
2366
|
+
});
|
|
2367
|
+
const distillMaxSteps = Number.isInteger(options.trainingBenchSteps) && options.trainingBenchSteps > 0
|
|
2368
|
+
? options.trainingBenchSteps
|
|
2369
|
+
: 2;
|
|
2370
|
+
const dataset = distillDatasetReport.createDataset({
|
|
2371
|
+
batchSize: 1,
|
|
2372
|
+
shuffle: false,
|
|
2373
|
+
seed: 1337,
|
|
2374
|
+
distillRuntime,
|
|
2375
|
+
});
|
|
2376
|
+
const distillRunStartMs = performance.now();
|
|
2377
|
+
const distillArtifactDir = await resolveIsolatedArtifactDir(options.distillArtifactDir, 'distill');
|
|
2378
|
+
const metrics = await runner.run(fixture.model, dataset, {
|
|
2379
|
+
epochs: 1,
|
|
2380
|
+
batchSize: 1,
|
|
2381
|
+
shuffle: false,
|
|
2382
|
+
maxSteps: distillMaxSteps,
|
|
2383
|
+
modelId: options.modelId || distillRuntime.studentModelId || 'training',
|
|
2384
|
+
modelUrl: options.modelUrl || distillRuntime.studentModelUrl || null,
|
|
2385
|
+
runtimePreset: options.runtimePreset || null,
|
|
2386
|
+
trainingStage: stage,
|
|
2387
|
+
command: options.command || null,
|
|
2388
|
+
surface: options.surface || null,
|
|
2389
|
+
forceResume: options.forceResume === true,
|
|
2390
|
+
forceResumeReason: options.forceResumeReason || null,
|
|
2391
|
+
forceResumeSource: options.forceResumeSource || null,
|
|
2392
|
+
checkpointOperator: options.checkpointOperator || null,
|
|
2393
|
+
checkpointEvery: options.checkpointEvery ?? null,
|
|
2394
|
+
gpuAdapterInfo: getKernelCapabilities(),
|
|
2395
|
+
timestamp: options.timestamp || null,
|
|
2396
|
+
distillArtifactDir,
|
|
2397
|
+
stageAArtifact: options.stageAArtifact || null,
|
|
2398
|
+
stageAArtifactHash: options.stageAArtifactHash || null,
|
|
2399
|
+
teacherModelId: distillRuntime.teacherModelId || null,
|
|
2400
|
+
studentModelId: distillRuntime.studentModelId || null,
|
|
2401
|
+
distillDatasetId: options.distillDatasetId || null,
|
|
2402
|
+
distillDatasetPath: distillDatasetReport.absolutePath,
|
|
2403
|
+
distillLanguagePair: options.distillLanguagePair || null,
|
|
2404
|
+
distillSourceLangs: distillDataScope.sourceLangs || null,
|
|
2405
|
+
distillTargetLangs: distillDataScope.targetLangs || null,
|
|
2406
|
+
distillPairAllowlist: distillDataScope.pairAllowlist || null,
|
|
2407
|
+
strictPairContract: distillDataScope.strictPairContract === true,
|
|
2408
|
+
distillShardIndex: options.distillShardIndex ?? fixture.config.training?.distill?.shardIndex ?? null,
|
|
2409
|
+
distillShardCount: options.distillShardCount ?? fixture.config.training?.distill?.shardCount ?? null,
|
|
2410
|
+
resumeFrom: options.resumeFrom ?? fixture.config.training?.distill?.resumeFrom ?? null,
|
|
2411
|
+
});
|
|
2412
|
+
if (!Array.isArray(metrics) || metrics.length === 0) {
|
|
2413
|
+
return { passed: false, error: `Distill ${stage} produced no metrics.` };
|
|
2414
|
+
}
|
|
2415
|
+
const requiredFields = stage === 'stage_a'
|
|
2416
|
+
? ['loss_kd', 'distill_stage']
|
|
2417
|
+
: ['loss_triplet', 'distill_stage', 'distill_triplet_margin'];
|
|
2418
|
+
for (const field of requiredFields) {
|
|
2419
|
+
if (!(field in metrics[0])) {
|
|
2420
|
+
return { passed: false, error: `Distill ${stage} missing metric field "${field}".` };
|
|
2421
|
+
}
|
|
2422
|
+
}
|
|
2423
|
+
const artifact = runner.lastArtifact;
|
|
2424
|
+
if (!artifact || !artifact.manifestPath) {
|
|
2425
|
+
return { passed: false, error: `Distill ${stage} did not produce artifacts.` };
|
|
2426
|
+
}
|
|
2427
|
+
const progress = resolveBenchProgressSummary(
|
|
2428
|
+
metrics,
|
|
2429
|
+
resolveDistillShardProgressContext(
|
|
2430
|
+
options,
|
|
2431
|
+
fixture.config.training,
|
|
2432
|
+
distillMaxSteps,
|
|
2433
|
+
distillDatasetReport?.shardCount ?? null
|
|
2434
|
+
),
|
|
2435
|
+
distillRunStartMs
|
|
2436
|
+
);
|
|
2437
|
+
return {
|
|
2438
|
+
passed: true,
|
|
2439
|
+
artifact: {
|
|
2440
|
+
...artifact,
|
|
2441
|
+
resumeAudits: Array.isArray(runner.resumeState?.resumeAudits)
|
|
2442
|
+
? runner.resumeState.resumeAudits
|
|
2443
|
+
: [],
|
|
2444
|
+
},
|
|
2445
|
+
metrics: {
|
|
2446
|
+
stage,
|
|
2447
|
+
steps: metrics.length,
|
|
2448
|
+
progress,
|
|
2449
|
+
manifestPath: artifact.manifestPath,
|
|
2450
|
+
manifestHash: artifact.manifestHash,
|
|
2451
|
+
manifestContentHash: artifact.manifestContentHash,
|
|
2452
|
+
manifestFileHash: artifact.manifestFileHash ?? null,
|
|
2453
|
+
distillResolvedConfig: {
|
|
2454
|
+
enabled: fixture.config.training?.distill?.enabled === true,
|
|
2455
|
+
stage: fixture.config.training?.distill?.stage ?? null,
|
|
2456
|
+
teacherModelId: fixture.config.training?.distill?.teacherModelId ?? null,
|
|
2457
|
+
studentModelId: fixture.config.training?.distill?.studentModelId ?? null,
|
|
2458
|
+
datasetId: fixture.config.training?.distill?.datasetId ?? null,
|
|
2459
|
+
datasetPath: fixture.config.training?.distill?.datasetPath ?? null,
|
|
2460
|
+
languagePair: fixture.config.training?.distill?.languagePair ?? null,
|
|
2461
|
+
sourceLangs: fixture.config.training?.distill?.sourceLangs ?? null,
|
|
2462
|
+
targetLangs: fixture.config.training?.distill?.targetLangs ?? null,
|
|
2463
|
+
pairAllowlist: fixture.config.training?.distill?.pairAllowlist ?? null,
|
|
2464
|
+
strictPairContract: fixture.config.training?.distill?.strictPairContract === true,
|
|
2465
|
+
shardIndex: fixture.config.training?.distill?.shardIndex ?? null,
|
|
2466
|
+
shardCount: fixture.config.training?.distill?.shardCount ?? null,
|
|
2467
|
+
resumeFrom: fixture.config.training?.distill?.resumeFrom ?? null,
|
|
2468
|
+
temperature: fixture.config.training?.distill?.temperature ?? null,
|
|
2469
|
+
alphaKd: fixture.config.training?.distill?.alphaKd ?? null,
|
|
2470
|
+
alphaCe: fixture.config.training?.distill?.alphaCe ?? null,
|
|
2471
|
+
tripletMargin: fixture.config.training?.distill?.tripletMargin ?? null,
|
|
2472
|
+
studentGraphMode: fixture.config.training?.distill?.studentGraphMode ?? null,
|
|
2473
|
+
topK: fixture.config.training?.distill?.topK ?? distillOutputDim,
|
|
2474
|
+
freeze: fixture.config.training?.distill?.freeze ?? null,
|
|
2475
|
+
},
|
|
2476
|
+
distillRuntime: {
|
|
2477
|
+
teacherModelId: distillRuntime.teacherModelId || null,
|
|
2478
|
+
studentModelId: distillRuntime.studentModelId || null,
|
|
2479
|
+
teacherModelUrl: distillRuntime.teacherModelUrl || null,
|
|
2480
|
+
studentModelUrl: distillRuntime.studentModelUrl || null,
|
|
2481
|
+
topK: distillRuntime.topK,
|
|
2482
|
+
studentGraphMode: distillRuntime.studentGraphMode || null,
|
|
2483
|
+
targetTokenMode: distillRuntime.targetTokenMode || null,
|
|
2484
|
+
},
|
|
2485
|
+
distillDataset: {
|
|
2486
|
+
path: distillDatasetReport.absolutePath,
|
|
2487
|
+
rowCount: distillDatasetReport.rowCount,
|
|
2488
|
+
sampleCount: distillDatasetReport.sampleCount,
|
|
2489
|
+
shardCount: distillDatasetReport.shardCount ?? 1,
|
|
2490
|
+
directionCounts: distillDatasetReport.directionCounts,
|
|
2491
|
+
dataScope: distillDatasetReport.dataScope || null,
|
|
2492
|
+
},
|
|
2493
|
+
checkpoint: runner.lastCheckpoint || null,
|
|
2494
|
+
resumeAuditCount: Number.isInteger(runner.resumeState?.resumeAuditCount)
|
|
2495
|
+
? runner.resumeState.resumeAuditCount
|
|
2496
|
+
: 0,
|
|
2497
|
+
},
|
|
2498
|
+
};
|
|
2499
|
+
} finally {
|
|
2500
|
+
if (distillRuntime && typeof distillRuntime.cleanup === 'function') {
|
|
2501
|
+
await distillRuntime.cleanup();
|
|
2502
|
+
}
|
|
2503
|
+
if (fixture) {
|
|
2504
|
+
fixture.cleanup();
|
|
2505
|
+
}
|
|
2506
|
+
}
|
|
2507
|
+
}
|
|
2508
|
+
|
|
2509
|
+
async function runDistillStageATest(options = {}) {
|
|
2510
|
+
return runDistillStageTest('stage_a', options);
|
|
2511
|
+
}
|
|
2512
|
+
|
|
2513
|
+
async function runDistillStageBTest(options = {}) {
|
|
2514
|
+
const explicitStageAArtifact = String(options.stageAArtifact || '').trim();
|
|
2515
|
+
let stageAArtifact = explicitStageAArtifact || null;
|
|
2516
|
+
let stageAArtifactHash = String(options.stageAArtifactHash || '').trim() || null;
|
|
2517
|
+
|
|
2518
|
+
if (!stageAArtifact) {
|
|
2519
|
+
const stageA = await runDistillStageATest({
|
|
2520
|
+
...options,
|
|
2521
|
+
trainingStage: 'stage_a',
|
|
2522
|
+
});
|
|
2523
|
+
if (!stageA?.passed || !stageA?.artifact?.manifestPath) {
|
|
2524
|
+
return { passed: false, error: 'Distill stage_b preflight failed to generate stage_a artifact.' };
|
|
2525
|
+
}
|
|
2526
|
+
stageAArtifact = stageA.artifact.manifestPath;
|
|
2527
|
+
stageAArtifactHash = stageA.artifact.manifestHash;
|
|
2528
|
+
const nodeHash = await computeNodeFileHash(stageAArtifact);
|
|
2529
|
+
if (nodeHash?.hash) {
|
|
2530
|
+
stageAArtifactHash = nodeHash.hash;
|
|
2531
|
+
stageAArtifact = nodeHash.absolutePath;
|
|
2532
|
+
}
|
|
2533
|
+
}
|
|
2534
|
+
|
|
2535
|
+
return runDistillStageTest('stage_b', {
|
|
2536
|
+
...options,
|
|
2537
|
+
stageAArtifact,
|
|
2538
|
+
stageAArtifactHash,
|
|
2539
|
+
});
|
|
2540
|
+
}
|
|
2541
|
+
|
|
2542
|
+
function createLegacySkippedTest(name) {
|
|
2543
|
+
return async () => ({
|
|
2544
|
+
passed: true,
|
|
2545
|
+
skipped: true,
|
|
2546
|
+
error: `Legacy browser-only test "${name}" remains in tests/training/browser/test-page.js.`,
|
|
2547
|
+
});
|
|
2548
|
+
}
|
|
2549
|
+
|
|
2550
|
+
const CORE_TESTS = Object.freeze({
|
|
2551
|
+
'runner-smoke': runRunnerSmokeTest,
|
|
2552
|
+
'train-step-metrics': runTrainStepMetricsTest,
|
|
2553
|
+
'ul-stage1': runUlStage1Test,
|
|
2554
|
+
'ul-stage2': runUlStage2Test,
|
|
2555
|
+
'distill-stage-a': runDistillStageATest,
|
|
2556
|
+
'distill-stage-b': runDistillStageBTest,
|
|
2557
|
+
});
|
|
2558
|
+
|
|
2559
|
+
const TESTS = Object.freeze({
|
|
2560
|
+
...CORE_TESTS,
|
|
2561
|
+
...Object.fromEntries(LEGACY_BROWSER_TESTS.map((name) => [name, createLegacySkippedTest(name)])),
|
|
2562
|
+
});
|
|
2563
|
+
|
|
2564
|
+
export const trainingHarness = Object.freeze({
|
|
2565
|
+
async getGPU() {
|
|
2566
|
+
await ensureTrainingGpuRuntime();
|
|
2567
|
+
return true;
|
|
2568
|
+
},
|
|
2569
|
+
async runTest(name, options = {}) {
|
|
2570
|
+
const fn = TESTS[name];
|
|
2571
|
+
if (!fn) {
|
|
2572
|
+
return { passed: false, error: `Unknown training test: ${name}` };
|
|
2573
|
+
}
|
|
2574
|
+
return fn(options);
|
|
2575
|
+
},
|
|
2576
|
+
listTests() {
|
|
2577
|
+
return Object.keys(TESTS);
|
|
2578
|
+
},
|
|
2579
|
+
});
|
|
2580
|
+
|
|
2581
|
+
export async function runTrainingSuite(options = {}) {
|
|
2582
|
+
const trainingSchemaVersion = assertTrainingSchemaVersion(options.trainingSchemaVersion);
|
|
2583
|
+
const adapterActivation = normalizeAdapterActivationConfig(options);
|
|
2584
|
+
const startTime = performance.now();
|
|
2585
|
+
await trainingHarness.getGPU();
|
|
2586
|
+
|
|
2587
|
+
const availableTests = trainingHarness.listTests();
|
|
2588
|
+
const requestedTestsFromOptions = normalizeTrainingTestNames(options.trainingTests);
|
|
2589
|
+
const requestedStage = normalizeTrainingStage(options.trainingStage);
|
|
2590
|
+
const stageDefaultTests = requestedStage === 'stage1_joint'
|
|
2591
|
+
? ['ul-stage1']
|
|
2592
|
+
: (
|
|
2593
|
+
requestedStage === 'stage2_base'
|
|
2594
|
+
? ['ul-stage2']
|
|
2595
|
+
: (
|
|
2596
|
+
requestedStage === 'stage_a'
|
|
2597
|
+
? ['distill-stage-a']
|
|
2598
|
+
: (requestedStage === 'stage_b' ? ['distill-stage-b'] : null)
|
|
2599
|
+
)
|
|
2600
|
+
);
|
|
2601
|
+
const requestedTests = requestedTestsFromOptions || stageDefaultTests;
|
|
2602
|
+
if (requestedTests) {
|
|
2603
|
+
const unknownTests = requestedTests.filter((name) => !availableTests.includes(name));
|
|
2604
|
+
if (unknownTests.length > 0) {
|
|
2605
|
+
throw new Error(`Unknown training test(s): ${unknownTests.join(', ')}`);
|
|
2606
|
+
}
|
|
2607
|
+
}
|
|
2608
|
+
const testsToRun = requestedTests ?? availableTests;
|
|
2609
|
+
|
|
2610
|
+
const results = [];
|
|
2611
|
+
for (const testName of testsToRun) {
|
|
2612
|
+
const testStart = performance.now();
|
|
2613
|
+
try {
|
|
2614
|
+
const outcome = await trainingHarness.runTest(testName, options);
|
|
2615
|
+
const passed = outcome?.passed === true;
|
|
2616
|
+
const skipped = outcome?.skipped === true;
|
|
2617
|
+
const errorMessage = skipped
|
|
2618
|
+
? (outcome?.error ? String(outcome.error) : undefined)
|
|
2619
|
+
: (passed ? undefined : String(outcome?.error || 'Training test failed'));
|
|
2620
|
+
const entry = {
|
|
2621
|
+
name: testName,
|
|
2622
|
+
passed,
|
|
2623
|
+
skipped,
|
|
2624
|
+
duration: Math.max(0, performance.now() - testStart),
|
|
2625
|
+
...(errorMessage ? { error: errorMessage } : {}),
|
|
2626
|
+
};
|
|
2627
|
+
if (outcome?.metrics && typeof outcome.metrics === 'object') {
|
|
2628
|
+
entry.metrics = outcome.metrics;
|
|
2629
|
+
}
|
|
2630
|
+
if (outcome?.artifact && typeof outcome.artifact === 'object') {
|
|
2631
|
+
entry.artifact = outcome.artifact;
|
|
2632
|
+
}
|
|
2633
|
+
results.push(entry);
|
|
2634
|
+
} catch (error) {
|
|
2635
|
+
results.push({
|
|
2636
|
+
name: testName,
|
|
2637
|
+
passed: false,
|
|
2638
|
+
duration: Math.max(0, performance.now() - testStart),
|
|
2639
|
+
error: String(error?.message || error),
|
|
2640
|
+
});
|
|
2641
|
+
}
|
|
2642
|
+
}
|
|
2643
|
+
|
|
2644
|
+
const summary = buildSuiteSummary('training', results, startTime);
|
|
2645
|
+
const adapterActivationResult = (
|
|
2646
|
+
adapterActivation.enabled
|
|
2647
|
+
&& adapterActivation.autoActivate
|
|
2648
|
+
)
|
|
2649
|
+
? await tryActivateAdapterPayload(adapterActivation.adapterPayload)
|
|
2650
|
+
: null;
|
|
2651
|
+
return {
|
|
2652
|
+
...summary,
|
|
2653
|
+
modelId: options.modelId || options.modelUrl || 'training',
|
|
2654
|
+
metrics: {
|
|
2655
|
+
testsRun: results.length,
|
|
2656
|
+
selectedTests: testsToRun,
|
|
2657
|
+
availableTests,
|
|
2658
|
+
trainingStage: requestedStage || null,
|
|
2659
|
+
trainingSchemaVersion,
|
|
2660
|
+
adapterActivation: adapterActivationResult,
|
|
2661
|
+
},
|
|
2662
|
+
deviceInfo: getKernelCapabilities(),
|
|
2663
|
+
};
|
|
2664
|
+
}
|
|
2665
|
+
|
|
2666
|
+
function toPositiveInteger(value, fallback) {
|
|
2667
|
+
const parsed = Number(value);
|
|
2668
|
+
if (!Number.isFinite(parsed)) return fallback;
|
|
2669
|
+
const floored = Math.floor(parsed);
|
|
2670
|
+
return floored > 0 ? floored : fallback;
|
|
2671
|
+
}
|
|
2672
|
+
|
|
2673
|
+
function toPositiveIntegerOrNull(value) {
|
|
2674
|
+
const parsed = Number(value);
|
|
2675
|
+
if (!Number.isFinite(parsed)) return null;
|
|
2676
|
+
const floored = Math.floor(parsed);
|
|
2677
|
+
return floored > 0 ? floored : null;
|
|
2678
|
+
}
|
|
2679
|
+
|
|
2680
|
+
function resolveDistillShardProgressContext(
|
|
2681
|
+
options = {},
|
|
2682
|
+
trainingOverrides = null,
|
|
2683
|
+
stepsPerShard = null,
|
|
2684
|
+
fallbackShardCount = null
|
|
2685
|
+
) {
|
|
2686
|
+
const distillConfig = trainingOverrides?.distill || {};
|
|
2687
|
+
const shardIndexInput = toPositiveIntegerOrNull(
|
|
2688
|
+
options.distillShardIndex ?? distillConfig.shardIndex ?? null
|
|
2689
|
+
);
|
|
2690
|
+
const shardCountInput = toPositiveIntegerOrNull(
|
|
2691
|
+
options.distillShardCount ?? distillConfig.shardCount ?? null
|
|
2692
|
+
);
|
|
2693
|
+
const fallbackShardCountInput = toPositiveIntegerOrNull(fallbackShardCount);
|
|
2694
|
+
if (
|
|
2695
|
+
shardIndexInput !== null
|
|
2696
|
+
&& shardCountInput !== null
|
|
2697
|
+
&& shardIndexInput > shardCountInput
|
|
2698
|
+
) {
|
|
2699
|
+
throw new Error('distillShardIndex must be <= distillShardCount.');
|
|
2700
|
+
}
|
|
2701
|
+
const shardCount = shardCountInput ?? fallbackShardCountInput ?? 1;
|
|
2702
|
+
const shardIndex = shardIndexInput ?? 1;
|
|
2703
|
+
const normalizedStepsPerShard = toPositiveIntegerOrNull(stepsPerShard);
|
|
2704
|
+
const totalGlobalSteps = normalizedStepsPerShard
|
|
2705
|
+
? (normalizedStepsPerShard * shardCount)
|
|
2706
|
+
: null;
|
|
2707
|
+
return {
|
|
2708
|
+
shardIndex: Math.min(Math.max(1, shardIndex), shardCount),
|
|
2709
|
+
shardCount: Math.max(1, shardCount),
|
|
2710
|
+
stepsPerShard: normalizedStepsPerShard,
|
|
2711
|
+
totalGlobalSteps,
|
|
2712
|
+
};
|
|
2713
|
+
}
|
|
2714
|
+
|
|
2715
|
+
function resolveBenchProgressSummary(stepEntries, context, startTimeMs) {
|
|
2716
|
+
const entries = Array.isArray(stepEntries) ? stepEntries : [];
|
|
2717
|
+
const lastEntry = entries.length > 0 ? entries[entries.length - 1] : null;
|
|
2718
|
+
const shardIndex = context?.shardIndex ?? 1;
|
|
2719
|
+
const shardCount = context?.shardCount ?? 1;
|
|
2720
|
+
const stepsPerShard = context?.stepsPerShard ?? null;
|
|
2721
|
+
const totalGlobalSteps = context?.totalGlobalSteps ?? null;
|
|
2722
|
+
const fallbackGlobalStep = stepsPerShard
|
|
2723
|
+
? (((shardIndex - 1) * stepsPerShard) + Math.min(entries.length, stepsPerShard))
|
|
2724
|
+
: null;
|
|
2725
|
+
const completedGlobalSteps = Number.isFinite(lastEntry?.progress_global_step)
|
|
2726
|
+
? lastEntry.progress_global_step
|
|
2727
|
+
: fallbackGlobalStep;
|
|
2728
|
+
const resolvedTotalGlobalSteps = Number.isFinite(lastEntry?.progress_global_steps)
|
|
2729
|
+
? lastEntry.progress_global_steps
|
|
2730
|
+
: totalGlobalSteps;
|
|
2731
|
+
const percentComplete = Number.isFinite(lastEntry?.progress_percent_complete)
|
|
2732
|
+
? lastEntry.progress_percent_complete
|
|
2733
|
+
: (
|
|
2734
|
+
Number.isFinite(completedGlobalSteps)
|
|
2735
|
+
&& Number.isFinite(resolvedTotalGlobalSteps)
|
|
2736
|
+
&& resolvedTotalGlobalSteps > 0
|
|
2737
|
+
? Math.min(100, (completedGlobalSteps / resolvedTotalGlobalSteps) * 100)
|
|
2738
|
+
: null
|
|
2739
|
+
);
|
|
2740
|
+
const etaMs = Number.isFinite(lastEntry?.progress_eta_ms)
|
|
2741
|
+
? Math.max(0, lastEntry.progress_eta_ms)
|
|
2742
|
+
: (
|
|
2743
|
+
Number.isFinite(percentComplete)
|
|
2744
|
+
&& percentComplete >= 100
|
|
2745
|
+
? 0
|
|
2746
|
+
: null
|
|
2747
|
+
);
|
|
2748
|
+
const elapsedMs = Number.isFinite(lastEntry?.progress_elapsed_ms)
|
|
2749
|
+
? Math.max(0, lastEntry.progress_elapsed_ms)
|
|
2750
|
+
: Math.max(0, performance.now() - startTimeMs);
|
|
2751
|
+
return {
|
|
2752
|
+
shardIndex,
|
|
2753
|
+
shardCount,
|
|
2754
|
+
stepsPerShard,
|
|
2755
|
+
completedGlobalSteps: Number.isFinite(completedGlobalSteps) ? completedGlobalSteps : null,
|
|
2756
|
+
totalGlobalSteps: Number.isFinite(resolvedTotalGlobalSteps) ? resolvedTotalGlobalSteps : null,
|
|
2757
|
+
percentComplete,
|
|
2758
|
+
etaMs,
|
|
2759
|
+
etaIso: Number.isFinite(etaMs) ? new Date(Date.now() + etaMs).toISOString() : null,
|
|
2760
|
+
elapsedMs,
|
|
2761
|
+
updatedAt: new Date().toISOString(),
|
|
2762
|
+
};
|
|
2763
|
+
}
|
|
2764
|
+
|
|
2765
|
+
function appendTimelineEvent(timeline, type, details = {}) {
|
|
2766
|
+
timeline.push({
|
|
2767
|
+
index: timeline.length + 1,
|
|
2768
|
+
timestamp: new Date().toISOString(),
|
|
2769
|
+
type,
|
|
2770
|
+
...details,
|
|
2771
|
+
});
|
|
2772
|
+
}
|
|
2773
|
+
|
|
2774
|
+
function resolveBenchRunSettings(options = {}) {
|
|
2775
|
+
const benchRun = options.benchRun && typeof options.benchRun === 'object'
|
|
2776
|
+
? options.benchRun
|
|
2777
|
+
: {};
|
|
2778
|
+
return {
|
|
2779
|
+
warmupRuns: Math.max(0, Math.floor(Number(benchRun.warmupRuns) || 0)),
|
|
2780
|
+
timedRuns: toPositiveInteger(benchRun.timedRuns, 1),
|
|
2781
|
+
stepsPerRun: toPositiveInteger(
|
|
2782
|
+
options.trainingBenchSteps ?? benchRun.steps ?? options.trainingSteps,
|
|
2783
|
+
2
|
|
2784
|
+
),
|
|
2785
|
+
};
|
|
2786
|
+
}
|
|
2787
|
+
|
|
2788
|
+
function resolveTrainingOverrides(options = {}) {
|
|
2789
|
+
const distillTraining = buildDistillTrainingOverrides(options);
|
|
2790
|
+
if (distillTraining?.distill?.enabled) {
|
|
2791
|
+
return distillTraining;
|
|
2792
|
+
}
|
|
2793
|
+
const ulTraining = buildUlTrainingOverrides(options);
|
|
2794
|
+
if (ulTraining) {
|
|
2795
|
+
return ulTraining;
|
|
2796
|
+
}
|
|
2797
|
+
return normalizeTrainingConfigOverride(options.trainingConfig) || undefined;
|
|
2798
|
+
}
|
|
2799
|
+
|
|
2800
|
+
export async function runTrainingBenchSuite(options = {}) {
|
|
2801
|
+
const trainingSchemaVersion = assertTrainingSchemaVersion(options.trainingSchemaVersion);
|
|
2802
|
+
const startTime = performance.now();
|
|
2803
|
+
await trainingHarness.getGPU();
|
|
2804
|
+
|
|
2805
|
+
const benchSettings = resolveBenchRunSettings(options);
|
|
2806
|
+
const totalRuns = benchSettings.warmupRuns + benchSettings.timedRuns;
|
|
2807
|
+
const trainingOverrides = resolveTrainingOverrides(options);
|
|
2808
|
+
const adapterActivation = normalizeAdapterActivationConfig(options);
|
|
2809
|
+
const distillEnabled = trainingOverrides?.distill?.enabled === true;
|
|
2810
|
+
const distillDatasetPath = resolveDistillDatasetPath(options, trainingOverrides);
|
|
2811
|
+
const distillDataScope = resolveDistillDataScope(options, trainingOverrides);
|
|
2812
|
+
const distillDatasetReport = distillEnabled
|
|
2813
|
+
? await loadDistillDatasetFromJsonl(distillDatasetPath, distillDataScope)
|
|
2814
|
+
: null;
|
|
2815
|
+
const resolvedResumeFrom = options.resumeFrom || trainingOverrides?.distill?.resumeFrom || null;
|
|
2816
|
+
const resolvedStage1Artifact = options.stage1Artifact || trainingOverrides?.ul?.stage1Artifact || null;
|
|
2817
|
+
const resolvedStage1ArtifactHash = options.stage1ArtifactHash || trainingOverrides?.ul?.stage1ArtifactHash || null;
|
|
2818
|
+
const resolvedStageAArtifact = options.stageAArtifact || trainingOverrides?.distill?.stageAArtifact || null;
|
|
2819
|
+
const resolvedStageAArtifactHash = options.stageAArtifactHash || trainingOverrides?.distill?.stageAArtifactHash || null;
|
|
2820
|
+
let distillRuntime = null;
|
|
2821
|
+
if (distillEnabled) {
|
|
2822
|
+
if (!distillDatasetPath) {
|
|
2823
|
+
throw new Error('Distill benchmark requires --distill-dataset-path (training.distill.datasetPath).');
|
|
2824
|
+
}
|
|
2825
|
+
distillRuntime = await createDistillRuntimeContext(options, trainingOverrides);
|
|
2826
|
+
}
|
|
2827
|
+
|
|
2828
|
+
const timedRunDurationsMs = [];
|
|
2829
|
+
const timedRunStepsPerSec = [];
|
|
2830
|
+
const timedStepDurationsMs = [];
|
|
2831
|
+
const timedRunUlArtifacts = [];
|
|
2832
|
+
const timedRunDistillArtifacts = [];
|
|
2833
|
+
const timedRunAdapterExports = [];
|
|
2834
|
+
const trainingMetricsReport = [];
|
|
2835
|
+
const distillShardProgress = resolveDistillShardProgressContext(
|
|
2836
|
+
options,
|
|
2837
|
+
trainingOverrides,
|
|
2838
|
+
benchSettings.stepsPerRun,
|
|
2839
|
+
distillDatasetReport?.shardCount ?? null
|
|
2840
|
+
);
|
|
2841
|
+
const checkpointResumeTimeline = [];
|
|
2842
|
+
appendTimelineEvent(checkpointResumeTimeline, 'benchmark_started', {
|
|
2843
|
+
workloadType: 'training',
|
|
2844
|
+
trainingStage: (
|
|
2845
|
+
options.trainingStage
|
|
2846
|
+
|| trainingOverrides?.distill?.stage
|
|
2847
|
+
|| trainingOverrides?.ul?.stage
|
|
2848
|
+
|| null
|
|
2849
|
+
),
|
|
2850
|
+
forceResume: options.forceResume === true,
|
|
2851
|
+
forceResumeReason: options.forceResume === true
|
|
2852
|
+
? (options.forceResumeReason || null)
|
|
2853
|
+
: null,
|
|
2854
|
+
shardIndex: distillShardProgress.shardIndex,
|
|
2855
|
+
shardCount: distillShardProgress.shardCount,
|
|
2856
|
+
stepsPerShard: distillShardProgress.stepsPerShard,
|
|
2857
|
+
});
|
|
2858
|
+
if (resolvedResumeFrom) {
|
|
2859
|
+
appendTimelineEvent(checkpointResumeTimeline, 'resume_requested', {
|
|
2860
|
+
resumeFrom: String(resolvedResumeFrom),
|
|
2861
|
+
});
|
|
2862
|
+
}
|
|
2863
|
+
if (resolvedStage1Artifact) {
|
|
2864
|
+
appendTimelineEvent(checkpointResumeTimeline, 'resume_dependency_declared', {
|
|
2865
|
+
dependencyType: 'ul_stage1',
|
|
2866
|
+
stage1Artifact: String(resolvedStage1Artifact),
|
|
2867
|
+
stage1ArtifactHash: resolvedStage1ArtifactHash,
|
|
2868
|
+
});
|
|
2869
|
+
}
|
|
2870
|
+
if (resolvedStageAArtifact) {
|
|
2871
|
+
appendTimelineEvent(checkpointResumeTimeline, 'resume_dependency_declared', {
|
|
2872
|
+
dependencyType: 'distill_stage_a',
|
|
2873
|
+
stageAArtifact: String(resolvedStageAArtifact),
|
|
2874
|
+
stageAArtifactHash: resolvedStageAArtifactHash,
|
|
2875
|
+
});
|
|
2876
|
+
}
|
|
2877
|
+
let completedTimedRuns = 0;
|
|
2878
|
+
let latestExportedAdapter = null;
|
|
2879
|
+
|
|
2880
|
+
try {
|
|
2881
|
+
for (let runIndex = 0; runIndex < totalRuns; runIndex += 1) {
|
|
2882
|
+
const fixture = distillEnabled
|
|
2883
|
+
? await createDistillStudentRuntimeModelFixture({
|
|
2884
|
+
training: trainingOverrides,
|
|
2885
|
+
}, {
|
|
2886
|
+
outputDim: distillRuntime?.topK ?? DISTILL_ADAPTER_TOP_K,
|
|
2887
|
+
distillRuntime,
|
|
2888
|
+
})
|
|
2889
|
+
: createToyModelFixture({
|
|
2890
|
+
training: trainingOverrides,
|
|
2891
|
+
});
|
|
2892
|
+
try {
|
|
2893
|
+
const runner = new TrainingRunner(fixture.config, {
|
|
2894
|
+
optimizer: new AdamOptimizer(fixture.config),
|
|
2895
|
+
crossEntropyLoss,
|
|
2896
|
+
clipGradients,
|
|
2897
|
+
});
|
|
2898
|
+
const dataset = distillEnabled
|
|
2899
|
+
? distillDatasetReport.createDataset({
|
|
2900
|
+
batchSize: 1,
|
|
2901
|
+
shuffle: false,
|
|
2902
|
+
seed: 1337 + runIndex,
|
|
2903
|
+
distillRuntime,
|
|
2904
|
+
})
|
|
2905
|
+
: {
|
|
2906
|
+
async *batches() {
|
|
2907
|
+
for (let i = 0; i < benchSettings.stepsPerRun; i += 1) {
|
|
2908
|
+
yield fixture.batch;
|
|
2909
|
+
}
|
|
2910
|
+
},
|
|
2911
|
+
};
|
|
2912
|
+
|
|
2913
|
+
const runStart = performance.now();
|
|
2914
|
+
const isTimedRun = runIndex >= benchSettings.warmupRuns;
|
|
2915
|
+
appendTimelineEvent(checkpointResumeTimeline, 'run_started', {
|
|
2916
|
+
runIndex: runIndex + 1,
|
|
2917
|
+
phase: isTimedRun ? 'timed' : 'warmup',
|
|
2918
|
+
});
|
|
2919
|
+
const runMetrics = await runner.run(fixture.model, dataset, {
|
|
2920
|
+
epochs: 1,
|
|
2921
|
+
batchSize: 1,
|
|
2922
|
+
shuffle: false,
|
|
2923
|
+
maxSteps: benchSettings.stepsPerRun,
|
|
2924
|
+
modelId: options.modelId || distillRuntime?.studentModelId || 'training',
|
|
2925
|
+
modelUrl: options.modelUrl || distillRuntime?.studentModelUrl || null,
|
|
2926
|
+
runtimePreset: options.runtimePreset || null,
|
|
2927
|
+
trainingStage: (
|
|
2928
|
+
options.trainingStage
|
|
2929
|
+
|| trainingOverrides?.distill?.stage
|
|
2930
|
+
|| trainingOverrides?.ul?.stage
|
|
2931
|
+
|| null
|
|
2932
|
+
),
|
|
2933
|
+
command: options.command || null,
|
|
2934
|
+
surface: options.surface || null,
|
|
2935
|
+
forceResume: options.forceResume === true,
|
|
2936
|
+
forceResumeReason: options.forceResumeReason || null,
|
|
2937
|
+
forceResumeSource: options.forceResumeSource || null,
|
|
2938
|
+
checkpointOperator: options.checkpointOperator || null,
|
|
2939
|
+
checkpointEvery: options.checkpointEvery ?? null,
|
|
2940
|
+
gpuAdapterInfo: getKernelCapabilities(),
|
|
2941
|
+
timestamp: options.timestamp || null,
|
|
2942
|
+
ulArtifactDir: options.ulArtifactDir || null,
|
|
2943
|
+
distillArtifactDir: options.distillArtifactDir || null,
|
|
2944
|
+
stageAArtifact: resolvedStageAArtifact,
|
|
2945
|
+
stageAArtifactHash: resolvedStageAArtifactHash,
|
|
2946
|
+
teacherModelId: distillRuntime?.teacherModelId || options.teacherModelId || null,
|
|
2947
|
+
studentModelId: distillRuntime?.studentModelId || options.studentModelId || null,
|
|
2948
|
+
distillDatasetId: options.distillDatasetId || null,
|
|
2949
|
+
distillDatasetPath: distillDatasetReport?.absolutePath || null,
|
|
2950
|
+
distillLanguagePair: options.distillLanguagePair || null,
|
|
2951
|
+
distillSourceLangs: distillDataScope.sourceLangs || null,
|
|
2952
|
+
distillTargetLangs: distillDataScope.targetLangs || null,
|
|
2953
|
+
distillPairAllowlist: distillDataScope.pairAllowlist || null,
|
|
2954
|
+
strictPairContract: distillDataScope.strictPairContract === true,
|
|
2955
|
+
distillShardIndex: distillShardProgress.shardIndex,
|
|
2956
|
+
distillShardCount: distillShardProgress.shardCount,
|
|
2957
|
+
resumeFrom: resolvedResumeFrom,
|
|
2958
|
+
});
|
|
2959
|
+
const runDurationMs = Math.max(0, performance.now() - runStart);
|
|
2960
|
+
if (runner.resumeState && typeof runner.resumeState === 'object') {
|
|
2961
|
+
appendTimelineEvent(checkpointResumeTimeline, 'run_resumed', {
|
|
2962
|
+
runIndex: runIndex + 1,
|
|
2963
|
+
phase: isTimedRun ? 'timed' : 'warmup',
|
|
2964
|
+
resumedStep: runner.resumeState.step ?? null,
|
|
2965
|
+
resumedEpoch: runner.resumeState.epoch ?? null,
|
|
2966
|
+
resumedBatch: runner.resumeState.batch ?? null,
|
|
2967
|
+
resumedCheckpointHash: runner.resumeState.checkpointHash ?? null,
|
|
2968
|
+
previousCheckpointHash: runner.resumeState.previousCheckpointHash ?? null,
|
|
2969
|
+
resumeAuditCount: Number.isInteger(runner.resumeState.resumeAuditCount)
|
|
2970
|
+
? runner.resumeState.resumeAuditCount
|
|
2971
|
+
: 0,
|
|
2972
|
+
checkpointKey: runner.resumeState.checkpointKey ?? null,
|
|
2973
|
+
});
|
|
2974
|
+
if (Number.isInteger(runner.resumeState.resumeAuditCount) && runner.resumeState.resumeAuditCount > 0) {
|
|
2975
|
+
appendTimelineEvent(checkpointResumeTimeline, 'resume_override_applied', {
|
|
2976
|
+
runIndex: runIndex + 1,
|
|
2977
|
+
phase: isTimedRun ? 'timed' : 'warmup',
|
|
2978
|
+
resumeAudits: Array.isArray(runner.resumeState.resumeAudits)
|
|
2979
|
+
? runner.resumeState.resumeAudits
|
|
2980
|
+
: [],
|
|
2981
|
+
});
|
|
2982
|
+
}
|
|
2983
|
+
}
|
|
2984
|
+
appendTimelineEvent(checkpointResumeTimeline, 'run_completed', {
|
|
2985
|
+
runIndex: runIndex + 1,
|
|
2986
|
+
phase: isTimedRun ? 'timed' : 'warmup',
|
|
2987
|
+
durationMs: runDurationMs,
|
|
2988
|
+
stepCount: Array.isArray(runMetrics) ? runMetrics.length : 0,
|
|
2989
|
+
});
|
|
2990
|
+
if (isTimedRun) {
|
|
2991
|
+
completedTimedRuns += 1;
|
|
2992
|
+
timedRunDurationsMs.push(runDurationMs);
|
|
2993
|
+
const runStepCount = Array.isArray(runMetrics) ? runMetrics.length : 0;
|
|
2994
|
+
if (runDurationMs > 0 && runStepCount > 0) {
|
|
2995
|
+
timedRunStepsPerSec.push((runStepCount * 1000) / runDurationMs);
|
|
2996
|
+
}
|
|
2997
|
+
for (const stepEntry of runMetrics) {
|
|
2998
|
+
const stepWithRun = {
|
|
2999
|
+
...stepEntry,
|
|
3000
|
+
bench_run_index: completedTimedRuns,
|
|
3001
|
+
bench_run_global_index: runIndex + 1,
|
|
3002
|
+
};
|
|
3003
|
+
if (isFiniteNumber(stepWithRun?.step_time_ms)) {
|
|
3004
|
+
timedStepDurationsMs.push(stepWithRun.step_time_ms);
|
|
3005
|
+
}
|
|
3006
|
+
trainingMetricsReport.push(stepWithRun);
|
|
3007
|
+
}
|
|
3008
|
+
if (runner.lastCheckpoint && typeof runner.lastCheckpoint === 'object') {
|
|
3009
|
+
appendTimelineEvent(checkpointResumeTimeline, 'checkpoint_state_written', {
|
|
3010
|
+
runIndex: runIndex + 1,
|
|
3011
|
+
timedRunIndex: completedTimedRuns,
|
|
3012
|
+
checkpointKey: runner.lastCheckpoint.key || null,
|
|
3013
|
+
checkpointStep: runner.lastCheckpoint.step ?? null,
|
|
3014
|
+
checkpointEpoch: runner.lastCheckpoint.epoch ?? null,
|
|
3015
|
+
checkpointBatch: runner.lastCheckpoint.batch ?? null,
|
|
3016
|
+
});
|
|
3017
|
+
}
|
|
3018
|
+
if (runner.lastArtifact && typeof runner.lastArtifact === 'object') {
|
|
3019
|
+
const artifactEntry = {
|
|
3020
|
+
runIndex: completedTimedRuns,
|
|
3021
|
+
...runner.lastArtifact,
|
|
3022
|
+
resumeAudits: Array.isArray(runner.resumeState?.resumeAudits)
|
|
3023
|
+
? runner.resumeState.resumeAudits
|
|
3024
|
+
: [],
|
|
3025
|
+
};
|
|
3026
|
+
appendTimelineEvent(checkpointResumeTimeline, 'checkpoint_written', {
|
|
3027
|
+
runIndex: runIndex + 1,
|
|
3028
|
+
timedRunIndex: completedTimedRuns,
|
|
3029
|
+
artifactKind: artifactEntry.kind || null,
|
|
3030
|
+
stage: artifactEntry.stage || null,
|
|
3031
|
+
manifestPath: artifactEntry.manifestPath || null,
|
|
3032
|
+
manifestHash: artifactEntry.manifestHash || null,
|
|
3033
|
+
manifestFileHash: artifactEntry.manifestFileHash || null,
|
|
3034
|
+
});
|
|
3035
|
+
if (artifactEntry.stageADependency) {
|
|
3036
|
+
appendTimelineEvent(checkpointResumeTimeline, 'resume_dependency_resolved', {
|
|
3037
|
+
dependencyType: 'distill_stage_a',
|
|
3038
|
+
runIndex: runIndex + 1,
|
|
3039
|
+
stageADependency: artifactEntry.stageADependency,
|
|
3040
|
+
});
|
|
3041
|
+
}
|
|
3042
|
+
if (artifactEntry.stage1Dependency) {
|
|
3043
|
+
appendTimelineEvent(checkpointResumeTimeline, 'resume_dependency_resolved', {
|
|
3044
|
+
dependencyType: 'ul_stage1',
|
|
3045
|
+
runIndex: runIndex + 1,
|
|
3046
|
+
stage1Dependency: artifactEntry.stage1Dependency,
|
|
3047
|
+
});
|
|
3048
|
+
}
|
|
3049
|
+
if (runner.lastArtifact.kind === 'distill') {
|
|
3050
|
+
timedRunDistillArtifacts.push(artifactEntry);
|
|
3051
|
+
} else {
|
|
3052
|
+
timedRunUlArtifacts.push(artifactEntry);
|
|
3053
|
+
}
|
|
3054
|
+
}
|
|
3055
|
+
if (adapterActivation.enabled && adapterActivation.exportConfig) {
|
|
3056
|
+
const exportedAdapter = await exportLoRAAdapterFromModel(
|
|
3057
|
+
fixture.model,
|
|
3058
|
+
adapterActivation.exportConfig,
|
|
3059
|
+
completedTimedRuns
|
|
3060
|
+
);
|
|
3061
|
+
if (exportedAdapter) {
|
|
3062
|
+
latestExportedAdapter = exportedAdapter;
|
|
3063
|
+
timedRunAdapterExports.push({
|
|
3064
|
+
runIndex: completedTimedRuns,
|
|
3065
|
+
id: exportedAdapter.manifest?.id || null,
|
|
3066
|
+
name: exportedAdapter.manifest?.name || null,
|
|
3067
|
+
hash: exportedAdapter.hash,
|
|
3068
|
+
});
|
|
3069
|
+
}
|
|
3070
|
+
}
|
|
3071
|
+
}
|
|
3072
|
+
} finally {
|
|
3073
|
+
fixture.cleanup();
|
|
3074
|
+
}
|
|
3075
|
+
}
|
|
3076
|
+
} finally {
|
|
3077
|
+
if (distillRuntime && typeof distillRuntime.cleanup === 'function') {
|
|
3078
|
+
await distillRuntime.cleanup();
|
|
3079
|
+
}
|
|
3080
|
+
}
|
|
3081
|
+
|
|
3082
|
+
const runMsStats = computeSampleStats(timedRunDurationsMs);
|
|
3083
|
+
const stepMsStats = computeSampleStats(timedStepDurationsMs);
|
|
3084
|
+
const stepsPerSecStats = computeSampleStats(timedRunStepsPerSec);
|
|
3085
|
+
const progress = resolveBenchProgressSummary(trainingMetricsReport, distillShardProgress, startTime);
|
|
3086
|
+
const activationPayload = adapterActivation.adapterPayload
|
|
3087
|
+
? adapterActivation.adapterPayload
|
|
3088
|
+
: (latestExportedAdapter
|
|
3089
|
+
? {
|
|
3090
|
+
adapterManifest: latestExportedAdapter.manifest,
|
|
3091
|
+
adapterManifestJson: latestExportedAdapter.json,
|
|
3092
|
+
}
|
|
3093
|
+
: null);
|
|
3094
|
+
const adapterActivationResult = (
|
|
3095
|
+
adapterActivation.enabled
|
|
3096
|
+
&& adapterActivation.autoActivate
|
|
3097
|
+
)
|
|
3098
|
+
? await tryActivateAdapterPayload(activationPayload)
|
|
3099
|
+
: null;
|
|
3100
|
+
appendTimelineEvent(checkpointResumeTimeline, 'benchmark_completed', {
|
|
3101
|
+
completedTimedRuns,
|
|
3102
|
+
metricEntryCount: trainingMetricsReport.length,
|
|
3103
|
+
percentComplete: progress.percentComplete,
|
|
3104
|
+
etaMs: progress.etaMs,
|
|
3105
|
+
});
|
|
3106
|
+
|
|
3107
|
+
const results = [
|
|
3108
|
+
{
|
|
3109
|
+
name: 'training-benchmark',
|
|
3110
|
+
passed: completedTimedRuns > 0 && trainingMetricsReport.length > 0,
|
|
3111
|
+
duration: Math.max(0, performance.now() - startTime),
|
|
3112
|
+
error: completedTimedRuns > 0 && trainingMetricsReport.length > 0
|
|
3113
|
+
? undefined
|
|
3114
|
+
: 'No timed training benchmark runs completed.',
|
|
3115
|
+
},
|
|
3116
|
+
];
|
|
3117
|
+
|
|
3118
|
+
const summary = buildSuiteSummary('bench', results, startTime);
|
|
3119
|
+
return {
|
|
3120
|
+
...summary,
|
|
3121
|
+
modelId: options.modelId || distillRuntime?.studentModelId || options.modelUrl || 'training',
|
|
3122
|
+
metrics: {
|
|
3123
|
+
workloadType: 'training',
|
|
3124
|
+
warmupRuns: benchSettings.warmupRuns,
|
|
3125
|
+
timedRuns: benchSettings.timedRuns,
|
|
3126
|
+
completedTimedRuns,
|
|
3127
|
+
stepsPerRun: benchSettings.stepsPerRun,
|
|
3128
|
+
trainingSchemaVersion,
|
|
3129
|
+
trainingMetricsReport,
|
|
3130
|
+
progress,
|
|
3131
|
+
ulArtifacts: timedRunUlArtifacts,
|
|
3132
|
+
distillArtifacts: timedRunDistillArtifacts,
|
|
3133
|
+
adapterExports: timedRunAdapterExports,
|
|
3134
|
+
adapterActivation: adapterActivationResult,
|
|
3135
|
+
checkpointResumeTimeline,
|
|
3136
|
+
distillDataset: distillDatasetReport
|
|
3137
|
+
? {
|
|
3138
|
+
path: distillDatasetReport.absolutePath,
|
|
3139
|
+
rowCount: distillDatasetReport.rowCount,
|
|
3140
|
+
sampleCount: distillDatasetReport.sampleCount,
|
|
3141
|
+
shardCount: distillDatasetReport.shardCount ?? 1,
|
|
3142
|
+
directionCounts: distillDatasetReport.directionCounts,
|
|
3143
|
+
dataScope: distillDatasetReport.dataScope || null,
|
|
3144
|
+
}
|
|
3145
|
+
: null,
|
|
3146
|
+
latency: {
|
|
3147
|
+
runMs: runMsStats,
|
|
3148
|
+
stepMs: stepMsStats,
|
|
3149
|
+
},
|
|
3150
|
+
throughput: {
|
|
3151
|
+
stepsPerSec: stepsPerSecStats,
|
|
3152
|
+
},
|
|
3153
|
+
},
|
|
3154
|
+
deviceInfo: getKernelCapabilities(),
|
|
3155
|
+
};
|
|
3156
|
+
}
|