@simulatte/doppler 0.1.6 → 0.1.8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (355) hide show
  1. package/CHANGELOG.md +145 -0
  2. package/README.md +16 -23
  3. package/package.json +30 -32
  4. package/src/adapters/adapter-registry.js +12 -1
  5. package/src/adapters/lora-loader.js +23 -6
  6. package/src/bridge/extension-client.d.ts +5 -0
  7. package/src/bridge/extension-client.js +40 -0
  8. package/src/bridge/index.d.ts +2 -1
  9. package/src/bridge/index.js +6 -4
  10. package/src/browser/browser-converter.js +31 -1
  11. package/src/browser/file-picker.js +6 -0
  12. package/src/browser/safetensors-parser-browser.js +84 -1
  13. package/src/browser/shard-io-browser.js +2 -2
  14. package/src/browser/tensor-source-download.js +8 -2
  15. package/src/browser/tensor-source-http.d.ts +1 -0
  16. package/src/browser/tensor-source-http.js +5 -1
  17. package/src/client/doppler-api.browser.js +20 -4
  18. package/src/client/doppler-api.js +19 -3
  19. package/src/client/doppler-provider/generation.js +12 -0
  20. package/src/client/doppler-provider/model-manager.d.ts +10 -0
  21. package/src/client/doppler-provider/model-manager.js +91 -19
  22. package/src/client/doppler-provider/source-runtime.d.ts +2 -1
  23. package/src/client/doppler-provider/source-runtime.js +132 -13
  24. package/src/client/doppler-registry.json +5 -20
  25. package/src/config/backward-registry-loader.js +17 -2
  26. package/src/config/execution-v0-contract-check.js +113 -15
  27. package/src/config/kernel-path-contract-check.js +57 -29
  28. package/src/config/kernel-path-loader.d.ts +5 -0
  29. package/src/config/kernel-path-loader.js +18 -36
  30. package/src/config/kernels/kernel-ref-digests.js +1 -1
  31. package/src/config/kernels/registry.js +14 -1
  32. package/src/config/kernels/registry.json +81 -5
  33. package/src/config/loader.d.ts +1 -1
  34. package/src/config/loader.js +15 -2
  35. package/src/config/merge-contract-check.js +66 -4
  36. package/src/config/merge-helpers.js +128 -7
  37. package/src/config/merge.d.ts +1 -0
  38. package/src/config/merge.js +10 -0
  39. package/src/config/param-validator.js +47 -2
  40. package/src/config/presets/kernel-paths/{gemma2-q4k-dequant-f32a.json → gemma2-q4k-dequant-f32a-nosubgroups.json} +3 -3
  41. package/src/config/presets/kernel-paths/gemma3-f16-fused-f32a-online-streamingprefill.json +223 -0
  42. package/src/config/presets/kernel-paths/{gemma3-q4k-dequant-f32a.json → gemma3-q4k-dequant-f32a-nosubgroups.json} +3 -3
  43. package/src/config/presets/kernel-paths/gemma3-q4k-dequant-f32w-f32a-online.json +56 -0
  44. package/src/config/presets/kernel-paths/lfm2-q4k-dequant-f32a-nosubgroups.json +61 -0
  45. package/src/config/presets/kernel-paths/registry.json +43 -8
  46. package/src/config/presets/models/gemma2.json +3 -2
  47. package/src/config/presets/models/gemma3.json +2 -0
  48. package/src/config/presets/models/qwen3.json +4 -3
  49. package/src/config/presets/models/qwen3_5.json +16 -0
  50. package/src/config/presets/runtime/experiments/bench/gemma3-bench-q4k.json +1 -1
  51. package/src/config/presets/runtime/experiments/debug/gemma3-debug-q4k.json +1 -1
  52. package/src/config/presets/runtime/experiments/verify/gemma3-verify.json +1 -1
  53. package/src/config/presets/runtime/kernels/dequant-f16-q4k.json +6 -13
  54. package/src/config/presets/runtime/kernels/dequant-f32-q4k.json +6 -13
  55. package/src/config/presets/runtime/kernels/embeddinggemma-q4k-dequant-f32a.json +37 -0
  56. package/src/config/presets/runtime/kernels/fused-q4k.json +6 -13
  57. package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f16a.json +33 -0
  58. package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f32a-nosubgroups.json +33 -0
  59. package/src/config/presets/runtime/kernels/gemma2-q4k-fused-f32a.json +33 -0
  60. package/src/config/presets/runtime/kernels/safe-q4k.json +6 -13
  61. package/src/config/presets/runtime/model/qwen3-5-layer-probe.json +52 -0
  62. package/src/config/presets/runtime/model/qwen3-5-linear-attn-debug.json +90 -0
  63. package/src/config/presets/runtime/platform/metal-apple-q4k.json +1 -1
  64. package/src/config/runtime.js +6 -1
  65. package/src/config/schema/conversion.schema.d.ts +1 -0
  66. package/src/config/schema/debug.schema.d.ts +5 -0
  67. package/src/config/schema/doppler.schema.js +16 -21
  68. package/src/config/schema/inference-defaults.schema.js +3 -3
  69. package/src/config/schema/kernel-path.schema.d.ts +5 -1
  70. package/src/config/schema/kernel-thresholds.schema.js +12 -4
  71. package/src/config/schema/manifest.schema.d.ts +3 -2
  72. package/src/config/schema/manifest.schema.js +17 -4
  73. package/src/config/schema/storage.schema.js +1 -1
  74. package/src/config/training-defaults.js +30 -22
  75. package/src/converter/conversion-plan.js +104 -11
  76. package/src/converter/core.d.ts +7 -0
  77. package/src/converter/core.js +16 -9
  78. package/src/converter/execution-v0-manifest.js +4 -1
  79. package/src/converter/index.d.ts +1 -0
  80. package/src/converter/index.js +1 -0
  81. package/src/converter/manifest-inference.js +50 -29
  82. package/src/converter/parsers/diffusion.js +0 -3
  83. package/src/converter/parsers/transformer.js +4 -0
  84. package/src/converter/quantization-info.js +40 -16
  85. package/src/converter/quantizer.js +19 -12
  86. package/src/converter/rope-config.js +8 -6
  87. package/src/converter/shard-packer.d.ts +1 -1
  88. package/src/converter/shard-packer.js +4 -1
  89. package/src/converter/tokenizer-utils.d.ts +1 -0
  90. package/src/converter/tokenizer-utils.js +4 -1
  91. package/src/debug/config.js +123 -11
  92. package/src/debug/reference/hf_qwen35_linear_attn_debug.py +268 -0
  93. package/src/debug/signals.js +7 -1
  94. package/src/debug/tensor.d.ts +2 -0
  95. package/src/debug/tensor.js +13 -2
  96. package/src/distribution/p2p-control-plane.js +52 -12
  97. package/src/distribution/p2p-observability.js +43 -7
  98. package/src/distribution/p2p-webrtc-browser.js +20 -0
  99. package/src/distribution/shard-delivery.js +83 -27
  100. package/src/formats/gguf/types.js +33 -16
  101. package/src/formats/rdrr/groups.d.ts +12 -4
  102. package/src/formats/rdrr/groups.js +3 -6
  103. package/src/formats/rdrr/parsing.d.ts +4 -0
  104. package/src/formats/rdrr/parsing.js +53 -3
  105. package/src/formats/rdrr/types.d.ts +2 -1
  106. package/src/gpu/command-recorder.js +86 -61
  107. package/src/gpu/device.d.ts +1 -0
  108. package/src/gpu/device.js +73 -19
  109. package/src/gpu/kernel-tuner/benchmarks.js +326 -316
  110. package/src/gpu/kernel-tuner/cache.js +71 -4
  111. package/src/gpu/kernel-tuner/tuner.js +22 -4
  112. package/src/gpu/kernels/attention.js +15 -34
  113. package/src/gpu/kernels/backward/adam.js +62 -58
  114. package/src/gpu/kernels/backward/attention_backward.js +257 -169
  115. package/src/gpu/kernels/backward/conv2d_backward.js +14 -1
  116. package/src/gpu/kernels/cast.js +191 -149
  117. package/src/gpu/kernels/check-stop.js +33 -44
  118. package/src/gpu/kernels/conv2d.js +27 -17
  119. package/src/gpu/kernels/cross_entropy_loss.js +21 -15
  120. package/src/gpu/kernels/depthwise_conv2d.js +36 -26
  121. package/src/gpu/kernels/dequant.js +178 -126
  122. package/src/gpu/kernels/energy.d.ts +3 -21
  123. package/src/gpu/kernels/energy.js +111 -88
  124. package/src/gpu/kernels/feature-check.js +1 -1
  125. package/src/gpu/kernels/fused_ffn.js +84 -65
  126. package/src/gpu/kernels/fused_matmul_residual.js +56 -33
  127. package/src/gpu/kernels/fused_matmul_rmsnorm.js +62 -45
  128. package/src/gpu/kernels/gather.js +33 -15
  129. package/src/gpu/kernels/gelu.js +19 -11
  130. package/src/gpu/kernels/grouped_pointwise_conv2d.js +33 -23
  131. package/src/gpu/kernels/groupnorm.js +34 -23
  132. package/src/gpu/kernels/index.d.ts +8 -0
  133. package/src/gpu/kernels/index.js +6 -0
  134. package/src/gpu/kernels/kv-quantize.js +5 -2
  135. package/src/gpu/kernels/layernorm.js +35 -19
  136. package/src/gpu/kernels/logit-merge.js +5 -3
  137. package/src/gpu/kernels/matmul-selection.js +47 -4
  138. package/src/gpu/kernels/matmul.d.ts +2 -0
  139. package/src/gpu/kernels/matmul.js +59 -40
  140. package/src/gpu/kernels/modulate.js +23 -15
  141. package/src/gpu/kernels/moe.js +221 -175
  142. package/src/gpu/kernels/pixel_shuffle.js +22 -14
  143. package/src/gpu/kernels/relu.js +18 -10
  144. package/src/gpu/kernels/repeat_channels.js +25 -17
  145. package/src/gpu/kernels/residual.js +37 -27
  146. package/src/gpu/kernels/rmsnorm.js +66 -43
  147. package/src/gpu/kernels/rope.js +3 -0
  148. package/src/gpu/kernels/sample.js +27 -38
  149. package/src/gpu/kernels/sana_linear_attention.js +18 -10
  150. package/src/gpu/kernels/scale.js +18 -11
  151. package/src/gpu/kernels/shader-cache.js +4 -2
  152. package/src/gpu/kernels/silu.js +120 -72
  153. package/src/gpu/kernels/softmax.js +44 -25
  154. package/src/gpu/kernels/split_qg.d.ts +50 -0
  155. package/src/gpu/kernels/split_qg.js +46 -0
  156. package/src/gpu/kernels/split_qg.wgsl +58 -0
  157. package/src/gpu/kernels/split_qg_f16.wgsl +62 -0
  158. package/src/gpu/kernels/split_qkv.js +23 -13
  159. package/src/gpu/kernels/transpose.js +18 -10
  160. package/src/gpu/kernels/transpose.wgsl +5 -3
  161. package/src/gpu/kernels/upsample2d.js +21 -13
  162. package/src/gpu/kernels/utils.js +20 -13
  163. package/src/gpu/partitioned-buffer-pool.js +10 -2
  164. package/src/gpu/perf-guards.js +2 -9
  165. package/src/gpu/profiler.js +27 -22
  166. package/src/gpu/readback-utils.d.ts +16 -0
  167. package/src/gpu/readback-utils.js +41 -0
  168. package/src/gpu/submit-tracker.js +13 -0
  169. package/src/gpu/uniform-cache.d.ts +1 -0
  170. package/src/gpu/uniform-cache.js +30 -9
  171. package/src/gpu/weight-buffer.d.ts +1 -1
  172. package/src/gpu/weight-buffer.js +1 -1
  173. package/src/hotswap/intent-bundle.js +6 -0
  174. package/src/hotswap/manifest.d.ts +10 -1
  175. package/src/hotswap/manifest.js +12 -2
  176. package/src/hotswap/runtime.js +30 -8
  177. package/src/index-browser.d.ts +44 -0
  178. package/src/index-browser.js +14 -0
  179. package/src/inference/browser-harness-contract-helpers.d.ts +5 -0
  180. package/src/inference/browser-harness-contract-helpers.js +28 -0
  181. package/src/inference/browser-harness-diffusion-energy-suites.d.ts +2 -0
  182. package/src/inference/browser-harness-diffusion-energy-suites.js +269 -0
  183. package/src/inference/browser-harness-model-helpers.d.ts +16 -0
  184. package/src/inference/browser-harness-model-helpers.js +217 -0
  185. package/src/inference/browser-harness-report-helpers.d.ts +7 -0
  186. package/src/inference/browser-harness-report-helpers.js +42 -0
  187. package/src/inference/browser-harness-runtime-helpers.d.ts +61 -0
  188. package/src/inference/browser-harness-runtime-helpers.js +415 -0
  189. package/src/inference/browser-harness-suite-helpers.d.ts +28 -0
  190. package/src/inference/browser-harness-suite-helpers.js +268 -0
  191. package/src/inference/browser-harness-text-helpers.d.ts +27 -0
  192. package/src/inference/browser-harness-text-helpers.js +788 -0
  193. package/src/inference/browser-harness.d.ts +8 -0
  194. package/src/inference/browser-harness.js +149 -1996
  195. package/src/inference/kv-cache/base.js +140 -94
  196. package/src/inference/kv-cache/tiered.js +5 -3
  197. package/src/inference/moe-router.js +88 -56
  198. package/src/inference/multi-model-network.js +5 -3
  199. package/src/inference/network-evolution.d.ts +11 -2
  200. package/src/inference/network-evolution.js +20 -21
  201. package/src/inference/pipelines/context.d.ts +3 -0
  202. package/src/inference/pipelines/context.js +142 -2
  203. package/src/inference/pipelines/diffusion/helpers.js +10 -2
  204. package/src/inference/pipelines/diffusion/pipeline.js +2 -1
  205. package/src/inference/pipelines/diffusion/sd3-transformer.js +10 -10
  206. package/src/inference/pipelines/diffusion/text-encoder-gpu.js +8 -2
  207. package/src/inference/pipelines/diffusion/vae.js +3 -7
  208. package/src/inference/pipelines/energy/pipeline.js +27 -21
  209. package/src/inference/pipelines/energy/quintel.d.ts +5 -0
  210. package/src/inference/pipelines/energy/quintel.js +11 -0
  211. package/src/inference/pipelines/energy-head/row-head-pipeline.js +17 -13
  212. package/src/inference/pipelines/structured/json-head-pipeline.js +26 -11
  213. package/src/inference/pipelines/text/attention/output-projection.d.ts +12 -0
  214. package/src/inference/pipelines/text/attention/output-projection.js +8 -0
  215. package/src/inference/pipelines/text/attention/projections.d.ts +10 -1
  216. package/src/inference/pipelines/text/attention/projections.js +192 -112
  217. package/src/inference/pipelines/text/attention/record.js +77 -14
  218. package/src/inference/pipelines/text/attention/run.js +112 -14
  219. package/src/inference/pipelines/text/config.js +17 -4
  220. package/src/inference/pipelines/text/embed.js +2 -8
  221. package/src/inference/pipelines/text/execution-plan.js +46 -23
  222. package/src/inference/pipelines/text/execution-v0-contract-helpers.d.ts +59 -0
  223. package/src/inference/pipelines/text/execution-v0-contract-helpers.js +937 -0
  224. package/src/inference/pipelines/text/execution-v0-runtime-builders.d.ts +15 -0
  225. package/src/inference/pipelines/text/execution-v0-runtime-builders.js +279 -0
  226. package/src/inference/pipelines/text/execution-v0.js +62 -1013
  227. package/src/inference/pipelines/text/generator-runtime.js +5 -0
  228. package/src/inference/pipelines/text/generator-steps.d.ts +52 -0
  229. package/src/inference/pipelines/text/generator-steps.js +340 -221
  230. package/src/inference/pipelines/text/generator.js +56 -40
  231. package/src/inference/pipelines/text/init.d.ts +13 -0
  232. package/src/inference/pipelines/text/init.js +94 -25
  233. package/src/inference/pipelines/text/kernel-path-auto-select.js +2 -0
  234. package/src/inference/pipelines/text/kernel-trace.d.ts +2 -0
  235. package/src/inference/pipelines/text/kernel-trace.js +6 -0
  236. package/src/inference/pipelines/text/layer.js +4 -9
  237. package/src/inference/pipelines/text/linear-attention.d.ts +15 -0
  238. package/src/inference/pipelines/text/linear-attention.js +113 -9
  239. package/src/inference/pipelines/text/logits/gpu.js +12 -7
  240. package/src/inference/pipelines/text/logits/index.d.ts +6 -1
  241. package/src/inference/pipelines/text/logits/index.js +13 -12
  242. package/src/inference/pipelines/text/logits/utils.d.ts +7 -0
  243. package/src/inference/pipelines/text/logits/utils.js +9 -0
  244. package/src/inference/pipelines/text/lora-apply.js +50 -32
  245. package/src/inference/pipelines/text/model-load.js +282 -104
  246. package/src/inference/pipelines/text/moe-cache.js +5 -4
  247. package/src/inference/pipelines/text/moe-cpu-gptoss.js +74 -69
  248. package/src/inference/pipelines/text/moe-cpu.js +42 -38
  249. package/src/inference/pipelines/text/moe-gpu.js +110 -86
  250. package/src/inference/pipelines/text/ops.js +90 -90
  251. package/src/inference/pipelines/text/probes.js +9 -9
  252. package/src/inference/pipelines/text/sampling.js +52 -6
  253. package/src/inference/pipelines/text/weights.js +17 -7
  254. package/src/inference/pipelines/text.js +13 -1
  255. package/src/inference/speculative.d.ts +2 -2
  256. package/src/inference/speculative.js +4 -18
  257. package/src/inference/test-harness.d.ts +1 -1
  258. package/src/inference/test-harness.js +17 -7
  259. package/src/inference/tokenizer.d.ts +0 -5
  260. package/src/inference/tokenizer.js +4 -23
  261. package/src/inference/tokenizers/bpe.js +9 -0
  262. package/src/inference/tokenizers/bundled.js +20 -0
  263. package/src/inference/tokenizers/sentencepiece.js +12 -0
  264. package/src/loader/doppler-loader.js +38 -22
  265. package/src/loader/dtype-utils.js +3 -44
  266. package/src/loader/embedding-loader.js +7 -3
  267. package/src/loader/experts/expert-cache.js +13 -6
  268. package/src/loader/experts/expert-loader.js +10 -6
  269. package/src/loader/final-weights-loader.js +10 -4
  270. package/src/loader/layer-loader.js +2 -1
  271. package/src/loader/loader-state.js +2 -2
  272. package/src/loader/memory-monitor.js +8 -0
  273. package/src/loader/multi-model-loader.d.ts +14 -0
  274. package/src/loader/multi-model-loader.js +70 -24
  275. package/src/loader/shard-cache.js +84 -14
  276. package/src/loader/shard-resolver.js +25 -3
  277. package/src/loader/tensors/tensor-loader.js +214 -144
  278. package/src/loader/tensors/tensor-reader.js +76 -19
  279. package/src/loader/weight-downcast.js +1 -1
  280. package/src/memory/buffer-pool.d.ts +9 -1
  281. package/src/memory/buffer-pool.js +109 -44
  282. package/src/memory/unified-detect.js +1 -1
  283. package/src/rules/inference/dtype.rules.json +5 -0
  284. package/src/rules/inference/kernel-path.rules.json +24 -8
  285. package/src/rules/kernels/split-qg.rules.json +6 -0
  286. package/src/rules/rule-registry.js +27 -1
  287. package/src/storage/backends/opfs-store.js +68 -24
  288. package/src/storage/downloader.js +365 -83
  289. package/src/storage/index.d.ts +3 -0
  290. package/src/storage/index.js +3 -0
  291. package/src/storage/preflight.d.ts +2 -2
  292. package/src/storage/preflight.js +24 -2
  293. package/src/storage/quickstart-downloader.js +11 -5
  294. package/src/storage/registry.js +10 -4
  295. package/src/storage/reports.js +1 -1
  296. package/src/storage/shard-manager.d.ts +15 -1
  297. package/src/storage/shard-manager.js +55 -6
  298. package/src/storage/source-artifact-store.d.ts +52 -0
  299. package/src/storage/source-artifact-store.js +234 -0
  300. package/src/tooling/command-api-constants.d.ts +9 -0
  301. package/src/tooling/command-api-constants.js +9 -0
  302. package/src/tooling/command-api-family-normalizers.d.ts +9 -0
  303. package/src/tooling/command-api-family-normalizers.js +343 -0
  304. package/src/tooling/command-api-helpers.d.ts +25 -0
  305. package/src/tooling/command-api-helpers.js +262 -0
  306. package/src/tooling/command-api.js +16 -602
  307. package/src/tooling/command-envelope.js +4 -1
  308. package/src/tooling/command-runner-shared.js +52 -18
  309. package/src/tooling/conversion-config-materializer.js +3 -5
  310. package/src/tooling/lean-execution-contract.js +150 -3
  311. package/src/tooling/node-browser-command-runner.js +161 -271
  312. package/src/tooling/node-command-runner.js +29 -3
  313. package/src/tooling/node-converter.js +30 -1
  314. package/src/tooling/node-source-runtime.d.ts +1 -1
  315. package/src/tooling/node-source-runtime.js +120 -3
  316. package/src/tooling/node-webgpu.js +24 -21
  317. package/src/tooling/opfs-cache.js +21 -4
  318. package/src/tooling/runtime-input-composition.d.ts +38 -0
  319. package/src/tooling/runtime-input-composition.js +86 -0
  320. package/src/tooling/source-runtime-bundle.d.ts +40 -5
  321. package/src/tooling/source-runtime-bundle.js +261 -34
  322. package/src/tooling/source-runtime-materializer.d.ts +6 -0
  323. package/src/tooling/source-runtime-materializer.js +93 -0
  324. package/src/training/attention-backward.js +32 -17
  325. package/src/training/autograd.js +80 -52
  326. package/src/training/checkpoint-watch.d.ts +2 -1
  327. package/src/training/checkpoint-watch.js +39 -6
  328. package/src/training/checkpoint.js +40 -11
  329. package/src/training/clip.js +2 -1
  330. package/src/training/datasets/token-batch.js +20 -8
  331. package/src/training/distillation/checkpoint-watch.js +1 -0
  332. package/src/training/distillation/student-fixture.d.ts +22 -0
  333. package/src/training/distillation/student-fixture.js +846 -0
  334. package/src/training/distillation/suite-data.d.ts +45 -0
  335. package/src/training/distillation/suite-data.js +189 -0
  336. package/src/training/lora-pipeline.js +4 -7
  337. package/src/training/lora.js +26 -12
  338. package/src/training/loss.js +5 -6
  339. package/src/training/objectives/cross_entropy.js +2 -5
  340. package/src/training/objectives/distill_kd.js +4 -8
  341. package/src/training/objectives/distill_triplet.js +4 -8
  342. package/src/training/objectives/ul_stage2_base.js +4 -8
  343. package/src/training/operator-command.js +2 -0
  344. package/src/training/optimizer.js +19 -7
  345. package/src/training/runner.js +2 -1
  346. package/src/training/suite.js +18 -978
  347. package/src/training/tensor-factory.d.ts +9 -0
  348. package/src/training/tensor-factory.js +13 -0
  349. package/src/training/trainer.js +3 -5
  350. package/src/training/ul_dataset.js +3 -5
  351. package/src/training/workloads.js +70 -79
  352. package/src/types/model.d.ts +5 -0
  353. package/src/version.js +1 -1
  354. package/tools/convert-safetensors-node.js +22 -16
  355. package/tools/doppler-cli.js +50 -26
