@simulatte/doppler 0.1.5 → 0.1.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (392) hide show
  1. package/CHANGELOG.md +126 -0
  2. package/README.md +25 -17
  3. package/package.json +20 -4
  4. package/src/adapters/adapter-registry.js +12 -1
  5. package/src/adapters/lora-loader.js +23 -6
  6. package/src/bridge/extension-client.d.ts +5 -0
  7. package/src/bridge/extension-client.js +40 -0
  8. package/src/bridge/index.d.ts +2 -1
  9. package/src/bridge/index.js +6 -4
  10. package/src/browser/browser-converter.js +26 -1
  11. package/src/browser/file-picker.js +6 -0
  12. package/src/browser/safetensors-parser-browser.js +84 -1
  13. package/src/browser/shard-io-browser.js +2 -2
  14. package/src/browser/tensor-source-download.js +8 -2
  15. package/src/browser/tensor-source-http.d.ts +1 -0
  16. package/src/browser/tensor-source-http.js +5 -1
  17. package/src/client/doppler-api.browser.js +20 -4
  18. package/src/client/doppler-api.js +19 -3
  19. package/src/client/doppler-provider/generation.js +12 -0
  20. package/src/client/doppler-provider/model-manager.d.ts +10 -0
  21. package/src/client/doppler-provider/model-manager.js +91 -19
  22. package/src/client/doppler-provider/source-runtime.d.ts +2 -1
  23. package/src/client/doppler-provider/source-runtime.js +132 -13
  24. package/src/client/doppler-registry.json +8 -7
  25. package/src/config/backward-registry-loader.js +17 -2
  26. package/src/config/execution-v0-contract-check.js +113 -15
  27. package/src/config/kernel-path-contract-check.js +57 -29
  28. package/src/config/kernel-path-loader.js +5 -36
  29. package/src/config/kernels/kernel-ref-digests.js +39 -39
  30. package/src/config/kernels/registry.js +14 -1
  31. package/src/config/kernels/registry.json +49 -7
  32. package/src/config/loader.d.ts +1 -1
  33. package/src/config/loader.js +43 -4
  34. package/src/config/merge-contract-check.js +59 -4
  35. package/src/config/merge-helpers.js +128 -7
  36. package/src/config/merge.d.ts +1 -0
  37. package/src/config/merge.js +28 -0
  38. package/src/config/param-validator.js +47 -2
  39. package/src/config/presets/kernel-paths/{gemma2-q4k-dequant-f32a.json → gemma2-q4k-dequant-f32a-nosubgroups.json} +3 -3
  40. package/src/config/presets/kernel-paths/gemma3-f16-fused-f32a-online-streamingprefill.json +223 -0
  41. package/src/config/presets/kernel-paths/{gemma3-q4k-dequant-f32a.json → gemma3-q4k-dequant-f32a-nosubgroups.json} +3 -3
  42. package/src/config/presets/kernel-paths/registry.json +29 -8
  43. package/src/config/presets/models/gemma2.json +2 -2
  44. package/src/config/presets/models/qwen3.json +9 -2
  45. package/src/config/presets/models/transformer.json +5 -0
  46. package/src/config/presets/runtime/experiments/bench/gemma3-bench-q4k.json +1 -1
  47. package/src/config/presets/runtime/experiments/debug/gemma3-debug-q4k.json +1 -1
  48. package/src/config/presets/runtime/experiments/verify/gemma3-verify.json +1 -1
  49. package/src/config/presets/runtime/kernels/dequant-f16-q4k.json +6 -13
  50. package/src/config/presets/runtime/kernels/dequant-f32-q4k.json +6 -13
  51. package/src/config/presets/runtime/kernels/embeddinggemma-q4k-dequant-f32a.json +37 -0
  52. package/src/config/presets/runtime/kernels/fused-q4k.json +6 -13
  53. package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f16a.json +33 -0
  54. package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f32a-nosubgroups.json +33 -0
  55. package/src/config/presets/runtime/kernels/gemma2-q4k-fused-f32a.json +33 -0
  56. package/src/config/presets/runtime/kernels/safe-q4k.json +6 -13
  57. package/src/config/presets/runtime/platform/metal-apple-q4k.json +1 -1
  58. package/src/config/required-inference-fields-contract-check.js +6 -0
  59. package/src/config/runtime.js +6 -1
  60. package/src/config/schema/debug.schema.d.ts +5 -0
  61. package/src/config/schema/doppler.schema.js +16 -21
  62. package/src/config/schema/inference-defaults.schema.js +6 -3
  63. package/src/config/schema/inference.schema.d.ts +9 -0
  64. package/src/config/schema/kernel-path.schema.d.ts +11 -1
  65. package/src/config/schema/kernel-thresholds.schema.js +12 -4
  66. package/src/config/schema/manifest.schema.d.ts +8 -1
  67. package/src/config/schema/manifest.schema.js +19 -3
  68. package/src/config/training-defaults.js +30 -22
  69. package/src/converter/conversion-plan.js +94 -9
  70. package/src/converter/core.d.ts +7 -0
  71. package/src/converter/core.js +14 -9
  72. package/src/converter/execution-v0-manifest.js +4 -1
  73. package/src/converter/index.d.ts +1 -0
  74. package/src/converter/index.js +1 -0
  75. package/src/converter/manifest-inference.js +43 -12
  76. package/src/converter/parsers/diffusion.js +0 -3
  77. package/src/converter/quantization-info.js +35 -15
  78. package/src/converter/rope-config.js +42 -0
  79. package/src/converter/shard-packer.d.ts +1 -1
  80. package/src/converter/shard-packer.js +4 -1
  81. package/src/debug/config.js +123 -11
  82. package/src/debug/signals.js +7 -1
  83. package/src/debug/tensor.d.ts +2 -0
  84. package/src/debug/tensor.js +13 -2
  85. package/src/distribution/p2p-control-plane.js +52 -12
  86. package/src/distribution/p2p-observability.js +43 -7
  87. package/src/distribution/p2p-webrtc-browser.js +20 -0
  88. package/src/distribution/shard-delivery.js +77 -26
  89. package/src/formats/gguf/types.js +33 -16
  90. package/src/formats/rdrr/groups.d.ts +12 -4
  91. package/src/formats/rdrr/groups.js +3 -6
  92. package/src/formats/rdrr/parsing.js +39 -2
  93. package/src/formats/rdrr/types.d.ts +2 -1
  94. package/src/gpu/command-recorder.js +86 -61
  95. package/src/gpu/device.d.ts +1 -0
  96. package/src/gpu/device.js +131 -19
  97. package/src/gpu/kernel-tuner/benchmarks.js +326 -316
  98. package/src/gpu/kernel-tuner/cache.js +71 -4
  99. package/src/gpu/kernel-tuner/tuner.js +22 -4
  100. package/src/gpu/kernels/attention.js +113 -34
  101. package/src/gpu/kernels/backward/adam.js +62 -58
  102. package/src/gpu/kernels/backward/attention_backward.js +257 -169
  103. package/src/gpu/kernels/backward/conv2d_backward.js +14 -1
  104. package/src/gpu/kernels/bias_add.wgsl +8 -6
  105. package/src/gpu/kernels/bias_add_f16.wgsl +8 -5
  106. package/src/gpu/kernels/cast.js +191 -149
  107. package/src/gpu/kernels/check-stop.js +33 -44
  108. package/src/gpu/kernels/conv2d.js +27 -17
  109. package/src/gpu/kernels/conv2d.wgsl +7 -8
  110. package/src/gpu/kernels/conv2d_f16.wgsl +7 -8
  111. package/src/gpu/kernels/cross_entropy_loss.js +21 -15
  112. package/src/gpu/kernels/depthwise_conv2d.js +37 -26
  113. package/src/gpu/kernels/depthwise_conv2d.wgsl +6 -9
  114. package/src/gpu/kernels/depthwise_conv2d_f16.wgsl +6 -9
  115. package/src/gpu/kernels/dequant.js +178 -126
  116. package/src/gpu/kernels/energy.d.ts +3 -21
  117. package/src/gpu/kernels/energy.js +111 -88
  118. package/src/gpu/kernels/feature-check.js +1 -1
  119. package/src/gpu/kernels/fused_ffn.js +84 -65
  120. package/src/gpu/kernels/fused_matmul_residual.js +56 -33
  121. package/src/gpu/kernels/fused_matmul_rmsnorm.js +62 -45
  122. package/src/gpu/kernels/gather.js +33 -15
  123. package/src/gpu/kernels/gelu.js +19 -11
  124. package/src/gpu/kernels/grouped_pointwise_conv2d.js +34 -23
  125. package/src/gpu/kernels/grouped_pointwise_conv2d.wgsl +6 -9
  126. package/src/gpu/kernels/grouped_pointwise_conv2d_f16.wgsl +6 -9
  127. package/src/gpu/kernels/groupnorm.js +34 -23
  128. package/src/gpu/kernels/kv-quantize.js +5 -2
  129. package/src/gpu/kernels/layernorm.js +35 -19
  130. package/src/gpu/kernels/logit-merge.js +5 -3
  131. package/src/gpu/kernels/matmul.js +83 -39
  132. package/src/gpu/kernels/modulate.js +23 -15
  133. package/src/gpu/kernels/moe.js +221 -175
  134. package/src/gpu/kernels/pixel_shuffle.js +22 -14
  135. package/src/gpu/kernels/pixel_shuffle.wgsl +4 -5
  136. package/src/gpu/kernels/pixel_shuffle_f16.wgsl +4 -5
  137. package/src/gpu/kernels/relu.js +31 -10
  138. package/src/gpu/kernels/relu.wgsl +2 -1
  139. package/src/gpu/kernels/relu_f16.wgsl +2 -1
  140. package/src/gpu/kernels/repeat_channels.js +25 -17
  141. package/src/gpu/kernels/repeat_channels.wgsl +4 -5
  142. package/src/gpu/kernels/repeat_channels_f16.wgsl +4 -5
  143. package/src/gpu/kernels/residual.js +69 -23
  144. package/src/gpu/kernels/residual.wgsl +6 -3
  145. package/src/gpu/kernels/residual_f16.wgsl +2 -1
  146. package/src/gpu/kernels/residual_f16_vec4.wgsl +2 -1
  147. package/src/gpu/kernels/residual_vec4.wgsl +2 -1
  148. package/src/gpu/kernels/rmsnorm.js +96 -28
  149. package/src/gpu/kernels/rmsnorm.wgsl +14 -6
  150. package/src/gpu/kernels/rmsnorm_f16.wgsl +10 -2
  151. package/src/gpu/kernels/rope.d.ts +2 -0
  152. package/src/gpu/kernels/rope.js +14 -1
  153. package/src/gpu/kernels/rope.wgsl +56 -40
  154. package/src/gpu/kernels/sample.js +27 -38
  155. package/src/gpu/kernels/sana_linear_attention.js +19 -12
  156. package/src/gpu/kernels/sana_linear_attention_apply.wgsl +4 -5
  157. package/src/gpu/kernels/sana_linear_attention_apply_f16.wgsl +4 -5
  158. package/src/gpu/kernels/sana_linear_attention_summary.wgsl +4 -0
  159. package/src/gpu/kernels/sana_linear_attention_summary_f16.wgsl +4 -0
  160. package/src/gpu/kernels/scale.js +18 -11
  161. package/src/gpu/kernels/shader-cache.js +4 -2
  162. package/src/gpu/kernels/silu.d.ts +1 -0
  163. package/src/gpu/kernels/silu.js +148 -82
  164. package/src/gpu/kernels/silu.wgsl +19 -9
  165. package/src/gpu/kernels/silu_f16.wgsl +19 -9
  166. package/src/gpu/kernels/softmax.js +44 -25
  167. package/src/gpu/kernels/split_qkv.js +23 -13
  168. package/src/gpu/kernels/transpose.js +31 -10
  169. package/src/gpu/kernels/transpose.wgsl +6 -5
  170. package/src/gpu/kernels/upsample2d.js +22 -13
  171. package/src/gpu/kernels/upsample2d.wgsl +6 -9
  172. package/src/gpu/kernels/upsample2d_f16.wgsl +6 -9
  173. package/src/gpu/kernels/utils.js +35 -13
  174. package/src/gpu/partitioned-buffer-pool.js +10 -2
  175. package/src/gpu/perf-guards.js +2 -9
  176. package/src/gpu/profiler.js +27 -22
  177. package/src/gpu/readback-utils.d.ts +16 -0
  178. package/src/gpu/readback-utils.js +41 -0
  179. package/src/gpu/submit-tracker.js +13 -0
  180. package/src/gpu/uniform-cache.d.ts +1 -0
  181. package/src/gpu/uniform-cache.js +30 -9
  182. package/src/hotswap/intent-bundle.js +6 -0
  183. package/src/hotswap/manifest.d.ts +10 -1
  184. package/src/hotswap/manifest.js +12 -2
  185. package/src/hotswap/runtime.js +30 -8
  186. package/src/index-browser.d.ts +44 -0
  187. package/src/index-browser.js +14 -0
  188. package/src/inference/browser-harness-contract-helpers.d.ts +5 -0
  189. package/src/inference/browser-harness-contract-helpers.js +28 -0
  190. package/src/inference/browser-harness-diffusion-energy-suites.d.ts +2 -0
  191. package/src/inference/browser-harness-diffusion-energy-suites.js +269 -0
  192. package/src/inference/browser-harness-model-helpers.d.ts +16 -0
  193. package/src/inference/browser-harness-model-helpers.js +217 -0
  194. package/src/inference/browser-harness-report-helpers.d.ts +7 -0
  195. package/src/inference/browser-harness-report-helpers.js +42 -0
  196. package/src/inference/browser-harness-runtime-helpers.d.ts +61 -0
  197. package/src/inference/browser-harness-runtime-helpers.js +415 -0
  198. package/src/inference/browser-harness-suite-helpers.d.ts +28 -0
  199. package/src/inference/browser-harness-suite-helpers.js +268 -0
  200. package/src/inference/browser-harness-text-helpers.d.ts +27 -0
  201. package/src/inference/browser-harness-text-helpers.js +788 -0
  202. package/src/inference/browser-harness.d.ts +6 -0
  203. package/src/inference/browser-harness.js +130 -1950
  204. package/src/inference/kv-cache/base.js +140 -94
  205. package/src/inference/kv-cache/tiered.js +5 -3
  206. package/src/inference/moe-router.js +88 -56
  207. package/src/inference/multi-model-network.js +5 -3
  208. package/src/inference/network-evolution.d.ts +11 -2
  209. package/src/inference/network-evolution.js +20 -21
  210. package/src/inference/pipelines/context.d.ts +3 -0
  211. package/src/inference/pipelines/context.js +142 -2
  212. package/src/inference/pipelines/diffusion/helpers.js +7 -2
  213. package/src/inference/pipelines/diffusion/pipeline.js +17 -7
  214. package/src/inference/pipelines/diffusion/sd3-transformer.js +10 -10
  215. package/src/inference/pipelines/diffusion/text-encoder-gpu.d.ts +5 -0
  216. package/src/inference/pipelines/diffusion/text-encoder-gpu.js +27 -15
  217. package/src/inference/pipelines/diffusion/vae.js +3 -7
  218. package/src/inference/pipelines/energy/pipeline.js +27 -21
  219. package/src/inference/pipelines/energy/quintel.d.ts +5 -0
  220. package/src/inference/pipelines/energy/quintel.js +11 -0
  221. package/src/inference/pipelines/energy-head/row-head-pipeline.js +17 -13
  222. package/src/inference/pipelines/structured/json-head-pipeline.js +26 -11
  223. package/src/inference/pipelines/text/attention/projections.js +151 -101
  224. package/src/inference/pipelines/text/attention/record.js +73 -10
  225. package/src/inference/pipelines/text/attention/run.js +73 -10
  226. package/src/inference/pipelines/text/chat-format.js +25 -1
  227. package/src/inference/pipelines/text/config.d.ts +4 -0
  228. package/src/inference/pipelines/text/config.js +71 -5
  229. package/src/inference/pipelines/text/embed.js +2 -8
  230. package/src/inference/pipelines/text/execution-plan.js +64 -50
  231. package/src/inference/pipelines/text/execution-v0-contract-helpers.d.ts +59 -0
  232. package/src/inference/pipelines/text/execution-v0-contract-helpers.js +937 -0
  233. package/src/inference/pipelines/text/execution-v0-runtime-builders.d.ts +15 -0
  234. package/src/inference/pipelines/text/execution-v0-runtime-builders.js +279 -0
  235. package/src/inference/pipelines/text/execution-v0.js +78 -1002
  236. package/src/inference/pipelines/text/ffn/standard.js +3 -0
  237. package/src/inference/pipelines/text/generator-steps.d.ts +46 -0
  238. package/src/inference/pipelines/text/generator-steps.js +298 -207
  239. package/src/inference/pipelines/text/generator.js +6 -23
  240. package/src/inference/pipelines/text/init.d.ts +4 -0
  241. package/src/inference/pipelines/text/init.js +134 -29
  242. package/src/inference/pipelines/text/kernel-path-auto-select.js +2 -0
  243. package/src/inference/pipelines/text/kernel-trace.d.ts +2 -0
  244. package/src/inference/pipelines/text/kernel-trace.js +6 -0
  245. package/src/inference/pipelines/text/layer.js +14 -9
  246. package/src/inference/pipelines/text/linear-attention.d.ts +10 -0
  247. package/src/inference/pipelines/text/linear-attention.js +80 -6
  248. package/src/inference/pipelines/text/logits/gpu.js +10 -5
  249. package/src/inference/pipelines/text/logits/index.js +10 -11
  250. package/src/inference/pipelines/text/logits/utils.d.ts +7 -0
  251. package/src/inference/pipelines/text/logits/utils.js +9 -0
  252. package/src/inference/pipelines/text/lora-apply.js +50 -32
  253. package/src/inference/pipelines/text/model-load.js +279 -104
  254. package/src/inference/pipelines/text/moe-cache.js +5 -4
  255. package/src/inference/pipelines/text/moe-cpu-gptoss.js +74 -69
  256. package/src/inference/pipelines/text/moe-cpu.js +42 -38
  257. package/src/inference/pipelines/text/moe-gpu.js +110 -86
  258. package/src/inference/pipelines/text/ops.js +90 -90
  259. package/src/inference/pipelines/text/probes.js +9 -9
  260. package/src/inference/pipelines/text/weights.js +17 -7
  261. package/src/inference/pipelines/text.js +17 -1
  262. package/src/inference/speculative.d.ts +2 -2
  263. package/src/inference/speculative.js +4 -18
  264. package/src/inference/test-harness.d.ts +1 -1
  265. package/src/inference/test-harness.js +15 -5
  266. package/src/inference/tokenizer.d.ts +0 -5
  267. package/src/inference/tokenizer.js +4 -23
  268. package/src/inference/tokenizers/bpe.js +9 -0
  269. package/src/inference/tokenizers/bundled.js +176 -33
  270. package/src/inference/tokenizers/sentencepiece.js +12 -0
  271. package/src/loader/doppler-loader.js +38 -22
  272. package/src/loader/dtype-utils.js +3 -44
  273. package/src/loader/embedding-loader.js +7 -3
  274. package/src/loader/experts/expert-cache.js +13 -6
  275. package/src/loader/experts/expert-loader.js +10 -6
  276. package/src/loader/final-weights-loader.js +8 -4
  277. package/src/loader/layer-loader.js +2 -1
  278. package/src/loader/loader-state.js +2 -2
  279. package/src/loader/memory-monitor.js +8 -0
  280. package/src/loader/multi-model-loader.d.ts +14 -0
  281. package/src/loader/multi-model-loader.js +70 -24
  282. package/src/loader/shard-cache.js +81 -12
  283. package/src/loader/shard-resolver.js +25 -3
  284. package/src/loader/tensors/tensor-loader.js +209 -144
  285. package/src/loader/tensors/tensor-reader.js +76 -19
  286. package/src/loader/weight-downcast.js +1 -1
  287. package/src/memory/buffer-pool.d.ts +9 -1
  288. package/src/memory/buffer-pool.js +109 -44
  289. package/src/memory/unified-detect.js +1 -1
  290. package/src/rules/inference/kernel-path.rules.json +24 -8
  291. package/src/rules/rule-registry.js +25 -1
  292. package/src/rules/tooling/command-runtime.rules.json +18 -0
  293. package/src/storage/backends/opfs-store.js +68 -24
  294. package/src/storage/downloader.js +364 -83
  295. package/src/storage/index.d.ts +3 -0
  296. package/src/storage/index.js +3 -0
  297. package/src/storage/preflight.d.ts +2 -2
  298. package/src/storage/preflight.js +24 -2
  299. package/src/storage/quickstart-downloader.js +11 -5
  300. package/src/storage/registry.js +10 -4
  301. package/src/storage/reports.js +1 -1
  302. package/src/storage/shard-manager.d.ts +15 -1
  303. package/src/storage/shard-manager.js +51 -3
  304. package/src/storage/source-artifact-store.d.ts +52 -0
  305. package/src/storage/source-artifact-store.js +234 -0
  306. package/src/tooling/command-api-constants.d.ts +9 -0
  307. package/src/tooling/command-api-constants.js +9 -0
  308. package/src/tooling/command-api-family-normalizers.d.ts +9 -0
  309. package/src/tooling/command-api-family-normalizers.js +343 -0
  310. package/src/tooling/command-api-helpers.d.ts +25 -0
  311. package/src/tooling/command-api-helpers.js +262 -0
  312. package/src/tooling/command-api.d.ts +27 -1
  313. package/src/tooling/command-api.js +26 -473
  314. package/src/tooling/command-envelope.js +4 -1
  315. package/src/tooling/command-runner-shared.js +52 -18
  316. package/src/tooling/lean-execution-contract.js +150 -3
  317. package/src/tooling/node-browser-command-runner.d.ts +4 -0
  318. package/src/tooling/node-browser-command-runner.js +218 -273
  319. package/src/tooling/node-command-runner.js +44 -3
  320. package/src/tooling/node-converter.js +27 -1
  321. package/src/tooling/node-source-runtime.d.ts +1 -1
  322. package/src/tooling/node-source-runtime.js +84 -3
  323. package/src/tooling/node-webgpu.js +30 -105
  324. package/src/tooling/opfs-cache.js +21 -4
  325. package/src/tooling/runtime-input-composition.d.ts +38 -0
  326. package/src/tooling/runtime-input-composition.js +86 -0
  327. package/src/tooling/source-runtime-bundle.d.ts +40 -5
  328. package/src/tooling/source-runtime-bundle.js +261 -34
  329. package/src/tooling/source-runtime-materializer.d.ts +6 -0
  330. package/src/tooling/source-runtime-materializer.js +93 -0
  331. package/src/training/attention-backward.js +32 -17
  332. package/src/training/autograd.js +80 -52
  333. package/src/training/checkpoint-watch.d.ts +8 -0
  334. package/src/training/checkpoint-watch.js +139 -0
  335. package/src/training/checkpoint.d.ts +6 -1
  336. package/src/training/checkpoint.js +46 -7
  337. package/src/training/clip.js +2 -1
  338. package/src/training/datasets/token-batch.js +20 -8
  339. package/src/training/distillation/artifacts.d.ts +71 -0
  340. package/src/training/distillation/artifacts.js +132 -0
  341. package/src/training/distillation/checkpoint-watch.d.ts +10 -0
  342. package/src/training/distillation/checkpoint-watch.js +58 -0
  343. package/src/training/distillation/dataset.d.ts +59 -0
  344. package/src/training/distillation/dataset.js +337 -0
  345. package/src/training/distillation/eval.d.ts +34 -0
  346. package/src/training/distillation/eval.js +310 -0
  347. package/src/training/distillation/index.d.ts +29 -0
  348. package/src/training/distillation/index.js +29 -0
  349. package/src/training/distillation/runtime.d.ts +20 -0
  350. package/src/training/distillation/runtime.js +121 -0
  351. package/src/training/distillation/scoreboard.d.ts +6 -0
  352. package/src/training/distillation/scoreboard.js +8 -0
  353. package/src/training/distillation/stage-a.d.ts +45 -0
  354. package/src/training/distillation/stage-a.js +338 -0
  355. package/src/training/distillation/stage-b.d.ts +24 -0
  356. package/src/training/distillation/stage-b.js +20 -0
  357. package/src/training/distillation/student-fixture.d.ts +22 -0
  358. package/src/training/distillation/student-fixture.js +846 -0
  359. package/src/training/distillation/suite-data.d.ts +45 -0
  360. package/src/training/distillation/suite-data.js +189 -0
  361. package/src/training/index.d.ts +10 -0
  362. package/src/training/index.js +10 -0
  363. package/src/training/lora-pipeline.d.ts +40 -0
  364. package/src/training/lora-pipeline.js +793 -0
  365. package/src/training/lora.js +26 -12
  366. package/src/training/loss.js +5 -6
  367. package/src/training/objectives/cross_entropy.js +2 -5
  368. package/src/training/objectives/distill_kd.js +4 -8
  369. package/src/training/objectives/distill_triplet.js +4 -8
  370. package/src/training/objectives/ul_stage2_base.js +4 -8
  371. package/src/training/operator-artifacts.d.ts +62 -0
  372. package/src/training/operator-artifacts.js +140 -0
  373. package/src/training/operator-command.d.ts +5 -0
  374. package/src/training/operator-command.js +455 -0
  375. package/src/training/operator-eval.d.ts +48 -0
  376. package/src/training/operator-eval.js +230 -0
  377. package/src/training/operator-scoreboard.d.ts +5 -0
  378. package/src/training/operator-scoreboard.js +44 -0
  379. package/src/training/optimizer.js +19 -7
  380. package/src/training/runner.d.ts +52 -0
  381. package/src/training/runner.js +31 -5
  382. package/src/training/suite.d.ts +112 -0
  383. package/src/training/suite.js +24 -984
  384. package/src/training/tensor-factory.d.ts +9 -0
  385. package/src/training/tensor-factory.js +13 -0
  386. package/src/training/trainer.js +3 -5
  387. package/src/training/ul_dataset.js +3 -5
  388. package/src/training/workloads.d.ts +164 -0
  389. package/src/training/workloads.js +530 -0
  390. package/src/version.js +1 -1
  391. package/tools/convert-safetensors-node.js +22 -16
  392. package/tools/doppler-cli.js +179 -63
