@simulatte/doppler 0.1.6 → 0.1.8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (355) hide show
  1. package/CHANGELOG.md +145 -0
  2. package/README.md +16 -23
  3. package/package.json +30 -32
  4. package/src/adapters/adapter-registry.js +12 -1
  5. package/src/adapters/lora-loader.js +23 -6
  6. package/src/bridge/extension-client.d.ts +5 -0
  7. package/src/bridge/extension-client.js +40 -0
  8. package/src/bridge/index.d.ts +2 -1
  9. package/src/bridge/index.js +6 -4
  10. package/src/browser/browser-converter.js +31 -1
  11. package/src/browser/file-picker.js +6 -0
  12. package/src/browser/safetensors-parser-browser.js +84 -1
  13. package/src/browser/shard-io-browser.js +2 -2
  14. package/src/browser/tensor-source-download.js +8 -2
  15. package/src/browser/tensor-source-http.d.ts +1 -0
  16. package/src/browser/tensor-source-http.js +5 -1
  17. package/src/client/doppler-api.browser.js +20 -4
  18. package/src/client/doppler-api.js +19 -3
  19. package/src/client/doppler-provider/generation.js +12 -0
  20. package/src/client/doppler-provider/model-manager.d.ts +10 -0
  21. package/src/client/doppler-provider/model-manager.js +91 -19
  22. package/src/client/doppler-provider/source-runtime.d.ts +2 -1
  23. package/src/client/doppler-provider/source-runtime.js +132 -13
  24. package/src/client/doppler-registry.json +5 -20
  25. package/src/config/backward-registry-loader.js +17 -2
  26. package/src/config/execution-v0-contract-check.js +113 -15
  27. package/src/config/kernel-path-contract-check.js +57 -29
  28. package/src/config/kernel-path-loader.d.ts +5 -0
  29. package/src/config/kernel-path-loader.js +18 -36
  30. package/src/config/kernels/kernel-ref-digests.js +1 -1
  31. package/src/config/kernels/registry.js +14 -1
  32. package/src/config/kernels/registry.json +81 -5
  33. package/src/config/loader.d.ts +1 -1
  34. package/src/config/loader.js +15 -2
  35. package/src/config/merge-contract-check.js +66 -4
  36. package/src/config/merge-helpers.js +128 -7
  37. package/src/config/merge.d.ts +1 -0
  38. package/src/config/merge.js +10 -0
  39. package/src/config/param-validator.js +47 -2
  40. package/src/config/presets/kernel-paths/{gemma2-q4k-dequant-f32a.json → gemma2-q4k-dequant-f32a-nosubgroups.json} +3 -3
  41. package/src/config/presets/kernel-paths/gemma3-f16-fused-f32a-online-streamingprefill.json +223 -0
  42. package/src/config/presets/kernel-paths/{gemma3-q4k-dequant-f32a.json → gemma3-q4k-dequant-f32a-nosubgroups.json} +3 -3
  43. package/src/config/presets/kernel-paths/gemma3-q4k-dequant-f32w-f32a-online.json +56 -0
  44. package/src/config/presets/kernel-paths/lfm2-q4k-dequant-f32a-nosubgroups.json +61 -0
  45. package/src/config/presets/kernel-paths/registry.json +43 -8
  46. package/src/config/presets/models/gemma2.json +3 -2
  47. package/src/config/presets/models/gemma3.json +2 -0
  48. package/src/config/presets/models/qwen3.json +4 -3
  49. package/src/config/presets/models/qwen3_5.json +16 -0
  50. package/src/config/presets/runtime/experiments/bench/gemma3-bench-q4k.json +1 -1
  51. package/src/config/presets/runtime/experiments/debug/gemma3-debug-q4k.json +1 -1
  52. package/src/config/presets/runtime/experiments/verify/gemma3-verify.json +1 -1
  53. package/src/config/presets/runtime/kernels/dequant-f16-q4k.json +6 -13
  54. package/src/config/presets/runtime/kernels/dequant-f32-q4k.json +6 -13
  55. package/src/config/presets/runtime/kernels/embeddinggemma-q4k-dequant-f32a.json +37 -0
  56. package/src/config/presets/runtime/kernels/fused-q4k.json +6 -13
  57. package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f16a.json +33 -0
  58. package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f32a-nosubgroups.json +33 -0
  59. package/src/config/presets/runtime/kernels/gemma2-q4k-fused-f32a.json +33 -0
  60. package/src/config/presets/runtime/kernels/safe-q4k.json +6 -13
  61. package/src/config/presets/runtime/model/qwen3-5-layer-probe.json +52 -0
  62. package/src/config/presets/runtime/model/qwen3-5-linear-attn-debug.json +90 -0
  63. package/src/config/presets/runtime/platform/metal-apple-q4k.json +1 -1
  64. package/src/config/runtime.js +6 -1
  65. package/src/config/schema/conversion.schema.d.ts +1 -0
  66. package/src/config/schema/debug.schema.d.ts +5 -0
  67. package/src/config/schema/doppler.schema.js +16 -21
  68. package/src/config/schema/inference-defaults.schema.js +3 -3
  69. package/src/config/schema/kernel-path.schema.d.ts +5 -1
  70. package/src/config/schema/kernel-thresholds.schema.js +12 -4
  71. package/src/config/schema/manifest.schema.d.ts +3 -2
  72. package/src/config/schema/manifest.schema.js +17 -4
  73. package/src/config/schema/storage.schema.js +1 -1
  74. package/src/config/training-defaults.js +30 -22
  75. package/src/converter/conversion-plan.js +104 -11
  76. package/src/converter/core.d.ts +7 -0
  77. package/src/converter/core.js +16 -9
  78. package/src/converter/execution-v0-manifest.js +4 -1
  79. package/src/converter/index.d.ts +1 -0
  80. package/src/converter/index.js +1 -0
  81. package/src/converter/manifest-inference.js +50 -29
  82. package/src/converter/parsers/diffusion.js +0 -3
  83. package/src/converter/parsers/transformer.js +4 -0
  84. package/src/converter/quantization-info.js +40 -16
  85. package/src/converter/quantizer.js +19 -12
  86. package/src/converter/rope-config.js +8 -6
  87. package/src/converter/shard-packer.d.ts +1 -1
  88. package/src/converter/shard-packer.js +4 -1
  89. package/src/converter/tokenizer-utils.d.ts +1 -0
  90. package/src/converter/tokenizer-utils.js +4 -1
  91. package/src/debug/config.js +123 -11
  92. package/src/debug/reference/hf_qwen35_linear_attn_debug.py +268 -0
  93. package/src/debug/signals.js +7 -1
  94. package/src/debug/tensor.d.ts +2 -0
  95. package/src/debug/tensor.js +13 -2
  96. package/src/distribution/p2p-control-plane.js +52 -12
  97. package/src/distribution/p2p-observability.js +43 -7
  98. package/src/distribution/p2p-webrtc-browser.js +20 -0
  99. package/src/distribution/shard-delivery.js +83 -27
  100. package/src/formats/gguf/types.js +33 -16
  101. package/src/formats/rdrr/groups.d.ts +12 -4
  102. package/src/formats/rdrr/groups.js +3 -6
  103. package/src/formats/rdrr/parsing.d.ts +4 -0
  104. package/src/formats/rdrr/parsing.js +53 -3
  105. package/src/formats/rdrr/types.d.ts +2 -1
  106. package/src/gpu/command-recorder.js +86 -61
  107. package/src/gpu/device.d.ts +1 -0
  108. package/src/gpu/device.js +73 -19
  109. package/src/gpu/kernel-tuner/benchmarks.js +326 -316
  110. package/src/gpu/kernel-tuner/cache.js +71 -4
  111. package/src/gpu/kernel-tuner/tuner.js +22 -4
  112. package/src/gpu/kernels/attention.js +15 -34
  113. package/src/gpu/kernels/backward/adam.js +62 -58
  114. package/src/gpu/kernels/backward/attention_backward.js +257 -169
  115. package/src/gpu/kernels/backward/conv2d_backward.js +14 -1
  116. package/src/gpu/kernels/cast.js +191 -149
  117. package/src/gpu/kernels/check-stop.js +33 -44
  118. package/src/gpu/kernels/conv2d.js +27 -17
  119. package/src/gpu/kernels/cross_entropy_loss.js +21 -15
  120. package/src/gpu/kernels/depthwise_conv2d.js +36 -26
  121. package/src/gpu/kernels/dequant.js +178 -126
  122. package/src/gpu/kernels/energy.d.ts +3 -21
  123. package/src/gpu/kernels/energy.js +111 -88
  124. package/src/gpu/kernels/feature-check.js +1 -1
  125. package/src/gpu/kernels/fused_ffn.js +84 -65
  126. package/src/gpu/kernels/fused_matmul_residual.js +56 -33
  127. package/src/gpu/kernels/fused_matmul_rmsnorm.js +62 -45
  128. package/src/gpu/kernels/gather.js +33 -15
  129. package/src/gpu/kernels/gelu.js +19 -11
  130. package/src/gpu/kernels/grouped_pointwise_conv2d.js +33 -23
  131. package/src/gpu/kernels/groupnorm.js +34 -23
  132. package/src/gpu/kernels/index.d.ts +8 -0
  133. package/src/gpu/kernels/index.js +6 -0
  134. package/src/gpu/kernels/kv-quantize.js +5 -2
  135. package/src/gpu/kernels/layernorm.js +35 -19
  136. package/src/gpu/kernels/logit-merge.js +5 -3
  137. package/src/gpu/kernels/matmul-selection.js +47 -4
  138. package/src/gpu/kernels/matmul.d.ts +2 -0
  139. package/src/gpu/kernels/matmul.js +59 -40
  140. package/src/gpu/kernels/modulate.js +23 -15
  141. package/src/gpu/kernels/moe.js +221 -175
  142. package/src/gpu/kernels/pixel_shuffle.js +22 -14
  143. package/src/gpu/kernels/relu.js +18 -10
  144. package/src/gpu/kernels/repeat_channels.js +25 -17
  145. package/src/gpu/kernels/residual.js +37 -27
  146. package/src/gpu/kernels/rmsnorm.js +66 -43
  147. package/src/gpu/kernels/rope.js +3 -0
  148. package/src/gpu/kernels/sample.js +27 -38
  149. package/src/gpu/kernels/sana_linear_attention.js +18 -10
  150. package/src/gpu/kernels/scale.js +18 -11
  151. package/src/gpu/kernels/shader-cache.js +4 -2
  152. package/src/gpu/kernels/silu.js +120 -72
  153. package/src/gpu/kernels/softmax.js +44 -25
  154. package/src/gpu/kernels/split_qg.d.ts +50 -0
  155. package/src/gpu/kernels/split_qg.js +46 -0
  156. package/src/gpu/kernels/split_qg.wgsl +58 -0
  157. package/src/gpu/kernels/split_qg_f16.wgsl +62 -0
  158. package/src/gpu/kernels/split_qkv.js +23 -13
  159. package/src/gpu/kernels/transpose.js +18 -10
  160. package/src/gpu/kernels/transpose.wgsl +5 -3
  161. package/src/gpu/kernels/upsample2d.js +21 -13
  162. package/src/gpu/kernels/utils.js +20 -13
  163. package/src/gpu/partitioned-buffer-pool.js +10 -2
  164. package/src/gpu/perf-guards.js +2 -9
  165. package/src/gpu/profiler.js +27 -22
  166. package/src/gpu/readback-utils.d.ts +16 -0
  167. package/src/gpu/readback-utils.js +41 -0
  168. package/src/gpu/submit-tracker.js +13 -0
  169. package/src/gpu/uniform-cache.d.ts +1 -0
  170. package/src/gpu/uniform-cache.js +30 -9
  171. package/src/gpu/weight-buffer.d.ts +1 -1
  172. package/src/gpu/weight-buffer.js +1 -1
  173. package/src/hotswap/intent-bundle.js +6 -0
  174. package/src/hotswap/manifest.d.ts +10 -1
  175. package/src/hotswap/manifest.js +12 -2
  176. package/src/hotswap/runtime.js +30 -8
  177. package/src/index-browser.d.ts +44 -0
  178. package/src/index-browser.js +14 -0
  179. package/src/inference/browser-harness-contract-helpers.d.ts +5 -0
  180. package/src/inference/browser-harness-contract-helpers.js +28 -0
  181. package/src/inference/browser-harness-diffusion-energy-suites.d.ts +2 -0
  182. package/src/inference/browser-harness-diffusion-energy-suites.js +269 -0
  183. package/src/inference/browser-harness-model-helpers.d.ts +16 -0
  184. package/src/inference/browser-harness-model-helpers.js +217 -0
  185. package/src/inference/browser-harness-report-helpers.d.ts +7 -0
  186. package/src/inference/browser-harness-report-helpers.js +42 -0
  187. package/src/inference/browser-harness-runtime-helpers.d.ts +61 -0
  188. package/src/inference/browser-harness-runtime-helpers.js +415 -0
  189. package/src/inference/browser-harness-suite-helpers.d.ts +28 -0
  190. package/src/inference/browser-harness-suite-helpers.js +268 -0
  191. package/src/inference/browser-harness-text-helpers.d.ts +27 -0
  192. package/src/inference/browser-harness-text-helpers.js +788 -0
  193. package/src/inference/browser-harness.d.ts +8 -0
  194. package/src/inference/browser-harness.js +149 -1996
  195. package/src/inference/kv-cache/base.js +140 -94
  196. package/src/inference/kv-cache/tiered.js +5 -3
  197. package/src/inference/moe-router.js +88 -56
  198. package/src/inference/multi-model-network.js +5 -3
  199. package/src/inference/network-evolution.d.ts +11 -2
  200. package/src/inference/network-evolution.js +20 -21
  201. package/src/inference/pipelines/context.d.ts +3 -0
  202. package/src/inference/pipelines/context.js +142 -2
  203. package/src/inference/pipelines/diffusion/helpers.js +10 -2
  204. package/src/inference/pipelines/diffusion/pipeline.js +2 -1
  205. package/src/inference/pipelines/diffusion/sd3-transformer.js +10 -10
  206. package/src/inference/pipelines/diffusion/text-encoder-gpu.js +8 -2
  207. package/src/inference/pipelines/diffusion/vae.js +3 -7
  208. package/src/inference/pipelines/energy/pipeline.js +27 -21
  209. package/src/inference/pipelines/energy/quintel.d.ts +5 -0
  210. package/src/inference/pipelines/energy/quintel.js +11 -0
  211. package/src/inference/pipelines/energy-head/row-head-pipeline.js +17 -13
  212. package/src/inference/pipelines/structured/json-head-pipeline.js +26 -11
  213. package/src/inference/pipelines/text/attention/output-projection.d.ts +12 -0
  214. package/src/inference/pipelines/text/attention/output-projection.js +8 -0
  215. package/src/inference/pipelines/text/attention/projections.d.ts +10 -1
  216. package/src/inference/pipelines/text/attention/projections.js +192 -112
  217. package/src/inference/pipelines/text/attention/record.js +77 -14
  218. package/src/inference/pipelines/text/attention/run.js +112 -14
  219. package/src/inference/pipelines/text/config.js +17 -4
  220. package/src/inference/pipelines/text/embed.js +2 -8
  221. package/src/inference/pipelines/text/execution-plan.js +46 -23
  222. package/src/inference/pipelines/text/execution-v0-contract-helpers.d.ts +59 -0
  223. package/src/inference/pipelines/text/execution-v0-contract-helpers.js +937 -0
  224. package/src/inference/pipelines/text/execution-v0-runtime-builders.d.ts +15 -0
  225. package/src/inference/pipelines/text/execution-v0-runtime-builders.js +279 -0
  226. package/src/inference/pipelines/text/execution-v0.js +62 -1013
  227. package/src/inference/pipelines/text/generator-runtime.js +5 -0
  228. package/src/inference/pipelines/text/generator-steps.d.ts +52 -0
  229. package/src/inference/pipelines/text/generator-steps.js +340 -221
  230. package/src/inference/pipelines/text/generator.js +56 -40
  231. package/src/inference/pipelines/text/init.d.ts +13 -0
  232. package/src/inference/pipelines/text/init.js +94 -25
  233. package/src/inference/pipelines/text/kernel-path-auto-select.js +2 -0
  234. package/src/inference/pipelines/text/kernel-trace.d.ts +2 -0
  235. package/src/inference/pipelines/text/kernel-trace.js +6 -0
  236. package/src/inference/pipelines/text/layer.js +4 -9
  237. package/src/inference/pipelines/text/linear-attention.d.ts +15 -0
  238. package/src/inference/pipelines/text/linear-attention.js +113 -9
  239. package/src/inference/pipelines/text/logits/gpu.js +12 -7
  240. package/src/inference/pipelines/text/logits/index.d.ts +6 -1
  241. package/src/inference/pipelines/text/logits/index.js +13 -12
  242. package/src/inference/pipelines/text/logits/utils.d.ts +7 -0
  243. package/src/inference/pipelines/text/logits/utils.js +9 -0
  244. package/src/inference/pipelines/text/lora-apply.js +50 -32
  245. package/src/inference/pipelines/text/model-load.js +282 -104
  246. package/src/inference/pipelines/text/moe-cache.js +5 -4
  247. package/src/inference/pipelines/text/moe-cpu-gptoss.js +74 -69
  248. package/src/inference/pipelines/text/moe-cpu.js +42 -38
  249. package/src/inference/pipelines/text/moe-gpu.js +110 -86
  250. package/src/inference/pipelines/text/ops.js +90 -90
  251. package/src/inference/pipelines/text/probes.js +9 -9
  252. package/src/inference/pipelines/text/sampling.js +52 -6
  253. package/src/inference/pipelines/text/weights.js +17 -7
  254. package/src/inference/pipelines/text.js +13 -1
  255. package/src/inference/speculative.d.ts +2 -2
  256. package/src/inference/speculative.js +4 -18
  257. package/src/inference/test-harness.d.ts +1 -1
  258. package/src/inference/test-harness.js +17 -7
  259. package/src/inference/tokenizer.d.ts +0 -5
  260. package/src/inference/tokenizer.js +4 -23
  261. package/src/inference/tokenizers/bpe.js +9 -0
  262. package/src/inference/tokenizers/bundled.js +20 -0
  263. package/src/inference/tokenizers/sentencepiece.js +12 -0
  264. package/src/loader/doppler-loader.js +38 -22
  265. package/src/loader/dtype-utils.js +3 -44
  266. package/src/loader/embedding-loader.js +7 -3
  267. package/src/loader/experts/expert-cache.js +13 -6
  268. package/src/loader/experts/expert-loader.js +10 -6
  269. package/src/loader/final-weights-loader.js +10 -4
  270. package/src/loader/layer-loader.js +2 -1
  271. package/src/loader/loader-state.js +2 -2
  272. package/src/loader/memory-monitor.js +8 -0
  273. package/src/loader/multi-model-loader.d.ts +14 -0
  274. package/src/loader/multi-model-loader.js +70 -24
  275. package/src/loader/shard-cache.js +84 -14
  276. package/src/loader/shard-resolver.js +25 -3
  277. package/src/loader/tensors/tensor-loader.js +214 -144
  278. package/src/loader/tensors/tensor-reader.js +76 -19
  279. package/src/loader/weight-downcast.js +1 -1
  280. package/src/memory/buffer-pool.d.ts +9 -1
  281. package/src/memory/buffer-pool.js +109 -44
  282. package/src/memory/unified-detect.js +1 -1
  283. package/src/rules/inference/dtype.rules.json +5 -0
  284. package/src/rules/inference/kernel-path.rules.json +24 -8
  285. package/src/rules/kernels/split-qg.rules.json +6 -0
  286. package/src/rules/rule-registry.js +27 -1
  287. package/src/storage/backends/opfs-store.js +68 -24
  288. package/src/storage/downloader.js +365 -83
  289. package/src/storage/index.d.ts +3 -0
  290. package/src/storage/index.js +3 -0
  291. package/src/storage/preflight.d.ts +2 -2
  292. package/src/storage/preflight.js +24 -2
  293. package/src/storage/quickstart-downloader.js +11 -5
  294. package/src/storage/registry.js +10 -4
  295. package/src/storage/reports.js +1 -1
  296. package/src/storage/shard-manager.d.ts +15 -1
  297. package/src/storage/shard-manager.js +55 -6
  298. package/src/storage/source-artifact-store.d.ts +52 -0
  299. package/src/storage/source-artifact-store.js +234 -0
  300. package/src/tooling/command-api-constants.d.ts +9 -0
  301. package/src/tooling/command-api-constants.js +9 -0
  302. package/src/tooling/command-api-family-normalizers.d.ts +9 -0
  303. package/src/tooling/command-api-family-normalizers.js +343 -0
  304. package/src/tooling/command-api-helpers.d.ts +25 -0
  305. package/src/tooling/command-api-helpers.js +262 -0
  306. package/src/tooling/command-api.js +16 -602
  307. package/src/tooling/command-envelope.js +4 -1
  308. package/src/tooling/command-runner-shared.js +52 -18
  309. package/src/tooling/conversion-config-materializer.js +3 -5
  310. package/src/tooling/lean-execution-contract.js +150 -3
  311. package/src/tooling/node-browser-command-runner.js +161 -271
  312. package/src/tooling/node-command-runner.js +29 -3
  313. package/src/tooling/node-converter.js +30 -1
  314. package/src/tooling/node-source-runtime.d.ts +1 -1
  315. package/src/tooling/node-source-runtime.js +120 -3
  316. package/src/tooling/node-webgpu.js +24 -21
  317. package/src/tooling/opfs-cache.js +21 -4
  318. package/src/tooling/runtime-input-composition.d.ts +38 -0
  319. package/src/tooling/runtime-input-composition.js +86 -0
  320. package/src/tooling/source-runtime-bundle.d.ts +40 -5
  321. package/src/tooling/source-runtime-bundle.js +261 -34
  322. package/src/tooling/source-runtime-materializer.d.ts +6 -0
  323. package/src/tooling/source-runtime-materializer.js +93 -0
  324. package/src/training/attention-backward.js +32 -17
  325. package/src/training/autograd.js +80 -52
  326. package/src/training/checkpoint-watch.d.ts +2 -1
  327. package/src/training/checkpoint-watch.js +39 -6
  328. package/src/training/checkpoint.js +40 -11
  329. package/src/training/clip.js +2 -1
  330. package/src/training/datasets/token-batch.js +20 -8
  331. package/src/training/distillation/checkpoint-watch.js +1 -0
  332. package/src/training/distillation/student-fixture.d.ts +22 -0
  333. package/src/training/distillation/student-fixture.js +846 -0
  334. package/src/training/distillation/suite-data.d.ts +45 -0
  335. package/src/training/distillation/suite-data.js +189 -0
  336. package/src/training/lora-pipeline.js +4 -7
  337. package/src/training/lora.js +26 -12
  338. package/src/training/loss.js +5 -6
  339. package/src/training/objectives/cross_entropy.js +2 -5
  340. package/src/training/objectives/distill_kd.js +4 -8
  341. package/src/training/objectives/distill_triplet.js +4 -8
  342. package/src/training/objectives/ul_stage2_base.js +4 -8
  343. package/src/training/operator-command.js +2 -0
  344. package/src/training/optimizer.js +19 -7
  345. package/src/training/runner.js +2 -1
  346. package/src/training/suite.js +18 -978
  347. package/src/training/tensor-factory.d.ts +9 -0
  348. package/src/training/tensor-factory.js +13 -0
  349. package/src/training/trainer.js +3 -5
  350. package/src/training/ul_dataset.js +3 -5
  351. package/src/training/workloads.js +70 -79
  352. package/src/types/model.d.ts +5 -0
  353. package/src/version.js +1 -1
  354. package/tools/convert-safetensors-node.js +22 -16
  355. package/tools/doppler-cli.js +50 -26
