@simulatte/doppler 0.1.5 → 0.1.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (392) hide show
  1. package/CHANGELOG.md +126 -0
  2. package/README.md +25 -17
  3. package/package.json +20 -4
  4. package/src/adapters/adapter-registry.js +12 -1
  5. package/src/adapters/lora-loader.js +23 -6
  6. package/src/bridge/extension-client.d.ts +5 -0
  7. package/src/bridge/extension-client.js +40 -0
  8. package/src/bridge/index.d.ts +2 -1
  9. package/src/bridge/index.js +6 -4
  10. package/src/browser/browser-converter.js +26 -1
  11. package/src/browser/file-picker.js +6 -0
  12. package/src/browser/safetensors-parser-browser.js +84 -1
  13. package/src/browser/shard-io-browser.js +2 -2
  14. package/src/browser/tensor-source-download.js +8 -2
  15. package/src/browser/tensor-source-http.d.ts +1 -0
  16. package/src/browser/tensor-source-http.js +5 -1
  17. package/src/client/doppler-api.browser.js +20 -4
  18. package/src/client/doppler-api.js +19 -3
  19. package/src/client/doppler-provider/generation.js +12 -0
  20. package/src/client/doppler-provider/model-manager.d.ts +10 -0
  21. package/src/client/doppler-provider/model-manager.js +91 -19
  22. package/src/client/doppler-provider/source-runtime.d.ts +2 -1
  23. package/src/client/doppler-provider/source-runtime.js +132 -13
  24. package/src/client/doppler-registry.json +8 -7
  25. package/src/config/backward-registry-loader.js +17 -2
  26. package/src/config/execution-v0-contract-check.js +113 -15
  27. package/src/config/kernel-path-contract-check.js +57 -29
  28. package/src/config/kernel-path-loader.js +5 -36
  29. package/src/config/kernels/kernel-ref-digests.js +39 -39
  30. package/src/config/kernels/registry.js +14 -1
  31. package/src/config/kernels/registry.json +49 -7
  32. package/src/config/loader.d.ts +1 -1
  33. package/src/config/loader.js +43 -4
  34. package/src/config/merge-contract-check.js +59 -4
  35. package/src/config/merge-helpers.js +128 -7
  36. package/src/config/merge.d.ts +1 -0
  37. package/src/config/merge.js +28 -0
  38. package/src/config/param-validator.js +47 -2
  39. package/src/config/presets/kernel-paths/{gemma2-q4k-dequant-f32a.json → gemma2-q4k-dequant-f32a-nosubgroups.json} +3 -3
  40. package/src/config/presets/kernel-paths/gemma3-f16-fused-f32a-online-streamingprefill.json +223 -0
  41. package/src/config/presets/kernel-paths/{gemma3-q4k-dequant-f32a.json → gemma3-q4k-dequant-f32a-nosubgroups.json} +3 -3
  42. package/src/config/presets/kernel-paths/registry.json +29 -8
  43. package/src/config/presets/models/gemma2.json +2 -2
  44. package/src/config/presets/models/qwen3.json +9 -2
  45. package/src/config/presets/models/transformer.json +5 -0
  46. package/src/config/presets/runtime/experiments/bench/gemma3-bench-q4k.json +1 -1
  47. package/src/config/presets/runtime/experiments/debug/gemma3-debug-q4k.json +1 -1
  48. package/src/config/presets/runtime/experiments/verify/gemma3-verify.json +1 -1
  49. package/src/config/presets/runtime/kernels/dequant-f16-q4k.json +6 -13
  50. package/src/config/presets/runtime/kernels/dequant-f32-q4k.json +6 -13
  51. package/src/config/presets/runtime/kernels/embeddinggemma-q4k-dequant-f32a.json +37 -0
  52. package/src/config/presets/runtime/kernels/fused-q4k.json +6 -13
  53. package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f16a.json +33 -0
  54. package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f32a-nosubgroups.json +33 -0
  55. package/src/config/presets/runtime/kernels/gemma2-q4k-fused-f32a.json +33 -0
  56. package/src/config/presets/runtime/kernels/safe-q4k.json +6 -13
  57. package/src/config/presets/runtime/platform/metal-apple-q4k.json +1 -1
  58. package/src/config/required-inference-fields-contract-check.js +6 -0
  59. package/src/config/runtime.js +6 -1
  60. package/src/config/schema/debug.schema.d.ts +5 -0
  61. package/src/config/schema/doppler.schema.js +16 -21
  62. package/src/config/schema/inference-defaults.schema.js +6 -3
  63. package/src/config/schema/inference.schema.d.ts +9 -0
  64. package/src/config/schema/kernel-path.schema.d.ts +11 -1
  65. package/src/config/schema/kernel-thresholds.schema.js +12 -4
  66. package/src/config/schema/manifest.schema.d.ts +8 -1
  67. package/src/config/schema/manifest.schema.js +19 -3
  68. package/src/config/training-defaults.js +30 -22
  69. package/src/converter/conversion-plan.js +94 -9
  70. package/src/converter/core.d.ts +7 -0
  71. package/src/converter/core.js +14 -9
  72. package/src/converter/execution-v0-manifest.js +4 -1
  73. package/src/converter/index.d.ts +1 -0
  74. package/src/converter/index.js +1 -0
  75. package/src/converter/manifest-inference.js +43 -12
  76. package/src/converter/parsers/diffusion.js +0 -3
  77. package/src/converter/quantization-info.js +35 -15
  78. package/src/converter/rope-config.js +42 -0
  79. package/src/converter/shard-packer.d.ts +1 -1
  80. package/src/converter/shard-packer.js +4 -1
  81. package/src/debug/config.js +123 -11
  82. package/src/debug/signals.js +7 -1
  83. package/src/debug/tensor.d.ts +2 -0
  84. package/src/debug/tensor.js +13 -2
  85. package/src/distribution/p2p-control-plane.js +52 -12
  86. package/src/distribution/p2p-observability.js +43 -7
  87. package/src/distribution/p2p-webrtc-browser.js +20 -0
  88. package/src/distribution/shard-delivery.js +77 -26
  89. package/src/formats/gguf/types.js +33 -16
  90. package/src/formats/rdrr/groups.d.ts +12 -4
  91. package/src/formats/rdrr/groups.js +3 -6
  92. package/src/formats/rdrr/parsing.js +39 -2
  93. package/src/formats/rdrr/types.d.ts +2 -1
  94. package/src/gpu/command-recorder.js +86 -61
  95. package/src/gpu/device.d.ts +1 -0
  96. package/src/gpu/device.js +131 -19
  97. package/src/gpu/kernel-tuner/benchmarks.js +326 -316
  98. package/src/gpu/kernel-tuner/cache.js +71 -4
  99. package/src/gpu/kernel-tuner/tuner.js +22 -4
  100. package/src/gpu/kernels/attention.js +113 -34
  101. package/src/gpu/kernels/backward/adam.js +62 -58
  102. package/src/gpu/kernels/backward/attention_backward.js +257 -169
  103. package/src/gpu/kernels/backward/conv2d_backward.js +14 -1
  104. package/src/gpu/kernels/bias_add.wgsl +8 -6
  105. package/src/gpu/kernels/bias_add_f16.wgsl +8 -5
  106. package/src/gpu/kernels/cast.js +191 -149
  107. package/src/gpu/kernels/check-stop.js +33 -44
  108. package/src/gpu/kernels/conv2d.js +27 -17
  109. package/src/gpu/kernels/conv2d.wgsl +7 -8
  110. package/src/gpu/kernels/conv2d_f16.wgsl +7 -8
  111. package/src/gpu/kernels/cross_entropy_loss.js +21 -15
  112. package/src/gpu/kernels/depthwise_conv2d.js +37 -26
  113. package/src/gpu/kernels/depthwise_conv2d.wgsl +6 -9
  114. package/src/gpu/kernels/depthwise_conv2d_f16.wgsl +6 -9
  115. package/src/gpu/kernels/dequant.js +178 -126
  116. package/src/gpu/kernels/energy.d.ts +3 -21
  117. package/src/gpu/kernels/energy.js +111 -88
  118. package/src/gpu/kernels/feature-check.js +1 -1
  119. package/src/gpu/kernels/fused_ffn.js +84 -65
  120. package/src/gpu/kernels/fused_matmul_residual.js +56 -33
  121. package/src/gpu/kernels/fused_matmul_rmsnorm.js +62 -45
  122. package/src/gpu/kernels/gather.js +33 -15
  123. package/src/gpu/kernels/gelu.js +19 -11
  124. package/src/gpu/kernels/grouped_pointwise_conv2d.js +34 -23
  125. package/src/gpu/kernels/grouped_pointwise_conv2d.wgsl +6 -9
  126. package/src/gpu/kernels/grouped_pointwise_conv2d_f16.wgsl +6 -9
  127. package/src/gpu/kernels/groupnorm.js +34 -23
  128. package/src/gpu/kernels/kv-quantize.js +5 -2
  129. package/src/gpu/kernels/layernorm.js +35 -19
  130. package/src/gpu/kernels/logit-merge.js +5 -3
  131. package/src/gpu/kernels/matmul.js +83 -39
  132. package/src/gpu/kernels/modulate.js +23 -15
  133. package/src/gpu/kernels/moe.js +221 -175
  134. package/src/gpu/kernels/pixel_shuffle.js +22 -14
  135. package/src/gpu/kernels/pixel_shuffle.wgsl +4 -5
  136. package/src/gpu/kernels/pixel_shuffle_f16.wgsl +4 -5
  137. package/src/gpu/kernels/relu.js +31 -10
  138. package/src/gpu/kernels/relu.wgsl +2 -1
  139. package/src/gpu/kernels/relu_f16.wgsl +2 -1
  140. package/src/gpu/kernels/repeat_channels.js +25 -17
  141. package/src/gpu/kernels/repeat_channels.wgsl +4 -5
  142. package/src/gpu/kernels/repeat_channels_f16.wgsl +4 -5
  143. package/src/gpu/kernels/residual.js +69 -23
  144. package/src/gpu/kernels/residual.wgsl +6 -3
  145. package/src/gpu/kernels/residual_f16.wgsl +2 -1
  146. package/src/gpu/kernels/residual_f16_vec4.wgsl +2 -1
  147. package/src/gpu/kernels/residual_vec4.wgsl +2 -1
  148. package/src/gpu/kernels/rmsnorm.js +96 -28
  149. package/src/gpu/kernels/rmsnorm.wgsl +14 -6
  150. package/src/gpu/kernels/rmsnorm_f16.wgsl +10 -2
  151. package/src/gpu/kernels/rope.d.ts +2 -0
  152. package/src/gpu/kernels/rope.js +14 -1
  153. package/src/gpu/kernels/rope.wgsl +56 -40
  154. package/src/gpu/kernels/sample.js +27 -38
  155. package/src/gpu/kernels/sana_linear_attention.js +19 -12
  156. package/src/gpu/kernels/sana_linear_attention_apply.wgsl +4 -5
  157. package/src/gpu/kernels/sana_linear_attention_apply_f16.wgsl +4 -5
  158. package/src/gpu/kernels/sana_linear_attention_summary.wgsl +4 -0
  159. package/src/gpu/kernels/sana_linear_attention_summary_f16.wgsl +4 -0
  160. package/src/gpu/kernels/scale.js +18 -11
  161. package/src/gpu/kernels/shader-cache.js +4 -2
  162. package/src/gpu/kernels/silu.d.ts +1 -0
  163. package/src/gpu/kernels/silu.js +148 -82
  164. package/src/gpu/kernels/silu.wgsl +19 -9
  165. package/src/gpu/kernels/silu_f16.wgsl +19 -9
  166. package/src/gpu/kernels/softmax.js +44 -25
  167. package/src/gpu/kernels/split_qkv.js +23 -13
  168. package/src/gpu/kernels/transpose.js +31 -10
  169. package/src/gpu/kernels/transpose.wgsl +6 -5
  170. package/src/gpu/kernels/upsample2d.js +22 -13
  171. package/src/gpu/kernels/upsample2d.wgsl +6 -9
  172. package/src/gpu/kernels/upsample2d_f16.wgsl +6 -9
  173. package/src/gpu/kernels/utils.js +35 -13
  174. package/src/gpu/partitioned-buffer-pool.js +10 -2
  175. package/src/gpu/perf-guards.js +2 -9
  176. package/src/gpu/profiler.js +27 -22
  177. package/src/gpu/readback-utils.d.ts +16 -0
  178. package/src/gpu/readback-utils.js +41 -0
  179. package/src/gpu/submit-tracker.js +13 -0
  180. package/src/gpu/uniform-cache.d.ts +1 -0
  181. package/src/gpu/uniform-cache.js +30 -9
  182. package/src/hotswap/intent-bundle.js +6 -0
  183. package/src/hotswap/manifest.d.ts +10 -1
  184. package/src/hotswap/manifest.js +12 -2
  185. package/src/hotswap/runtime.js +30 -8
  186. package/src/index-browser.d.ts +44 -0
  187. package/src/index-browser.js +14 -0
  188. package/src/inference/browser-harness-contract-helpers.d.ts +5 -0
  189. package/src/inference/browser-harness-contract-helpers.js +28 -0
  190. package/src/inference/browser-harness-diffusion-energy-suites.d.ts +2 -0
  191. package/src/inference/browser-harness-diffusion-energy-suites.js +269 -0
  192. package/src/inference/browser-harness-model-helpers.d.ts +16 -0
  193. package/src/inference/browser-harness-model-helpers.js +217 -0
  194. package/src/inference/browser-harness-report-helpers.d.ts +7 -0
  195. package/src/inference/browser-harness-report-helpers.js +42 -0
  196. package/src/inference/browser-harness-runtime-helpers.d.ts +61 -0
  197. package/src/inference/browser-harness-runtime-helpers.js +415 -0
  198. package/src/inference/browser-harness-suite-helpers.d.ts +28 -0
  199. package/src/inference/browser-harness-suite-helpers.js +268 -0
  200. package/src/inference/browser-harness-text-helpers.d.ts +27 -0
  201. package/src/inference/browser-harness-text-helpers.js +788 -0
  202. package/src/inference/browser-harness.d.ts +6 -0
  203. package/src/inference/browser-harness.js +130 -1950
  204. package/src/inference/kv-cache/base.js +140 -94
  205. package/src/inference/kv-cache/tiered.js +5 -3
  206. package/src/inference/moe-router.js +88 -56
  207. package/src/inference/multi-model-network.js +5 -3
  208. package/src/inference/network-evolution.d.ts +11 -2
  209. package/src/inference/network-evolution.js +20 -21
  210. package/src/inference/pipelines/context.d.ts +3 -0
  211. package/src/inference/pipelines/context.js +142 -2
  212. package/src/inference/pipelines/diffusion/helpers.js +7 -2
  213. package/src/inference/pipelines/diffusion/pipeline.js +17 -7
  214. package/src/inference/pipelines/diffusion/sd3-transformer.js +10 -10
  215. package/src/inference/pipelines/diffusion/text-encoder-gpu.d.ts +5 -0
  216. package/src/inference/pipelines/diffusion/text-encoder-gpu.js +27 -15
  217. package/src/inference/pipelines/diffusion/vae.js +3 -7
  218. package/src/inference/pipelines/energy/pipeline.js +27 -21
  219. package/src/inference/pipelines/energy/quintel.d.ts +5 -0
  220. package/src/inference/pipelines/energy/quintel.js +11 -0
  221. package/src/inference/pipelines/energy-head/row-head-pipeline.js +17 -13
  222. package/src/inference/pipelines/structured/json-head-pipeline.js +26 -11
  223. package/src/inference/pipelines/text/attention/projections.js +151 -101
  224. package/src/inference/pipelines/text/attention/record.js +73 -10
  225. package/src/inference/pipelines/text/attention/run.js +73 -10
  226. package/src/inference/pipelines/text/chat-format.js +25 -1
  227. package/src/inference/pipelines/text/config.d.ts +4 -0
  228. package/src/inference/pipelines/text/config.js +71 -5
  229. package/src/inference/pipelines/text/embed.js +2 -8
  230. package/src/inference/pipelines/text/execution-plan.js +64 -50
  231. package/src/inference/pipelines/text/execution-v0-contract-helpers.d.ts +59 -0
  232. package/src/inference/pipelines/text/execution-v0-contract-helpers.js +937 -0
  233. package/src/inference/pipelines/text/execution-v0-runtime-builders.d.ts +15 -0
  234. package/src/inference/pipelines/text/execution-v0-runtime-builders.js +279 -0
  235. package/src/inference/pipelines/text/execution-v0.js +78 -1002
  236. package/src/inference/pipelines/text/ffn/standard.js +3 -0
  237. package/src/inference/pipelines/text/generator-steps.d.ts +46 -0
  238. package/src/inference/pipelines/text/generator-steps.js +298 -207
  239. package/src/inference/pipelines/text/generator.js +6 -23
  240. package/src/inference/pipelines/text/init.d.ts +4 -0
  241. package/src/inference/pipelines/text/init.js +134 -29
  242. package/src/inference/pipelines/text/kernel-path-auto-select.js +2 -0
  243. package/src/inference/pipelines/text/kernel-trace.d.ts +2 -0
  244. package/src/inference/pipelines/text/kernel-trace.js +6 -0
  245. package/src/inference/pipelines/text/layer.js +14 -9
  246. package/src/inference/pipelines/text/linear-attention.d.ts +10 -0
  247. package/src/inference/pipelines/text/linear-attention.js +80 -6
  248. package/src/inference/pipelines/text/logits/gpu.js +10 -5
  249. package/src/inference/pipelines/text/logits/index.js +10 -11
  250. package/src/inference/pipelines/text/logits/utils.d.ts +7 -0
  251. package/src/inference/pipelines/text/logits/utils.js +9 -0
  252. package/src/inference/pipelines/text/lora-apply.js +50 -32
  253. package/src/inference/pipelines/text/model-load.js +279 -104
  254. package/src/inference/pipelines/text/moe-cache.js +5 -4
  255. package/src/inference/pipelines/text/moe-cpu-gptoss.js +74 -69
  256. package/src/inference/pipelines/text/moe-cpu.js +42 -38
  257. package/src/inference/pipelines/text/moe-gpu.js +110 -86
  258. package/src/inference/pipelines/text/ops.js +90 -90
  259. package/src/inference/pipelines/text/probes.js +9 -9
  260. package/src/inference/pipelines/text/weights.js +17 -7
  261. package/src/inference/pipelines/text.js +17 -1
  262. package/src/inference/speculative.d.ts +2 -2
  263. package/src/inference/speculative.js +4 -18
  264. package/src/inference/test-harness.d.ts +1 -1
  265. package/src/inference/test-harness.js +15 -5
  266. package/src/inference/tokenizer.d.ts +0 -5
  267. package/src/inference/tokenizer.js +4 -23
  268. package/src/inference/tokenizers/bpe.js +9 -0
  269. package/src/inference/tokenizers/bundled.js +176 -33
  270. package/src/inference/tokenizers/sentencepiece.js +12 -0
  271. package/src/loader/doppler-loader.js +38 -22
  272. package/src/loader/dtype-utils.js +3 -44
  273. package/src/loader/embedding-loader.js +7 -3
  274. package/src/loader/experts/expert-cache.js +13 -6
  275. package/src/loader/experts/expert-loader.js +10 -6
  276. package/src/loader/final-weights-loader.js +8 -4
  277. package/src/loader/layer-loader.js +2 -1
  278. package/src/loader/loader-state.js +2 -2
  279. package/src/loader/memory-monitor.js +8 -0
  280. package/src/loader/multi-model-loader.d.ts +14 -0
  281. package/src/loader/multi-model-loader.js +70 -24
  282. package/src/loader/shard-cache.js +81 -12
  283. package/src/loader/shard-resolver.js +25 -3
  284. package/src/loader/tensors/tensor-loader.js +209 -144
  285. package/src/loader/tensors/tensor-reader.js +76 -19
  286. package/src/loader/weight-downcast.js +1 -1
  287. package/src/memory/buffer-pool.d.ts +9 -1
  288. package/src/memory/buffer-pool.js +109 -44
  289. package/src/memory/unified-detect.js +1 -1
  290. package/src/rules/inference/kernel-path.rules.json +24 -8
  291. package/src/rules/rule-registry.js +25 -1
  292. package/src/rules/tooling/command-runtime.rules.json +18 -0
  293. package/src/storage/backends/opfs-store.js +68 -24
  294. package/src/storage/downloader.js +364 -83
  295. package/src/storage/index.d.ts +3 -0
  296. package/src/storage/index.js +3 -0
  297. package/src/storage/preflight.d.ts +2 -2
  298. package/src/storage/preflight.js +24 -2
  299. package/src/storage/quickstart-downloader.js +11 -5
  300. package/src/storage/registry.js +10 -4
  301. package/src/storage/reports.js +1 -1
  302. package/src/storage/shard-manager.d.ts +15 -1
  303. package/src/storage/shard-manager.js +51 -3
  304. package/src/storage/source-artifact-store.d.ts +52 -0
  305. package/src/storage/source-artifact-store.js +234 -0
  306. package/src/tooling/command-api-constants.d.ts +9 -0
  307. package/src/tooling/command-api-constants.js +9 -0
  308. package/src/tooling/command-api-family-normalizers.d.ts +9 -0
  309. package/src/tooling/command-api-family-normalizers.js +343 -0
  310. package/src/tooling/command-api-helpers.d.ts +25 -0
  311. package/src/tooling/command-api-helpers.js +262 -0
  312. package/src/tooling/command-api.d.ts +27 -1
  313. package/src/tooling/command-api.js +26 -473
  314. package/src/tooling/command-envelope.js +4 -1
  315. package/src/tooling/command-runner-shared.js +52 -18
  316. package/src/tooling/lean-execution-contract.js +150 -3
  317. package/src/tooling/node-browser-command-runner.d.ts +4 -0
  318. package/src/tooling/node-browser-command-runner.js +218 -273
  319. package/src/tooling/node-command-runner.js +44 -3
  320. package/src/tooling/node-converter.js +27 -1
  321. package/src/tooling/node-source-runtime.d.ts +1 -1
  322. package/src/tooling/node-source-runtime.js +84 -3
  323. package/src/tooling/node-webgpu.js +30 -105
  324. package/src/tooling/opfs-cache.js +21 -4
  325. package/src/tooling/runtime-input-composition.d.ts +38 -0
  326. package/src/tooling/runtime-input-composition.js +86 -0
  327. package/src/tooling/source-runtime-bundle.d.ts +40 -5
  328. package/src/tooling/source-runtime-bundle.js +261 -34
  329. package/src/tooling/source-runtime-materializer.d.ts +6 -0
  330. package/src/tooling/source-runtime-materializer.js +93 -0
  331. package/src/training/attention-backward.js +32 -17
  332. package/src/training/autograd.js +80 -52
  333. package/src/training/checkpoint-watch.d.ts +8 -0
  334. package/src/training/checkpoint-watch.js +139 -0
  335. package/src/training/checkpoint.d.ts +6 -1
  336. package/src/training/checkpoint.js +46 -7
  337. package/src/training/clip.js +2 -1
  338. package/src/training/datasets/token-batch.js +20 -8
  339. package/src/training/distillation/artifacts.d.ts +71 -0
  340. package/src/training/distillation/artifacts.js +132 -0
  341. package/src/training/distillation/checkpoint-watch.d.ts +10 -0
  342. package/src/training/distillation/checkpoint-watch.js +58 -0
  343. package/src/training/distillation/dataset.d.ts +59 -0
  344. package/src/training/distillation/dataset.js +337 -0
  345. package/src/training/distillation/eval.d.ts +34 -0
  346. package/src/training/distillation/eval.js +310 -0
  347. package/src/training/distillation/index.d.ts +29 -0
  348. package/src/training/distillation/index.js +29 -0
  349. package/src/training/distillation/runtime.d.ts +20 -0
  350. package/src/training/distillation/runtime.js +121 -0
  351. package/src/training/distillation/scoreboard.d.ts +6 -0
  352. package/src/training/distillation/scoreboard.js +8 -0
  353. package/src/training/distillation/stage-a.d.ts +45 -0
  354. package/src/training/distillation/stage-a.js +338 -0
  355. package/src/training/distillation/stage-b.d.ts +24 -0
  356. package/src/training/distillation/stage-b.js +20 -0
  357. package/src/training/distillation/student-fixture.d.ts +22 -0
  358. package/src/training/distillation/student-fixture.js +846 -0
  359. package/src/training/distillation/suite-data.d.ts +45 -0
  360. package/src/training/distillation/suite-data.js +189 -0
  361. package/src/training/index.d.ts +10 -0
  362. package/src/training/index.js +10 -0
  363. package/src/training/lora-pipeline.d.ts +40 -0
  364. package/src/training/lora-pipeline.js +793 -0
  365. package/src/training/lora.js +26 -12
  366. package/src/training/loss.js +5 -6
  367. package/src/training/objectives/cross_entropy.js +2 -5
  368. package/src/training/objectives/distill_kd.js +4 -8
  369. package/src/training/objectives/distill_triplet.js +4 -8
  370. package/src/training/objectives/ul_stage2_base.js +4 -8
  371. package/src/training/operator-artifacts.d.ts +62 -0
  372. package/src/training/operator-artifacts.js +140 -0
  373. package/src/training/operator-command.d.ts +5 -0
  374. package/src/training/operator-command.js +455 -0
  375. package/src/training/operator-eval.d.ts +48 -0
  376. package/src/training/operator-eval.js +230 -0
  377. package/src/training/operator-scoreboard.d.ts +5 -0
  378. package/src/training/operator-scoreboard.js +44 -0
  379. package/src/training/optimizer.js +19 -7
  380. package/src/training/runner.d.ts +52 -0
  381. package/src/training/runner.js +31 -5
  382. package/src/training/suite.d.ts +112 -0
  383. package/src/training/suite.js +24 -984
  384. package/src/training/tensor-factory.d.ts +9 -0
  385. package/src/training/tensor-factory.js +13 -0
  386. package/src/training/trainer.js +3 -5
  387. package/src/training/ul_dataset.js +3 -5
  388. package/src/training/workloads.d.ts +164 -0
  389. package/src/training/workloads.js +530 -0
  390. package/src/version.js +1 -1
  391. package/tools/convert-safetensors-node.js +22 -16
  392. package/tools/doppler-cli.js +179 -63