@@ -1,13 +1,26 @@
1
1
 
2
2
 
3
3
  import { getDevice } from '../device.js';
4
- import { acquireBuffer } from '../../memory/buffer-pool.js';
4
+ import { acquireBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
5
5
  import { createTensor, dtypeBytes } from '../tensor.js';
6
6
  import { WORKGROUP_SIZES } from './constants.js';
7
7
  import { dispatch, recordDispatch } from './dispatch.js';
8
8
  import { getPipelineFast, createUniformBufferWithView } from './utils.js';
9
9
  import { selectRuleValue } from './rule-registry.js';
10
10
 
11
+ function destroyAfterSubmit(device, buffer) {
12
+ if (!buffer) {
13
+ return;
14
+ }
15
+ device.queue.onSubmittedWorkDone()
16
+ .then(() => {
17
+ buffer.destroy();
18
+ })
19
+ .catch(() => {
20
+ buffer.destroy();
21
+ });
22
+ }
23
+
11
24
  function canUseF16(input) {
12
25
  return input.dtype === 'f16';
13
26
  }
@@ -47,6 +60,24 @@ function createSiLUBindGroupEntries(uniformBuffer, input, output, gate) {
47
60
  ];
48
61
  }
49
62
 
63
+ function cleanupRunResources(uniformBuffer, ownedOutput) {
64
+ if (ownedOutput) {
65
+ releaseBuffer(ownedOutput);
66
+ }
67
+ }
68
+
69
+ function planSiLUDispatch(device, size, useVec4) {
70
+ const maxPerDim = Number.isFinite(device?.limits?.maxComputeWorkgroupsPerDimension)
71
+ ? device.limits.maxComputeWorkgroupsPerDimension
72
+ : 65535;
73
+ const laneWidth = useVec4 ? 4 : 1;
74
+ const chunkSize = maxPerDim * WORKGROUP_SIZES.DEFAULT * laneWidth;
75
+ const dispatchStride = Math.min(size, chunkSize);
76
+ const x = Math.min(maxPerDim, Math.ceil(dispatchStride / (WORKGROUP_SIZES.DEFAULT * laneWidth)));
77
+ const y = Math.max(1, Math.ceil(size / chunkSize));
78
+ return { dispatchStride, workgroups: [x, y, 1] };
79
+ }
80
+
50
81
 
