@simulatte/doppler 0.1.5 → 0.1.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (392) hide show
  1. package/CHANGELOG.md +126 -0
  2. package/README.md +25 -17
  3. package/package.json +20 -4
  4. package/src/adapters/adapter-registry.js +12 -1
  5. package/src/adapters/lora-loader.js +23 -6
  6. package/src/bridge/extension-client.d.ts +5 -0
  7. package/src/bridge/extension-client.js +40 -0
  8. package/src/bridge/index.d.ts +2 -1
  9. package/src/bridge/index.js +6 -4
  10. package/src/browser/browser-converter.js +26 -1
  11. package/src/browser/file-picker.js +6 -0
  12. package/src/browser/safetensors-parser-browser.js +84 -1
  13. package/src/browser/shard-io-browser.js +2 -2
  14. package/src/browser/tensor-source-download.js +8 -2
  15. package/src/browser/tensor-source-http.d.ts +1 -0
  16. package/src/browser/tensor-source-http.js +5 -1
  17. package/src/client/doppler-api.browser.js +20 -4
  18. package/src/client/doppler-api.js +19 -3
  19. package/src/client/doppler-provider/generation.js +12 -0
  20. package/src/client/doppler-provider/model-manager.d.ts +10 -0
  21. package/src/client/doppler-provider/model-manager.js +91 -19
  22. package/src/client/doppler-provider/source-runtime.d.ts +2 -1
  23. package/src/client/doppler-provider/source-runtime.js +132 -13
  24. package/src/client/doppler-registry.json +8 -7
  25. package/src/config/backward-registry-loader.js +17 -2
  26. package/src/config/execution-v0-contract-check.js +113 -15
  27. package/src/config/kernel-path-contract-check.js +57 -29
  28. package/src/config/kernel-path-loader.js +5 -36
  29. package/src/config/kernels/kernel-ref-digests.js +39 -39
  30. package/src/config/kernels/registry.js +14 -1
  31. package/src/config/kernels/registry.json +49 -7
  32. package/src/config/loader.d.ts +1 -1
  33. package/src/config/loader.js +43 -4
  34. package/src/config/merge-contract-check.js +59 -4
  35. package/src/config/merge-helpers.js +128 -7
  36. package/src/config/merge.d.ts +1 -0
  37. package/src/config/merge.js +28 -0
  38. package/src/config/param-validator.js +47 -2
  39. package/src/config/presets/kernel-paths/{gemma2-q4k-dequant-f32a.json → gemma2-q4k-dequant-f32a-nosubgroups.json} +3 -3
  40. package/src/config/presets/kernel-paths/gemma3-f16-fused-f32a-online-streamingprefill.json +223 -0
  41. package/src/config/presets/kernel-paths/{gemma3-q4k-dequant-f32a.json → gemma3-q4k-dequant-f32a-nosubgroups.json} +3 -3
  42. package/src/config/presets/kernel-paths/registry.json +29 -8
  43. package/src/config/presets/models/gemma2.json +2 -2
  44. package/src/config/presets/models/qwen3.json +9 -2
  45. package/src/config/presets/models/transformer.json +5 -0
  46. package/src/config/presets/runtime/experiments/bench/gemma3-bench-q4k.json +1 -1
  47. package/src/config/presets/runtime/experiments/debug/gemma3-debug-q4k.json +1 -1
  48. package/src/config/presets/runtime/experiments/verify/gemma3-verify.json +1 -1
  49. package/src/config/presets/runtime/kernels/dequant-f16-q4k.json +6 -13
  50. package/src/config/presets/runtime/kernels/dequant-f32-q4k.json +6 -13
  51. package/src/config/presets/runtime/kernels/embeddinggemma-q4k-dequant-f32a.json +37 -0
  52. package/src/config/presets/runtime/kernels/fused-q4k.json +6 -13
  53. package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f16a.json +33 -0
  54. package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f32a-nosubgroups.json +33 -0
  55. package/src/config/presets/runtime/kernels/gemma2-q4k-fused-f32a.json +33 -0
  56. package/src/config/presets/runtime/kernels/safe-q4k.json +6 -13
  57. package/src/config/presets/runtime/platform/metal-apple-q4k.json +1 -1
  58. package/src/config/required-inference-fields-contract-check.js +6 -0
  59. package/src/config/runtime.js +6 -1
  60. package/src/config/schema/debug.schema.d.ts +5 -0
  61. package/src/config/schema/doppler.schema.js +16 -21
  62. package/src/config/schema/inference-defaults.schema.js +6 -3
  63. package/src/config/schema/inference.schema.d.ts +9 -0
  64. package/src/config/schema/kernel-path.schema.d.ts +11 -1
  65. package/src/config/schema/kernel-thresholds.schema.js +12 -4
  66. package/src/config/schema/manifest.schema.d.ts +8 -1
  67. package/src/config/schema/manifest.schema.js +19 -3
  68. package/src/config/training-defaults.js +30 -22
  69. package/src/converter/conversion-plan.js +94 -9
  70. package/src/converter/core.d.ts +7 -0
  71. package/src/converter/core.js +14 -9
  72. package/src/converter/execution-v0-manifest.js +4 -1
  73. package/src/converter/index.d.ts +1 -0
  74. package/src/converter/index.js +1 -0
  75. package/src/converter/manifest-inference.js +43 -12
  76. package/src/converter/parsers/diffusion.js +0 -3
  77. package/src/converter/quantization-info.js +35 -15
  78. package/src/converter/rope-config.js +42 -0
  79. package/src/converter/shard-packer.d.ts +1 -1
  80. package/src/converter/shard-packer.js +4 -1
  81. package/src/debug/config.js +123 -11
  82. package/src/debug/signals.js +7 -1
  83. package/src/debug/tensor.d.ts +2 -0
  84. package/src/debug/tensor.js +13 -2
  85. package/src/distribution/p2p-control-plane.js +52 -12
  86. package/src/distribution/p2p-observability.js +43 -7
  87. package/src/distribution/p2p-webrtc-browser.js +20 -0
  88. package/src/distribution/shard-delivery.js +77 -26
  89. package/src/formats/gguf/types.js +33 -16
  90. package/src/formats/rdrr/groups.d.ts +12 -4
  91. package/src/formats/rdrr/groups.js +3 -6
  92. package/src/formats/rdrr/parsing.js +39 -2
  93. package/src/formats/rdrr/types.d.ts +2 -1
  94. package/src/gpu/command-recorder.js +86 -61
  95. package/src/gpu/device.d.ts +1 -0
  96. package/src/gpu/device.js +131 -19
  97. package/src/gpu/kernel-tuner/benchmarks.js +326 -316
  98. package/src/gpu/kernel-tuner/cache.js +71 -4
  99. package/src/gpu/kernel-tuner/tuner.js +22 -4
  100. package/src/gpu/kernels/attention.js +113 -34
  101. package/src/gpu/kernels/backward/adam.js +62 -58
  102. package/src/gpu/kernels/backward/attention_backward.js +257 -169
  103. package/src/gpu/kernels/backward/conv2d_backward.js +14 -1
  104. package/src/gpu/kernels/bias_add.wgsl +8 -6
  105. package/src/gpu/kernels/bias_add_f16.wgsl +8 -5
  106. package/src/gpu/kernels/cast.js +191 -149
  107. package/src/gpu/kernels/check-stop.js +33 -44
  108. package/src/gpu/kernels/conv2d.js +27 -17
  109. package/src/gpu/kernels/conv2d.wgsl +7 -8
  110. package/src/gpu/kernels/conv2d_f16.wgsl +7 -8
  111. package/src/gpu/kernels/cross_entropy_loss.js +21 -15
  112. package/src/gpu/kernels/depthwise_conv2d.js +37 -26
  113. package/src/gpu/kernels/depthwise_conv2d.wgsl +6 -9
  114. package/src/gpu/kernels/depthwise_conv2d_f16.wgsl +6 -9
  115. package/src/gpu/kernels/dequant.js +178 -126
  116. package/src/gpu/kernels/energy.d.ts +3 -21
  117. package/src/gpu/kernels/energy.js +111 -88
  118. package/src/gpu/kernels/feature-check.js +1 -1
  119. package/src/gpu/kernels/fused_ffn.js +84 -65
  120. package/src/gpu/kernels/fused_matmul_residual.js +56 -33
  121. package/src/gpu/kernels/fused_matmul_rmsnorm.js +62 -45
  122. package/src/gpu/kernels/gather.js +33 -15
  123. package/src/gpu/kernels/gelu.js +19 -11
  124. package/src/gpu/kernels/grouped_pointwise_conv2d.js +34 -23
  125. package/src/gpu/kernels/grouped_pointwise_conv2d.wgsl +6 -9
  126. package/src/gpu/kernels/grouped_pointwise_conv2d_f16.wgsl +6 -9
  127. package/src/gpu/kernels/groupnorm.js +34 -23
  128. package/src/gpu/kernels/kv-quantize.js +5 -2
  129. package/src/gpu/kernels/layernorm.js +35 -19
  130. package/src/gpu/kernels/logit-merge.js +5 -3
  131. package/src/gpu/kernels/matmul.js +83 -39
  132. package/src/gpu/kernels/modulate.js +23 -15
  133. package/src/gpu/kernels/moe.js +221 -175
  134. package/src/gpu/kernels/pixel_shuffle.js +22 -14
  135. package/src/gpu/kernels/pixel_shuffle.wgsl +4 -5
  136. package/src/gpu/kernels/pixel_shuffle_f16.wgsl +4 -5
  137. package/src/gpu/kernels/relu.js +31 -10
  138. package/src/gpu/kernels/relu.wgsl +2 -1
  139. package/src/gpu/kernels/relu_f16.wgsl +2 -1
  140. package/src/gpu/kernels/repeat_channels.js +25 -17
  141. package/src/gpu/kernels/repeat_channels.wgsl +4 -5
  142. package/src/gpu/kernels/repeat_channels_f16.wgsl +4 -5
  143. package/src/gpu/kernels/residual.js +69 -23
  144. package/src/gpu/kernels/residual.wgsl +6 -3
  145. package/src/gpu/kernels/residual_f16.wgsl +2 -1
  146. package/src/gpu/kernels/residual_f16_vec4.wgsl +2 -1
  147. package/src/gpu/kernels/residual_vec4.wgsl +2 -1
  148. package/src/gpu/kernels/rmsnorm.js +96 -28
  149. package/src/gpu/kernels/rmsnorm.wgsl +14 -6
  150. package/src/gpu/kernels/rmsnorm_f16.wgsl +10 -2
  151. package/src/gpu/kernels/rope.d.ts +2 -0
  152. package/src/gpu/kernels/rope.js +14 -1
  153. package/src/gpu/kernels/rope.wgsl +56 -40
  154. package/src/gpu/kernels/sample.js +27 -38
  155. package/src/gpu/kernels/sana_linear_attention.js +19 -12
  156. package/src/gpu/kernels/sana_linear_attention_apply.wgsl +4 -5
  157. package/src/gpu/kernels/sana_linear_attention_apply_f16.wgsl +4 -5
  158. package/src/gpu/kernels/sana_linear_attention_summary.wgsl +4 -0
  159. package/src/gpu/kernels/sana_linear_attention_summary_f16.wgsl +4 -0
  160. package/src/gpu/kernels/scale.js +18 -11
  161. package/src/gpu/kernels/shader-cache.js +4 -2
  162. package/src/gpu/kernels/silu.d.ts +1 -0
  163. package/src/gpu/kernels/silu.js +148 -82
  164. package/src/gpu/kernels/silu.wgsl +19 -9
  165. package/src/gpu/kernels/silu_f16.wgsl +19 -9
  166. package/src/gpu/kernels/softmax.js +44 -25
  167. package/src/gpu/kernels/split_qkv.js +23 -13
  168. package/src/gpu/kernels/transpose.js +31 -10
  169. package/src/gpu/kernels/transpose.wgsl +6 -5
  170. package/src/gpu/kernels/upsample2d.js +22 -13
  171. package/src/gpu/kernels/upsample2d.wgsl +6 -9
  172. package/src/gpu/kernels/upsample2d_f16.wgsl +6 -9
  173. package/src/gpu/kernels/utils.js +35 -13
  174. package/src/gpu/partitioned-buffer-pool.js +10 -2
  175. package/src/gpu/perf-guards.js +2 -9
  176. package/src/gpu/profiler.js +27 -22
  177. package/src/gpu/readback-utils.d.ts +16 -0
  178. package/src/gpu/readback-utils.js +41 -0
  179. package/src/gpu/submit-tracker.js +13 -0
  180. package/src/gpu/uniform-cache.d.ts +1 -0
  181. package/src/gpu/uniform-cache.js +30 -9
  182. package/src/hotswap/intent-bundle.js +6 -0
  183. package/src/hotswap/manifest.d.ts +10 -1
  184. package/src/hotswap/manifest.js +12 -2
  185. package/src/hotswap/runtime.js +30 -8
  186. package/src/index-browser.d.ts +44 -0
  187. package/src/index-browser.js +14 -0
  188. package/src/inference/browser-harness-contract-helpers.d.ts +5 -0
  189. package/src/inference/browser-harness-contract-helpers.js +28 -0
  190. package/src/inference/browser-harness-diffusion-energy-suites.d.ts +2 -0
  191. package/src/inference/browser-harness-diffusion-energy-suites.js +269 -0
  192. package/src/inference/browser-harness-model-helpers.d.ts +16 -0
  193. package/src/inference/browser-harness-model-helpers.js +217 -0
  194. package/src/inference/browser-harness-report-helpers.d.ts +7 -0
  195. package/src/inference/browser-harness-report-helpers.js +42 -0
  196. package/src/inference/browser-harness-runtime-helpers.d.ts +61 -0
  197. package/src/inference/browser-harness-runtime-helpers.js +415 -0
  198. package/src/inference/browser-harness-suite-helpers.d.ts +28 -0
  199. package/src/inference/browser-harness-suite-helpers.js +268 -0
  200. package/src/inference/browser-harness-text-helpers.d.ts +27 -0
  201. package/src/inference/browser-harness-text-helpers.js +788 -0
  202. package/src/inference/browser-harness.d.ts +6 -0
  203. package/src/inference/browser-harness.js +130 -1950
  204. package/src/inference/kv-cache/base.js +140 -94
  205. package/src/inference/kv-cache/tiered.js +5 -3
  206. package/src/inference/moe-router.js +88 -56
  207. package/src/inference/multi-model-network.js +5 -3
  208. package/src/inference/network-evolution.d.ts +11 -2
  209. package/src/inference/network-evolution.js +20 -21
  210. package/src/inference/pipelines/context.d.ts +3 -0
  211. package/src/inference/pipelines/context.js +142 -2
  212. package/src/inference/pipelines/diffusion/helpers.js +7 -2
  213. package/src/inference/pipelines/diffusion/pipeline.js +17 -7
  214. package/src/inference/pipelines/diffusion/sd3-transformer.js +10 -10
  215. package/src/inference/pipelines/diffusion/text-encoder-gpu.d.ts +5 -0
  216. package/src/inference/pipelines/diffusion/text-encoder-gpu.js +27 -15
  217. package/src/inference/pipelines/diffusion/vae.js +3 -7
  218. package/src/inference/pipelines/energy/pipeline.js +27 -21
  219. package/src/inference/pipelines/energy/quintel.d.ts +5 -0
  220. package/src/inference/pipelines/energy/quintel.js +11 -0
  221. package/src/inference/pipelines/energy-head/row-head-pipeline.js +17 -13
  222. package/src/inference/pipelines/structured/json-head-pipeline.js +26 -11
  223. package/src/inference/pipelines/text/attention/projections.js +151 -101
  224. package/src/inference/pipelines/text/attention/record.js +73 -10
  225. package/src/inference/pipelines/text/attention/run.js +73 -10
  226. package/src/inference/pipelines/text/chat-format.js +25 -1
  227. package/src/inference/pipelines/text/config.d.ts +4 -0
  228. package/src/inference/pipelines/text/config.js +71 -5
  229. package/src/inference/pipelines/text/embed.js +2 -8
  230. package/src/inference/pipelines/text/execution-plan.js +64 -50
  231. package/src/inference/pipelines/text/execution-v0-contract-helpers.d.ts +59 -0
  232. package/src/inference/pipelines/text/execution-v0-contract-helpers.js +937 -0
  233. package/src/inference/pipelines/text/execution-v0-runtime-builders.d.ts +15 -0
  234. package/src/inference/pipelines/text/execution-v0-runtime-builders.js +279 -0
  235. package/src/inference/pipelines/text/execution-v0.js +78 -1002
  236. package/src/inference/pipelines/text/ffn/standard.js +3 -0
  237. package/src/inference/pipelines/text/generator-steps.d.ts +46 -0
  238. package/src/inference/pipelines/text/generator-steps.js +298 -207
  239. package/src/inference/pipelines/text/generator.js +6 -23
  240. package/src/inference/pipelines/text/init.d.ts +4 -0
  241. package/src/inference/pipelines/text/init.js +134 -29
  242. package/src/inference/pipelines/text/kernel-path-auto-select.js +2 -0
  243. package/src/inference/pipelines/text/kernel-trace.d.ts +2 -0
  244. package/src/inference/pipelines/text/kernel-trace.js +6 -0
  245. package/src/inference/pipelines/text/layer.js +14 -9
  246. package/src/inference/pipelines/text/linear-attention.d.ts +10 -0
  247. package/src/inference/pipelines/text/linear-attention.js +80 -6
  248. package/src/inference/pipelines/text/logits/gpu.js +10 -5
  249. package/src/inference/pipelines/text/logits/index.js +10 -11
  250. package/src/inference/pipelines/text/logits/utils.d.ts +7 -0
  251. package/src/inference/pipelines/text/logits/utils.js +9 -0
  252. package/src/inference/pipelines/text/lora-apply.js +50 -32
  253. package/src/inference/pipelines/text/model-load.js +279 -104
  254. package/src/inference/pipelines/text/moe-cache.js +5 -4
  255. package/src/inference/pipelines/text/moe-cpu-gptoss.js +74 -69
  256. package/src/inference/pipelines/text/moe-cpu.js +42 -38
  257. package/src/inference/pipelines/text/moe-gpu.js +110 -86
  258. package/src/inference/pipelines/text/ops.js +90 -90
  259. package/src/inference/pipelines/text/probes.js +9 -9
  260. package/src/inference/pipelines/text/weights.js +17 -7
  261. package/src/inference/pipelines/text.js +17 -1
  262. package/src/inference/speculative.d.ts +2 -2
  263. package/src/inference/speculative.js +4 -18
  264. package/src/inference/test-harness.d.ts +1 -1
  265. package/src/inference/test-harness.js +15 -5
  266. package/src/inference/tokenizer.d.ts +0 -5
  267. package/src/inference/tokenizer.js +4 -23
  268. package/src/inference/tokenizers/bpe.js +9 -0
  269. package/src/inference/tokenizers/bundled.js +176 -33
  270. package/src/inference/tokenizers/sentencepiece.js +12 -0
  271. package/src/loader/doppler-loader.js +38 -22
  272. package/src/loader/dtype-utils.js +3 -44
  273. package/src/loader/embedding-loader.js +7 -3
  274. package/src/loader/experts/expert-cache.js +13 -6
  275. package/src/loader/experts/expert-loader.js +10 -6
  276. package/src/loader/final-weights-loader.js +8 -4
  277. package/src/loader/layer-loader.js +2 -1
  278. package/src/loader/loader-state.js +2 -2
  279. package/src/loader/memory-monitor.js +8 -0
  280. package/src/loader/multi-model-loader.d.ts +14 -0
  281. package/src/loader/multi-model-loader.js +70 -24
  282. package/src/loader/shard-cache.js +81 -12
  283. package/src/loader/shard-resolver.js +25 -3
  284. package/src/loader/tensors/tensor-loader.js +209 -144
  285. package/src/loader/tensors/tensor-reader.js +76 -19
  286. package/src/loader/weight-downcast.js +1 -1
  287. package/src/memory/buffer-pool.d.ts +9 -1
  288. package/src/memory/buffer-pool.js +109 -44
  289. package/src/memory/unified-detect.js +1 -1
  290. package/src/rules/inference/kernel-path.rules.json +24 -8
  291. package/src/rules/rule-registry.js +25 -1
  292. package/src/rules/tooling/command-runtime.rules.json +18 -0
  293. package/src/storage/backends/opfs-store.js +68 -24
  294. package/src/storage/downloader.js +364 -83
  295. package/src/storage/index.d.ts +3 -0
  296. package/src/storage/index.js +3 -0
  297. package/src/storage/preflight.d.ts +2 -2
  298. package/src/storage/preflight.js +24 -2
  299. package/src/storage/quickstart-downloader.js +11 -5
  300. package/src/storage/registry.js +10 -4
  301. package/src/storage/reports.js +1 -1
  302. package/src/storage/shard-manager.d.ts +15 -1
  303. package/src/storage/shard-manager.js +51 -3
  304. package/src/storage/source-artifact-store.d.ts +52 -0
  305. package/src/storage/source-artifact-store.js +234 -0
  306. package/src/tooling/command-api-constants.d.ts +9 -0
  307. package/src/tooling/command-api-constants.js +9 -0
  308. package/src/tooling/command-api-family-normalizers.d.ts +9 -0
  309. package/src/tooling/command-api-family-normalizers.js +343 -0
  310. package/src/tooling/command-api-helpers.d.ts +25 -0
  311. package/src/tooling/command-api-helpers.js +262 -0
  312. package/src/tooling/command-api.d.ts +27 -1
  313. package/src/tooling/command-api.js +26 -473
  314. package/src/tooling/command-envelope.js +4 -1
  315. package/src/tooling/command-runner-shared.js +52 -18
  316. package/src/tooling/lean-execution-contract.js +150 -3
  317. package/src/tooling/node-browser-command-runner.d.ts +4 -0
  318. package/src/tooling/node-browser-command-runner.js +218 -273
  319. package/src/tooling/node-command-runner.js +44 -3
  320. package/src/tooling/node-converter.js +27 -1
  321. package/src/tooling/node-source-runtime.d.ts +1 -1
  322. package/src/tooling/node-source-runtime.js +84 -3
  323. package/src/tooling/node-webgpu.js +30 -105
  324. package/src/tooling/opfs-cache.js +21 -4
  325. package/src/tooling/runtime-input-composition.d.ts +38 -0
  326. package/src/tooling/runtime-input-composition.js +86 -0
  327. package/src/tooling/source-runtime-bundle.d.ts +40 -5
  328. package/src/tooling/source-runtime-bundle.js +261 -34
  329. package/src/tooling/source-runtime-materializer.d.ts +6 -0
  330. package/src/tooling/source-runtime-materializer.js +93 -0
  331. package/src/training/attention-backward.js +32 -17
  332. package/src/training/autograd.js +80 -52
  333. package/src/training/checkpoint-watch.d.ts +8 -0
  334. package/src/training/checkpoint-watch.js +139 -0
  335. package/src/training/checkpoint.d.ts +6 -1
  336. package/src/training/checkpoint.js +46 -7
  337. package/src/training/clip.js +2 -1
  338. package/src/training/datasets/token-batch.js +20 -8
  339. package/src/training/distillation/artifacts.d.ts +71 -0
  340. package/src/training/distillation/artifacts.js +132 -0
  341. package/src/training/distillation/checkpoint-watch.d.ts +10 -0
  342. package/src/training/distillation/checkpoint-watch.js +58 -0
  343. package/src/training/distillation/dataset.d.ts +59 -0
  344. package/src/training/distillation/dataset.js +337 -0
  345. package/src/training/distillation/eval.d.ts +34 -0
  346. package/src/training/distillation/eval.js +310 -0
  347. package/src/training/distillation/index.d.ts +29 -0
  348. package/src/training/distillation/index.js +29 -0
  349. package/src/training/distillation/runtime.d.ts +20 -0
  350. package/src/training/distillation/runtime.js +121 -0
  351. package/src/training/distillation/scoreboard.d.ts +6 -0
  352. package/src/training/distillation/scoreboard.js +8 -0
  353. package/src/training/distillation/stage-a.d.ts +45 -0
  354. package/src/training/distillation/stage-a.js +338 -0
  355. package/src/training/distillation/stage-b.d.ts +24 -0
  356. package/src/training/distillation/stage-b.js +20 -0
  357. package/src/training/distillation/student-fixture.d.ts +22 -0
  358. package/src/training/distillation/student-fixture.js +846 -0
  359. package/src/training/distillation/suite-data.d.ts +45 -0
  360. package/src/training/distillation/suite-data.js +189 -0
  361. package/src/training/index.d.ts +10 -0
  362. package/src/training/index.js +10 -0
  363. package/src/training/lora-pipeline.d.ts +40 -0
  364. package/src/training/lora-pipeline.js +793 -0
  365. package/src/training/lora.js +26 -12
  366. package/src/training/loss.js +5 -6
  367. package/src/training/objectives/cross_entropy.js +2 -5
  368. package/src/training/objectives/distill_kd.js +4 -8
  369. package/src/training/objectives/distill_triplet.js +4 -8
  370. package/src/training/objectives/ul_stage2_base.js +4 -8
  371. package/src/training/operator-artifacts.d.ts +62 -0
  372. package/src/training/operator-artifacts.js +140 -0
  373. package/src/training/operator-command.d.ts +5 -0
  374. package/src/training/operator-command.js +455 -0
  375. package/src/training/operator-eval.d.ts +48 -0
  376. package/src/training/operator-eval.js +230 -0
  377. package/src/training/operator-scoreboard.d.ts +5 -0
  378. package/src/training/operator-scoreboard.js +44 -0
  379. package/src/training/optimizer.js +19 -7
  380. package/src/training/runner.d.ts +52 -0
  381. package/src/training/runner.js +31 -5
  382. package/src/training/suite.d.ts +112 -0
  383. package/src/training/suite.js +24 -984
  384. package/src/training/tensor-factory.d.ts +9 -0
  385. package/src/training/tensor-factory.js +13 -0
  386. package/src/training/trainer.js +3 -5
  387. package/src/training/ul_dataset.js +3 -5
  388. package/src/training/workloads.d.ts +164 -0
  389. package/src/training/workloads.js +530 -0
  390. package/src/version.js +1 -1
  391. package/tools/convert-safetensors-node.js +22 -16
  392. package/tools/doppler-cli.js +179 -63
