@simulatte/doppler 0.1.5 → 0.1.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (392) hide show
  1. package/CHANGELOG.md +126 -0
  2. package/README.md +25 -17
  3. package/package.json +20 -4
  4. package/src/adapters/adapter-registry.js +12 -1
  5. package/src/adapters/lora-loader.js +23 -6
  6. package/src/bridge/extension-client.d.ts +5 -0
  7. package/src/bridge/extension-client.js +40 -0
  8. package/src/bridge/index.d.ts +2 -1
  9. package/src/bridge/index.js +6 -4
  10. package/src/browser/browser-converter.js +26 -1
  11. package/src/browser/file-picker.js +6 -0
  12. package/src/browser/safetensors-parser-browser.js +84 -1
  13. package/src/browser/shard-io-browser.js +2 -2
  14. package/src/browser/tensor-source-download.js +8 -2
  15. package/src/browser/tensor-source-http.d.ts +1 -0
  16. package/src/browser/tensor-source-http.js +5 -1
  17. package/src/client/doppler-api.browser.js +20 -4
  18. package/src/client/doppler-api.js +19 -3
  19. package/src/client/doppler-provider/generation.js +12 -0
  20. package/src/client/doppler-provider/model-manager.d.ts +10 -0
  21. package/src/client/doppler-provider/model-manager.js +91 -19
  22. package/src/client/doppler-provider/source-runtime.d.ts +2 -1
  23. package/src/client/doppler-provider/source-runtime.js +132 -13
  24. package/src/client/doppler-registry.json +8 -7
  25. package/src/config/backward-registry-loader.js +17 -2
  26. package/src/config/execution-v0-contract-check.js +113 -15
  27. package/src/config/kernel-path-contract-check.js +57 -29
  28. package/src/config/kernel-path-loader.js +5 -36
  29. package/src/config/kernels/kernel-ref-digests.js +39 -39
  30. package/src/config/kernels/registry.js +14 -1
  31. package/src/config/kernels/registry.json +49 -7
  32. package/src/config/loader.d.ts +1 -1
  33. package/src/config/loader.js +43 -4
  34. package/src/config/merge-contract-check.js +59 -4
  35. package/src/config/merge-helpers.js +128 -7
  36. package/src/config/merge.d.ts +1 -0
  37. package/src/config/merge.js +28 -0
  38. package/src/config/param-validator.js +47 -2
  39. package/src/config/presets/kernel-paths/{gemma2-q4k-dequant-f32a.json → gemma2-q4k-dequant-f32a-nosubgroups.json} +3 -3
  40. package/src/config/presets/kernel-paths/gemma3-f16-fused-f32a-online-streamingprefill.json +223 -0
  41. package/src/config/presets/kernel-paths/{gemma3-q4k-dequant-f32a.json → gemma3-q4k-dequant-f32a-nosubgroups.json} +3 -3
  42. package/src/config/presets/kernel-paths/registry.json +29 -8
  43. package/src/config/presets/models/gemma2.json +2 -2
  44. package/src/config/presets/models/qwen3.json +9 -2
  45. package/src/config/presets/models/transformer.json +5 -0
  46. package/src/config/presets/runtime/experiments/bench/gemma3-bench-q4k.json +1 -1
  47. package/src/config/presets/runtime/experiments/debug/gemma3-debug-q4k.json +1 -1
  48. package/src/config/presets/runtime/experiments/verify/gemma3-verify.json +1 -1
  49. package/src/config/presets/runtime/kernels/dequant-f16-q4k.json +6 -13
  50. package/src/config/presets/runtime/kernels/dequant-f32-q4k.json +6 -13
  51. package/src/config/presets/runtime/kernels/embeddinggemma-q4k-dequant-f32a.json +37 -0
  52. package/src/config/presets/runtime/kernels/fused-q4k.json +6 -13
  53. package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f16a.json +33 -0
  54. package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f32a-nosubgroups.json +33 -0
  55. package/src/config/presets/runtime/kernels/gemma2-q4k-fused-f32a.json +33 -0
  56. package/src/config/presets/runtime/kernels/safe-q4k.json +6 -13
  57. package/src/config/presets/runtime/platform/metal-apple-q4k.json +1 -1
  58. package/src/config/required-inference-fields-contract-check.js +6 -0
  59. package/src/config/runtime.js +6 -1
  60. package/src/config/schema/debug.schema.d.ts +5 -0
  61. package/src/config/schema/doppler.schema.js +16 -21
  62. package/src/config/schema/inference-defaults.schema.js +6 -3
  63. package/src/config/schema/inference.schema.d.ts +9 -0
  64. package/src/config/schema/kernel-path.schema.d.ts +11 -1
  65. package/src/config/schema/kernel-thresholds.schema.js +12 -4
  66. package/src/config/schema/manifest.schema.d.ts +8 -1
  67. package/src/config/schema/manifest.schema.js +19 -3
  68. package/src/config/training-defaults.js +30 -22
  69. package/src/converter/conversion-plan.js +94 -9
  70. package/src/converter/core.d.ts +7 -0
  71. package/src/converter/core.js +14 -9
  72. package/src/converter/execution-v0-manifest.js +4 -1
  73. package/src/converter/index.d.ts +1 -0
  74. package/src/converter/index.js +1 -0
  75. package/src/converter/manifest-inference.js +43 -12
  76. package/src/converter/parsers/diffusion.js +0 -3
  77. package/src/converter/quantization-info.js +35 -15
  78. package/src/converter/rope-config.js +42 -0
  79. package/src/converter/shard-packer.d.ts +1 -1
  80. package/src/converter/shard-packer.js +4 -1
  81. package/src/debug/config.js +123 -11
  82. package/src/debug/signals.js +7 -1
  83. package/src/debug/tensor.d.ts +2 -0
  84. package/src/debug/tensor.js +13 -2
  85. package/src/distribution/p2p-control-plane.js +52 -12
  86. package/src/distribution/p2p-observability.js +43 -7
  87. package/src/distribution/p2p-webrtc-browser.js +20 -0
  88. package/src/distribution/shard-delivery.js +77 -26
  89. package/src/formats/gguf/types.js +33 -16
  90. package/src/formats/rdrr/groups.d.ts +12 -4
  91. package/src/formats/rdrr/groups.js +3 -6
  92. package/src/formats/rdrr/parsing.js +39 -2
  93. package/src/formats/rdrr/types.d.ts +2 -1
  94. package/src/gpu/command-recorder.js +86 -61
  95. package/src/gpu/device.d.ts +1 -0
  96. package/src/gpu/device.js +131 -19
  97. package/src/gpu/kernel-tuner/benchmarks.js +326 -316
  98. package/src/gpu/kernel-tuner/cache.js +71 -4
  99. package/src/gpu/kernel-tuner/tuner.js +22 -4
  100. package/src/gpu/kernels/attention.js +113 -34
  101. package/src/gpu/kernels/backward/adam.js +62 -58
  102. package/src/gpu/kernels/backward/attention_backward.js +257 -169
  103. package/src/gpu/kernels/backward/conv2d_backward.js +14 -1
  104. package/src/gpu/kernels/bias_add.wgsl +8 -6
  105. package/src/gpu/kernels/bias_add_f16.wgsl +8 -5
  106. package/src/gpu/kernels/cast.js +191 -149
  107. package/src/gpu/kernels/check-stop.js +33 -44
  108. package/src/gpu/kernels/conv2d.js +27 -17
  109. package/src/gpu/kernels/conv2d.wgsl +7 -8
  110. package/src/gpu/kernels/conv2d_f16.wgsl +7 -8
  111. package/src/gpu/kernels/cross_entropy_loss.js +21 -15
  112. package/src/gpu/kernels/depthwise_conv2d.js +37 -26
  113. package/src/gpu/kernels/depthwise_conv2d.wgsl +6 -9
  114. package/src/gpu/kernels/depthwise_conv2d_f16.wgsl +6 -9
  115. package/src/gpu/kernels/dequant.js +178 -126
  116. package/src/gpu/kernels/energy.d.ts +3 -21
  117. package/src/gpu/kernels/energy.js +111 -88
  118. package/src/gpu/kernels/feature-check.js +1 -1
  119. package/src/gpu/kernels/fused_ffn.js +84 -65
  120. package/src/gpu/kernels/fused_matmul_residual.js +56 -33
  121. package/src/gpu/kernels/fused_matmul_rmsnorm.js +62 -45
  122. package/src/gpu/kernels/gather.js +33 -15
  123. package/src/gpu/kernels/gelu.js +19 -11
  124. package/src/gpu/kernels/grouped_pointwise_conv2d.js +34 -23
  125. package/src/gpu/kernels/grouped_pointwise_conv2d.wgsl +6 -9
  126. package/src/gpu/kernels/grouped_pointwise_conv2d_f16.wgsl +6 -9
  127. package/src/gpu/kernels/groupnorm.js +34 -23
  128. package/src/gpu/kernels/kv-quantize.js +5 -2
  129. package/src/gpu/kernels/layernorm.js +35 -19
  130. package/src/gpu/kernels/logit-merge.js +5 -3
  131. package/src/gpu/kernels/matmul.js +83 -39
  132. package/src/gpu/kernels/modulate.js +23 -15
  133. package/src/gpu/kernels/moe.js +221 -175
  134. package/src/gpu/kernels/pixel_shuffle.js +22 -14
  135. package/src/gpu/kernels/pixel_shuffle.wgsl +4 -5
  136. package/src/gpu/kernels/pixel_shuffle_f16.wgsl +4 -5
  137. package/src/gpu/kernels/relu.js +31 -10
  138. package/src/gpu/kernels/relu.wgsl +2 -1
  139. package/src/gpu/kernels/relu_f16.wgsl +2 -1
  140. package/src/gpu/kernels/repeat_channels.js +25 -17
  141. package/src/gpu/kernels/repeat_channels.wgsl +4 -5
  142. package/src/gpu/kernels/repeat_channels_f16.wgsl +4 -5
  143. package/src/gpu/kernels/residual.js +69 -23
  144. package/src/gpu/kernels/residual.wgsl +6 -3
  145. package/src/gpu/kernels/residual_f16.wgsl +2 -1
  146. package/src/gpu/kernels/residual_f16_vec4.wgsl +2 -1
  147. package/src/gpu/kernels/residual_vec4.wgsl +2 -1
  148. package/src/gpu/kernels/rmsnorm.js +96 -28
  149. package/src/gpu/kernels/rmsnorm.wgsl +14 -6
  150. package/src/gpu/kernels/rmsnorm_f16.wgsl +10 -2
  151. package/src/gpu/kernels/rope.d.ts +2 -0
  152. package/src/gpu/kernels/rope.js +14 -1
  153. package/src/gpu/kernels/rope.wgsl +56 -40
  154. package/src/gpu/kernels/sample.js +27 -38
  155. package/src/gpu/kernels/sana_linear_attention.js +19 -12
  156. package/src/gpu/kernels/sana_linear_attention_apply.wgsl +4 -5
  157. package/src/gpu/kernels/sana_linear_attention_apply_f16.wgsl +4 -5
  158. package/src/gpu/kernels/sana_linear_attention_summary.wgsl +4 -0
  159. package/src/gpu/kernels/sana_linear_attention_summary_f16.wgsl +4 -0
  160. package/src/gpu/kernels/scale.js +18 -11
  161. package/src/gpu/kernels/shader-cache.js +4 -2
  162. package/src/gpu/kernels/silu.d.ts +1 -0
  163. package/src/gpu/kernels/silu.js +148 -82
  164. package/src/gpu/kernels/silu.wgsl +19 -9
  165. package/src/gpu/kernels/silu_f16.wgsl +19 -9
  166. package/src/gpu/kernels/softmax.js +44 -25
  167. package/src/gpu/kernels/split_qkv.js +23 -13
  168. package/src/gpu/kernels/transpose.js +31 -10
  169. package/src/gpu/kernels/transpose.wgsl +6 -5
  170. package/src/gpu/kernels/upsample2d.js +22 -13
  171. package/src/gpu/kernels/upsample2d.wgsl +6 -9
  172. package/src/gpu/kernels/upsample2d_f16.wgsl +6 -9
  173. package/src/gpu/kernels/utils.js +35 -13
  174. package/src/gpu/partitioned-buffer-pool.js +10 -2
  175. package/src/gpu/perf-guards.js +2 -9
  176. package/src/gpu/profiler.js +27 -22
  177. package/src/gpu/readback-utils.d.ts +16 -0
  178. package/src/gpu/readback-utils.js +41 -0
  179. package/src/gpu/submit-tracker.js +13 -0
  180. package/src/gpu/uniform-cache.d.ts +1 -0
  181. package/src/gpu/uniform-cache.js +30 -9
  182. package/src/hotswap/intent-bundle.js +6 -0
  183. package/src/hotswap/manifest.d.ts +10 -1
  184. package/src/hotswap/manifest.js +12 -2
  185. package/src/hotswap/runtime.js +30 -8
  186. package/src/index-browser.d.ts +44 -0
  187. package/src/index-browser.js +14 -0
  188. package/src/inference/browser-harness-contract-helpers.d.ts +5 -0
  189. package/src/inference/browser-harness-contract-helpers.js +28 -0
  190. package/src/inference/browser-harness-diffusion-energy-suites.d.ts +2 -0
  191. package/src/inference/browser-harness-diffusion-energy-suites.js +269 -0
  192. package/src/inference/browser-harness-model-helpers.d.ts +16 -0
  193. package/src/inference/browser-harness-model-helpers.js +217 -0
  194. package/src/inference/browser-harness-report-helpers.d.ts +7 -0
  195. package/src/inference/browser-harness-report-helpers.js +42 -0
  196. package/src/inference/browser-harness-runtime-helpers.d.ts +61 -0
  197. package/src/inference/browser-harness-runtime-helpers.js +415 -0
  198. package/src/inference/browser-harness-suite-helpers.d.ts +28 -0
  199. package/src/inference/browser-harness-suite-helpers.js +268 -0
  200. package/src/inference/browser-harness-text-helpers.d.ts +27 -0
  201. package/src/inference/browser-harness-text-helpers.js +788 -0
  202. package/src/inference/browser-harness.d.ts +6 -0
  203. package/src/inference/browser-harness.js +130 -1950
  204. package/src/inference/kv-cache/base.js +140 -94
  205. package/src/inference/kv-cache/tiered.js +5 -3
  206. package/src/inference/moe-router.js +88 -56
  207. package/src/inference/multi-model-network.js +5 -3
  208. package/src/inference/network-evolution.d.ts +11 -2
  209. package/src/inference/network-evolution.js +20 -21
  210. package/src/inference/pipelines/context.d.ts +3 -0
  211. package/src/inference/pipelines/context.js +142 -2
  212. package/src/inference/pipelines/diffusion/helpers.js +7 -2
  213. package/src/inference/pipelines/diffusion/pipeline.js +17 -7
  214. package/src/inference/pipelines/diffusion/sd3-transformer.js +10 -10
  215. package/src/inference/pipelines/diffusion/text-encoder-gpu.d.ts +5 -0
  216. package/src/inference/pipelines/diffusion/text-encoder-gpu.js +27 -15
  217. package/src/inference/pipelines/diffusion/vae.js +3 -7
  218. package/src/inference/pipelines/energy/pipeline.js +27 -21
  219. package/src/inference/pipelines/energy/quintel.d.ts +5 -0
  220. package/src/inference/pipelines/energy/quintel.js +11 -0
  221. package/src/inference/pipelines/energy-head/row-head-pipeline.js +17 -13
  222. package/src/inference/pipelines/structured/json-head-pipeline.js +26 -11
  223. package/src/inference/pipelines/text/attention/projections.js +151 -101
  224. package/src/inference/pipelines/text/attention/record.js +73 -10
  225. package/src/inference/pipelines/text/attention/run.js +73 -10
  226. package/src/inference/pipelines/text/chat-format.js +25 -1
  227. package/src/inference/pipelines/text/config.d.ts +4 -0
  228. package/src/inference/pipelines/text/config.js +71 -5
  229. package/src/inference/pipelines/text/embed.js +2 -8
  230. package/src/inference/pipelines/text/execution-plan.js +64 -50
  231. package/src/inference/pipelines/text/execution-v0-contract-helpers.d.ts +59 -0
  232. package/src/inference/pipelines/text/execution-v0-contract-helpers.js +937 -0
  233. package/src/inference/pipelines/text/execution-v0-runtime-builders.d.ts +15 -0
  234. package/src/inference/pipelines/text/execution-v0-runtime-builders.js +279 -0
  235. package/src/inference/pipelines/text/execution-v0.js +78 -1002
  236. package/src/inference/pipelines/text/ffn/standard.js +3 -0
  237. package/src/inference/pipelines/text/generator-steps.d.ts +46 -0
  238. package/src/inference/pipelines/text/generator-steps.js +298 -207
  239. package/src/inference/pipelines/text/generator.js +6 -23
  240. package/src/inference/pipelines/text/init.d.ts +4 -0
  241. package/src/inference/pipelines/text/init.js +134 -29
  242. package/src/inference/pipelines/text/kernel-path-auto-select.js +2 -0
  243. package/src/inference/pipelines/text/kernel-trace.d.ts +2 -0
  244. package/src/inference/pipelines/text/kernel-trace.js +6 -0
  245. package/src/inference/pipelines/text/layer.js +14 -9
  246. package/src/inference/pipelines/text/linear-attention.d.ts +10 -0
  247. package/src/inference/pipelines/text/linear-attention.js +80 -6
  248. package/src/inference/pipelines/text/logits/gpu.js +10 -5
  249. package/src/inference/pipelines/text/logits/index.js +10 -11
  250. package/src/inference/pipelines/text/logits/utils.d.ts +7 -0
  251. package/src/inference/pipelines/text/logits/utils.js +9 -0
  252. package/src/inference/pipelines/text/lora-apply.js +50 -32
  253. package/src/inference/pipelines/text/model-load.js +279 -104
  254. package/src/inference/pipelines/text/moe-cache.js +5 -4
  255. package/src/inference/pipelines/text/moe-cpu-gptoss.js +74 -69
  256. package/src/inference/pipelines/text/moe-cpu.js +42 -38
  257. package/src/inference/pipelines/text/moe-gpu.js +110 -86
  258. package/src/inference/pipelines/text/ops.js +90 -90
  259. package/src/inference/pipelines/text/probes.js +9 -9
  260. package/src/inference/pipelines/text/weights.js +17 -7
  261. package/src/inference/pipelines/text.js +17 -1
  262. package/src/inference/speculative.d.ts +2 -2
  263. package/src/inference/speculative.js +4 -18
  264. package/src/inference/test-harness.d.ts +1 -1
  265. package/src/inference/test-harness.js +15 -5
  266. package/src/inference/tokenizer.d.ts +0 -5
  267. package/src/inference/tokenizer.js +4 -23
  268. package/src/inference/tokenizers/bpe.js +9 -0
  269. package/src/inference/tokenizers/bundled.js +176 -33
  270. package/src/inference/tokenizers/sentencepiece.js +12 -0
  271. package/src/loader/doppler-loader.js +38 -22
  272. package/src/loader/dtype-utils.js +3 -44
  273. package/src/loader/embedding-loader.js +7 -3
  274. package/src/loader/experts/expert-cache.js +13 -6
  275. package/src/loader/experts/expert-loader.js +10 -6
  276. package/src/loader/final-weights-loader.js +8 -4
  277. package/src/loader/layer-loader.js +2 -1
  278. package/src/loader/loader-state.js +2 -2
  279. package/src/loader/memory-monitor.js +8 -0
  280. package/src/loader/multi-model-loader.d.ts +14 -0
  281. package/src/loader/multi-model-loader.js +70 -24
  282. package/src/loader/shard-cache.js +81 -12
  283. package/src/loader/shard-resolver.js +25 -3
  284. package/src/loader/tensors/tensor-loader.js +209 -144
  285. package/src/loader/tensors/tensor-reader.js +76 -19
  286. package/src/loader/weight-downcast.js +1 -1
  287. package/src/memory/buffer-pool.d.ts +9 -1
  288. package/src/memory/buffer-pool.js +109 -44
  289. package/src/memory/unified-detect.js +1 -1
  290. package/src/rules/inference/kernel-path.rules.json +24 -8
  291. package/src/rules/rule-registry.js +25 -1
  292. package/src/rules/tooling/command-runtime.rules.json +18 -0
  293. package/src/storage/backends/opfs-store.js +68 -24
  294. package/src/storage/downloader.js +364 -83
  295. package/src/storage/index.d.ts +3 -0
  296. package/src/storage/index.js +3 -0
  297. package/src/storage/preflight.d.ts +2 -2
  298. package/src/storage/preflight.js +24 -2
  299. package/src/storage/quickstart-downloader.js +11 -5
  300. package/src/storage/registry.js +10 -4
  301. package/src/storage/reports.js +1 -1
  302. package/src/storage/shard-manager.d.ts +15 -1
  303. package/src/storage/shard-manager.js +51 -3
  304. package/src/storage/source-artifact-store.d.ts +52 -0
  305. package/src/storage/source-artifact-store.js +234 -0
  306. package/src/tooling/command-api-constants.d.ts +9 -0
  307. package/src/tooling/command-api-constants.js +9 -0
  308. package/src/tooling/command-api-family-normalizers.d.ts +9 -0
  309. package/src/tooling/command-api-family-normalizers.js +343 -0
  310. package/src/tooling/command-api-helpers.d.ts +25 -0
  311. package/src/tooling/command-api-helpers.js +262 -0
  312. package/src/tooling/command-api.d.ts +27 -1
  313. package/src/tooling/command-api.js +26 -473
  314. package/src/tooling/command-envelope.js +4 -1
  315. package/src/tooling/command-runner-shared.js +52 -18
  316. package/src/tooling/lean-execution-contract.js +150 -3
  317. package/src/tooling/node-browser-command-runner.d.ts +4 -0
  318. package/src/tooling/node-browser-command-runner.js +218 -273
  319. package/src/tooling/node-command-runner.js +44 -3
  320. package/src/tooling/node-converter.js +27 -1
  321. package/src/tooling/node-source-runtime.d.ts +1 -1
  322. package/src/tooling/node-source-runtime.js +84 -3
  323. package/src/tooling/node-webgpu.js +30 -105
  324. package/src/tooling/opfs-cache.js +21 -4
  325. package/src/tooling/runtime-input-composition.d.ts +38 -0
  326. package/src/tooling/runtime-input-composition.js +86 -0
  327. package/src/tooling/source-runtime-bundle.d.ts +40 -5
  328. package/src/tooling/source-runtime-bundle.js +261 -34
  329. package/src/tooling/source-runtime-materializer.d.ts +6 -0
  330. package/src/tooling/source-runtime-materializer.js +93 -0
  331. package/src/training/attention-backward.js +32 -17
  332. package/src/training/autograd.js +80 -52
  333. package/src/training/checkpoint-watch.d.ts +8 -0
  334. package/src/training/checkpoint-watch.js +139 -0
  335. package/src/training/checkpoint.d.ts +6 -1
  336. package/src/training/checkpoint.js +46 -7
  337. package/src/training/clip.js +2 -1
  338. package/src/training/datasets/token-batch.js +20 -8
  339. package/src/training/distillation/artifacts.d.ts +71 -0
  340. package/src/training/distillation/artifacts.js +132 -0
  341. package/src/training/distillation/checkpoint-watch.d.ts +10 -0
  342. package/src/training/distillation/checkpoint-watch.js +58 -0
  343. package/src/training/distillation/dataset.d.ts +59 -0
  344. package/src/training/distillation/dataset.js +337 -0
  345. package/src/training/distillation/eval.d.ts +34 -0
  346. package/src/training/distillation/eval.js +310 -0
  347. package/src/training/distillation/index.d.ts +29 -0
  348. package/src/training/distillation/index.js +29 -0
  349. package/src/training/distillation/runtime.d.ts +20 -0
  350. package/src/training/distillation/runtime.js +121 -0
  351. package/src/training/distillation/scoreboard.d.ts +6 -0
  352. package/src/training/distillation/scoreboard.js +8 -0
  353. package/src/training/distillation/stage-a.d.ts +45 -0
  354. package/src/training/distillation/stage-a.js +338 -0
  355. package/src/training/distillation/stage-b.d.ts +24 -0
  356. package/src/training/distillation/stage-b.js +20 -0
  357. package/src/training/distillation/student-fixture.d.ts +22 -0
  358. package/src/training/distillation/student-fixture.js +846 -0
  359. package/src/training/distillation/suite-data.d.ts +45 -0
  360. package/src/training/distillation/suite-data.js +189 -0
  361. package/src/training/index.d.ts +10 -0
  362. package/src/training/index.js +10 -0
  363. package/src/training/lora-pipeline.d.ts +40 -0
  364. package/src/training/lora-pipeline.js +793 -0
  365. package/src/training/lora.js +26 -12
  366. package/src/training/loss.js +5 -6
  367. package/src/training/objectives/cross_entropy.js +2 -5
  368. package/src/training/objectives/distill_kd.js +4 -8
  369. package/src/training/objectives/distill_triplet.js +4 -8
  370. package/src/training/objectives/ul_stage2_base.js +4 -8
  371. package/src/training/operator-artifacts.d.ts +62 -0
  372. package/src/training/operator-artifacts.js +140 -0
  373. package/src/training/operator-command.d.ts +5 -0
  374. package/src/training/operator-command.js +455 -0
  375. package/src/training/operator-eval.d.ts +48 -0
  376. package/src/training/operator-eval.js +230 -0
  377. package/src/training/operator-scoreboard.d.ts +5 -0
  378. package/src/training/operator-scoreboard.js +44 -0
  379. package/src/training/optimizer.js +19 -7
  380. package/src/training/runner.d.ts +52 -0
  381. package/src/training/runner.js +31 -5
  382. package/src/training/suite.d.ts +112 -0
  383. package/src/training/suite.js +24 -984
  384. package/src/training/tensor-factory.d.ts +9 -0
  385. package/src/training/tensor-factory.js +13 -0
  386. package/src/training/trainer.js +3 -5
  387. package/src/training/ul_dataset.js +3 -5
  388. package/src/training/workloads.d.ts +164 -0
  389. package/src/training/workloads.js +530 -0
  390. package/src/version.js +1 -1
  391. package/tools/convert-safetensors-node.js +22 -16
  392. package/tools/doppler-cli.js +179 -63
