@simulatte/doppler 0.1.5 → 0.1.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (392) hide show
  1. package/CHANGELOG.md +126 -0
  2. package/README.md +25 -17
  3. package/package.json +20 -4
  4. package/src/adapters/adapter-registry.js +12 -1
  5. package/src/adapters/lora-loader.js +23 -6
  6. package/src/bridge/extension-client.d.ts +5 -0
  7. package/src/bridge/extension-client.js +40 -0
  8. package/src/bridge/index.d.ts +2 -1
  9. package/src/bridge/index.js +6 -4
  10. package/src/browser/browser-converter.js +26 -1
  11. package/src/browser/file-picker.js +6 -0
  12. package/src/browser/safetensors-parser-browser.js +84 -1
  13. package/src/browser/shard-io-browser.js +2 -2
  14. package/src/browser/tensor-source-download.js +8 -2
  15. package/src/browser/tensor-source-http.d.ts +1 -0
  16. package/src/browser/tensor-source-http.js +5 -1
  17. package/src/client/doppler-api.browser.js +20 -4
  18. package/src/client/doppler-api.js +19 -3
  19. package/src/client/doppler-provider/generation.js +12 -0
  20. package/src/client/doppler-provider/model-manager.d.ts +10 -0
  21. package/src/client/doppler-provider/model-manager.js +91 -19
  22. package/src/client/doppler-provider/source-runtime.d.ts +2 -1
  23. package/src/client/doppler-provider/source-runtime.js +132 -13
  24. package/src/client/doppler-registry.json +8 -7
  25. package/src/config/backward-registry-loader.js +17 -2
  26. package/src/config/execution-v0-contract-check.js +113 -15
  27. package/src/config/kernel-path-contract-check.js +57 -29
  28. package/src/config/kernel-path-loader.js +5 -36
  29. package/src/config/kernels/kernel-ref-digests.js +39 -39
  30. package/src/config/kernels/registry.js +14 -1
  31. package/src/config/kernels/registry.json +49 -7
  32. package/src/config/loader.d.ts +1 -1
  33. package/src/config/loader.js +43 -4
  34. package/src/config/merge-contract-check.js +59 -4
  35. package/src/config/merge-helpers.js +128 -7
  36. package/src/config/merge.d.ts +1 -0
  37. package/src/config/merge.js +28 -0
  38. package/src/config/param-validator.js +47 -2
  39. package/src/config/presets/kernel-paths/{gemma2-q4k-dequant-f32a.json → gemma2-q4k-dequant-f32a-nosubgroups.json} +3 -3
  40. package/src/config/presets/kernel-paths/gemma3-f16-fused-f32a-online-streamingprefill.json +223 -0
  41. package/src/config/presets/kernel-paths/{gemma3-q4k-dequant-f32a.json → gemma3-q4k-dequant-f32a-nosubgroups.json} +3 -3
  42. package/src/config/presets/kernel-paths/registry.json +29 -8
  43. package/src/config/presets/models/gemma2.json +2 -2
  44. package/src/config/presets/models/qwen3.json +9 -2
  45. package/src/config/presets/models/transformer.json +5 -0
  46. package/src/config/presets/runtime/experiments/bench/gemma3-bench-q4k.json +1 -1
  47. package/src/config/presets/runtime/experiments/debug/gemma3-debug-q4k.json +1 -1
  48. package/src/config/presets/runtime/experiments/verify/gemma3-verify.json +1 -1
  49. package/src/config/presets/runtime/kernels/dequant-f16-q4k.json +6 -13
  50. package/src/config/presets/runtime/kernels/dequant-f32-q4k.json +6 -13
  51. package/src/config/presets/runtime/kernels/embeddinggemma-q4k-dequant-f32a.json +37 -0
  52. package/src/config/presets/runtime/kernels/fused-q4k.json +6 -13
  53. package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f16a.json +33 -0
  54. package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f32a-nosubgroups.json +33 -0
  55. package/src/config/presets/runtime/kernels/gemma2-q4k-fused-f32a.json +33 -0
  56. package/src/config/presets/runtime/kernels/safe-q4k.json +6 -13
  57. package/src/config/presets/runtime/platform/metal-apple-q4k.json +1 -1
  58. package/src/config/required-inference-fields-contract-check.js +6 -0
  59. package/src/config/runtime.js +6 -1
  60. package/src/config/schema/debug.schema.d.ts +5 -0
  61. package/src/config/schema/doppler.schema.js +16 -21
  62. package/src/config/schema/inference-defaults.schema.js +6 -3
  63. package/src/config/schema/inference.schema.d.ts +9 -0
  64. package/src/config/schema/kernel-path.schema.d.ts +11 -1
  65. package/src/config/schema/kernel-thresholds.schema.js +12 -4
  66. package/src/config/schema/manifest.schema.d.ts +8 -1
  67. package/src/config/schema/manifest.schema.js +19 -3
  68. package/src/config/training-defaults.js +30 -22
  69. package/src/converter/conversion-plan.js +94 -9
  70. package/src/converter/core.d.ts +7 -0
  71. package/src/converter/core.js +14 -9
  72. package/src/converter/execution-v0-manifest.js +4 -1
  73. package/src/converter/index.d.ts +1 -0
  74. package/src/converter/index.js +1 -0
  75. package/src/converter/manifest-inference.js +43 -12
  76. package/src/converter/parsers/diffusion.js +0 -3
  77. package/src/converter/quantization-info.js +35 -15
  78. package/src/converter/rope-config.js +42 -0
  79. package/src/converter/shard-packer.d.ts +1 -1
  80. package/src/converter/shard-packer.js +4 -1
  81. package/src/debug/config.js +123 -11
  82. package/src/debug/signals.js +7 -1
  83. package/src/debug/tensor.d.ts +2 -0
  84. package/src/debug/tensor.js +13 -2
  85. package/src/distribution/p2p-control-plane.js +52 -12
  86. package/src/distribution/p2p-observability.js +43 -7
  87. package/src/distribution/p2p-webrtc-browser.js +20 -0
  88. package/src/distribution/shard-delivery.js +77 -26
  89. package/src/formats/gguf/types.js +33 -16
  90. package/src/formats/rdrr/groups.d.ts +12 -4
  91. package/src/formats/rdrr/groups.js +3 -6
  92. package/src/formats/rdrr/parsing.js +39 -2
  93. package/src/formats/rdrr/types.d.ts +2 -1
  94. package/src/gpu/command-recorder.js +86 -61
  95. package/src/gpu/device.d.ts +1 -0
  96. package/src/gpu/device.js +131 -19
  97. package/src/gpu/kernel-tuner/benchmarks.js +326 -316
  98. package/src/gpu/kernel-tuner/cache.js +71 -4
  99. package/src/gpu/kernel-tuner/tuner.js +22 -4
  100. package/src/gpu/kernels/attention.js +113 -34
  101. package/src/gpu/kernels/backward/adam.js +62 -58
  102. package/src/gpu/kernels/backward/attention_backward.js +257 -169
  103. package/src/gpu/kernels/backward/conv2d_backward.js +14 -1
  104. package/src/gpu/kernels/bias_add.wgsl +8 -6
  105. package/src/gpu/kernels/bias_add_f16.wgsl +8 -5
  106. package/src/gpu/kernels/cast.js +191 -149
  107. package/src/gpu/kernels/check-stop.js +33 -44
  108. package/src/gpu/kernels/conv2d.js +27 -17
  109. package/src/gpu/kernels/conv2d.wgsl +7 -8
  110. package/src/gpu/kernels/conv2d_f16.wgsl +7 -8
  111. package/src/gpu/kernels/cross_entropy_loss.js +21 -15
  112. package/src/gpu/kernels/depthwise_conv2d.js +37 -26
  113. package/src/gpu/kernels/depthwise_conv2d.wgsl +6 -9
  114. package/src/gpu/kernels/depthwise_conv2d_f16.wgsl +6 -9
  115. package/src/gpu/kernels/dequant.js +178 -126
  116. package/src/gpu/kernels/energy.d.ts +3 -21
  117. package/src/gpu/kernels/energy.js +111 -88
  118. package/src/gpu/kernels/feature-check.js +1 -1
  119. package/src/gpu/kernels/fused_ffn.js +84 -65
  120. package/src/gpu/kernels/fused_matmul_residual.js +56 -33
  121. package/src/gpu/kernels/fused_matmul_rmsnorm.js +62 -45
  122. package/src/gpu/kernels/gather.js +33 -15
  123. package/src/gpu/kernels/gelu.js +19 -11
  124. package/src/gpu/kernels/grouped_pointwise_conv2d.js +34 -23
  125. package/src/gpu/kernels/grouped_pointwise_conv2d.wgsl +6 -9
  126. package/src/gpu/kernels/grouped_pointwise_conv2d_f16.wgsl +6 -9
  127. package/src/gpu/kernels/groupnorm.js +34 -23
  128. package/src/gpu/kernels/kv-quantize.js +5 -2
  129. package/src/gpu/kernels/layernorm.js +35 -19
  130. package/src/gpu/kernels/logit-merge.js +5 -3
  131. package/src/gpu/kernels/matmul.js +83 -39
  132. package/src/gpu/kernels/modulate.js +23 -15
  133. package/src/gpu/kernels/moe.js +221 -175
  134. package/src/gpu/kernels/pixel_shuffle.js +22 -14
  135. package/src/gpu/kernels/pixel_shuffle.wgsl +4 -5
  136. package/src/gpu/kernels/pixel_shuffle_f16.wgsl +4 -5
  137. package/src/gpu/kernels/relu.js +31 -10
  138. package/src/gpu/kernels/relu.wgsl +2 -1
  139. package/src/gpu/kernels/relu_f16.wgsl +2 -1
  140. package/src/gpu/kernels/repeat_channels.js +25 -17
  141. package/src/gpu/kernels/repeat_channels.wgsl +4 -5
  142. package/src/gpu/kernels/repeat_channels_f16.wgsl +4 -5
  143. package/src/gpu/kernels/residual.js +69 -23
  144. package/src/gpu/kernels/residual.wgsl +6 -3
  145. package/src/gpu/kernels/residual_f16.wgsl +2 -1
  146. package/src/gpu/kernels/residual_f16_vec4.wgsl +2 -1
  147. package/src/gpu/kernels/residual_vec4.wgsl +2 -1
  148. package/src/gpu/kernels/rmsnorm.js +96 -28
  149. package/src/gpu/kernels/rmsnorm.wgsl +14 -6
  150. package/src/gpu/kernels/rmsnorm_f16.wgsl +10 -2
  151. package/src/gpu/kernels/rope.d.ts +2 -0
  152. package/src/gpu/kernels/rope.js +14 -1
  153. package/src/gpu/kernels/rope.wgsl +56 -40
  154. package/src/gpu/kernels/sample.js +27 -38
  155. package/src/gpu/kernels/sana_linear_attention.js +19 -12
  156. package/src/gpu/kernels/sana_linear_attention_apply.wgsl +4 -5
  157. package/src/gpu/kernels/sana_linear_attention_apply_f16.wgsl +4 -5
  158. package/src/gpu/kernels/sana_linear_attention_summary.wgsl +4 -0
  159. package/src/gpu/kernels/sana_linear_attention_summary_f16.wgsl +4 -0
  160. package/src/gpu/kernels/scale.js +18 -11
  161. package/src/gpu/kernels/shader-cache.js +4 -2
  162. package/src/gpu/kernels/silu.d.ts +1 -0
  163. package/src/gpu/kernels/silu.js +148 -82
  164. package/src/gpu/kernels/silu.wgsl +19 -9
  165. package/src/gpu/kernels/silu_f16.wgsl +19 -9
  166. package/src/gpu/kernels/softmax.js +44 -25
  167. package/src/gpu/kernels/split_qkv.js +23 -13
  168. package/src/gpu/kernels/transpose.js +31 -10
  169. package/src/gpu/kernels/transpose.wgsl +6 -5
  170. package/src/gpu/kernels/upsample2d.js +22 -13
  171. package/src/gpu/kernels/upsample2d.wgsl +6 -9
  172. package/src/gpu/kernels/upsample2d_f16.wgsl +6 -9
  173. package/src/gpu/kernels/utils.js +35 -13
  174. package/src/gpu/partitioned-buffer-pool.js +10 -2
  175. package/src/gpu/perf-guards.js +2 -9
  176. package/src/gpu/profiler.js +27 -22
  177. package/src/gpu/readback-utils.d.ts +16 -0
  178. package/src/gpu/readback-utils.js +41 -0
  179. package/src/gpu/submit-tracker.js +13 -0
  180. package/src/gpu/uniform-cache.d.ts +1 -0
  181. package/src/gpu/uniform-cache.js +30 -9
  182. package/src/hotswap/intent-bundle.js +6 -0
  183. package/src/hotswap/manifest.d.ts +10 -1
  184. package/src/hotswap/manifest.js +12 -2
  185. package/src/hotswap/runtime.js +30 -8
  186. package/src/index-browser.d.ts +44 -0
  187. package/src/index-browser.js +14 -0
  188. package/src/inference/browser-harness-contract-helpers.d.ts +5 -0
  189. package/src/inference/browser-harness-contract-helpers.js +28 -0
  190. package/src/inference/browser-harness-diffusion-energy-suites.d.ts +2 -0
  191. package/src/inference/browser-harness-diffusion-energy-suites.js +269 -0
  192. package/src/inference/browser-harness-model-helpers.d.ts +16 -0
  193. package/src/inference/browser-harness-model-helpers.js +217 -0
  194. package/src/inference/browser-harness-report-helpers.d.ts +7 -0
  195. package/src/inference/browser-harness-report-helpers.js +42 -0
  196. package/src/inference/browser-harness-runtime-helpers.d.ts +61 -0
  197. package/src/inference/browser-harness-runtime-helpers.js +415 -0
  198. package/src/inference/browser-harness-suite-helpers.d.ts +28 -0
  199. package/src/inference/browser-harness-suite-helpers.js +268 -0
  200. package/src/inference/browser-harness-text-helpers.d.ts +27 -0
  201. package/src/inference/browser-harness-text-helpers.js +788 -0
  202. package/src/inference/browser-harness.d.ts +6 -0
  203. package/src/inference/browser-harness.js +130 -1950
  204. package/src/inference/kv-cache/base.js +140 -94
  205. package/src/inference/kv-cache/tiered.js +5 -3
  206. package/src/inference/moe-router.js +88 -56
  207. package/src/inference/multi-model-network.js +5 -3
  208. package/src/inference/network-evolution.d.ts +11 -2
  209. package/src/inference/network-evolution.js +20 -21
  210. package/src/inference/pipelines/context.d.ts +3 -0
  211. package/src/inference/pipelines/context.js +142 -2
  212. package/src/inference/pipelines/diffusion/helpers.js +7 -2
  213. package/src/inference/pipelines/diffusion/pipeline.js +17 -7
  214. package/src/inference/pipelines/diffusion/sd3-transformer.js +10 -10
  215. package/src/inference/pipelines/diffusion/text-encoder-gpu.d.ts +5 -0
  216. package/src/inference/pipelines/diffusion/text-encoder-gpu.js +27 -15
  217. package/src/inference/pipelines/diffusion/vae.js +3 -7
  218. package/src/inference/pipelines/energy/pipeline.js +27 -21
  219. package/src/inference/pipelines/energy/quintel.d.ts +5 -0
  220. package/src/inference/pipelines/energy/quintel.js +11 -0
  221. package/src/inference/pipelines/energy-head/row-head-pipeline.js +17 -13
  222. package/src/inference/pipelines/structured/json-head-pipeline.js +26 -11
  223. package/src/inference/pipelines/text/attention/projections.js +151 -101
  224. package/src/inference/pipelines/text/attention/record.js +73 -10
  225. package/src/inference/pipelines/text/attention/run.js +73 -10
  226. package/src/inference/pipelines/text/chat-format.js +25 -1
  227. package/src/inference/pipelines/text/config.d.ts +4 -0
  228. package/src/inference/pipelines/text/config.js +71 -5
  229. package/src/inference/pipelines/text/embed.js +2 -8
  230. package/src/inference/pipelines/text/execution-plan.js +64 -50
  231. package/src/inference/pipelines/text/execution-v0-contract-helpers.d.ts +59 -0
  232. package/src/inference/pipelines/text/execution-v0-contract-helpers.js +937 -0
  233. package/src/inference/pipelines/text/execution-v0-runtime-builders.d.ts +15 -0
  234. package/src/inference/pipelines/text/execution-v0-runtime-builders.js +279 -0
  235. package/src/inference/pipelines/text/execution-v0.js +78 -1002
  236. package/src/inference/pipelines/text/ffn/standard.js +3 -0
  237. package/src/inference/pipelines/text/generator-steps.d.ts +46 -0
  238. package/src/inference/pipelines/text/generator-steps.js +298 -207
  239. package/src/inference/pipelines/text/generator.js +6 -23
  240. package/src/inference/pipelines/text/init.d.ts +4 -0
  241. package/src/inference/pipelines/text/init.js +134 -29
  242. package/src/inference/pipelines/text/kernel-path-auto-select.js +2 -0
  243. package/src/inference/pipelines/text/kernel-trace.d.ts +2 -0
  244. package/src/inference/pipelines/text/kernel-trace.js +6 -0
  245. package/src/inference/pipelines/text/layer.js +14 -9
  246. package/src/inference/pipelines/text/linear-attention.d.ts +10 -0
  247. package/src/inference/pipelines/text/linear-attention.js +80 -6
  248. package/src/inference/pipelines/text/logits/gpu.js +10 -5
  249. package/src/inference/pipelines/text/logits/index.js +10 -11
  250. package/src/inference/pipelines/text/logits/utils.d.ts +7 -0
  251. package/src/inference/pipelines/text/logits/utils.js +9 -0
  252. package/src/inference/pipelines/text/lora-apply.js +50 -32
  253. package/src/inference/pipelines/text/model-load.js +279 -104
  254. package/src/inference/pipelines/text/moe-cache.js +5 -4
  255. package/src/inference/pipelines/text/moe-cpu-gptoss.js +74 -69
  256. package/src/inference/pipelines/text/moe-cpu.js +42 -38
  257. package/src/inference/pipelines/text/moe-gpu.js +110 -86
  258. package/src/inference/pipelines/text/ops.js +90 -90
  259. package/src/inference/pipelines/text/probes.js +9 -9
  260. package/src/inference/pipelines/text/weights.js +17 -7
  261. package/src/inference/pipelines/text.js +17 -1
  262. package/src/inference/speculative.d.ts +2 -2
  263. package/src/inference/speculative.js +4 -18
  264. package/src/inference/test-harness.d.ts +1 -1
  265. package/src/inference/test-harness.js +15 -5
  266. package/src/inference/tokenizer.d.ts +0 -5
  267. package/src/inference/tokenizer.js +4 -23
  268. package/src/inference/tokenizers/bpe.js +9 -0
  269. package/src/inference/tokenizers/bundled.js +176 -33
  270. package/src/inference/tokenizers/sentencepiece.js +12 -0
  271. package/src/loader/doppler-loader.js +38 -22
  272. package/src/loader/dtype-utils.js +3 -44
  273. package/src/loader/embedding-loader.js +7 -3
  274. package/src/loader/experts/expert-cache.js +13 -6
  275. package/src/loader/experts/expert-loader.js +10 -6
  276. package/src/loader/final-weights-loader.js +8 -4
  277. package/src/loader/layer-loader.js +2 -1
  278. package/src/loader/loader-state.js +2 -2
  279. package/src/loader/memory-monitor.js +8 -0
  280. package/src/loader/multi-model-loader.d.ts +14 -0
  281. package/src/loader/multi-model-loader.js +70 -24
  282. package/src/loader/shard-cache.js +81 -12
  283. package/src/loader/shard-resolver.js +25 -3
  284. package/src/loader/tensors/tensor-loader.js +209 -144
  285. package/src/loader/tensors/tensor-reader.js +76 -19
  286. package/src/loader/weight-downcast.js +1 -1
  287. package/src/memory/buffer-pool.d.ts +9 -1
  288. package/src/memory/buffer-pool.js +109 -44
  289. package/src/memory/unified-detect.js +1 -1
  290. package/src/rules/inference/kernel-path.rules.json +24 -8
  291. package/src/rules/rule-registry.js +25 -1
  292. package/src/rules/tooling/command-runtime.rules.json +18 -0
  293. package/src/storage/backends/opfs-store.js +68 -24
  294. package/src/storage/downloader.js +364 -83
  295. package/src/storage/index.d.ts +3 -0
  296. package/src/storage/index.js +3 -0
  297. package/src/storage/preflight.d.ts +2 -2
  298. package/src/storage/preflight.js +24 -2
  299. package/src/storage/quickstart-downloader.js +11 -5
  300. package/src/storage/registry.js +10 -4
  301. package/src/storage/reports.js +1 -1
  302. package/src/storage/shard-manager.d.ts +15 -1
  303. package/src/storage/shard-manager.js +51 -3
  304. package/src/storage/source-artifact-store.d.ts +52 -0
  305. package/src/storage/source-artifact-store.js +234 -0
  306. package/src/tooling/command-api-constants.d.ts +9 -0
  307. package/src/tooling/command-api-constants.js +9 -0
  308. package/src/tooling/command-api-family-normalizers.d.ts +9 -0
  309. package/src/tooling/command-api-family-normalizers.js +343 -0
  310. package/src/tooling/command-api-helpers.d.ts +25 -0
  311. package/src/tooling/command-api-helpers.js +262 -0
  312. package/src/tooling/command-api.d.ts +27 -1
  313. package/src/tooling/command-api.js +26 -473
  314. package/src/tooling/command-envelope.js +4 -1
  315. package/src/tooling/command-runner-shared.js +52 -18
  316. package/src/tooling/lean-execution-contract.js +150 -3
  317. package/src/tooling/node-browser-command-runner.d.ts +4 -0
  318. package/src/tooling/node-browser-command-runner.js +218 -273
  319. package/src/tooling/node-command-runner.js +44 -3
  320. package/src/tooling/node-converter.js +27 -1
  321. package/src/tooling/node-source-runtime.d.ts +1 -1
  322. package/src/tooling/node-source-runtime.js +84 -3
  323. package/src/tooling/node-webgpu.js +30 -105
  324. package/src/tooling/opfs-cache.js +21 -4
  325. package/src/tooling/runtime-input-composition.d.ts +38 -0
  326. package/src/tooling/runtime-input-composition.js +86 -0
  327. package/src/tooling/source-runtime-bundle.d.ts +40 -5
  328. package/src/tooling/source-runtime-bundle.js +261 -34
  329. package/src/tooling/source-runtime-materializer.d.ts +6 -0
  330. package/src/tooling/source-runtime-materializer.js +93 -0
  331. package/src/training/attention-backward.js +32 -17
  332. package/src/training/autograd.js +80 -52
  333. package/src/training/checkpoint-watch.d.ts +8 -0
  334. package/src/training/checkpoint-watch.js +139 -0
  335. package/src/training/checkpoint.d.ts +6 -1
  336. package/src/training/checkpoint.js +46 -7
  337. package/src/training/clip.js +2 -1
  338. package/src/training/datasets/token-batch.js +20 -8
  339. package/src/training/distillation/artifacts.d.ts +71 -0
  340. package/src/training/distillation/artifacts.js +132 -0
  341. package/src/training/distillation/checkpoint-watch.d.ts +10 -0
  342. package/src/training/distillation/checkpoint-watch.js +58 -0
  343. package/src/training/distillation/dataset.d.ts +59 -0
  344. package/src/training/distillation/dataset.js +337 -0
  345. package/src/training/distillation/eval.d.ts +34 -0
  346. package/src/training/distillation/eval.js +310 -0
  347. package/src/training/distillation/index.d.ts +29 -0
  348. package/src/training/distillation/index.js +29 -0
  349. package/src/training/distillation/runtime.d.ts +20 -0
  350. package/src/training/distillation/runtime.js +121 -0
  351. package/src/training/distillation/scoreboard.d.ts +6 -0
  352. package/src/training/distillation/scoreboard.js +8 -0
  353. package/src/training/distillation/stage-a.d.ts +45 -0
  354. package/src/training/distillation/stage-a.js +338 -0
  355. package/src/training/distillation/stage-b.d.ts +24 -0
  356. package/src/training/distillation/stage-b.js +20 -0
  357. package/src/training/distillation/student-fixture.d.ts +22 -0
  358. package/src/training/distillation/student-fixture.js +846 -0
  359. package/src/training/distillation/suite-data.d.ts +45 -0
  360. package/src/training/distillation/suite-data.js +189 -0
  361. package/src/training/index.d.ts +10 -0
  362. package/src/training/index.js +10 -0
  363. package/src/training/lora-pipeline.d.ts +40 -0
  364. package/src/training/lora-pipeline.js +793 -0
  365. package/src/training/lora.js +26 -12
  366. package/src/training/loss.js +5 -6
  367. package/src/training/objectives/cross_entropy.js +2 -5
  368. package/src/training/objectives/distill_kd.js +4 -8
  369. package/src/training/objectives/distill_triplet.js +4 -8
  370. package/src/training/objectives/ul_stage2_base.js +4 -8
  371. package/src/training/operator-artifacts.d.ts +62 -0
  372. package/src/training/operator-artifacts.js +140 -0
  373. package/src/training/operator-command.d.ts +5 -0
  374. package/src/training/operator-command.js +455 -0
  375. package/src/training/operator-eval.d.ts +48 -0
  376. package/src/training/operator-eval.js +230 -0
  377. package/src/training/operator-scoreboard.d.ts +5 -0
  378. package/src/training/operator-scoreboard.js +44 -0
  379. package/src/training/optimizer.js +19 -7
  380. package/src/training/runner.d.ts +52 -0
  381. package/src/training/runner.js +31 -5
  382. package/src/training/suite.d.ts +112 -0
  383. package/src/training/suite.js +24 -984
  384. package/src/training/tensor-factory.d.ts +9 -0
  385. package/src/training/tensor-factory.js +13 -0
  386. package/src/training/trainer.js +3 -5
  387. package/src/training/ul_dataset.js +3 -5
  388. package/src/training/workloads.d.ts +164 -0
  389. package/src/training/workloads.js +530 -0
  390. package/src/version.js +1 -1
  391. package/tools/convert-safetensors-node.js +22 -16
  392. package/tools/doppler-cli.js +179 -63
