@simulatte/doppler 0.1.6 → 0.1.8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (355) hide show
  1. package/CHANGELOG.md +145 -0
  2. package/README.md +16 -23
  3. package/package.json +30 -32
  4. package/src/adapters/adapter-registry.js +12 -1
  5. package/src/adapters/lora-loader.js +23 -6
  6. package/src/bridge/extension-client.d.ts +5 -0
  7. package/src/bridge/extension-client.js +40 -0
  8. package/src/bridge/index.d.ts +2 -1
  9. package/src/bridge/index.js +6 -4
  10. package/src/browser/browser-converter.js +31 -1
  11. package/src/browser/file-picker.js +6 -0
  12. package/src/browser/safetensors-parser-browser.js +84 -1
  13. package/src/browser/shard-io-browser.js +2 -2
  14. package/src/browser/tensor-source-download.js +8 -2
  15. package/src/browser/tensor-source-http.d.ts +1 -0
  16. package/src/browser/tensor-source-http.js +5 -1
  17. package/src/client/doppler-api.browser.js +20 -4
  18. package/src/client/doppler-api.js +19 -3
  19. package/src/client/doppler-provider/generation.js +12 -0
  20. package/src/client/doppler-provider/model-manager.d.ts +10 -0
  21. package/src/client/doppler-provider/model-manager.js +91 -19
  22. package/src/client/doppler-provider/source-runtime.d.ts +2 -1
  23. package/src/client/doppler-provider/source-runtime.js +132 -13
  24. package/src/client/doppler-registry.json +5 -20
  25. package/src/config/backward-registry-loader.js +17 -2
  26. package/src/config/execution-v0-contract-check.js +113 -15
  27. package/src/config/kernel-path-contract-check.js +57 -29
  28. package/src/config/kernel-path-loader.d.ts +5 -0
  29. package/src/config/kernel-path-loader.js +18 -36
  30. package/src/config/kernels/kernel-ref-digests.js +1 -1
  31. package/src/config/kernels/registry.js +14 -1
  32. package/src/config/kernels/registry.json +81 -5
  33. package/src/config/loader.d.ts +1 -1
  34. package/src/config/loader.js +15 -2
  35. package/src/config/merge-contract-check.js +66 -4
  36. package/src/config/merge-helpers.js +128 -7
  37. package/src/config/merge.d.ts +1 -0
  38. package/src/config/merge.js +10 -0
  39. package/src/config/param-validator.js +47 -2
  40. package/src/config/presets/kernel-paths/{gemma2-q4k-dequant-f32a.json → gemma2-q4k-dequant-f32a-nosubgroups.json} +3 -3
  41. package/src/config/presets/kernel-paths/gemma3-f16-fused-f32a-online-streamingprefill.json +223 -0
  42. package/src/config/presets/kernel-paths/{gemma3-q4k-dequant-f32a.json → gemma3-q4k-dequant-f32a-nosubgroups.json} +3 -3
  43. package/src/config/presets/kernel-paths/gemma3-q4k-dequant-f32w-f32a-online.json +56 -0
  44. package/src/config/presets/kernel-paths/lfm2-q4k-dequant-f32a-nosubgroups.json +61 -0
  45. package/src/config/presets/kernel-paths/registry.json +43 -8
  46. package/src/config/presets/models/gemma2.json +3 -2
  47. package/src/config/presets/models/gemma3.json +2 -0
  48. package/src/config/presets/models/qwen3.json +4 -3
  49. package/src/config/presets/models/qwen3_5.json +16 -0
  50. package/src/config/presets/runtime/experiments/bench/gemma3-bench-q4k.json +1 -1
  51. package/src/config/presets/runtime/experiments/debug/gemma3-debug-q4k.json +1 -1
  52. package/src/config/presets/runtime/experiments/verify/gemma3-verify.json +1 -1
  53. package/src/config/presets/runtime/kernels/dequant-f16-q4k.json +6 -13
  54. package/src/config/presets/runtime/kernels/dequant-f32-q4k.json +6 -13
  55. package/src/config/presets/runtime/kernels/embeddinggemma-q4k-dequant-f32a.json +37 -0
  56. package/src/config/presets/runtime/kernels/fused-q4k.json +6 -13
  57. package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f16a.json +33 -0
  58. package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f32a-nosubgroups.json +33 -0
  59. package/src/config/presets/runtime/kernels/gemma2-q4k-fused-f32a.json +33 -0
  60. package/src/config/presets/runtime/kernels/safe-q4k.json +6 -13
  61. package/src/config/presets/runtime/model/qwen3-5-layer-probe.json +52 -0
  62. package/src/config/presets/runtime/model/qwen3-5-linear-attn-debug.json +90 -0
  63. package/src/config/presets/runtime/platform/metal-apple-q4k.json +1 -1
  64. package/src/config/runtime.js +6 -1
  65. package/src/config/schema/conversion.schema.d.ts +1 -0
  66. package/src/config/schema/debug.schema.d.ts +5 -0
  67. package/src/config/schema/doppler.schema.js +16 -21
  68. package/src/config/schema/inference-defaults.schema.js +3 -3
  69. package/src/config/schema/kernel-path.schema.d.ts +5 -1
  70. package/src/config/schema/kernel-thresholds.schema.js +12 -4
  71. package/src/config/schema/manifest.schema.d.ts +3 -2
  72. package/src/config/schema/manifest.schema.js +17 -4
  73. package/src/config/schema/storage.schema.js +1 -1
  74. package/src/config/training-defaults.js +30 -22
  75. package/src/converter/conversion-plan.js +104 -11
  76. package/src/converter/core.d.ts +7 -0
  77. package/src/converter/core.js +16 -9
  78. package/src/converter/execution-v0-manifest.js +4 -1
  79. package/src/converter/index.d.ts +1 -0
  80. package/src/converter/index.js +1 -0
  81. package/src/converter/manifest-inference.js +50 -29
  82. package/src/converter/parsers/diffusion.js +0 -3
  83. package/src/converter/parsers/transformer.js +4 -0
  84. package/src/converter/quantization-info.js +40 -16
  85. package/src/converter/quantizer.js +19 -12
  86. package/src/converter/rope-config.js +8 -6
  87. package/src/converter/shard-packer.d.ts +1 -1
  88. package/src/converter/shard-packer.js +4 -1
  89. package/src/converter/tokenizer-utils.d.ts +1 -0
  90. package/src/converter/tokenizer-utils.js +4 -1
  91. package/src/debug/config.js +123 -11
  92. package/src/debug/reference/hf_qwen35_linear_attn_debug.py +268 -0
  93. package/src/debug/signals.js +7 -1
  94. package/src/debug/tensor.d.ts +2 -0
  95. package/src/debug/tensor.js +13 -2
  96. package/src/distribution/p2p-control-plane.js +52 -12
  97. package/src/distribution/p2p-observability.js +43 -7
  98. package/src/distribution/p2p-webrtc-browser.js +20 -0
  99. package/src/distribution/shard-delivery.js +83 -27
  100. package/src/formats/gguf/types.js +33 -16
  101. package/src/formats/rdrr/groups.d.ts +12 -4
  102. package/src/formats/rdrr/groups.js +3 -6
  103. package/src/formats/rdrr/parsing.d.ts +4 -0
  104. package/src/formats/rdrr/parsing.js +53 -3
  105. package/src/formats/rdrr/types.d.ts +2 -1
  106. package/src/gpu/command-recorder.js +86 -61
  107. package/src/gpu/device.d.ts +1 -0
  108. package/src/gpu/device.js +73 -19
  109. package/src/gpu/kernel-tuner/benchmarks.js +326 -316
  110. package/src/gpu/kernel-tuner/cache.js +71 -4
  111. package/src/gpu/kernel-tuner/tuner.js +22 -4
  112. package/src/gpu/kernels/attention.js +15 -34
  113. package/src/gpu/kernels/backward/adam.js +62 -58
  114. package/src/gpu/kernels/backward/attention_backward.js +257 -169
  115. package/src/gpu/kernels/backward/conv2d_backward.js +14 -1
  116. package/src/gpu/kernels/cast.js +191 -149
  117. package/src/gpu/kernels/check-stop.js +33 -44
  118. package/src/gpu/kernels/conv2d.js +27 -17
  119. package/src/gpu/kernels/cross_entropy_loss.js +21 -15
  120. package/src/gpu/kernels/depthwise_conv2d.js +36 -26
  121. package/src/gpu/kernels/dequant.js +178 -126
  122. package/src/gpu/kernels/energy.d.ts +3 -21
  123. package/src/gpu/kernels/energy.js +111 -88
  124. package/src/gpu/kernels/feature-check.js +1 -1
  125. package/src/gpu/kernels/fused_ffn.js +84 -65
  126. package/src/gpu/kernels/fused_matmul_residual.js +56 -33
  127. package/src/gpu/kernels/fused_matmul_rmsnorm.js +62 -45
  128. package/src/gpu/kernels/gather.js +33 -15
  129. package/src/gpu/kernels/gelu.js +19 -11
  130. package/src/gpu/kernels/grouped_pointwise_conv2d.js +33 -23
  131. package/src/gpu/kernels/groupnorm.js +34 -23
  132. package/src/gpu/kernels/index.d.ts +8 -0
  133. package/src/gpu/kernels/index.js +6 -0
  134. package/src/gpu/kernels/kv-quantize.js +5 -2
  135. package/src/gpu/kernels/layernorm.js +35 -19
  136. package/src/gpu/kernels/logit-merge.js +5 -3
  137. package/src/gpu/kernels/matmul-selection.js +47 -4
  138. package/src/gpu/kernels/matmul.d.ts +2 -0
  139. package/src/gpu/kernels/matmul.js +59 -40
  140. package/src/gpu/kernels/modulate.js +23 -15
  141. package/src/gpu/kernels/moe.js +221 -175
  142. package/src/gpu/kernels/pixel_shuffle.js +22 -14
  143. package/src/gpu/kernels/relu.js +18 -10
  144. package/src/gpu/kernels/repeat_channels.js +25 -17
  145. package/src/gpu/kernels/residual.js +37 -27
  146. package/src/gpu/kernels/rmsnorm.js +66 -43
  147. package/src/gpu/kernels/rope.js +3 -0
  148. package/src/gpu/kernels/sample.js +27 -38
  149. package/src/gpu/kernels/sana_linear_attention.js +18 -10
  150. package/src/gpu/kernels/scale.js +18 -11
  151. package/src/gpu/kernels/shader-cache.js +4 -2
  152. package/src/gpu/kernels/silu.js +120 -72
  153. package/src/gpu/kernels/softmax.js +44 -25
  154. package/src/gpu/kernels/split_qg.d.ts +50 -0
  155. package/src/gpu/kernels/split_qg.js +46 -0
  156. package/src/gpu/kernels/split_qg.wgsl +58 -0
  157. package/src/gpu/kernels/split_qg_f16.wgsl +62 -0
  158. package/src/gpu/kernels/split_qkv.js +23 -13
  159. package/src/gpu/kernels/transpose.js +18 -10
  160. package/src/gpu/kernels/transpose.wgsl +5 -3
  161. package/src/gpu/kernels/upsample2d.js +21 -13
  162. package/src/gpu/kernels/utils.js +20 -13
  163. package/src/gpu/partitioned-buffer-pool.js +10 -2
  164. package/src/gpu/perf-guards.js +2 -9
  165. package/src/gpu/profiler.js +27 -22
  166. package/src/gpu/readback-utils.d.ts +16 -0
  167. package/src/gpu/readback-utils.js +41 -0
  168. package/src/gpu/submit-tracker.js +13 -0
  169. package/src/gpu/uniform-cache.d.ts +1 -0
  170. package/src/gpu/uniform-cache.js +30 -9
  171. package/src/gpu/weight-buffer.d.ts +1 -1
  172. package/src/gpu/weight-buffer.js +1 -1
  173. package/src/hotswap/intent-bundle.js +6 -0
  174. package/src/hotswap/manifest.d.ts +10 -1
  175. package/src/hotswap/manifest.js +12 -2
  176. package/src/hotswap/runtime.js +30 -8
  177. package/src/index-browser.d.ts +44 -0
  178. package/src/index-browser.js +14 -0
  179. package/src/inference/browser-harness-contract-helpers.d.ts +5 -0
  180. package/src/inference/browser-harness-contract-helpers.js +28 -0
  181. package/src/inference/browser-harness-diffusion-energy-suites.d.ts +2 -0
  182. package/src/inference/browser-harness-diffusion-energy-suites.js +269 -0
  183. package/src/inference/browser-harness-model-helpers.d.ts +16 -0
  184. package/src/inference/browser-harness-model-helpers.js +217 -0
  185. package/src/inference/browser-harness-report-helpers.d.ts +7 -0
  186. package/src/inference/browser-harness-report-helpers.js +42 -0
  187. package/src/inference/browser-harness-runtime-helpers.d.ts +61 -0
  188. package/src/inference/browser-harness-runtime-helpers.js +415 -0
  189. package/src/inference/browser-harness-suite-helpers.d.ts +28 -0
  190. package/src/inference/browser-harness-suite-helpers.js +268 -0
  191. package/src/inference/browser-harness-text-helpers.d.ts +27 -0
  192. package/src/inference/browser-harness-text-helpers.js +788 -0
  193. package/src/inference/browser-harness.d.ts +8 -0
  194. package/src/inference/browser-harness.js +149 -1996
  195. package/src/inference/kv-cache/base.js +140 -94
  196. package/src/inference/kv-cache/tiered.js +5 -3
  197. package/src/inference/moe-router.js +88 -56
  198. package/src/inference/multi-model-network.js +5 -3
  199. package/src/inference/network-evolution.d.ts +11 -2
  200. package/src/inference/network-evolution.js +20 -21
  201. package/src/inference/pipelines/context.d.ts +3 -0
  202. package/src/inference/pipelines/context.js +142 -2
  203. package/src/inference/pipelines/diffusion/helpers.js +10 -2
  204. package/src/inference/pipelines/diffusion/pipeline.js +2 -1
  205. package/src/inference/pipelines/diffusion/sd3-transformer.js +10 -10
  206. package/src/inference/pipelines/diffusion/text-encoder-gpu.js +8 -2
  207. package/src/inference/pipelines/diffusion/vae.js +3 -7
  208. package/src/inference/pipelines/energy/pipeline.js +27 -21
  209. package/src/inference/pipelines/energy/quintel.d.ts +5 -0
  210. package/src/inference/pipelines/energy/quintel.js +11 -0
  211. package/src/inference/pipelines/energy-head/row-head-pipeline.js +17 -13
  212. package/src/inference/pipelines/structured/json-head-pipeline.js +26 -11
  213. package/src/inference/pipelines/text/attention/output-projection.d.ts +12 -0
  214. package/src/inference/pipelines/text/attention/output-projection.js +8 -0
  215. package/src/inference/pipelines/text/attention/projections.d.ts +10 -1
  216. package/src/inference/pipelines/text/attention/projections.js +192 -112
  217. package/src/inference/pipelines/text/attention/record.js +77 -14
  218. package/src/inference/pipelines/text/attention/run.js +112 -14
  219. package/src/inference/pipelines/text/config.js +17 -4
  220. package/src/inference/pipelines/text/embed.js +2 -8
  221. package/src/inference/pipelines/text/execution-plan.js +46 -23
  222. package/src/inference/pipelines/text/execution-v0-contract-helpers.d.ts +59 -0
  223. package/src/inference/pipelines/text/execution-v0-contract-helpers.js +937 -0
  224. package/src/inference/pipelines/text/execution-v0-runtime-builders.d.ts +15 -0
  225. package/src/inference/pipelines/text/execution-v0-runtime-builders.js +279 -0
  226. package/src/inference/pipelines/text/execution-v0.js +62 -1013
  227. package/src/inference/pipelines/text/generator-runtime.js +5 -0
  228. package/src/inference/pipelines/text/generator-steps.d.ts +52 -0
  229. package/src/inference/pipelines/text/generator-steps.js +340 -221
  230. package/src/inference/pipelines/text/generator.js +56 -40
  231. package/src/inference/pipelines/text/init.d.ts +13 -0
  232. package/src/inference/pipelines/text/init.js +94 -25
  233. package/src/inference/pipelines/text/kernel-path-auto-select.js +2 -0
  234. package/src/inference/pipelines/text/kernel-trace.d.ts +2 -0
  235. package/src/inference/pipelines/text/kernel-trace.js +6 -0
  236. package/src/inference/pipelines/text/layer.js +4 -9
  237. package/src/inference/pipelines/text/linear-attention.d.ts +15 -0
  238. package/src/inference/pipelines/text/linear-attention.js +113 -9
  239. package/src/inference/pipelines/text/logits/gpu.js +12 -7
  240. package/src/inference/pipelines/text/logits/index.d.ts +6 -1
  241. package/src/inference/pipelines/text/logits/index.js +13 -12
  242. package/src/inference/pipelines/text/logits/utils.d.ts +7 -0
  243. package/src/inference/pipelines/text/logits/utils.js +9 -0
  244. package/src/inference/pipelines/text/lora-apply.js +50 -32
  245. package/src/inference/pipelines/text/model-load.js +282 -104
  246. package/src/inference/pipelines/text/moe-cache.js +5 -4
  247. package/src/inference/pipelines/text/moe-cpu-gptoss.js +74 -69
  248. package/src/inference/pipelines/text/moe-cpu.js +42 -38
  249. package/src/inference/pipelines/text/moe-gpu.js +110 -86
  250. package/src/inference/pipelines/text/ops.js +90 -90
  251. package/src/inference/pipelines/text/probes.js +9 -9
  252. package/src/inference/pipelines/text/sampling.js +52 -6
  253. package/src/inference/pipelines/text/weights.js +17 -7
  254. package/src/inference/pipelines/text.js +13 -1
  255. package/src/inference/speculative.d.ts +2 -2
  256. package/src/inference/speculative.js +4 -18
  257. package/src/inference/test-harness.d.ts +1 -1
  258. package/src/inference/test-harness.js +17 -7
  259. package/src/inference/tokenizer.d.ts +0 -5
  260. package/src/inference/tokenizer.js +4 -23
  261. package/src/inference/tokenizers/bpe.js +9 -0
  262. package/src/inference/tokenizers/bundled.js +20 -0
  263. package/src/inference/tokenizers/sentencepiece.js +12 -0
  264. package/src/loader/doppler-loader.js +38 -22
  265. package/src/loader/dtype-utils.js +3 -44
  266. package/src/loader/embedding-loader.js +7 -3
  267. package/src/loader/experts/expert-cache.js +13 -6
  268. package/src/loader/experts/expert-loader.js +10 -6
  269. package/src/loader/final-weights-loader.js +10 -4
  270. package/src/loader/layer-loader.js +2 -1
  271. package/src/loader/loader-state.js +2 -2
  272. package/src/loader/memory-monitor.js +8 -0
  273. package/src/loader/multi-model-loader.d.ts +14 -0
  274. package/src/loader/multi-model-loader.js +70 -24
  275. package/src/loader/shard-cache.js +84 -14
  276. package/src/loader/shard-resolver.js +25 -3
  277. package/src/loader/tensors/tensor-loader.js +214 -144
  278. package/src/loader/tensors/tensor-reader.js +76 -19
  279. package/src/loader/weight-downcast.js +1 -1
  280. package/src/memory/buffer-pool.d.ts +9 -1
  281. package/src/memory/buffer-pool.js +109 -44
  282. package/src/memory/unified-detect.js +1 -1
  283. package/src/rules/inference/dtype.rules.json +5 -0
  284. package/src/rules/inference/kernel-path.rules.json +24 -8
  285. package/src/rules/kernels/split-qg.rules.json +6 -0
  286. package/src/rules/rule-registry.js +27 -1
  287. package/src/storage/backends/opfs-store.js +68 -24
  288. package/src/storage/downloader.js +365 -83
  289. package/src/storage/index.d.ts +3 -0
  290. package/src/storage/index.js +3 -0
  291. package/src/storage/preflight.d.ts +2 -2
  292. package/src/storage/preflight.js +24 -2
  293. package/src/storage/quickstart-downloader.js +11 -5
  294. package/src/storage/registry.js +10 -4
  295. package/src/storage/reports.js +1 -1
  296. package/src/storage/shard-manager.d.ts +15 -1
  297. package/src/storage/shard-manager.js +55 -6
  298. package/src/storage/source-artifact-store.d.ts +52 -0
  299. package/src/storage/source-artifact-store.js +234 -0
  300. package/src/tooling/command-api-constants.d.ts +9 -0
  301. package/src/tooling/command-api-constants.js +9 -0
  302. package/src/tooling/command-api-family-normalizers.d.ts +9 -0
  303. package/src/tooling/command-api-family-normalizers.js +343 -0
  304. package/src/tooling/command-api-helpers.d.ts +25 -0
  305. package/src/tooling/command-api-helpers.js +262 -0
  306. package/src/tooling/command-api.js +16 -602
  307. package/src/tooling/command-envelope.js +4 -1
  308. package/src/tooling/command-runner-shared.js +52 -18
  309. package/src/tooling/conversion-config-materializer.js +3 -5
  310. package/src/tooling/lean-execution-contract.js +150 -3
  311. package/src/tooling/node-browser-command-runner.js +161 -271
  312. package/src/tooling/node-command-runner.js +29 -3
  313. package/src/tooling/node-converter.js +30 -1
  314. package/src/tooling/node-source-runtime.d.ts +1 -1
  315. package/src/tooling/node-source-runtime.js +120 -3
  316. package/src/tooling/node-webgpu.js +24 -21
  317. package/src/tooling/opfs-cache.js +21 -4
  318. package/src/tooling/runtime-input-composition.d.ts +38 -0
  319. package/src/tooling/runtime-input-composition.js +86 -0
  320. package/src/tooling/source-runtime-bundle.d.ts +40 -5
  321. package/src/tooling/source-runtime-bundle.js +261 -34
  322. package/src/tooling/source-runtime-materializer.d.ts +6 -0
  323. package/src/tooling/source-runtime-materializer.js +93 -0
  324. package/src/training/attention-backward.js +32 -17
  325. package/src/training/autograd.js +80 -52
  326. package/src/training/checkpoint-watch.d.ts +2 -1
  327. package/src/training/checkpoint-watch.js +39 -6
  328. package/src/training/checkpoint.js +40 -11
  329. package/src/training/clip.js +2 -1
  330. package/src/training/datasets/token-batch.js +20 -8
  331. package/src/training/distillation/checkpoint-watch.js +1 -0
  332. package/src/training/distillation/student-fixture.d.ts +22 -0
  333. package/src/training/distillation/student-fixture.js +846 -0
  334. package/src/training/distillation/suite-data.d.ts +45 -0
  335. package/src/training/distillation/suite-data.js +189 -0
  336. package/src/training/lora-pipeline.js +4 -7
  337. package/src/training/lora.js +26 -12
  338. package/src/training/loss.js +5 -6
  339. package/src/training/objectives/cross_entropy.js +2 -5
  340. package/src/training/objectives/distill_kd.js +4 -8
  341. package/src/training/objectives/distill_triplet.js +4 -8
  342. package/src/training/objectives/ul_stage2_base.js +4 -8
  343. package/src/training/operator-command.js +2 -0
  344. package/src/training/optimizer.js +19 -7
  345. package/src/training/runner.js +2 -1
  346. package/src/training/suite.js +18 -978
  347. package/src/training/tensor-factory.d.ts +9 -0
  348. package/src/training/tensor-factory.js +13 -0
  349. package/src/training/trainer.js +3 -5
  350. package/src/training/ul_dataset.js +3 -5
  351. package/src/training/workloads.js +70 -79
  352. package/src/types/model.d.ts +5 -0
  353. package/src/version.js +1 -1
  354. package/tools/convert-safetensors-node.js +22 -16
  355. package/tools/doppler-cli.js +50 -26