@@ -1,13 +1,26 @@
1
1
 
2
2
 
3
3
  import { getDevice } from '../device.js';
4
- import { acquireBuffer } from '../../memory/buffer-pool.js';
4
+ import { acquireBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
5
5
  import { createTensor, dtypeBytes } from '../tensor.js';
6
6
  import { WORKGROUP_SIZES } from './constants.js';
7
7
  import { dispatch, recordDispatch } from './dispatch.js';
8
8
  import { getPipelineFast, createUniformBufferWithView } from './utils.js';
9
9
  import { selectRuleValue } from './rule-registry.js';
10
10
 
11
+ function destroyAfterSubmit(device, buffer) {
12
+ if (!buffer) {
13
+ return;
14
+ }
15
+ device.queue.onSubmittedWorkDone()
16
+ .then(() => {
17
+ buffer.destroy();
18
+ })
19
+ .catch(() => {
20
+ buffer.destroy();
21
+ });
22
+ }
23
+
11
24
  function canUseF16(input) {
12
25
  return input.dtype === 'f16';
13
26
  }
@@ -47,6 +60,12 @@ function createSiLUBindGroupEntries(uniformBuffer, input, output, gate) {
47
60
  ];
48
61
  }
49
62
 
63
+ function cleanupRunResources(uniformBuffer, ownedOutput) {
64
+ if (ownedOutput) {
65
+ releaseBuffer(ownedOutput);
66
+ }
67
+ }
68
+
50
69
  function planSiLUDispatch(device, size, useVec4) {
51
70
  const maxPerDim = Number.isFinite(device?.limits?.maxComputeWorkgroupsPerDimension)
52
71
  ? device.limits.maxComputeWorkgroupsPerDimension
@@ -97,6 +116,7 @@ export async function runSiLU(
97
116
  const inferredSize = size || (input.buffer.size / bytesPerElement);
98
117
  const outputSize = inferredSize * bytesPerElement;
99
118
  const output = outputBuffer || acquireBuffer(outputSize, undefined, 'silu_output');
119
+ const ownedOutput = outputBuffer ? null : output;
100
120
  const dispatchPlan = planSiLUDispatch(device, inferredSize, useVec4);
101
121
 
102
122
  // Create uniform buffer
@@ -116,17 +136,21 @@ export async function runSiLU(
116
136
  // Create bind group using helper
117
137
  const entries = createSiLUBindGroupEntries(uniformBuffer, input, output, gate);
118
138
 
119
- const bindGroup = device.createBindGroup({
120
- label: 'silu_bind_group',
121
- layout: pipeline.getBindGroupLayout(0),
122
- entries,
123
- });
124
-
125
- dispatch(device, pipeline, bindGroup, dispatchPlan.workgroups, 'silu');
126
-
127
- uniformBuffer.destroy();
128
-
129
- return createTensor(output, input.dtype, [inferredSize], 'silu_output');
139
+ try {
140
+ const bindGroup = device.createBindGroup({
141
+ label: 'silu_bind_group',
142
+ layout: pipeline.getBindGroupLayout(0),
143
+ entries,
144
+ });
145
+
146
+ dispatch(device, pipeline, bindGroup, dispatchPlan.workgroups, 'silu');
147
+ return createTensor(output, input.dtype, [inferredSize], 'silu_output');
148
+ } catch (error) {
149
+ cleanupRunResources(null, ownedOutput);
150
+ throw error;
151
+ } finally {
152
+ destroyAfterSubmit(device, uniformBuffer);
153
+ }
130
154
  }
131
155
 
132
156
 
@@ -148,6 +172,7 @@ export async function runSwiGLURowsplitBias(
148
172
  const bytesPerElement = dtypeBytes(input.dtype);
149
173
  const outputSize = numTokens * dim * bytesPerElement;
150
174
  const output = outputBuffer || acquireBuffer(outputSize, undefined, 'swiglu_output');
175
+ const ownedOutput = outputBuffer ? null : output;
151
176
 
152
177
  // Create uniform buffer
153
178
  const uniformBuffer = createUniformBufferWithView(
@@ -164,23 +189,27 @@ export async function runSwiGLURowsplitBias(
164
189
  );
165
190
 
166
191
  // Create bind group
167
- const bindGroup = device.createBindGroup({
168
- label: 'swiglu_bind_group',
169
- layout: pipeline.getBindGroupLayout(0),
170
- entries: [
171
- { binding: 0, resource: { buffer: uniformBuffer } },
172
- { binding: 1, resource: { buffer: input.buffer } },
173
- { binding: 2, resource: { buffer: bias.buffer } },
174
- { binding: 3, resource: { buffer: output } },
175
- ],
176
- });
177
-
178
- const workgroups = Math.ceil((numTokens * dim) / WORKGROUP_SIZES.DEFAULT);
179
- dispatch(device, pipeline, bindGroup, workgroups, 'swiglu');
180
-
181
- uniformBuffer.destroy();
182
-
183
- return createTensor(output, input.dtype, [numTokens, dim], 'swiglu_output');
192
+ try {
193
+ const bindGroup = device.createBindGroup({
194
+ label: 'swiglu_bind_group',
195
+ layout: pipeline.getBindGroupLayout(0),
196
+ entries: [
197
+ { binding: 0, resource: { buffer: uniformBuffer } },
198
+ { binding: 1, resource: { buffer: input.buffer } },
199
+ { binding: 2, resource: { buffer: bias.buffer } },
200
+ { binding: 3, resource: { buffer: output } },
201
+ ],
202
+ });
203
+
204
+ const workgroups = Math.ceil((numTokens * dim) / WORKGROUP_SIZES.DEFAULT);
205
+ dispatch(device, pipeline, bindGroup, workgroups, 'swiglu');
206
+ return createTensor(output, input.dtype, [numTokens, dim], 'swiglu_output');
207
+ } catch (error) {
208
+ cleanupRunResources(null, ownedOutput);
209
+ throw error;
210
+ } finally {
211
+ destroyAfterSubmit(device, uniformBuffer);
212
+ }
184
213
  }
185
214
 
186
215
 
@@ -202,6 +231,7 @@ export async function runSiLURowSplit(
202
231
 
203
232
  const outputSize = numTokens * dim * bytesPerElement;
204
233
  const output = outputBuffer || acquireBuffer(outputSize, undefined, 'silu_rowsplit_output');
234
+ const ownedOutput = outputBuffer ? null : output;
205
235
 
206
236
  // Create uniform buffer
207
237
  const uniformBuffer = createUniformBufferWithView(
@@ -218,24 +248,28 @@ export async function runSiLURowSplit(
218
248
  );
219
249
 
220
250
  // Bind group: provide a dummy gate buffer to satisfy the fixed layout
221
- const gateBuffer = input.buffer;
222
- const bindGroup = device.createBindGroup({
223
- label: 'silu_rowsplit_bind_group',
224
- layout: pipeline.getBindGroupLayout(0),
225
- entries: [
226
- { binding: 0, resource: { buffer: uniformBuffer } },
227
- { binding: 1, resource: { buffer: input.buffer } },
228
- { binding: 2, resource: { buffer: output } },
229
- { binding: 3, resource: { buffer: gateBuffer } },
230
- ],
231
- });
232
-
233
- const workgroups = [Math.ceil(dim / WORKGROUP_SIZES.DEFAULT), numTokens, 1];
234
- dispatch(device, pipeline, bindGroup, workgroups, 'silu_rowsplit');
235
-
236
- uniformBuffer.destroy();
237
-
238
- return createTensor(output, input.dtype, [numTokens, dim], 'silu_rowsplit_output');
251
+ try {
252
+ const gateBuffer = input.buffer;
253
+ const bindGroup = device.createBindGroup({
254
+ label: 'silu_rowsplit_bind_group',
255
+ layout: pipeline.getBindGroupLayout(0),
256
+ entries: [
257
+ { binding: 0, resource: { buffer: uniformBuffer } },
258
+ { binding: 1, resource: { buffer: input.buffer } },
259
+ { binding: 2, resource: { buffer: output } },
260
+ { binding: 3, resource: { buffer: gateBuffer } },
261
+ ],
262
+ });
263
+
264
+ const workgroups = [Math.ceil(dim / WORKGROUP_SIZES.DEFAULT), numTokens, 1];
265
+ dispatch(device, pipeline, bindGroup, workgroups, 'silu_rowsplit');
266
+ return createTensor(output, input.dtype, [numTokens, dim], 'silu_rowsplit_output');
267
+ } catch (error) {
268
+ cleanupRunResources(null, ownedOutput);
269
+ throw error;
270
+ } finally {
271
+ uniformBuffer.destroy();
272
+ }
239
273
  }
240
274
 
241
275
 
@@ -258,6 +292,7 @@ export async function recordSiLURowSplit(
258
292
 
259
293
  const outputSize = numTokens * dim * bytesPerElement;
260
294
  const output = outputBuffer || acquireBuffer(outputSize, undefined, 'silu_rowsplit_output');
295
+ const ownedOutput = outputBuffer ? null : output;
261
296
 
262
297
  // Uniform buffer
263
298
  const uniformBuffer = createUniformBufferWithView(
@@ -272,22 +307,28 @@ export async function recordSiLURowSplit(
272
307
  recorder
273
308
  );
274
309
 
275
- const gateBuffer = input.buffer;
276
- const bindGroup = device.createBindGroup({
277
- label: 'silu_rowsplit_bind_group',
278
- layout: pipeline.getBindGroupLayout(0),
279
- entries: [
280
- { binding: 0, resource: { buffer: uniformBuffer } },
281
- { binding: 1, resource: { buffer: input.buffer } },
282
- { binding: 2, resource: { buffer: output } },
283
- { binding: 3, resource: { buffer: gateBuffer } },
284
- ],
285
- });
286
-
287
- const workgroups = [Math.ceil(dim / WORKGROUP_SIZES.DEFAULT), numTokens, 1];
288
- recordDispatch(recorder, pipeline, bindGroup, workgroups, 'silu_rowsplit');
289
-
290
- return createTensor(output, input.dtype, [numTokens, dim], 'silu_rowsplit_output');
310
+ try {
311
+ const gateBuffer = input.buffer;
312
+ const bindGroup = device.createBindGroup({
313
+ label: 'silu_rowsplit_bind_group',
314
+ layout: pipeline.getBindGroupLayout(0),
315
+ entries: [
316
+ { binding: 0, resource: { buffer: uniformBuffer } },
317
+ { binding: 1, resource: { buffer: input.buffer } },
318
+ { binding: 2, resource: { buffer: output } },
319
+ { binding: 3, resource: { buffer: gateBuffer } },
320
+ ],
321
+ });
322
+
323
+ const workgroups = [Math.ceil(dim / WORKGROUP_SIZES.DEFAULT), numTokens, 1];
324
+ recordDispatch(recorder, pipeline, bindGroup, workgroups, 'silu_rowsplit');
325
+ return createTensor(output, input.dtype, [numTokens, dim], 'silu_rowsplit_output');
326
+ } catch (error) {
327
+ if (ownedOutput) {
328
+ releaseBuffer(ownedOutput);
329
+ }
330
+ throw error;
331
+ }
291
332
  }
292
333
 
293
334
 
@@ -328,6 +369,7 @@ export async function recordSiLU(
328
369
  const inferredSize = size || (input.buffer.size / bytesPerElement);
329
370
  const outputSize = inferredSize * bytesPerElement;
330
371
  const output = outputBuffer || acquireBuffer(outputSize, undefined, 'silu_output');
372
+ const ownedOutput = outputBuffer ? null : output;
331
373
  const dispatchPlan = planSiLUDispatch(device, inferredSize, false);
332
374
 
333
375
  // Uniform buffer
@@ -346,13 +388,19 @@ export async function recordSiLU(
346
388
  // Create bind group using helper
347
389
  const entries = createSiLUBindGroupEntries(uniformBuffer, input, output, gate);
348
390
 
349
- const bindGroup = device.createBindGroup({
350
- label: 'silu_bind_group',
351
- layout: pipeline.getBindGroupLayout(0),
352
- entries,
353
- });
354
-
355
- recordDispatch(recorder, pipeline, bindGroup, dispatchPlan.workgroups, 'silu');
356
-
357
- return createTensor(output, input.dtype, [inferredSize], 'silu_output');
391
+ try {
392
+ const bindGroup = device.createBindGroup({
393
+ label: 'silu_bind_group',
394
+ layout: pipeline.getBindGroupLayout(0),
395
+ entries,
396
+ });
397
+
398
+ recordDispatch(recorder, pipeline, bindGroup, dispatchPlan.workgroups, 'silu');
399
+ return createTensor(output, input.dtype, [inferredSize], 'silu_output');
400
+ } catch (error) {
401
+ if (ownedOutput) {
402
+ releaseBuffer(ownedOutput);
403
+ }
404
+ throw error;
405
+ }
358
406
  }
@@ -1,6 +1,6 @@
1
1
 
2
2
  import { getKernelCapabilities } from '../device.js';
3
- import { acquireBuffer } from '../../memory/buffer-pool.js';
3
+ import { acquireBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
4
4
  import { createTensor } from '../tensor.js';
5
5
  import { unifiedKernelWrapper } from './utils.js';
6
6
  import { createPipeline, createUniformBufferWithView, createBindGroupWithValidation } from './utils.js';
@@ -20,23 +20,34 @@ function selectSoftmaxVariant(innerSize) {
20
20
 
21
21
  async function _softmax(target, input, axis, options = {}) {
22
22
  const { batchSize = 1, size, seqLen, temperature = 1.0, outputBuffer = null } = options;
23
+ if (input.dtype !== 'f32') {
24
+ throw new Error(`Softmax requires f32 input, got ${input.dtype}.`);
25
+ }
23
26
 
24
- const bytesPerElement = input.dtype === 'f16' ? 2 : 4;
27
+ const bytesPerElement = 4;
25
28
  const inferredSize = size || seqLen || (input.buffer.size / (batchSize * bytesPerElement));
26
29
  const variant = selectSoftmaxVariant(inferredSize);
27
30
  trace.kernels(`Softmax: size=${inferredSize}, variant=${variant}`);
28
31
 
29
32
  const outputSize = batchSize * inferredSize * bytesPerElement;
30
33
  const output = outputBuffer || acquireBuffer(outputSize, undefined, 'softmax_output');
34
+ const ownedOutput = outputBuffer ? null : output;
35
+
36
+ try {
37
+ await unifiedKernelWrapper(
38
+ 'softmax', target, variant,
39
+ [input, output],
40
+ { inner_size: inferredSize, outer_size: batchSize, temperature },
41
+ batchSize
42
+ );
43
+ } catch (error) {
44
+ if (ownedOutput) {
45
+ releaseBuffer(ownedOutput);
46
+ }
47
+ throw error;
48
+ }
31
49
 
32
- await unifiedKernelWrapper(
33
- 'softmax', target, variant,
34
- [input, output],
35
- { inner_size: inferredSize, outer_size: batchSize, temperature },
36
- batchSize
37
- );
38
-
39
- return createTensor(output, input.dtype, [batchSize, inferredSize], 'softmax_output');
50
+ return createTensor(output, 'f32', [batchSize, inferredSize], 'softmax_output');
40
51
  }
41
52
 
42
53
  export async function runSoftmax(input, axis, options = {}) {
@@ -76,6 +87,7 @@ export async function runSoftmaxTopK(logits, numTokens, numExperts, topK, option
76
87
 
77
88
  const indices = acquireBuffer(indicesSize, undefined, 'softmax_topk_indices');
78
89
  const weights = acquireBuffer(weightsSize, undefined, 'softmax_topk_weights');
90
+ let completed = false;
79
91
 
80
92
  const uniformBuffer = createUniformBufferWithView(
81
93
  'softmax_topk_uniforms', 16,
@@ -88,19 +100,26 @@ export async function runSoftmaxTopK(logits, numTokens, numExperts, topK, option
88
100
  null, device
89
101
  );
90
102
 
91
- const bindGroup = await createBindGroupWithValidation(device, {
92
- label: 'softmax_topk_bind_group',
93
- layout: pipeline.getBindGroupLayout(0),
94
- entries: [
95
- { binding: 0, resource: { buffer: uniformBuffer } },
96
- { binding: 1, resource: { buffer: logits } },
97
- { binding: 2, resource: { buffer: indices } },
98
- { binding: 3, resource: { buffer: weights } },
99
- ],
100
- }, `topk:${variant}`);
101
-
102
- dispatchKernel(null, pipeline, bindGroup, numTokens, 'softmax_topk');
103
- uniformBuffer.destroy();
104
-
105
- return { indices, weights };
103
+ try {
104
+ const bindGroup = await createBindGroupWithValidation(device, {
105
+ label: 'softmax_topk_bind_group',
106
+ layout: pipeline.getBindGroupLayout(0),
107
+ entries: [
108
+ { binding: 0, resource: { buffer: uniformBuffer } },
109
+ { binding: 1, resource: { buffer: logits } },
110
+ { binding: 2, resource: { buffer: indices } },
111
+ { binding: 3, resource: { buffer: weights } },
112
+ ],
113
+ }, `topk:${variant}`);
114
+
115
+ dispatchKernel(null, pipeline, bindGroup, numTokens, 'softmax_topk');
116
+ completed = true;
117
+ return { indices, weights };
118
+ } finally {
119
+ uniformBuffer.destroy();
120
+ if (!completed) {
121
+ releaseBuffer(indices);
122
+ releaseBuffer(weights);
123
+ }
124
+ }
106
125
  }
@@ -0,0 +1,50 @@
1
+ /**
2
+ * Split Q and Gate Kernel
3
+ *
4
+ * De-interleaves Q and Gate projections from q_proj output for attentionOutputGate models.
5
+ * Models like Qwen 3.5 store q_proj weights in per-head interleaved layout:
6
+ * rows [h*headDim*2 : h*headDim*2+headDim] = Q for head h
7
+ * rows [h*headDim*2+headDim : (h+1)*headDim*2] = Gate for head h
8
+ * This kernel separates the full matmul output into contiguous Q and Gate tensors.
9
+ */
10
+
11
+ import type { Tensor } from '../tensor.js';
12
+ import type { CommandRecorder } from '../command-recorder.js';
13
+
14
+ /** Split Q and Gate options */
15
+ export interface SplitQGOptions {
16
+ numTokens: number;
17
+ numHeads: number;
18
+ headDim: number;
19
+ /** Pre-allocated Q output tensor */
20
+ qTensor?: Tensor | null;
21
+ /** Pre-allocated Gate output tensor */
22
+ gTensor?: Tensor | null;
23
+ }
24
+
25
+ /** Split Q and Gate result */
26
+ export interface SplitQGResult {
27
+ Q: Tensor;
28
+ G: Tensor;
29
+ }
30
+
31
+ /**
32
+ * De-interleave Q and Gate from q_proj output.
33
+ *
34
+ * @param qgTensor - Full q_proj output [numTokens, numHeads * headDim * 2] (interleaved)
35
+ * @param options - Split configuration
36
+ * @returns Separate Q and Gate tensors, each [numTokens, numHeads * headDim]
37
+ */
38
+ export declare function runSplitQG(
39
+ qgTensor: Tensor,
40
+ options: SplitQGOptions
41
+ ): Promise<SplitQGResult>;
42
+
43
+ /**
44
+ * Record split Q and Gate (batched, no submit).
45
+ */
46
+ export declare function recordSplitQG(
47
+ recorder: CommandRecorder,
48
+ qgTensor: Tensor,
49
+ options: SplitQGOptions
50
+ ): Promise<SplitQGResult>;
@@ -0,0 +1,46 @@
1
+
2
+ import { acquireBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
3
+ import { createTensor, dtypeBytes } from '../tensor.js';
4
+ import { WORKGROUP_SIZES } from './constants.js';
5
+ import { unifiedKernelWrapper } from './utils.js';
6
+ import { selectRuleValue } from './rule-registry.js';
7
+
8
+ async function _splitQG(target, qgTensor, options) {
9
+ const { numTokens, numHeads, headDim, qTensor = null, gTensor = null } = options;
10
+ const ownsQ = qTensor == null;
11
+ const ownsG = gTensor == null;
12
+
13
+ const outputDtype = qgTensor.dtype;
14
+ const pipelineVariant = selectRuleValue('splitQg', 'variant', { outputDtype });
15
+ const bytesPerElement = dtypeBytes(outputDtype);
16
+ const qSize = numHeads * headDim;
17
+
18
+ const qBuffer = qTensor?.buffer || acquireBuffer(numTokens * qSize * bytesPerElement, undefined, 'Q');
19
+ const gBuffer = gTensor?.buffer || acquireBuffer(numTokens * qSize * bytesPerElement, undefined, 'Q_gate');
20
+
21
+ try {
22
+ await unifiedKernelWrapper(
23
+ 'split_qg', target, pipelineVariant,
24
+ [qgTensor, qBuffer, gBuffer],
25
+ { num_tokens: numTokens, num_heads: numHeads, head_dim: headDim, _pad: 0 },
26
+ Math.ceil((numTokens * qSize) / WORKGROUP_SIZES.DEFAULT)
27
+ );
28
+
29
+ const Q = qTensor || createTensor(qBuffer, outputDtype, [numTokens, qSize], 'Q');
30
+ const G = gTensor || createTensor(gBuffer, outputDtype, [numTokens, qSize], 'Q_gate');
31
+
32
+ return { Q, G };
33
+ } catch (error) {
34
+ if (ownsQ) releaseBuffer(qBuffer);
35
+ if (ownsG) releaseBuffer(gBuffer);
36
+ throw error;
37
+ }
38
+ }
39
+
40
+ export async function runSplitQG(qgTensor, options) {
41
+ return _splitQG(null, qgTensor, options);
42
+ }
43
+
44
+ export async function recordSplitQG(recorder, qgTensor, options) {
45
+ return _splitQG(recorder, qgTensor, options);
46
+ }
@@ -0,0 +1,58 @@
1
+ // split_qg.wgsl
2
+
3
+ /**
4
+ * De-interleave Q and Gate projections from q_proj output for attentionOutputGate models.
5
+ *
6
+ * Models like Qwen 3.5 store q_proj weights with interleaved head layout:
7
+ * rows [h*headDim*2 : h*headDim*2+headDim] = Q for head h
8
+ * rows [h*headDim*2+headDim : (h+1)*headDim*2] = Gate for head h
9
+ *
10
+ * A single full matmul over all 2*qSize rows produces interleaved output:
11
+ * input[token, h*headDim*2 : h*headDim*2+headDim] = Q head h
12
+ * input[token, h*headDim*2+headDim : (h+1)*headDim*2] = Gate head h
13
+ *
14
+ * This kernel separates them into contiguous Q and G outputs:
15
+ * Q[token, h*headDim + dim] = input[token, h*headDim*2 + dim]
16
+ * G[token, h*headDim + dim] = input[token, h*headDim*2 + headDim + dim]
17
+ *
18
+ * Input layout (row-major): [numTokens, numHeads * headDim * 2]
19
+ * Output Q layout (row-major): [numTokens, numHeads * headDim]
20
+ * Output G layout (row-major): [numTokens, numHeads * headDim]
21
+ */
22
+
23
+ struct Params {
24
+ num_tokens: u32,
25
+ num_heads: u32,
26
+ head_dim: u32,
27
+ _pad: u32,
28
+ }
29
+
30
+ override WORKGROUP_SIZE: u32 = 256u;
31
+
32
+ @group(0) @binding(0) var<uniform> params: Params;
33
+ @group(0) @binding(1) var<storage, read> input: array<f32>;
34
+ @group(0) @binding(2) var<storage, read_write> Q: array<f32>;
35
+ @group(0) @binding(3) var<storage, read_write> G: array<f32>;
36
+
37
+ @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
38
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
39
+ let idx = gid.x;
40
+ let q_size = params.num_heads * params.head_dim;
41
+ let total_elements = params.num_tokens * q_size;
42
+
43
+ if (idx >= total_elements) {
44
+ return;
45
+ }
46
+
47
+ let token = idx / q_size;
48
+ let elem = idx % q_size;
49
+ let head = elem / params.head_dim;
50
+ let dim = elem % params.head_dim;
51
+
52
+ // Input is interleaved per head: [Q_h (headDim elems), G_h (headDim elems)]
53
+ let src_q = token * (q_size * 2u) + head * (params.head_dim * 2u) + dim;
54
+ let src_g = src_q + params.head_dim;
55
+
56
+ Q[idx] = input[src_q];
57
+ G[idx] = input[src_g];
58
+ }
@@ -0,0 +1,62 @@
1
+ // AUTO-GENERATED from src/gpu/kernels/split_qg.wgsl.
2
+ // Edit the source kernel and tools/configs/wgsl-variants.js, then run `npm run kernels:generate`.
3
+ // split_qg_f16.wgsl
4
+
5
+ /**
6
+ * De-interleave Q and Gate projections from q_proj output for attentionOutputGate models (f16).
7
+ *
8
+ * Models like Qwen 3.5 store q_proj weights with interleaved head layout:
9
+ * rows [h*headDim*2 : h*headDim*2+headDim] = Q for head h
10
+ * rows [h*headDim*2+headDim : (h+1)*headDim*2] = Gate for head h
11
+ *
12
+ * A single full matmul over all 2*qSize rows produces interleaved output:
13
+ * input[token, h*headDim*2 : h*headDim*2+headDim] = Q head h
14
+ * input[token, h*headDim*2+headDim : (h+1)*headDim*2] = Gate head h
15
+ *
16
+ * This kernel separates them into contiguous Q and G outputs:
17
+ * Q[token, h*headDim + dim] = input[token, h*headDim*2 + dim]
18
+ * G[token, h*headDim + dim] = input[token, h*headDim*2 + headDim + dim]
19
+ *
20
+ * Input layout (row-major): [numTokens, numHeads * headDim * 2]
21
+ * Output Q layout (row-major): [numTokens, numHeads * headDim]
22
+ * Output G layout (row-major): [numTokens, numHeads * headDim]
23
+ */
24
+
25
+ enable f16;
26
+
27
+ struct Params {
28
+ num_tokens: u32,
29
+ num_heads: u32,
30
+ head_dim: u32,
31
+ _pad: u32,
32
+ }
33
+
34
+ override WORKGROUP_SIZE: u32 = 256u;
35
+
36
+ @group(0) @binding(0) var<uniform> params: Params;
37
+ @group(0) @binding(1) var<storage, read> input: array<f16>;
38
+ @group(0) @binding(2) var<storage, read_write> Q: array<f16>;
39
+ @group(0) @binding(3) var<storage, read_write> G: array<f16>;
40
+
41
+ @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
42
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
43
+ let idx = gid.x;
44
+ let q_size = params.num_heads * params.head_dim;
45
+ let total_elements = params.num_tokens * q_size;
46
+
47
+ if (idx >= total_elements) {
48
+ return;
49
+ }
50
+
51
+ let token = idx / q_size;
52
+ let elem = idx % q_size;
53
+ let head = elem / params.head_dim;
54
+ let dim = elem % params.head_dim;
55
+
56
+ // Input is interleaved per head: [Q_h (headDim elems), G_h (headDim elems)]
57
+ let src_q = token * (q_size * 2u) + head * (params.head_dim * 2u) + dim;
58
+ let src_g = src_q + params.head_dim;
59
+
60
+ Q[idx] = input[src_q];
61
+ G[idx] = input[src_g];
62
+ }
@@ -1,5 +1,5 @@
1
1
 
2
- import { acquireBuffer } from '../../memory/buffer-pool.js';
2
+ import { acquireBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
3
3
  import { createTensor, dtypeBytes } from '../tensor.js';
4
4
  import { WORKGROUP_SIZES } from './constants.js';
5
5
  import { unifiedKernelWrapper } from './utils.js';
@@ -7,6 +7,9 @@ import { selectRuleValue } from './rule-registry.js';
7
7
 
8
8
  async function _splitQKV(target, qkvTensor, options) {
9
9
  const { numTokens, qSize, kSize, vSize, qTensor = null, kTensor = null, vTensor = null } = options;
10
+ const ownsQ = qTensor == null;
11
+ const ownsK = kTensor == null;
12
+ const ownsV = vTensor == null;
10
13
 
11
14
  const outputDtype = qkvTensor.dtype;
12
15
  const pipelineVariant = selectRuleValue('splitQkv', 'variant', { outputDtype });
@@ -18,18 +21,25 @@ async function _splitQKV(target, qkvTensor, options) {
18
21
 
19
22
  const totalElements = numTokens * (qSize + kSize + vSize);
20
23
 
21
- await unifiedKernelWrapper(
22
- 'split_qkv', target, pipelineVariant,
23
- [qkvTensor, qBuffer, kBuffer, vBuffer],
24
- { num_tokens: numTokens, q_size: qSize, k_size: kSize, v_size: vSize },
25
- Math.ceil(totalElements / WORKGROUP_SIZES.DEFAULT)
26
- );
27
-
28
- const Q = qTensor || createTensor(qBuffer, outputDtype, [numTokens, qSize], 'Q');
29
- const K = kTensor || createTensor(kBuffer, outputDtype, [numTokens, kSize], 'K');
30
- const V = vTensor || createTensor(vBuffer, outputDtype, [numTokens, vSize], 'V');
31
-
32
- return { Q, K, V };
24
+ try {
25
+ await unifiedKernelWrapper(
26
+ 'split_qkv', target, pipelineVariant,
27
+ [qkvTensor, qBuffer, kBuffer, vBuffer],
28
+ { num_tokens: numTokens, q_size: qSize, k_size: kSize, v_size: vSize },
29
+ Math.ceil(totalElements / WORKGROUP_SIZES.DEFAULT)
30
+ );
31
+
32
+ const Q = qTensor || createTensor(qBuffer, outputDtype, [numTokens, qSize], 'Q');
33
+ const K = kTensor || createTensor(kBuffer, outputDtype, [numTokens, kSize], 'K');
34
+ const V = vTensor || createTensor(vBuffer, outputDtype, [numTokens, vSize], 'V');
35
+
36
+ return { Q, K, V };
37
+ } catch (error) {
38
+ if (ownsQ) releaseBuffer(qBuffer);
39
+ if (ownsK) releaseBuffer(kBuffer);
40
+ if (ownsV) releaseBuffer(vBuffer);
41
+ throw error;
42
+ }
33
43
  }
34
44
 
35
45
  export async function runSplitQKV(qkvTensor, options) {