51
82
  export async function runSiLU(
52
83
  input,
@@ -60,6 +91,7 @@ export async function runSiLU(
60
91
  useVec4 = false,
61
92
  swigluLimit,
62
93
  gateActivation = 'silu',
94
+ inputActivation = 'silu',
63
95
  } = options;
64
96
  const resolvedSwigluLimit = resolveSwigluLimit(swigluLimit, 'SiLU');
65
97
 
@@ -74,14 +106,18 @@ export async function runSiLU(
74
106
  useSplit: false,
75
107
  useRowsplit: false,
76
108
  });
77
- const constants = gate && gateActivation === 'sigmoid'
78
- ? { ...(overrides || {}), GATE_USE_SIGMOID: true }
79
- : overrides;
109
+ const constants = {
110
+ ...(overrides || {}),
111
+ ...(gate && gateActivation === 'sigmoid' ? { GATE_USE_SIGMOID: true } : {}),
112
+ ...(inputActivation === 'identity' ? { INPUT_USE_IDENTITY: true } : {}),
113
+ };
80
114
  const pipeline = await getPipelineFast('silu', variant, null, constants);
81
115
 
82
116
  const inferredSize = size || (input.buffer.size / bytesPerElement);
83
117
  const outputSize = inferredSize * bytesPerElement;
