@simulatte/doppler 0.1.5 → 0.1.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (392) hide show
  1. package/CHANGELOG.md +126 -0
  2. package/README.md +25 -17
  3. package/package.json +20 -4
  4. package/src/adapters/adapter-registry.js +12 -1
  5. package/src/adapters/lora-loader.js +23 -6
  6. package/src/bridge/extension-client.d.ts +5 -0
  7. package/src/bridge/extension-client.js +40 -0
  8. package/src/bridge/index.d.ts +2 -1
  9. package/src/bridge/index.js +6 -4
  10. package/src/browser/browser-converter.js +26 -1
  11. package/src/browser/file-picker.js +6 -0
  12. package/src/browser/safetensors-parser-browser.js +84 -1
  13. package/src/browser/shard-io-browser.js +2 -2
  14. package/src/browser/tensor-source-download.js +8 -2
  15. package/src/browser/tensor-source-http.d.ts +1 -0
  16. package/src/browser/tensor-source-http.js +5 -1
  17. package/src/client/doppler-api.browser.js +20 -4
  18. package/src/client/doppler-api.js +19 -3
  19. package/src/client/doppler-provider/generation.js +12 -0
  20. package/src/client/doppler-provider/model-manager.d.ts +10 -0
  21. package/src/client/doppler-provider/model-manager.js +91 -19
  22. package/src/client/doppler-provider/source-runtime.d.ts +2 -1
  23. package/src/client/doppler-provider/source-runtime.js +132 -13
  24. package/src/client/doppler-registry.json +8 -7
  25. package/src/config/backward-registry-loader.js +17 -2
  26. package/src/config/execution-v0-contract-check.js +113 -15
  27. package/src/config/kernel-path-contract-check.js +57 -29
  28. package/src/config/kernel-path-loader.js +5 -36
  29. package/src/config/kernels/kernel-ref-digests.js +39 -39
  30. package/src/config/kernels/registry.js +14 -1
  31. package/src/config/kernels/registry.json +49 -7
  32. package/src/config/loader.d.ts +1 -1
  33. package/src/config/loader.js +43 -4
  34. package/src/config/merge-contract-check.js +59 -4
  35. package/src/config/merge-helpers.js +128 -7
  36. package/src/config/merge.d.ts +1 -0
  37. package/src/config/merge.js +28 -0
  38. package/src/config/param-validator.js +47 -2
  39. package/src/config/presets/kernel-paths/{gemma2-q4k-dequant-f32a.json → gemma2-q4k-dequant-f32a-nosubgroups.json} +3 -3
  40. package/src/config/presets/kernel-paths/gemma3-f16-fused-f32a-online-streamingprefill.json +223 -0
  41. package/src/config/presets/kernel-paths/{gemma3-q4k-dequant-f32a.json → gemma3-q4k-dequant-f32a-nosubgroups.json} +3 -3
  42. package/src/config/presets/kernel-paths/registry.json +29 -8
  43. package/src/config/presets/models/gemma2.json +2 -2
  44. package/src/config/presets/models/qwen3.json +9 -2
  45. package/src/config/presets/models/transformer.json +5 -0
  46. package/src/config/presets/runtime/experiments/bench/gemma3-bench-q4k.json +1 -1
  47. package/src/config/presets/runtime/experiments/debug/gemma3-debug-q4k.json +1 -1
  48. package/src/config/presets/runtime/experiments/verify/gemma3-verify.json +1 -1
  49. package/src/config/presets/runtime/kernels/dequant-f16-q4k.json +6 -13
  50. package/src/config/presets/runtime/kernels/dequant-f32-q4k.json +6 -13
  51. package/src/config/presets/runtime/kernels/embeddinggemma-q4k-dequant-f32a.json +37 -0
  52. package/src/config/presets/runtime/kernels/fused-q4k.json +6 -13
  53. package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f16a.json +33 -0
  54. package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f32a-nosubgroups.json +33 -0
  55. package/src/config/presets/runtime/kernels/gemma2-q4k-fused-f32a.json +33 -0
  56. package/src/config/presets/runtime/kernels/safe-q4k.json +6 -13
  57. package/src/config/presets/runtime/platform/metal-apple-q4k.json +1 -1
  58. package/src/config/required-inference-fields-contract-check.js +6 -0
  59. package/src/config/runtime.js +6 -1
  60. package/src/config/schema/debug.schema.d.ts +5 -0
  61. package/src/config/schema/doppler.schema.js +16 -21
  62. package/src/config/schema/inference-defaults.schema.js +6 -3
  63. package/src/config/schema/inference.schema.d.ts +9 -0
  64. package/src/config/schema/kernel-path.schema.d.ts +11 -1
  65. package/src/config/schema/kernel-thresholds.schema.js +12 -4
  66. package/src/config/schema/manifest.schema.d.ts +8 -1
  67. package/src/config/schema/manifest.schema.js +19 -3
  68. package/src/config/training-defaults.js +30 -22
  69. package/src/converter/conversion-plan.js +94 -9
  70. package/src/converter/core.d.ts +7 -0
  71. package/src/converter/core.js +14 -9
  72. package/src/converter/execution-v0-manifest.js +4 -1
  73. package/src/converter/index.d.ts +1 -0
  74. package/src/converter/index.js +1 -0
  75. package/src/converter/manifest-inference.js +43 -12
  76. package/src/converter/parsers/diffusion.js +0 -3
  77. package/src/converter/quantization-info.js +35 -15
  78. package/src/converter/rope-config.js +42 -0
  79. package/src/converter/shard-packer.d.ts +1 -1
  80. package/src/converter/shard-packer.js +4 -1
  81. package/src/debug/config.js +123 -11
  82. package/src/debug/signals.js +7 -1
  83. package/src/debug/tensor.d.ts +2 -0
  84. package/src/debug/tensor.js +13 -2
  85. package/src/distribution/p2p-control-plane.js +52 -12
  86. package/src/distribution/p2p-observability.js +43 -7
  87. package/src/distribution/p2p-webrtc-browser.js +20 -0
  88. package/src/distribution/shard-delivery.js +77 -26
  89. package/src/formats/gguf/types.js +33 -16
  90. package/src/formats/rdrr/groups.d.ts +12 -4
  91. package/src/formats/rdrr/groups.js +3 -6
  92. package/src/formats/rdrr/parsing.js +39 -2
  93. package/src/formats/rdrr/types.d.ts +2 -1
  94. package/src/gpu/command-recorder.js +86 -61
  95. package/src/gpu/device.d.ts +1 -0
  96. package/src/gpu/device.js +131 -19
  97. package/src/gpu/kernel-tuner/benchmarks.js +326 -316
  98. package/src/gpu/kernel-tuner/cache.js +71 -4
  99. package/src/gpu/kernel-tuner/tuner.js +22 -4
  100. package/src/gpu/kernels/attention.js +113 -34
  101. package/src/gpu/kernels/backward/adam.js +62 -58
  102. package/src/gpu/kernels/backward/attention_backward.js +257 -169
  103. package/src/gpu/kernels/backward/conv2d_backward.js +14 -1
  104. package/src/gpu/kernels/bias_add.wgsl +8 -6
  105. package/src/gpu/kernels/bias_add_f16.wgsl +8 -5
  106. package/src/gpu/kernels/cast.js +191 -149
  107. package/src/gpu/kernels/check-stop.js +33 -44
  108. package/src/gpu/kernels/conv2d.js +27 -17
  109. package/src/gpu/kernels/conv2d.wgsl +7 -8
  110. package/src/gpu/kernels/conv2d_f16.wgsl +7 -8
  111. package/src/gpu/kernels/cross_entropy_loss.js +21 -15
  112. package/src/gpu/kernels/depthwise_conv2d.js +37 -26
  113. package/src/gpu/kernels/depthwise_conv2d.wgsl +6 -9
  114. package/src/gpu/kernels/depthwise_conv2d_f16.wgsl +6 -9
  115. package/src/gpu/kernels/dequant.js +178 -126
  116. package/src/gpu/kernels/energy.d.ts +3 -21
  117. package/src/gpu/kernels/energy.js +111 -88
  118. package/src/gpu/kernels/feature-check.js +1 -1
  119. package/src/gpu/kernels/fused_ffn.js +84 -65
  120. package/src/gpu/kernels/fused_matmul_residual.js +56 -33
  121. package/src/gpu/kernels/fused_matmul_rmsnorm.js +62 -45
  122. package/src/gpu/kernels/gather.js +33 -15
  123. package/src/gpu/kernels/gelu.js +19 -11
  124. package/src/gpu/kernels/grouped_pointwise_conv2d.js +34 -23
  125. package/src/gpu/kernels/grouped_pointwise_conv2d.wgsl +6 -9
  126. package/src/gpu/kernels/grouped_pointwise_conv2d_f16.wgsl +6 -9
  127. package/src/gpu/kernels/groupnorm.js +34 -23
  128. package/src/gpu/kernels/kv-quantize.js +5 -2
  129. package/src/gpu/kernels/layernorm.js +35 -19
  130. package/src/gpu/kernels/logit-merge.js +5 -3
  131. package/src/gpu/kernels/matmul.js +83 -39
  132. package/src/gpu/kernels/modulate.js +23 -15
  133. package/src/gpu/kernels/moe.js +221 -175
  134. package/src/gpu/kernels/pixel_shuffle.js +22 -14
  135. package/src/gpu/kernels/pixel_shuffle.wgsl +4 -5
  136. package/src/gpu/kernels/pixel_shuffle_f16.wgsl +4 -5
  137. package/src/gpu/kernels/relu.js +31 -10
  138. package/src/gpu/kernels/relu.wgsl +2 -1
  139. package/src/gpu/kernels/relu_f16.wgsl +2 -1
  140. package/src/gpu/kernels/repeat_channels.js +25 -17
  141. package/src/gpu/kernels/repeat_channels.wgsl +4 -5
  142. package/src/gpu/kernels/repeat_channels_f16.wgsl +4 -5
  143. package/src/gpu/kernels/residual.js +69 -23
  144. package/src/gpu/kernels/residual.wgsl +6 -3
  145. package/src/gpu/kernels/residual_f16.wgsl +2 -1
  146. package/src/gpu/kernels/residual_f16_vec4.wgsl +2 -1
  147. package/src/gpu/kernels/residual_vec4.wgsl +2 -1
  148. package/src/gpu/kernels/rmsnorm.js +96 -28
  149. package/src/gpu/kernels/rmsnorm.wgsl +14 -6
  150. package/src/gpu/kernels/rmsnorm_f16.wgsl +10 -2
  151. package/src/gpu/kernels/rope.d.ts +2 -0
  152. package/src/gpu/kernels/rope.js +14 -1
  153. package/src/gpu/kernels/rope.wgsl +56 -40
  154. package/src/gpu/kernels/sample.js +27 -38
  155. package/src/gpu/kernels/sana_linear_attention.js +19 -12
  156. package/src/gpu/kernels/sana_linear_attention_apply.wgsl +4 -5
  157. package/src/gpu/kernels/sana_linear_attention_apply_f16.wgsl +4 -5
  158. package/src/gpu/kernels/sana_linear_attention_summary.wgsl +4 -0
  159. package/src/gpu/kernels/sana_linear_attention_summary_f16.wgsl +4 -0
  160. package/src/gpu/kernels/scale.js +18 -11
  161. package/src/gpu/kernels/shader-cache.js +4 -2
  162. package/src/gpu/kernels/silu.d.ts +1 -0
  163. package/src/gpu/kernels/silu.js +148 -82
  164. package/src/gpu/kernels/silu.wgsl +19 -9
  165. package/src/gpu/kernels/silu_f16.wgsl +19 -9
  166. package/src/gpu/kernels/softmax.js +44 -25
  167. package/src/gpu/kernels/split_qkv.js +23 -13
  168. package/src/gpu/kernels/transpose.js +31 -10
  169. package/src/gpu/kernels/transpose.wgsl +6 -5
  170. package/src/gpu/kernels/upsample2d.js +22 -13
  171. package/src/gpu/kernels/upsample2d.wgsl +6 -9
  172. package/src/gpu/kernels/upsample2d_f16.wgsl +6 -9
  173. package/src/gpu/kernels/utils.js +35 -13
  174. package/src/gpu/partitioned-buffer-pool.js +10 -2
  175. package/src/gpu/perf-guards.js +2 -9
  176. package/src/gpu/profiler.js +27 -22
  177. package/src/gpu/readback-utils.d.ts +16 -0
  178. package/src/gpu/readback-utils.js +41 -0
  179. package/src/gpu/submit-tracker.js +13 -0
  180. package/src/gpu/uniform-cache.d.ts +1 -0
  181. package/src/gpu/uniform-cache.js +30 -9
  182. package/src/hotswap/intent-bundle.js +6 -0
  183. package/src/hotswap/manifest.d.ts +10 -1
  184. package/src/hotswap/manifest.js +12 -2
  185. package/src/hotswap/runtime.js +30 -8
  186. package/src/index-browser.d.ts +44 -0
  187. package/src/index-browser.js +14 -0
  188. package/src/inference/browser-harness-contract-helpers.d.ts +5 -0
  189. package/src/inference/browser-harness-contract-helpers.js +28 -0
  190. package/src/inference/browser-harness-diffusion-energy-suites.d.ts +2 -0
  191. package/src/inference/browser-harness-diffusion-energy-suites.js +269 -0
  192. package/src/inference/browser-harness-model-helpers.d.ts +16 -0
  193. package/src/inference/browser-harness-model-helpers.js +217 -0
  194. package/src/inference/browser-harness-report-helpers.d.ts +7 -0
  195. package/src/inference/browser-harness-report-helpers.js +42 -0
  196. package/src/inference/browser-harness-runtime-helpers.d.ts +61 -0
  197. package/src/inference/browser-harness-runtime-helpers.js +415 -0
  198. package/src/inference/browser-harness-suite-helpers.d.ts +28 -0
  199. package/src/inference/browser-harness-suite-helpers.js +268 -0
  200. package/src/inference/browser-harness-text-helpers.d.ts +27 -0
  201. package/src/inference/browser-harness-text-helpers.js +788 -0
  202. package/src/inference/browser-harness.d.ts +6 -0
  203. package/src/inference/browser-harness.js +130 -1950
  204. package/src/inference/kv-cache/base.js +140 -94
  205. package/src/inference/kv-cache/tiered.js +5 -3
  206. package/src/inference/moe-router.js +88 -56
  207. package/src/inference/multi-model-network.js +5 -3
  208. package/src/inference/network-evolution.d.ts +11 -2
  209. package/src/inference/network-evolution.js +20 -21
  210. package/src/inference/pipelines/context.d.ts +3 -0
  211. package/src/inference/pipelines/context.js +142 -2
  212. package/src/inference/pipelines/diffusion/helpers.js +7 -2
  213. package/src/inference/pipelines/diffusion/pipeline.js +17 -7
  214. package/src/inference/pipelines/diffusion/sd3-transformer.js +10 -10
  215. package/src/inference/pipelines/diffusion/text-encoder-gpu.d.ts +5 -0
  216. package/src/inference/pipelines/diffusion/text-encoder-gpu.js +27 -15
  217. package/src/inference/pipelines/diffusion/vae.js +3 -7
  218. package/src/inference/pipelines/energy/pipeline.js +27 -21
  219. package/src/inference/pipelines/energy/quintel.d.ts +5 -0
  220. package/src/inference/pipelines/energy/quintel.js +11 -0
  221. package/src/inference/pipelines/energy-head/row-head-pipeline.js +17 -13
  222. package/src/inference/pipelines/structured/json-head-pipeline.js +26 -11
  223. package/src/inference/pipelines/text/attention/projections.js +151 -101
  224. package/src/inference/pipelines/text/attention/record.js +73 -10
  225. package/src/inference/pipelines/text/attention/run.js +73 -10
  226. package/src/inference/pipelines/text/chat-format.js +25 -1
  227. package/src/inference/pipelines/text/config.d.ts +4 -0
  228. package/src/inference/pipelines/text/config.js +71 -5
  229. package/src/inference/pipelines/text/embed.js +2 -8
  230. package/src/inference/pipelines/text/execution-plan.js +64 -50
  231. package/src/inference/pipelines/text/execution-v0-contract-helpers.d.ts +59 -0
  232. package/src/inference/pipelines/text/execution-v0-contract-helpers.js +937 -0
  233. package/src/inference/pipelines/text/execution-v0-runtime-builders.d.ts +15 -0
  234. package/src/inference/pipelines/text/execution-v0-runtime-builders.js +279 -0
  235. package/src/inference/pipelines/text/execution-v0.js +78 -1002
  236. package/src/inference/pipelines/text/ffn/standard.js +3 -0
  237. package/src/inference/pipelines/text/generator-steps.d.ts +46 -0
  238. package/src/inference/pipelines/text/generator-steps.js +298 -207
  239. package/src/inference/pipelines/text/generator.js +6 -23
  240. package/src/inference/pipelines/text/init.d.ts +4 -0
  241. package/src/inference/pipelines/text/init.js +134 -29
  242. package/src/inference/pipelines/text/kernel-path-auto-select.js +2 -0
  243. package/src/inference/pipelines/text/kernel-trace.d.ts +2 -0
  244. package/src/inference/pipelines/text/kernel-trace.js +6 -0
  245. package/src/inference/pipelines/text/layer.js +14 -9
  246. package/src/inference/pipelines/text/linear-attention.d.ts +10 -0
  247. package/src/inference/pipelines/text/linear-attention.js +80 -6
  248. package/src/inference/pipelines/text/logits/gpu.js +10 -5
  249. package/src/inference/pipelines/text/logits/index.js +10 -11
  250. package/src/inference/pipelines/text/logits/utils.d.ts +7 -0
  251. package/src/inference/pipelines/text/logits/utils.js +9 -0
  252. package/src/inference/pipelines/text/lora-apply.js +50 -32
  253. package/src/inference/pipelines/text/model-load.js +279 -104
  254. package/src/inference/pipelines/text/moe-cache.js +5 -4
  255. package/src/inference/pipelines/text/moe-cpu-gptoss.js +74 -69
  256. package/src/inference/pipelines/text/moe-cpu.js +42 -38
  257. package/src/inference/pipelines/text/moe-gpu.js +110 -86
  258. package/src/inference/pipelines/text/ops.js +90 -90
  259. package/src/inference/pipelines/text/probes.js +9 -9
  260. package/src/inference/pipelines/text/weights.js +17 -7
  261. package/src/inference/pipelines/text.js +17 -1
  262. package/src/inference/speculative.d.ts +2 -2
  263. package/src/inference/speculative.js +4 -18
  264. package/src/inference/test-harness.d.ts +1 -1
  265. package/src/inference/test-harness.js +15 -5
  266. package/src/inference/tokenizer.d.ts +0 -5
  267. package/src/inference/tokenizer.js +4 -23
  268. package/src/inference/tokenizers/bpe.js +9 -0
  269. package/src/inference/tokenizers/bundled.js +176 -33
  270. package/src/inference/tokenizers/sentencepiece.js +12 -0
  271. package/src/loader/doppler-loader.js +38 -22
  272. package/src/loader/dtype-utils.js +3 -44
  273. package/src/loader/embedding-loader.js +7 -3
  274. package/src/loader/experts/expert-cache.js +13 -6
  275. package/src/loader/experts/expert-loader.js +10 -6
  276. package/src/loader/final-weights-loader.js +8 -4
  277. package/src/loader/layer-loader.js +2 -1
  278. package/src/loader/loader-state.js +2 -2
  279. package/src/loader/memory-monitor.js +8 -0
  280. package/src/loader/multi-model-loader.d.ts +14 -0
  281. package/src/loader/multi-model-loader.js +70 -24
  282. package/src/loader/shard-cache.js +81 -12
  283. package/src/loader/shard-resolver.js +25 -3
  284. package/src/loader/tensors/tensor-loader.js +209 -144
  285. package/src/loader/tensors/tensor-reader.js +76 -19
  286. package/src/loader/weight-downcast.js +1 -1
  287. package/src/memory/buffer-pool.d.ts +9 -1
  288. package/src/memory/buffer-pool.js +109 -44
  289. package/src/memory/unified-detect.js +1 -1
  290. package/src/rules/inference/kernel-path.rules.json +24 -8
  291. package/src/rules/rule-registry.js +25 -1
  292. package/src/rules/tooling/command-runtime.rules.json +18 -0
  293. package/src/storage/backends/opfs-store.js +68 -24
  294. package/src/storage/downloader.js +364 -83
  295. package/src/storage/index.d.ts +3 -0
  296. package/src/storage/index.js +3 -0
  297. package/src/storage/preflight.d.ts +2 -2
  298. package/src/storage/preflight.js +24 -2
  299. package/src/storage/quickstart-downloader.js +11 -5
  300. package/src/storage/registry.js +10 -4
  301. package/src/storage/reports.js +1 -1
  302. package/src/storage/shard-manager.d.ts +15 -1
  303. package/src/storage/shard-manager.js +51 -3
  304. package/src/storage/source-artifact-store.d.ts +52 -0
  305. package/src/storage/source-artifact-store.js +234 -0
  306. package/src/tooling/command-api-constants.d.ts +9 -0
  307. package/src/tooling/command-api-constants.js +9 -0
  308. package/src/tooling/command-api-family-normalizers.d.ts +9 -0
  309. package/src/tooling/command-api-family-normalizers.js +343 -0
  310. package/src/tooling/command-api-helpers.d.ts +25 -0
  311. package/src/tooling/command-api-helpers.js +262 -0
  312. package/src/tooling/command-api.d.ts +27 -1
  313. package/src/tooling/command-api.js +26 -473
  314. package/src/tooling/command-envelope.js +4 -1
  315. package/src/tooling/command-runner-shared.js +52 -18
  316. package/src/tooling/lean-execution-contract.js +150 -3
  317. package/src/tooling/node-browser-command-runner.d.ts +4 -0
  318. package/src/tooling/node-browser-command-runner.js +218 -273
  319. package/src/tooling/node-command-runner.js +44 -3
  320. package/src/tooling/node-converter.js +27 -1
  321. package/src/tooling/node-source-runtime.d.ts +1 -1
  322. package/src/tooling/node-source-runtime.js +84 -3
  323. package/src/tooling/node-webgpu.js +30 -105
  324. package/src/tooling/opfs-cache.js +21 -4
  325. package/src/tooling/runtime-input-composition.d.ts +38 -0
  326. package/src/tooling/runtime-input-composition.js +86 -0
  327. package/src/tooling/source-runtime-bundle.d.ts +40 -5
  328. package/src/tooling/source-runtime-bundle.js +261 -34
  329. package/src/tooling/source-runtime-materializer.d.ts +6 -0
  330. package/src/tooling/source-runtime-materializer.js +93 -0
  331. package/src/training/attention-backward.js +32 -17
  332. package/src/training/autograd.js +80 -52
  333. package/src/training/checkpoint-watch.d.ts +8 -0
  334. package/src/training/checkpoint-watch.js +139 -0
  335. package/src/training/checkpoint.d.ts +6 -1
  336. package/src/training/checkpoint.js +46 -7
  337. package/src/training/clip.js +2 -1
  338. package/src/training/datasets/token-batch.js +20 -8
  339. package/src/training/distillation/artifacts.d.ts +71 -0
  340. package/src/training/distillation/artifacts.js +132 -0
  341. package/src/training/distillation/checkpoint-watch.d.ts +10 -0
  342. package/src/training/distillation/checkpoint-watch.js +58 -0
  343. package/src/training/distillation/dataset.d.ts +59 -0
  344. package/src/training/distillation/dataset.js +337 -0
  345. package/src/training/distillation/eval.d.ts +34 -0
  346. package/src/training/distillation/eval.js +310 -0
  347. package/src/training/distillation/index.d.ts +29 -0
  348. package/src/training/distillation/index.js +29 -0
  349. package/src/training/distillation/runtime.d.ts +20 -0
  350. package/src/training/distillation/runtime.js +121 -0
  351. package/src/training/distillation/scoreboard.d.ts +6 -0
  352. package/src/training/distillation/scoreboard.js +8 -0
  353. package/src/training/distillation/stage-a.d.ts +45 -0
  354. package/src/training/distillation/stage-a.js +338 -0
  355. package/src/training/distillation/stage-b.d.ts +24 -0
  356. package/src/training/distillation/stage-b.js +20 -0
  357. package/src/training/distillation/student-fixture.d.ts +22 -0
  358. package/src/training/distillation/student-fixture.js +846 -0
  359. package/src/training/distillation/suite-data.d.ts +45 -0
  360. package/src/training/distillation/suite-data.js +189 -0
  361. package/src/training/index.d.ts +10 -0
  362. package/src/training/index.js +10 -0
  363. package/src/training/lora-pipeline.d.ts +40 -0
  364. package/src/training/lora-pipeline.js +793 -0
  365. package/src/training/lora.js +26 -12
  366. package/src/training/loss.js +5 -6
  367. package/src/training/objectives/cross_entropy.js +2 -5
  368. package/src/training/objectives/distill_kd.js +4 -8
  369. package/src/training/objectives/distill_triplet.js +4 -8
  370. package/src/training/objectives/ul_stage2_base.js +4 -8
  371. package/src/training/operator-artifacts.d.ts +62 -0
  372. package/src/training/operator-artifacts.js +140 -0
  373. package/src/training/operator-command.d.ts +5 -0
  374. package/src/training/operator-command.js +455 -0
  375. package/src/training/operator-eval.d.ts +48 -0
  376. package/src/training/operator-eval.js +230 -0
  377. package/src/training/operator-scoreboard.d.ts +5 -0
  378. package/src/training/operator-scoreboard.js +44 -0
  379. package/src/training/optimizer.js +19 -7
  380. package/src/training/runner.d.ts +52 -0
  381. package/src/training/runner.js +31 -5
  382. package/src/training/suite.d.ts +112 -0
  383. package/src/training/suite.js +24 -984
  384. package/src/training/tensor-factory.d.ts +9 -0
  385. package/src/training/tensor-factory.js +13 -0
  386. package/src/training/trainer.js +3 -5
  387. package/src/training/ul_dataset.js +3 -5
  388. package/src/training/workloads.d.ts +164 -0
  389. package/src/training/workloads.js +530 -0
  390. package/src/version.js +1 -1
  391. package/tools/convert-safetensors-node.js +22 -16
  392. package/tools/doppler-cli.js +179 -63
