@simulatte/doppler 0.1.6 → 0.1.8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (355) hide show
  1. package/CHANGELOG.md +145 -0
  2. package/README.md +16 -23
  3. package/package.json +30 -32
  4. package/src/adapters/adapter-registry.js +12 -1
  5. package/src/adapters/lora-loader.js +23 -6
  6. package/src/bridge/extension-client.d.ts +5 -0
  7. package/src/bridge/extension-client.js +40 -0
  8. package/src/bridge/index.d.ts +2 -1
  9. package/src/bridge/index.js +6 -4
  10. package/src/browser/browser-converter.js +31 -1
  11. package/src/browser/file-picker.js +6 -0
  12. package/src/browser/safetensors-parser-browser.js +84 -1
  13. package/src/browser/shard-io-browser.js +2 -2
  14. package/src/browser/tensor-source-download.js +8 -2
  15. package/src/browser/tensor-source-http.d.ts +1 -0
  16. package/src/browser/tensor-source-http.js +5 -1
  17. package/src/client/doppler-api.browser.js +20 -4
  18. package/src/client/doppler-api.js +19 -3
  19. package/src/client/doppler-provider/generation.js +12 -0
  20. package/src/client/doppler-provider/model-manager.d.ts +10 -0
  21. package/src/client/doppler-provider/model-manager.js +91 -19
  22. package/src/client/doppler-provider/source-runtime.d.ts +2 -1
  23. package/src/client/doppler-provider/source-runtime.js +132 -13
  24. package/src/client/doppler-registry.json +5 -20
  25. package/src/config/backward-registry-loader.js +17 -2
  26. package/src/config/execution-v0-contract-check.js +113 -15
  27. package/src/config/kernel-path-contract-check.js +57 -29
  28. package/src/config/kernel-path-loader.d.ts +5 -0
  29. package/src/config/kernel-path-loader.js +18 -36
  30. package/src/config/kernels/kernel-ref-digests.js +1 -1
  31. package/src/config/kernels/registry.js +14 -1
  32. package/src/config/kernels/registry.json +81 -5
  33. package/src/config/loader.d.ts +1 -1
  34. package/src/config/loader.js +15 -2
  35. package/src/config/merge-contract-check.js +66 -4
  36. package/src/config/merge-helpers.js +128 -7
  37. package/src/config/merge.d.ts +1 -0
  38. package/src/config/merge.js +10 -0
  39. package/src/config/param-validator.js +47 -2
  40. package/src/config/presets/kernel-paths/{gemma2-q4k-dequant-f32a.json → gemma2-q4k-dequant-f32a-nosubgroups.json} +3 -3
  41. package/src/config/presets/kernel-paths/gemma3-f16-fused-f32a-online-streamingprefill.json +223 -0
  42. package/src/config/presets/kernel-paths/{gemma3-q4k-dequant-f32a.json → gemma3-q4k-dequant-f32a-nosubgroups.json} +3 -3
  43. package/src/config/presets/kernel-paths/gemma3-q4k-dequant-f32w-f32a-online.json +56 -0
  44. package/src/config/presets/kernel-paths/lfm2-q4k-dequant-f32a-nosubgroups.json +61 -0
  45. package/src/config/presets/kernel-paths/registry.json +43 -8
  46. package/src/config/presets/models/gemma2.json +3 -2
  47. package/src/config/presets/models/gemma3.json +2 -0
  48. package/src/config/presets/models/qwen3.json +4 -3
  49. package/src/config/presets/models/qwen3_5.json +16 -0
  50. package/src/config/presets/runtime/experiments/bench/gemma3-bench-q4k.json +1 -1
  51. package/src/config/presets/runtime/experiments/debug/gemma3-debug-q4k.json +1 -1
  52. package/src/config/presets/runtime/experiments/verify/gemma3-verify.json +1 -1
  53. package/src/config/presets/runtime/kernels/dequant-f16-q4k.json +6 -13
  54. package/src/config/presets/runtime/kernels/dequant-f32-q4k.json +6 -13
  55. package/src/config/presets/runtime/kernels/embeddinggemma-q4k-dequant-f32a.json +37 -0
  56. package/src/config/presets/runtime/kernels/fused-q4k.json +6 -13
  57. package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f16a.json +33 -0
  58. package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f32a-nosubgroups.json +33 -0
  59. package/src/config/presets/runtime/kernels/gemma2-q4k-fused-f32a.json +33 -0
  60. package/src/config/presets/runtime/kernels/safe-q4k.json +6 -13
  61. package/src/config/presets/runtime/model/qwen3-5-layer-probe.json +52 -0
  62. package/src/config/presets/runtime/model/qwen3-5-linear-attn-debug.json +90 -0
  63. package/src/config/presets/runtime/platform/metal-apple-q4k.json +1 -1
  64. package/src/config/runtime.js +6 -1
  65. package/src/config/schema/conversion.schema.d.ts +1 -0
  66. package/src/config/schema/debug.schema.d.ts +5 -0
  67. package/src/config/schema/doppler.schema.js +16 -21
  68. package/src/config/schema/inference-defaults.schema.js +3 -3
  69. package/src/config/schema/kernel-path.schema.d.ts +5 -1
  70. package/src/config/schema/kernel-thresholds.schema.js +12 -4
  71. package/src/config/schema/manifest.schema.d.ts +3 -2
  72. package/src/config/schema/manifest.schema.js +17 -4
  73. package/src/config/schema/storage.schema.js +1 -1
  74. package/src/config/training-defaults.js +30 -22
  75. package/src/converter/conversion-plan.js +104 -11
  76. package/src/converter/core.d.ts +7 -0
  77. package/src/converter/core.js +16 -9
  78. package/src/converter/execution-v0-manifest.js +4 -1
  79. package/src/converter/index.d.ts +1 -0
  80. package/src/converter/index.js +1 -0
  81. package/src/converter/manifest-inference.js +50 -29
  82. package/src/converter/parsers/diffusion.js +0 -3
  83. package/src/converter/parsers/transformer.js +4 -0
  84. package/src/converter/quantization-info.js +40 -16
  85. package/src/converter/quantizer.js +19 -12
  86. package/src/converter/rope-config.js +8 -6
  87. package/src/converter/shard-packer.d.ts +1 -1
  88. package/src/converter/shard-packer.js +4 -1
  89. package/src/converter/tokenizer-utils.d.ts +1 -0
  90. package/src/converter/tokenizer-utils.js +4 -1
  91. package/src/debug/config.js +123 -11
  92. package/src/debug/reference/hf_qwen35_linear_attn_debug.py +268 -0
  93. package/src/debug/signals.js +7 -1
  94. package/src/debug/tensor.d.ts +2 -0
  95. package/src/debug/tensor.js +13 -2
  96. package/src/distribution/p2p-control-plane.js +52 -12
  97. package/src/distribution/p2p-observability.js +43 -7
  98. package/src/distribution/p2p-webrtc-browser.js +20 -0
  99. package/src/distribution/shard-delivery.js +83 -27
  100. package/src/formats/gguf/types.js +33 -16
  101. package/src/formats/rdrr/groups.d.ts +12 -4
  102. package/src/formats/rdrr/groups.js +3 -6
  103. package/src/formats/rdrr/parsing.d.ts +4 -0
  104. package/src/formats/rdrr/parsing.js +53 -3
  105. package/src/formats/rdrr/types.d.ts +2 -1
  106. package/src/gpu/command-recorder.js +86 -61
  107. package/src/gpu/device.d.ts +1 -0
  108. package/src/gpu/device.js +73 -19
  109. package/src/gpu/kernel-tuner/benchmarks.js +326 -316
  110. package/src/gpu/kernel-tuner/cache.js +71 -4
  111. package/src/gpu/kernel-tuner/tuner.js +22 -4
  112. package/src/gpu/kernels/attention.js +15 -34
  113. package/src/gpu/kernels/backward/adam.js +62 -58
  114. package/src/gpu/kernels/backward/attention_backward.js +257 -169
  115. package/src/gpu/kernels/backward/conv2d_backward.js +14 -1
  116. package/src/gpu/kernels/cast.js +191 -149
  117. package/src/gpu/kernels/check-stop.js +33 -44
  118. package/src/gpu/kernels/conv2d.js +27 -17
  119. package/src/gpu/kernels/cross_entropy_loss.js +21 -15
  120. package/src/gpu/kernels/depthwise_conv2d.js +36 -26
  121. package/src/gpu/kernels/dequant.js +178 -126
  122. package/src/gpu/kernels/energy.d.ts +3 -21
  123. package/src/gpu/kernels/energy.js +111 -88
  124. package/src/gpu/kernels/feature-check.js +1 -1
  125. package/src/gpu/kernels/fused_ffn.js +84 -65
  126. package/src/gpu/kernels/fused_matmul_residual.js +56 -33
  127. package/src/gpu/kernels/fused_matmul_rmsnorm.js +62 -45
  128. package/src/gpu/kernels/gather.js +33 -15
  129. package/src/gpu/kernels/gelu.js +19 -11
  130. package/src/gpu/kernels/grouped_pointwise_conv2d.js +33 -23
  131. package/src/gpu/kernels/groupnorm.js +34 -23
  132. package/src/gpu/kernels/index.d.ts +8 -0
  133. package/src/gpu/kernels/index.js +6 -0
  134. package/src/gpu/kernels/kv-quantize.js +5 -2
  135. package/src/gpu/kernels/layernorm.js +35 -19
  136. package/src/gpu/kernels/logit-merge.js +5 -3
  137. package/src/gpu/kernels/matmul-selection.js +47 -4
  138. package/src/gpu/kernels/matmul.d.ts +2 -0
  139. package/src/gpu/kernels/matmul.js +59 -40
  140. package/src/gpu/kernels/modulate.js +23 -15
  141. package/src/gpu/kernels/moe.js +221 -175
  142. package/src/gpu/kernels/pixel_shuffle.js +22 -14
  143. package/src/gpu/kernels/relu.js +18 -10
  144. package/src/gpu/kernels/repeat_channels.js +25 -17
  145. package/src/gpu/kernels/residual.js +37 -27
  146. package/src/gpu/kernels/rmsnorm.js +66 -43
  147. package/src/gpu/kernels/rope.js +3 -0
  148. package/src/gpu/kernels/sample.js +27 -38
  149. package/src/gpu/kernels/sana_linear_attention.js +18 -10
  150. package/src/gpu/kernels/scale.js +18 -11
  151. package/src/gpu/kernels/shader-cache.js +4 -2
  152. package/src/gpu/kernels/silu.js +120 -72
  153. package/src/gpu/kernels/softmax.js +44 -25
  154. package/src/gpu/kernels/split_qg.d.ts +50 -0
  155. package/src/gpu/kernels/split_qg.js +46 -0
  156. package/src/gpu/kernels/split_qg.wgsl +58 -0
  157. package/src/gpu/kernels/split_qg_f16.wgsl +62 -0
  158. package/src/gpu/kernels/split_qkv.js +23 -13
  159. package/src/gpu/kernels/transpose.js +18 -10
  160. package/src/gpu/kernels/transpose.wgsl +5 -3
  161. package/src/gpu/kernels/upsample2d.js +21 -13
  162. package/src/gpu/kernels/utils.js +20 -13
  163. package/src/gpu/partitioned-buffer-pool.js +10 -2
  164. package/src/gpu/perf-guards.js +2 -9
  165. package/src/gpu/profiler.js +27 -22
  166. package/src/gpu/readback-utils.d.ts +16 -0
  167. package/src/gpu/readback-utils.js +41 -0
  168. package/src/gpu/submit-tracker.js +13 -0
  169. package/src/gpu/uniform-cache.d.ts +1 -0
  170. package/src/gpu/uniform-cache.js +30 -9
  171. package/src/gpu/weight-buffer.d.ts +1 -1
  172. package/src/gpu/weight-buffer.js +1 -1
  173. package/src/hotswap/intent-bundle.js +6 -0
  174. package/src/hotswap/manifest.d.ts +10 -1
  175. package/src/hotswap/manifest.js +12 -2
  176. package/src/hotswap/runtime.js +30 -8
  177. package/src/index-browser.d.ts +44 -0
  178. package/src/index-browser.js +14 -0
  179. package/src/inference/browser-harness-contract-helpers.d.ts +5 -0
  180. package/src/inference/browser-harness-contract-helpers.js +28 -0
  181. package/src/inference/browser-harness-diffusion-energy-suites.d.ts +2 -0
  182. package/src/inference/browser-harness-diffusion-energy-suites.js +269 -0
  183. package/src/inference/browser-harness-model-helpers.d.ts +16 -0
  184. package/src/inference/browser-harness-model-helpers.js +217 -0
  185. package/src/inference/browser-harness-report-helpers.d.ts +7 -0
  186. package/src/inference/browser-harness-report-helpers.js +42 -0
  187. package/src/inference/browser-harness-runtime-helpers.d.ts +61 -0
  188. package/src/inference/browser-harness-runtime-helpers.js +415 -0
  189. package/src/inference/browser-harness-suite-helpers.d.ts +28 -0
  190. package/src/inference/browser-harness-suite-helpers.js +268 -0
  191. package/src/inference/browser-harness-text-helpers.d.ts +27 -0
  192. package/src/inference/browser-harness-text-helpers.js +788 -0
  193. package/src/inference/browser-harness.d.ts +8 -0
  194. package/src/inference/browser-harness.js +149 -1996
  195. package/src/inference/kv-cache/base.js +140 -94
  196. package/src/inference/kv-cache/tiered.js +5 -3
  197. package/src/inference/moe-router.js +88 -56
  198. package/src/inference/multi-model-network.js +5 -3
  199. package/src/inference/network-evolution.d.ts +11 -2
  200. package/src/inference/network-evolution.js +20 -21
  201. package/src/inference/pipelines/context.d.ts +3 -0
  202. package/src/inference/pipelines/context.js +142 -2
  203. package/src/inference/pipelines/diffusion/helpers.js +10 -2
  204. package/src/inference/pipelines/diffusion/pipeline.js +2 -1
  205. package/src/inference/pipelines/diffusion/sd3-transformer.js +10 -10
  206. package/src/inference/pipelines/diffusion/text-encoder-gpu.js +8 -2
  207. package/src/inference/pipelines/diffusion/vae.js +3 -7
  208. package/src/inference/pipelines/energy/pipeline.js +27 -21
  209. package/src/inference/pipelines/energy/quintel.d.ts +5 -0
  210. package/src/inference/pipelines/energy/quintel.js +11 -0
  211. package/src/inference/pipelines/energy-head/row-head-pipeline.js +17 -13
  212. package/src/inference/pipelines/structured/json-head-pipeline.js +26 -11
  213. package/src/inference/pipelines/text/attention/output-projection.d.ts +12 -0
  214. package/src/inference/pipelines/text/attention/output-projection.js +8 -0
  215. package/src/inference/pipelines/text/attention/projections.d.ts +10 -1
  216. package/src/inference/pipelines/text/attention/projections.js +192 -112
  217. package/src/inference/pipelines/text/attention/record.js +77 -14
  218. package/src/inference/pipelines/text/attention/run.js +112 -14
  219. package/src/inference/pipelines/text/config.js +17 -4
  220. package/src/inference/pipelines/text/embed.js +2 -8
  221. package/src/inference/pipelines/text/execution-plan.js +46 -23
  222. package/src/inference/pipelines/text/execution-v0-contract-helpers.d.ts +59 -0
  223. package/src/inference/pipelines/text/execution-v0-contract-helpers.js +937 -0
  224. package/src/inference/pipelines/text/execution-v0-runtime-builders.d.ts +15 -0
  225. package/src/inference/pipelines/text/execution-v0-runtime-builders.js +279 -0
  226. package/src/inference/pipelines/text/execution-v0.js +62 -1013
  227. package/src/inference/pipelines/text/generator-runtime.js +5 -0
  228. package/src/inference/pipelines/text/generator-steps.d.ts +52 -0
  229. package/src/inference/pipelines/text/generator-steps.js +340 -221
  230. package/src/inference/pipelines/text/generator.js +56 -40
  231. package/src/inference/pipelines/text/init.d.ts +13 -0
  232. package/src/inference/pipelines/text/init.js +94 -25
  233. package/src/inference/pipelines/text/kernel-path-auto-select.js +2 -0
  234. package/src/inference/pipelines/text/kernel-trace.d.ts +2 -0
  235. package/src/inference/pipelines/text/kernel-trace.js +6 -0
  236. package/src/inference/pipelines/text/layer.js +4 -9
  237. package/src/inference/pipelines/text/linear-attention.d.ts +15 -0
  238. package/src/inference/pipelines/text/linear-attention.js +113 -9
  239. package/src/inference/pipelines/text/logits/gpu.js +12 -7
  240. package/src/inference/pipelines/text/logits/index.d.ts +6 -1
  241. package/src/inference/pipelines/text/logits/index.js +13 -12
  242. package/src/inference/pipelines/text/logits/utils.d.ts +7 -0
  243. package/src/inference/pipelines/text/logits/utils.js +9 -0
  244. package/src/inference/pipelines/text/lora-apply.js +50 -32
  245. package/src/inference/pipelines/text/model-load.js +282 -104
  246. package/src/inference/pipelines/text/moe-cache.js +5 -4
  247. package/src/inference/pipelines/text/moe-cpu-gptoss.js +74 -69
  248. package/src/inference/pipelines/text/moe-cpu.js +42 -38
  249. package/src/inference/pipelines/text/moe-gpu.js +110 -86
  250. package/src/inference/pipelines/text/ops.js +90 -90
  251. package/src/inference/pipelines/text/probes.js +9 -9
  252. package/src/inference/pipelines/text/sampling.js +52 -6
  253. package/src/inference/pipelines/text/weights.js +17 -7
  254. package/src/inference/pipelines/text.js +13 -1
  255. package/src/inference/speculative.d.ts +2 -2
  256. package/src/inference/speculative.js +4 -18
  257. package/src/inference/test-harness.d.ts +1 -1
  258. package/src/inference/test-harness.js +17 -7
  259. package/src/inference/tokenizer.d.ts +0 -5
  260. package/src/inference/tokenizer.js +4 -23
  261. package/src/inference/tokenizers/bpe.js +9 -0
  262. package/src/inference/tokenizers/bundled.js +20 -0
  263. package/src/inference/tokenizers/sentencepiece.js +12 -0
  264. package/src/loader/doppler-loader.js +38 -22
  265. package/src/loader/dtype-utils.js +3 -44
  266. package/src/loader/embedding-loader.js +7 -3
  267. package/src/loader/experts/expert-cache.js +13 -6
  268. package/src/loader/experts/expert-loader.js +10 -6
  269. package/src/loader/final-weights-loader.js +10 -4
  270. package/src/loader/layer-loader.js +2 -1
  271. package/src/loader/loader-state.js +2 -2
  272. package/src/loader/memory-monitor.js +8 -0
  273. package/src/loader/multi-model-loader.d.ts +14 -0
  274. package/src/loader/multi-model-loader.js +70 -24
  275. package/src/loader/shard-cache.js +84 -14
  276. package/src/loader/shard-resolver.js +25 -3
  277. package/src/loader/tensors/tensor-loader.js +214 -144
  278. package/src/loader/tensors/tensor-reader.js +76 -19
  279. package/src/loader/weight-downcast.js +1 -1
  280. package/src/memory/buffer-pool.d.ts +9 -1
  281. package/src/memory/buffer-pool.js +109 -44
  282. package/src/memory/unified-detect.js +1 -1
  283. package/src/rules/inference/dtype.rules.json +5 -0
  284. package/src/rules/inference/kernel-path.rules.json +24 -8
  285. package/src/rules/kernels/split-qg.rules.json +6 -0
  286. package/src/rules/rule-registry.js +27 -1
  287. package/src/storage/backends/opfs-store.js +68 -24
  288. package/src/storage/downloader.js +365 -83
  289. package/src/storage/index.d.ts +3 -0
  290. package/src/storage/index.js +3 -0
  291. package/src/storage/preflight.d.ts +2 -2
  292. package/src/storage/preflight.js +24 -2
  293. package/src/storage/quickstart-downloader.js +11 -5
  294. package/src/storage/registry.js +10 -4
  295. package/src/storage/reports.js +1 -1
  296. package/src/storage/shard-manager.d.ts +15 -1
  297. package/src/storage/shard-manager.js +55 -6
  298. package/src/storage/source-artifact-store.d.ts +52 -0
  299. package/src/storage/source-artifact-store.js +234 -0
  300. package/src/tooling/command-api-constants.d.ts +9 -0
  301. package/src/tooling/command-api-constants.js +9 -0
  302. package/src/tooling/command-api-family-normalizers.d.ts +9 -0
  303. package/src/tooling/command-api-family-normalizers.js +343 -0
  304. package/src/tooling/command-api-helpers.d.ts +25 -0
  305. package/src/tooling/command-api-helpers.js +262 -0
  306. package/src/tooling/command-api.js +16 -602
  307. package/src/tooling/command-envelope.js +4 -1
  308. package/src/tooling/command-runner-shared.js +52 -18
  309. package/src/tooling/conversion-config-materializer.js +3 -5
  310. package/src/tooling/lean-execution-contract.js +150 -3
  311. package/src/tooling/node-browser-command-runner.js +161 -271
  312. package/src/tooling/node-command-runner.js +29 -3
  313. package/src/tooling/node-converter.js +30 -1
  314. package/src/tooling/node-source-runtime.d.ts +1 -1
  315. package/src/tooling/node-source-runtime.js +120 -3
  316. package/src/tooling/node-webgpu.js +24 -21
  317. package/src/tooling/opfs-cache.js +21 -4
  318. package/src/tooling/runtime-input-composition.d.ts +38 -0
  319. package/src/tooling/runtime-input-composition.js +86 -0
  320. package/src/tooling/source-runtime-bundle.d.ts +40 -5
  321. package/src/tooling/source-runtime-bundle.js +261 -34
  322. package/src/tooling/source-runtime-materializer.d.ts +6 -0
  323. package/src/tooling/source-runtime-materializer.js +93 -0
  324. package/src/training/attention-backward.js +32 -17
  325. package/src/training/autograd.js +80 -52
  326. package/src/training/checkpoint-watch.d.ts +2 -1
  327. package/src/training/checkpoint-watch.js +39 -6
  328. package/src/training/checkpoint.js +40 -11
  329. package/src/training/clip.js +2 -1
  330. package/src/training/datasets/token-batch.js +20 -8
  331. package/src/training/distillation/checkpoint-watch.js +1 -0
  332. package/src/training/distillation/student-fixture.d.ts +22 -0
  333. package/src/training/distillation/student-fixture.js +846 -0
  334. package/src/training/distillation/suite-data.d.ts +45 -0
  335. package/src/training/distillation/suite-data.js +189 -0
  336. package/src/training/lora-pipeline.js +4 -7
  337. package/src/training/lora.js +26 -12
  338. package/src/training/loss.js +5 -6
  339. package/src/training/objectives/cross_entropy.js +2 -5
  340. package/src/training/objectives/distill_kd.js +4 -8
  341. package/src/training/objectives/distill_triplet.js +4 -8
  342. package/src/training/objectives/ul_stage2_base.js +4 -8
  343. package/src/training/operator-command.js +2 -0
  344. package/src/training/optimizer.js +19 -7
  345. package/src/training/runner.js +2 -1
  346. package/src/training/suite.js +18 -978
  347. package/src/training/tensor-factory.d.ts +9 -0
  348. package/src/training/tensor-factory.js +13 -0
  349. package/src/training/trainer.js +3 -5
  350. package/src/training/ul_dataset.js +3 -5
  351. package/src/training/workloads.js +70 -79
  352. package/src/types/model.d.ts +5 -0
  353. package/src/version.js +1 -1
  354. package/tools/convert-safetensors-node.js +22 -16
  355. package/tools/doppler-cli.js +50 -26