@@ -1,5 +1,5 @@
1
1
  import { getDevice } from '../device.js';
2
- import { acquireBuffer } from '../../memory/buffer-pool.js';
2
+ import { acquireBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
3
3
  import { createTensor, dtypeBytes } from '../tensor.js';
4
4
  import { WORKGROUP_SIZES } from './constants.js';
5
5
  import { dispatch, recordDispatch } from './dispatch.js';
@@ -61,15 +61,14 @@ function resolveQuintelSize(state, sizeOverride) {
61
61
  return null;
62
62
  }
63
63
 
64
- function buildQuintelFlags(rules, binarizeWeight) {
65
- let flags = 0;
66
- if (rules?.mirrorX) flags |= 1;
67
- if (rules?.mirrorY) flags |= 2;
68
- if (rules?.diagonal) flags |= 4;
69
- if (rules?.count) flags |= 8;
70
- if (rules?.center) flags |= 16;
71
- if (Number.isFinite(binarizeWeight) && binarizeWeight !== 0) flags |= 32;
72
- return flags >>> 0;
64
+ function resolveQuintelFlags(options, op) {
65
+ if (options.rules !== undefined) {
66
+ throw new Error(`${op}: quintel kernel flags must be resolved before dispatch.`);
67
+ }
68
+ if (!Number.isFinite(options.flags)) {
69
+ throw new Error(`${op}: flags is required for quintel kernels.`);
70
+ }
71
+ return options.flags >>> 0;
73
72
  }
74
73
 
75
74
  function resolveExecution(recorder) {
@@ -103,6 +102,12 @@ function releaseUniformBuffer(execution, uniformBuffer) {
103
102
  }
104
103
  }
105
104
 
105
+ function releaseOwnedBuffer(ownedBuffer) {
106
+ if (ownedBuffer) {
107
+ releaseBuffer(ownedBuffer);
108
+ }
109
+ }
110
+
106
111
  function writeQuintelUpdateUniform(view, params) {
107
112
  view.setUint32(0, params.elementCount, true);
108
113
  view.setUint32(4, params.boardSize, true);
@@ -149,6 +154,7 @@ async function executeEnergyEval(recorder, state, target, options = {}, op) {
149
154
 
150
155
  const outputSize = elementCount * 4;
151
156
  const output = outputBuffer || acquireBuffer(outputSize, undefined, 'energy_eval_output');
157
+ const ownedOutput = outputBuffer ? null : output;
152
158
 
153
159
  const variant = selectEnergyEvalVariant(state.dtype);
154
160
  const pipeline = await getPipelineFast('energy_eval', variant);
@@ -157,23 +163,27 @@ async function executeEnergyEval(recorder, state, target, options = {}, op) {
157
163
  view.setUint32(0, elementCount, true);
158
164
  view.setFloat32(4, scale, true);
159
165
  });
166
+ try {
167
+ const bindGroup = execution.device.createBindGroup({
168
+ label: 'energy_eval_bind_group',
169
+ layout: pipeline.getBindGroupLayout(0),
170
+ entries: [
171
+ { binding: 0, resource: { buffer: uniformBuffer } },
172
+ { binding: 1, resource: { buffer: state.buffer } },
173
+ { binding: 2, resource: { buffer: target.buffer } },
174
+ { binding: 3, resource: { buffer: output } },
175
+ ],
176
+ });
160
177
 
161
- const bindGroup = execution.device.createBindGroup({
162
- label: 'energy_eval_bind_group',
163
- layout: pipeline.getBindGroupLayout(0),
164
- entries: [
165
- { binding: 0, resource: { buffer: uniformBuffer } },
166
- { binding: 1, resource: { buffer: state.buffer } },
167
- { binding: 2, resource: { buffer: target.buffer } },
168
- { binding: 3, resource: { buffer: output } },
169
- ],
170
- });
171
-
172
- const workgroups = Math.ceil(elementCount / WORKGROUP_SIZES.DEFAULT);
173
- dispatchEnergy(execution, pipeline, bindGroup, workgroups, 'energy_eval');
174
- releaseUniformBuffer(execution, uniformBuffer);
175
-
176
- return createTensor(output, 'f32', [elementCount], 'energy_eval_output');
178
+ const workgroups = Math.ceil(elementCount / WORKGROUP_SIZES.DEFAULT);
179
+ dispatchEnergy(execution, pipeline, bindGroup, workgroups, 'energy_eval');
180
+ return createTensor(output, 'f32', [elementCount], 'energy_eval_output');
181
+ } catch (error) {
182
+ releaseOwnedBuffer(ownedOutput);
183
+ throw error;
184
+ } finally {
185
+ releaseUniformBuffer(execution, uniformBuffer);
186
+ }
177
187
  }
178
188
 
179
189
  async function executeEnergyUpdate(recorder, state, target, options = {}, op) {
@@ -191,21 +201,23 @@ async function executeEnergyUpdate(recorder, state, target, options = {}, op) {
191
201
  view.setFloat32(8, gradientScale, true);
192
202
  });
193
203
 
194
- const bindGroup = execution.device.createBindGroup({
195
- label: 'energy_update_bind_group',
196
- layout: pipeline.getBindGroupLayout(0),
197
- entries: [
198
- { binding: 0, resource: { buffer: uniformBuffer } },
199
- { binding: 1, resource: { buffer: state.buffer } },
200
- { binding: 2, resource: { buffer: target.buffer } },
201
- ],
202
- });
203
-
204
- const workgroups = Math.ceil(elementCount / WORKGROUP_SIZES.DEFAULT);
205
- dispatchEnergy(execution, pipeline, bindGroup, workgroups, 'energy_update');
206
- releaseUniformBuffer(execution, uniformBuffer);
204
+ try {
205
+ const bindGroup = execution.device.createBindGroup({
206
+ label: 'energy_update_bind_group',
207
+ layout: pipeline.getBindGroupLayout(0),
208
+ entries: [
209
+ { binding: 0, resource: { buffer: uniformBuffer } },
210
+ { binding: 1, resource: { buffer: state.buffer } },
211
+ { binding: 2, resource: { buffer: target.buffer } },
212
+ ],
213
+ });
207
214
 
208
- return state;
215
+ const workgroups = Math.ceil(elementCount / WORKGROUP_SIZES.DEFAULT);
216
+ dispatchEnergy(execution, pipeline, bindGroup, workgroups, 'energy_update');
217
+ return state;
218
+ } finally {
219
+ releaseUniformBuffer(execution, uniformBuffer);
220
+ }
209
221
  }
210
222
 
211
223
  async function executeEnergyQuintelUpdate(recorder, state, options = {}, op) {
@@ -224,7 +236,6 @@ async function executeEnergyQuintelUpdate(recorder, state, options = {}, op) {
224
236
  centerTarget = 1.0,
225
237
  clampMin = 0.0,
226
238
  clampMax = 1.0,
227
- rules = {},
228
239
  } = options;
229
240
  const elementCount = inferCount(state, count);
230
241
  const boardSize = resolveQuintelSize(state, size);
@@ -234,7 +245,7 @@ async function executeEnergyQuintelUpdate(recorder, state, options = {}, op) {
234
245
 
235
246
  const variant = selectEnergyQuintelUpdateVariant(state.dtype);
236
247
  const pipeline = await getPipelineFast('energy_quintel_update', variant);
237
- const flags = buildQuintelFlags(rules, binarizeWeight);
248
+ const flags = resolveQuintelFlags(options, op);
238
249
 
239
250
  const uniformBuffer = createUniformBuffer(execution, 'energy_quintel_uniforms', 64, (view) => {
240
251
  writeQuintelUpdateUniform(view, {
@@ -254,20 +265,22 @@ async function executeEnergyQuintelUpdate(recorder, state, options = {}, op) {
254
265
  });
255
266
  });
256
267
 
257
- const bindGroup = execution.device.createBindGroup({
258
- label: 'energy_quintel_update_bind_group',
259
- layout: pipeline.getBindGroupLayout(0),
260
- entries: [
261
- { binding: 0, resource: { buffer: uniformBuffer } },
262
- { binding: 1, resource: { buffer: state.buffer } },
263
- ],
264
- });
265
-
266
- const workgroups = Math.ceil(elementCount / WORKGROUP_SIZES.DEFAULT);
267
- dispatchEnergy(execution, pipeline, bindGroup, workgroups, 'energy_quintel_update');
268
- releaseUniformBuffer(execution, uniformBuffer);
268
+ try {
269
+ const bindGroup = execution.device.createBindGroup({
270
+ label: 'energy_quintel_update_bind_group',
271
+ layout: pipeline.getBindGroupLayout(0),
272
+ entries: [
273
+ { binding: 0, resource: { buffer: uniformBuffer } },
274
+ { binding: 1, resource: { buffer: state.buffer } },
275
+ ],
276
+ });
269
277
 
270
- return state;
278
+ const workgroups = Math.ceil(elementCount / WORKGROUP_SIZES.DEFAULT);
279
+ dispatchEnergy(execution, pipeline, bindGroup, workgroups, 'energy_quintel_update');
280
+ return state;
281
+ } finally {
282
+ releaseUniformBuffer(execution, uniformBuffer);
283
+ }
271
284
  }
272
285
 
273
286
  async function executeEnergyQuintelReduce(recorder, state, options = {}, op) {
@@ -280,7 +293,6 @@ async function executeEnergyQuintelReduce(recorder, state, options = {}, op) {
280
293
  centerWeight = 1.0,
281
294
  binarizeWeight = 0.0,
282
295
  centerTarget = 1.0,
283
- rules = {},
284
296
  outputBuffer = null,
285
297
  } = options;
286
298
  const elementCount = inferCount(state, count);
@@ -291,7 +303,7 @@ async function executeEnergyQuintelReduce(recorder, state, options = {}, op) {
291
303
 
292
304
  const variant = selectEnergyQuintelReduceVariant(state.dtype);
293
305
  const pipeline = await getPipelineFast('energy_quintel_reduce', variant);
294
- const flags = buildQuintelFlags(rules, binarizeWeight);
306
+ const flags = resolveQuintelFlags(options, op);
295
307
 
296
308
  const uniformBuffer = createUniformBuffer(execution, 'energy_quintel_reduce_uniforms', 48, (view) => {
297
309
  writeQuintelReduceUniform(view, {
@@ -308,21 +320,27 @@ async function executeEnergyQuintelReduce(recorder, state, options = {}, op) {
308
320
  const workgroups = Math.ceil(elementCount / WORKGROUP_SIZES.DEFAULT);
309
321
  const outputSize = workgroups * 16;
310
322
  const output = outputBuffer || acquireBuffer(outputSize, undefined, 'energy_quintel_reduce_output');
323
+ const ownedOutput = outputBuffer ? null : output;
324
+
325
+ try {
326
+ const bindGroup = execution.device.createBindGroup({
327
+ label: 'energy_quintel_reduce_bind_group',
328
+ layout: pipeline.getBindGroupLayout(0),
329
+ entries: [
330
+ { binding: 0, resource: { buffer: uniformBuffer } },
331
+ { binding: 1, resource: { buffer: state.buffer } },
332
+ { binding: 2, resource: { buffer: output } },
333
+ ],
334
+ });
311
335
 
312
- const bindGroup = execution.device.createBindGroup({
313
- label: 'energy_quintel_reduce_bind_group',
314
- layout: pipeline.getBindGroupLayout(0),
315
- entries: [
316
- { binding: 0, resource: { buffer: uniformBuffer } },
317
- { binding: 1, resource: { buffer: state.buffer } },
318
- { binding: 2, resource: { buffer: output } },
319
- ],
320
- });
321
-
322
- dispatchEnergy(execution, pipeline, bindGroup, workgroups, 'energy_quintel_reduce');
323
- releaseUniformBuffer(execution, uniformBuffer);
324
-
325
- return createTensor(output, 'f32', [workgroups, 4], 'energy_quintel_reduce_output');
336
+ dispatchEnergy(execution, pipeline, bindGroup, workgroups, 'energy_quintel_reduce');
337
+ return createTensor(output, 'f32', [workgroups, 4], 'energy_quintel_reduce_output');
338
+ } catch (error) {
339
+ releaseOwnedBuffer(ownedOutput);
340
+ throw error;
341
+ } finally {
342
+ releaseUniformBuffer(execution, uniformBuffer);
343
+ }
326
344
  }
327
345
 
328
346
  async function executeEnergyQuintelGrad(recorder, state, options = {}, op) {
@@ -337,7 +355,6 @@ async function executeEnergyQuintelGrad(recorder, state, options = {}, op) {
337
355
  centerWeight = 1.0,
338
356
  binarizeWeight = 0.0,
339
357
  centerTarget = 1.0,
340
- rules = {},
341
358
  outputBuffer = null,
342
359
  } = options;
343
360
  const elementCount = inferCount(state, count);
@@ -348,7 +365,7 @@ async function executeEnergyQuintelGrad(recorder, state, options = {}, op) {
348
365
 
349
366
  const variant = selectEnergyQuintelGradVariant(state.dtype);
350
367
  const pipeline = await getPipelineFast('energy_quintel_grad', variant);
351
- const flags = buildQuintelFlags(rules, binarizeWeight);
368
+ const flags = resolveQuintelFlags(options, op);
352
369
 
353
370
  const uniformBuffer = createUniformBuffer(execution, 'energy_quintel_grad_uniforms', 64, (view) => {
354
371
  writeQuintelGradUniform(view, {
@@ -366,22 +383,28 @@ async function executeEnergyQuintelGrad(recorder, state, options = {}, op) {
366
383
 
367
384
  const outputSize = elementCount * 4;
368
385
  const output = outputBuffer || acquireBuffer(outputSize, undefined, 'energy_quintel_grad_output');
386
+ const ownedOutput = outputBuffer ? null : output;
387
+
388
+ try {
389
+ const bindGroup = execution.device.createBindGroup({
390
+ label: 'energy_quintel_grad_bind_group',
391
+ layout: pipeline.getBindGroupLayout(0),
392
+ entries: [
393
+ { binding: 0, resource: { buffer: uniformBuffer } },
394
+ { binding: 1, resource: { buffer: state.buffer } },
395
+ { binding: 2, resource: { buffer: output } },
396
+ ],
397
+ });
369
398
 
370
- const bindGroup = execution.device.createBindGroup({
371
- label: 'energy_quintel_grad_bind_group',
372
- layout: pipeline.getBindGroupLayout(0),
373
- entries: [
374
- { binding: 0, resource: { buffer: uniformBuffer } },
375
- { binding: 1, resource: { buffer: state.buffer } },
376
- { binding: 2, resource: { buffer: output } },
377
- ],
378
- });
379
-
380
- const workgroups = Math.ceil(elementCount / WORKGROUP_SIZES.DEFAULT);
381
- dispatchEnergy(execution, pipeline, bindGroup, workgroups, 'energy_quintel_grad');
382
- releaseUniformBuffer(execution, uniformBuffer);
383
-
384
- return createTensor(output, 'f32', [elementCount], 'energy_quintel_grad_output');
399
+ const workgroups = Math.ceil(elementCount / WORKGROUP_SIZES.DEFAULT);
400
+ dispatchEnergy(execution, pipeline, bindGroup, workgroups, 'energy_quintel_grad');
401
+ return createTensor(output, 'f32', [elementCount], 'energy_quintel_grad_output');
402
+ } catch (error) {
403
+ releaseOwnedBuffer(ownedOutput);
404
+ throw error;
405
+ } finally {
406
+ releaseUniformBuffer(execution, uniformBuffer);
407
+ }
385
408
  }
386
409
 
387
410
  export async function runEnergyEval(state, target, options = {}) {
@@ -16,7 +16,7 @@ export function hasRequiredFeatures(
16
16
  for (const feature of required) {
17
17
  if (feature === 'shader-f16' && !capabilities.hasF16) return false;
18
18
  if (feature === 'subgroups' && !capabilities.hasSubgroups) return false;
19
- if (feature === 'subgroups-f16' && !capabilities.hasSubgroups) return false;
19
+ if (feature === 'subgroups-f16' && !capabilities.hasSubgroupsF16) return false;
20
20
  }
21
21
  return true;
22
22
  }
@@ -1,7 +1,7 @@
1
1
 
2
2
 
3
3
  import { getDevice, getKernelCapabilities } from '../device.js';
4
- import { acquireBuffer } from '../../memory/buffer-pool.js';
4
+ import { acquireBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
5
5
  import { createTensor } from '../tensor.js';
6
6
  import { KernelBase } from './kernel-base.js';
7
7
  import { createUniformBufferWithView } from './utils.js';
@@ -77,6 +77,17 @@ function resolveSwigluLimit(value, context) {
77
77
  return value;
78
78
  }
79
79
 
80
+ function releaseRunResources(uniformBuffer, ownedBuffers) {
81
+ if (uniformBuffer) {
82
+ uniformBuffer.destroy();
83
+ }
84
+ for (const buffer of ownedBuffers) {
85
+ if (buffer) {
86
+ releaseBuffer(buffer);
87
+ }
88
+ }
89
+ }
90
+
80
91
 
81
92
  export async function runFusedFFN(
82
93
  input,
@@ -132,7 +143,8 @@ export async function runFusedFFN(
132
143
  const outputBytesPerElement = isF16Native ? 2 : 4;
133
144
  const outputDtype = isF16Native ? 'f16' : 'f32';
134
145
  const outputSize = batchSize * intermediateSize * outputBytesPerElement;
135
- const output = outputBuffer || acquireBuffer(outputSize, undefined, 'fused_ffn_output');
146
+ const ownedOutput = outputBuffer ? null : acquireBuffer(outputSize, undefined, 'fused_ffn_output');
147
+ const output = outputBuffer || ownedOutput;
136
148
 
137
149
  // Create uniform buffer
138
150
  const uniformBuffer = createFFNUniformBuffer(device, null, {
@@ -145,41 +157,42 @@ export async function runFusedFFN(
145
157
  swigluLimit: activation === 'silu' ? swigluLimit : null,
146
158
  });
147
159
 
148
- // Create bind group
149
- const bindGroup = device.createBindGroup({
150
- label: 'fused_ffn_bind_group',
151
- layout: pipeline.getBindGroupLayout(0),
152
- entries: [
153
- { binding: 0, resource: { buffer: uniformBuffer } },
154
- { binding: 1, resource: { buffer: input.buffer } },
155
- { binding: 2, resource: { buffer: getBuffer(W_gate) } },
156
- { binding: 3, resource: { buffer: getBuffer(W_up) } },
157
- { binding: 4, resource: { buffer: output } },
158
- ],
159
- });
160
+ try {
161
+ const bindGroup = device.createBindGroup({
162
+ label: 'fused_ffn_bind_group',
163
+ layout: pipeline.getBindGroupLayout(0),
164
+ entries: [
165
+ { binding: 0, resource: { buffer: uniformBuffer } },
166
+ { binding: 1, resource: { buffer: input.buffer } },
167
+ { binding: 2, resource: { buffer: getBuffer(W_gate) } },
168
+ { binding: 3, resource: { buffer: getBuffer(W_up) } },
169
+ { binding: 4, resource: { buffer: output } },
170
+ ],
171
+ });
172
+
173
+ let workgroupsX;
174
+ let workgroupsY = 1;
175
+
176
+ if (variant === 'multi') {
177
+ const outputsPerWg = 4;
178
+ workgroupsX = Math.ceil(intermediateSize / outputsPerWg);
179
+ } else if (variant === 'q4k' || variant === 'q4k_batched') {
180
+ const colsPerWg = 32;
181
+ workgroupsX = Math.ceil(intermediateSize / colsPerWg);
182
+ workgroupsY = variant === 'q4k_batched' ? batchSize : 1;
183
+ } else if (variant === 'batched' || variant === 'f16_native_batched') {
184
+ workgroupsX = intermediateSize;
185
+ workgroupsY = batchSize;
186
+ } else {
187
+ workgroupsX = intermediateSize;
188
+ }
160
189
 
161
- // Calculate workgroups
162
-
163
- let workgroupsX;
164
- let workgroupsY = 1;
165
-
166
- if (variant === 'multi') {
167
- const outputsPerWg = 4;
168
- workgroupsX = Math.ceil(intermediateSize / outputsPerWg);
169
- } else if (variant === 'q4k' || variant === 'q4k_batched') {
170
- // Q4K uses multi-column: 32 columns per workgroup
171
- const colsPerWg = 32;
172
- workgroupsX = Math.ceil(intermediateSize / colsPerWg);
173
- workgroupsY = variant === 'q4k_batched' ? batchSize : 1;
174
- } else if (variant === 'batched' || variant === 'f16_native_batched') {
175
- workgroupsX = intermediateSize;
176
- workgroupsY = batchSize;
177
- } else {
178
- workgroupsX = intermediateSize;
190
+ kernel.dispatch(pipeline, bindGroup, workgroupsX, workgroupsY);
191
+ } catch (error) {
192
+ releaseRunResources(uniformBuffer, [ownedOutput]);
193
+ throw error;
179
194
  }
180
195
 
181
- kernel.dispatch(pipeline, bindGroup, workgroupsX, workgroupsY);
182
-
183
196
  uniformBuffer.destroy();
184
197
 
185
198
  return createTensor(output, outputDtype, [batchSize, intermediateSize], 'fused_ffn_output');
@@ -240,7 +253,8 @@ export async function recordFusedFFN(
240
253
  const outputBytesPerElement = isF16Native ? 2 : 4;
241
254
  const outputDtype = isF16Native ? 'f16' : 'f32';
242
255
  const outputSize = batchSize * intermediateSize * outputBytesPerElement;
243
- const output = outputBuffer || acquireBuffer(outputSize, undefined, 'fused_ffn_output');
256
+ const ownedOutput = outputBuffer ? null : acquireBuffer(outputSize, undefined, 'fused_ffn_output');
257
+ const output = outputBuffer || ownedOutput;
244
258
 
245
259
  const uniformBuffer = createFFNUniformBuffer(device, recorder, {
246
260
  M: batchSize,
@@ -252,39 +266,44 @@ export async function recordFusedFFN(
252
266
  swigluLimit: activation === 'silu' ? swigluLimit : null,
253
267
  });
254
268
 
255
- const bindGroup = device.createBindGroup({
256
- label: 'fused_ffn_bind_group',
257
- layout: pipeline.getBindGroupLayout(0),
258
- entries: [
259
- { binding: 0, resource: { buffer: uniformBuffer } },
260
- { binding: 1, resource: { buffer: input.buffer } },
261
- { binding: 2, resource: { buffer: getBuffer(W_gate) } },
262
- { binding: 3, resource: { buffer: getBuffer(W_up) } },
263
- { binding: 4, resource: { buffer: output } },
264
- ],
265
- });
266
-
269
+ try {
270
+ const bindGroup = device.createBindGroup({
271
+ label: 'fused_ffn_bind_group',
272
+ layout: pipeline.getBindGroupLayout(0),
273
+ entries: [
274
+ { binding: 0, resource: { buffer: uniformBuffer } },
275
+ { binding: 1, resource: { buffer: input.buffer } },
276
+ { binding: 2, resource: { buffer: getBuffer(W_gate) } },
277
+ { binding: 3, resource: { buffer: getBuffer(W_up) } },
278
+ { binding: 4, resource: { buffer: output } },
279
+ ],
280
+ });
281
+
282
+ let workgroupsX;
283
+ let workgroupsY = 1;
284
+
285
+ if (variant === 'multi') {
286
+ const outputsPerWg = 4;
287
+ workgroupsX = Math.ceil(intermediateSize / outputsPerWg);
288
+ } else if (variant === 'q4k' || variant === 'q4k_batched') {
289
+ const colsPerWg = 32;
290
+ workgroupsX = Math.ceil(intermediateSize / colsPerWg);
291
+ workgroupsY = variant === 'q4k_batched' ? batchSize : 1;
292
+ } else if (variant === 'batched' || variant === 'f16_native_batched') {
293
+ workgroupsX = intermediateSize;
294
+ workgroupsY = batchSize;
295
+ } else {
296
+ workgroupsX = intermediateSize;
297
+ }
267
298
 
268
- let workgroupsX;
269
- let workgroupsY = 1;
270
-
271
- if (variant === 'multi') {
272
- const outputsPerWg = 4;
273
- workgroupsX = Math.ceil(intermediateSize / outputsPerWg);
274
- } else if (variant === 'q4k' || variant === 'q4k_batched') {
275
- // Q4K uses multi-column: 32 columns per workgroup
276
- const colsPerWg = 32;
277
- workgroupsX = Math.ceil(intermediateSize / colsPerWg);
278
- workgroupsY = variant === 'q4k_batched' ? batchSize : 1;
279
- } else if (variant === 'batched' || variant === 'f16_native_batched') {
280
- workgroupsX = intermediateSize;
281
- workgroupsY = batchSize;
282
- } else {
283
- workgroupsX = intermediateSize;
299
+ kernel.record(recorder, pipeline, bindGroup, workgroupsX, workgroupsY);
300
+ } catch (error) {
301
+ if (ownedOutput) {
302
+ releaseBuffer(ownedOutput);
303
+ }
304
+ throw error;
284
305
  }
285
306
 
286
- kernel.record(recorder, pipeline, bindGroup, workgroupsX, workgroupsY);
287
-
288
307
  return createTensor(output, outputDtype, [batchSize, intermediateSize], 'fused_ffn_output');
289
308
  }
290
309
 
@@ -1,7 +1,7 @@
1
1
 
2
2
 
3
3
  import { getDevice } from '../device.js';
4
- import { acquireBuffer } from '../../memory/buffer-pool.js';
4
+ import { acquireBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
5
5
  import { createTensor, dtypeBytes } from '../tensor.js';
6
6
  import { getBuffer } from '../weight-buffer.js';
7
7
  import { dispatch, recordDispatch } from './dispatch.js';
@@ -47,7 +47,12 @@ export async function runMatmulResidualFused(
47
47
  const pipelineVariant = resolveFusedResidualVariant(input, residual);
48
48
  const pipeline = await getPipelineFast('fused_matmul_residual', pipelineVariant);
49
49
 
50
- const output = outputBuffer || acquireBuffer(N * dtypeBytes(outputDtype), undefined, 'matmul_residual_output');
50
+ const ownedOutput = outputBuffer ? null : acquireBuffer(
51
+ N * dtypeBytes(outputDtype),
52
+ undefined,
53
+ 'matmul_residual_output'
54
+ );
55
+ const output = outputBuffer || ownedOutput;
51
56
 
52
57
  // Create uniform buffer (same layout as matmul_gemv)
53
58
  const uniformBuffer = createUniformBufferWithView(
@@ -68,21 +73,28 @@ export async function runMatmulResidualFused(
68
73
  );
69
74
 
70
75
  // Create bind group
71
- const bindGroup = device.createBindGroup({
72
- label: 'matmul_residual_bind_group',
73
- layout: pipeline.getBindGroupLayout(0),
74
- entries: [
75
- { binding: 0, resource: { buffer: uniformBuffer } },
76
- { binding: 1, resource: { buffer: input.buffer } },
77
- { binding: 2, resource: { buffer: weightBuffer } },
78
- { binding: 3, resource: { buffer: output } },
79
- { binding: 4, resource: { buffer: residual.buffer } },
80
- ],
81
- });
82
-
83
- // One workgroup per output element
84
- const workgroups = N;
85
- dispatch(device, pipeline, bindGroup, workgroups, 'matmul_residual_fused');
76
+ try {
77
+ const bindGroup = device.createBindGroup({
78
+ label: 'matmul_residual_bind_group',
79
+ layout: pipeline.getBindGroupLayout(0),
80
+ entries: [
81
+ { binding: 0, resource: { buffer: uniformBuffer } },
82
+ { binding: 1, resource: { buffer: input.buffer } },
83
+ { binding: 2, resource: { buffer: weightBuffer } },
84
+ { binding: 3, resource: { buffer: output } },
85
+ { binding: 4, resource: { buffer: residual.buffer } },
86
+ ],
87
+ });
88
+
89
+ const workgroups = N;
90
+ dispatch(device, pipeline, bindGroup, workgroups, 'matmul_residual_fused');
91
+ } catch (error) {
92
+ uniformBuffer.destroy();
93
+ if (ownedOutput) {
94
+ releaseBuffer(ownedOutput);
95
+ }
96
+ throw error;
97
+ }
86
98
 
87
99
  uniformBuffer.destroy();
88
100
 
@@ -112,7 +124,12 @@ export async function recordMatmulResidualFused(
112
124
  const pipelineVariant = resolveFusedResidualVariant(input, residual);
113
125
  const pipeline = await getPipelineFast('fused_matmul_residual', pipelineVariant);
114
126
 
115
- const output = outputBuffer || acquireBuffer(N * dtypeBytes(outputDtype), undefined, 'matmul_residual_output');
127
+ const ownedOutput = outputBuffer ? null : acquireBuffer(
128
+ N * dtypeBytes(outputDtype),
129
+ undefined,
130
+ 'matmul_residual_output'
131
+ );
132
+ const output = outputBuffer || ownedOutput;
116
133
 
117
134
  // Create uniform buffer
118
135
  const uniformBuffer = createUniformBufferWithView(
@@ -132,21 +149,27 @@ export async function recordMatmulResidualFused(
132
149
  );
133
150
 
134
151
  // Create bind group
135
- const bindGroup = device.createBindGroup({
136
- label: 'matmul_residual_bind_group',
137
- layout: pipeline.getBindGroupLayout(0),
138
- entries: [
139
- { binding: 0, resource: { buffer: uniformBuffer } },
140
- { binding: 1, resource: { buffer: input.buffer } },
141
- { binding: 2, resource: { buffer: weightBuffer } },
142
- { binding: 3, resource: { buffer: output } },
143
- { binding: 4, resource: { buffer: residual.buffer } },
144
- ],
145
- });
146
-
147
- // One workgroup per output element
148
- const workgroups = N;
149
- recordDispatch(recorder, pipeline, bindGroup, workgroups, 'matmul_residual_fused');
152
+ try {
153
+ const bindGroup = device.createBindGroup({
154
+ label: 'matmul_residual_bind_group',
155
+ layout: pipeline.getBindGroupLayout(0),
156
+ entries: [
157
+ { binding: 0, resource: { buffer: uniformBuffer } },
158
+ { binding: 1, resource: { buffer: input.buffer } },
159
+ { binding: 2, resource: { buffer: weightBuffer } },
160
+ { binding: 3, resource: { buffer: output } },
161
+ { binding: 4, resource: { buffer: residual.buffer } },
162
+ ],
163
+ });
164
+
165
+ const workgroups = N;
166
+ recordDispatch(recorder, pipeline, bindGroup, workgroups, 'matmul_residual_fused');
167
+ } catch (error) {
168
+ if (ownedOutput) {
169
+ releaseBuffer(ownedOutput);
170
+ }
171
+ throw error;
172
+ }
150
173
 
151
174
  return createTensor(output, outputDtype, [1, N], 'matmul_residual_output');
152
175
  }