@@ -0,0 +1,93 @@
1
+ import path from 'node:path';
2
+
3
+ import {
4
+ DIRECT_SOURCE_PATH_ARTIFACT_RELATIVE,
5
+ DIRECT_SOURCE_RUNTIME_MODE,
6
+ DIRECT_SOURCE_RUNTIME_SCHEMA,
7
+ DIRECT_SOURCE_RUNTIME_SCHEMA_VERSION,
8
+ getSourceRuntimeMetadata,
9
+ } from './source-runtime-bundle.js';
10
+
11
+ function cloneJsonValue(value) {
12
+ if (typeof structuredClone === 'function') {
13
+ return structuredClone(value);
14
+ }
15
+ return JSON.parse(JSON.stringify(value));
16
+ }
17
+
18
+ function toRelativeArtifactPath(value, artifactDir, label) {
19
+ const raw = String(value || '').trim();
20
+ if (!raw) {
21
+ throw new Error(`${label} path is required.`);
22
+ }
23
+ const resolvedArtifactDir = path.resolve(artifactDir);
24
+ const resolvedTarget = path.resolve(raw);
25
+ const relativePath = path.relative(resolvedArtifactDir, resolvedTarget).replace(/\\/g, '/');
26
+ if (!relativePath || relativePath.startsWith('../') || relativePath === '..') {
27
+ throw new Error(
28
+ `${label} "${raw}" must live inside artifactDir "${resolvedArtifactDir}" for a persisted direct-source manifest.`
29
+ );
30
+ }
31
+ return relativePath;
32
+ }
33
+
34
+ export function materializeSourceRuntimeManifest(manifest, artifactDir) {
35
+ const sourceRuntime = getSourceRuntimeMetadata(manifest);
36
+ if (!sourceRuntime) {
37
+ throw new Error('materializeSourceRuntimeManifest requires manifest.metadata.sourceRuntime.');
38
+ }
39
+ const resolvedArtifactDir = String(artifactDir || '').trim();
40
+ if (!resolvedArtifactDir) {
41
+ throw new Error('materializeSourceRuntimeManifest requires artifactDir.');
42
+ }
43
+
44
+ const nextManifest = cloneJsonValue(manifest);
45
+ if (!nextManifest.metadata || typeof nextManifest.metadata !== 'object') {
46
+ nextManifest.metadata = {};
47
+ }
48
+ const sourceMetadata = nextManifest.metadata.sourceRuntime && typeof nextManifest.metadata.sourceRuntime === 'object'
49
+ ? cloneJsonValue(nextManifest.metadata.sourceRuntime)
50
+ : {};
51
+
52
+ sourceMetadata.mode = DIRECT_SOURCE_RUNTIME_MODE;
53
+ sourceMetadata.schema = DIRECT_SOURCE_RUNTIME_SCHEMA;
54
+ sourceMetadata.schemaVersion = DIRECT_SOURCE_RUNTIME_SCHEMA_VERSION;
55
+ sourceMetadata.hashAlgorithm = sourceRuntime.hashAlgorithm;
56
+ sourceMetadata.pathSemantics = DIRECT_SOURCE_PATH_ARTIFACT_RELATIVE;
57
+ sourceMetadata.sourceFiles = sourceRuntime.sourceFiles.map((entry) => ({
58
+ index: entry.index,
59
+ filename: entry.filename ?? null,
60
+ path: toRelativeArtifactPath(
61
+ entry.path,
62
+ resolvedArtifactDir,
63
+ `source runtime source file ${entry.index}`
64
+ ),
65
+ size: entry.size,
66
+ hash: entry.hash,
67
+ hashAlgorithm: entry.hashAlgorithm,
68
+ }));
69
+ sourceMetadata.auxiliaryFiles = sourceRuntime.auxiliaryFiles.map((entry) => ({
70
+ path: toRelativeArtifactPath(
71
+ entry.path,
72
+ resolvedArtifactDir,
73
+ `source runtime auxiliary file ${entry.kind}`
74
+ ),
75
+ size: entry.size,
76
+ hash: entry.hash,
77
+ hashAlgorithm: entry.hashAlgorithm,
78
+ kind: entry.kind,
79
+ }));
80
+ sourceMetadata.tokenizer = {
81
+ jsonPath: sourceRuntime.tokenizer.jsonPath
82
+ ? toRelativeArtifactPath(sourceRuntime.tokenizer.jsonPath, resolvedArtifactDir, 'source runtime tokenizer json')
83
+ : null,
84
+ configPath: sourceRuntime.tokenizer.configPath
85
+ ? toRelativeArtifactPath(sourceRuntime.tokenizer.configPath, resolvedArtifactDir, 'source runtime tokenizer config')
86
+ : null,
87
+ modelPath: sourceRuntime.tokenizer.modelPath
88
+ ? toRelativeArtifactPath(sourceRuntime.tokenizer.modelPath, resolvedArtifactDir, 'source runtime tokenizer model')
89
+ : null,
90
+ };
91
+ nextManifest.metadata.sourceRuntime = sourceMetadata;
92
+ return nextManifest;
93
+ }
@@ -1,6 +1,7 @@
1
- import { acquireBuffer, uploadData, readBuffer } from '../memory/buffer-pool.js';
1
+ import { acquireBuffer, uploadData, readBuffer, releaseBuffer } from '../memory/buffer-pool.js';
2
2
  import { createTensor, tensorBytes } from '../gpu/tensor.js';
