@simulatte/doppler 0.1.5 → 0.1.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (392) hide show
  1. package/CHANGELOG.md +126 -0
  2. package/README.md +25 -17
  3. package/package.json +20 -4
  4. package/src/adapters/adapter-registry.js +12 -1
  5. package/src/adapters/lora-loader.js +23 -6
  6. package/src/bridge/extension-client.d.ts +5 -0
  7. package/src/bridge/extension-client.js +40 -0
  8. package/src/bridge/index.d.ts +2 -1
  9. package/src/bridge/index.js +6 -4
  10. package/src/browser/browser-converter.js +26 -1
  11. package/src/browser/file-picker.js +6 -0
  12. package/src/browser/safetensors-parser-browser.js +84 -1
  13. package/src/browser/shard-io-browser.js +2 -2
  14. package/src/browser/tensor-source-download.js +8 -2
  15. package/src/browser/tensor-source-http.d.ts +1 -0
  16. package/src/browser/tensor-source-http.js +5 -1
  17. package/src/client/doppler-api.browser.js +20 -4
  18. package/src/client/doppler-api.js +19 -3
  19. package/src/client/doppler-provider/generation.js +12 -0
  20. package/src/client/doppler-provider/model-manager.d.ts +10 -0
  21. package/src/client/doppler-provider/model-manager.js +91 -19
  22. package/src/client/doppler-provider/source-runtime.d.ts +2 -1
  23. package/src/client/doppler-provider/source-runtime.js +132 -13
  24. package/src/client/doppler-registry.json +8 -7
  25. package/src/config/backward-registry-loader.js +17 -2
  26. package/src/config/execution-v0-contract-check.js +113 -15
  27. package/src/config/kernel-path-contract-check.js +57 -29
  28. package/src/config/kernel-path-loader.js +5 -36
  29. package/src/config/kernels/kernel-ref-digests.js +39 -39
  30. package/src/config/kernels/registry.js +14 -1
  31. package/src/config/kernels/registry.json +49 -7
  32. package/src/config/loader.d.ts +1 -1
  33. package/src/config/loader.js +43 -4
  34. package/src/config/merge-contract-check.js +59 -4
  35. package/src/config/merge-helpers.js +128 -7
  36. package/src/config/merge.d.ts +1 -0
  37. package/src/config/merge.js +28 -0
  38. package/src/config/param-validator.js +47 -2
  39. package/src/config/presets/kernel-paths/{gemma2-q4k-dequant-f32a.json → gemma2-q4k-dequant-f32a-nosubgroups.json} +3 -3
  40. package/src/config/presets/kernel-paths/gemma3-f16-fused-f32a-online-streamingprefill.json +223 -0
  41. package/src/config/presets/kernel-paths/{gemma3-q4k-dequant-f32a.json → gemma3-q4k-dequant-f32a-nosubgroups.json} +3 -3
  42. package/src/config/presets/kernel-paths/registry.json +29 -8
  43. package/src/config/presets/models/gemma2.json +2 -2
  44. package/src/config/presets/models/qwen3.json +9 -2
  45. package/src/config/presets/models/transformer.json +5 -0
  46. package/src/config/presets/runtime/experiments/bench/gemma3-bench-q4k.json +1 -1
  47. package/src/config/presets/runtime/experiments/debug/gemma3-debug-q4k.json +1 -1
  48. package/src/config/presets/runtime/experiments/verify/gemma3-verify.json +1 -1
  49. package/src/config/presets/runtime/kernels/dequant-f16-q4k.json +6 -13
  50. package/src/config/presets/runtime/kernels/dequant-f32-q4k.json +6 -13
  51. package/src/config/presets/runtime/kernels/embeddinggemma-q4k-dequant-f32a.json +37 -0
  52. package/src/config/presets/runtime/kernels/fused-q4k.json +6 -13
  53. package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f16a.json +33 -0
  54. package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f32a-nosubgroups.json +33 -0
  55. package/src/config/presets/runtime/kernels/gemma2-q4k-fused-f32a.json +33 -0
  56. package/src/config/presets/runtime/kernels/safe-q4k.json +6 -13
  57. package/src/config/presets/runtime/platform/metal-apple-q4k.json +1 -1
  58. package/src/config/required-inference-fields-contract-check.js +6 -0
  59. package/src/config/runtime.js +6 -1
  60. package/src/config/schema/debug.schema.d.ts +5 -0
  61. package/src/config/schema/doppler.schema.js +16 -21
  62. package/src/config/schema/inference-defaults.schema.js +6 -3
  63. package/src/config/schema/inference.schema.d.ts +9 -0
  64. package/src/config/schema/kernel-path.schema.d.ts +11 -1
  65. package/src/config/schema/kernel-thresholds.schema.js +12 -4
  66. package/src/config/schema/manifest.schema.d.ts +8 -1
  67. package/src/config/schema/manifest.schema.js +19 -3
  68. package/src/config/training-defaults.js +30 -22
  69. package/src/converter/conversion-plan.js +94 -9
  70. package/src/converter/core.d.ts +7 -0
  71. package/src/converter/core.js +14 -9
  72. package/src/converter/execution-v0-manifest.js +4 -1
  73. package/src/converter/index.d.ts +1 -0
  74. package/src/converter/index.js +1 -0
  75. package/src/converter/manifest-inference.js +43 -12
  76. package/src/converter/parsers/diffusion.js +0 -3
  77. package/src/converter/quantization-info.js +35 -15
  78. package/src/converter/rope-config.js +42 -0
  79. package/src/converter/shard-packer.d.ts +1 -1
  80. package/src/converter/shard-packer.js +4 -1
  81. package/src/debug/config.js +123 -11
  82. package/src/debug/signals.js +7 -1
  83. package/src/debug/tensor.d.ts +2 -0
  84. package/src/debug/tensor.js +13 -2
  85. package/src/distribution/p2p-control-plane.js +52 -12
  86. package/src/distribution/p2p-observability.js +43 -7
  87. package/src/distribution/p2p-webrtc-browser.js +20 -0
  88. package/src/distribution/shard-delivery.js +77 -26
  89. package/src/formats/gguf/types.js +33 -16
  90. package/src/formats/rdrr/groups.d.ts +12 -4
  91. package/src/formats/rdrr/groups.js +3 -6
  92. package/src/formats/rdrr/parsing.js +39 -2
  93. package/src/formats/rdrr/types.d.ts +2 -1
  94. package/src/gpu/command-recorder.js +86 -61
  95. package/src/gpu/device.d.ts +1 -0
  96. package/src/gpu/device.js +131 -19
  97. package/src/gpu/kernel-tuner/benchmarks.js +326 -316
  98. package/src/gpu/kernel-tuner/cache.js +71 -4
  99. package/src/gpu/kernel-tuner/tuner.js +22 -4
  100. package/src/gpu/kernels/attention.js +113 -34
  101. package/src/gpu/kernels/backward/adam.js +62 -58
  102. package/src/gpu/kernels/backward/attention_backward.js +257 -169
  103. package/src/gpu/kernels/backward/conv2d_backward.js +14 -1
  104. package/src/gpu/kernels/bias_add.wgsl +8 -6
  105. package/src/gpu/kernels/bias_add_f16.wgsl +8 -5
  106. package/src/gpu/kernels/cast.js +191 -149
  107. package/src/gpu/kernels/check-stop.js +33 -44
  108. package/src/gpu/kernels/conv2d.js +27 -17
  109. package/src/gpu/kernels/conv2d.wgsl +7 -8
  110. package/src/gpu/kernels/conv2d_f16.wgsl +7 -8
  111. package/src/gpu/kernels/cross_entropy_loss.js +21 -15
  112. package/src/gpu/kernels/depthwise_conv2d.js +37 -26
  113. package/src/gpu/kernels/depthwise_conv2d.wgsl +6 -9
  114. package/src/gpu/kernels/depthwise_conv2d_f16.wgsl +6 -9
  115. package/src/gpu/kernels/dequant.js +178 -126
  116. package/src/gpu/kernels/energy.d.ts +3 -21
  117. package/src/gpu/kernels/energy.js +111 -88
  118. package/src/gpu/kernels/feature-check.js +1 -1
  119. package/src/gpu/kernels/fused_ffn.js +84 -65
  120. package/src/gpu/kernels/fused_matmul_residual.js +56 -33
  121. package/src/gpu/kernels/fused_matmul_rmsnorm.js +62 -45
  122. package/src/gpu/kernels/gather.js +33 -15
  123. package/src/gpu/kernels/gelu.js +19 -11
  124. package/src/gpu/kernels/grouped_pointwise_conv2d.js +34 -23
  125. package/src/gpu/kernels/grouped_pointwise_conv2d.wgsl +6 -9
  126. package/src/gpu/kernels/grouped_pointwise_conv2d_f16.wgsl +6 -9
  127. package/src/gpu/kernels/groupnorm.js +34 -23
  128. package/src/gpu/kernels/kv-quantize.js +5 -2
  129. package/src/gpu/kernels/layernorm.js +35 -19
  130. package/src/gpu/kernels/logit-merge.js +5 -3
  131. package/src/gpu/kernels/matmul.js +83 -39
  132. package/src/gpu/kernels/modulate.js +23 -15
  133. package/src/gpu/kernels/moe.js +221 -175
  134. package/src/gpu/kernels/pixel_shuffle.js +22 -14
  135. package/src/gpu/kernels/pixel_shuffle.wgsl +4 -5
  136. package/src/gpu/kernels/pixel_shuffle_f16.wgsl +4 -5
  137. package/src/gpu/kernels/relu.js +31 -10
  138. package/src/gpu/kernels/relu.wgsl +2 -1
  139. package/src/gpu/kernels/relu_f16.wgsl +2 -1
  140. package/src/gpu/kernels/repeat_channels.js +25 -17
  141. package/src/gpu/kernels/repeat_channels.wgsl +4 -5
  142. package/src/gpu/kernels/repeat_channels_f16.wgsl +4 -5
  143. package/src/gpu/kernels/residual.js +69 -23
  144. package/src/gpu/kernels/residual.wgsl +6 -3
  145. package/src/gpu/kernels/residual_f16.wgsl +2 -1
  146. package/src/gpu/kernels/residual_f16_vec4.wgsl +2 -1
  147. package/src/gpu/kernels/residual_vec4.wgsl +2 -1
  148. package/src/gpu/kernels/rmsnorm.js +96 -28
  149. package/src/gpu/kernels/rmsnorm.wgsl +14 -6
  150. package/src/gpu/kernels/rmsnorm_f16.wgsl +10 -2
  151. package/src/gpu/kernels/rope.d.ts +2 -0
  152. package/src/gpu/kernels/rope.js +14 -1
  153. package/src/gpu/kernels/rope.wgsl +56 -40
  154. package/src/gpu/kernels/sample.js +27 -38
  155. package/src/gpu/kernels/sana_linear_attention.js +19 -12
  156. package/src/gpu/kernels/sana_linear_attention_apply.wgsl +4 -5
  157. package/src/gpu/kernels/sana_linear_attention_apply_f16.wgsl +4 -5
  158. package/src/gpu/kernels/sana_linear_attention_summary.wgsl +4 -0
  159. package/src/gpu/kernels/sana_linear_attention_summary_f16.wgsl +4 -0
  160. package/src/gpu/kernels/scale.js +18 -11
  161. package/src/gpu/kernels/shader-cache.js +4 -2
  162. package/src/gpu/kernels/silu.d.ts +1 -0
  163. package/src/gpu/kernels/silu.js +148 -82
  164. package/src/gpu/kernels/silu.wgsl +19 -9
  165. package/src/gpu/kernels/silu_f16.wgsl +19 -9
  166. package/src/gpu/kernels/softmax.js +44 -25
  167. package/src/gpu/kernels/split_qkv.js +23 -13
  168. package/src/gpu/kernels/transpose.js +31 -10
  169. package/src/gpu/kernels/transpose.wgsl +6 -5
  170. package/src/gpu/kernels/upsample2d.js +22 -13
  171. package/src/gpu/kernels/upsample2d.wgsl +6 -9
  172. package/src/gpu/kernels/upsample2d_f16.wgsl +6 -9
  173. package/src/gpu/kernels/utils.js +35 -13
  174. package/src/gpu/partitioned-buffer-pool.js +10 -2
  175. package/src/gpu/perf-guards.js +2 -9
  176. package/src/gpu/profiler.js +27 -22
  177. package/src/gpu/readback-utils.d.ts +16 -0
  178. package/src/gpu/readback-utils.js +41 -0
  179. package/src/gpu/submit-tracker.js +13 -0
  180. package/src/gpu/uniform-cache.d.ts +1 -0
  181. package/src/gpu/uniform-cache.js +30 -9
  182. package/src/hotswap/intent-bundle.js +6 -0
  183. package/src/hotswap/manifest.d.ts +10 -1
  184. package/src/hotswap/manifest.js +12 -2
  185. package/src/hotswap/runtime.js +30 -8
  186. package/src/index-browser.d.ts +44 -0
  187. package/src/index-browser.js +14 -0
  188. package/src/inference/browser-harness-contract-helpers.d.ts +5 -0
  189. package/src/inference/browser-harness-contract-helpers.js +28 -0
  190. package/src/inference/browser-harness-diffusion-energy-suites.d.ts +2 -0
  191. package/src/inference/browser-harness-diffusion-energy-suites.js +269 -0
  192. package/src/inference/browser-harness-model-helpers.d.ts +16 -0
  193. package/src/inference/browser-harness-model-helpers.js +217 -0
  194. package/src/inference/browser-harness-report-helpers.d.ts +7 -0
  195. package/src/inference/browser-harness-report-helpers.js +42 -0
  196. package/src/inference/browser-harness-runtime-helpers.d.ts +61 -0
  197. package/src/inference/browser-harness-runtime-helpers.js +415 -0
  198. package/src/inference/browser-harness-suite-helpers.d.ts +28 -0
  199. package/src/inference/browser-harness-suite-helpers.js +268 -0
  200. package/src/inference/browser-harness-text-helpers.d.ts +27 -0
  201. package/src/inference/browser-harness-text-helpers.js +788 -0
  202. package/src/inference/browser-harness.d.ts +6 -0
  203. package/src/inference/browser-harness.js +130 -1950
  204. package/src/inference/kv-cache/base.js +140 -94
  205. package/src/inference/kv-cache/tiered.js +5 -3
  206. package/src/inference/moe-router.js +88 -56
  207. package/src/inference/multi-model-network.js +5 -3
  208. package/src/inference/network-evolution.d.ts +11 -2
  209. package/src/inference/network-evolution.js +20 -21
  210. package/src/inference/pipelines/context.d.ts +3 -0
  211. package/src/inference/pipelines/context.js +142 -2
  212. package/src/inference/pipelines/diffusion/helpers.js +7 -2
  213. package/src/inference/pipelines/diffusion/pipeline.js +17 -7
  214. package/src/inference/pipelines/diffusion/sd3-transformer.js +10 -10
  215. package/src/inference/pipelines/diffusion/text-encoder-gpu.d.ts +5 -0
  216. package/src/inference/pipelines/diffusion/text-encoder-gpu.js +27 -15
  217. package/src/inference/pipelines/diffusion/vae.js +3 -7
  218. package/src/inference/pipelines/energy/pipeline.js +27 -21
  219. package/src/inference/pipelines/energy/quintel.d.ts +5 -0
  220. package/src/inference/pipelines/energy/quintel.js +11 -0
  221. package/src/inference/pipelines/energy-head/row-head-pipeline.js +17 -13
  222. package/src/inference/pipelines/structured/json-head-pipeline.js +26 -11
  223. package/src/inference/pipelines/text/attention/projections.js +151 -101
  224. package/src/inference/pipelines/text/attention/record.js +73 -10
  225. package/src/inference/pipelines/text/attention/run.js +73 -10
  226. package/src/inference/pipelines/text/chat-format.js +25 -1
  227. package/src/inference/pipelines/text/config.d.ts +4 -0
  228. package/src/inference/pipelines/text/config.js +71 -5
  229. package/src/inference/pipelines/text/embed.js +2 -8
  230. package/src/inference/pipelines/text/execution-plan.js +64 -50
  231. package/src/inference/pipelines/text/execution-v0-contract-helpers.d.ts +59 -0
  232. package/src/inference/pipelines/text/execution-v0-contract-helpers.js +937 -0
  233. package/src/inference/pipelines/text/execution-v0-runtime-builders.d.ts +15 -0
  234. package/src/inference/pipelines/text/execution-v0-runtime-builders.js +279 -0
  235. package/src/inference/pipelines/text/execution-v0.js +78 -1002
  236. package/src/inference/pipelines/text/ffn/standard.js +3 -0
  237. package/src/inference/pipelines/text/generator-steps.d.ts +46 -0
  238. package/src/inference/pipelines/text/generator-steps.js +298 -207
  239. package/src/inference/pipelines/text/generator.js +6 -23
  240. package/src/inference/pipelines/text/init.d.ts +4 -0
  241. package/src/inference/pipelines/text/init.js +134 -29
  242. package/src/inference/pipelines/text/kernel-path-auto-select.js +2 -0
  243. package/src/inference/pipelines/text/kernel-trace.d.ts +2 -0
  244. package/src/inference/pipelines/text/kernel-trace.js +6 -0
  245. package/src/inference/pipelines/text/layer.js +14 -9
  246. package/src/inference/pipelines/text/linear-attention.d.ts +10 -0
  247. package/src/inference/pipelines/text/linear-attention.js +80 -6
  248. package/src/inference/pipelines/text/logits/gpu.js +10 -5
  249. package/src/inference/pipelines/text/logits/index.js +10 -11
  250. package/src/inference/pipelines/text/logits/utils.d.ts +7 -0
  251. package/src/inference/pipelines/text/logits/utils.js +9 -0
  252. package/src/inference/pipelines/text/lora-apply.js +50 -32
  253. package/src/inference/pipelines/text/model-load.js +279 -104
  254. package/src/inference/pipelines/text/moe-cache.js +5 -4
  255. package/src/inference/pipelines/text/moe-cpu-gptoss.js +74 -69
  256. package/src/inference/pipelines/text/moe-cpu.js +42 -38
  257. package/src/inference/pipelines/text/moe-gpu.js +110 -86
  258. package/src/inference/pipelines/text/ops.js +90 -90
  259. package/src/inference/pipelines/text/probes.js +9 -9
  260. package/src/inference/pipelines/text/weights.js +17 -7
  261. package/src/inference/pipelines/text.js +17 -1
  262. package/src/inference/speculative.d.ts +2 -2
  263. package/src/inference/speculative.js +4 -18
  264. package/src/inference/test-harness.d.ts +1 -1
  265. package/src/inference/test-harness.js +15 -5
  266. package/src/inference/tokenizer.d.ts +0 -5
  267. package/src/inference/tokenizer.js +4 -23
  268. package/src/inference/tokenizers/bpe.js +9 -0
  269. package/src/inference/tokenizers/bundled.js +176 -33
  270. package/src/inference/tokenizers/sentencepiece.js +12 -0
  271. package/src/loader/doppler-loader.js +38 -22
  272. package/src/loader/dtype-utils.js +3 -44
  273. package/src/loader/embedding-loader.js +7 -3
  274. package/src/loader/experts/expert-cache.js +13 -6
  275. package/src/loader/experts/expert-loader.js +10 -6
  276. package/src/loader/final-weights-loader.js +8 -4
  277. package/src/loader/layer-loader.js +2 -1
  278. package/src/loader/loader-state.js +2 -2
  279. package/src/loader/memory-monitor.js +8 -0
  280. package/src/loader/multi-model-loader.d.ts +14 -0
  281. package/src/loader/multi-model-loader.js +70 -24
  282. package/src/loader/shard-cache.js +81 -12
  283. package/src/loader/shard-resolver.js +25 -3
  284. package/src/loader/tensors/tensor-loader.js +209 -144
  285. package/src/loader/tensors/tensor-reader.js +76 -19
  286. package/src/loader/weight-downcast.js +1 -1
  287. package/src/memory/buffer-pool.d.ts +9 -1
  288. package/src/memory/buffer-pool.js +109 -44
  289. package/src/memory/unified-detect.js +1 -1
  290. package/src/rules/inference/kernel-path.rules.json +24 -8
  291. package/src/rules/rule-registry.js +25 -1
  292. package/src/rules/tooling/command-runtime.rules.json +18 -0
  293. package/src/storage/backends/opfs-store.js +68 -24
  294. package/src/storage/downloader.js +364 -83
  295. package/src/storage/index.d.ts +3 -0
  296. package/src/storage/index.js +3 -0
  297. package/src/storage/preflight.d.ts +2 -2
  298. package/src/storage/preflight.js +24 -2
  299. package/src/storage/quickstart-downloader.js +11 -5
  300. package/src/storage/registry.js +10 -4
  301. package/src/storage/reports.js +1 -1
  302. package/src/storage/shard-manager.d.ts +15 -1
  303. package/src/storage/shard-manager.js +51 -3
  304. package/src/storage/source-artifact-store.d.ts +52 -0
  305. package/src/storage/source-artifact-store.js +234 -0
  306. package/src/tooling/command-api-constants.d.ts +9 -0
  307. package/src/tooling/command-api-constants.js +9 -0
  308. package/src/tooling/command-api-family-normalizers.d.ts +9 -0
  309. package/src/tooling/command-api-family-normalizers.js +343 -0
  310. package/src/tooling/command-api-helpers.d.ts +25 -0
  311. package/src/tooling/command-api-helpers.js +262 -0
  312. package/src/tooling/command-api.d.ts +27 -1
  313. package/src/tooling/command-api.js +26 -473
  314. package/src/tooling/command-envelope.js +4 -1
  315. package/src/tooling/command-runner-shared.js +52 -18
  316. package/src/tooling/lean-execution-contract.js +150 -3
  317. package/src/tooling/node-browser-command-runner.d.ts +4 -0
  318. package/src/tooling/node-browser-command-runner.js +218 -273
  319. package/src/tooling/node-command-runner.js +44 -3
  320. package/src/tooling/node-converter.js +27 -1
  321. package/src/tooling/node-source-runtime.d.ts +1 -1
  322. package/src/tooling/node-source-runtime.js +84 -3
  323. package/src/tooling/node-webgpu.js +30 -105
  324. package/src/tooling/opfs-cache.js +21 -4
  325. package/src/tooling/runtime-input-composition.d.ts +38 -0
  326. package/src/tooling/runtime-input-composition.js +86 -0
  327. package/src/tooling/source-runtime-bundle.d.ts +40 -5
  328. package/src/tooling/source-runtime-bundle.js +261 -34
  329. package/src/tooling/source-runtime-materializer.d.ts +6 -0
  330. package/src/tooling/source-runtime-materializer.js +93 -0
  331. package/src/training/attention-backward.js +32 -17
  332. package/src/training/autograd.js +80 -52
  333. package/src/training/checkpoint-watch.d.ts +8 -0
  334. package/src/training/checkpoint-watch.js +139 -0
  335. package/src/training/checkpoint.d.ts +6 -1
  336. package/src/training/checkpoint.js +46 -7
  337. package/src/training/clip.js +2 -1
  338. package/src/training/datasets/token-batch.js +20 -8
  339. package/src/training/distillation/artifacts.d.ts +71 -0
  340. package/src/training/distillation/artifacts.js +132 -0
  341. package/src/training/distillation/checkpoint-watch.d.ts +10 -0
  342. package/src/training/distillation/checkpoint-watch.js +58 -0
  343. package/src/training/distillation/dataset.d.ts +59 -0
  344. package/src/training/distillation/dataset.js +337 -0
  345. package/src/training/distillation/eval.d.ts +34 -0
  346. package/src/training/distillation/eval.js +310 -0
  347. package/src/training/distillation/index.d.ts +29 -0
  348. package/src/training/distillation/index.js +29 -0
  349. package/src/training/distillation/runtime.d.ts +20 -0
  350. package/src/training/distillation/runtime.js +121 -0
  351. package/src/training/distillation/scoreboard.d.ts +6 -0
  352. package/src/training/distillation/scoreboard.js +8 -0
  353. package/src/training/distillation/stage-a.d.ts +45 -0
  354. package/src/training/distillation/stage-a.js +338 -0
  355. package/src/training/distillation/stage-b.d.ts +24 -0
  356. package/src/training/distillation/stage-b.js +20 -0
  357. package/src/training/distillation/student-fixture.d.ts +22 -0
  358. package/src/training/distillation/student-fixture.js +846 -0
  359. package/src/training/distillation/suite-data.d.ts +45 -0
  360. package/src/training/distillation/suite-data.js +189 -0
  361. package/src/training/index.d.ts +10 -0
  362. package/src/training/index.js +10 -0
  363. package/src/training/lora-pipeline.d.ts +40 -0
  364. package/src/training/lora-pipeline.js +793 -0
  365. package/src/training/lora.js +26 -12
  366. package/src/training/loss.js +5 -6
  367. package/src/training/objectives/cross_entropy.js +2 -5
  368. package/src/training/objectives/distill_kd.js +4 -8
  369. package/src/training/objectives/distill_triplet.js +4 -8
  370. package/src/training/objectives/ul_stage2_base.js +4 -8
  371. package/src/training/operator-artifacts.d.ts +62 -0
  372. package/src/training/operator-artifacts.js +140 -0
  373. package/src/training/operator-command.d.ts +5 -0
  374. package/src/training/operator-command.js +455 -0
  375. package/src/training/operator-eval.d.ts +48 -0
  376. package/src/training/operator-eval.js +230 -0
  377. package/src/training/operator-scoreboard.d.ts +5 -0
  378. package/src/training/operator-scoreboard.js +44 -0
  379. package/src/training/optimizer.js +19 -7
  380. package/src/training/runner.d.ts +52 -0
  381. package/src/training/runner.js +31 -5
  382. package/src/training/suite.d.ts +112 -0
  383. package/src/training/suite.js +24 -984
  384. package/src/training/tensor-factory.d.ts +9 -0
  385. package/src/training/tensor-factory.js +13 -0
  386. package/src/training/trainer.js +3 -5
  387. package/src/training/ul_dataset.js +3 -5
  388. package/src/training/workloads.d.ts +164 -0
  389. package/src/training/workloads.js +530 -0
  390. package/src/version.js +1 -1
  391. package/tools/convert-safetensors-node.js +22 -16
  392. package/tools/doppler-cli.js +179 -63