@@ -0,0 +1,793 @@
1
+ import { mkdir, readFile, readdir, writeFile } from 'node:fs/promises';
2
+ import { join, resolve } from 'node:path';
3
+
4
+ import { loadBackwardRegistry } from '../config/backward-registry-loader.js';
5
+ import { acquireBuffer, readBuffer, releaseBuffer, uploadData } from '../memory/buffer-pool.js';
6
+ import { runMatmul } from '../gpu/kernels/index.js';
7
+ import { runResidualAdd } from '../gpu/kernels/residual.js';
8
+ import { parseJsonl } from './datasets/jsonl.js';
9
+ import { LoraAdapter } from './lora.js';
10
+ import { TrainingRunner, restoreTrainingCheckpointState } from './runner.js';
11
+ import { AdamOptimizer } from './optimizer.js';
12
+ import { crossEntropyLoss } from './loss.js';
13
+ import { clipGradients } from './clip.js';
14
+ import { OpType, AutogradTape } from './autograd.js';
15
+ import { loadCheckpoint } from './checkpoint.js';
16
+ import { exportLoRAAdapter } from './export.js';
17
+ import { computeEvalMetrics } from './operator-eval.js';
18
+ import { appendScoreboardRow } from './operator-scoreboard.js';
19
+ import {
20
+ buildArtifactBase,
21
+ createTrainingRunLayout,
22
+ hashArtifactPayload,
23
+ writeJsonArtifact,
24
+ writeRunContract,
25
+ writeWorkloadLock,
26
+ } from './operator-artifacts.js';
27
+ import { watchFinalizedCheckpoints } from './checkpoint-watch.js';
28
+ import { loadLoRAFromManifest } from '../adapters/lora-loader.js';
29
+ import { createUploadedTensor } from './tensor-factory.js';
30
+
31
+ function stableSortObject(value) {
32
+ if (Array.isArray(value)) {
33
+ return value.map((entry) => stableSortObject(entry));
34
+ }
35
+ if (!value || typeof value !== 'object') {
36
+ return value;
37
+ }
38
+ const sorted = {};
39
+ for (const key of Object.keys(value).sort()) {
40
+ sorted[key] = stableSortObject(value[key]);
41
+ }
42
+ return sorted;
43
+ }
44
+
45
+ function stableJson(value) {
46
+ return JSON.stringify(stableSortObject(value));
47
+ }
48
+
49
+ function makeTensorFromFloat32(values, shape, label) {
50
+ const data = values instanceof Float32Array ? values : new Float32Array(values);
51
+ return createUploadedTensor(data, 'f32', shape, label);
52
+ }
53
+
54
+ function makeTensorFromUint32(values, shape, label) {
55
+ const data = values instanceof Uint32Array ? values : new Uint32Array(values);
56
+ return createUploadedTensor(data, 'u32', shape, label);
57
+ }
58
+
59
+ function releaseTensor(tensor) {
60
+ if (!tensor?.buffer) return;
61
+ releaseBuffer(tensor.buffer);
62
+ }
63
+
64
+ function createToyLoraModel(workload) {
65
+ const targetModule = workload.pipeline.adapter.targetModules[0];
66
+ if (!targetModule) {
67
+ throw new Error('LoRA workload requires at least one adapter target module.');
68
+ }
69
+ const baseWeight = makeTensorFromFloat32(
70
+ [0.08, -0.12, 0.16, 0.22, -0.03, 0.09],
71
+ [3, 2],
72
+ 'lora_toy_base_weight'
73
+ );
74
+ const adapter = new LoraAdapter({
75
+ inDim: 3,
76
+ outDim: 2,
77
+ rank: workload.pipeline.adapter.rank,
78
+ alpha: workload.pipeline.adapter.alpha,
79
+ });
80
+ const model = {
81
+ adapter,
82
+ baseWeight,
83
+ targetModule,
84
+ async forward(inputTensor, tape) {
85
+ const batchSize = Number.isInteger(inputTensor?.shape?.[0]) ? inputTensor.shape[0] : 1;
86
+ const baseLogits = await tape.record(
87
+ OpType.MATMUL,
88
+ (a, b) => runMatmul(a, b, batchSize, 2, 3, { transposeB: false }),
89
+ [inputTensor, baseWeight],
90
+ { M: batchSize, N: 2, K: 3, transposeB: false }
91
+ );
92
+ const delta = await adapter.forward(inputTensor, tape);
93
+ return tape.record(
94
+ OpType.RESIDUAL_ADD,
95
+ (a, b) => runResidualAdd(a, b, batchSize * 2),
96
+ [baseLogits, delta],
97
+ { size: batchSize * 2 }
98
+ );
99
+ },
100
+ loraParams() {
101
+ return [adapter.A, adapter.B];
102
+ },
103
+ paramGroups() {
104
+ return {
105
+ encoder: [],
106
+ prior: [],
107
+ decoder: [],
108
+ base: [baseWeight],
109
+ lora: [adapter.A, adapter.B],
110
+ };
111
+ },
112
+ };
113
+ return {
114
+ model,
115
+ cleanup() {
116
+ adapter.dispose();
117
+ releaseTensor(baseWeight);
118
+ },
119
+ };
120
+ }
121
+
122
+ function normalizeToyRow(record, index) {
123
+ if (!record || typeof record !== 'object' || Array.isArray(record)) {
124
+ throw new Error(`LoRA toy dataset row ${index + 1} must be an object.`);
125
+ }
126
+ const values = Array.isArray(record.input)
127
+ ? record.input
128
+ : (Array.isArray(record.features) ? record.features : null);
129
+ if (!Array.isArray(values) || values.length !== 3) {
130
+ throw new Error(`LoRA toy dataset row ${index + 1} requires input[3].`);
131
+ }
132
+ const input = values.map((value, valueIndex) => {
133
+ const parsed = Number(value);
134
+ if (!Number.isFinite(parsed)) {
135
+ throw new Error(`LoRA toy dataset row ${index + 1} input[${valueIndex}] must be finite.`);
136
+ }
137
+ return parsed;
138
+ });
139
+ const target = Number(record.target ?? record.label);
140
+ if (!Number.isInteger(target) || target < 0 || target > 1) {
141
+ throw new Error(`LoRA toy dataset row ${index + 1} requires integer target 0 or 1.`);
142
+ }
143
+ return {
144
+ id: String(record.id || `row-${index + 1}`),
145
+ input,
146
+ target,
147
+ };
148
+ }
149
+
150
+ async function loadToyLoraDataset(datasetPath) {
151
+ const absolutePath = resolve(String(datasetPath));
152
+ const raw = await readFile(absolutePath, 'utf8');
153
+ const rows = absolutePath.endsWith('.json')
154
+ ? JSON.parse(raw)
155
+ : parseJsonl(raw);
156
+ if (!Array.isArray(rows)) {
157
+ throw new Error(`LoRA dataset "${absolutePath}" must be a JSON array or JSONL file.`);
158
+ }
159
+ const normalizedRows = rows.map((row, index) => normalizeToyRow(row, index));
160
+ return {
161
+ absolutePath,
162
+ raw,
163
+ rows: normalizedRows,
164
+ datasetHash: hashArtifactPayload({ rows: normalizedRows }),
165
+ };
166
+ }
167
+
168
+ function createToyDatasetBatches(rows, batchSize) {
169
+ return {
170
+ async *batches() {
171
+ let inputTensor = null;
172
+ let targetTensor = null;
173
+ let tensorBatchSize = 0;
174
+ try {
175
+ for (let offset = 0; offset < rows.length; offset += batchSize) {
176
+ const batchRows = rows.slice(offset, offset + batchSize);
177
+ const inputData = new Float32Array(batchRows.length * 3);
178
+ const targetData = new Uint32Array(batchRows.length);
179
+ for (let rowIndex = 0; rowIndex < batchRows.length; rowIndex += 1) {
180
+ inputData.set(batchRows[rowIndex].input, rowIndex * 3);
181
+ targetData[rowIndex] = batchRows[rowIndex].target;
182
+ }
183
+ if (!inputTensor || !targetTensor || tensorBatchSize !== batchRows.length) {
184
+ releaseTensor(inputTensor);
185
+ releaseTensor(targetTensor);
186
+ inputTensor = makeTensorFromFloat32(inputData, [batchRows.length, 3], 'lora_toy_input');
187
+ targetTensor = makeTensorFromUint32(targetData, [batchRows.length], 'lora_toy_target');
188
+ tensorBatchSize = batchRows.length;
189
+ } else {
190
+ uploadData(inputTensor.buffer, inputData);
191
+ uploadData(targetTensor.buffer, targetData);
192
+ }
193
+ yield {
194
+ input: inputTensor,
195
+ targets: targetTensor,
196
+ };
197
+ }
198
+ } finally {
199
+ releaseTensor(inputTensor);
200
+ releaseTensor(targetTensor);
201
+ }
202
+ },
203
+ };
204
+ }
205
+
206
+ function collectProtectedBuffers(model) {
207
+ const protectedBuffers = new Set();
208
+ const groups = model.paramGroups();
209
+ for (const params of Object.values(groups)) {
210
+ for (const tensor of params) {
211
+ if (tensor?.buffer) {
212
+ protectedBuffers.add(tensor.buffer);
213
+ }
214
+ }
215
+ }
216
+ return protectedBuffers;
217
+ }
218
+
219
+ function disposeTapeOutputs(tape, protectedBuffers = new Set()) {
220
+ if (!Array.isArray(tape?.records)) return;
221
+ const released = new Set();
222
+ for (const record of tape.records) {
223
+ const output = record?.output;
224
+ if (output?.buffer && !protectedBuffers.has(output.buffer) && !released.has(output.buffer)) {
225
+ released.add(output.buffer);
226
+ releaseBuffer(output.buffer);
227
+ }
228
+ }
229
+ }
230
+
231
+ function argmax(values) {
232
+ let bestIndex = 0;
233
+ let bestValue = Number.NEGATIVE_INFINITY;
234
+ for (let index = 0; index < values.length; index += 1) {
235
+ const value = Number.isFinite(values[index]) ? values[index] : Number.NEGATIVE_INFINITY;
236
+ if (value > bestValue) {
237
+ bestValue = value;
238
+ bestIndex = index;
239
+ }
240
+ }
241
+ return bestIndex;
242
+ }
243
+
244
+ async function evaluateToyLoraModel(workload, model, dataset, layout = null, checkpointMeta = {}) {
245
+ const protectedBuffers = collectProtectedBuffers(model);
246
+ const evalReports = [];
247
+ const evalDatasets = Array.isArray(workload.evalDatasets) ? workload.evalDatasets : [];
248
+ for (const evalDataset of evalDatasets) {
249
+ if (evalDataset.evalKind !== 'classification' && evalDataset.evalKind !== 'text_generation') {
250
+ throw new Error(`LoRA eval currently supports classification/text_generation only, got "${evalDataset.evalKind}".`);
251
+ }
252
+ const evalDatasetMaterialized = evalDataset.datasetPath === dataset.absolutePath
253
+ ? dataset
254
+ : await loadToyLoraDataset(evalDataset.datasetPath);
255
+ const rows = evalDatasetMaterialized.rows;
256
+ const predictions = [];
257
+ const labels = [];
258
+ for (const row of rows) {
259
+ const tape = new AutogradTape(loadBackwardRegistry());
260
+ const inputTensor = makeTensorFromFloat32(row.input, [1, 3], 'lora_eval_input');
261
+ let logits = null;
262
+ try {
263
+ logits = await model.forward(inputTensor, tape);
264
+ const logitsData = new Float32Array(await readBuffer(logits.buffer));
265
+ predictions.push(String(argmax(logitsData)));
266
+ labels.push(String(row.target));
267
+ } finally {
268
+ releaseTensor(inputTensor);
269
+ if (logits?.buffer && !protectedBuffers.has(logits.buffer)) {
270
+ releaseBuffer(logits.buffer);
271
+ }
272
+ disposeTapeOutputs(tape, protectedBuffers);
273
+ }
274
+ }
275
+ const metrics = computeEvalMetrics('classification', predictions, labels, {});
276
+ const reportPayload = {
277
+ artifactType: 'training_eval_report',
278
+ schemaVersion: 1,
279
+ generatedAt: new Date().toISOString(),
280
+ workloadId: workload.id,
281
+ workloadPath: checkpointMeta.workloadPath || null,
282
+ workloadSha256: checkpointMeta.workloadSha256 || null,
283
+ configHash: checkpointMeta.configHash || workload.configHash,
284
+ datasetPath: evalDataset.datasetPath,
285
+ datasetHash: evalDatasetMaterialized.datasetHash,
286
+ baseModelId: workload.baseModelId,
287
+ stage: 'lora',
288
+ checkpointStep: checkpointMeta.checkpointStep ?? null,
289
+ evalDatasetId: evalDataset.id,
290
+ metrics,
291
+ primaryMetric: metrics.primaryMetric,
292
+ primaryScore: metrics.primaryScore,
293
+ accuracy: metrics.accuracy?.score ?? null,
294
+ };
295
+ const reportFile = layout
296
+ ? await writeJsonArtifact(
297
+ join(layout.eval, `${checkpointMeta.checkpointId || 'checkpoint'}__${evalDataset.id}.json`),
298
+ reportPayload
299
+ )
300
+ : null;
301
+ evalReports.push({
302
+ ...reportPayload,
303
+ reportPath: reportFile?.path || null,
304
+ });
305
+ }
306
+ return evalReports;
307
+ }
308
+
309
+ function buildRunContract(loadedWorkload) {
310
+ return {
311
+ artifactType: 'training_run_contract',
312
+ schemaVersion: 1,
313
+ generatedAt: new Date().toISOString(),
314
+ workloadId: loadedWorkload.workload.id,
315
+ workloadPath: loadedWorkload.absolutePath,
316
+ workloadSha256: loadedWorkload.workloadSha256,
317
+ configHash: loadedWorkload.workload.configHash,
318
+ claimBoundary: loadedWorkload.workload.claimBoundary,
319
+ kind: loadedWorkload.workload.kind,
320
+ evalDatasets: loadedWorkload.workload.evalDatasets,
321
+ };
322
+ }
323
+
324
+ function buildArtifact(loadedWorkload, options) {
325
+ const workload = loadedWorkload.workload;
326
+ const payload = buildArtifactBase({
327
+ artifactType: options.artifactType,
328
+ reportId: `${options.prefix}_${workload.id}_${options.id}`,
329
+ workload,
330
+ workloadPath: loadedWorkload.absolutePath,
331
+ workloadSha256: loadedWorkload.workloadSha256,
332
+ datasetPath: options.datasetPath || workload.datasetPath,
333
+ datasetHash: options.datasetHash || null,
334
+ baseModelId: workload.baseModelId,
335
+ stage: options.stage || 'lora',
336
+ checkpointStep: options.checkpointStep ?? null,
337
+ parentArtifacts: options.parentArtifacts || [],
338
+ runtime: 'node',
339
+ surface: 'node',
340
+ claimBoundary: workload.claimBoundary,
341
+ configHash: options.configHash || workload.configHash,
342
+ });
343
+ return {
344
+ ...payload,
345
+ artifactHash: hashArtifactPayload(payload),
346
+ };
347
+ }
348
+
349
+ async function exportToyLoraModel(loadedWorkload, layout, model, checkpointId, checkpointStep, datasetHash) {
350
+ const workload = loadedWorkload.workload;
351
+ const targetModule = model.targetModule || workload.pipeline.adapter.targetModules[0];
352
+ const exported = await exportLoRAAdapter({
353
+ id: workload.pipeline.export?.id || `${workload.id}-${checkpointId}`,
354
+ name: workload.pipeline.export?.name || `${workload.id}-${checkpointId}`,
355
+ baseModel: workload.baseModelId,
356
+ rank: workload.pipeline.adapter.rank,
357
+ alpha: workload.pipeline.adapter.alpha,
358
+ targetModules: [targetModule],
359
+ tensors: [
360
+ { name: `layers.0.${targetModule}.lora_a`, tensor: model.adapter.A },
361
+ { name: `layers.0.${targetModule}.lora_b`, tensor: model.adapter.B },
362
+ ],
363
+ });
364
+ const manifestPath = join(layout.exports, `${checkpointId}.adapter.manifest.json`);
365
+ await writeFile(manifestPath, exported.json, 'utf8');
366
+ await loadLoRAFromManifest(exported.manifest, {});
367
+ const artifactPayload = {
368
+ ...buildArtifact(loadedWorkload, {
369
+ prefix: 'lora_export',
370
+ id: checkpointId,
371
+ artifactType: 'lora_adapter_manifest',
372
+ checkpointStep,
373
+ datasetHash,
374
+ }),
375
+ checkpointId,
376
+ manifestPath,
377
+ manifest: exported.manifest,
378
+ };
379
+ const artifactFile = await writeJsonArtifact(
380
+ join(layout.exports, `${checkpointId}.export.json`),
381
+ artifactPayload
382
+ );
383
+ return {
384
+ checkpointId,
385
+ manifestPath,
386
+ exportPath: artifactFile.path,
387
+ manifest: exported.manifest,
388
+ };
389
+ }
390
+
391
+ async function selectLatestCheckpoint(runRoot) {
392
+ const checkpointsDir = join(runRoot, 'checkpoints');
393
+ const entries = await readdir(checkpointsDir, { withFileTypes: true });
394
+ const dirs = entries
395
+ .filter((entry) => entry.isDirectory())
396
+ .map((entry) => entry.name)
397
+ .sort((left, right) => left.localeCompare(right));
398
+ const latest = dirs[dirs.length - 1];
399
+ if (!latest) {
400
+ throw new Error(`No checkpoints found in ${checkpointsDir}.`);
401
+ }
402
+ return {
403
+ checkpointId: latest,
404
+ checkpointPath: join(checkpointsDir, latest, 'state.json'),
405
+ markerPath: join(checkpointsDir, latest, 'checkpoint.complete.json'),
406
+ };
407
+ }
408
+
409
+ export async function runLoraPipeline(options) {
410
+ const loadedWorkload = options.loadedWorkload;
411
+ const workload = loadedWorkload.workload;
412
+ if (workload.kind !== 'lora') {
413
+ throw new Error('runLoraPipeline requires a lora workload.');
414
+ }
415
+ if (workload.baseModelId !== 'training-toy') {
416
+ throw new Error('LoRA run currently supports baseModelId="training-toy" only.');
417
+ }
418
+ if (workload.pipeline.datasetFormat !== 'toy_linear_classification_jsonl') {
419
+ throw new Error('LoRA run currently supports datasetFormat="toy_linear_classification_jsonl" only.');
420
+ }
421
+ const layout = options.runRoot
422
+ ? {
423
+ runRoot: resolve(String(options.runRoot)),
424
+ logs: join(resolve(String(options.runRoot)), 'logs'),
425
+ checkpoints: join(resolve(String(options.runRoot)), 'checkpoints'),
426
+ eval: join(resolve(String(options.runRoot)), 'eval'),
427
+ scoreboard: join(resolve(String(options.runRoot)), 'scoreboard'),
428
+ exports: join(resolve(String(options.runRoot)), 'exports'),
429
+ compare: join(resolve(String(options.runRoot)), 'compare'),
430
+ qualityGate: join(resolve(String(options.runRoot)), 'quality-gate'),
431
+ }
432
+ : await createTrainingRunLayout({
433
+ kind: 'lora',
434
+ workloadId: workload.id,
435
+ timestamp: options.timestamp || null,
436
+ });
437
+ await Promise.all(Object.values(layout).map((dirPath) => mkdir(dirPath, { recursive: true })));
438
+ await writeRunContract(layout, buildRunContract(loadedWorkload));
439
+ await writeWorkloadLock(layout, loadedWorkload);
440
+ const dataset = await loadToyLoraDataset(workload.datasetPath);
441
+ const fixture = createToyLoraModel(workload);
442
+ try {
443
+ const evalReports = [];
444
+ const checkpointArtifacts = [];
445
+ const exports = [];
446
+ const runner = new TrainingRunner({
447
+ training: {
448
+ enabled: true,
449
+ optimizer: {
450
+ type: workload.training.optimizer.type,
451
+ lr: workload.training.optimizer.lr,
452
+ beta1: workload.training.optimizer.beta1,
453
+ beta2: workload.training.optimizer.beta2,
454
+ eps: workload.training.optimizer.eps,
455
+ weightDecay: workload.training.optimizer.weightDecay,
456
+ scheduler: workload.training.optimizer.scheduler,
457
+ },
458
+ gradient: {
459
+ maxNorm: workload.training.gradientClipping.maxNorm,
460
+ },
461
+ precision: workload.training.precision,
462
+ lossScaling: { enabled: false },
463
+ distill: {
464
+ enabled: false,
465
+ stage: 'stage_a',
466
+ teacherModelId: null,
467
+ studentModelId: null,
468
+ datasetId: null,
469
+ datasetPath: null,
470
+ languagePair: null,
471
+ sourceLangs: null,
472
+ targetLangs: null,
473
+ pairAllowlist: null,
474
+ strictPairContract: false,
475
+ shardIndex: null,
476
+ shardCount: null,
477
+ resumeFrom: null,
478
+ artifactDir: null,
479
+ stageAArtifact: null,
480
+ stageAArtifactHash: null,
481
+ temperature: 1,
482
+ alphaKd: 1,
483
+ alphaCe: 0,
484
+ allowHintFallback: false,
485
+ tripletMargin: 0.2,
486
+ studentGraphMode: 'projection_head',
487
+ freeze: { encoder: false, prior: false, decoder: false, base: true, lora: false },
488
+ },
489
+ ul: {
490
+ enabled: false,
491
+ stage: 'stage1_joint',
492
+ stage1Artifact: null,
493
+ stage1ArtifactHash: null,
494
+ artifactDir: null,
495
+ lambda0: 5,
496
+ seed: workload.seed,
497
+ noiseSchedule: { name: 'linear', minSigma: 0.1, maxSigma: 1, steps: 1 },
498
+ priorAlignment: { enabled: false, weight: 1 },
499
+ decoderSigmoidWeight: { enabled: false, maxWeight: 1 },
500
+ lossWeights: { prior: 1, decoder: 1, recon: 1 },
501
+ freeze: null,
502
+ },
503
+ },
504
+ }, {
505
+ optimizer: new AdamOptimizer({
506
+ training: {
507
+ optimizer: {
508
+ type: workload.training.optimizer.type,
509
+ lr: workload.training.optimizer.lr,
510
+ beta1: workload.training.optimizer.beta1,
511
+ beta2: workload.training.optimizer.beta2,
512
+ eps: workload.training.optimizer.eps,
513
+ weightDecay: workload.training.optimizer.weightDecay,
514
+ scheduler: workload.training.optimizer.scheduler,
515
+ },
516
+ gradient: {
517
+ maxNorm: workload.training.gradientClipping.maxNorm,
518
+ },
519
+ precision: workload.training.precision,
520
+ },
521
+ }),
522
+ crossEntropyLoss,
523
+ clipGradients,
524
+ resolveCheckpointKey({ step }) {
525
+ return join(layout.checkpoints, `checkpoint-${String(step).padStart(6, '0')}`, 'state.json');
526
+ },
527
+ onCheckpoint: async (checkpoint) => {
528
+ const checkpointId = `checkpoint-${String(checkpoint.step).padStart(6, '0')}`;
529
+ const checkpointPayload = {
530
+ ...buildArtifact(loadedWorkload, {
531
+ prefix: 'lora_ckpt',
532
+ id: checkpointId,
533
+ artifactType: 'training_checkpoint',
534
+ datasetHash: dataset.datasetHash,
535
+ checkpointStep: checkpoint.step,
536
+ }),
537
+ checkpointId,
538
+ checkpointPath: checkpoint.path,
539
+ optimizerStatePresent: true,
540
+ schedulerStatePresent: workload.training.optimizer.scheduler.enabled === true,
541
+ resumeLineage: checkpoint.metadata?.lineage || null,
542
+ };
543
+ await writeJsonArtifact(
544
+ join(layout.checkpoints, checkpointId, 'checkpoint.json'),
545
+ checkpointPayload
546
+ );
547
+ const checkpointArtifact = await writeJsonArtifact(
548
+ join(layout.checkpoints, checkpointId, 'checkpoint.complete.json'),
549
+ checkpointPayload
550
+ );
551
+ checkpointArtifacts.push({
552
+ checkpointId,
553
+ checkpointPath: checkpoint.path,
554
+ markerPath: checkpointArtifact.path,
555
+ checkpointStep: checkpoint.step,
556
+ });
557
+ if (workload.pipeline.export?.enabled === true && workload.pipeline.export.atCheckpoints === true) {
558
+ exports.push(await exportToyLoraModel(
559
+ loadedWorkload,
560
+ layout,
561
+ fixture.model,
562
+ checkpointId,
563
+ checkpoint.step,
564
+ dataset.datasetHash
565
+ ));
566
+ }
567
+ const reports = await evaluateToyLoraModel(workload, fixture.model, dataset, layout, {
568
+ checkpointId,
569
+ checkpointStep: checkpoint.step,
570
+ configHash: workload.configHash,
571
+ workloadPath: loadedWorkload.absolutePath,
572
+ workloadSha256: loadedWorkload.workloadSha256,
573
+ });
574
+ for (const report of reports) {
575
+ evalReports.push(report);
576
+ await appendScoreboardRow(layout.scoreboard, {
577
+ artifactType: 'training_scoreboard',
578
+ schemaVersion: 1,
579
+ generatedAt: new Date().toISOString(),
580
+ checkpointId,
581
+ checkpointStep: checkpoint.step,
582
+ evalDatasetId: report.evalDatasetId,
583
+ primaryMetric: report.primaryMetric,
584
+ primaryScore: report.primaryScore,
585
+ accuracy: report.accuracy,
586
+ metrics: {
587
+ accuracy: report.accuracy,
588
+ primaryScore: report.primaryScore,
589
+ },
590
+ }, {
591
+ selectionMetric: workload.selectionMetric,
592
+ selectionGoal: workload.selectionGoal,
593
+ });
594
+ }
595
+ },
596
+ });
597
+ const metrics = await runner.run(
598
+ fixture.model,
599
+ createToyDatasetBatches(dataset.rows, workload.training.batchSize),
600
+ {
601
+ epochs: 1,
602
+ batchSize: workload.training.batchSize,
603
+ shuffle: false,
604
+ maxSteps: workload.training.steps,
605
+ checkpointEvery: workload.checkpointEvery,
606
+ modelId: workload.baseModelId,
607
+ }
608
+ );
609
+ const finalCheckpointId = runner.lastCheckpoint
610
+ ? `checkpoint-${String(runner.lastCheckpoint.step).padStart(6, '0')}`
611
+ : null;
612
+ if (workload.pipeline.export?.enabled === true && finalCheckpointId && exports.every((entry) => entry.checkpointId !== finalCheckpointId)) {
613
+ exports.push(await exportToyLoraModel(
614
+ loadedWorkload,
615
+ layout,
616
+ fixture.model,
617
+ finalCheckpointId,
618
+ runner.lastCheckpoint.step,
619
+ dataset.datasetHash
620
+ ));
621
+ }
622
+ return {
623
+ ok: true,
624
+ kind: 'lora',
625
+ action: 'run',
626
+ workloadId: workload.id,
627
+ runRoot: layout.runRoot,
628
+ checkpointArtifacts,
629
+ evalReports,
630
+ exports,
631
+ metrics,
632
+ lastCheckpoint: runner.lastCheckpoint,
633
+ };
634
+ } finally {
635
+ fixture.cleanup();
636
+ }
637
+ }
638
+
639
+ export async function evaluateLoraCheckpoint(options) {
640
+ const loadedWorkload = options.loadedWorkload;
641
+ const checkpointPath = resolve(String(options.checkpointPath));
642
+ const workload = loadedWorkload.workload;
643
+ const dataset = await loadToyLoraDataset(workload.datasetPath);
644
+ const checkpointRecord = await loadCheckpoint(checkpointPath);
645
+ if (!checkpointRecord) {
646
+ throw new Error(`Checkpoint not found: ${checkpointPath}`);
647
+ }
648
+ const fixture = createToyLoraModel(workload);
649
+ try {
650
+ await restoreTrainingCheckpointState(fixture.model, { getState: () => null }, checkpointRecord, {
651
+ training: {
652
+ distill: { freeze: { encoder: false, prior: false, decoder: false, base: true, lora: false } },
653
+ ul: { freeze: null },
654
+ },
655
+ });
656
+ return evaluateToyLoraModel(workload, fixture.model, dataset, options.layout || null, {
657
+ checkpointId: options.checkpointId || 'checkpoint',
658
+ checkpointStep: options.checkpointStep ?? null,
659
+ configHash: workload.configHash,
660
+ workloadPath: loadedWorkload.absolutePath,
661
+ workloadSha256: loadedWorkload.workloadSha256,
662
+ });
663
+ } finally {
664
+ fixture.cleanup();
665
+ }
666
+ }
667
+
668
+ export async function exportLoraCheckpoint(options) {
669
+ const loadedWorkload = options.loadedWorkload;
670
+ const workload = loadedWorkload.workload;
671
+ const layout = options.layout || {
672
+ exports: resolve(options.exportsDir || 'reports/training/lora/exports'),
673
+ };
674
+ const checkpointPath = resolve(String(options.checkpointPath));
675
+ const checkpointRecord = await loadCheckpoint(checkpointPath);
676
+ if (!checkpointRecord) {
677
+ throw new Error(`Checkpoint not found: ${checkpointPath}`);
678
+ }
679
+ const fixture = createToyLoraModel(workload);
680
+ try {
681
+ await restoreTrainingCheckpointState(fixture.model, { getState: () => null }, checkpointRecord, {
682
+ training: {
683
+ distill: { freeze: { encoder: false, prior: false, decoder: false, base: true, lora: false } },
684
+ ul: { freeze: null },
685
+ },
686
+ });
687
+ const checkpointId = options.checkpointId || 'checkpoint';
688
+ return exportToyLoraModel(
689
+ loadedWorkload,
690
+ { ...layout, exports: layout.exports || resolve(options.exportsDir || 'reports/training/lora/exports') },
691
+ fixture.model,
692
+ checkpointId,
693
+ options.checkpointStep ?? null,
694
+ options.datasetHash || null
695
+ );
696
+ } finally {
697
+ fixture.cleanup();
698
+ }
699
+ }
700
+
701
+ export async function watchLoraCheckpoints(options) {
702
+ const latestCheckpoint = await selectLatestCheckpoint(options.runRoot);
703
+ return watchFinalizedCheckpoints({
704
+ checkpointsDir: join(options.runRoot, 'checkpoints'),
705
+ manifestPath: join(options.runRoot, 'scoreboard', 'watch-manifest.json'),
706
+ pollIntervalMs: options.pollIntervalMs || 2000,
707
+ stopWhenIdle: options.stopWhenIdle === true,
708
+ signal: options.signal ?? null,
709
+ onCheckpoint: async (markerPath) => {
710
+ const raw = await readFile(markerPath, 'utf8');
711
+ const marker = JSON.parse(raw);
712
+ await evaluateLoraCheckpoint({
713
+ loadedWorkload: options.loadedWorkload,
714
+ checkpointPath: marker.checkpointPath || latestCheckpoint.checkpointPath,
715
+ checkpointId: marker.checkpointId || latestCheckpoint.checkpointId,
716
+ checkpointStep: marker.checkpointStep ?? null,
717
+ layout: {
718
+ eval: join(options.runRoot, 'eval'),
719
+ },
720
+ });
721
+ },
722
+ });
723
+ }
724
+
725
+ export async function compareLoraRun(options) {
726
+ const evalDir = join(options.runRoot, 'eval');
727
+ const entries = await readdir(evalDir, { withFileTypes: true });
728
+ const reports = [];
729
+ for (const entry of entries) {
730
+ if (!entry.isFile() || !entry.name.endsWith('.json')) continue;
731
+ const raw = await readFile(join(evalDir, entry.name), 'utf8');
732
+ reports.push(JSON.parse(raw));
733
+ }
734
+ const sorted = reports
735
+ .slice()
736
+ .sort((left, right) => {
737
+ const leftScore = Number(left?.primaryScore ?? Number.NEGATIVE_INFINITY);
738
+ const rightScore = Number(right?.primaryScore ?? Number.NEGATIVE_INFINITY);
739
+ return rightScore - leftScore;
740
+ });
741
+ const payload = {
742
+ artifactType: 'training_compare_report',
743
+ schemaVersion: 1,
744
+ generatedAt: new Date().toISOString(),
745
+ runRoot: options.runRoot,
746
+ count: sorted.length,
747
+ best: sorted[0] || null,
748
+ reports: sorted.map((report) => ({
749
+ checkpointId: report.checkpointId || null,
750
+ evalDatasetId: report.evalDatasetId || null,
751
+ primaryMetric: report.primaryMetric || null,
752
+ primaryScore: report.primaryScore ?? null,
753
+ accuracy: report.accuracy ?? null,
754
+ reportPath: report.reportPath || null,
755
+ })),
756
+ };
757
+ const artifact = await writeJsonArtifact(join(options.runRoot, 'compare', 'compare.json'), payload);
758
+ return {
759
+ ...payload,
760
+ comparePath: artifact.path,
761
+ };
762
+ }
763
+
764
+ export async function qualityGateLoraRun(options) {
765
+ const runRoot = resolve(String(options.runRoot));
766
+ const requiredPaths = [
767
+ join(runRoot, 'run_contract.json'),
768
+ join(runRoot, 'workload.lock.json'),
769
+ ];
770
+ const checks = [];
771
+ for (const filePath of requiredPaths) {
772
+ try {
773
+ await readFile(filePath, 'utf8');
774
+ checks.push({ path: filePath, ok: true });
775
+ } catch (error) {
776
+ checks.push({ path: filePath, ok: false, error: error?.message || String(error) });
777
+ }
778
+ }
779
+ const passed = checks.every((entry) => entry.ok === true);
780
+ const payload = {
781
+ artifactType: 'training_quality_gate',
782
+ schemaVersion: 1,
783
+ generatedAt: new Date().toISOString(),
784
+ runRoot,
785
+ passed,
786
+ checks,
787
+ };
788
+ const artifact = await writeJsonArtifact(join(runRoot, 'quality-gate', 'quality-gate.json'), payload);
789
+ return {
790
+ ...payload,
791
+ reportPath: artifact.path,
792
+ };
793
+ }