@@ -26,8 +26,8 @@ struct Uniforms {
26
26
  start_pos: u32, // Starting position (for decode)
27
27
  rope_base: f32, // Base frequency (default 10000)
28
28
  rope_scale: f32, // Scaling factor for extended context
29
- _pad0: u32,
30
- _pad1: u32,
29
+ rotary_dim: u32, // Rotary slice within head_dim
30
+ interleaved: u32, // 1 = adjacent pairs, 0 = rotate-half
31
31
  }
32
32
 
33
33
  @group(0) @binding(0) var<uniform> u: Uniforms;
@@ -46,7 +46,8 @@ fn main(
46
46
  let start_pos = u.start_pos;
47
47
 
48
48
  // Global thread index (one thread per complex pair)
49
- let half_dim = head_dim / 2u;
49
+ let rotary_dim = u.rotary_dim;
50
+ let half_dim = rotary_dim / 2u;
50
51
  let total_pairs = seq_len * num_heads * half_dim;
51
52
  let idx = global_id.x;
52
53
 
@@ -68,16 +69,18 @@ fn main(
68
69
 
69
70
  // Apply "rotate-half" layout: pair (x[i], x[i + half_dim])
70
71
  let base_idx = pos * num_heads * head_dim + head_idx * head_dim;
71
- let x0 = input[base_idx + pair_idx];
72
- let x1 = input[base_idx + pair_idx + half_dim];
72
+ let first_idx = select(pair_idx, pair_idx * 2u, u.interleaved == 1u);
73
+ let second_idx = select(pair_idx + half_dim, pair_idx * 2u + 1u, u.interleaved == 1u);
74
+ let x0 = input[base_idx + first_idx];
75
+ let x1 = input[base_idx + second_idx];
73
76
 
74
77
  // Apply rotation
75
78
  let y0 = x0 * cos_val - x1 * sin_val;
76
79
  let y1 = x0 * sin_val + x1 * cos_val;
77
80
 
78
81
  // Write back
79
- input[base_idx + pair_idx] = y0;
80
- input[base_idx + pair_idx + half_dim] = y1;
82
+ input[base_idx + first_idx] = y0;
83
+ input[base_idx + second_idx] = y1;
81
84
  }
82
85
 
83
86
  // Compute frequencies on-the-fly (no precomputation needed)
@@ -91,9 +94,10 @@ fn rope_compute_freqs(
91
94
  let start_pos = u.start_pos;
92
95
  let rope_base = u.rope_base;
93
96
  let rope_scale = u.rope_scale;
97
+ let rotary_dim = u.rotary_dim;
94
98
 
95
99
  let idx = global_id.x;
96
- let half_dim = head_dim / 2u;
100
+ let half_dim = rotary_dim / 2u;
97
101
  let total_pairs = seq_len * num_heads * half_dim;
98
102
 
99
103
  if (idx >= total_pairs) {
@@ -109,7 +113,7 @@ fn rope_compute_freqs(
109
113
  let actual_pos = f32(start_pos + pos) / rope_scale;
110
114
 
111
115
  // Compute frequency: 1 / (base^(2*pair_idx/head_dim))
112
- let exponent = f32(pair_idx * 2u) / f32(head_dim);
116
+ let exponent = f32(pair_idx * 2u) / f32(rotary_dim);
113
117
  let freq = 1.0 / pow(rope_base, exponent);
114
118
  let theta = actual_pos * freq;
115
119
 
@@ -118,12 +122,14 @@ fn rope_compute_freqs(
118
122
 
119
123
  // Apply "rotate-half" layout: pair (x[i], x[i + half_dim])
120
124
  let base_idx = pos * num_heads * head_dim + head_idx * head_dim;
121
- let x0 = input[base_idx + pair_idx];
122
- let x1 = input[base_idx + pair_idx + half_dim];
125
+ let first_idx = select(pair_idx, pair_idx * 2u, u.interleaved == 1u);
126
+ let second_idx = select(pair_idx + half_dim, pair_idx * 2u + 1u, u.interleaved == 1u);
127
+ let x0 = input[base_idx + first_idx];
128
+ let x1 = input[base_idx + second_idx];
123
129
 
124
130
  // Apply rotation
125
- input[base_idx + pair_idx] = x0 * cos_val - x1 * sin_val;
126
- input[base_idx + pair_idx + half_dim] = x0 * sin_val + x1 * cos_val;
131
+ input[base_idx + first_idx] = x0 * cos_val - x1 * sin_val;
132
+ input[base_idx + second_idx] = x0 * sin_val + x1 * cos_val;
127
133
  }
128
134
 
129
135
  // Apply RoPE to both Q and K in one pass
@@ -138,10 +144,11 @@ fn rope_qk(
138
144
  let start_pos = u.start_pos;
139
145
  let rope_base = u.rope_base;
140
146
  let rope_scale = u.rope_scale;
147
+ let rotary_dim = u.rotary_dim;
141
148
 
142
149
  let idx = global_id.x;
143
150
  // Each thread handles one Q-K pair at one dimension pair
144
- let half_dim = head_dim / 2u;
151
+ let half_dim = rotary_dim / 2u;
145
152
  let total_pairs = seq_len * num_heads * half_dim;
146
153
 
147
154
  if (idx >= total_pairs) {
@@ -156,7 +163,7 @@ fn rope_qk(
156
163
  let actual_pos = f32(start_pos + pos) / rope_scale;
157
164
 
158
165
  // Compute frequency
159
- let exponent = f32(pair_idx * 2u) / f32(head_dim);
166
+ let exponent = f32(pair_idx * 2u) / f32(rotary_dim);
160
167
  let freq = 1.0 / pow(rope_base, exponent);
161
168
  let theta = actual_pos * freq;
162
169
 
@@ -168,16 +175,18 @@ fn rope_qk(
168
175
  let k_base_idx = q_base_idx + head_dim; // K starts after Q
169
176
 
170
177
  // Process Q
171
- let q0 = input[q_base_idx + pair_idx];
172
- let q1 = input[q_base_idx + pair_idx + half_dim];
173
- input[q_base_idx + pair_idx] = q0 * cos_val - q1 * sin_val;
174
- input[q_base_idx + pair_idx + half_dim] = q0 * sin_val + q1 * cos_val;
178
+ let first_idx = select(pair_idx, pair_idx * 2u, u.interleaved == 1u);
179
+ let second_idx = select(pair_idx + half_dim, pair_idx * 2u + 1u, u.interleaved == 1u);
180
+ let q0 = input[q_base_idx + first_idx];
181
+ let q1 = input[q_base_idx + second_idx];
182
+ input[q_base_idx + first_idx] = q0 * cos_val - q1 * sin_val;
183
+ input[q_base_idx + second_idx] = q0 * sin_val + q1 * cos_val;
175
184
 
176
185
  // Process K
177
- let k0 = input[k_base_idx + pair_idx];
178
- let k1 = input[k_base_idx + pair_idx + half_dim];
179
- input[k_base_idx + pair_idx] = k0 * cos_val - k1 * sin_val;
180
- input[k_base_idx + pair_idx + half_dim] = k0 * sin_val + k1 * cos_val;
186
+ let k0 = input[k_base_idx + first_idx];
187
+ let k1 = input[k_base_idx + second_idx];
188
+ input[k_base_idx + first_idx] = k0 * cos_val - k1 * sin_val;
189
+ input[k_base_idx + second_idx] = k0 * sin_val + k1 * cos_val;
181
190
  }
182
191
 
183
192
  // Precompute frequency table (run once at init)
@@ -190,9 +199,10 @@ fn precompute_freqs(
190
199
  let seq_len = u.seq_len; // maxSeqLen for precomputation
191
200
  let rope_base = u.rope_base;
192
201
  let rope_scale = u.rope_scale;
202
+ let rotary_dim = u.rotary_dim;
193
203
 
194
204
  let idx = global_id.x;
195
- let half_dim = head_dim / 2u;
205
+ let half_dim = rotary_dim / 2u;
196
206
  let total_elements = seq_len * half_dim;
197
207
 
198
208
  if (idx >= total_elements) {
@@ -203,7 +213,7 @@ fn precompute_freqs(
203
213
  let dim_idx = idx % half_dim;
204
214
 
205
215
  let actual_pos = f32(pos) / rope_scale;
206
- let exponent = f32(dim_idx * 2u) / f32(head_dim);
216
+ let exponent = f32(dim_idx * 2u) / f32(rotary_dim);
207
217
  let freq = 1.0 / pow(rope_base, exponent);
208
218
  let theta = actual_pos * freq;
209
219
 
@@ -218,6 +228,7 @@ fn rope_ntk_scaled(
218
228
  @builtin(global_invocation_id) global_id: vec3<u32>
219
229
  ) {
220
230
  let head_dim = u.head_dim;
231
+ let rotary_dim = u.rotary_dim;
221
232
  let num_heads = u.num_heads;
222
233
  let seq_len = u.seq_len;
223
234
  let start_pos = u.start_pos;
@@ -225,7 +236,7 @@ fn rope_ntk_scaled(
225
236
  let rope_scale = u.rope_scale;
226
237
 
227
238
  let idx = global_id.x;
228
- let half_dim = head_dim / 2u;
239
+ let half_dim = rotary_dim / 2u;
229
240
  let total_pairs = seq_len * num_heads * half_dim;
230
241
 
231
242
  if (idx >= total_pairs) {
@@ -234,7 +245,7 @@ fn rope_ntk_scaled(
234
245
 
235
246
  // NTK scaling: increase base proportionally to scale factor
236
247
  // This preserves high-frequency components better than linear interpolation
237
- rope_base = rope_base * pow(rope_scale, f32(head_dim) / (f32(head_dim) - 2.0));
248
+ rope_base = rope_base * pow(rope_scale, f32(rotary_dim) / (f32(rotary_dim) - 2.0));
238
249
 
239
250
  let pos = idx / (num_heads * half_dim);
240
251
  let remainder = idx % (num_heads * half_dim);
@@ -243,7 +254,7 @@ fn rope_ntk_scaled(
243
254
 
244
255
  let actual_pos = f32(start_pos + pos);
245
256
 
246
- let exponent = f32(pair_idx * 2u) / f32(head_dim);
257
+ let exponent = f32(pair_idx * 2u) / f32(rotary_dim);
247
258
  let freq = 1.0 / pow(rope_base, exponent);
248
259
  let theta = actual_pos * freq;
249
260
 
@@ -251,11 +262,13 @@ fn rope_ntk_scaled(
251
262
  let sin_val = sin(theta);
252
263
 
253
264
  let base_idx = pos * num_heads * head_dim + head_idx * head_dim;
254
- let x0 = input[base_idx + pair_idx];
255
- let x1 = input[base_idx + pair_idx + half_dim];
265
+ let first_idx = select(pair_idx, pair_idx * 2u, u.interleaved == 1u);
266
+ let second_idx = select(pair_idx + half_dim, pair_idx * 2u + 1u, u.interleaved == 1u);
267
+ let x0 = input[base_idx + first_idx];
268
+ let x1 = input[base_idx + second_idx];
256
269
 
257
- input[base_idx + pair_idx] = x0 * cos_val - x1 * sin_val;
258
- input[base_idx + pair_idx + half_dim] = x0 * sin_val + x1 * cos_val;
270
+ input[base_idx + first_idx] = x0 * cos_val - x1 * sin_val;
271
+ input[base_idx + second_idx] = x0 * sin_val + x1 * cos_val;
259
272
  }
260
273
 
261
274
  // YaRN-style RoPE with attention scaling
@@ -265,6 +278,7 @@ fn rope_yarn(
265
278
  @builtin(global_invocation_id) global_id: vec3<u32>
266
279
  ) {
267
280
  let head_dim = u.head_dim;
281
+ let rotary_dim = u.rotary_dim;
268
282
  let num_heads = u.num_heads;
269
283
  let seq_len = u.seq_len;
270
284
  let start_pos = u.start_pos;
@@ -272,7 +286,7 @@ fn rope_yarn(
272
286
  let rope_scale = u.rope_scale;
273
287
 
274
288
  let idx = global_id.x;
275
- let half_dim = head_dim / 2u;
289
+ let half_dim = rotary_dim / 2u;
276
290
  let total_pairs = seq_len * num_heads * half_dim;
277
291
 
278
292
  if (idx >= total_pairs) {
@@ -292,7 +306,7 @@ fn rope_yarn(
292
306
  let alpha: f32 = 1.0;
293
307
 
294
308
  // Compute original frequency
295
- let exponent = f32(pair_idx * 2u) / f32(head_dim);
309
+ let exponent = f32(pair_idx * 2u) / f32(rotary_dim);
296
310
  let orig_freq = 1.0 / pow(rope_base, exponent);
297
311
 
298
312
  // Compute wavelength
@@ -300,8 +314,8 @@ fn rope_yarn(
300
314
 
301
315
  // Interpolation factor based on wavelength
302
316
  var ramp: f32;
303
- let low_wavelength = f32(head_dim) / beta_fast;
304
- let high_wavelength = f32(head_dim) / beta_slow;
317
+ let low_wavelength = f32(rotary_dim) / beta_fast;
318
+ let high_wavelength = f32(rotary_dim) / beta_slow;
305
319
 
306
320
  if (wavelength < low_wavelength) {
307
321
  ramp = 0.0; // No interpolation for high frequencies
@@ -320,9 +334,11 @@ fn rope_yarn(
320
334
  let sin_val = sin(theta);
321
335
 
322
336
  let base_idx = pos * num_heads * head_dim + head_idx * head_dim;
323
- let x0 = input[base_idx + pair_idx];
324
- let x1 = input[base_idx + pair_idx + half_dim];
337
+ let first_idx = select(pair_idx, pair_idx * 2u, u.interleaved == 1u);
338
+ let second_idx = select(pair_idx + half_dim, pair_idx * 2u + 1u, u.interleaved == 1u);
339
+ let x0 = input[base_idx + first_idx];
340
+ let x1 = input[base_idx + second_idx];
325
341
 
326
- input[base_idx + pair_idx] = x0 * cos_val - x1 * sin_val;
327
- input[base_idx + pair_idx + half_dim] = x0 * sin_val + x1 * cos_val;
342
+ input[base_idx + first_idx] = x0 * cos_val - x1 * sin_val;
343
+ input[base_idx + second_idx] = x0 * sin_val + x1 * cos_val;
328
344
  }
@@ -1,7 +1,7 @@
1
1
 
2
2
 
3
3
  import { getDevice, getKernelCapabilities } from '../device.js';
4
- import { acquireBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
4
+ import { acquireBuffer, readBufferSlice, releaseBuffer } from '../../memory/buffer-pool.js';
5
5
  import { WORKGROUP_SIZES } from './constants.js';
6
6
  import { createPipeline, createUniformBufferWithView, getOrCreateBindGroupLayout } from './utils.js';
7
7
  import { allowReadback } from '../perf-guards.js';
@@ -156,18 +156,19 @@ function ensureOutputBufferSize(outputBuffer, minBytes, label) {
156
156
  }
157
157
  }
158
158
 
159
- function readTokenFromOutput(device, outputBuffer, outputIndex, label) {
160
- const stagingBuffer = device.createBuffer({
161
- label,
162
- size: 4,
163
- usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST,
164
- });
165
-
166
- const copyEncoder = device.createCommandEncoder({ label: `${label}_copy` });
167
- copyEncoder.copyBufferToBuffer(outputBuffer, outputIndex * 4, stagingBuffer, 0, 4);
168
- device.queue.submit([copyEncoder.finish()]);
159
+ async function readTokenFromOutput(outputBuffer, outputIndex) {
160
+ return new Uint32Array(await readBufferSlice(outputBuffer, outputIndex * 4, 4))[0];
161
+ }
169
162
 
170
- return stagingBuffer;
163
+ function cleanupRunResources(uniformBuffer, ownedBuffers) {
164
+ if (uniformBuffer) {
165
+ uniformBuffer.destroy();
166
+ }
167
+ for (const buffer of ownedBuffers) {
168
+ if (buffer) {
169
+ releaseBuffer(buffer);
170
+ }
171
+ }
171
172
  }
172
173
 
173
174
  async function executeArgmaxRun(logits, vocabSize, options) {
@@ -238,20 +239,14 @@ async function executeArgmaxRun(logits, vocabSize, options) {
238
239
 
239
240
  device.queue.submit([encoder.finish()]);
240
241
 
241
- const stagingBuffer = readTokenFromOutput(device, outputBuffer, outputIndex, 'argmax_staging');
242
- await stagingBuffer.mapAsync(GPUMapMode.READ);
243
- const tokenId = new Uint32Array(stagingBuffer.getMappedRange())[0];
244
- stagingBuffer.unmap();
245
-
246
- stagingBuffer.destroy();
247
- uniformBuffer.destroy();
248
- releaseBuffer(tempLogits);
249
- releaseBuffer(tempIndices);
250
- if (ownsOutputBuffer) {
251
- releaseBuffer(outputBuffer);
242
+ try {
243
+ return await readTokenFromOutput(outputBuffer, outputIndex);
244
+ } finally {
245
+ cleanupRunResources(
246
+ uniformBuffer,
247
+ [tempLogits, tempIndices, ownsOutputBuffer ? outputBuffer : null]
248
+ );
252
249
  }
253
-
254
- return tokenId;
255
250
  }
256
251
 
257
252
  async function executeArgmaxRecord(recorder, logits, vocabSize, options) {
@@ -428,20 +423,14 @@ export async function runGPUSample(
428
423
 
429
424
  device.queue.submit([encoder.finish()]);
430
425
 
431
- const stagingBuffer = readTokenFromOutput(device, outputBuffer, outputIndex, 'sample_staging');
432
- await stagingBuffer.mapAsync(GPUMapMode.READ);
433
- const tokenId = new Uint32Array(stagingBuffer.getMappedRange())[0];
434
- stagingBuffer.unmap();
435
-
436
- stagingBuffer.destroy();
437
- uniformBuffer.destroy();
438
- releaseBuffer(topkLogits);
439
- releaseBuffer(topkIndices);
440
- if (ownsOutputBuffer) {
441
- releaseBuffer(outputBuffer);
426
+ try {
427
+ return await readTokenFromOutput(outputBuffer, outputIndex);
428
+ } finally {
429
+ cleanupRunResources(
430
+ uniformBuffer,
431
+ [topkLogits, topkIndices, ownsOutputBuffer ? outputBuffer : null]
432
+ );
442
433
  }
443
-
444
- return tokenId;
445
434
  }
446
435
 
447
436
 
@@ -29,7 +29,6 @@ async function runSummary(target, query, key, value, summaryBuffer, uniforms, va
29
29
  }
30
30
 
31
31
  async function runApply(target, query, summaryBuffer, outputBuffer, uniforms, variant) {
32
- const outputSize = uniforms.num_tokens * uniforms.hidden_size;
33
32
  await unifiedKernelWrapper(
34
33
  'sana_linear_attention_apply',
35
34
  target,
@@ -45,7 +44,7 @@ async function runApply(target, query, summaryBuffer, outputBuffer, uniforms, va
45
44
  _pad1: 0,
46
45
  _pad2: 0,
47
46
  },
48
- Math.ceil(outputSize / WORKGROUP_SIZES.DEFAULT)
47
+ [Math.ceil(uniforms.hidden_size / WORKGROUP_SIZES.DEFAULT), uniforms.num_tokens, 1]
49
48
  );
50
49
  }
51
50
 
@@ -65,6 +64,8 @@ async function _sanaLinearAttention(target, query, key, value, options = {}) {
65
64
  outputBuffer = null,
66
65
  summaryBuffer = null,
67
66
  } = options;
67
+ const ownsSummary = summaryBuffer == null;
68
+ const ownsOutput = outputBuffer == null;
68
69
 
69
70
  if (
70
71
  !Number.isFinite(numHeads) ||
@@ -99,18 +100,24 @@ async function _sanaLinearAttention(target, query, key, value, options = {}) {
99
100
  eps,
100
101
  };
101
102
 
102
- await runSummary(target, query, key, value, temporarySummary, uniforms, variant);
103
- await runApply(target, query, temporarySummary, output, uniforms, variant);
104
-
105
- if (!summaryBuffer) {
106
- if (recorder) {
107
- recorder.trackTemporaryBuffer(temporarySummary);
108
- } else {
109
- releaseBuffer(temporarySummary);
103
+ try {
104
+ await runSummary(target, query, key, value, temporarySummary, uniforms, variant);
105
+ await runApply(target, query, temporarySummary, output, uniforms, variant);
106
+ return createTensor(output, query.dtype, [numTokens, hiddenSize], 'sana_linear_attention_output');
107
+ } catch (error) {
108
+ if (ownsOutput) {
109
+ releaseBuffer(output);
110
+ }
111
+ throw error;
112
+ } finally {
113
+ if (ownsSummary) {
114
+ if (recorder) {
115
+ recorder.trackTemporaryBuffer(temporarySummary);
116
+ } else {
117
+ releaseBuffer(temporarySummary);
118
+ }
110
119
  }
111
120
  }
112
-
113
- return createTensor(output, query.dtype, [numTokens, hiddenSize], 'sana_linear_attention_output');
114
121
  }
115
122
 
116
123
  export async function runSanaLinearAttention(query, key, value, options = {}) {
@@ -18,14 +18,13 @@ struct Uniforms {
18
18
 
19
19
  @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
20
20
  fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
21
- let idx = gid.x;
22
- let total = u.num_tokens * u.hidden_size;
23
- if (idx >= total) {
21
+ let hidden = gid.x;
22
+ let token = gid.y;
23
+ if (token >= u.num_tokens || hidden >= u.hidden_size) {
24
24
  return;
25
25
  }
26
26
 
27
- let token = idx / u.hidden_size;
28
- let hidden = idx - token * u.hidden_size;
27
+ let idx = token * u.hidden_size + hidden;
29
28
  let head = hidden / u.head_dim;
30
29
  let dim = hidden - head * u.head_dim;
31
30
  let rows_per_head = u.head_dim + 1u;
@@ -20,14 +20,13 @@ struct Uniforms {
20
20
 
21
21
  @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
22
22
  fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
23
- let idx = gid.x;
24
- let total = u.num_tokens * u.hidden_size;
25
- if (idx >= total) {
23
+ let hidden = gid.x;
24
+ let token = gid.y;
25
+ if (token >= u.num_tokens || hidden >= u.hidden_size) {
26
26
  return;
27
27
  }
28
28
 
29
- let token = idx / u.hidden_size;
30
- let hidden = idx - token * u.hidden_size;
29
+ let idx = token * u.hidden_size + hidden;
31
30
  let head = hidden / u.head_dim;
32
31
  let dim = hidden - head * u.head_dim;
33
32
  let rows_per_head = u.head_dim + 1u;
@@ -33,6 +33,7 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
33
33
 
34
34
  var acc: f32 = 0.0;
35
35
  for (var token: u32 = 0u; token < u.num_tokens; token = token + 1u) {
36
+ let query_value = query[token * u.hidden_size + hidden_base + col];
36
37
  let key_idx = token * u.hidden_size + hidden_base + col;
37
38
  let key_value = max(key[key_idx], 0.0);
38
39
  let value_value = select(
@@ -40,6 +41,9 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
40
41
  1.0,
41
42
  row == u.head_dim
42
43
  );
44
+ if (u.hidden_size == 0u) {
45
+ acc = acc + query_value;
46
+ }
43
47
  acc = acc + value_value * key_value;
44
48
  }
45
49
 
@@ -35,6 +35,7 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
35
35
 
36
36
  var acc: f32 = 0.0;
37
37
  for (var token: u32 = 0u; token < u.num_tokens; token = token + 1u) {
38
+ let query_value = f32(query[token * u.hidden_size + hidden_base + col]);
38
39
  let key_idx = token * u.hidden_size + hidden_base + col;
39
40
  let key_value = max(f32(key[key_idx]), 0.0);
40
41
  let value_value = select(
@@ -42,6 +43,9 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
42
43
  1.0,
43
44
  row == u.head_dim
44
45
  );
46
+ if (u.hidden_size == 0u) {
47
+ acc = acc + query_value;
48
+ }
45
49
  acc = acc + value_value * key_value;
46
50
  }
47
51
 
@@ -1,4 +1,4 @@
1
- import { acquireBuffer } from '../../memory/buffer-pool.js';
1
+ import { acquireBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
2
2
  import { createTensor, dtypeBytes } from '../tensor.js';
3
3
  import { WORKGROUP_SIZES } from './constants.js';
4
4
  import { unifiedKernelWrapper } from './utils.js';
@@ -6,6 +6,7 @@ import { selectRuleValue } from './rule-registry.js';
6
6
 
7
7
  async function _scale(target, input, scale, options = {}) {
8
8
  const { count, outputBuffer = null, inplace = false } = options;
9
+ const ownsOutput = !inplace && outputBuffer == null;
9
10
 
10
11
  const bytesPerElement = dtypeBytes(input.dtype);
11
12
  const inferredCount = count ?? Math.floor(input.buffer.size / bytesPerElement);
@@ -16,16 +17,22 @@ async function _scale(target, input, scale, options = {}) {
16
17
 
17
18
  const bindings = inplace ? [outputBuf, outputBuf] : [input, outputBuf];
18
19
 
19
- await unifiedKernelWrapper(
20
- 'scale',
21
- target,
22
- variant,
23
- bindings,
24
- { size: inferredCount, scale },
25
- Math.ceil(inferredCount / WORKGROUP_SIZES.DEFAULT)
26
- );
27
-
28
- return createTensor(outputBuf, input.dtype, [...input.shape], 'scale_output');
20
+ try {
21
+ await unifiedKernelWrapper(
22
+ 'scale',
23
+ target,
24
+ variant,
25
+ bindings,
26
+ { size: inferredCount, scale },
27
+ Math.ceil(inferredCount / WORKGROUP_SIZES.DEFAULT)
28
+ );
29
+ return createTensor(outputBuf, input.dtype, [...input.shape], 'scale_output');
30
+ } catch (error) {
31
+ if (ownsOutput) {
32
+ releaseBuffer(outputBuf);
33
+ }
34
+ throw error;
35
+ }
29
36
  }
30
37
 
31
38
  export async function runScale(input, scale, options = {}) {
@@ -138,8 +138,10 @@ export async function compileShader(
138
138
  code: source,
139
139
  });
140
140
 
141
- // Check for compilation errors
142
- const compilationInfo = await module.getCompilationInfo();
141
+ // Check for compilation errors (getCompilationInfo not available in all WebGPU providers)
142
+ const compilationInfo = typeof module.getCompilationInfo === 'function'
143
+ ? await module.getCompilationInfo()
144
+ : { messages: [] };
143
145
  if (compilationInfo.messages.length > 0) {
144
146
  for (const msg of compilationInfo.messages) {
145
147
  if (msg.type === 'error') {
@@ -16,6 +16,7 @@ export interface SiLUOptions extends OutputBufferOptions {
16
16
  size?: number | null;
17
17
  gate?: Tensor | null;
18
18
  gateActivation?: 'silu' | 'sigmoid';
19
+ inputActivation?: 'silu' | 'identity';
19
20
  useVec4?: boolean;
20
21
  biasOffset?: number;
21
22
  swigluLimit: number | null;