@@ -1,33 +1,29 @@
1
- let fallbackRandomState = (Date.now() >>> 0) || 0x6d2b79f5;
2
-
3
- function unseededRandom() {
4
- if (typeof crypto !== 'undefined' && typeof crypto.getRandomValues === 'function') {
5
- const values = new Uint32Array(1);
6
- crypto.getRandomValues(values);
7
- return values[0] / 4294967296;
1
+ function requireRandomSource(random) {
2
+ if (typeof random !== 'function') {
3
+ throw new Error('network evolution requires an explicit random() source.');
8
4
  }
9
- fallbackRandomState = (fallbackRandomState + 0x6d2b79f5) >>> 0;
10
- return fallbackRandomState / 4294967296;
5
+ return random;
11
6
  }
12
7
 
13
- export const mutateGenome = (genome, mutationRate = 0.1) => {
8
+ export const mutateGenome = (genome, mutationRate = 0.1, random = null) => {
9
+ const sample = requireRandomSource(random);
14
10
 
15
11
  const mutated = JSON.parse(JSON.stringify(genome));
16
- if (unseededRandom() < mutationRate) {
12
+ if (sample() < mutationRate) {
17
13
 
18
14
  const types = ['chain', 'tree', 'mesh', 'dag'];
19
- mutated.topology.type = types[Math.floor(unseededRandom() * types.length)];
15
+ mutated.topology.type = types[Math.floor(sample() * types.length)];
20
16
  }
21
17
 
22
18
  for (const node of mutated.nodes) {
23
- if (unseededRandom() < mutationRate && typeof node.temperature === 'number') {
24
- node.temperature = Math.min(1, Math.max(0, node.temperature + (unseededRandom() - 0.5) * 0.2));
19
+ if (sample() < mutationRate && typeof node.temperature === 'number') {
20
+ node.temperature = Math.min(1, Math.max(0, node.temperature + (sample() - 0.5) * 0.2));
25
21
  }
26
22
  }
27
23
 
28
24
  for (const edge of mutated.edges) {
29
- if (unseededRandom() < mutationRate) {
30
- edge.weight = Math.min(1, Math.max(0, edge.weight + (unseededRandom() - 0.5) * 0.4));
25
+ if (sample() < mutationRate) {
26
+ edge.weight = Math.min(1, Math.max(0, edge.weight + (sample() - 0.5) * 0.4));
31
27
  }
32
28
  }
33
29
 
@@ -35,8 +31,9 @@ export const mutateGenome = (genome, mutationRate = 0.1) => {
35
31
  };
36
32
 
37
33
 
38
- export const crossoverGenome = (a, b) => {
39
- return unseededRandom() < 0.5 ? JSON.parse(JSON.stringify(a)) : JSON.parse(JSON.stringify(b));
34
+ export const crossoverGenome = (a, b, random = null) => {
35
+ const sample = requireRandomSource(random);
36
+ return sample() < 0.5 ? JSON.parse(JSON.stringify(a)) : JSON.parse(JSON.stringify(b));
40
37
  };
41
38
 
42
39
 
@@ -48,7 +45,9 @@ export async function evolveNetwork(config) {
48
45
  mutationRate = 0.1,
49
46
  evaluate,
50
47
  randomGenome,
48
+ random,
51
49
  } = config;
50
+ const sample = requireRandomSource(random);
52
51
 
53
52
  let population = Array.from({ length: populationSize }, () => randomGenome());
54
53
 
@@ -63,9 +62,9 @@ export async function evolveNetwork(config) {
63
62
  const offspring = [];
64
63
 
65
64
  while (offspring.length < populationSize - eliteCount) {
66
- const parentA = scored[Math.floor(unseededRandom() * scored.length)].genome;
67
- const parentB = scored[Math.floor(unseededRandom() * scored.length)].genome;
68
- const child = mutateGenome(crossoverGenome(parentA, parentB), mutationRate);
65
+ const parentA = scored[Math.floor(sample() * scored.length)].genome;
66
+ const parentB = scored[Math.floor(sample() * scored.length)].genome;
67
+ const child = mutateGenome(crossoverGenome(parentA, parentB, sample), mutationRate, sample);
69
68
  offspring.push(child);
70
69
  }
71
70
 
@@ -8,6 +8,8 @@ export type PipelineContextOptions = {
8
8
  assignProgress?: boolean;
9
9
  };
10
10
 
11
+ export declare function restorePipelineContexts(target: Record<string, unknown>): boolean;
12
+
11
13
  export declare function applyPipelineContexts(
12
14
  target: Record<string, unknown>,
13
15
  contexts?: Record<string, unknown>,
@@ -15,4 +17,5 @@ export declare function applyPipelineContexts(
15
17
  ): {
16
18
  runtimeConfig: Record<string, unknown>;
17
19
  sharedDebug: Record<string, unknown> | null | undefined;
20
+ restore: () => void;
18
21
  };
@@ -1,8 +1,115 @@
1
- import { getDevice, setDevice } from '../../gpu/device.js';
1
+ import {
2
+ getDevice,
3
+ getKernelCapabilities,
4
+ getPlatformConfig,
5
+ setDevice,
6
+ } from '../../gpu/device.js';
2
7
  import { applyDebugConfig, setGPUDevice } from '../../debug/index.js';
3
8
  import { getRuntimeConfig, setRuntimeConfig } from '../../config/runtime.js';
9
+ import {
10
+ getLogLevel,
11
+ getTrace,
12
+ isSilentMode,
13
+ setLogLevel,
14
+ setSilentMode,
15
+ setTrace,
16
+ } from '../../debug/config.js';
17
+ import {
18
+ gpuDevice as debugGpuDevice,
19
+ traceBreakOnAnomaly,
20
+ traceLayerFilter,
21
+ traceMaxDecodeSteps,
22
+ } from '../../debug/config.js';
23
+
24
+ const RESTORE_PIPELINE_CONTEXTS = Symbol('restorePipelineContexts');
25
+
26
+ function captureTargetField(target, key) {
27
+ return {
28
+ present: Object.prototype.hasOwnProperty.call(target, key),
29
+ value: target[key],
30
+ };
31
+ }
32
+
33
+ function restoreTargetField(target, key, snapshot) {
34
+ if (snapshot.present) {
35
+ target[key] = snapshot.value;
36
+ return;
37
+ }
38
+ delete target[key];
39
+ }
40
+
41
+ function captureDebugState() {
42
+ return {
43
+ logLevel: getLogLevel(),
44
+ traceCategories: getTrace(),
45
+ traceLayers: [...traceLayerFilter],
46
+ traceMaxDecodeSteps,
47
+ traceBreakOnAnomaly,
48
+ silentMode: isSilentMode(),
49
+ gpuDevice: debugGpuDevice,
50
+ };
51
+ }
52
+
53
+ function restoreDebugState(snapshot) {
54
+ if (snapshot.silentMode !== isSilentMode()) {
55
+ setSilentMode(snapshot.silentMode);
56
+ }
57
+ if (getLogLevel() !== snapshot.logLevel) {
58
+ setLogLevel(snapshot.logLevel);
59
+ }
60
+
61
+ const traceCategories = getTrace();
62
+ const traceChanged = traceCategories.length !== snapshot.traceCategories.length
63
+ || traceCategories.some((category, idx) => category !== snapshot.traceCategories[idx])
64
+ || traceLayerFilter.length !== snapshot.traceLayers.length
65
+ || traceLayerFilter.some((layer, idx) => layer !== snapshot.traceLayers[idx])
66
+ || traceMaxDecodeSteps !== snapshot.traceMaxDecodeSteps
67
+ || traceBreakOnAnomaly !== snapshot.traceBreakOnAnomaly;
68
+
69
+ if (traceChanged) {
70
+ if (snapshot.traceCategories.length > 0) {
71
+ setTrace(snapshot.traceCategories.join(','), {
72
+ layers: snapshot.traceLayers.length > 0 ? snapshot.traceLayers : undefined,
73
+ maxDecodeSteps: snapshot.traceMaxDecodeSteps > 0 ? snapshot.traceMaxDecodeSteps : undefined,
74
+ breakOnAnomaly: snapshot.traceBreakOnAnomaly,
75
+ });
76
+ } else {
77
+ setTrace(false);
78
+ }
79
+ }
80
+
81
+ setGPUDevice(snapshot.gpuDevice ?? null);
82
+ }
83
+
84
+ export function restorePipelineContexts(target) {
85
+ const restore = target?.[RESTORE_PIPELINE_CONTEXTS];
86
+ if (typeof restore !== 'function') {
87
+ return false;
88
+ }
89
+ delete target[RESTORE_PIPELINE_CONTEXTS];
90
+ restore();
91
+ return true;
92
+ }
4
93
 
5
94
  export function applyPipelineContexts(target, contexts = {}, options = {}) {
95
+ restorePipelineContexts(target);
96
+
97
+ const previousRuntimeConfig = getRuntimeConfig();
98
+ const previousDevice = getDevice();
99
+ const previousPlatformConfig = getPlatformConfig();
100
+ const previousAdapterInfo = previousDevice
101
+ ? (getKernelCapabilities().adapterInfo ?? null)
102
+ : null;
103
+ const previousDebugState = captureDebugState();
104
+ const targetSnapshot = {
105
+ gpuContext: captureTargetField(target, 'gpuContext'),
106
+ useGPU: captureTargetField(target, 'useGPU'),
107
+ memoryContext: captureTargetField(target, 'memoryContext'),
108
+ storageContext: captureTargetField(target, 'storageContext'),
109
+ baseUrl: captureTargetField(target, 'baseUrl'),
110
+ _onProgress: captureTargetField(target, '_onProgress'),
111
+ };
112
+
6
113
  const runtimeConfig = contexts.runtimeConfig
7
114
  ? setRuntimeConfig(contexts.runtimeConfig)
8
115
  : getRuntimeConfig();
@@ -40,5 +147,38 @@ export function applyPipelineContexts(target, contexts = {}, options = {}) {
40
147
  target._onProgress = contexts.onProgress;
41
148
  }
42
149
 
43
- return { runtimeConfig, sharedDebug };
150
+ let restored = false;
151
+ const restore = () => {
152
+ if (restored) {
153
+ return;
154
+ }
155
+ restored = true;
156
+ delete target[RESTORE_PIPELINE_CONTEXTS];
157
+
158
+ setRuntimeConfig(previousRuntimeConfig);
159
+ if (previousDevice) {
160
+ setDevice(previousDevice, {
161
+ platformConfig: previousPlatformConfig,
162
+ adapterInfo: previousAdapterInfo,
163
+ });
164
+ } else {
165
+ setDevice(null);
166
+ }
167
+ restoreDebugState(previousDebugState);
168
+ restoreTargetField(target, 'gpuContext', targetSnapshot.gpuContext);
169
+ restoreTargetField(target, 'useGPU', targetSnapshot.useGPU);
170
+ restoreTargetField(target, 'memoryContext', targetSnapshot.memoryContext);
171
+ restoreTargetField(target, 'storageContext', targetSnapshot.storageContext);
172
+ restoreTargetField(target, 'baseUrl', targetSnapshot.baseUrl);
173
+ restoreTargetField(target, '_onProgress', targetSnapshot._onProgress);
174
+ };
175
+
176
+ Object.defineProperty(target, RESTORE_PIPELINE_CONTEXTS, {
177
+ value: restore,
178
+ configurable: true,
179
+ enumerable: false,
180
+ writable: false,
181
+ });
182
+
183
+ return { runtimeConfig, sharedDebug, restore };
44
184
  }
@@ -54,8 +54,13 @@ export function createDiffusionIndexBuffer(device, indices, label) {
54
54
  size: indices.byteLength,
55
55
  usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST,
56
56
  });
57
- device.queue.writeBuffer(buffer, 0, indices);
58
- return buffer;
57
+ try {
58
+ device.queue.writeBuffer(buffer, 0, indices);
59
+ return buffer;
60
+ } catch (error) {
61
+ buffer.destroy();
62
+ throw error;
63
+ }
59
64
  }
60
65
 
61
66
  export function expectDiffusionWeight(weight, label) {
@@ -1,7 +1,7 @@
1
1
  import { getDevice, getKernelCapabilities } from '../../../gpu/device.js';
2
2
  import { log, trace } from '../../../debug/index.js';
3
3
  import { registerPipeline } from '../registry.js';
4
- import { applyPipelineContexts } from '../context.js';
4
+ import { applyPipelineContexts, restorePipelineContexts } from '../context.js';
5
5
  import { createInitializedPipeline } from '../factory.js';
6
6
  import { createRng, sampleNormal } from '../rng.js';
7
7
  import { initializeDiffusion } from './init.js';
@@ -52,6 +52,18 @@ function generateLatents(width, height, channels, latentScale, seed) {
52
52
  return { latents, latentWidth, latentHeight };
53
53
  }
54
54
 
55
+ function generateNoiseVector(size, seed) {
56
+ if (!Number.isFinite(size) || size <= 0) {
57
+ throw new Error(`generateNoiseVector requires a positive size, got ${size}.`);
58
+ }
59
+ const out = new Float32Array(size);
60
+ const rand = createRng(seed ?? createRandomSeed());
61
+ for (let i = 0; i < size; i++) {
62
+ out[i] = sampleNormal(rand);
63
+ }
64
+ return out;
65
+ }
66
+
55
67
  function extractTokenSet(tokensByEncoder, key) {
56
68
  const output = {};
57
69
  for (const [name, entry] of Object.entries(tokensByEncoder || {})) {
@@ -195,13 +207,10 @@ async function applySchedulerStep(latentsTensor, scheduler, stepIndex, timestep,
195
207
  const isFinalStep = stepIndex + 1 >= scheduler.timesteps.length - 1;
196
208
  const noise = isFinalStep
197
209
  ? null
198
- : generateLatents(
199
- runtime.latent.width,
200
- runtime.latent.height,
201
- runtime.latent.channels,
202
- runtime.latent.scale,
210
+ : generateNoiseVector(
211
+ sample.length,
203
212
  (options.seedBase ?? createRandomSeed()) + stepIndex + 1
204
- ).latents;
213
+ );
205
214
  const step = stepScmScheduler(scheduler, modelOutput, timestep, sample, stepIndex, noise);
206
215
  return createLatentTensor(step.prevSample, [...latentsTensor.shape], runtime);
207
216
  }
@@ -310,6 +319,7 @@ export class DiffusionPipeline {
310
319
  this.vaeWeights = null;
311
320
  this.textEncoderWeights = null;
312
321
  this.transformerWeights = null;
322
+ restorePipelineContexts(this);
313
323
  }
314
324
 
315
325
  async ensureVaeWeights() {
@@ -299,26 +299,26 @@ function resolveModulationSegments(weight, hiddenSize, fallbackSegments, resolve
299
299
  if (Number.isInteger(segments) && segments > 0) {
300
300
  return segments;
301
301
  }
302
- log.warn(
303
- 'Diffusion',
304
- `Modulation segments mismatch for ${name || 'unknown'}: rows=${rows}, hidden=${hiddenSize}, fallback=${fallbackSegments}`
302
+ throw new Error(
303
+ `Modulation segments mismatch for ${name || 'unknown'}: rows=${rows}, hidden=${hiddenSize}, ` +
304
+ `expected an integer multiple instead of falling back to ${fallbackSegments}.`
305
305
  );
306
306
  }
307
- return fallbackSegments;
307
+ throw new Error(
308
+ `Modulation tensor "${name || 'unknown'}" is missing shape metadata. ` +
309
+ `Runtime cannot fall back to ${fallbackSegments} segments.`
310
+ );
308
311
  }
309
312
 
310
313
  function resolveModulationOffsets(segments, hiddenSize) {
311
- if (segments >= 9) {
314
+ if (segments === 9) {
312
315
  return {
313
316
  attn: { scale: 0, shift: hiddenSize, gate: hiddenSize * 2 },
314
317
  attn2: { scale: hiddenSize * 3, shift: hiddenSize * 4, gate: hiddenSize * 5 },
315
318
  ff: { scale: hiddenSize * 6, shift: hiddenSize * 7, gate: hiddenSize * 8 },
316
319
  };
317
320
  }
318
- if (segments >= 6) {
319
- if (segments !== 6) {
320
- log.warn('Diffusion', `Unexpected modulation segment count=${segments}; using 6-segment layout.`);
321
- }
321
+ if (segments === 6) {
322
322
  const attn = { scale: 0, shift: hiddenSize, gate: hiddenSize * 2 };
323
323
  return {
324
324
  attn,
@@ -326,7 +326,7 @@ function resolveModulationOffsets(segments, hiddenSize) {
326
326
  ff: { scale: hiddenSize * 3, shift: hiddenSize * 4, gate: hiddenSize * 5 },
327
327
  };
328
328
  }
329
- throw new Error(`Unsupported modulation segments=${segments} (expected >= 6).`);
329
+ throw new Error(`Unsupported modulation segments=${segments} (expected 6 or 9).`);
330
330
  }
331
331
 
332
332
  async function buildModulation(timeText, weight, bias, hiddenSize, segments, runtime, matmul, weightName, ops) {
@@ -80,3 +80,8 @@ export declare function projectContext(
80
80
  ): Promise<Tensor>;
81
81
 
82
82
  export declare function assertClipHiddenActivationSupported(config: { hidden_act?: string }): void;
83
+
84
+ export declare function resolveGemma2WeightRoot(
85
+ weights: Map<string, any>,
86
+ prefix?: string
87
+ ): string;
@@ -723,8 +723,19 @@ function buildGemma2LayerTypes(layerCount, slidingWindow) {
723
723
  ));
724
724
  }
725
725
 
726
- function getGemma2LayerWeight(weights, prefix, layerIdx, suffix, required = true) {
727
- const key = `${prefix}.model.layers.${layerIdx}.${suffix}`;
726
+ export function resolveGemma2WeightRoot(weights, prefix = 'text_encoder') {
727
+ const nestedRoot = `${prefix}.model`;
728
+ if (weights?.has(`${nestedRoot}.embed_tokens.weight`)) {
729
+ return nestedRoot;
730
+ }
731
+ if (weights?.has(`${prefix}.embed_tokens.weight`)) {
732
+ return prefix;
733
+ }
734
+ return nestedRoot;
735
+ }
736
+
737
+ function getGemma2LayerWeight(weights, weightRoot, layerIdx, suffix, required = true) {
738
+ const key = `${weightRoot}.layers.${layerIdx}.${suffix}`;
728
739
  const weight = weights.get(key) || null;
729
740
  if (!weight && required) {
730
741
  throw new Error(`Missing Gemma2 diffusion weight "${key}".`);
@@ -805,8 +816,9 @@ async function runGemma2TextEncoder(tokens, weightsEntry, config, runtime, optio
805
816
  const tokenIds = normalizeTokens(tokens, options.maxLength ?? resolved.maxPositionEmbeddings, padTokenId);
806
817
  const numTokens = tokenIds.length;
807
818
  const tokenBuffer = createDiffusionIndexBuffer(device, tokenIds, `${prefix}_tokens`);
819
+ const weightRoot = resolveGemma2WeightRoot(weights, prefix);
808
820
 
809
- const embedKey = `${prefix}.model.embed_tokens.weight`;
821
+ const embedKey = `${weightRoot}.embed_tokens.weight`;
810
822
  const embedWeight = expectDiffusionWeight(
811
823
  weights.get(embedKey),
812
824
  embedKey
@@ -837,16 +849,16 @@ async function runGemma2TextEncoder(tokens, weightsEntry, config, runtime, optio
837
849
  const layerWeights = new Map();
838
850
  for (let layerIdx = 0; layerIdx < resolved.numLayers; layerIdx++) {
839
851
  layerWeights.set(`layer_${layerIdx}`, {
840
- inputNorm: getGemma2LayerWeight(weights, prefix, layerIdx, 'input_layernorm.weight'),
841
- qProj: getGemma2LayerWeight(weights, prefix, layerIdx, 'self_attn.q_proj.weight'),
842
- kProj: getGemma2LayerWeight(weights, prefix, layerIdx, 'self_attn.k_proj.weight'),
843
- vProj: getGemma2LayerWeight(weights, prefix, layerIdx, 'self_attn.v_proj.weight'),
844
- oProj: getGemma2LayerWeight(weights, prefix, layerIdx, 'self_attn.o_proj.weight'),
845
- postAttentionNorm: getGemma2LayerWeight(weights, prefix, layerIdx, 'post_attention_layernorm.weight'),
846
- preFeedforwardNorm: getGemma2LayerWeight(weights, prefix, layerIdx, 'pre_feedforward_layernorm.weight'),
847
- gate: getGemma2LayerWeight(weights, prefix, layerIdx, 'mlp.gate_proj.weight'),
848
- up: getGemma2LayerWeight(weights, prefix, layerIdx, 'mlp.up_proj.weight'),
849
- down: getGemma2LayerWeight(weights, prefix, layerIdx, 'mlp.down_proj.weight'),
852
+ inputNorm: getGemma2LayerWeight(weights, weightRoot, layerIdx, 'input_layernorm.weight'),
853
+ qProj: getGemma2LayerWeight(weights, weightRoot, layerIdx, 'self_attn.q_proj.weight'),
854
+ kProj: getGemma2LayerWeight(weights, weightRoot, layerIdx, 'self_attn.k_proj.weight'),
855
+ vProj: getGemma2LayerWeight(weights, weightRoot, layerIdx, 'self_attn.v_proj.weight'),
856
+ oProj: getGemma2LayerWeight(weights, weightRoot, layerIdx, 'self_attn.o_proj.weight'),
857
+ postAttentionNorm: getGemma2LayerWeight(weights, weightRoot, layerIdx, 'post_attention_layernorm.weight'),
858
+ preFeedforwardNorm: getGemma2LayerWeight(weights, weightRoot, layerIdx, 'pre_feedforward_layernorm.weight'),
859
+ gate: getGemma2LayerWeight(weights, weightRoot, layerIdx, 'mlp.gate_proj.weight'),
860
+ up: getGemma2LayerWeight(weights, weightRoot, layerIdx, 'mlp.up_proj.weight'),
861
+ down: getGemma2LayerWeight(weights, weightRoot, layerIdx, 'mlp.down_proj.weight'),
850
862
  });
851
863
  }
852
864
 
@@ -910,10 +922,10 @@ async function runGemma2TextEncoder(tokens, weightsEntry, config, runtime, optio
910
922
  numTokens * resolved.hiddenSize,
911
923
  context
912
924
  );
913
- hidden = createTensor(output.buffer, output.dtype, [numTokens, resolved.hiddenSize], `gemma2_layer_${layerIdx}`);
925
+ hidden = createTensor(output, activationDtype, [numTokens, resolved.hiddenSize], `gemma2_layer_${layerIdx}`);
914
926
  }
915
927
 
916
- const finalNormKey = `${prefix}.model.norm.weight`;
928
+ const finalNormKey = `${weightRoot}.norm.weight`;
917
929
  const finalNorm = expectDiffusionWeight(weights.get(finalNormKey), finalNormKey);
918
930
  const final = await ops.rmsNorm(hidden, getBuffer(finalNorm), resolved.rmsNormEps, {
919
931
  batchSize: numTokens,
@@ -118,13 +118,9 @@ function resolveAttentionHeadShape(channels, config) {
118
118
  headDim: channels / configuredNumHeads,
119
119
  };
120
120
  }
121
-
122
- const fallbackHeadDims = [64, 40, 32, 24, 20, 16, 12, 10, 8, 6, 5, 4, 3, 2, 1];
123
- const headDim = fallbackHeadDims.find((candidate) => candidate <= channels && channels % candidate === 0) || 1;
124
- return {
125
- numHeads: Math.max(1, channels / headDim),
126
- headDim,
127
- };
121
+ throw new Error(
122
+ `VAE attention requires explicit compatible attention_head_dim or num_attention_heads for channels=${channels}.`
123
+ );
128
124
  }
129
125
 
130
126
  function createBiasTensor(weight, label, fallbackDtype = 'f16') {
@@ -16,10 +16,10 @@ import { log, trace } from '../../../debug/index.js';
16
16
  import { DEFAULT_ENERGY_CONFIG } from '../../../config/schema/energy.schema.js';
17
17
  import { f32ToF16Array, f16ToF32Array } from '../../kv-cache/types.js';
18
18
  import { registerPipeline } from '../registry.js';
19
- import { applyPipelineContexts } from '../context.js';
19
+ import { applyPipelineContexts, restorePipelineContexts } from '../context.js';
20
20
  import { createInitializedPipeline } from '../factory.js';
21
21
  import { createRng, sampleNormal } from '../rng.js';
22
- import { mergeQuintelConfig, runQuintelEnergyLoop } from './quintel.js';
22
+ import { buildQuintelKernelFlags, mergeQuintelConfig, runQuintelEnergyLoop } from './quintel.js';
23
23
 
24
24
 
25
25
  function generateRandomArray(count, mode, seed, scale) {
@@ -140,24 +140,28 @@ async function createEnergyTensor(device, data, dtype, shape, label) {
140
140
  const byteLength = data.byteLength;
141
141
  const alignedSize = Math.ceil(byteLength / 4) * 4;
142
142
  const buffer = acquireBuffer(alignedSize, undefined, label);
143
+ try {
144
+ let payload = data;
145
+ if (alignedSize !== byteLength) {
146
+ const padded = new Uint8Array(alignedSize);
147
+ const view = data instanceof ArrayBuffer
148
+ ? new Uint8Array(data)
149
+ : new Uint8Array(data.buffer, data.byteOffset, data.byteLength);
150
+ padded.set(view);
151
+ payload = padded;
152
+ }
143
153
 
144
- let payload = data;
145
- if (alignedSize !== byteLength) {
146
- const padded = new Uint8Array(alignedSize);
147
- const view = data instanceof ArrayBuffer
148
- ? new Uint8Array(data)
149
- : new Uint8Array(data.buffer, data.byteOffset, data.byteLength);
150
- padded.set(view);
151
- payload = padded;
152
- }
153
-
154
- device.queue.writeBuffer(buffer, 0, payload);
155
- const tensor = createTensor(buffer, dtype, shape, label);
156
- const expectedBytes = tensorBytes(shape, dtype);
157
- if (expectedBytes !== byteLength) {
158
- log.warn('Energy', `${label} byte length mismatch: expected ${expectedBytes}, got ${byteLength}`);
154
+ device.queue.writeBuffer(buffer, 0, payload);
155
+ const tensor = createTensor(buffer, dtype, shape, label);
156
+ const expectedBytes = tensorBytes(shape, dtype);
157
+ if (expectedBytes !== byteLength) {
158
+ log.warn('Energy', `${label} byte length mismatch: expected ${expectedBytes}, got ${byteLength}`);
159
+ }
160
+ return tensor;
161
+ } catch (error) {
162
+ releaseBuffer(buffer);
163
+ throw error;
159
164
  }
160
- return tensor;
161
165
  }
162
166
 
163
167
  async function readTensorToFloat32(tensor) {
@@ -202,6 +206,7 @@ export class EnergyPipeline {
202
206
 
203
207
  async unload() {
204
208
  this.manifest = null;
209
+ restorePipelineContexts(this);
205
210
  }
206
211
 
207
212
  async generate(request = {}) {
@@ -336,6 +341,7 @@ export class EnergyPipeline {
336
341
  const centerWeight = Number.isFinite(weights.center) ? weights.center : 1.0;
337
342
  const binarizeWeight = Number.isFinite(weights.binarize) ? weights.binarize : 0.0;
338
343
  const centerTarget = Number.isFinite(quintelConfig.centerTarget) ? quintelConfig.centerTarget : 1.0;
344
+ const flags = buildQuintelKernelFlags(rules, binarizeWeight);
339
345
  const energyHistory = [];
340
346
  const stepTimesMs = [];
341
347
  let lastEnergy = null;
@@ -387,11 +393,11 @@ export class EnergyPipeline {
387
393
  await runEnergyQuintelReduce(stateTensor, {
388
394
  count: elementCount,
389
395
  size,
396
+ flags,
390
397
  symmetryWeight,
391
398
  centerWeight,
392
399
  binarizeWeight,
393
400
  centerTarget,
394
- rules,
395
401
  outputBuffer: reduceBuffer,
396
402
  });
397
403
 
@@ -447,13 +453,13 @@ export class EnergyPipeline {
447
453
  await runEnergyQuintelGrad(stateTensor, {
448
454
  count: elementCount,
449
455
  size,
456
+ flags,
450
457
  countDiff: safeCountDiff,
451
458
  symmetryWeight,
452
459
  countWeight,
453
460
  centerWeight,
454
461
  binarizeWeight,
455
462
  centerTarget,
456
- rules,
457
463
  outputBuffer: gradBuffer,
458
464
  });
459
465
 
@@ -471,6 +477,7 @@ export class EnergyPipeline {
471
477
  await runEnergyQuintelUpdate(stateTensor, {
472
478
  count: elementCount,
473
479
  size,
480
+ flags,
474
481
  stepSize,
475
482
  gradientScale,
476
483
  countDiff: safeCountDiff,
@@ -481,7 +488,6 @@ export class EnergyPipeline {
481
488
  centerTarget,
482
489
  clampMin,
483
490
  clampMax,
484
- rules,
485
491
  });
486
492
  }
487
493
 
@@ -84,4 +84,9 @@ export function mergeQuintelConfig(
84
84
  override?: Partial<QuintelEnergyConfig> | null
85
85
  ): QuintelEnergyConfig;
86
86
 
87
+ export function buildQuintelKernelFlags(
88
+ rules: Partial<QuintelRuleConfig> | null | undefined,
89
+ binarizeWeight?: number
90
+ ): number;
91
+
87
92
  export function runQuintelEnergyLoop(options: QuintelEnergyLoopOptions): QuintelEnergyLoopResult;
@@ -22,6 +22,17 @@ export function mergeQuintelConfig(base, override) {
22
22
  };
23
23
  }
24
24
 
25
+ export function buildQuintelKernelFlags(rules, binarizeWeight) {
26
+ let flags = 0;
27
+ if (rules?.mirrorX) flags |= 1;
28
+ if (rules?.mirrorY) flags |= 2;
29
+ if (rules?.diagonal) flags |= 4;
30
+ if (rules?.count) flags |= 8;
31
+ if (rules?.center) flags |= 16;
32
+ if (Number.isFinite(binarizeWeight) && binarizeWeight !== 0) flags |= 32;
33
+ return flags >>> 0;
34
+ }
35
+
25
36
  function applyPairEnergy(state, gradients, indexA, indexB, weight) {
26
37
  const diff = state[indexA] - state[indexB];
27
38
  const energy = weight * diff * diff;
@@ -5,7 +5,7 @@ import { runEnergyEval, runEnergyUpdate } from '../../../gpu/kernels/index.js';
5
5
  import { log } from '../../../debug/index.js';
6
6
  import { f16ToF32Array, f32ToF16Array } from '../../kv-cache/types.js';
7
7
  import { registerPipeline } from '../registry.js';
8
- import { applyPipelineContexts } from '../context.js';
8
+ import { applyPipelineContexts, restorePipelineContexts } from '../context.js';
9
9
  import { createInitializedPipeline } from '../factory.js';
10
10
  import { selectRuleValue } from '../../../rules/rule-registry.js';
11
11
 
@@ -165,19 +165,22 @@ async function createFeatureTensor(device, values, dtype, label) {
165
165
  const byteLength = payload.byteLength;
166
166
  const alignedSize = Math.ceil(byteLength / 4) * 4;
167
167
  const buffer = acquireBuffer(alignedSize, undefined, label);
168
-
169
- if (alignedSize === byteLength) {
170
- device.queue.writeBuffer(buffer, 0, payload);
171
- } else {
172
- const bytes = payload instanceof Uint16Array
173
- ? new Uint8Array(payload.buffer, payload.byteOffset, payload.byteLength)
174
- : new Uint8Array(payload.buffer, payload.byteOffset, payload.byteLength);
175
- const padded = new Uint8Array(alignedSize);
176
- padded.set(bytes);
177
- device.queue.writeBuffer(buffer, 0, padded);
168
+ try {
169
+ if (alignedSize === byteLength) {
170
+ device.queue.writeBuffer(buffer, 0, payload);
171
+ } else {
172
+ const bytes = payload instanceof Uint16Array
173
+ ? new Uint8Array(payload.buffer, payload.byteOffset, payload.byteLength)
174
+ : new Uint8Array(payload.buffer, payload.byteOffset, payload.byteLength);
175
+ const padded = new Uint8Array(alignedSize);
176
+ padded.set(bytes);
177
+ device.queue.writeBuffer(buffer, 0, padded);
178
+ }
179
+ return createTensor(buffer, dtype, [values.length], label);
180
+ } catch (error) {
181
+ releaseBuffer(buffer);
182
+ throw error;
178
183
  }
179
-
180
- return createTensor(buffer, dtype, [values.length], label);
181
184
  }
182
185
 
183
186
  async function readTensorF32(tensor) {
@@ -307,6 +310,7 @@ export class EnergyRowHeadPipeline {
307
310
  this.manifest = null;
308
311
  this.model = null;
309
312
  this.stats = {};
313
+ restorePipelineContexts(this);
310
314
  }
311
315
 
312
316
  async scoreRows(request = {}) {