84
118
  const output = outputBuffer || acquireBuffer(outputSize, undefined, 'silu_output');
119
+ const ownedOutput = outputBuffer ? null : output;
120
+ const dispatchPlan = planSiLUDispatch(device, inferredSize, useVec4);
85
121
 
86
122
  // Create uniform buffer
87
123
  const uniformBuffer = createUniformBufferWithView(
@@ -89,7 +125,7 @@ export async function runSiLU(
89
125
  16,
90
126
  (view) => {
91
127
  view.setUint32(0, inferredSize, true);
92
- view.setUint32(4, 0, true);
128
+ view.setUint32(4, dispatchPlan.dispatchStride, true);
93
129
  view.setFloat32(8, gate ? resolvedSwigluLimit : 0, true);
94
130
  view.setFloat32(12, 0, true);
95
131
  },
@@ -100,18 +136,21 @@ export async function runSiLU(
100
136
  // Create bind group using helper
101
137
  const entries = createSiLUBindGroupEntries(uniformBuffer, input, output, gate);
102
138
 
103
- const bindGroup = device.createBindGroup({
104
- label: 'silu_bind_group',
105
- layout: pipeline.getBindGroupLayout(0),
106
- entries,
107
- });
108
-
109
- const workgroups = Math.ceil(inferredSize / WORKGROUP_SIZES.DEFAULT);
110
- dispatch(device, pipeline, bindGroup, workgroups, 'silu');
111
-
112
- uniformBuffer.destroy();
113
-
114
- return createTensor(output, input.dtype, [inferredSize], 'silu_output');
139
+ try {
140
+ const bindGroup = device.createBindGroup({
141
+ label: 'silu_bind_group',
142
+ layout: pipeline.getBindGroupLayout(0),
143
+ entries,
144
+ });
145
+
146
+ dispatch(device, pipeline, bindGroup, dispatchPlan.workgroups, 'silu');
147
+ return createTensor(output, input.dtype, [inferredSize], 'silu_output');
148
+ } catch (error) {
149
+ cleanupRunResources(null, ownedOutput);
150
+ throw error;
151
+ } finally {
152
+ destroyAfterSubmit(device, uniformBuffer);
153
+ }
115
154
  }
116
155
 
117
156
 
@@ -133,6 +172,7 @@ export async function runSwiGLURowsplitBias(
133
172
  const bytesPerElement = dtypeBytes(input.dtype);
134
173
  const outputSize = numTokens * dim * bytesPerElement;
135
174
  const output = outputBuffer || acquireBuffer(outputSize, undefined, 'swiglu_output');
175
+ const ownedOutput = outputBuffer ? null : output;
136
176
 
137
177
  // Create uniform buffer
138
178
  const uniformBuffer = createUniformBufferWithView(
@@ -149,23 +189,27 @@ export async function runSwiGLURowsplitBias(
149
189
  );
150
190
 
151
191
  // Create bind group
152
- const bindGroup = device.createBindGroup({
153
- label: 'swiglu_bind_group',
154
- layout: pipeline.getBindGroupLayout(0),
155
- entries: [
156
- { binding: 0, resource: { buffer: uniformBuffer } },
157
- { binding: 1, resource: { buffer: input.buffer } },
158
- { binding: 2, resource: { buffer: bias.buffer } },
159
- { binding: 3, resource: { buffer: output } },
160
- ],
161
- });
162
-
163
- const workgroups = Math.ceil((numTokens * dim) / WORKGROUP_SIZES.DEFAULT);
164
- dispatch(device, pipeline, bindGroup, workgroups, 'swiglu');
165
-
166
- uniformBuffer.destroy();
167
-
168
- return createTensor(output, input.dtype, [numTokens, dim], 'swiglu_output');
192
+ try {
193
+ const bindGroup = device.createBindGroup({
194
+ label: 'swiglu_bind_group',
195
+ layout: pipeline.getBindGroupLayout(0),
196
+ entries: [
197
+ { binding: 0, resource: { buffer: uniformBuffer } },
198
+ { binding: 1, resource: { buffer: input.buffer } },
199
+ { binding: 2, resource: { buffer: bias.buffer } },
200
+ { binding: 3, resource: { buffer: output } },
201
+ ],
202
+ });
203
+
204
+ const workgroups = Math.ceil((numTokens * dim) / WORKGROUP_SIZES.DEFAULT);
205
+ dispatch(device, pipeline, bindGroup, workgroups, 'swiglu');
206
+ return createTensor(output, input.dtype, [numTokens, dim], 'swiglu_output');
207
+ } catch (error) {
208
+ cleanupRunResources(null, ownedOutput);
209
+ throw error;
210
+ } finally {
211
+ destroyAfterSubmit(device, uniformBuffer);
212
+ }
169
213
  }
170
214
 
171
215
 
@@ -187,6 +231,7 @@ export async function runSiLURowSplit(
187
231
 
188
232
  const outputSize = numTokens * dim * bytesPerElement;
189
233
  const output = outputBuffer || acquireBuffer(outputSize, undefined, 'silu_rowsplit_output');
234
+ const ownedOutput = outputBuffer ? null : output;
190
235
 
191
236
  // Create uniform buffer
192
237
  const uniformBuffer = createUniformBufferWithView(
@@ -203,24 +248,28 @@ export async function runSiLURowSplit(
203
248
  );
204
249
 
205
250
  // Bind group: provide a dummy gate buffer to satisfy the fixed layout
206
- const gateBuffer = input.buffer;
207
- const bindGroup = device.createBindGroup({
208
- label: 'silu_rowsplit_bind_group',
209
- layout: pipeline.getBindGroupLayout(0),
210
- entries: [
211
- { binding: 0, resource: { buffer: uniformBuffer } },
212
- { binding: 1, resource: { buffer: input.buffer } },
213
- { binding: 2, resource: { buffer: output } },
214
- { binding: 3, resource: { buffer: gateBuffer } },
215
- ],
216
- });
217
-
218
- const workgroups = Math.ceil((numTokens * dim) / WORKGROUP_SIZES.DEFAULT);
219
- dispatch(device, pipeline, bindGroup, workgroups, 'silu_rowsplit');
220
-
221
- uniformBuffer.destroy();
222
-
223
- return createTensor(output, input.dtype, [numTokens, dim], 'silu_rowsplit_output');
251
+ try {
252
+ const gateBuffer = input.buffer;
253
+ const bindGroup = device.createBindGroup({
254
+ label: 'silu_rowsplit_bind_group',
255
+ layout: pipeline.getBindGroupLayout(0),
256
+ entries: [
257
+ { binding: 0, resource: { buffer: uniformBuffer } },
258
+ { binding: 1, resource: { buffer: input.buffer } },
259
+ { binding: 2, resource: { buffer: output } },
260
+ { binding: 3, resource: { buffer: gateBuffer } },
261
+ ],
262
+ });
263
+
264
+ const workgroups = [Math.ceil(dim / WORKGROUP_SIZES.DEFAULT), numTokens, 1];
265
+ dispatch(device, pipeline, bindGroup, workgroups, 'silu_rowsplit');
266
+ return createTensor(output, input.dtype, [numTokens, dim], 'silu_rowsplit_output');
267
+ } catch (error) {
268
+ cleanupRunResources(null, ownedOutput);
269
+ throw error;
270
+ } finally {
271
+ uniformBuffer.destroy();
272
+ }
224
273
  }
225
274
 
226
275
 
@@ -243,6 +292,7 @@ export async function recordSiLURowSplit(
243
292
 
244
293
  const outputSize = numTokens * dim * bytesPerElement;
245
294
  const output = outputBuffer || acquireBuffer(outputSize, undefined, 'silu_rowsplit_output');
295
+ const ownedOutput = outputBuffer ? null : output;
246
296
 
247
297
  // Uniform buffer
248
298
  const uniformBuffer = createUniformBufferWithView(
@@ -257,22 +307,28 @@ export async function recordSiLURowSplit(
257
307
  recorder
258
308
  );
259
309
 
260
- const gateBuffer = input.buffer;
261
- const bindGroup = device.createBindGroup({
262
- label: 'silu_rowsplit_bind_group',
263
- layout: pipeline.getBindGroupLayout(0),
264
- entries: [
265
- { binding: 0, resource: { buffer: uniformBuffer } },
266
- { binding: 1, resource: { buffer: input.buffer } },
267
- { binding: 2, resource: { buffer: output } },
268
- { binding: 3, resource: { buffer: gateBuffer } },
269
- ],
270
- });
271
-
272
- const workgroups = Math.ceil((numTokens * dim) / WORKGROUP_SIZES.DEFAULT);
273
- recordDispatch(recorder, pipeline, bindGroup, workgroups, 'silu_rowsplit');
274
-
275
- return createTensor(output, input.dtype, [numTokens, dim], 'silu_rowsplit_output');
310
+ try {
311
+ const gateBuffer = input.buffer;
312
+ const bindGroup = device.createBindGroup({
313
+ label: 'silu_rowsplit_bind_group',
314
+ layout: pipeline.getBindGroupLayout(0),
315
+ entries: [
316
+ { binding: 0, resource: { buffer: uniformBuffer } },
317
+ { binding: 1, resource: { buffer: input.buffer } },
318
+ { binding: 2, resource: { buffer: output } },
319
+ { binding: 3, resource: { buffer: gateBuffer } },
320
+ ],
321
+ });
322
+
323
+ const workgroups = [Math.ceil(dim / WORKGROUP_SIZES.DEFAULT), numTokens, 1];
324
+ recordDispatch(recorder, pipeline, bindGroup, workgroups, 'silu_rowsplit');
325
+ return createTensor(output, input.dtype, [numTokens, dim], 'silu_rowsplit_output');
326
+ } catch (error) {
327
+ if (ownedOutput) {
328
+ releaseBuffer(ownedOutput);
329
+ }
330
+ throw error;
331
+ }
276
332
  }
277
333
 
278
334
 
@@ -288,6 +344,7 @@ export async function recordSiLU(
288
344
  outputBuffer = null,
289
345
  swigluLimit,
290
346
  gateActivation = 'silu',
347
+ inputActivation = 'silu',
291
348
  } = options;
292
349
  const resolvedSwigluLimit = resolveSwigluLimit(swigluLimit, 'SiLU');
293
350
 
@@ -302,14 +359,18 @@ export async function recordSiLU(
302
359
  useSplit: false,
303
360
  useRowsplit: false,
304
361
  });
305
- const constants = gate && gateActivation === 'sigmoid'
306
- ? { ...(overrides || {}), GATE_USE_SIGMOID: true }
307
- : overrides;
362
+ const constants = {
363
+ ...(overrides || {}),
364
+ ...(gate && gateActivation === 'sigmoid' ? { GATE_USE_SIGMOID: true } : {}),
365
+ ...(inputActivation === 'identity' ? { INPUT_USE_IDENTITY: true } : {}),
366
+ };
308
367
  const pipeline = await getPipelineFast('silu', variant, null, constants);
309
368
 
310
369
  const inferredSize = size || (input.buffer.size / bytesPerElement);
311
370
  const outputSize = inferredSize * bytesPerElement;
312
371
  const output = outputBuffer || acquireBuffer(outputSize, undefined, 'silu_output');
372
+ const ownedOutput = outputBuffer ? null : output;
373
+ const dispatchPlan = planSiLUDispatch(device, inferredSize, false);
313
374
 
314
375
  // Uniform buffer
315
376
  const uniformBuffer = createUniformBufferWithView(
@@ -317,7 +378,7 @@ export async function recordSiLU(
317
378
  16,
318
379
  (view) => {
319
380
  view.setUint32(0, inferredSize, true);
320
- view.setUint32(4, 0, true);
381
+ view.setUint32(4, dispatchPlan.dispatchStride, true);
321
382
  view.setFloat32(8, gate ? resolvedSwigluLimit : 0, true);
322
383
  view.setFloat32(12, 0, true);
323
384
  },
@@ -327,14 +388,19 @@ export async function recordSiLU(
327
388
  // Create bind group using helper
328
389
  const entries = createSiLUBindGroupEntries(uniformBuffer, input, output, gate);
329
390
 
330
- const bindGroup = device.createBindGroup({
331
- label: 'silu_bind_group',
332
- layout: pipeline.getBindGroupLayout(0),
333
- entries,
334
- });
335
-
336
- const workgroups = Math.ceil(inferredSize / WORKGROUP_SIZES.DEFAULT);
337
- recordDispatch(recorder, pipeline, bindGroup, workgroups, 'silu');
338
-
339
- return createTensor(output, input.dtype, [inferredSize], 'silu_output');
391
+ try {
392
+ const bindGroup = device.createBindGroup({
393
+ label: 'silu_bind_group',
394
+ layout: pipeline.getBindGroupLayout(0),
395
+ entries,
396
+ });
397
+
398
+ recordDispatch(recorder, pipeline, bindGroup, dispatchPlan.workgroups, 'silu');
399
+ return createTensor(output, input.dtype, [inferredSize], 'silu_output');
400
+ } catch (error) {
401
+ if (ownedOutput) {
402
+ releaseBuffer(ownedOutput);
403
+ }
404
+ throw error;
405
+ }
340
406
  }
@@ -10,13 +10,14 @@
10
10
  override WORKGROUP_SIZE: u32 = 256u;
11
11
  override HAS_GATE: bool = false;
12
12
  override GATE_USE_SIGMOID: bool = false;
13
+ override INPUT_USE_IDENTITY: bool = false;
13
14
  override USE_SPLIT: bool = false;
14
15
  override USE_VEC4: bool = false;
15
16
  override USE_ROWSPLIT: bool = false;
16
17
 
17
18
  struct Uniforms {
18
19
  size: u32, // Total output elements
19
- rowsplit_dim: u32, // Dim for rowsplit variants (0 when unused)
20
+ rowsplit_dim: u32, // Row-split dim or dispatch stride for non-row-split variants
20
21
  clamp_max: f32, // SwiGLU clamp (0 = disabled)
21
22
  _pad1: f32,
22
23
  }
@@ -35,6 +36,10 @@ fn silu(x: f32) -> f32 {
35
36
  return x * sigmoid(x);
36
37
  }
37
38
 
39
+ fn apply_input_activation(x: f32) -> f32 {
40
+ return select(silu(x), x, INPUT_USE_IDENTITY);
41
+ }
42
+
38
43
  fn clamp_swiglu(x: f32) -> f32 {
39
44
  if (u.clamp_max <= 0.0) {
40
45
  return x;
@@ -46,8 +51,9 @@ fn clamp_swiglu(x: f32) -> f32 {
46
51
  fn main(
47
52
  @builtin(global_invocation_id) global_id: vec3<u32>
48
53
  ) {
54
+ let dispatch_stride = max(u.rowsplit_dim, 1u);
49
55
  if (USE_VEC4) {
50
- let base_idx = global_id.x * 4u;
56
+ let base_idx = global_id.y * dispatch_stride + global_id.x * 4u;
51
57
  if (base_idx >= u.size) {
52
58
  return;
53
59
  }
@@ -55,12 +61,12 @@ fn main(
55
61
  let remaining = min(4u, u.size - base_idx);
56
62
  for (var i: u32 = 0u; i < remaining; i = i + 1u) {
57
63
  let x = input[base_idx + i];
58
- output[base_idx + i] = silu(x);
64
+ output[base_idx + i] = apply_input_activation(x);
59
65
  }
60
66
  return;
61
67
  }
62
68
 
63
- let idx = global_id.x;
69
+ let idx = global_id.y * dispatch_stride + global_id.x;
64
70
  if (idx >= u.size) {
65
71
  return;
66
72
  }
@@ -70,12 +76,16 @@ fn main(
70
76
  return;
71
77
  }
72
78
  let dim = u.rowsplit_dim;
73
- let token_idx = idx / dim;
74
- let dim_idx = idx % dim;
79
+ let num_tokens = u.size / dim;
80
+ let token_idx = global_id.y;
81
+ let dim_idx = global_id.x;
82
+ if (token_idx >= num_tokens || dim_idx >= dim) {
83
+ return;
84
+ }
75
85
  let row_base = token_idx * dim * 2u;
76
86
  let g = input[row_base + dim_idx];
77
87
  let up = input[row_base + dim + dim_idx];
78
- output[idx] = clamp_swiglu(silu(g) * up);
88
+ output[token_idx * dim + dim_idx] = clamp_swiglu(silu(g) * up);
79
89
  return;
80
90
  }
81
91
 
@@ -83,7 +93,7 @@ fn main(
83
93
  let up = input[idx];
84
94
  let g = gate[idx];
85
95
  let gateAct = select(silu(g), sigmoid(g), GATE_USE_SIGMOID);
86
- output[idx] = clamp_swiglu(gateAct * up);
96
+ output[idx] = clamp_swiglu(gateAct * apply_input_activation(up));
87
97
  return;
88
98
  }
89
99
 
@@ -95,5 +105,5 @@ fn main(
95
105
  }
96
106
 
97
107
  let x = input[idx];
98
- output[idx] = silu(x);
108
+ output[idx] = apply_input_activation(x);
99
109
  }
@@ -9,13 +9,14 @@ enable f16;
9
9
  override WORKGROUP_SIZE: u32 = 256u;
10
10
  override HAS_GATE: bool = false;
11
11
  override GATE_USE_SIGMOID: bool = false;
12
+ override INPUT_USE_IDENTITY: bool = false;
12
13
  override USE_SPLIT: bool = false;
13
14
  override USE_VEC4: bool = false;
14
15
  override USE_ROWSPLIT: bool = false;
15
16
 
16
17
  struct Uniforms {
17
18
  size: u32, // Total output elements
18
- rowsplit_dim: u32, // Dim for rowsplit variants (0 when unused)
19
+ rowsplit_dim: u32, // Row-split dim or dispatch stride for non-row-split variants
19
20
  clamp_max: f32, // SwiGLU clamp (0 = disabled)
20
21
  _pad1: f32,
21
22
  }
@@ -34,6 +35,10 @@ fn silu(x: f32) -> f32 {
34
35
  return x * sigmoid(x);
35
36
  }
36
37
 
38
+ fn apply_input_activation(x: f32) -> f32 {
39
+ return select(silu(x), x, INPUT_USE_IDENTITY);
40
+ }
41
+
37
42
  fn clamp_swiglu(x: f32) -> f32 {
38
43
  if (u.clamp_max <= 0.0) {
39
44
  return x;
@@ -45,8 +50,9 @@ fn clamp_swiglu(x: f32) -> f32 {
45
50
  fn main(
46
51
  @builtin(global_invocation_id) global_id: vec3<u32>
47
52
  ) {
53
+ let dispatch_stride = max(u.rowsplit_dim, 1u);
48
54
  if (USE_VEC4) {
49
- let base_idx = global_id.x * 4u;
55
+ let base_idx = global_id.y * dispatch_stride + global_id.x * 4u;
50
56
  if (base_idx >= u.size) {
51
57
  return;
52
58
  }
@@ -54,12 +60,12 @@ fn main(
54
60
  let remaining = min(4u, u.size - base_idx);
55
61
  for (var i: u32 = 0u; i < remaining; i = i + 1u) {
56
62
  let x = f32(input[base_idx + i]);
57
- output[base_idx + i] = f16(silu(x));
63
+ output[base_idx + i] = f16(apply_input_activation(x));
58
64
  }
59
65
  return;
60
66
  }
61
67
 
62
- let idx = global_id.x;
68
+ let idx = global_id.y * dispatch_stride + global_id.x;
63
69
  if (idx >= u.size) {
64
70
  return;
65
71
  }
@@ -69,12 +75,16 @@ fn main(
69
75
  return;
70
76
  }
71
77
  let dim = u.rowsplit_dim;
72
- let token_idx = idx / dim;
73
- let dim_idx = idx % dim;
78
+ let num_tokens = u.size / dim;
79
+ let token_idx = global_id.y;
80
+ let dim_idx = global_id.x;
81
+ if (token_idx >= num_tokens || dim_idx >= dim) {
82
+ return;
83
+ }
74
84
  let row_base = token_idx * dim * 2u;
75
85
  let g = f32(input[row_base + dim_idx]);
76
86
  let up = f32(input[row_base + dim + dim_idx]);
77
- output[idx] = f16(clamp_swiglu(silu(g) * up));
87
+ output[token_idx * dim + dim_idx] = f16(clamp_swiglu(silu(g) * up));
78
88
  return;
79
89
  }
80
90
 
@@ -82,7 +92,7 @@ fn main(
82
92
  let up = f32(input[idx]);
83
93
  let g = f32(gate[idx]);
84
94
  let gateAct = select(silu(g), sigmoid(g), GATE_USE_SIGMOID);
85
- output[idx] = f16(clamp_swiglu(gateAct * up));
95
+ output[idx] = f16(clamp_swiglu(gateAct * apply_input_activation(up)));
86
96
  return;
87
97
  }
88
98
 
@@ -94,5 +104,5 @@ fn main(
94
104
  }
95
105
 
96
106
  let x = f32(input[idx]);
97
- output[idx] = f16(silu(x));
107
+ output[idx] = f16(apply_input_activation(x));
98
108
  }
@@ -1,6 +1,6 @@
1
1
 
2
2
  import { getKernelCapabilities } from '../device.js';
3
- import { acquireBuffer } from '../../memory/buffer-pool.js';
3
+ import { acquireBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
4
4
  import { createTensor } from '../tensor.js';
5
5
  import { unifiedKernelWrapper } from './utils.js';
6
6
  import { createPipeline, createUniformBufferWithView, createBindGroupWithValidation } from './utils.js';
@@ -20,23 +20,34 @@ function selectSoftmaxVariant(innerSize) {
20
20
 
21
21
  async function _softmax(target, input, axis, options = {}) {
22
22
  const { batchSize = 1, size, seqLen, temperature = 1.0, outputBuffer = null } = options;
23
+ if (input.dtype !== 'f32') {
24
+ throw new Error(`Softmax requires f32 input, got ${input.dtype}.`);
25
+ }
23
26
 
24
- const bytesPerElement = input.dtype === 'f16' ? 2 : 4;
27
+ const bytesPerElement = 4;
25
28
  const inferredSize = size || seqLen || (input.buffer.size / (batchSize * bytesPerElement));
26
29
  const variant = selectSoftmaxVariant(inferredSize);
27
30
  trace.kernels(`Softmax: size=${inferredSize}, variant=${variant}`);
28
31
 
29
32
  const outputSize = batchSize * inferredSize * bytesPerElement;
30
33
  const output = outputBuffer || acquireBuffer(outputSize, undefined, 'softmax_output');
34
+ const ownedOutput = outputBuffer ? null : output;
35
+
36
+ try {
37
+ await unifiedKernelWrapper(
38
+ 'softmax', target, variant,
39
+ [input, output],
40
+ { inner_size: inferredSize, outer_size: batchSize, temperature },
41
+ batchSize
42
+ );
43
+ } catch (error) {
44
+ if (ownedOutput) {
45
+ releaseBuffer(ownedOutput);
46
+ }
47
+ throw error;
48
+ }
31
49
 
32
- await unifiedKernelWrapper(
33
- 'softmax', target, variant,
34
- [input, output],
35
- { inner_size: inferredSize, outer_size: batchSize, temperature },
36
- batchSize
37
- );
38
-
39
- return createTensor(output, input.dtype, [batchSize, inferredSize], 'softmax_output');
50
+ return createTensor(output, 'f32', [batchSize, inferredSize], 'softmax_output');
40
51
  }
41
52
 
42
53
  export async function runSoftmax(input, axis, options = {}) {
@@ -76,6 +87,7 @@ export async function runSoftmaxTopK(logits, numTokens, numExperts, topK, option
76
87
 
77
88
  const indices = acquireBuffer(indicesSize, undefined, 'softmax_topk_indices');
78
89
  const weights = acquireBuffer(weightsSize, undefined, 'softmax_topk_weights');
90
+ let completed = false;
79
91
 
80
92
  const uniformBuffer = createUniformBufferWithView(
81
93
  'softmax_topk_uniforms', 16,
@@ -88,19 +100,26 @@ export async function runSoftmaxTopK(logits, numTokens, numExperts, topK, option
88
100
  null, device
89
101
  );
90
102
 
91
- const bindGroup = await createBindGroupWithValidation(device, {
92
- label: 'softmax_topk_bind_group',
93
- layout: pipeline.getBindGroupLayout(0),
94
- entries: [
95
- { binding: 0, resource: { buffer: uniformBuffer } },
96
- { binding: 1, resource: { buffer: logits } },
97
- { binding: 2, resource: { buffer: indices } },
98
- { binding: 3, resource: { buffer: weights } },
99
- ],
100
- }, `topk:${variant}`);
101
-
102
- dispatchKernel(null, pipeline, bindGroup, numTokens, 'softmax_topk');
103
- uniformBuffer.destroy();
104
-
105
- return { indices, weights };
103
+ try {
104
+ const bindGroup = await createBindGroupWithValidation(device, {
105
+ label: 'softmax_topk_bind_group',
106
+ layout: pipeline.getBindGroupLayout(0),
107
+ entries: [
108
+ { binding: 0, resource: { buffer: uniformBuffer } },
109
+ { binding: 1, resource: { buffer: logits } },
110
+ { binding: 2, resource: { buffer: indices } },
111
+ { binding: 3, resource: { buffer: weights } },
112
+ ],
113
+ }, `topk:${variant}`);
114
+
115
+ dispatchKernel(null, pipeline, bindGroup, numTokens, 'softmax_topk');
116
+ completed = true;
117
+ return { indices, weights };
118
+ } finally {
119
+ uniformBuffer.destroy();
120
+ if (!completed) {
121
+ releaseBuffer(indices);
122
+ releaseBuffer(weights);
123
+ }
124
+ }
106
125
  }