@@ -16,6 +16,7 @@ export interface EnergyUpdateOptions {
16
16
  export interface EnergyQuintelUpdateOptions {
17
17
  count?: number;
18
18
  size?: number;
19
+ flags?: number;
19
20
  stepSize?: number;
20
21
  gradientScale?: number;
21
22
  countDiff?: number;
@@ -26,48 +27,29 @@ export interface EnergyQuintelUpdateOptions {
26
27
  centerTarget?: number;
27
28
  clampMin?: number;
28
29
  clampMax?: number;
29
- rules?: {
30
- mirrorX?: boolean;
31
- mirrorY?: boolean;
32
- diagonal?: boolean;
33
- count?: boolean;
34
- center?: boolean;
35
- };
36
30
  }
37
31
 
38
32
  export interface EnergyQuintelReduceOptions {
39
33
  count?: number;
40
34
  size?: number;
35
+ flags?: number;
41
36
  symmetryWeight?: number;
42
37
  centerWeight?: number;
43
38
  binarizeWeight?: number;
44
39
  centerTarget?: number;
45
- rules?: {
46
- mirrorX?: boolean;
47
- mirrorY?: boolean;
48
- diagonal?: boolean;
49
- count?: boolean;
50
- center?: boolean;
51
- };
52
40
  outputBuffer?: GPUBuffer | null;
53
41
  }
54
42
 
55
43
  export interface EnergyQuintelGradOptions {
56
44
  count?: number;
57
45
  size?: number;
46
+ flags?: number;
58
47
  countDiff?: number;
59
48
  symmetryWeight?: number;
60
49
  countWeight?: number;
61
50
  centerWeight?: number;
62
51
  binarizeWeight?: number;
63
52
  centerTarget?: number;
64
- rules?: {
65
- mirrorX?: boolean;
66
- mirrorY?: boolean;
67
- diagonal?: boolean;
68
- count?: boolean;
69
- center?: boolean;
70
- };
71
53
  outputBuffer?: GPUBuffer | null;
72
54
  }
73
55
 
@@ -1,5 +1,5 @@
1
1
  import { getDevice } from '../device.js';
2
- import { acquireBuffer } from '../../memory/buffer-pool.js';
2
+ import { acquireBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
3
3
  import { createTensor, dtypeBytes } from '../tensor.js';
4
4
  import { WORKGROUP_SIZES } from './constants.js';
5
5
  import { dispatch, recordDispatch } from './dispatch.js';
@@ -61,15 +61,14 @@ function resolveQuintelSize(state, sizeOverride) {
61
61
  return null;
62
62
  }
63
63
 
64
- function buildQuintelFlags(rules, binarizeWeight) {
65
- let flags = 0;
66
- if (rules?.mirrorX) flags |= 1;
67
- if (rules?.mirrorY) flags |= 2;
68
- if (rules?.diagonal) flags |= 4;
69
- if (rules?.count) flags |= 8;
70
- if (rules?.center) flags |= 16;
71
- if (Number.isFinite(binarizeWeight) && binarizeWeight !== 0) flags |= 32;
72
- return flags >>> 0;
64
+ function resolveQuintelFlags(options, op) {
65
+ if (options.rules !== undefined) {
66
+ throw new Error(`${op}: quintel kernel flags must be resolved before dispatch.`);
67
+ }
68
+ if (!Number.isFinite(options.flags)) {
69
+ throw new Error(`${op}: flags is required for quintel kernels.`);
70
+ }
71
+ return options.flags >>> 0;
73
72
  }
74
73
 
75
74
  function resolveExecution(recorder) {
@@ -103,6 +102,12 @@ function releaseUniformBuffer(execution, uniformBuffer) {
103
102
  }
104
103
  }
105
104
 
105
+ function releaseOwnedBuffer(ownedBuffer) {
106
+ if (ownedBuffer) {
107
+ releaseBuffer(ownedBuffer);
108
+ }
109
+ }
110
+
106
111
  function writeQuintelUpdateUniform(view, params) {
107
112
  view.setUint32(0, params.elementCount, true);
108
113
  view.setUint32(4, params.boardSize, true);
@@ -149,6 +154,7 @@ async function executeEnergyEval(recorder, state, target, options = {}, op) {
149
154
 
150
155
  const outputSize = elementCount * 4;
151
156
  const output = outputBuffer || acquireBuffer(outputSize, undefined, 'energy_eval_output');
157
+ const ownedOutput = outputBuffer ? null : output;
152
158
 
153
159
  const variant = selectEnergyEvalVariant(state.dtype);
154
160
  const pipeline = await getPipelineFast('energy_eval', variant);
@@ -157,23 +163,27 @@ async function executeEnergyEval(recorder, state, target, options = {}, op) {
157
163
  view.setUint32(0, elementCount, true);
158
164
  view.setFloat32(4, scale, true);
159
165
  });
166
+ try {
167
+ const bindGroup = execution.device.createBindGroup({
168
+ label: 'energy_eval_bind_group',
169
+ layout: pipeline.getBindGroupLayout(0),
170
+ entries: [
171
+ { binding: 0, resource: { buffer: uniformBuffer } },
172
+ { binding: 1, resource: { buffer: state.buffer } },
173
+ { binding: 2, resource: { buffer: target.buffer } },
174
+ { binding: 3, resource: { buffer: output } },
175
+ ],
176
+ });
160
177
 
161
- const bindGroup = execution.device.createBindGroup({
162
- label: 'energy_eval_bind_group',
163
- layout: pipeline.getBindGroupLayout(0),
164
- entries: [
165
- { binding: 0, resource: { buffer: uniformBuffer } },
166
- { binding: 1, resource: { buffer: state.buffer } },
167
- { binding: 2, resource: { buffer: target.buffer } },
168
- { binding: 3, resource: { buffer: output } },
169
- ],
170
- });
171
-
172
- const workgroups = Math.ceil(elementCount / WORKGROUP_SIZES.DEFAULT);
173
- dispatchEnergy(execution, pipeline, bindGroup, workgroups, 'energy_eval');
174
- releaseUniformBuffer(execution, uniformBuffer);
175
-
176
- return createTensor(output, 'f32', [elementCount], 'energy_eval_output');
178
+ const workgroups = Math.ceil(elementCount / WORKGROUP_SIZES.DEFAULT);
179
+ dispatchEnergy(execution, pipeline, bindGroup, workgroups, 'energy_eval');
180
+ return createTensor(output, 'f32', [elementCount], 'energy_eval_output');
181
+ } catch (error) {
182
+ releaseOwnedBuffer(ownedOutput);
183
+ throw error;
184
+ } finally {
185
+ releaseUniformBuffer(execution, uniformBuffer);
186
+ }
177
187
  }
178
188
 
179
189
  async function executeEnergyUpdate(recorder, state, target, options = {}, op) {
@@ -191,21 +201,23 @@ async function executeEnergyUpdate(recorder, state, target, options = {}, op) {
191
201
  view.setFloat32(8, gradientScale, true);
192
202
  });
193
203
 
194
- const bindGroup = execution.device.createBindGroup({
195
- label: 'energy_update_bind_group',
196
- layout: pipeline.getBindGroupLayout(0),
197
- entries: [
198
- { binding: 0, resource: { buffer: uniformBuffer } },
199
- { binding: 1, resource: { buffer: state.buffer } },
200
- { binding: 2, resource: { buffer: target.buffer } },
201
- ],
202
- });
203
-
204
- const workgroups = Math.ceil(elementCount / WORKGROUP_SIZES.DEFAULT);
205
- dispatchEnergy(execution, pipeline, bindGroup, workgroups, 'energy_update');
206
- releaseUniformBuffer(execution, uniformBuffer);
204
+ try {
205
+ const bindGroup = execution.device.createBindGroup({
206
+ label: 'energy_update_bind_group',
207
+ layout: pipeline.getBindGroupLayout(0),
208
+ entries: [
209
+ { binding: 0, resource: { buffer: uniformBuffer } },
210
+ { binding: 1, resource: { buffer: state.buffer } },
211
+ { binding: 2, resource: { buffer: target.buffer } },
212
+ ],
213
+ });
207
214
 
208
- return state;
215
+ const workgroups = Math.ceil(elementCount / WORKGROUP_SIZES.DEFAULT);
216
+ dispatchEnergy(execution, pipeline, bindGroup, workgroups, 'energy_update');
217
+ return state;
218
+ } finally {
219
+ releaseUniformBuffer(execution, uniformBuffer);
220
+ }
209
221
  }
210
222
 
211
223
  async function executeEnergyQuintelUpdate(recorder, state, options = {}, op) {
@@ -224,7 +236,6 @@ async function executeEnergyQuintelUpdate(recorder, state, options = {}, op) {
224
236
  centerTarget = 1.0,
225
237
  clampMin = 0.0,
226
238
  clampMax = 1.0,
227
- rules = {},
228
239
  } = options;
229
240
  const elementCount = inferCount(state, count);
230
241
  const boardSize = resolveQuintelSize(state, size);
@@ -234,7 +245,7 @@ async function executeEnergyQuintelUpdate(recorder, state, options = {}, op) {
234
245
 
235
246
  const variant = selectEnergyQuintelUpdateVariant(state.dtype);
236
247
  const pipeline = await getPipelineFast('energy_quintel_update', variant);
237
- const flags = buildQuintelFlags(rules, binarizeWeight);
248
+ const flags = resolveQuintelFlags(options, op);
238
249
 
239
250
  const uniformBuffer = createUniformBuffer(execution, 'energy_quintel_uniforms', 64, (view) => {
240
251
  writeQuintelUpdateUniform(view, {
@@ -254,20 +265,22 @@ async function executeEnergyQuintelUpdate(recorder, state, options = {}, op) {
254
265
  });
255
266
  });
256
267
 
257
- const bindGroup = execution.device.createBindGroup({
258
- label: 'energy_quintel_update_bind_group',
259
- layout: pipeline.getBindGroupLayout(0),
260
- entries: [
261
- { binding: 0, resource: { buffer: uniformBuffer } },
262
- { binding: 1, resource: { buffer: state.buffer } },
263
- ],
264
- });
265
-
266
- const workgroups = Math.ceil(elementCount / WORKGROUP_SIZES.DEFAULT);
267
- dispatchEnergy(execution, pipeline, bindGroup, workgroups, 'energy_quintel_update');
268
- releaseUniformBuffer(execution, uniformBuffer);
268
+ try {
269
+ const bindGroup = execution.device.createBindGroup({
270
+ label: 'energy_quintel_update_bind_group',
271
+ layout: pipeline.getBindGroupLayout(0),
272
+ entries: [
273
+ { binding: 0, resource: { buffer: uniformBuffer } },
274
+ { binding: 1, resource: { buffer: state.buffer } },
275
+ ],
276
+ });
269
277
 
270
- return state;
278
+ const workgroups = Math.ceil(elementCount / WORKGROUP_SIZES.DEFAULT);
279
+ dispatchEnergy(execution, pipeline, bindGroup, workgroups, 'energy_quintel_update');
280
+ return state;
281
+ } finally {
282
+ releaseUniformBuffer(execution, uniformBuffer);
283
+ }
271
284
  }
272
285
 
273
286
  async function executeEnergyQuintelReduce(recorder, state, options = {}, op) {
@@ -280,7 +293,6 @@ async function executeEnergyQuintelReduce(recorder, state, options = {}, op) {
280
293
  centerWeight = 1.0,
281
294
  binarizeWeight = 0.0,
282
295
  centerTarget = 1.0,
283
- rules = {},
284
296
  outputBuffer = null,
285
297
  } = options;
286
298
  const elementCount = inferCount(state, count);
@@ -291,7 +303,7 @@ async function executeEnergyQuintelReduce(recorder, state, options = {}, op) {
291
303
 
292
304
  const variant = selectEnergyQuintelReduceVariant(state.dtype);
293
305
  const pipeline = await getPipelineFast('energy_quintel_reduce', variant);
294
- const flags = buildQuintelFlags(rules, binarizeWeight);
306
+ const flags = resolveQuintelFlags(options, op);
295
307
 
296
308
  const uniformBuffer = createUniformBuffer(execution, 'energy_quintel_reduce_uniforms', 48, (view) => {
297
309
  writeQuintelReduceUniform(view, {
@@ -308,21 +320,27 @@ async function executeEnergyQuintelReduce(recorder, state, options = {}, op) {
308
320
  const workgroups = Math.ceil(elementCount / WORKGROUP_SIZES.DEFAULT);
309
321
  const outputSize = workgroups * 16;
310
322
  const output = outputBuffer || acquireBuffer(outputSize, undefined, 'energy_quintel_reduce_output');
323
+ const ownedOutput = outputBuffer ? null : output;
324
+
325
+ try {
326
+ const bindGroup = execution.device.createBindGroup({
327
+ label: 'energy_quintel_reduce_bind_group',
328
+ layout: pipeline.getBindGroupLayout(0),
329
+ entries: [
330
+ { binding: 0, resource: { buffer: uniformBuffer } },
331
+ { binding: 1, resource: { buffer: state.buffer } },
332
+ { binding: 2, resource: { buffer: output } },
333
+ ],
334
+ });
311
335
 
312
- const bindGroup = execution.device.createBindGroup({
313
- label: 'energy_quintel_reduce_bind_group',
314
- layout: pipeline.getBindGroupLayout(0),
315
- entries: [
316
- { binding: 0, resource: { buffer: uniformBuffer } },
317
- { binding: 1, resource: { buffer: state.buffer } },
318
- { binding: 2, resource: { buffer: output } },
319
- ],
320
- });
321
-
322
- dispatchEnergy(execution, pipeline, bindGroup, workgroups, 'energy_quintel_reduce');
323
- releaseUniformBuffer(execution, uniformBuffer);
324
-
325
- return createTensor(output, 'f32', [workgroups, 4], 'energy_quintel_reduce_output');
336
+ dispatchEnergy(execution, pipeline, bindGroup, workgroups, 'energy_quintel_reduce');
337
+ return createTensor(output, 'f32', [workgroups, 4], 'energy_quintel_reduce_output');
338
+ } catch (error) {
339
+ releaseOwnedBuffer(ownedOutput);
340
+ throw error;
341
+ } finally {
342
+ releaseUniformBuffer(execution, uniformBuffer);
343
+ }
326
344
  }
327
345
 
328
346
  async function executeEnergyQuintelGrad(recorder, state, options = {}, op) {
@@ -337,7 +355,6 @@ async function executeEnergyQuintelGrad(recorder, state, options = {}, op) {
337
355
  centerWeight = 1.0,
338
356
  binarizeWeight = 0.0,
339
357
  centerTarget = 1.0,
340
- rules = {},
341
358
  outputBuffer = null,
342
359
  } = options;
343
360
  const elementCount = inferCount(state, count);
@@ -348,7 +365,7 @@ async function executeEnergyQuintelGrad(recorder, state, options = {}, op) {
348
365
 
349
366
  const variant = selectEnergyQuintelGradVariant(state.dtype);
350
367
  const pipeline = await getPipelineFast('energy_quintel_grad', variant);
351
- const flags = buildQuintelFlags(rules, binarizeWeight);
368
+ const flags = resolveQuintelFlags(options, op);
352
369
 
353
370
  const uniformBuffer = createUniformBuffer(execution, 'energy_quintel_grad_uniforms', 64, (view) => {
354
371
  writeQuintelGradUniform(view, {
@@ -366,22 +383,28 @@ async function executeEnergyQuintelGrad(recorder, state, options = {}, op) {
366
383
 
367
384
  const outputSize = elementCount * 4;
368
385
  const output = outputBuffer || acquireBuffer(outputSize, undefined, 'energy_quintel_grad_output');
386
+ const ownedOutput = outputBuffer ? null : output;
387
+
388
+ try {
389
+ const bindGroup = execution.device.createBindGroup({
390
+ label: 'energy_quintel_grad_bind_group',
391
+ layout: pipeline.getBindGroupLayout(0),
392
+ entries: [
393
+ { binding: 0, resource: { buffer: uniformBuffer } },
394
+ { binding: 1, resource: { buffer: state.buffer } },
395
+ { binding: 2, resource: { buffer: output } },
396
+ ],
397
+ });
369
398
 
370
- const bindGroup = execution.device.createBindGroup({
371
- label: 'energy_quintel_grad_bind_group',
372
- layout: pipeline.getBindGroupLayout(0),
373
- entries: [
374
- { binding: 0, resource: { buffer: uniformBuffer } },
375
- { binding: 1, resource: { buffer: state.buffer } },
376
- { binding: 2, resource: { buffer: output } },
377
- ],
378
- });
379
-
380
- const workgroups = Math.ceil(elementCount / WORKGROUP_SIZES.DEFAULT);
381
- dispatchEnergy(execution, pipeline, bindGroup, workgroups, 'energy_quintel_grad');
382
- releaseUniformBuffer(execution, uniformBuffer);
383
-
384
- return createTensor(output, 'f32', [elementCount], 'energy_quintel_grad_output');
399
+ const workgroups = Math.ceil(elementCount / WORKGROUP_SIZES.DEFAULT);
400
+ dispatchEnergy(execution, pipeline, bindGroup, workgroups, 'energy_quintel_grad');
401
+ return createTensor(output, 'f32', [elementCount], 'energy_quintel_grad_output');
402
+ } catch (error) {
403
+ releaseOwnedBuffer(ownedOutput);
404
+ throw error;
405
+ } finally {
406
+ releaseUniformBuffer(execution, uniformBuffer);
407
+ }
385
408
  }
386
409
 
387
410
  export async function runEnergyEval(state, target, options = {}) {
@@ -16,7 +16,7 @@ export function hasRequiredFeatures(
16
16
  for (const feature of required) {
17
17
  if (feature === 'shader-f16' && !capabilities.hasF16) return false;
18
18
  if (feature === 'subgroups' && !capabilities.hasSubgroups) return false;
19
- if (feature === 'subgroups-f16' && !capabilities.hasSubgroups) return false;
19
+ if (feature === 'subgroups-f16' && !capabilities.hasSubgroupsF16) return false;
20
20
  }
21
21
  return true;
22
22
  }
@@ -1,7 +1,7 @@
1
1
 
2
2
 
3
3
  import { getDevice, getKernelCapabilities } from '../device.js';
4
- import { acquireBuffer } from '../../memory/buffer-pool.js';
4
+ import { acquireBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
5
5
  import { createTensor } from '../tensor.js';
6
6
  import { KernelBase } from './kernel-base.js';
7
7
  import { createUniformBufferWithView } from './utils.js';
@@ -77,6 +77,17 @@ function resolveSwigluLimit(value, context) {
77
77
  return value;
78
78
  }
79
79
 
80
+ function releaseRunResources(uniformBuffer, ownedBuffers) {
81
+ if (uniformBuffer) {
82
+ uniformBuffer.destroy();
83
+ }
84
+ for (const buffer of ownedBuffers) {
85
+ if (buffer) {
86
+ releaseBuffer(buffer);
87
+ }
88
+ }
89
+ }
90
+
80
91
 
81
92
  export async function runFusedFFN(
82
93
  input,
@@ -132,7 +143,8 @@ export async function runFusedFFN(
132
143
  const outputBytesPerElement = isF16Native ? 2 : 4;
133
144
  const outputDtype = isF16Native ? 'f16' : 'f32';
134
145
  const outputSize = batchSize * intermediateSize * outputBytesPerElement;
135
- const output = outputBuffer || acquireBuffer(outputSize, undefined, 'fused_ffn_output');
146
+ const ownedOutput = outputBuffer ? null : acquireBuffer(outputSize, undefined, 'fused_ffn_output');
147
+ const output = outputBuffer || ownedOutput;
136
148
 
137
149
  // Create uniform buffer
138
150
  const uniformBuffer = createFFNUniformBuffer(device, null, {
@@ -145,41 +157,42 @@ export async function runFusedFFN(
145
157
  swigluLimit: activation === 'silu' ? swigluLimit : null,
146
158
  });
147
159
 
148
- // Create bind group
149
- const bindGroup = device.createBindGroup({
150
- label: 'fused_ffn_bind_group',
151
- layout: pipeline.getBindGroupLayout(0),
152
- entries: [
153
- { binding: 0, resource: { buffer: uniformBuffer } },
154
- { binding: 1, resource: { buffer: input.buffer } },
155
- { binding: 2, resource: { buffer: getBuffer(W_gate) } },
156
- { binding: 3, resource: { buffer: getBuffer(W_up) } },
157
- { binding: 4, resource: { buffer: output } },
158
- ],
159
- });
160
+ try {
161
+ const bindGroup = device.createBindGroup({
162
+ label: 'fused_ffn_bind_group',
163
+ layout: pipeline.getBindGroupLayout(0),
164
+ entries: [
165
+ { binding: 0, resource: { buffer: uniformBuffer } },
166
+ { binding: 1, resource: { buffer: input.buffer } },
167
+ { binding: 2, resource: { buffer: getBuffer(W_gate) } },
168
+ { binding: 3, resource: { buffer: getBuffer(W_up) } },
169
+ { binding: 4, resource: { buffer: output } },
170
+ ],
171
+ });
172
+
173
+ let workgroupsX;
174
+ let workgroupsY = 1;
175
+
176
+ if (variant === 'multi') {
177
+ const outputsPerWg = 4;
178
+ workgroupsX = Math.ceil(intermediateSize / outputsPerWg);
179
+ } else if (variant === 'q4k' || variant === 'q4k_batched') {
180
+ const colsPerWg = 32;
181
+ workgroupsX = Math.ceil(intermediateSize / colsPerWg);
182
+ workgroupsY = variant === 'q4k_batched' ? batchSize : 1;
183
+ } else if (variant === 'batched' || variant === 'f16_native_batched') {
184
+ workgroupsX = intermediateSize;
185
+ workgroupsY = batchSize;
186
+ } else {
187
+ workgroupsX = intermediateSize;
188
+ }
160
189
 
161
- // Calculate workgroups
162
-
163
- let workgroupsX;
164
- let workgroupsY = 1;
165
-
166
- if (variant === 'multi') {
167
- const outputsPerWg = 4;
168
- workgroupsX = Math.ceil(intermediateSize / outputsPerWg);
169
- } else if (variant === 'q4k' || variant === 'q4k_batched') {
170
- // Q4K uses multi-column: 32 columns per workgroup
171
- const colsPerWg = 32;
172
- workgroupsX = Math.ceil(intermediateSize / colsPerWg);
173
- workgroupsY = variant === 'q4k_batched' ? batchSize : 1;
174
- } else if (variant === 'batched' || variant === 'f16_native_batched') {
175
- workgroupsX = intermediateSize;
176
- workgroupsY = batchSize;
177
- } else {
178
- workgroupsX = intermediateSize;
190
+ kernel.dispatch(pipeline, bindGroup, workgroupsX, workgroupsY);
191
+ } catch (error) {
192
+ releaseRunResources(uniformBuffer, [ownedOutput]);
193
+ throw error;
179
194
  }
180
195
 
181
- kernel.dispatch(pipeline, bindGroup, workgroupsX, workgroupsY);
182
-
183
196
  uniformBuffer.destroy();
184
197
 
185
198
  return createTensor(output, outputDtype, [batchSize, intermediateSize], 'fused_ffn_output');
@@ -240,7 +253,8 @@ export async function recordFusedFFN(
240
253
  const outputBytesPerElement = isF16Native ? 2 : 4;
241
254
  const outputDtype = isF16Native ? 'f16' : 'f32';
242
255
  const outputSize = batchSize * intermediateSize * outputBytesPerElement;
243
- const output = outputBuffer || acquireBuffer(outputSize, undefined, 'fused_ffn_output');
256
+ const ownedOutput = outputBuffer ? null : acquireBuffer(outputSize, undefined, 'fused_ffn_output');
257
+ const output = outputBuffer || ownedOutput;
244
258
 
245
259
  const uniformBuffer = createFFNUniformBuffer(device, recorder, {
246
260
  M: batchSize,
@@ -252,39 +266,44 @@ export async function recordFusedFFN(
252
266
  swigluLimit: activation === 'silu' ? swigluLimit : null,
253
267
  });
254
268
 
255
- const bindGroup = device.createBindGroup({
256
- label: 'fused_ffn_bind_group',
257
- layout: pipeline.getBindGroupLayout(0),
258
- entries: [
259
- { binding: 0, resource: { buffer: uniformBuffer } },
260
- { binding: 1, resource: { buffer: input.buffer } },
261
- { binding: 2, resource: { buffer: getBuffer(W_gate) } },
262
- { binding: 3, resource: { buffer: getBuffer(W_up) } },
263
- { binding: 4, resource: { buffer: output } },
264
- ],
265
- });
266
-
269
+ try {
270
+ const bindGroup = device.createBindGroup({
271
+ label: 'fused_ffn_bind_group',
272
+ layout: pipeline.getBindGroupLayout(0),
273
+ entries: [
274
+ { binding: 0, resource: { buffer: uniformBuffer } },
275
+ { binding: 1, resource: { buffer: input.buffer } },
276
+ { binding: 2, resource: { buffer: getBuffer(W_gate) } },
277
+ { binding: 3, resource: { buffer: getBuffer(W_up) } },
278
+ { binding: 4, resource: { buffer: output } },
279
+ ],
280
+ });
281
+
282
+ let workgroupsX;
283
+ let workgroupsY = 1;
284
+
285
+ if (variant === 'multi') {
286
+ const outputsPerWg = 4;
287
+ workgroupsX = Math.ceil(intermediateSize / outputsPerWg);
288
+ } else if (variant === 'q4k' || variant === 'q4k_batched') {
289
+ const colsPerWg = 32;
290
+ workgroupsX = Math.ceil(intermediateSize / colsPerWg);
291
+ workgroupsY = variant === 'q4k_batched' ? batchSize : 1;
292
+ } else if (variant === 'batched' || variant === 'f16_native_batched') {
293
+ workgroupsX = intermediateSize;
294
+ workgroupsY = batchSize;
295
+ } else {
296
+ workgroupsX = intermediateSize;
297
+ }
267
298
 
268
- let workgroupsX;
269
- let workgroupsY = 1;
270
-
271
- if (variant === 'multi') {
272
- const outputsPerWg = 4;
273
- workgroupsX = Math.ceil(intermediateSize / outputsPerWg);
274
- } else if (variant === 'q4k' || variant === 'q4k_batched') {
275
- // Q4K uses multi-column: 32 columns per workgroup
276
- const colsPerWg = 32;
277
- workgroupsX = Math.ceil(intermediateSize / colsPerWg);
278
- workgroupsY = variant === 'q4k_batched' ? batchSize : 1;
279
- } else if (variant === 'batched' || variant === 'f16_native_batched') {
280
- workgroupsX = intermediateSize;
281
- workgroupsY = batchSize;
282
- } else {
283
- workgroupsX = intermediateSize;
299
+ kernel.record(recorder, pipeline, bindGroup, workgroupsX, workgroupsY);
300
+ } catch (error) {
301
+ if (ownedOutput) {
302
+ releaseBuffer(ownedOutput);
303
+ }
304
+ throw error;
284
305
  }
285
306
 
286
- kernel.record(recorder, pipeline, bindGroup, workgroupsX, workgroupsY);
287
-
288
307
  return createTensor(output, outputDtype, [batchSize, intermediateSize], 'fused_ffn_output');
289
308
  }
290
309