@@ -0,0 +1,846 @@
1
+ import { getDevice } from '../../gpu/device.js';
2
+ import { createTrainingConfig } from '../../config/training-defaults.js';
3
+ import {
4
+ runAttention,
5
+ castF16ToF32,
6
+ runGather,
7
+ runMatmul,
8
+ runResidualAdd,
9
+ runRMSNorm,
10
+ runRoPE,
11
+ runSiLURowSplit,
12
+ } from '../../gpu/kernels/index.js';
13
+ import { createTensor } from '../../gpu/tensor.js';
14
+ import { acquireBuffer, uploadData, releaseBuffer } from '../../memory/buffer-pool.js';
15
+ import { getBufferDtype, getWeightDtype, isCpuWeightBuffer, isWeightBuffer } from '../../gpu/weight-buffer.js';
16
+ import { OpType } from '../autograd.js';
17
+ import { normalizeOptionalString } from './suite-data.js';
18
+
19
+ const DISTILL_ADAPTER_TOP_K = 64;
20
+ const DISTILL_STUDENT_GRAPH_PROJECTION = 'projection_head';
21
+ const DISTILL_STUDENT_GRAPH_FULL = 'transformer_full';
22
+
23
+ function makeTensorFromFloat32(values, shape, label) {
24
+ const data = values instanceof Float32Array ? values : new Float32Array(values);
25
+ const buffer = acquireBuffer(data.byteLength, undefined, label || 'train_tensor');
26
+ uploadData(buffer, data);
27
+ return createTensor(buffer, 'f32', shape, label || 'train_tensor');
28
+ }
29
+
30
+ function makeTensorFromF16Bits(values, shape, label) {
31
+ const data = values instanceof Uint16Array ? values : new Uint16Array(values);
32
+ const buffer = acquireBuffer(data.byteLength, undefined, label || 'train_tensor_f16');
33
+ uploadData(buffer, data);
34
+ return createTensor(buffer, 'f16', shape, label || 'train_tensor_f16');
35
+ }
36
+
37
+ function makeTensorFromUint32(values, shape, label) {
38
+ const data = values instanceof Uint32Array ? values : new Uint32Array(values);
39
+ const buffer = acquireBuffer(data.byteLength, undefined, label || 'train_tokens');
40
+ uploadData(buffer, data);
41
+ return createTensor(buffer, 'f32', shape, label || 'train_tokens');
42
+ }
43
+
44
+ function releaseTensor(tensor) {
45
+ if (!tensor?.buffer) return;
46
+ releaseBuffer(tensor.buffer);
47
+ }
48
+
49
+ function toFloat32Array(values, label = 'values') {
50
+ if (values instanceof Float32Array) return values;
51
+ if (ArrayBuffer.isView(values)) {
52
+ return new Float32Array(values.buffer.slice(values.byteOffset, values.byteOffset + values.byteLength));
53
+ }
54
+ if (values instanceof ArrayBuffer) {
55
+ return new Float32Array(values.slice(0));
56
+ }
57
+ if (Array.isArray(values)) {
58
+ return new Float32Array(values);
59
+ }
60
+ throw new Error(`Expected ${label} to be a Float32Array-compatible value.`);
61
+ }
62
+
63
+ function disposePrefillSnapshot(result) {
64
+ const cache = result?.cache;
65
+ if (cache && typeof cache.clear === 'function') {
66
+ cache.clear();
67
+ }
68
+ }
69
+
70
+ function toFiniteNumber(value, fallback) {
71
+ const parsed = Number(value);
72
+ return Number.isFinite(parsed) ? parsed : fallback;
73
+ }
74
+
75
+ function clampDistillTopK(value) {
76
+ const parsed = Math.floor(toFiniteNumber(value, DISTILL_ADAPTER_TOP_K));
77
+ return Math.max(2, Math.min(256, parsed));
78
+ }
79
+
80
+ function normalizeDistillStudentGraphMode(value) {
81
+ const normalized = normalizeOptionalString(value);
82
+ if (!normalized) return DISTILL_STUDENT_GRAPH_FULL;
83
+ const compact = normalized.toLowerCase().replace(/[-\s]/g, '_');
84
+ if (compact === 'projection_head' || compact === 'projection') {
85
+ return DISTILL_STUDENT_GRAPH_PROJECTION;
86
+ }
87
+ return DISTILL_STUDENT_GRAPH_FULL;
88
+ }
89
+
90
+ function resolveTensorDtype(value, fallback = 'f32') {
91
+ const dtype = isWeightBuffer(value)
92
+ ? value.dtype
93
+ : (value?.dtype || getWeightDtype(value) || null);
94
+ const normalized = String(dtype || '').toLowerCase();
95
+ return normalized === 'f16' ? 'f16' : (normalized === 'f32' ? 'f32' : fallback);
96
+ }
97
+
98
+ async function ensureTrainableTensor(value, shape, label, ownedTrainables = null) {
99
+ if (!value) {
100
+ throw new Error(`Distill full-graph student missing required weight "${label}".`);
101
+ }
102
+ const registerOwned = (tensor) => {
103
+ if (ownedTrainables instanceof Set && tensor?.buffer instanceof GPUBuffer) {
104
+ ownedTrainables.add(tensor);
105
+ }
106
+ return tensor;
107
+ };
108
+ if (isWeightBuffer(value)) {
109
+ if (value.dtype === 'f32') {
110
+ return value;
111
+ }
112
+ if (value.dtype === 'f16') {
113
+ const sourceShape = Array.isArray(value.shape) && value.shape.length > 0 ? value.shape : [...shape];
114
+ const source = createTensor(value.buffer, 'f16', sourceShape, `${label}_source_f16`);
115
+ const promoted = await castF16ToF32(source);
116
+ return registerOwned(createTensor(promoted.buffer, 'f32', sourceShape, `${label}_trainable_f32`));
117
+ }
118
+ throw new Error(`Distill full-graph student weight "${label}" uses unsupported dtype "${value.dtype}".`);
119
+ }
120
+ if (value instanceof GPUBuffer) {
121
+ const sourceShape = [...shape];
122
+ const rawDtype = String(getBufferDtype(value) || 'f32').toLowerCase();
123
+ const dtype = rawDtype === 'f16' ? 'f16' : 'f32';
124
+ const tensor = createTensor(value, dtype, sourceShape, label);
125
+ if (dtype === 'f16') {
126
+ const promoted = await castF16ToF32(tensor);
127
+ return registerOwned(createTensor(promoted.buffer, 'f32', sourceShape, `${label}_trainable_f32`));
128
+ }
129
+ return tensor;
130
+ }
131
+ if (isCpuWeightBuffer(value)) {
132
+ const sourceShape = Array.isArray(value.shape) && value.shape.length > 0 ? value.shape : [...shape];
133
+ const dtype = resolveTensorDtype(value, 'f32');
134
+ if (dtype === 'f32') {
135
+ const tensor = makeTensorFromFloat32(value.data, sourceShape, `${label}_cpu_f32`);
136
+ return registerOwned(tensor);
137
+ }
138
+ if (dtype === 'f16') {
139
+ let raw = null;
140
+ if (value.data instanceof Uint16Array) {
141
+ raw = value.data;
142
+ } else if (ArrayBuffer.isView(value.data)) {
143
+ raw = new Uint16Array(
144
+ value.data.buffer,
145
+ value.data.byteOffset,
146
+ Math.floor(value.data.byteLength / 2)
147
+ );
148
+ } else if (value.data instanceof ArrayBuffer) {
149
+ raw = new Uint16Array(value.data);
150
+ }
151
+ if (!raw) {
152
+ throw new Error(`Distill full-graph student weight "${label}" has non-typed f16 CPU data.`);
153
+ }
154
+ const source = makeTensorFromF16Bits(raw, sourceShape, `${label}_cpu_f16`);
155
+ const promoted = await castF16ToF32(source);
156
+ releaseTensor(source);
157
+ return registerOwned(createTensor(promoted.buffer, 'f32', sourceShape, `${label}_trainable_f32`));
158
+ }
159
+ throw new Error(`Distill full-graph student weight "${label}" has unsupported CPU dtype "${dtype}".`);
160
+ }
161
+ if (value.buffer instanceof GPUBuffer) {
162
+ const resolvedShape = Array.isArray(value.shape) && value.shape.length > 0 ? value.shape : [...shape];
163
+ const tensor = createTensor(
164
+ value.buffer,
165
+ resolveTensorDtype(value, 'f32'),
166
+ resolvedShape,
167
+ label
168
+ );
169
+ if (tensor.dtype === 'f16') {
170
+ const promoted = await castF16ToF32(tensor);
171
+ return registerOwned(createTensor(promoted.buffer, 'f32', resolvedShape, `${label}_trainable_f32`));
172
+ }
173
+ return tensor;
174
+ }
175
+ throw new Error(`Distill full-graph student weight "${label}" is not GPU-resident.`);
176
+ }
177
+
178
+ async function ensureNormTensor(value, hiddenSize, label, ownedTrainables = null) {
179
+ return ensureTrainableTensor(value, [hiddenSize], label, ownedTrainables);
180
+ }
181
+
182
+ function hasTensorPayload(value) {
183
+ if (!value) return false;
184
+ if (value instanceof GPUBuffer) return true;
185
+ if (isWeightBuffer(value) || isCpuWeightBuffer(value)) return true;
186
+ if (value?.buffer instanceof GPUBuffer) return true;
187
+ if (ArrayBuffer.isView(value) || Array.isArray(value)) return true;
188
+ return false;
189
+ }
190
+
191
+ async function fuseGateUpTensors(gateTensor, upTensor, intermediateSize, hiddenSize, label, ownedTrainables = null) {
192
+ const device = getDevice();
193
+ if (!device) {
194
+ throw new Error('Distill full-graph student requires active GPU device.');
195
+ }
196
+ if (gateTensor?.dtype !== 'f32' || upTensor?.dtype !== 'f32') {
197
+ throw new Error(`Distill fused gate_up expects f32 tensors for "${label}".`);
198
+ }
199
+ const expectedRows = intermediateSize;
200
+ const expectedCols = hiddenSize;
201
+ const gateRows = Number.isFinite(gateTensor?.shape?.[0]) ? gateTensor.shape[0] : 0;
202
+ const gateCols = Number.isFinite(gateTensor?.shape?.[1]) ? gateTensor.shape[1] : 0;
203
+ const upRows = Number.isFinite(upTensor?.shape?.[0]) ? upTensor.shape[0] : 0;
204
+ const upCols = Number.isFinite(upTensor?.shape?.[1]) ? upTensor.shape[1] : 0;
205
+ if (gateRows !== expectedRows || gateCols !== expectedCols || upRows !== expectedRows || upCols !== expectedCols) {
206
+ throw new Error(
207
+ `Distill gate/up shape mismatch for "${label}": gate=[${gateRows},${gateCols}] up=[${upRows},${upCols}] ` +
208
+ `expected=[${expectedRows},${expectedCols}]`
209
+ );
210
+ }
211
+ const rowBytes = expectedCols * 4;
212
+ const blockBytes = expectedRows * rowBytes;
213
+ const fusedBuffer = acquireBuffer(blockBytes * 2, undefined, `${label}_fused`);
214
+ const encoder = device.createCommandEncoder();
215
+ encoder.copyBufferToBuffer(gateTensor.buffer, 0, fusedBuffer, 0, blockBytes);
216
+ encoder.copyBufferToBuffer(upTensor.buffer, 0, fusedBuffer, blockBytes, blockBytes);
217
+ device.queue.submit([encoder.finish()]);
218
+ const fused = createTensor(fusedBuffer, 'f32', [expectedRows * 2, expectedCols], `${label}_fused`);
219
+ if (ownedTrainables instanceof Set) {
220
+ ownedTrainables.add(fused);
221
+ }
222
+ return fused;
223
+ }
224
+
225
+ function resolvePhasePrompts(batch, phase) {
226
+ const distill = batch?.distill || {};
227
+ const prompts = phase === 'positive'
228
+ ? distill.tripletPositivePrompts
229
+ : (phase === 'negative' ? distill.tripletNegativePrompts : distill.prompts);
230
+ if (!Array.isArray(prompts) || prompts.length === 0) {
231
+ throw new Error(`Distill student fixture requires distill prompts for phase "${phase}".`);
232
+ }
233
+ return prompts;
234
+ }
235
+
236
+ function createRowSliceTensor(inputTensor, rows, cols, rowIndex, label) {
237
+ const device = getDevice();
238
+ if (!device) {
239
+ throw new Error('Distill full-graph student requires active GPU device.');
240
+ }
241
+ const dtype = inputTensor?.dtype === 'f16' ? 'f16' : 'f32';
242
+ const bytesPerElement = dtype === 'f16' ? 2 : 4;
243
+ const rowBytes = cols * bytesPerElement;
244
+ const clampedRow = Math.max(0, Math.min(rows - 1, rowIndex));
245
+ const outputBuffer = acquireBuffer(rowBytes, undefined, label);
246
+ const encoder = device.createCommandEncoder();
247
+ encoder.copyBufferToBuffer(
248
+ inputTensor.buffer,
249
+ clampedRow * rowBytes,
250
+ outputBuffer,
251
+ 0,
252
+ rowBytes
253
+ );
254
+ device.queue.submit([encoder.finish()]);
255
+ return createTensor(outputBuffer, dtype, [1, cols], label);
256
+ }
257
+
258
+ function createDistillStudentProjectionModelFixture(overrides = {}, options = {}) {
259
+ const distillRuntime = options.distillRuntime && typeof options.distillRuntime === 'object'
260
+ ? options.distillRuntime
261
+ : null;
262
+ if (!distillRuntime?.studentPipeline) {
263
+ throw new Error('Distill student fixture requires distillRuntime.studentPipeline.');
264
+ }
265
+ const outputDim = clampDistillTopK(
266
+ options.outputDim
267
+ ?? options.inputDim
268
+ ?? DISTILL_ADAPTER_TOP_K
269
+ );
270
+ const inferredEmbeddingDim = Math.floor(
271
+ Number(distillRuntime.studentPipeline?.modelConfig?.hiddenSize)
272
+ );
273
+ const embeddingDim = Number.isInteger(options.embeddingDim) && options.embeddingDim > 0
274
+ ? options.embeddingDim
275
+ : (Number.isFinite(inferredEmbeddingDim) && inferredEmbeddingDim > 0
276
+ ? inferredEmbeddingDim
277
+ : outputDim);
278
+ const config = createTrainingConfig({
279
+ ...overrides,
280
+ training: {
281
+ enabled: true,
282
+ lossScaling: { enabled: false },
283
+ gradient: { maxNorm: 0 },
284
+ ...(overrides.training || {}),
285
+ },
286
+ });
287
+
288
+ const projectionWeights = new Float32Array(embeddingDim * outputDim);
289
+ const projectionWeight = makeTensorFromFloat32(
290
+ projectionWeights,
291
+ [embeddingDim, outputDim],
292
+ 'distill_student_head_weight'
293
+ );
294
+ const temporaryInputs = new Set();
295
+
296
+ async function projectEmbeddingInput(inputTensor, tape) {
297
+ const rows = Number.isFinite(inputTensor?.shape?.[0]) ? inputTensor.shape[0] : 1;
298
+ return tape.record(
299
+ OpType.MATMUL,
300
+ (a, b) => runMatmul(a, b, rows, outputDim, embeddingDim, { transposeB: false }),
301
+ [inputTensor, projectionWeight],
302
+ { M: rows, N: outputDim, K: embeddingDim, transposeB: false }
303
+ );
304
+ }
305
+
306
+ async function buildStudentEmbeddingInput(batch, phase = 'anchor') {
307
+ const prompts = resolvePhasePrompts(batch, phase);
308
+ const rows = prompts.length;
309
+ const features = new Float32Array(rows * embeddingDim);
310
+ for (let row = 0; row < rows; row += 1) {
311
+ const prompt = String(prompts[row] || '').trim();
312
+ const studentResult = await distillRuntime.studentPipeline.prefillWithEmbedding(prompt, {
313
+ useChatTemplate: false,
314
+ embeddingMode: 'last',
315
+ });
316
+ try {
317
+ const studentEmbedding = toFloat32Array(studentResult?.embedding, 'student embedding');
318
+ const rowOffset = row * embeddingDim;
319
+ const copyCount = Math.min(embeddingDim, studentEmbedding.length);
320
+ features.set(studentEmbedding.subarray(0, copyCount), rowOffset);
321
+ } finally {
322
+ disposePrefillSnapshot(studentResult);
323
+ distillRuntime.studentPipeline.reset();
324
+ }
325
+ }
326
+ const inputTensor = makeTensorFromFloat32(
327
+ features,
328
+ [rows, embeddingDim],
329
+ `distill_student_${phase}_embedding`
330
+ );
331
+ temporaryInputs.add(inputTensor);
332
+ return inputTensor;
333
+ }
334
+
335
+ const model = {
336
+ async forward(inputTensor, tape) {
337
+ return projectEmbeddingInput(inputTensor, tape);
338
+ },
339
+ async forwardDistill(batch, tape, forwardOptions = {}) {
340
+ const requestedPhase = String(forwardOptions?.phase || 'anchor').trim();
341
+ const phase = requestedPhase === 'positive'
342
+ ? 'positive'
343
+ : (requestedPhase === 'negative' ? 'negative' : 'anchor');
344
+ const inputTensor = await buildStudentEmbeddingInput(batch, phase);
345
+ const logits = await projectEmbeddingInput(inputTensor, tape);
346
+ return { logits };
347
+ },
348
+ cleanupDistillStep() {
349
+ for (const tensor of temporaryInputs) {
350
+ releaseTensor(tensor);
351
+ }
352
+ temporaryInputs.clear();
353
+ },
354
+ loraParams() {
355
+ return [projectionWeight];
356
+ },
357
+ paramGroups() {
358
+ return {
359
+ encoder: [],
360
+ prior: [],
361
+ decoder: [],
362
+ base: [projectionWeight],
363
+ lora: [projectionWeight],
364
+ };
365
+ },
366
+ };
367
+
368
+ return {
369
+ config,
370
+ model,
371
+ outputDim,
372
+ embeddingDim,
373
+ cleanup() {
374
+ model.cleanupDistillStep();
375
+ releaseTensor(projectionWeight);
376
+ },
377
+ };
378
+ }
379
+
380
+ async function createDistillStudentTransformerModelFixture(overrides = {}, options = {}) {
381
+ const distillRuntime = options.distillRuntime && typeof options.distillRuntime === 'object'
382
+ ? options.distillRuntime
383
+ : null;
384
+ const studentPipeline = distillRuntime?.studentPipeline || null;
385
+ if (!studentPipeline?.modelConfig || !(studentPipeline.weights instanceof Map)) {
386
+ throw new Error('Distill full-graph student fixture requires loaded student pipeline weights.');
387
+ }
388
+ const modelConfig = studentPipeline.modelConfig;
389
+ const hiddenSize = Math.max(1, Math.floor(Number(modelConfig.hiddenSize) || 0));
390
+ const intermediateSize = Math.max(1, Math.floor(Number(modelConfig.intermediateSize) || 0));
391
+ const numLayers = Math.max(1, Math.floor(Number(modelConfig.numLayers) || 0));
392
+ const numHeads = Math.max(1, Math.floor(Number(modelConfig.numHeads) || 0));
393
+ const numKVHeads = Math.max(1, Math.floor(Number(modelConfig.numKVHeads || numHeads) || 0));
394
+ const headDim = Math.max(1, Math.floor(Number(modelConfig.headDim) || 0));
395
+ const vocabSize = Math.max(1, Math.floor(Number(modelConfig.vocabSize) || 0));
396
+ const rmsNormEps = Number.isFinite(modelConfig.rmsNormEps) ? modelConfig.rmsNormEps : 1e-6;
397
+ const hiddenActivation = String(modelConfig.hiddenActivation || 'silu').toLowerCase();
398
+ const swigluLimit = Number.isFinite(modelConfig.swigluLimit) ? modelConfig.swigluLimit : 0;
399
+ const useEmbeddingTranspose = modelConfig.embeddingTranspose === true;
400
+ const tieWordEmbeddings = modelConfig.useTiedEmbeddings === true;
401
+
402
+ const config = createTrainingConfig({
403
+ ...overrides,
404
+ training: {
405
+ enabled: true,
406
+ lossScaling: { enabled: false },
407
+ gradient: { maxNorm: 0 },
408
+ ...(overrides.training || {}),
409
+ },
410
+ });
411
+
412
+ const ownedTrainables = new Set();
413
+ const embeddingWeight = await ensureTrainableTensor(
414
+ studentPipeline.weights.get('embed'),
415
+ [vocabSize, hiddenSize],
416
+ 'embed',
417
+ ownedTrainables
418
+ );
419
+ const lmHeadWeight = tieWordEmbeddings
420
+ ? embeddingWeight
421
+ : await ensureTrainableTensor(
422
+ studentPipeline.weights.get('lm_head'),
423
+ [vocabSize, hiddenSize],
424
+ 'lm_head',
425
+ ownedTrainables
426
+ );
427
+ const finalNormWeight = await ensureNormTensor(
428
+ studentPipeline.weights.get('final_norm'),
429
+ hiddenSize,
430
+ 'final_norm',
431
+ ownedTrainables
432
+ );
433
+
434
+ const ropeDim = Math.max(1, Math.floor(headDim / 2));
435
+ const ropeRows = Math.max(1, Math.floor(Number(modelConfig.maxSeqLen) || 1));
436
+ const ropeCos = await ensureTrainableTensor(
437
+ createTensor(studentPipeline.ropeFreqsCos, 'f32', [ropeRows, ropeDim], 'rope_cos'),
438
+ [ropeRows, ropeDim],
439
+ 'rope_cos',
440
+ ownedTrainables
441
+ );
442
+ const ropeSin = await ensureTrainableTensor(
443
+ createTensor(studentPipeline.ropeFreqsSin, 'f32', [ropeRows, ropeDim], 'rope_sin'),
444
+ [ropeRows, ropeDim],
445
+ 'rope_sin',
446
+ ownedTrainables
447
+ );
448
+
449
+ const layerParams = [];
450
+ const layers = [];
451
+ for (let layerIdx = 0; layerIdx < numLayers; layerIdx += 1) {
452
+ const layerWeights = studentPipeline.weights.get(`layer_${layerIdx}`);
453
+ if (!layerWeights) {
454
+ throw new Error(`Distill full-graph student missing layer_${layerIdx} weights.`);
455
+ }
456
+ const gateUpWeight = layerWeights.gateUp || layerWeights.ffnGateUp || null;
457
+ let layerGateUp = null;
458
+ if (hasTensorPayload(gateUpWeight)) {
459
+ layerGateUp = await ensureTrainableTensor(
460
+ gateUpWeight,
461
+ [intermediateSize * 2, hiddenSize],
462
+ `layer_${layerIdx}.ffn_gate_up`,
463
+ ownedTrainables
464
+ );
465
+ } else {
466
+ const gateWeight = layerWeights.gate || layerWeights.ffnGate || null;
467
+ const upWeight = layerWeights.up || layerWeights.ffnUp || null;
468
+ if (!hasTensorPayload(gateWeight) || !hasTensorPayload(upWeight)) {
469
+ throw new Error(
470
+ `Distill full-graph student missing gate/up projections on layer ${layerIdx}.`
471
+ );
472
+ }
473
+ const gateTensor = await ensureTrainableTensor(
474
+ gateWeight,
475
+ [intermediateSize, hiddenSize],
476
+ `layer_${layerIdx}.ffn_gate`,
477
+ ownedTrainables
478
+ );
479
+ const upTensor = await ensureTrainableTensor(
480
+ upWeight,
481
+ [intermediateSize, hiddenSize],
482
+ `layer_${layerIdx}.ffn_up`,
483
+ ownedTrainables
484
+ );
485
+ layerGateUp = await fuseGateUpTensors(
486
+ gateTensor,
487
+ upTensor,
488
+ intermediateSize,
489
+ hiddenSize,
490
+ `layer_${layerIdx}.ffn_gate_up`,
491
+ ownedTrainables
492
+ );
493
+ }
494
+ const layer = {
495
+ inputNorm: await ensureNormTensor(
496
+ layerWeights.inputNorm,
497
+ hiddenSize,
498
+ `layer_${layerIdx}.input_norm`,
499
+ ownedTrainables
500
+ ),
501
+ qProj: await ensureTrainableTensor(
502
+ layerWeights.qProj,
503
+ [numHeads * headDim, hiddenSize],
504
+ `layer_${layerIdx}.q_proj`,
505
+ ownedTrainables
506
+ ),
507
+ kProj: await ensureTrainableTensor(
508
+ layerWeights.kProj,
509
+ [numKVHeads * headDim, hiddenSize],
510
+ `layer_${layerIdx}.k_proj`,
511
+ ownedTrainables
512
+ ),
513
+ vProj: await ensureTrainableTensor(
514
+ layerWeights.vProj,
515
+ [numKVHeads * headDim, hiddenSize],
516
+ `layer_${layerIdx}.v_proj`,
517
+ ownedTrainables
518
+ ),
519
+ oProj: await ensureTrainableTensor(
520
+ layerWeights.oProj,
521
+ [hiddenSize, hiddenSize],
522
+ `layer_${layerIdx}.o_proj`,
523
+ ownedTrainables
524
+ ),
525
+ postAttentionNorm: layerWeights.postAttentionNorm
526
+ ? await ensureNormTensor(
527
+ layerWeights.postAttentionNorm,
528
+ hiddenSize,
529
+ `layer_${layerIdx}.post_attention_norm`,
530
+ ownedTrainables
531
+ )
532
+ : null,
533
+ gateUp: layerGateUp,
534
+ down: await ensureTrainableTensor(
535
+ layerWeights.down || layerWeights.ffnDown,
536
+ [hiddenSize, intermediateSize],
537
+ `layer_${layerIdx}.ffn_down`,
538
+ ownedTrainables
539
+ ),
540
+ };
541
+ layers.push(layer);
542
+ layerParams.push(layer.inputNorm, layer.qProj, layer.kProj, layer.vProj, layer.oProj, layer.gateUp, layer.down);
543
+ if (layer.postAttentionNorm) {
544
+ layerParams.push(layer.postAttentionNorm);
545
+ }
546
+ }
547
+
548
+ const encoderParams = [embeddingWeight, ...layerParams];
549
+ const decoderParams = [finalNormWeight, lmHeadWeight];
550
+ const baseParams = [...encoderParams, ...decoderParams];
551
+ const temporaryInputs = new Set();
552
+
553
+ async function buildPromptTokens(prompt) {
554
+ const normalized = String(prompt || '').trim();
555
+ if (!normalized) {
556
+ throw new Error('Distill full-graph student prompt is empty.');
557
+ }
558
+ const tokenIds = studentPipeline.tokenizer.encode(normalized);
559
+ if (!Array.isArray(tokenIds) || tokenIds.length === 0) {
560
+ throw new Error('Distill full-graph student tokenizer produced no tokens.');
561
+ }
562
+ const tokenTensor = makeTensorFromUint32(
563
+ tokenIds,
564
+ [tokenIds.length],
565
+ 'distill_student_prompt_tokens'
566
+ );
567
+ temporaryInputs.add(tokenTensor);
568
+ return { tokenTensor, seqLen: tokenIds.length };
569
+ }
570
+
571
+ async function runTransformerPrompt(prompt, tape) {
572
+ const { tokenTensor, seqLen } = await buildPromptTokens(prompt);
573
+ let hidden = await tape.record(
574
+ OpType.EMBED,
575
+ (indices, embeddings) => runGather(
576
+ indices,
577
+ embeddings,
578
+ seqLen,
579
+ hiddenSize,
580
+ vocabSize,
581
+ {
582
+ embeddingDtype: resolveTensorDtype(embeddingWeight, 'f32'),
583
+ outputDtype: 'f32',
584
+ transpose: useEmbeddingTranspose,
585
+ }
586
+ ),
587
+ [tokenTensor, embeddingWeight],
588
+ {
589
+ numTokens: seqLen,
590
+ hiddenSize,
591
+ vocabSize,
592
+ transpose: useEmbeddingTranspose,
593
+ indexOffset: 0,
594
+ }
595
+ );
596
+
597
+ for (let layerIdx = 0; layerIdx < layers.length; layerIdx += 1) {
598
+ const layer = layers[layerIdx];
599
+ const normed = await tape.record(
600
+ OpType.RMSNORM,
601
+ (x, gamma) => runRMSNorm(x, gamma, rmsNormEps, {
602
+ batchSize: seqLen,
603
+ hiddenSize,
604
+ rmsNormWeightOffset: modelConfig.rmsNormWeightOffset === true,
605
+ }),
606
+ [hidden, layer.inputNorm],
607
+ { numTokens: seqLen, hiddenSize, eps: rmsNormEps }
608
+ );
609
+
610
+ const q2d = await tape.record(
611
+ OpType.MATMUL,
612
+ (x, w) => runMatmul(x, w, seqLen, numHeads * headDim, hiddenSize, {
613
+ transposeB: 'auto',
614
+ outputDtype: 'f32',
615
+ }),
616
+ [normed, layer.qProj],
617
+ { M: seqLen, N: numHeads * headDim, K: hiddenSize, transposeB: 'auto' }
618
+ );
619
+ const k2d = await tape.record(
620
+ OpType.MATMUL,
621
+ (x, w) => runMatmul(x, w, seqLen, numKVHeads * headDim, hiddenSize, {
622
+ transposeB: 'auto',
623
+ outputDtype: 'f32',
624
+ }),
625
+ [normed, layer.kProj],
626
+ { M: seqLen, N: numKVHeads * headDim, K: hiddenSize, transposeB: 'auto' }
627
+ );
628
+ const v2d = await tape.record(
629
+ OpType.MATMUL,
630
+ (x, w) => runMatmul(x, w, seqLen, numKVHeads * headDim, hiddenSize, {
631
+ transposeB: 'auto',
632
+ outputDtype: 'f32',
633
+ }),
634
+ [normed, layer.vProj],
635
+ { M: seqLen, N: numKVHeads * headDim, K: hiddenSize, transposeB: 'auto' }
636
+ );
637
+
638
+ const q3d = createTensor(q2d.buffer, q2d.dtype, [seqLen, numHeads, headDim], `layer_${layerIdx}_q`);
639
+ const k3d = createTensor(k2d.buffer, k2d.dtype, [seqLen, numKVHeads, headDim], `layer_${layerIdx}_k`);
640
+ const v3d = createTensor(v2d.buffer, v2d.dtype, [seqLen, numKVHeads, headDim], `layer_${layerIdx}_v`);
641
+
642
+ const qRope = await tape.record(
643
+ OpType.ROPE,
644
+ (q, cos, sin) => runRoPE(q, cos, sin, seqLen, { numHeads, headDim, startPos: 0 }),
645
+ [q3d, ropeCos, ropeSin],
646
+ { seqLen, numHeads, headDim, startPos: 0 }
647
+ );
648
+ const kRope = await tape.record(
649
+ OpType.ROPE,
650
+ (k, cos, sin) => runRoPE(k, cos, sin, seqLen, { numHeads: numKVHeads, headDim, startPos: 0 }),
651
+ [k3d, ropeCos, ropeSin],
652
+ { seqLen, numHeads: numKVHeads, headDim, startPos: 0 }
653
+ );
654
+
655
+ const attention = await tape.record(
656
+ OpType.ATTENTION,
657
+ (q, k, v) => runAttention(q, k, v, null, numHeads, headDim, {
658
+ seqLen,
659
+ kvLen: seqLen,
660
+ numKVHeads,
661
+ causal: true,
662
+ startPos: 0,
663
+ scale: 1 / Math.sqrt(headDim),
664
+ }),
665
+ [qRope, kRope, v3d],
666
+ { seqLen, numHeads, headDim, scale: 1 / Math.sqrt(headDim), causal: true, recomputeForward: true }
667
+ );
668
+ const attention2d = createTensor(
669
+ attention.buffer,
670
+ attention.dtype,
671
+ [seqLen, hiddenSize],
672
+ `layer_${layerIdx}_attn_2d`
673
+ );
674
+
675
+ const attentionOutput = await tape.record(
676
+ OpType.MATMUL,
677
+ (x, w) => runMatmul(x, w, seqLen, hiddenSize, hiddenSize, {
678
+ transposeB: 'auto',
679
+ outputDtype: 'f32',
680
+ }),
681
+ [attention2d, layer.oProj],
682
+ { M: seqLen, N: hiddenSize, K: hiddenSize, transposeB: 'auto' }
683
+ );
684
+ const postAttention = await tape.record(
685
+ OpType.RESIDUAL_ADD,
686
+ (a, b) => runResidualAdd(a, b, seqLen * hiddenSize),
687
+ [attentionOutput, hidden],
688
+ { size: seqLen * hiddenSize }
689
+ );
690
+
691
+ const ffnInput = layer.postAttentionNorm
692
+ ? await tape.record(
693
+ OpType.RMSNORM,
694
+ (x, gamma) => runRMSNorm(x, gamma, rmsNormEps, {
695
+ batchSize: seqLen,
696
+ hiddenSize,
697
+ rmsNormWeightOffset: modelConfig.rmsNormWeightOffset === true,
698
+ }),
699
+ [postAttention, layer.postAttentionNorm],
700
+ { numTokens: seqLen, hiddenSize, eps: rmsNormEps }
701
+ )
702
+ : postAttention;
703
+ const gateUp = await tape.record(
704
+ OpType.MATMUL,
705
+ (x, w) => runMatmul(x, w, seqLen, intermediateSize * 2, hiddenSize, {
706
+ transposeB: 'auto',
707
+ outputDtype: 'f32',
708
+ }),
709
+ [ffnInput, layer.gateUp],
710
+ { M: seqLen, N: intermediateSize * 2, K: hiddenSize, transposeB: 'auto' }
711
+ );
712
+ const activated = await tape.record(
713
+ OpType.SILU_ROWSPLIT,
714
+ (x) => runSiLURowSplit(x, {
715
+ numTokens: seqLen,
716
+ dim: intermediateSize,
717
+ activation: hiddenActivation === 'gelu' ? 'gelu' : 'silu',
718
+ swigluLimit: hiddenActivation === 'gelu' ? null : swigluLimit,
719
+ }),
720
+ [gateUp],
721
+ {
722
+ numTokens: seqLen,
723
+ dim: intermediateSize,
724
+ activation: hiddenActivation === 'gelu' ? 'gelu' : 'silu',
725
+ swigluLimit: hiddenActivation === 'gelu' ? 0 : swigluLimit,
726
+ }
727
+ );
728
+ const ffnOutput = await tape.record(
729
+ OpType.MATMUL,
730
+ (x, w) => runMatmul(x, w, seqLen, hiddenSize, intermediateSize, {
731
+ transposeB: 'auto',
732
+ outputDtype: 'f32',
733
+ }),
734
+ [activated, layer.down],
735
+ { M: seqLen, N: hiddenSize, K: intermediateSize, transposeB: 'auto' }
736
+ );
737
+ hidden = await tape.record(
738
+ OpType.RESIDUAL_ADD,
739
+ (a, b) => runResidualAdd(a, b, seqLen * hiddenSize),
740
+ [ffnOutput, postAttention],
741
+ { size: seqLen * hiddenSize }
742
+ );
743
+ }
744
+
745
+ const finalHidden = await tape.record(
746
+ OpType.RMSNORM,
747
+ (x, gamma) => runRMSNorm(x, gamma, rmsNormEps, {
748
+ batchSize: seqLen,
749
+ hiddenSize,
750
+ rmsNormWeightOffset: modelConfig.rmsNormWeightOffset === true,
751
+ }),
752
+ [hidden, finalNormWeight],
753
+ { numTokens: seqLen, hiddenSize, eps: rmsNormEps }
754
+ );
755
+ const lastHidden = await tape.record(
756
+ OpType.ROW_SLICE,
757
+ (x) => createRowSliceTensor(x, seqLen, hiddenSize, seqLen - 1, 'distill_last_hidden'),
758
+ [finalHidden],
759
+ { rows: seqLen, cols: hiddenSize, rowIndex: seqLen - 1 }
760
+ );
761
+ return tape.record(
762
+ OpType.MATMUL,
763
+ (x, w) => runMatmul(x, w, 1, vocabSize, hiddenSize, {
764
+ transposeB: 'auto',
765
+ outputDtype: 'f32',
766
+ }),
767
+ [lastHidden, lmHeadWeight],
768
+ { M: 1, N: vocabSize, K: hiddenSize, transposeB: 'auto' }
769
+ );
770
+ }
771
+
772
+ const model = {
773
+ async forward(inputTensor, tape) {
774
+ return tape.record(
775
+ OpType.MATMUL,
776
+ (x, w) => runMatmul(x, w, 1, vocabSize, hiddenSize, {
777
+ transposeB: 'auto',
778
+ outputDtype: 'f32',
779
+ }),
780
+ [inputTensor, lmHeadWeight],
781
+ { M: 1, N: vocabSize, K: hiddenSize, transposeB: 'auto' }
782
+ );
783
+ },
784
+ async forwardDistill(batch, tape, forwardOptions = {}) {
785
+ const requestedPhase = String(forwardOptions?.phase || 'anchor').trim();
786
+ const phase = requestedPhase === 'positive'
787
+ ? 'positive'
788
+ : (requestedPhase === 'negative' ? 'negative' : 'anchor');
789
+ const prompts = resolvePhasePrompts(batch, phase);
790
+ if (prompts.length !== 1) {
791
+ throw new Error(
792
+ `Distill full-graph student currently requires batchSize=1, got ${prompts.length}.`
793
+ );
794
+ }
795
+ const logits = await runTransformerPrompt(prompts[0], tape);
796
+ return { logits };
797
+ },
798
+ cleanupDistillStep() {
799
+ for (const tensor of temporaryInputs) {
800
+ releaseTensor(tensor);
801
+ }
802
+ temporaryInputs.clear();
803
+ },
804
+ loraParams() {
805
+ return decoderParams;
806
+ },
807
+ paramGroups() {
808
+ return {
809
+ encoder: encoderParams,
810
+ prior: [],
811
+ decoder: decoderParams,
812
+ base: baseParams,
813
+ lora: [],
814
+ };
815
+ },
816
+ };
817
+
818
+ return {
819
+ config,
820
+ model,
821
+ outputDim: vocabSize,
822
+ embeddingDim: hiddenSize,
823
+ cleanup() {
824
+ model.cleanupDistillStep();
825
+ for (const tensor of ownedTrainables) {
826
+ releaseTensor(tensor);
827
+ }
828
+ ownedTrainables.clear();
829
+ },
830
+ };
831
+ }
832
+
833
+ export async function createDistillStudentRuntimeModelFixture(overrides = {}, options = {}) {
834
+ const distillRuntime = options.distillRuntime && typeof options.distillRuntime === 'object'
835
+ ? options.distillRuntime
836
+ : null;
837
+ const graphMode = normalizeDistillStudentGraphMode(
838
+ options.studentGraphMode
839
+ ?? distillRuntime?.studentGraphMode
840
+ ?? overrides?.training?.distill?.studentGraphMode
841
+ );
842
+ if (graphMode === DISTILL_STUDENT_GRAPH_PROJECTION) {
843
+ return createDistillStudentProjectionModelFixture(overrides, options);
844
+ }
845
+ return createDistillStudentTransformerModelFixture(overrides, options);
846
+ }