3
3
  import { f16ToF32Array } from '../inference/kv-cache/types.js';
4
+ import { createUploadedTensor } from './tensor-factory.js';
4
5
 
5
6
  function toFloat32(buffer, dtype) {
6
7
  if (dtype === 'f16') {
@@ -67,9 +68,7 @@ export async function buildAttentionSoftmaxCache(q, k, options) {
67
68
  const kData = toFloat32(kBuf, k.dtype);
68
69
  const sData = computeSoftmax(qData, kData, options);
69
70
  const { seqLen, numHeads } = options;
70
- const outBuf = acquireBuffer(tensorBytes([numHeads, seqLen, seqLen], 'f32'), undefined, 'attn_softmax_cache');
71
- uploadData(outBuf, sData);
72
- return createTensor(outBuf, 'f32', [numHeads, seqLen, seqLen], 'attn_softmax_cache');
71
+ return createUploadedTensor(sData, 'f32', [numHeads, seqLen, seqLen], 'attn_softmax_cache');
73
72
  }
74
73
 
75
74
  export async function attentionBackwardCpu(
@@ -201,17 +200,33 @@ export async function attentionBackwardCpu(
201
200
  }
202
201
  }
203
202
 
204
- const qBufOut = acquireBuffer(tensorBytes([seqLen, numHeads, headDim], 'f32'), undefined, 'attn_backward_q');
205
- const kBufOut = acquireBuffer(tensorBytes([seqLen, numHeads, headDim], 'f32'), undefined, 'attn_backward_k');
206
- const vBufOut = acquireBuffer(tensorBytes([seqLen, numHeads, headDim], 'f32'), undefined, 'attn_backward_v');
207
-
208
- uploadData(qBufOut, dQ);
209
- uploadData(kBufOut, dK);
210
- uploadData(vBufOut, dV);
211
-
212
- return {
213
- gradQ: createTensor(qBufOut, 'f32', [seqLen, numHeads, headDim], 'attn_grad_q'),
214
- gradK: createTensor(kBufOut, 'f32', [seqLen, numHeads, headDim], 'attn_grad_k'),
215
- gradV: createTensor(vBufOut, 'f32', [seqLen, numHeads, headDim], 'attn_grad_v'),
216
- };
203
+ let qBufOut = null;
204
+ let kBufOut = null;
205
+ let vBufOut = null;
206
+ try {
207
+ qBufOut = acquireBuffer(tensorBytes([seqLen, numHeads, headDim], 'f32'), undefined, 'attn_backward_q');
208
+ kBufOut = acquireBuffer(tensorBytes([seqLen, numHeads, headDim], 'f32'), undefined, 'attn_backward_k');
209
+ vBufOut = acquireBuffer(tensorBytes([seqLen, numHeads, headDim], 'f32'), undefined, 'attn_backward_v');
210
+
211
+ uploadData(qBufOut, dQ);
212
+ uploadData(kBufOut, dK);
213
+ uploadData(vBufOut, dV);
214
+
215
+ return {
216
+ gradQ: createTensor(qBufOut, 'f32', [seqLen, numHeads, headDim], 'attn_grad_q'),
217
+ gradK: createTensor(kBufOut, 'f32', [seqLen, numHeads, headDim], 'attn_grad_k'),
218
+ gradV: createTensor(vBufOut, 'f32', [seqLen, numHeads, headDim], 'attn_grad_v'),
219
+ };
220
+ } catch (error) {
221
+ if (qBufOut) {
222
+ releaseBuffer(qBufOut);
223
+ }
224
+ if (kBufOut) {
225
+ releaseBuffer(kBufOut);
226
+ }
227
+ if (vBufOut) {
228
+ releaseBuffer(vBufOut);
229
+ }
230
+ throw error;
231
+ }
217
232
  }
@@ -6,6 +6,7 @@ import { acquireBuffer, readBuffer, releaseBuffer, uploadData } from '../memory/
6
6
  import { createTensor } from '../gpu/tensor.js';
7
7
  import { attentionBackwardCpu } from './attention-backward.js';
8
8
  import { f16ToF32Array, f32ToF16Array } from '../inference/kv-cache/types.js';
9
+ import { createUploadedTensor } from './tensor-factory.js';
9
10
 
10
11
  export const OpType = {
11
12
  EMBED: 'embed',
@@ -35,6 +36,7 @@ export class AutogradTape {
35
36
  constructor(registry) {
36
37
  this.registry = registry;
37
38
  this.records = [];
39
+ this.retainedBuffers = new Set();
38
40
  }
39
41
 
40
42
  watch(tensor) {
@@ -43,6 +45,13 @@ export class AutogradTape {
43
45
 
44
46
  async record(op, fn, inputs, options = {}) {
45
47
  const output = await fn(...inputs);
48
+ if (Array.isArray(options.retainBuffers)) {
49
+ for (const buffer of options.retainBuffers) {
50
+ if (buffer) {
51
+ this.retainedBuffers.add(buffer);
52
+ }
53
+ }
54
+ }
46
55
  this.records.push({ op, inputs, output, options });
47
56
  return output;
48
57
  }
@@ -50,31 +59,40 @@ export class AutogradTape {
50
59
  async backward(gradOutput) {
51
60
  const grads = new Map();
52
61
  const seeds = this.normalizeBackwardSeeds(gradOutput);
53
- for (const seed of seeds) {
54
- await this.accumulateGrad(grads, seed.tensor, seed.grad);
55
- }
56
-
57
- for (let i = this.records.length - 1; i >= 0; i -= 1) {
58
- const record = this.records[i];
59
- const entry = this.registry.ops[record.op];
60
- if (!entry) {
61
- continue;
62
+ try {
63
+ for (const seed of seeds) {
64
+ await this.accumulateGrad(grads, seed.tensor, seed.grad);
62
65
  }
63
66
 
64
- const gradOut = grads.get(record.output);
65
- if (!gradOut) {
66
- continue;
67
- }
67
+ for (let i = this.records.length - 1; i >= 0; i -= 1) {
68
+ const record = this.records[i];
69
+ const entry = this.registry.ops[record.op];
70
+ if (!entry) {
71
+ continue;
72
+ }
73
+
74
+ const gradOut = grads.get(record.output);
75
+ if (!gradOut) {
76
+ continue;
77
+ }
68
78
 
69
- const gradsOut = await this.runBackward(entry.backward, record, gradOut);
70
- for (const { input, grad } of gradsOut) {
71
- if (input && grad) {
72
- await this.accumulateGrad(grads, input, grad);
79
+ const gradsOut = await this.runBackward(entry.backward, record, gradOut);
80
+ for (const { input, grad } of gradsOut) {
81
+ if (input && grad) {
82
+ await this.accumulateGrad(grads, input, grad);
83
+ }
73
84
  }
74
85
  }
75
- }
76
86
 
77
- return grads;
87
+ return grads;
88
+ } finally {
89
+ for (const buffer of this.retainedBuffers) {
90
+ try {
91
+ releaseBuffer(buffer);
92
+ } catch {}
93
+ }
94
+ this.retainedBuffers.clear();
95
+ }
78
96
  }
79
97
 
80
98
  isTensorLike(value) {
@@ -245,9 +263,7 @@ export class AutogradTape {
245
263
  expanded.set(gradRow.subarray(0, copyCount), rowOffset);
246
264
  const dtype = gradOut.dtype === 'f16' ? 'f16' : 'f32';
247
265
  const payload = dtype === 'f16' ? f32ToF16Array(expanded) : expanded;
248
- const outBuffer = acquireBuffer(payload.byteLength, undefined, 'row_slice_backward_output');
249
- uploadData(outBuffer, payload);
250
- return createTensor(outBuffer, dtype, [rows, cols], 'row_slice_backward_output');
266
+ return createUploadedTensor(payload, dtype, [rows, cols], 'row_slice_backward_output');
251
267
  }
252
268
 
253
269
  resolveSiluRowsplitGate(gateValue, activation) {
@@ -305,9 +321,7 @@ export class AutogradTape {
305
321
 
306
322
  const dtype = gradOut.dtype === 'f16' ? 'f16' : 'f32';
307
323
  const payload = dtype === 'f16' ? f32ToF16Array(output) : output;
308
- const outBuffer = acquireBuffer(payload.byteLength, undefined, 'silu_rowsplit_backward_output');
309
- uploadData(outBuffer, payload);
310
- return createTensor(outBuffer, dtype, [numTokens, dim * 2], 'silu_rowsplit_backward_output');
324
+ return createUploadedTensor(payload, dtype, [numTokens, dim * 2], 'silu_rowsplit_backward_output');
311
325
  }
312
326
 
313
327
  async accumulateLargeGradF32(existing, grad, size, shape) {
@@ -317,35 +331,49 @@ export class AutogradTape {
317
331
  }
318
332
  const bytesPerElement = 4;
319
333
  const outputBuffer = acquireBuffer(size * bytesPerElement, undefined, 'grad_accum_large_output');
320
-
321
- for (let offset = 0; offset < size; offset += MAX_RESIDUAL_ELEMENTS_PER_DISPATCH) {
322
- const chunkElements = Math.min(MAX_RESIDUAL_ELEMENTS_PER_DISPATCH, size - offset);
323
- const chunkBytes = chunkElements * bytesPerElement;
324
- const chunkOffsetBytes = offset * bytesPerElement;
325
-
326
- const aChunkBuffer = acquireBuffer(chunkBytes, undefined, 'grad_accum_large_a_chunk');
327
- const bChunkBuffer = acquireBuffer(chunkBytes, undefined, 'grad_accum_large_b_chunk');
328
- const copyIn = device.createCommandEncoder();
329
- copyIn.copyBufferToBuffer(existing.buffer, chunkOffsetBytes, aChunkBuffer, 0, chunkBytes);
330
- copyIn.copyBufferToBuffer(grad.buffer, chunkOffsetBytes, bChunkBuffer, 0, chunkBytes);
331
- device.queue.submit([copyIn.finish()]);
332
-
333
- const aChunk = createTensor(aChunkBuffer, 'f32', [chunkElements], 'grad_accum_large_a_tensor');
334
- const bChunk = createTensor(bChunkBuffer, 'f32', [chunkElements], 'grad_accum_large_b_tensor');
335
- const summedChunk = await runResidualAdd(aChunk, bChunk, chunkElements);
336
-
337
- const copyOut = device.createCommandEncoder();
338
- copyOut.copyBufferToBuffer(summedChunk.buffer, 0, outputBuffer, chunkOffsetBytes, chunkBytes);
339
- device.queue.submit([copyOut.finish()]);
340
-
341
- releaseBuffer(aChunkBuffer);
342
- releaseBuffer(bChunkBuffer);
343
- if (summedChunk?.buffer && summedChunk.buffer !== outputBuffer) {
344
- releaseBuffer(summedChunk.buffer);
334
+ try {
335
+ for (let offset = 0; offset < size; offset += MAX_RESIDUAL_ELEMENTS_PER_DISPATCH) {
336
+ const chunkElements = Math.min(MAX_RESIDUAL_ELEMENTS_PER_DISPATCH, size - offset);
337
+ const chunkBytes = chunkElements * bytesPerElement;
338
+ const chunkOffsetBytes = offset * bytesPerElement;
339
+
340
+ let aChunkBuffer = null;
341
+ let bChunkBuffer = null;
342
+ let summedChunkBuffer = null;
343
+ try {
344
+ aChunkBuffer = acquireBuffer(chunkBytes, undefined, 'grad_accum_large_a_chunk');
345
+ bChunkBuffer = acquireBuffer(chunkBytes, undefined, 'grad_accum_large_b_chunk');
346
+ const copyIn = device.createCommandEncoder();
347
+ copyIn.copyBufferToBuffer(existing.buffer, chunkOffsetBytes, aChunkBuffer, 0, chunkBytes);
348
+ copyIn.copyBufferToBuffer(grad.buffer, chunkOffsetBytes, bChunkBuffer, 0, chunkBytes);
349
+ device.queue.submit([copyIn.finish()]);
350
+
351
+ const aChunk = createTensor(aChunkBuffer, 'f32', [chunkElements], 'grad_accum_large_a_tensor');
352
+ const bChunk = createTensor(bChunkBuffer, 'f32', [chunkElements], 'grad_accum_large_b_tensor');
353
+ const summedChunk = await runResidualAdd(aChunk, bChunk, chunkElements);
354
+ summedChunkBuffer = summedChunk?.buffer ?? null;
355
+
356
+ const copyOut = device.createCommandEncoder();
357
+ copyOut.copyBufferToBuffer(summedChunk.buffer, 0, outputBuffer, chunkOffsetBytes, chunkBytes);
358
+ device.queue.submit([copyOut.finish()]);
359
+ } finally {
360
+ if (aChunkBuffer) {
361
+ releaseBuffer(aChunkBuffer);
362
+ }
363
+ if (bChunkBuffer) {
364
+ releaseBuffer(bChunkBuffer);
365
+ }
366
+ if (summedChunkBuffer && summedChunkBuffer !== outputBuffer) {
367
+ releaseBuffer(summedChunkBuffer);
368
+ }
369
+ }
345
370
  }
346
- }
347
371
 
348
- return createTensor(outputBuffer, 'f32', [...shape], 'grad_accum_large_output');
372
+ return createTensor(outputBuffer, 'f32', [...shape], 'grad_accum_large_output');
373
+ } catch (error) {
374
+ releaseBuffer(outputBuffer);
375
+ throw error;
376
+ }
349
377
  }
350
378
 
351
379
 
@@ -0,0 +1,8 @@
1
+ export declare function watchFinalizedCheckpoints(options: {
2
+ checkpointsDir: string;
3
+ manifestPath: string;
4
+ pollIntervalMs?: number | null;
5
+ stopWhenIdle?: boolean;
6
+ signal?: AbortSignal | null;
7
+ onCheckpoint: (markerPath: string) => Promise<void> | void;
8
+ }): Promise<{ ok: true; processedCount: number; manifestPath: string; aborted?: boolean }>;
@@ -0,0 +1,139 @@
1
+ import { readdir, readFile } from 'node:fs/promises';
2
+ import { join, resolve } from 'node:path';
3
+
4
+ import { writeJsonArtifact } from './operator-artifacts.js';
5
+
6
+ async function listCheckpointMarkers(checkpointsDir) {
7
+ const absoluteDir = resolve(String(checkpointsDir));
8
+ const entries = await readdir(absoluteDir, { withFileTypes: true });
9
+ const markers = [];
10
+ for (const entry of entries) {
11
+ if (!entry.isDirectory()) {
12
+ continue;
13
+ }
14
+ const entryPath = join(absoluteDir, entry.name);
15
+ const markerPath = join(entryPath, 'checkpoint.complete.json');
16
+ try {
17
+ await readFile(markerPath, 'utf8');
18
+ markers.push(markerPath);
19
+ continue;
20
+ } catch (error) {
21
+ if (error?.code !== 'ENOENT') {
22
+ throw error;
23
+ }
24
+ }
25
+ markers.push(...await listCheckpointMarkers(entryPath));
26
+ }
27
+ return markers.sort((left, right) => left.localeCompare(right));
28
+ }
29
+
30
+ async function ensureDirectoryExists(directoryPath) {
31
+ try {
32
+ const entries = await readdir(directoryPath, { withFileTypes: true });
33
+ return Array.isArray(entries);
34
+ } catch (error) {
35
+ if (error?.code === 'ENOENT') {
36
+ return false;
37
+ }
38
+ throw error;
39
+ }
40
+ }
41
+
42
+ async function readProcessedManifest(manifestPath) {
43
+ try {
44
+ const raw = await readFile(manifestPath, 'utf8');
45
+ const parsed = JSON.parse(raw);
46
+ const processed = Array.isArray(parsed?.processedCheckpointMarkers)
47
+ ? parsed.processedCheckpointMarkers.filter((entry) => typeof entry === 'string')
48
+ : [];
49
+ return new Set(processed);
50
+ } catch (error) {
51
+ if (error?.code === 'ENOENT') {
52
+ return new Set();
53
+ }
54
+ throw error;
55
+ }
56
+ }
57
+
58
+ function createWatchResult(processed, manifestPath, aborted = false) {
59
+ return {
60
+ ok: true,
61
+ processedCount: processed.size,
62
+ manifestPath,
63
+ aborted,
64
+ };
65
+ }
66
+
67
+ async function waitForPollInterval(pollIntervalMs, signal) {
68
+ if (!signal) {
69
+ await new Promise((resolvePromise) => setTimeout(resolvePromise, pollIntervalMs));
70
+ return true;
71
+ }
72
+ if (signal.aborted) {
73
+ return false;
74
+ }
75
+ return new Promise((resolvePromise) => {
76
+ const onAbort = () => {
77
+ clearTimeout(timer);
78
+ resolvePromise(false);
79
+ };
80
+ const timer = setTimeout(() => {
81
+ signal.removeEventListener('abort', onAbort);
82
+ resolvePromise(true);
83
+ }, pollIntervalMs);
84
+ signal.addEventListener('abort', onAbort, { once: true });
85
+ });
86
+ }
87
+
88
+ export async function watchFinalizedCheckpoints(options) {
89
+ const checkpointsDir = resolve(String(options.checkpointsDir));
90
+ const manifestPath = resolve(String(options.manifestPath));
91
+ const pollIntervalMs = Number.isFinite(options.pollIntervalMs)
92
+ ? Math.max(100, Math.floor(options.pollIntervalMs))
93
+ : 2000;
94
+ const stopWhenIdle = options.stopWhenIdle === true;
95
+ const onCheckpoint = typeof options.onCheckpoint === 'function'
96
+ ? options.onCheckpoint
97
+ : null;
98
+ const signal = options.signal ?? null;
99
+ if (!onCheckpoint) {
100
+ throw new Error('watchFinalizedCheckpoints requires onCheckpoint(markerPath).');
101
+ }
102
+
103
+ const processed = await readProcessedManifest(manifestPath);
104
+ let idlePolls = 0;
105
+ for (;;) {
106
+ if (signal?.aborted) {
107
+ return createWatchResult(processed, manifestPath, true);
108
+ }
109
+ const checkpointsExist = await ensureDirectoryExists(checkpointsDir);
110
+ const markers = checkpointsExist
111
+ ? await listCheckpointMarkers(checkpointsDir)
112
+ : [];
113
+ let sawNewMarker = false;
114
+ for (const markerPath of markers) {
115
+ if (processed.has(markerPath)) continue;
116
+ sawNewMarker = true;
117
+ await onCheckpoint(markerPath);
118
+ processed.add(markerPath);
119
+ await writeJsonArtifact(manifestPath, {
120
+ artifactType: 'training_checkpoint_watch_manifest',
121
+ schemaVersion: 1,
122
+ generatedAt: new Date().toISOString(),
123
+ processedCheckpointMarkers: [...processed].sort((left, right) => left.localeCompare(right)),
124
+ });
125
+ }
126
+ if (!sawNewMarker) {
127
+ idlePolls += 1;
128
+ if (stopWhenIdle && idlePolls > 0) {
129
+ return createWatchResult(processed, manifestPath);
130
+ }
131
+ } else {
132
+ idlePolls = 0;
133
+ }
134
+ const shouldContinue = await waitForPollInterval(pollIntervalMs, signal);
135
+ if (!shouldContinue) {
136
+ return createWatchResult(processed, manifestPath, true);
137
+ }
138
+ }
139
+ }
@@ -23,7 +23,12 @@ export declare function saveCheckpoint(
23
23
  key: string,
24
24
  data: unknown,
25
25
  options?: CheckpointStoreOptions
26
- ): Promise<void>;
26
+ ): Promise<{
27
+ key: string;
28
+ path: string | null;
29
+ metadata: Record<string, unknown>;
30
+ data: unknown;
31
+ }>;
27
32
 
28
33
  export declare function loadCheckpoint(
29
34
  key: string,
@@ -31,6 +31,13 @@ function openCheckpointDB(options = {}) {
31
31
  });
32
32
  }
33
33
 
34
+ function closeCheckpointDB(db) {
35
+ if (!db || typeof db.close !== 'function') {
36
+ return;
37
+ }
38
+ db.close();
39
+ }
40
+
34
41
  async function resolveNodeCheckpointPath(key, options = {}) {
35
42
  const [{ resolve, join, dirname }, { mkdir }] = await Promise.all([
36
43
  import('node:path'),
@@ -140,9 +147,15 @@ export async function saveCheckpoint(key, payload, options = {}) {
140
147
  const useNodeStore = isNodeRuntime() && typeof indexedDB === 'undefined';
141
148
  const nodePath = useNodeStore ? await resolveNodeCheckpointPath(key, options) : null;
142
149
  const browserStore = useNodeStore ? null : await openCheckpointDB(options);
143
- const previousData = useNodeStore
144
- ? await readNodeCheckpointRecord(nodePath)
145
- : await readCheckpointRecord(browserStore.db, browserStore.storeName, key);
150
+ let previousData;
151
+ try {
152
+ previousData = useNodeStore
153
+ ? await readNodeCheckpointRecord(nodePath)
154
+ : await readCheckpointRecord(browserStore.db, browserStore.storeName, key);
155
+ } catch (error) {
156
+ closeCheckpointDB(browserStore?.db);
157
+ throw error;
158
+ }
146
159
  const previousMetadata = previousData?.metadata || {};
147
160
  const previousLineage = previousMetadata.lineage || {};
148
161
  const previousCheckpointHash = options.priorCheckpointHash
@@ -184,13 +197,35 @@ export async function saveCheckpoint(key, payload, options = {}) {
184
197
 
185
198
  if (useNodeStore) {
186
199
  await writeNodeCheckpointRecord(nodePath, data);
187
- return;
200
+ return {
201
+ key,
202
+ path: nodePath,
203
+ metadata: data.metadata,
204
+ data,
205
+ };
188
206
  }
189
207
 
190
208
  return new Promise((resolve, reject) => {
191
209
  const tx = browserStore.db.transaction(browserStore.storeName, 'readwrite');
192
- tx.oncomplete = () => resolve();
193
- tx.onerror = () => reject(tx.error);
210
+ tx.oncomplete = () => {
211
+ closeCheckpointDB(browserStore.db);
212
+ resolve({
213
+ key,
214
+ path: null,
215
+ metadata: data.metadata,
216
+ data,
217
+ });
218
+ };
219
+ tx.onerror = () => {
220
+ const error = tx.error;
221
+ closeCheckpointDB(browserStore.db);
222
+ reject(error);
223
+ };
224
+ tx.onabort = () => {
225
+ const error = tx.error ?? new Error('Checkpoint transaction aborted');
226
+ closeCheckpointDB(browserStore.db);
227
+ reject(error);
228
+ };
194
229
  const store = tx.objectStore(browserStore.storeName);
195
230
  store.put(data, key);
196
231
  });
@@ -203,7 +238,11 @@ export async function loadCheckpoint(key, options = {}) {
203
238
  ? await readNodeCheckpointRecord(nodePath)
204
239
  : await (async () => {
205
240
  const { db, storeName } = await openCheckpointDB(options);
206
- return readCheckpointRecord(db, storeName, key);
241
+ try {
242
+ return await readCheckpointRecord(db, storeName, key);
243
+ } finally {
244
+ closeCheckpointDB(db);
245
+ }
207
246
  })();
208
247
 
209
248
  if (!data || !data.metadata || !options.expectedMetadata) {
@@ -12,7 +12,8 @@ async function readGradData(grad) {
12
12
  }
13
13
 
14
14
  export async function clipGradients(grads, config) {
15
- const maxNorm = config?.training?.gradient?.maxNorm;
15
+ const maxNorm = config?.training?.gradientClipping?.maxNorm
16
+ ?? config?.training?.gradient?.maxNorm;
16
17
  let sumSq = 0;
17
18
  let totalParamCount = 0;
18
19
 
@@ -1,5 +1,5 @@
1
1
 
2
- import { acquireBuffer, uploadData } from '../../memory/buffer-pool.js';
2
+ import { acquireBuffer, uploadData, releaseBuffer } from '../../memory/buffer-pool.js';
3
3
  import { createTensor } from '../../gpu/tensor.js';
4
4
 
5
5
  function flattenTokenBatch(samples, key) {
@@ -27,14 +27,26 @@ export function buildTokenBatch(samples) {
27
27
  }
28
28
 
29
29
  export function createTokenBatchTensors(batch) {
30
- const inputBuf = acquireBuffer(batch.inputFlat.byteLength, undefined, 'train_input_tokens');
31
- uploadData(inputBuf, batch.inputFlat);
30
+ let inputBuf = null;
31
+ let targetBuf = null;
32
+ try {
33
+ inputBuf = acquireBuffer(batch.inputFlat.byteLength, undefined, 'train_input_tokens');
34
+ uploadData(inputBuf, batch.inputFlat);
32
35
 
33
- const targetBuf = acquireBuffer(batch.targetFlat.byteLength, undefined, 'train_target_tokens');
34
- uploadData(targetBuf, batch.targetFlat);
36
+ targetBuf = acquireBuffer(batch.targetFlat.byteLength, undefined, 'train_target_tokens');
37
+ uploadData(targetBuf, batch.targetFlat);
35
38
 
36
- const input = createTensor(inputBuf, 'f32', [batch.inputFlat.length], 'train_input_tokens');
37
- const targets = createTensor(targetBuf, 'f32', [batch.targetFlat.length], 'train_target_tokens');
39
+ const input = createTensor(inputBuf, 'f32', [batch.inputFlat.length], 'train_input_tokens');
40
+ const targets = createTensor(targetBuf, 'f32', [batch.targetFlat.length], 'train_target_tokens');
38
41
 
39
- return { input, targets, offsets: batch.offsets };
42
+ return { input, targets, offsets: batch.offsets };
43
+ } catch (error) {
44
+ if (inputBuf) {
45
+ releaseBuffer(inputBuf);
46
+ }
47
+ if (targetBuf) {
48
+ releaseBuffer(targetBuf);
49
+ }
50
+ throw error;
51
+ }
40
52
  }