@@ -1,7 +1,7 @@
1
1
 
2
2
 
3
3
  import { getDevice } from '../device.js';
4
- import { acquireBuffer, getBufferRequestedSize } from '../../memory/buffer-pool.js';
4
+ import { acquireBuffer, getBufferRequestedSize, releaseBuffer } from '../../memory/buffer-pool.js';
5
5
  import { createTensor } from '../tensor.js';
6
6
  import { getBuffer } from '../weight-buffer.js';
7
7
  import { dispatch, recordDispatch } from './dispatch.js';
@@ -91,7 +91,8 @@ export async function runMatmulRMSNormFused(
91
91
  // Output buffer: [1, N] - size depends on dtype
92
92
  const bytesPerElement = dtype === 'f16' ? 2 : 4;
93
93
  const outputSize = N * bytesPerElement;
94
- const output = outputBuffer || acquireBuffer(outputSize, undefined, 'matmul_rmsnorm_fused_output');
94
+ const ownedOutput = outputBuffer ? null : acquireBuffer(outputSize, undefined, 'matmul_rmsnorm_fused_output');
95
+ const output = outputBuffer || ownedOutput;
95
96
 
96
97
  // Create uniform buffer (8 u32/f32 = 32 bytes, padded for alignment)
97
98
  const uniformBuffer = createUniformBufferWithView(
@@ -110,36 +111,44 @@ export async function runMatmulRMSNormFused(
110
111
  );
111
112
 
112
113
  // Create placeholder for residual if not provided
114
+ const ownsResidualBuffer = !residual;
113
115
  const residualBuffer = residual || device.createBuffer({
114
116
  label: 'matmul_rmsnorm_residual_placeholder',
115
117
  size: 4,
116
118
  usage: GPUBufferUsage.STORAGE,
117
119
  });
118
120
 
119
- // Create bind group
120
- const bindGroup = device.createBindGroup({
121
- label: 'matmul_rmsnorm_fused_bind_group',
122
- layout: pipeline.getBindGroupLayout(0),
123
- entries: [
124
- { binding: 0, resource: { buffer: uniformBuffer } },
125
- { binding: 1, resource: { buffer: input.buffer } },
126
- { binding: 2, resource: { buffer: weightBuffer } },
127
- { binding: 3, resource: { buffer: normWeightBuffer } },
128
- { binding: 4, resource: { buffer: output } },
129
- { binding: 5, resource: { buffer: residualBuffer } },
130
- ],
131
- });
132
-
133
- // Calculate workgroups
134
-
135
- const workgroups = 1;
136
-
137
- const dispatchLabel = label ? `matmul_rmsnorm_fused:${label}` : 'matmul_rmsnorm_fused';
138
- dispatch(device, pipeline, bindGroup, workgroups, dispatchLabel);
121
+ try {
122
+ const bindGroup = device.createBindGroup({
123
+ label: 'matmul_rmsnorm_fused_bind_group',
124
+ layout: pipeline.getBindGroupLayout(0),
125
+ entries: [
126
+ { binding: 0, resource: { buffer: uniformBuffer } },
127
+ { binding: 1, resource: { buffer: input.buffer } },
128
+ { binding: 2, resource: { buffer: weightBuffer } },
129
+ { binding: 3, resource: { buffer: normWeightBuffer } },
130
+ { binding: 4, resource: { buffer: output } },
131
+ { binding: 5, resource: { buffer: residualBuffer } },
132
+ ],
133
+ });
134
+
135
+ const workgroups = 1;
136
+ const dispatchLabel = label ? `matmul_rmsnorm_fused:${label}` : 'matmul_rmsnorm_fused';
137
+ dispatch(device, pipeline, bindGroup, workgroups, dispatchLabel);
138
+ } catch (error) {
139
+ uniformBuffer.destroy();
140
+ if (ownsResidualBuffer) {
141
+ residualBuffer.destroy();
142
+ }
143
+ if (ownedOutput) {
144
+ releaseBuffer(ownedOutput);
145
+ }
146
+ throw error;
147
+ }
139
148
 
140
149
  // Cleanup
141
150
  uniformBuffer.destroy();
142
- if (!residual) residualBuffer.destroy();
151
+ if (ownsResidualBuffer) residualBuffer.destroy();
143
152
 
144
153
  // Output dtype matches input dtype
145
154
  return createTensor(output, input.dtype, [1, N], 'matmul_rmsnorm_fused_output');
@@ -199,7 +208,8 @@ export async function recordMatmulRMSNormFused(
199
208
  // Output buffer - size depends on dtype
200
209
  const bytesPerElement = dtype === 'f16' ? 2 : 4;
201
210
  const outputSize = N * bytesPerElement;
202
- const output = outputBuffer || acquireBuffer(outputSize, undefined, 'matmul_rmsnorm_fused_output');
211
+ const ownedOutput = outputBuffer ? null : acquireBuffer(outputSize, undefined, 'matmul_rmsnorm_fused_output');
212
+ const output = outputBuffer || ownedOutput;
203
213
 
204
214
  // Uniform buffer via recorder (8 u32/f32 = 32 bytes, padded for alignment)
205
215
  const uniformBuffer = createUniformBufferWithView(
@@ -217,35 +227,42 @@ export async function recordMatmulRMSNormFused(
217
227
  );
218
228
 
219
229
  // Placeholder for residual
230
+ const ownsResidualBuffer = !residual;
220
231
  const residualBuffer = residual || device.createBuffer({
221
232
  label: 'matmul_rmsnorm_residual_placeholder',
222
233
  size: 4,
223
234
  usage: GPUBufferUsage.STORAGE,
224
235
  });
225
236
 
226
- // Bind group
227
- const bindGroup = device.createBindGroup({
228
- label: 'matmul_rmsnorm_fused_bind_group',
229
- layout: pipeline.getBindGroupLayout(0),
230
- entries: [
231
- { binding: 0, resource: { buffer: uniformBuffer } },
232
- { binding: 1, resource: { buffer: input.buffer } },
233
- { binding: 2, resource: { buffer: weightBuffer } },
234
- { binding: 3, resource: { buffer: normWeightBuffer } },
235
- { binding: 4, resource: { buffer: output } },
236
- { binding: 5, resource: { buffer: residualBuffer } },
237
- ],
238
- });
239
-
240
- // Calculate workgroups
241
-
242
- const workgroups = 1;
243
-
244
- const dispatchLabel = label ? `matmul_rmsnorm_fused:${label}` : 'matmul_rmsnorm_fused';
245
- recordDispatch(recorder, pipeline, bindGroup, workgroups, dispatchLabel);
237
+ try {
238
+ const bindGroup = device.createBindGroup({
239
+ label: 'matmul_rmsnorm_fused_bind_group',
240
+ layout: pipeline.getBindGroupLayout(0),
241
+ entries: [
242
+ { binding: 0, resource: { buffer: uniformBuffer } },
243
+ { binding: 1, resource: { buffer: input.buffer } },
244
+ { binding: 2, resource: { buffer: weightBuffer } },
245
+ { binding: 3, resource: { buffer: normWeightBuffer } },
246
+ { binding: 4, resource: { buffer: output } },
247
+ { binding: 5, resource: { buffer: residualBuffer } },
248
+ ],
249
+ });
250
+
251
+ const workgroups = 1;
252
+ const dispatchLabel = label ? `matmul_rmsnorm_fused:${label}` : 'matmul_rmsnorm_fused';
253
+ recordDispatch(recorder, pipeline, bindGroup, workgroups, dispatchLabel);
254
+ } catch (error) {
255
+ if (ownsResidualBuffer) {
256
+ residualBuffer.destroy();
257
+ }
258
+ if (ownedOutput) {
259
+ releaseBuffer(ownedOutput);
260
+ }
261
+ throw error;
262
+ }
246
263
 
247
264
  // Track placeholder for cleanup
248
- if (!residual) {
265
+ if (ownsResidualBuffer) {
249
266
  recorder.trackTemporaryBuffer(residualBuffer);
250
267
  }
251
268
 
@@ -1,5 +1,5 @@
1
1
  import { getKernelCapabilities } from '../device.js';
2
- import { acquireBuffer } from '../../memory/buffer-pool.js';
2
+ import { acquireBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
3
3
  import { WORKGROUP_SIZES, VEC4_ELEMENTS_PER_WG } from './constants.js';
4
4
  import { unifiedKernelWrapper } from './utils.js';
5
5
  import { trace } from '../../debug/index.js';
@@ -26,7 +26,6 @@ async function _gather(
26
26
  options = {}
27
27
  ) {
28
28
  const {
29
- useVec4 = true,
30
29
  outputBuffer = null,
31
30
  embeddingDtype,
32
31
  outputDtype,
@@ -43,9 +42,22 @@ async function _gather(
43
42
  if (outputDtype == null) {
44
43
  throw new Error('[Gather] outputDtype is required.');
45
44
  }
45
+ if (embeddingDtype === 'f16' && !caps.hasF16) {
46
+ throw new Error('[Gather] embeddingDtype=f16 requires shader-f16 support.');
47
+ }
48
+ if (outputDtype === 'f16' && !caps.hasF16) {
49
+ throw new Error('[Gather] outputDtype=f16 requires shader-f16 support.');
50
+ }
46
51
 
47
- const useF16Input = embeddingDtype === 'f16' && caps.hasF16;
48
- const useF16Output = outputDtype === 'f16' && caps.hasF16;
52
+ const requestedVec4 = options.useVec4;
53
+ const wantsVec4 = requestedVec4 ?? true;
54
+ if (requestedVec4 === true && hiddenSize % 4 !== 0) {
55
+ throw new Error('[Gather] useVec4=true requires hiddenSize to be divisible by 4.');
56
+ }
57
+
58
+ const useF16Input = embeddingDtype === 'f16';
59
+ const useF16Output = outputDtype === 'f16';
60
+ const useVec4 = wantsVec4 && hiddenSize % 4 === 0;
49
61
 
50
62
  trace.embed(
51
63
  `Gather: numTokens=${numTokens}, hiddenSize=${hiddenSize}, vocabSize=${vocabSize}, ` +
@@ -64,6 +76,7 @@ async function _gather(
64
76
  const paddedHiddenSize = padToQ4KBlock(hiddenSize);
65
77
  const outputSize = numTokens * paddedHiddenSize * bytesPerElement;
66
78
  const output = outputBuffer || acquireBuffer(outputSize, undefined, 'gather_output');
79
+ const ownedOutput = outputBuffer ? null : output;
67
80
 
68
81
  const uniforms = {
69
82
  num_tokens: numTokens,
@@ -82,16 +95,22 @@ async function _gather(
82
95
  ? Math.ceil((numTokens * hiddenSize) / VEC4_ELEMENTS_PER_WG)
83
96
  : Math.ceil((numTokens * hiddenSize) / WORKGROUP_SIZES.DEFAULT));
84
97
 
85
- await unifiedKernelWrapper(
86
- 'gather',
87
- target,
88
- variant,
89
- [indices, embeddings, output],
90
- uniforms,
91
- workgroups
92
- );
93
-
94
- return createTensor(output, actualDtype, [numTokens, hiddenSize], 'gather_output');
98
+ try {
99
+ await unifiedKernelWrapper(
100
+ 'gather',
101
+ target,
102
+ variant,
103
+ [indices, embeddings, output],
104
+ uniforms,
105
+ workgroups
106
+ );
107
+ return createTensor(output, actualDtype, [numTokens, hiddenSize], 'gather_output');
108
+ } catch (error) {
109
+ if (ownedOutput) {
110
+ releaseBuffer(ownedOutput);
111
+ }
112
+ throw error;
113
+ }
95
114
  }
96
115
 
97
116
  export async function runGather(
@@ -116,4 +135,3 @@ export async function recordGather(
116
135
  ) {
117
136
  return _gather(recorder, indices, embeddings, numTokens, hiddenSize, vocabSize, options);
118
137
  }
119
-
@@ -1,5 +1,5 @@
1
1
 
2
- import { acquireBuffer } from '../../memory/buffer-pool.js';
2
+ import { acquireBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
3
3
  import { createTensor, dtypeBytes } from '../tensor.js';
4
4
  import { WORKGROUP_SIZES } from './constants.js';
5
5
  import { unifiedKernelWrapper } from './utils.js';
@@ -26,16 +26,24 @@ async function _gelu(target, input, options = {}) {
26
26
  const outputSize = inferredSize * bytesPerElement;
27
27
  const output = outputBuffer || acquireBuffer(outputSize, undefined, 'gelu_output');
28
28
  const gateBuffer = gate ?? input;
29
-
30
- await unifiedKernelWrapper(
31
- 'gelu', target, variant,
32
- [input, output, gateBuffer],
33
- { size: inferredSize, rowsplit_dim: 0 },
34
- Math.ceil(inferredSize / WORKGROUP_SIZES.DEFAULT),
35
- overrides
36
- );
37
-
38
- return createTensor(output, input.dtype, [inferredSize], 'gelu_output');
29
+ const ownedOutput = outputBuffer ? null : output;
30
+
31
+ try {
32
+ await unifiedKernelWrapper(
33
+ 'gelu', target, variant,
34
+ [input, output, gateBuffer],
35
+ { size: inferredSize, rowsplit_dim: 0 },
36
+ Math.ceil(inferredSize / WORKGROUP_SIZES.DEFAULT),
37
+ overrides
38
+ );
39
+
40
+ return createTensor(output, input.dtype, [inferredSize], 'gelu_output');
41
+ } catch (error) {
42
+ if (ownedOutput) {
43
+ releaseBuffer(ownedOutput);
44
+ }
45
+ throw error;
46
+ }
39
47
  }
40
48
 
41
49
  export async function runGeLU(input, options = {}) {
@@ -55,33 +55,43 @@ async function _groupedPointwiseConv2D(target, input, weight, bias, options = {}
55
55
  device.queue.writeBuffer(biasBuffer, 0, new Uint8Array(paddedSize));
56
56
  }
57
57
 
58
- await unifiedKernelWrapper(
59
- 'grouped_pointwise_conv2d',
60
- target,
61
- variant,
62
- [input, weightBuffer, biasBuffer, output],
63
- {
64
- in_channels: inChannels,
65
- out_channels: outChannels,
66
- height,
67
- width,
68
- groups,
69
- _pad0: 0,
70
- _pad1: 0,
71
- _pad2: 0,
72
- },
73
- [Math.ceil(spatial / WORKGROUP_SIZES.DEFAULT), outChannels, 1]
74
- );
58
+ try {
59
+ await unifiedKernelWrapper(
60
+ 'grouped_pointwise_conv2d',
61
+ target,
62
+ variant,
63
+ [input, weightBuffer, biasBuffer, output],
64
+ {
65
+ in_channels: inChannels,
66
+ out_channels: outChannels,
67
+ height,
68
+ width,
69
+ groups,
70
+ _pad0: 0,
71
+ _pad1: 0,
72
+ _pad2: 0,
73
+ },
74
+ [Math.ceil(spatial / WORKGROUP_SIZES.DEFAULT), outChannels, 1]
75
+ );
76
+
77
+ if (tempBias) {
78
+ if (recorder) {
79
+ recorder.trackTemporaryBuffer(tempBias);
80
+ } else {
81
+ releaseBuffer(tempBias);
82
+ }
83
+ }
75
84
 
76
- if (tempBias) {
77
- if (recorder) {
78
- recorder.trackTemporaryBuffer(tempBias);
79
- } else {
85
+ return createTensor(output, input.dtype, [outChannels, height, width], 'grouped_pointwise_conv2d_output');
86
+ } catch (error) {
87
+ if (tempBias) {
80
88
  releaseBuffer(tempBias);
81
89
  }
90
+ if (!outputBuffer) {
91
+ releaseBuffer(output);
92
+ }
93
+ throw error;
82
94
  }
83
-
84
- return createTensor(output, input.dtype, [outChannels, height, width], 'grouped_pointwise_conv2d_output');
85
95
  }
86
96
 
87
97
  export async function runGroupedPointwiseConv2D(input, weight, bias, options = {}) {
@@ -17,6 +17,9 @@ function validateOptions(options) {
17
17
  if (!Number.isFinite(numGroups) || numGroups <= 0) {
18
18
  throw new Error('GroupNorm requires numGroups > 0.');
19
19
  }
20
+ if (channels % numGroups !== 0) {
21
+ throw new Error('GroupNorm requires channels to be divisible by numGroups.');
22
+ }
20
23
  if (!Number.isFinite(eps)) {
21
24
  throw new Error('GroupNorm requires eps.');
22
25
  }
@@ -44,34 +47,42 @@ async function _groupNorm(target, input, weight, bias, options = {}) {
44
47
 
45
48
  const statsSize = numGroups * 2 * 4;
46
49
  const statsBuffer = acquireBuffer(statsSize, undefined, 'groupnorm_stats');
47
-
48
- await unifiedKernelWrapper(
49
- 'groupnorm_stats',
50
- target,
51
- statsVariant,
52
- [input, statsBuffer],
53
- uniforms,
54
- numGroups
55
- );
56
-
57
50
  const bytesPerElement = dtypeBytes(input.dtype);
58
51
  const outputSize = channels * height * width * bytesPerElement;
59
- const output = outputBuffer || acquireBuffer(outputSize, undefined, 'groupnorm_output');
52
+ const ownedOutput = outputBuffer ? null : acquireBuffer(outputSize, undefined, 'groupnorm_output');
53
+ const output = outputBuffer || ownedOutput;
60
54
 
61
- const weightBuffer = getBuffer(weight);
62
- const biasBuffer = getBuffer(bias);
55
+ try {
56
+ await unifiedKernelWrapper(
57
+ 'groupnorm_stats',
58
+ target,
59
+ statsVariant,
60
+ [input, statsBuffer],
61
+ uniforms,
62
+ numGroups
63
+ );
63
64
 
64
- const total = channels * height * width;
65
- const workgroups = Math.ceil(total / WORKGROUP_SIZES.DEFAULT);
65
+ const weightBuffer = getBuffer(weight);
66
+ const biasBuffer = getBuffer(bias);
66
67
 
67
- await unifiedKernelWrapper(
68
- 'groupnorm_apply',
69
- target,
70
- applyVariant,
71
- [input, statsBuffer, weightBuffer, biasBuffer, output],
72
- uniforms,
73
- workgroups
74
- );
68
+ const total = channels * height * width;
69
+ const workgroups = Math.ceil(total / WORKGROUP_SIZES.DEFAULT);
70
+
71
+ await unifiedKernelWrapper(
72
+ 'groupnorm_apply',
73
+ target,
74
+ applyVariant,
75
+ [input, statsBuffer, weightBuffer, biasBuffer, output],
76
+ uniforms,
77
+ workgroups
78
+ );
79
+ } catch (error) {
80
+ releaseBuffer(statsBuffer);
81
+ if (ownedOutput) {
82
+ releaseBuffer(ownedOutput);
83
+ }
84
+ throw error;
85
+ }
75
86
 
76
87
  if (recorder) {
77
88
  recorder.trackTemporaryBuffer(statsBuffer);
@@ -326,6 +326,14 @@ export {
326
326
  type SplitQKVResult,
327
327
  } from './split_qkv.js';
328
328
 
329
+ // Split Q and Gate (de-interleave attentionOutputGate q_proj output)
330
+ export {
331
+ runSplitQG,
332
+ recordSplitQG,
333
+ type SplitQGOptions,
334
+ type SplitQGResult,
335
+ } from './split_qg.js';
336
+
329
337
  // Transpose
330
338
  export {
331
339
  runTranspose,
@@ -268,6 +268,12 @@ export {
268
268
  recordSplitQKV,
269
269
  } from './split_qkv.js';
270
270
 
271
+ // Split Q and Gate (de-interleave attentionOutputGate q_proj output)
272
+ export {
273
+ runSplitQG,
274
+ recordSplitQG,
275
+ } from './split_qg.js';
276
+
271
277
  // Transpose
272
278
  export {
273
279
  runTranspose,
@@ -78,8 +78,11 @@ export async function runKVQuantize(
78
78
  });
79
79
 
80
80
  const workgroups = [numKVHeads, numTokens, 1];
81
- dispatch(device, pipeline, bindGroup, workgroups, 'kv_quantize');
82
- uniformBuffer.destroy();
81
+ try {
82
+ dispatch(device, pipeline, bindGroup, workgroups, 'kv_quantize');
83
+ } finally {
84
+ uniformBuffer.destroy();
85
+ }
83
86
  }
84
87
 
85
88
 
@@ -1,6 +1,6 @@
1
1
 
2
2
  import { getKernelCapabilities } from '../device.js';
3
- import { acquireBuffer } from '../../memory/buffer-pool.js';
3
+ import { acquireBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
4
4
  import { createTensor } from '../tensor.js';
5
5
  import { padToQ4KBlock } from '../../config/schema/index.js';
6
6
  import { selectRuleValue } from './rule-registry.js';
@@ -36,17 +36,25 @@ export async function runLayerNorm(
36
36
  const paddedHiddenSize = padToQ4KBlock(inferredHiddenSize);
37
37
  const outputSize = batchSize * paddedHiddenSize * bytesPerElement;
38
38
  const outputBuf = outputBuffer || acquireBuffer(outputSize, undefined, 'layernorm_output');
39
+ const ownedOutput = outputBuffer ? null : outputBuf;
39
40
 
40
- await unifiedKernelWrapper(
41
- 'layernorm',
42
- null,
43
- variant,
44
- [input, weight, bias, outputBuf],
45
- { hidden_size: inferredHiddenSize, num_tokens: batchSize, eps },
46
- batchSize
47
- );
41
+ try {
42
+ await unifiedKernelWrapper(
43
+ 'layernorm',
44
+ null,
45
+ variant,
46
+ [input, weight, bias, outputBuf],
47
+ { hidden_size: inferredHiddenSize, num_tokens: batchSize, eps },
48
+ batchSize
49
+ );
48
50
 
49
- return createTensor(outputBuf, input.dtype, [batchSize, inferredHiddenSize], 'layernorm_output');
51
+ return createTensor(outputBuf, input.dtype, [batchSize, inferredHiddenSize], 'layernorm_output');
52
+ } catch (error) {
53
+ if (ownedOutput) {
54
+ releaseBuffer(ownedOutput);
55
+ }
56
+ throw error;
57
+ }
50
58
  }
51
59
 
52
60
  export async function recordLayerNorm(
@@ -66,15 +74,23 @@ export async function recordLayerNorm(
66
74
  const paddedHiddenSize = padToQ4KBlock(inferredHiddenSize);
67
75
  const outputSize = batchSize * paddedHiddenSize * bytesPerElement;
68
76
  const outputBuf = outputBuffer || acquireBuffer(outputSize, undefined, 'layernorm_output');
77
+ const ownedOutput = outputBuffer ? null : outputBuf;
69
78
 
70
- await unifiedKernelWrapper(
71
- 'layernorm',
72
- recorder,
73
- variant,
74
- [input, weight, bias, outputBuf],
75
- { hidden_size: inferredHiddenSize, num_tokens: batchSize, eps },
76
- batchSize
77
- );
79
+ try {
80
+ await unifiedKernelWrapper(
81
+ 'layernorm',
82
+ recorder,
83
+ variant,
84
+ [input, weight, bias, outputBuf],
85
+ { hidden_size: inferredHiddenSize, num_tokens: batchSize, eps },
86
+ batchSize
87
+ );
78
88
 
79
- return createTensor(outputBuf, input.dtype, [batchSize, inferredHiddenSize], 'layernorm_output');
89
+ return createTensor(outputBuf, input.dtype, [batchSize, inferredHiddenSize], 'layernorm_output');
90
+ } catch (error) {
91
+ if (ownedOutput) {
92
+ releaseBuffer(ownedOutput);
93
+ }
94
+ throw error;
95
+ }
80
96
  }
@@ -266,9 +266,11 @@ export class LogitMergeKernel {
266
266
  pass.end();
267
267
 
268
268
  this.#device.queue.submit([encoder.finish()]);
269
-
270
- // Cleanup temporary buffer
271
- paramsBuffer.destroy();
269
+ this.#device.queue.onSubmittedWorkDone()
270
+ .catch(() => {})
271
+ .finally(() => {
272
+ paramsBuffer.destroy();
273
+ });
272
274
 
273
275
  return mergedBuffer;
274
276
  }
@@ -29,7 +29,13 @@ function selectQ4KFusedVariant(isM1, wantF16Output, aDtype) {
29
29
  }
30
30
 
31
31
 
32
- export function resolveMatmulPhase(M) {
32
+ export function resolveMatmulPhase(M, phaseOverride = null) {
33
+ if (phaseOverride != null) {
34
+ if (phaseOverride !== 'decode' && phaseOverride !== 'prefill') {
35
+ throw new Error(`[Matmul] Invalid phase override "${phaseOverride}". Expected "decode" or "prefill".`);
36
+ }
37
+ return phaseOverride;
38
+ }
33
39
  return selectKernelRuleValue('matmul', 'phase', { isDecode: M === 1 });
34
40
  }
35
41
 
@@ -125,7 +131,9 @@ export function selectMatmulKernel(options = {}) {
125
131
  const { tiledPrefillMinRows } = getKernelThresholds().matmul;
126
132
 
127
133
  const inputsAreF16 = aDtype === 'f16' && bDtype === 'f16';
128
- const weightsAreF16 = bDtype === 'f16' && aDtype !== 'f16';
134
+ // F16 weights needing F32a path: weights are F16 and either activation is already F32,
135
+ // or both inputs are F16 but output is F32 (activation will be cast to F32 by executeMatmul)
136
+ const weightsAreF16 = bDtype === 'f16' && (aDtype !== 'f16' || outputDtype !== 'f16');
129
137
  const useF16Matmul = outputDtype === 'f16' && preferF16 && inputsAreF16 && capabilities.hasF16;
130
138
  const useF16wF32a = preferF16 && weightsAreF16 && capabilities.hasF16;
131
139
  const useTiled = isPrefill
@@ -244,6 +252,30 @@ export function requiresF32Input(variant) {
244
252
  return !supportsF16Input(variant);
245
253
  }
246
254
 
255
+ function resolveRequiredWeightDtype(config) {
256
+ const shaderFile = String(config?.shaderFile ?? config?.wgsl ?? '');
257
+ if (!shaderFile) {
258
+ return null;
259
+ }
260
+ if (shaderFile.startsWith('fused_matmul_q4')) {
261
+ return 'q4k';
262
+ }
263
+ if (
264
+ shaderFile === 'matmul_f16.wgsl'
265
+ || shaderFile === 'matmul_f16_tiled.wgsl'
266
+ || shaderFile === 'matmul_f16w_f32a.wgsl'
267
+ || shaderFile === 'matmul_f16w_f32a_tiled.wgsl'
268
+ || shaderFile === 'matmul_gemv_subgroup.wgsl'
269
+ || shaderFile === 'matmul_gemv_subgroup_f16a.wgsl'
270
+ ) {
271
+ return 'f16';
272
+ }
273
+ if (shaderFile === 'matmul_f32.wgsl') {
274
+ return 'f32';
275
+ }
276
+ return null;
277
+ }
278
+
247
279
 
248
280
  function resolveMatmulOverride(
249
281
  variantOverride,
@@ -287,6 +319,16 @@ function resolveMatmulOverride(
287
319
  );
288
320
  }
289
321
 
322
+ const requiredWeightDtype = resolveRequiredWeightDtype(config);
323
+ const weightDtypeOk = !requiredWeightDtype
324
+ || bDtype === requiredWeightDtype
325
+ || (requiredWeightDtype === 'f16' && bDtype === 'q4k');
326
+ if (!weightDtypeOk) {
327
+ return failOrWarn(
328
+ `Matmul kernel "${variantOverride}" requires ${requiredWeightDtype} weights but B dtype is ${bDtype}.`
329
+ );
330
+ }
331
+
290
332
  if (supportsF16Input(override) && aDtype !== 'f16') {
291
333
  return failOrWarn(`Matmul kernel "${variantOverride}" requires f16 activations but A dtype is ${aDtype}.`);
292
334
  }
@@ -341,7 +383,7 @@ function selectGemvVariant(useF16Gemv, useF32Gemv, hasSubgroups, useVec4, N, mul
341
383
  export function selectMatmulVariantAndFlags(mode, M, N, K, aDtype, bDtype, transposeB, requestedOutputDtype, options) {
342
384
  const capabilities = getKernelCapabilities();
343
385
  const strict = getKernelPathStrict();
344
- const phase = resolveMatmulPhase(M);
386
+ const phase = resolveMatmulPhase(M, options.phaseOverride ?? null);
345
387
  let pathVariant = getKernelPathMatmulVariant(options.role, phase, options.layerIdx, options.kernelPath);
346
388
  const hadPathVariant = Boolean(pathVariant);
347
389
 
@@ -426,7 +468,8 @@ export function selectMatmulVariantAndFlags(mode, M, N, K, aDtype, bDtype, trans
426
468
 
427
469
  const canGemv = M === 1 && effectiveBDtype === 'f16' && capabilities.hasF16;
428
470
  const useF16Gemv = canGemv && aDtype === 'f16' && wantF16Output;
429
- const useF32Gemv = canGemv && aDtype === 'f32';
471
+ // F32 GEMV: activation is F32, or activation is F16 with F32 output (will be cast to F32)
472
+ const useF32Gemv = canGemv && (aDtype === 'f32' || (aDtype === 'f16' && !wantF16Output));
430
473
  const useGemv = useF16Gemv || useF32Gemv;
431
474
  const useVec4 = (K % 4 === 0);
432
475
  const { multicolThreshold } = getKernelThresholds().matmul;
@@ -23,6 +23,8 @@ export interface MatmulOptions extends OutputBufferOptions, OutputDtypeOptions,
23
23
  layerIdx?: number;
24
24
  /** Explicit kernel path context for variant selection (avoids global path state). */
25
25
  kernelPath?: KernelPathSchema | null;
26
+ /** Optional explicit phase for kernel-path lookup when the runtime rewrites rows (for example prefill last-position logits). */
27
+ phaseOverride?: 'decode' | 'prefill' | null;
26
28
  /**
27
29
  * Whether B matrix is stored transposed.
28
30
  * - true: B is [N,K] (SafeTensors/row-major), needs transpose