@@ -46,7 +46,16 @@ export function recordAttentionInputs(
46
46
  info: AttentionInputInfo | null | undefined
47
47
  ): void;
48
48
 
49
- export function resolveAttentionProjectionOutputDtype(attentionInputDtype: string): 'f16' | 'f32' | string;
49
+ export function shouldForceF32AttentionProjectionForRoPE(options: {
50
+ attentionInputDtype: string;
51
+ headDim: number;
52
+ rotaryDim?: number;
53
+ interleaved?: boolean;
54
+ }): boolean;
55
+ export function resolveAttentionProjectionOutputDtype(
56
+ attentionInputDtype: string,
57
+ options?: { forceF32?: boolean }
58
+ ): 'f16' | 'f32' | string;
50
59
  export function resolveProjectionSliceOffsetBytes(
51
60
  weightBuffer: WeightBuffer | Tensor | GPUBuffer | null | undefined,
52
61
  outputRows: number,
@@ -1,10 +1,12 @@
1
- import { acquireBuffer } from '../../../../memory/buffer-pool.js';
1
+ import { releaseBuffer } from '../../../../memory/buffer-pool.js';
2
2
  import { isWeightBuffer, getLayout, getWeightDtype } from '../../../../gpu/weight-buffer.js';
3
3
  import {
4
4
  runMatmul,
5
5
  recordMatmul,
6
6
  runSplitQKV,
7
7
  recordSplitQKV,
8
+ runSplitQG,
9
+ recordSplitQG,
8
10
  runRMSNorm,
9
11
  recordRMSNorm,
10
12
  } from '../../../../gpu/kernel-selector.js';
@@ -28,6 +30,13 @@ function getSplitRunner(recorder) {
28
30
  return (qkvTensor, options) => recordSplitQKV(recorder, qkvTensor, options);
29
31
  }
30
32
 
33
+ function getSplitQGRunner(recorder) {
34
+ if (!recorder) {
35
+ return (qgTensor, options) => runSplitQG(qgTensor, options);
36
+ }
37
+ return (qgTensor, options) => recordSplitQG(recorder, qgTensor, options);
38
+ }
39
+
31
40
  function getRmsNormRunner(recorder) {
32
41
  if (!recorder) {
33
42
  return (input, weight, eps, options) => runRMSNorm(input, weight, eps, options);
@@ -36,7 +45,7 @@ function getRmsNormRunner(recorder) {
36
45
  }
37
46
 
38
47
  function releaseOwnedWeightBuffer(layerWeight, resolvedWeightBuffer, releaseTemporary) {
39
- if (layerWeight instanceof GPUBuffer || isWeightBuffer(layerWeight)) {
48
+ if ((typeof GPUBuffer !== 'undefined' && layerWeight instanceof GPUBuffer) || isWeightBuffer(layerWeight)) {
40
49
  return;
41
50
  }
42
51
  if (!resolvedWeightBuffer) {
@@ -66,10 +75,16 @@ async function projectSingleQkvTensor({
66
75
  }) {
67
76
  const runMatmulForMode = getMatmulRunner(recorder);
68
77
  const layerWeight = layerWeights?.[weightKey];
69
- let projected;
78
+ if (!layerWeight) {
79
+ throw new Error(`Attention projection requires ${weightKey}.`);
80
+ }
81
+ if (!getWeightBuffer) {
82
+ throw new Error(`Attention projection requires getWeightBuffer for ${role}.`);
83
+ }
70
84
 
71
- if (layerWeight && getWeightBuffer) {
72
- const projBuffer = getWeightBuffer(layerWeight, role);
85
+ let projected;
86
+ const projBuffer = getWeightBuffer(layerWeight, role);
87
+ try {
73
88
  projected = await runMatmulForMode(normed, projBuffer, numTokens, outputSize, hiddenSize, {
74
89
  transposeB: 'auto',
75
90
  role,
@@ -77,26 +92,31 @@ async function projectSingleQkvTensor({
77
92
  kernelPath,
78
93
  outputDtype: matmulOutputDtype,
79
94
  });
95
+ } finally {
80
96
  releaseOwnedWeightBuffer(layerWeight, projBuffer, releaseTemporary);
81
- } else {
82
- const fallback = acquireBuffer(numTokens * outputSize * 4, undefined, outputLabel);
83
- projected = createTensor(fallback, normed.dtype, [numTokens, outputSize], outputLabel);
84
97
  }
85
98
 
86
99
  const loraModule = getLoRAModule(lora, layerIdx, loraKey);
87
100
  if (loraModule && getWeightBuffer) {
88
- const combined = await applyLoRA(
89
- normed,
90
- projected,
91
- loraModule,
92
- { M: numTokens, N: outputSize, K: hiddenSize },
93
- getWeightBuffer,
94
- recorder ?? undefined,
95
- { kernelPath }
96
- );
97
- if (combined.buffer !== projected.buffer) {
98
- releaseTemporary(projected.buffer);
99
- projected = combined;
101
+ try {
102
+ const combined = await applyLoRA(
103
+ normed,
104
+ projected,
105
+ loraModule,
106
+ { M: numTokens, N: outputSize, K: hiddenSize },
107
+ getWeightBuffer,
108
+ recorder ?? undefined,
109
+ { kernelPath }
110
+ );
111
+ if (combined.buffer !== projected.buffer) {
112
+ releaseTemporary(projected.buffer);
113
+ projected = combined;
114
+ }
115
+ } catch (error) {
116
+ if (projected?.buffer) {
117
+ releaseTemporary(projected.buffer);
118
+ }
119
+ throw error;
100
120
  }
101
121
  }
102
122
 
@@ -190,13 +210,17 @@ async function projectQueryWithOptionalGate({
190
210
  return { qTensor, qGateTensor: null };
191
211
  }
192
212
 
213
+ // q_proj weights are stored with interleaved head layout: for head h,
214
+ // rows [h*headDim*2 : h*headDim*2+headDim] = Q, rows [h*headDim*2+headDim : (h+1)*headDim*2] = gate.
215
+ // Compute the full 2*qSize matmul, then de-interleave into separate Q and gate tensors.
193
216
  const runMatmulForMode = getMatmulRunner(recorder);
217
+ const runSplitQGForMode = getSplitQGRunner(recorder);
194
218
  const qWeightBuffer = getWeightBuffer(qWeight, 'q_proj');
195
- const gateOffset = resolveProjectionSliceOffsetBytes(qWeightBuffer, qSize, hiddenSize);
219
+ let fullQGTensor = null;
196
220
  let qTensor = null;
197
221
  let qGateTensor = null;
198
222
  try {
199
- qTensor = await runMatmulForMode(normed, qWeightBuffer, numTokens, qSize, hiddenSize, {
223
+ fullQGTensor = await runMatmulForMode(normed, qWeightBuffer, numTokens, qSize * 2, hiddenSize, {
200
224
  transposeB: 'auto',
201
225
  role: 'q_proj',
202
226
  layerIdx,
@@ -204,32 +228,54 @@ async function projectQueryWithOptionalGate({
204
228
  outputDtype: matmulOutputDtype,
205
229
  });
206
230
 
207
- qGateTensor = await runMatmulForMode(normed, qWeightBuffer, numTokens, qSize, hiddenSize, {
208
- transposeB: 'auto',
209
- role: 'q_proj_gate',
210
- layerIdx,
211
- kernelPath,
212
- bOffset: gateOffset,
213
- outputDtype: matmulOutputDtype,
231
+ const split = await runSplitQGForMode(fullQGTensor, {
232
+ numTokens,
233
+ numHeads,
234
+ headDim,
214
235
  });
236
+ releaseTemporary(fullQGTensor.buffer);
237
+ fullQGTensor = null;
238
+ qTensor = split.Q;
239
+ qGateTensor = split.G;
240
+ } catch (error) {
241
+ if (fullQGTensor) {
242
+ releaseTemporary(fullQGTensor.buffer);
243
+ }
244
+ if (qTensor) {
245
+ releaseTemporary(qTensor.buffer);
246
+ }
247
+ if (qGateTensor) {
248
+ releaseTemporary(qGateTensor.buffer);
249
+ }
250
+ throw error;
215
251
  } finally {
216
252
  releaseOwnedWeightBuffer(qWeight, qWeightBuffer, releaseTemporary);
217
253
  }
218
254
 
219
255
  const loraModule = getLoRAModule(lora, layerIdx, 'q_proj');
220
256
  if (loraModule && getWeightBuffer) {
221
- const combined = await applyLoRA(
222
- normed,
223
- qTensor,
224
- loraModule,
225
- { M: numTokens, N: qSize, K: hiddenSize },
226
- getWeightBuffer,
227
- recorder ?? undefined,
228
- { kernelPath }
229
- );
230
- if (combined.buffer !== qTensor.buffer) {
231
- releaseTemporary(qTensor.buffer);
232
- qTensor = combined;
257
+ try {
258
+ const combined = await applyLoRA(
259
+ normed,
260
+ qTensor,
261
+ loraModule,
262
+ { M: numTokens, N: qSize, K: hiddenSize },
263
+ getWeightBuffer,
264
+ recorder ?? undefined,
265
+ { kernelPath }
266
+ );
267
+ if (combined.buffer !== qTensor.buffer) {
268
+ releaseTemporary(qTensor.buffer);
269
+ qTensor = combined;
270
+ }
271
+ } catch (error) {
272
+ if (qTensor?.buffer) {
273
+ releaseTemporary(qTensor.buffer);
274
+ }
275
+ if (qGateTensor?.buffer) {
276
+ releaseTemporary(qGateTensor.buffer);
277
+ }
278
+ throw error;
233
279
  }
234
280
  }
235
281
 
@@ -248,9 +294,22 @@ export function recordAttentionInputs(state, info) {
248
294
  state.stats.attentionInputs.push(info);
249
295
  }
250
296
 
251
- export function resolveAttentionProjectionOutputDtype(attentionInputDtype) {
297
+ export function shouldForceF32AttentionProjectionForRoPE({
298
+ attentionInputDtype,
299
+ headDim,
300
+ rotaryDim = headDim,
301
+ interleaved = false,
302
+ }) {
303
+ return attentionInputDtype === 'f16'
304
+ && Number.isFinite(headDim)
305
+ && Number.isFinite(rotaryDim)
306
+ && (rotaryDim !== headDim || interleaved === true);
307
+ }
308
+
309
+ export function resolveAttentionProjectionOutputDtype(attentionInputDtype, options = {}) {
252
310
  const useF16Activations = attentionInputDtype === 'f16';
253
- return selectRuleValue('shared', 'dtype', 'f16OrFallbackByFlag', {
311
+ return selectRuleValue('inference', 'dtype', 'attentionProjectionOutputDtype', {
312
+ forceF32: options.forceF32 === true,
254
313
  useF16: useF16Activations,
255
314
  fallback: attentionInputDtype,
256
315
  });
@@ -289,82 +348,103 @@ export async function projectAttentionQKV({
289
348
  if (useFusedQKV && layerWeights.qkvProj && layerWeights.qkvSizes) {
290
349
  const [qSizeFused, kSizeFused, vSizeFused] = layerWeights.qkvSizes;
291
350
  const qkvSizeTotal = qSizeFused + kSizeFused + vSizeFused;
292
- const qkvTensor = await runMatmulForMode(normed, layerWeights.qkvProj, numTokens, qkvSizeTotal, hiddenSize, {
293
- transposeB: 'auto',
294
- role: 'qkv_proj',
351
+ let qkvTensor = null;
352
+ try {
353
+ qkvTensor = await runMatmulForMode(normed, layerWeights.qkvProj, numTokens, qkvSizeTotal, hiddenSize, {
354
+ transposeB: 'auto',
355
+ role: 'qkv_proj',
356
+ layerIdx,
357
+ kernelPath,
358
+ outputDtype: matmulOutputDtype,
359
+ });
360
+ const split = await runSplitForMode(qkvTensor, {
361
+ numTokens,
362
+ qSize: qSizeFused,
363
+ kSize: kSizeFused,
364
+ vSize: vSizeFused,
365
+ });
366
+ releaseTemporary(qkvTensor.buffer);
367
+ if (onFusedQKV) {
368
+ onFusedQKV({ qSize: qSizeFused, kSize: kSizeFused, vSize: vSizeFused, totalSize: qkvSizeTotal });
369
+ }
370
+ return { qTensor: split.Q, qGateTensor: null, kTensor: split.K, vTensor: split.V, usedFusedQKV: true };
371
+ } catch (error) {
372
+ if (qkvTensor) {
373
+ releaseTemporary(qkvTensor.buffer);
374
+ }
375
+ throw error;
376
+ }
377
+ }
378
+
379
+ let qTensor = null;
380
+ let qGateTensor = null;
381
+ let kTensor = null;
382
+ let vTensor = null;
383
+ try {
384
+ ({ qTensor, qGateTensor } = await projectQueryWithOptionalGate({
385
+ recorder,
386
+ normed,
387
+ layerWeights,
388
+ numTokens,
389
+ numHeads,
390
+ headDim,
391
+ hiddenSize,
295
392
  layerIdx,
296
393
  kernelPath,
297
- outputDtype: matmulOutputDtype,
394
+ matmulOutputDtype,
395
+ getWeightBuffer,
396
+ lora,
397
+ releaseTemporary,
398
+ attentionOutputGate,
399
+ }));
400
+
401
+ kTensor = await projectSingleQkvTensor({
402
+ recorder,
403
+ normed,
404
+ layerWeights,
405
+ weightKey: 'kProj',
406
+ role: 'k_proj',
407
+ outputSize: numKVHeads * headDim,
408
+ outputLabel: 'K',
409
+ loraKey: 'k_proj',
410
+ numTokens,
411
+ hiddenSize,
412
+ layerIdx,
413
+ kernelPath,
414
+ matmulOutputDtype,
415
+ getWeightBuffer,
416
+ lora,
417
+ releaseTemporary,
298
418
  });
299
- const split = await runSplitForMode(qkvTensor, {
419
+
420
+ vTensor = await projectSingleQkvTensor({
421
+ recorder,
422
+ normed,
423
+ layerWeights,
424
+ weightKey: 'vProj',
425
+ role: 'v_proj',
426
+ outputSize: numKVHeads * headDim,
427
+ outputLabel: 'V',
428
+ loraKey: 'v_proj',
300
429
  numTokens,
301
- qSize: qSizeFused,
302
- kSize: kSizeFused,
303
- vSize: vSizeFused,
430
+ hiddenSize,
431
+ layerIdx,
432
+ kernelPath,
433
+ matmulOutputDtype,
434
+ getWeightBuffer,
435
+ lora,
436
+ releaseTemporary,
304
437
  });
305
- releaseTemporary(qkvTensor.buffer);
306
- if (onFusedQKV) {
307
- onFusedQKV({ qSize: qSizeFused, kSize: kSizeFused, vSize: vSizeFused, totalSize: qkvSizeTotal });
438
+
439
+ return { qTensor, qGateTensor, kTensor, vTensor, usedFusedQKV: false };
440
+ } catch (error) {
441
+ for (const tensor of [qTensor, qGateTensor, kTensor, vTensor]) {
442
+ if (tensor?.buffer) {
443
+ releaseTemporary(tensor.buffer);
444
+ }
308
445
  }
309
- return { qTensor: split.Q, qGateTensor: null, kTensor: split.K, vTensor: split.V, usedFusedQKV: true };
446
+ throw error;
310
447
  }
311
-
312
- const { qTensor, qGateTensor } = await projectQueryWithOptionalGate({
313
- recorder,
314
- normed,
315
- layerWeights,
316
- numTokens,
317
- numHeads,
318
- headDim,
319
- hiddenSize,
320
- layerIdx,
321
- kernelPath,
322
- matmulOutputDtype,
323
- getWeightBuffer,
324
- lora,
325
- releaseTemporary,
326
- attentionOutputGate,
327
- });
328
-
329
- const kTensor = await projectSingleQkvTensor({
330
- recorder,
331
- normed,
332
- layerWeights,
333
- weightKey: 'kProj',
334
- role: 'k_proj',
335
- outputSize: numKVHeads * headDim,
336
- outputLabel: 'K',
337
- loraKey: 'k_proj',
338
- numTokens,
339
- hiddenSize,
340
- layerIdx,
341
- kernelPath,
342
- matmulOutputDtype,
343
- getWeightBuffer,
344
- lora,
345
- releaseTemporary,
346
- });
347
-
348
- const vTensor = await projectSingleQkvTensor({
349
- recorder,
350
- normed,
351
- layerWeights,
352
- weightKey: 'vProj',
353
- role: 'v_proj',
354
- outputSize: numKVHeads * headDim,
355
- outputLabel: 'V',
356
- loraKey: 'v_proj',
357
- numTokens,
358
- hiddenSize,
359
- layerIdx,
360
- kernelPath,
361
- matmulOutputDtype,
362
- getWeightBuffer,
363
- lora,
364
- releaseTemporary,
365
- });
366
-
367
- return { qTensor, qGateTensor, kTensor, vTensor, usedFusedQKV: false };
368
448
  }
369
449
 
370
450
  export async function applyAttentionQKNorm({
@@ -24,10 +24,12 @@ import { selectRuleValue } from '../../../../rules/rule-registry.js';
24
24
  import { SlidingWindowKVCache } from '../../../kv-cache.js';
25
25
  import {
26
26
  recordAttentionInputs,
27
+ shouldForceF32AttentionProjectionForRoPE,
27
28
  resolveAttentionProjectionOutputDtype,
28
29
  projectAttentionQKV,
29
30
  applyAttentionQKNorm,
30
31
  } from './projections.js';
32
+ import { prepareAttentionProjectionInput } from './output-projection.js';
31
33
 
32
34
  import { releaseOrTrack, shouldDebugLayer } from './types.js';
33
35
 
@@ -90,9 +92,20 @@ export async function recordLayerAttentionGPU(
90
92
  const allowF16Attention = wantsF16Output && kvCacheDtype === 'f16';
91
93
  let attentionInput = input;
92
94
  let attentionInputTemp = false;
95
+ let normed = attentionInput;
96
+ let qTensor = null;
97
+ let qGateTensor = null;
98
+ let kTensor = null;
99
+ let vTensor = null;
100
+ let attnOutput = null;
101
+ let attnForProjection = null;
102
+ let output = null;
103
+ let finalOutput = null;
104
+ let oProjInputTemp = null;
93
105
  if (wantsF16Output && !allowF16Attention) {
94
106
  attentionInput = await recordCastF16ToF32(recorder, input);
95
107
  attentionInputTemp = true;
108
+ normed = attentionInput;
96
109
  }
97
110
 
98
111
  if (!layerWeights) {
@@ -108,7 +121,7 @@ export async function recordLayerAttentionGPU(
108
121
 
109
122
  // 1. Input norm
110
123
 
111
- let normed = attentionInput;
124
+ try {
112
125
  if (!skipInputNorm && layerWeights.inputNorm && getNormWeightBuffer) {
113
126
  const normWeightBuf = getNormWeightBuffer(layerWeights.inputNorm, 'input_norm');
114
127
  normed = await recordRMSNorm(recorder, attentionInput, normWeightBuf, rmsNormEps, {
@@ -131,8 +144,16 @@ export async function recordLayerAttentionGPU(
131
144
  }
132
145
 
133
146
  // 2. Q/K/V projections
134
- const matmulOutputDtype = resolveAttentionProjectionOutputDtype(desiredOutputDtype);
135
- let { qTensor, qGateTensor, kTensor, vTensor, usedFusedQKV } = await projectAttentionQKV({
147
+ const matmulOutputDtype = resolveAttentionProjectionOutputDtype(desiredOutputDtype, {
148
+ forceF32: shouldForceF32AttentionProjectionForRoPE({
149
+ attentionInputDtype: desiredOutputDtype,
150
+ headDim,
151
+ rotaryDim: config.ropeRotaryDim,
152
+ interleaved: config.ropeInterleaved,
153
+ }),
154
+ });
155
+ let usedFusedQKV = false;
156
+ ({ qTensor, qGateTensor, kTensor, vTensor, usedFusedQKV } = await projectAttentionQKV({
136
157
  recorder,
137
158
  normed,
138
159
  layerWeights,
@@ -153,7 +174,7 @@ export async function recordLayerAttentionGPU(
153
174
  trace.attn(layerIdx, `Using fused QKV path: ${qSizeFused}+${kSizeFused}+${vSizeFused}=${totalSize}`);
154
175
  }
155
176
  : null,
156
- });
177
+ }));
157
178
 
158
179
  // Optional per-head Q/K normalization.
159
180
  // Some models use RMSNorm with (1+weight) offset formula, controlled by rmsNormWeightOffset.
@@ -502,9 +523,9 @@ export async function recordLayerAttentionGPU(
502
523
  throw new Error(`Unsupported attention kernel variant "${attentionKernelVariant}" at layer ${layerIdx}`);
503
524
  }
504
525
 
505
- const attnOutput = await runAttentionKernel();
526
+ attnOutput = await runAttentionKernel();
506
527
 
507
- let attnForProjection = attnOutput;
528
+ attnForProjection = attnOutput;
508
529
  if (qGateTensor) {
509
530
  attnForProjection = await recordSiLU(recorder, attnOutput, {
510
531
  size: numTokens * numHeads * headDim,
@@ -518,19 +539,19 @@ export async function recordLayerAttentionGPU(
518
539
 
519
540
  // 6. Output projection (with optional fused residual for decode)
520
541
 
521
- let output;
542
+ output = null;
522
543
  let residualFused = false;
523
544
  let oProjInput = attnForProjection;
524
- let oProjInputTemp = null;
545
+ oProjInputTemp = null;
525
546
  if (layerWeights.oProj && getWeightBuffer) {
547
+ ({ oProjInput, oProjInputTemp } = await prepareAttentionProjectionInput(
548
+ attnForProjection,
549
+ matmulOutputDtype,
550
+ (tensor) => recordCastF32ToF16(recorder, tensor)
551
+ ));
526
552
  const oProjBuf = getWeightBuffer(layerWeights.oProj, 'o_proj');
527
553
  const loraO = getLoRAModule(lora, layerIdx, 'o_proj');
528
554
 
529
- if (matmulOutputDtype === 'f16' && attnForProjection.dtype !== 'f16') {
530
- oProjInput = await recordCastF32ToF16(recorder, attnForProjection);
531
- oProjInputTemp = oProjInput;
532
- }
533
-
534
555
  // Use fused o_proj + residual for decode when possible
535
556
  // Note: dtype from WeightBuffer metadata (buffer-dtypes WeakMap removed)
536
557
  const oProjDtype = getWeightDtype(oProjBuf);
@@ -589,7 +610,7 @@ export async function recordLayerAttentionGPU(
589
610
  }
590
611
  }
591
612
 
592
- let finalOutput = output;
613
+ finalOutput = output;
593
614
 
594
615
  const buffersToTrack = [];
595
616
  if (output.buffer !== attnForProjection.buffer) {
@@ -619,4 +640,46 @@ export async function recordLayerAttentionGPU(
619
640
  }
620
641
 
621
642
  return { output: finalOutput, residualFused };
643
+ } catch (error) {
644
+ const tracked = new Set();
645
+ const trackOnce = (buffer) => {
646
+ if (!buffer || tracked.has(buffer)) return;
647
+ tracked.add(buffer);
648
+ recorder.trackTemporaryBuffer(buffer);
649
+ };
650
+ if (finalOutput?.buffer && finalOutput.buffer !== output?.buffer) {
651
+ trackOnce(finalOutput.buffer);
652
+ }
653
+ if (output?.buffer && output.buffer !== attnForProjection?.buffer) {
654
+ trackOnce(output.buffer);
655
+ }
656
+ if (oProjInputTemp?.buffer) {
657
+ trackOnce(oProjInputTemp.buffer);
658
+ }
659
+ if (attnForProjection?.buffer && attnForProjection.buffer !== attnOutput?.buffer) {
660
+ trackOnce(attnForProjection.buffer);
661
+ }
662
+ if (attnOutput?.buffer) {
663
+ trackOnce(attnOutput.buffer);
664
+ }
665
+ if (qGateTensor?.buffer) {
666
+ trackOnce(qGateTensor.buffer);
667
+ }
668
+ if (qTensor?.buffer) {
669
+ trackOnce(qTensor.buffer);
670
+ }
671
+ if (kTensor?.buffer) {
672
+ trackOnce(kTensor.buffer);
673
+ }
674
+ if (vTensor?.buffer) {
675
+ trackOnce(vTensor.buffer);
676
+ }
677
+ if (normed?.buffer && normed.buffer !== attentionInput?.buffer) {
678
+ trackOnce(normed.buffer);
679
+ }
680
+ if (attentionInputTemp && attentionInput?.buffer) {
681
+ trackOnce(attentionInput.buffer);
682
+ }
683
+ throw error;
684
+ }
622
685
  }