@simulatte/doppler 0.1.5 → 0.1.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (392) hide show
  1. package/CHANGELOG.md +126 -0
  2. package/README.md +25 -17
  3. package/package.json +20 -4
  4. package/src/adapters/adapter-registry.js +12 -1
  5. package/src/adapters/lora-loader.js +23 -6
  6. package/src/bridge/extension-client.d.ts +5 -0
  7. package/src/bridge/extension-client.js +40 -0
  8. package/src/bridge/index.d.ts +2 -1
  9. package/src/bridge/index.js +6 -4
  10. package/src/browser/browser-converter.js +26 -1
  11. package/src/browser/file-picker.js +6 -0
  12. package/src/browser/safetensors-parser-browser.js +84 -1
  13. package/src/browser/shard-io-browser.js +2 -2
  14. package/src/browser/tensor-source-download.js +8 -2
  15. package/src/browser/tensor-source-http.d.ts +1 -0
  16. package/src/browser/tensor-source-http.js +5 -1
  17. package/src/client/doppler-api.browser.js +20 -4
  18. package/src/client/doppler-api.js +19 -3
  19. package/src/client/doppler-provider/generation.js +12 -0
  20. package/src/client/doppler-provider/model-manager.d.ts +10 -0
  21. package/src/client/doppler-provider/model-manager.js +91 -19
  22. package/src/client/doppler-provider/source-runtime.d.ts +2 -1
  23. package/src/client/doppler-provider/source-runtime.js +132 -13
  24. package/src/client/doppler-registry.json +8 -7
  25. package/src/config/backward-registry-loader.js +17 -2
  26. package/src/config/execution-v0-contract-check.js +113 -15
  27. package/src/config/kernel-path-contract-check.js +57 -29
  28. package/src/config/kernel-path-loader.js +5 -36
  29. package/src/config/kernels/kernel-ref-digests.js +39 -39
  30. package/src/config/kernels/registry.js +14 -1
  31. package/src/config/kernels/registry.json +49 -7
  32. package/src/config/loader.d.ts +1 -1
  33. package/src/config/loader.js +43 -4
  34. package/src/config/merge-contract-check.js +59 -4
  35. package/src/config/merge-helpers.js +128 -7
  36. package/src/config/merge.d.ts +1 -0
  37. package/src/config/merge.js +28 -0
  38. package/src/config/param-validator.js +47 -2
  39. package/src/config/presets/kernel-paths/{gemma2-q4k-dequant-f32a.json → gemma2-q4k-dequant-f32a-nosubgroups.json} +3 -3
  40. package/src/config/presets/kernel-paths/gemma3-f16-fused-f32a-online-streamingprefill.json +223 -0
  41. package/src/config/presets/kernel-paths/{gemma3-q4k-dequant-f32a.json → gemma3-q4k-dequant-f32a-nosubgroups.json} +3 -3
  42. package/src/config/presets/kernel-paths/registry.json +29 -8
  43. package/src/config/presets/models/gemma2.json +2 -2
  44. package/src/config/presets/models/qwen3.json +9 -2
  45. package/src/config/presets/models/transformer.json +5 -0
  46. package/src/config/presets/runtime/experiments/bench/gemma3-bench-q4k.json +1 -1
  47. package/src/config/presets/runtime/experiments/debug/gemma3-debug-q4k.json +1 -1
  48. package/src/config/presets/runtime/experiments/verify/gemma3-verify.json +1 -1
  49. package/src/config/presets/runtime/kernels/dequant-f16-q4k.json +6 -13
  50. package/src/config/presets/runtime/kernels/dequant-f32-q4k.json +6 -13
  51. package/src/config/presets/runtime/kernels/embeddinggemma-q4k-dequant-f32a.json +37 -0
  52. package/src/config/presets/runtime/kernels/fused-q4k.json +6 -13
  53. package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f16a.json +33 -0
  54. package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f32a-nosubgroups.json +33 -0
  55. package/src/config/presets/runtime/kernels/gemma2-q4k-fused-f32a.json +33 -0
  56. package/src/config/presets/runtime/kernels/safe-q4k.json +6 -13
  57. package/src/config/presets/runtime/platform/metal-apple-q4k.json +1 -1
  58. package/src/config/required-inference-fields-contract-check.js +6 -0
  59. package/src/config/runtime.js +6 -1
  60. package/src/config/schema/debug.schema.d.ts +5 -0
  61. package/src/config/schema/doppler.schema.js +16 -21
  62. package/src/config/schema/inference-defaults.schema.js +6 -3
  63. package/src/config/schema/inference.schema.d.ts +9 -0
  64. package/src/config/schema/kernel-path.schema.d.ts +11 -1
  65. package/src/config/schema/kernel-thresholds.schema.js +12 -4
  66. package/src/config/schema/manifest.schema.d.ts +8 -1
  67. package/src/config/schema/manifest.schema.js +19 -3
  68. package/src/config/training-defaults.js +30 -22
  69. package/src/converter/conversion-plan.js +94 -9
  70. package/src/converter/core.d.ts +7 -0
  71. package/src/converter/core.js +14 -9
  72. package/src/converter/execution-v0-manifest.js +4 -1
  73. package/src/converter/index.d.ts +1 -0
  74. package/src/converter/index.js +1 -0
  75. package/src/converter/manifest-inference.js +43 -12
  76. package/src/converter/parsers/diffusion.js +0 -3
  77. package/src/converter/quantization-info.js +35 -15
  78. package/src/converter/rope-config.js +42 -0
  79. package/src/converter/shard-packer.d.ts +1 -1
  80. package/src/converter/shard-packer.js +4 -1
  81. package/src/debug/config.js +123 -11
  82. package/src/debug/signals.js +7 -1
  83. package/src/debug/tensor.d.ts +2 -0
  84. package/src/debug/tensor.js +13 -2
  85. package/src/distribution/p2p-control-plane.js +52 -12
  86. package/src/distribution/p2p-observability.js +43 -7
  87. package/src/distribution/p2p-webrtc-browser.js +20 -0
  88. package/src/distribution/shard-delivery.js +77 -26
  89. package/src/formats/gguf/types.js +33 -16
  90. package/src/formats/rdrr/groups.d.ts +12 -4
  91. package/src/formats/rdrr/groups.js +3 -6
  92. package/src/formats/rdrr/parsing.js +39 -2
  93. package/src/formats/rdrr/types.d.ts +2 -1
  94. package/src/gpu/command-recorder.js +86 -61
  95. package/src/gpu/device.d.ts +1 -0
  96. package/src/gpu/device.js +131 -19
  97. package/src/gpu/kernel-tuner/benchmarks.js +326 -316
  98. package/src/gpu/kernel-tuner/cache.js +71 -4
  99. package/src/gpu/kernel-tuner/tuner.js +22 -4
  100. package/src/gpu/kernels/attention.js +113 -34
  101. package/src/gpu/kernels/backward/adam.js +62 -58
  102. package/src/gpu/kernels/backward/attention_backward.js +257 -169
  103. package/src/gpu/kernels/backward/conv2d_backward.js +14 -1
  104. package/src/gpu/kernels/bias_add.wgsl +8 -6
  105. package/src/gpu/kernels/bias_add_f16.wgsl +8 -5
  106. package/src/gpu/kernels/cast.js +191 -149
  107. package/src/gpu/kernels/check-stop.js +33 -44
  108. package/src/gpu/kernels/conv2d.js +27 -17
  109. package/src/gpu/kernels/conv2d.wgsl +7 -8
  110. package/src/gpu/kernels/conv2d_f16.wgsl +7 -8
  111. package/src/gpu/kernels/cross_entropy_loss.js +21 -15
  112. package/src/gpu/kernels/depthwise_conv2d.js +37 -26
  113. package/src/gpu/kernels/depthwise_conv2d.wgsl +6 -9
  114. package/src/gpu/kernels/depthwise_conv2d_f16.wgsl +6 -9
  115. package/src/gpu/kernels/dequant.js +178 -126
  116. package/src/gpu/kernels/energy.d.ts +3 -21
  117. package/src/gpu/kernels/energy.js +111 -88
  118. package/src/gpu/kernels/feature-check.js +1 -1
  119. package/src/gpu/kernels/fused_ffn.js +84 -65
  120. package/src/gpu/kernels/fused_matmul_residual.js +56 -33
  121. package/src/gpu/kernels/fused_matmul_rmsnorm.js +62 -45
  122. package/src/gpu/kernels/gather.js +33 -15
  123. package/src/gpu/kernels/gelu.js +19 -11
  124. package/src/gpu/kernels/grouped_pointwise_conv2d.js +34 -23
  125. package/src/gpu/kernels/grouped_pointwise_conv2d.wgsl +6 -9
  126. package/src/gpu/kernels/grouped_pointwise_conv2d_f16.wgsl +6 -9
  127. package/src/gpu/kernels/groupnorm.js +34 -23
  128. package/src/gpu/kernels/kv-quantize.js +5 -2
  129. package/src/gpu/kernels/layernorm.js +35 -19
  130. package/src/gpu/kernels/logit-merge.js +5 -3
  131. package/src/gpu/kernels/matmul.js +83 -39
  132. package/src/gpu/kernels/modulate.js +23 -15
  133. package/src/gpu/kernels/moe.js +221 -175
  134. package/src/gpu/kernels/pixel_shuffle.js +22 -14
  135. package/src/gpu/kernels/pixel_shuffle.wgsl +4 -5
  136. package/src/gpu/kernels/pixel_shuffle_f16.wgsl +4 -5
  137. package/src/gpu/kernels/relu.js +31 -10
  138. package/src/gpu/kernels/relu.wgsl +2 -1
  139. package/src/gpu/kernels/relu_f16.wgsl +2 -1
  140. package/src/gpu/kernels/repeat_channels.js +25 -17
  141. package/src/gpu/kernels/repeat_channels.wgsl +4 -5
  142. package/src/gpu/kernels/repeat_channels_f16.wgsl +4 -5
  143. package/src/gpu/kernels/residual.js +69 -23
  144. package/src/gpu/kernels/residual.wgsl +6 -3
  145. package/src/gpu/kernels/residual_f16.wgsl +2 -1
  146. package/src/gpu/kernels/residual_f16_vec4.wgsl +2 -1
  147. package/src/gpu/kernels/residual_vec4.wgsl +2 -1
  148. package/src/gpu/kernels/rmsnorm.js +96 -28
  149. package/src/gpu/kernels/rmsnorm.wgsl +14 -6
  150. package/src/gpu/kernels/rmsnorm_f16.wgsl +10 -2
  151. package/src/gpu/kernels/rope.d.ts +2 -0
  152. package/src/gpu/kernels/rope.js +14 -1
  153. package/src/gpu/kernels/rope.wgsl +56 -40
  154. package/src/gpu/kernels/sample.js +27 -38
  155. package/src/gpu/kernels/sana_linear_attention.js +19 -12
  156. package/src/gpu/kernels/sana_linear_attention_apply.wgsl +4 -5
  157. package/src/gpu/kernels/sana_linear_attention_apply_f16.wgsl +4 -5
  158. package/src/gpu/kernels/sana_linear_attention_summary.wgsl +4 -0
  159. package/src/gpu/kernels/sana_linear_attention_summary_f16.wgsl +4 -0
  160. package/src/gpu/kernels/scale.js +18 -11
  161. package/src/gpu/kernels/shader-cache.js +4 -2
  162. package/src/gpu/kernels/silu.d.ts +1 -0
  163. package/src/gpu/kernels/silu.js +148 -82
  164. package/src/gpu/kernels/silu.wgsl +19 -9
  165. package/src/gpu/kernels/silu_f16.wgsl +19 -9
  166. package/src/gpu/kernels/softmax.js +44 -25
  167. package/src/gpu/kernels/split_qkv.js +23 -13
  168. package/src/gpu/kernels/transpose.js +31 -10
  169. package/src/gpu/kernels/transpose.wgsl +6 -5
  170. package/src/gpu/kernels/upsample2d.js +22 -13
  171. package/src/gpu/kernels/upsample2d.wgsl +6 -9
  172. package/src/gpu/kernels/upsample2d_f16.wgsl +6 -9
  173. package/src/gpu/kernels/utils.js +35 -13
  174. package/src/gpu/partitioned-buffer-pool.js +10 -2
  175. package/src/gpu/perf-guards.js +2 -9
  176. package/src/gpu/profiler.js +27 -22
  177. package/src/gpu/readback-utils.d.ts +16 -0
  178. package/src/gpu/readback-utils.js +41 -0
  179. package/src/gpu/submit-tracker.js +13 -0
  180. package/src/gpu/uniform-cache.d.ts +1 -0
  181. package/src/gpu/uniform-cache.js +30 -9
  182. package/src/hotswap/intent-bundle.js +6 -0
  183. package/src/hotswap/manifest.d.ts +10 -1
  184. package/src/hotswap/manifest.js +12 -2
  185. package/src/hotswap/runtime.js +30 -8
  186. package/src/index-browser.d.ts +44 -0
  187. package/src/index-browser.js +14 -0
  188. package/src/inference/browser-harness-contract-helpers.d.ts +5 -0
  189. package/src/inference/browser-harness-contract-helpers.js +28 -0
  190. package/src/inference/browser-harness-diffusion-energy-suites.d.ts +2 -0
  191. package/src/inference/browser-harness-diffusion-energy-suites.js +269 -0
  192. package/src/inference/browser-harness-model-helpers.d.ts +16 -0
  193. package/src/inference/browser-harness-model-helpers.js +217 -0
  194. package/src/inference/browser-harness-report-helpers.d.ts +7 -0
  195. package/src/inference/browser-harness-report-helpers.js +42 -0
  196. package/src/inference/browser-harness-runtime-helpers.d.ts +61 -0
  197. package/src/inference/browser-harness-runtime-helpers.js +415 -0
  198. package/src/inference/browser-harness-suite-helpers.d.ts +28 -0
  199. package/src/inference/browser-harness-suite-helpers.js +268 -0
  200. package/src/inference/browser-harness-text-helpers.d.ts +27 -0
  201. package/src/inference/browser-harness-text-helpers.js +788 -0
  202. package/src/inference/browser-harness.d.ts +6 -0
  203. package/src/inference/browser-harness.js +130 -1950
  204. package/src/inference/kv-cache/base.js +140 -94
  205. package/src/inference/kv-cache/tiered.js +5 -3
  206. package/src/inference/moe-router.js +88 -56
  207. package/src/inference/multi-model-network.js +5 -3
  208. package/src/inference/network-evolution.d.ts +11 -2
  209. package/src/inference/network-evolution.js +20 -21
  210. package/src/inference/pipelines/context.d.ts +3 -0
  211. package/src/inference/pipelines/context.js +142 -2
  212. package/src/inference/pipelines/diffusion/helpers.js +7 -2
  213. package/src/inference/pipelines/diffusion/pipeline.js +17 -7
  214. package/src/inference/pipelines/diffusion/sd3-transformer.js +10 -10
  215. package/src/inference/pipelines/diffusion/text-encoder-gpu.d.ts +5 -0
  216. package/src/inference/pipelines/diffusion/text-encoder-gpu.js +27 -15
  217. package/src/inference/pipelines/diffusion/vae.js +3 -7
  218. package/src/inference/pipelines/energy/pipeline.js +27 -21
  219. package/src/inference/pipelines/energy/quintel.d.ts +5 -0
  220. package/src/inference/pipelines/energy/quintel.js +11 -0
  221. package/src/inference/pipelines/energy-head/row-head-pipeline.js +17 -13
  222. package/src/inference/pipelines/structured/json-head-pipeline.js +26 -11
  223. package/src/inference/pipelines/text/attention/projections.js +151 -101
  224. package/src/inference/pipelines/text/attention/record.js +73 -10
  225. package/src/inference/pipelines/text/attention/run.js +73 -10
  226. package/src/inference/pipelines/text/chat-format.js +25 -1
  227. package/src/inference/pipelines/text/config.d.ts +4 -0
  228. package/src/inference/pipelines/text/config.js +71 -5
  229. package/src/inference/pipelines/text/embed.js +2 -8
  230. package/src/inference/pipelines/text/execution-plan.js +64 -50
  231. package/src/inference/pipelines/text/execution-v0-contract-helpers.d.ts +59 -0
  232. package/src/inference/pipelines/text/execution-v0-contract-helpers.js +937 -0
  233. package/src/inference/pipelines/text/execution-v0-runtime-builders.d.ts +15 -0
  234. package/src/inference/pipelines/text/execution-v0-runtime-builders.js +279 -0
  235. package/src/inference/pipelines/text/execution-v0.js +78 -1002
  236. package/src/inference/pipelines/text/ffn/standard.js +3 -0
  237. package/src/inference/pipelines/text/generator-steps.d.ts +46 -0
  238. package/src/inference/pipelines/text/generator-steps.js +298 -207
  239. package/src/inference/pipelines/text/generator.js +6 -23
  240. package/src/inference/pipelines/text/init.d.ts +4 -0
  241. package/src/inference/pipelines/text/init.js +134 -29
  242. package/src/inference/pipelines/text/kernel-path-auto-select.js +2 -0
  243. package/src/inference/pipelines/text/kernel-trace.d.ts +2 -0
  244. package/src/inference/pipelines/text/kernel-trace.js +6 -0
  245. package/src/inference/pipelines/text/layer.js +14 -9
  246. package/src/inference/pipelines/text/linear-attention.d.ts +10 -0
  247. package/src/inference/pipelines/text/linear-attention.js +80 -6
  248. package/src/inference/pipelines/text/logits/gpu.js +10 -5
  249. package/src/inference/pipelines/text/logits/index.js +10 -11
  250. package/src/inference/pipelines/text/logits/utils.d.ts +7 -0
  251. package/src/inference/pipelines/text/logits/utils.js +9 -0
  252. package/src/inference/pipelines/text/lora-apply.js +50 -32
  253. package/src/inference/pipelines/text/model-load.js +279 -104
  254. package/src/inference/pipelines/text/moe-cache.js +5 -4
  255. package/src/inference/pipelines/text/moe-cpu-gptoss.js +74 -69
  256. package/src/inference/pipelines/text/moe-cpu.js +42 -38
  257. package/src/inference/pipelines/text/moe-gpu.js +110 -86
  258. package/src/inference/pipelines/text/ops.js +90 -90
  259. package/src/inference/pipelines/text/probes.js +9 -9
  260. package/src/inference/pipelines/text/weights.js +17 -7
  261. package/src/inference/pipelines/text.js +17 -1
  262. package/src/inference/speculative.d.ts +2 -2
  263. package/src/inference/speculative.js +4 -18
  264. package/src/inference/test-harness.d.ts +1 -1
  265. package/src/inference/test-harness.js +15 -5
  266. package/src/inference/tokenizer.d.ts +0 -5
  267. package/src/inference/tokenizer.js +4 -23
  268. package/src/inference/tokenizers/bpe.js +9 -0
  269. package/src/inference/tokenizers/bundled.js +176 -33
  270. package/src/inference/tokenizers/sentencepiece.js +12 -0
  271. package/src/loader/doppler-loader.js +38 -22
  272. package/src/loader/dtype-utils.js +3 -44
  273. package/src/loader/embedding-loader.js +7 -3
  274. package/src/loader/experts/expert-cache.js +13 -6
  275. package/src/loader/experts/expert-loader.js +10 -6
  276. package/src/loader/final-weights-loader.js +8 -4
  277. package/src/loader/layer-loader.js +2 -1
  278. package/src/loader/loader-state.js +2 -2
  279. package/src/loader/memory-monitor.js +8 -0
  280. package/src/loader/multi-model-loader.d.ts +14 -0
  281. package/src/loader/multi-model-loader.js +70 -24
  282. package/src/loader/shard-cache.js +81 -12
  283. package/src/loader/shard-resolver.js +25 -3
  284. package/src/loader/tensors/tensor-loader.js +209 -144
  285. package/src/loader/tensors/tensor-reader.js +76 -19
  286. package/src/loader/weight-downcast.js +1 -1
  287. package/src/memory/buffer-pool.d.ts +9 -1
  288. package/src/memory/buffer-pool.js +109 -44
  289. package/src/memory/unified-detect.js +1 -1
  290. package/src/rules/inference/kernel-path.rules.json +24 -8
  291. package/src/rules/rule-registry.js +25 -1
  292. package/src/rules/tooling/command-runtime.rules.json +18 -0
  293. package/src/storage/backends/opfs-store.js +68 -24
  294. package/src/storage/downloader.js +364 -83
  295. package/src/storage/index.d.ts +3 -0
  296. package/src/storage/index.js +3 -0
  297. package/src/storage/preflight.d.ts +2 -2
  298. package/src/storage/preflight.js +24 -2
  299. package/src/storage/quickstart-downloader.js +11 -5
  300. package/src/storage/registry.js +10 -4
  301. package/src/storage/reports.js +1 -1
  302. package/src/storage/shard-manager.d.ts +15 -1
  303. package/src/storage/shard-manager.js +51 -3
  304. package/src/storage/source-artifact-store.d.ts +52 -0
  305. package/src/storage/source-artifact-store.js +234 -0
  306. package/src/tooling/command-api-constants.d.ts +9 -0
  307. package/src/tooling/command-api-constants.js +9 -0
  308. package/src/tooling/command-api-family-normalizers.d.ts +9 -0
  309. package/src/tooling/command-api-family-normalizers.js +343 -0
  310. package/src/tooling/command-api-helpers.d.ts +25 -0
  311. package/src/tooling/command-api-helpers.js +262 -0
  312. package/src/tooling/command-api.d.ts +27 -1
  313. package/src/tooling/command-api.js +26 -473
  314. package/src/tooling/command-envelope.js +4 -1
  315. package/src/tooling/command-runner-shared.js +52 -18
  316. package/src/tooling/lean-execution-contract.js +150 -3
  317. package/src/tooling/node-browser-command-runner.d.ts +4 -0
  318. package/src/tooling/node-browser-command-runner.js +218 -273
  319. package/src/tooling/node-command-runner.js +44 -3
  320. package/src/tooling/node-converter.js +27 -1
  321. package/src/tooling/node-source-runtime.d.ts +1 -1
  322. package/src/tooling/node-source-runtime.js +84 -3
  323. package/src/tooling/node-webgpu.js +30 -105
  324. package/src/tooling/opfs-cache.js +21 -4
  325. package/src/tooling/runtime-input-composition.d.ts +38 -0
  326. package/src/tooling/runtime-input-composition.js +86 -0
  327. package/src/tooling/source-runtime-bundle.d.ts +40 -5
  328. package/src/tooling/source-runtime-bundle.js +261 -34
  329. package/src/tooling/source-runtime-materializer.d.ts +6 -0
  330. package/src/tooling/source-runtime-materializer.js +93 -0
  331. package/src/training/attention-backward.js +32 -17
  332. package/src/training/autograd.js +80 -52
  333. package/src/training/checkpoint-watch.d.ts +8 -0
  334. package/src/training/checkpoint-watch.js +139 -0
  335. package/src/training/checkpoint.d.ts +6 -1
  336. package/src/training/checkpoint.js +46 -7
  337. package/src/training/clip.js +2 -1
  338. package/src/training/datasets/token-batch.js +20 -8
  339. package/src/training/distillation/artifacts.d.ts +71 -0
  340. package/src/training/distillation/artifacts.js +132 -0
  341. package/src/training/distillation/checkpoint-watch.d.ts +10 -0
  342. package/src/training/distillation/checkpoint-watch.js +58 -0
  343. package/src/training/distillation/dataset.d.ts +59 -0
  344. package/src/training/distillation/dataset.js +337 -0
  345. package/src/training/distillation/eval.d.ts +34 -0
  346. package/src/training/distillation/eval.js +310 -0
  347. package/src/training/distillation/index.d.ts +29 -0
  348. package/src/training/distillation/index.js +29 -0
  349. package/src/training/distillation/runtime.d.ts +20 -0
  350. package/src/training/distillation/runtime.js +121 -0
  351. package/src/training/distillation/scoreboard.d.ts +6 -0
  352. package/src/training/distillation/scoreboard.js +8 -0
  353. package/src/training/distillation/stage-a.d.ts +45 -0
  354. package/src/training/distillation/stage-a.js +338 -0
  355. package/src/training/distillation/stage-b.d.ts +24 -0
  356. package/src/training/distillation/stage-b.js +20 -0
  357. package/src/training/distillation/student-fixture.d.ts +22 -0
  358. package/src/training/distillation/student-fixture.js +846 -0
  359. package/src/training/distillation/suite-data.d.ts +45 -0
  360. package/src/training/distillation/suite-data.js +189 -0
  361. package/src/training/index.d.ts +10 -0
  362. package/src/training/index.js +10 -0
  363. package/src/training/lora-pipeline.d.ts +40 -0
  364. package/src/training/lora-pipeline.js +793 -0
  365. package/src/training/lora.js +26 -12
  366. package/src/training/loss.js +5 -6
  367. package/src/training/objectives/cross_entropy.js +2 -5
  368. package/src/training/objectives/distill_kd.js +4 -8
  369. package/src/training/objectives/distill_triplet.js +4 -8
  370. package/src/training/objectives/ul_stage2_base.js +4 -8
  371. package/src/training/operator-artifacts.d.ts +62 -0
  372. package/src/training/operator-artifacts.js +140 -0
  373. package/src/training/operator-command.d.ts +5 -0
  374. package/src/training/operator-command.js +455 -0
  375. package/src/training/operator-eval.d.ts +48 -0
  376. package/src/training/operator-eval.js +230 -0
  377. package/src/training/operator-scoreboard.d.ts +5 -0
  378. package/src/training/operator-scoreboard.js +44 -0
  379. package/src/training/optimizer.js +19 -7
  380. package/src/training/runner.d.ts +52 -0
  381. package/src/training/runner.js +31 -5
  382. package/src/training/suite.d.ts +112 -0
  383. package/src/training/suite.js +24 -984
  384. package/src/training/tensor-factory.d.ts +9 -0
  385. package/src/training/tensor-factory.js +13 -0
  386. package/src/training/trainer.js +3 -5
  387. package/src/training/ul_dataset.js +3 -5
  388. package/src/training/workloads.d.ts +164 -0
  389. package/src/training/workloads.js +530 -0
  390. package/src/version.js +1 -1
  391. package/tools/convert-safetensors-node.js +22 -16
  392. package/tools/doppler-cli.js +179 -63
@@ -1,7 +1,7 @@
1
1
 
2
2
 
3
3
  import { getDevice } from '../device.js';
4
- import { acquireBuffer } from '../../memory/buffer-pool.js';
4
+ import { acquireBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
5
5
  import { createTensor, dtypeBytes } from '../tensor.js';
6
6
  import { getBuffer } from '../weight-buffer.js';
7
7
  import { dispatch, recordDispatch } from './dispatch.js';
@@ -47,7 +47,12 @@ export async function runMatmulResidualFused(
47
47
  const pipelineVariant = resolveFusedResidualVariant(input, residual);
48
48
  const pipeline = await getPipelineFast('fused_matmul_residual', pipelineVariant);
49
49
 
50
- const output = outputBuffer || acquireBuffer(N * dtypeBytes(outputDtype), undefined, 'matmul_residual_output');
50
+ const ownedOutput = outputBuffer ? null : acquireBuffer(
51
+ N * dtypeBytes(outputDtype),
52
+ undefined,
53
+ 'matmul_residual_output'
54
+ );
55
+ const output = outputBuffer || ownedOutput;
51
56
 
52
57
  // Create uniform buffer (same layout as matmul_gemv)
53
58
  const uniformBuffer = createUniformBufferWithView(
@@ -68,21 +73,28 @@ export async function runMatmulResidualFused(
68
73
  );
69
74
 
70
75
  // Create bind group
71
- const bindGroup = device.createBindGroup({
72
- label: 'matmul_residual_bind_group',
73
- layout: pipeline.getBindGroupLayout(0),
74
- entries: [
75
- { binding: 0, resource: { buffer: uniformBuffer } },
76
- { binding: 1, resource: { buffer: input.buffer } },
77
- { binding: 2, resource: { buffer: weightBuffer } },
78
- { binding: 3, resource: { buffer: output } },
79
- { binding: 4, resource: { buffer: residual.buffer } },
80
- ],
81
- });
82
-
83
- // One workgroup per output element
84
- const workgroups = N;
85
- dispatch(device, pipeline, bindGroup, workgroups, 'matmul_residual_fused');
76
+ try {
77
+ const bindGroup = device.createBindGroup({
78
+ label: 'matmul_residual_bind_group',
79
+ layout: pipeline.getBindGroupLayout(0),
80
+ entries: [
81
+ { binding: 0, resource: { buffer: uniformBuffer } },
82
+ { binding: 1, resource: { buffer: input.buffer } },
83
+ { binding: 2, resource: { buffer: weightBuffer } },
84
+ { binding: 3, resource: { buffer: output } },
85
+ { binding: 4, resource: { buffer: residual.buffer } },
86
+ ],
87
+ });
88
+
89
+ const workgroups = N;
90
+ dispatch(device, pipeline, bindGroup, workgroups, 'matmul_residual_fused');
91
+ } catch (error) {
92
+ uniformBuffer.destroy();
93
+ if (ownedOutput) {
94
+ releaseBuffer(ownedOutput);
95
+ }
96
+ throw error;
97
+ }
86
98
 
87
99
  uniformBuffer.destroy();
88
100
 
@@ -112,7 +124,12 @@ export async function recordMatmulResidualFused(
112
124
  const pipelineVariant = resolveFusedResidualVariant(input, residual);
113
125
  const pipeline = await getPipelineFast('fused_matmul_residual', pipelineVariant);
114
126
 
115
- const output = outputBuffer || acquireBuffer(N * dtypeBytes(outputDtype), undefined, 'matmul_residual_output');
127
+ const ownedOutput = outputBuffer ? null : acquireBuffer(
128
+ N * dtypeBytes(outputDtype),
129
+ undefined,
130
+ 'matmul_residual_output'
131
+ );
132
+ const output = outputBuffer || ownedOutput;
116
133
 
117
134
  // Create uniform buffer
118
135
  const uniformBuffer = createUniformBufferWithView(
@@ -132,21 +149,27 @@ export async function recordMatmulResidualFused(
132
149
  );
133
150
 
134
151
  // Create bind group
135
- const bindGroup = device.createBindGroup({
136
- label: 'matmul_residual_bind_group',
137
- layout: pipeline.getBindGroupLayout(0),
138
- entries: [
139
- { binding: 0, resource: { buffer: uniformBuffer } },
140
- { binding: 1, resource: { buffer: input.buffer } },
141
- { binding: 2, resource: { buffer: weightBuffer } },
142
- { binding: 3, resource: { buffer: output } },
143
- { binding: 4, resource: { buffer: residual.buffer } },
144
- ],
145
- });
146
-
147
- // One workgroup per output element
148
- const workgroups = N;
149
- recordDispatch(recorder, pipeline, bindGroup, workgroups, 'matmul_residual_fused');
152
+ try {
153
+ const bindGroup = device.createBindGroup({
154
+ label: 'matmul_residual_bind_group',
155
+ layout: pipeline.getBindGroupLayout(0),
156
+ entries: [
157
+ { binding: 0, resource: { buffer: uniformBuffer } },
158
+ { binding: 1, resource: { buffer: input.buffer } },
159
+ { binding: 2, resource: { buffer: weightBuffer } },
160
+ { binding: 3, resource: { buffer: output } },
161
+ { binding: 4, resource: { buffer: residual.buffer } },
162
+ ],
163
+ });
164
+
165
+ const workgroups = N;
166
+ recordDispatch(recorder, pipeline, bindGroup, workgroups, 'matmul_residual_fused');
167
+ } catch (error) {
168
+ if (ownedOutput) {
169
+ releaseBuffer(ownedOutput);
170
+ }
171
+ throw error;
172
+ }
150
173
 
151
174
  return createTensor(output, outputDtype, [1, N], 'matmul_residual_output');
152
175
  }
@@ -1,7 +1,7 @@
1
1
 
2
2
 
3
3
  import { getDevice } from '../device.js';
4
- import { acquireBuffer, getBufferRequestedSize } from '../../memory/buffer-pool.js';
4
+ import { acquireBuffer, getBufferRequestedSize, releaseBuffer } from '../../memory/buffer-pool.js';
5
5
  import { createTensor } from '../tensor.js';
6
6
  import { getBuffer } from '../weight-buffer.js';
7
7
  import { dispatch, recordDispatch } from './dispatch.js';
@@ -91,7 +91,8 @@ export async function runMatmulRMSNormFused(
91
91
  // Output buffer: [1, N] - size depends on dtype
92
92
  const bytesPerElement = dtype === 'f16' ? 2 : 4;
93
93
  const outputSize = N * bytesPerElement;
94
- const output = outputBuffer || acquireBuffer(outputSize, undefined, 'matmul_rmsnorm_fused_output');
94
+ const ownedOutput = outputBuffer ? null : acquireBuffer(outputSize, undefined, 'matmul_rmsnorm_fused_output');
95
+ const output = outputBuffer || ownedOutput;
95
96
 
96
97
  // Create uniform buffer (8 u32/f32 = 32 bytes, padded for alignment)
97
98
  const uniformBuffer = createUniformBufferWithView(
@@ -110,36 +111,44 @@ export async function runMatmulRMSNormFused(
110
111
  );
111
112
 
112
113
  // Create placeholder for residual if not provided
114
+ const ownsResidualBuffer = !residual;
113
115
  const residualBuffer = residual || device.createBuffer({
114
116
  label: 'matmul_rmsnorm_residual_placeholder',
115
117
  size: 4,
116
118
  usage: GPUBufferUsage.STORAGE,
117
119
  });
118
120
 
119
- // Create bind group
120
- const bindGroup = device.createBindGroup({
121
- label: 'matmul_rmsnorm_fused_bind_group',
122
- layout: pipeline.getBindGroupLayout(0),
123
- entries: [
124
- { binding: 0, resource: { buffer: uniformBuffer } },
125
- { binding: 1, resource: { buffer: input.buffer } },
126
- { binding: 2, resource: { buffer: weightBuffer } },
127
- { binding: 3, resource: { buffer: normWeightBuffer } },
128
- { binding: 4, resource: { buffer: output } },
129
- { binding: 5, resource: { buffer: residualBuffer } },
130
- ],
131
- });
132
-
133
- // Calculate workgroups
134
-
135
- const workgroups = 1;
136
-
137
- const dispatchLabel = label ? `matmul_rmsnorm_fused:${label}` : 'matmul_rmsnorm_fused';
138
- dispatch(device, pipeline, bindGroup, workgroups, dispatchLabel);
121
+ try {
122
+ const bindGroup = device.createBindGroup({
123
+ label: 'matmul_rmsnorm_fused_bind_group',
124
+ layout: pipeline.getBindGroupLayout(0),
125
+ entries: [
126
+ { binding: 0, resource: { buffer: uniformBuffer } },
127
+ { binding: 1, resource: { buffer: input.buffer } },
128
+ { binding: 2, resource: { buffer: weightBuffer } },
129
+ { binding: 3, resource: { buffer: normWeightBuffer } },
130
+ { binding: 4, resource: { buffer: output } },
131
+ { binding: 5, resource: { buffer: residualBuffer } },
132
+ ],
133
+ });
134
+
135
+ const workgroups = 1;
136
+ const dispatchLabel = label ? `matmul_rmsnorm_fused:${label}` : 'matmul_rmsnorm_fused';
137
+ dispatch(device, pipeline, bindGroup, workgroups, dispatchLabel);
138
+ } catch (error) {
139
+ uniformBuffer.destroy();
140
+ if (ownsResidualBuffer) {
141
+ residualBuffer.destroy();
142
+ }
143
+ if (ownedOutput) {
144
+ releaseBuffer(ownedOutput);
145
+ }
146
+ throw error;
147
+ }
139
148
 
140
149
  // Cleanup
141
150
  uniformBuffer.destroy();
142
- if (!residual) residualBuffer.destroy();
151
+ if (ownsResidualBuffer) residualBuffer.destroy();
143
152
 
144
153
  // Output dtype matches input dtype
145
154
  return createTensor(output, input.dtype, [1, N], 'matmul_rmsnorm_fused_output');
@@ -199,7 +208,8 @@ export async function recordMatmulRMSNormFused(
199
208
  // Output buffer - size depends on dtype
200
209
  const bytesPerElement = dtype === 'f16' ? 2 : 4;
201
210
  const outputSize = N * bytesPerElement;
202
- const output = outputBuffer || acquireBuffer(outputSize, undefined, 'matmul_rmsnorm_fused_output');
211
+ const ownedOutput = outputBuffer ? null : acquireBuffer(outputSize, undefined, 'matmul_rmsnorm_fused_output');
212
+ const output = outputBuffer || ownedOutput;
203
213
 
204
214
  // Uniform buffer via recorder (8 u32/f32 = 32 bytes, padded for alignment)
205
215
  const uniformBuffer = createUniformBufferWithView(
@@ -217,35 +227,42 @@ export async function recordMatmulRMSNormFused(
217
227
  );
218
228
 
219
229
  // Placeholder for residual
230
+ const ownsResidualBuffer = !residual;
220
231
  const residualBuffer = residual || device.createBuffer({
221
232
  label: 'matmul_rmsnorm_residual_placeholder',
222
233
  size: 4,
223
234
  usage: GPUBufferUsage.STORAGE,
224
235
  });
225
236
 
226
- // Bind group
227
- const bindGroup = device.createBindGroup({
228
- label: 'matmul_rmsnorm_fused_bind_group',
229
- layout: pipeline.getBindGroupLayout(0),
230
- entries: [
231
- { binding: 0, resource: { buffer: uniformBuffer } },
232
- { binding: 1, resource: { buffer: input.buffer } },
233
- { binding: 2, resource: { buffer: weightBuffer } },
234
- { binding: 3, resource: { buffer: normWeightBuffer } },
235
- { binding: 4, resource: { buffer: output } },
236
- { binding: 5, resource: { buffer: residualBuffer } },
237
- ],
238
- });
239
-
240
- // Calculate workgroups
241
-
242
- const workgroups = 1;
243
-
244
- const dispatchLabel = label ? `matmul_rmsnorm_fused:${label}` : 'matmul_rmsnorm_fused';
245
- recordDispatch(recorder, pipeline, bindGroup, workgroups, dispatchLabel);
237
+ try {
238
+ const bindGroup = device.createBindGroup({
239
+ label: 'matmul_rmsnorm_fused_bind_group',
240
+ layout: pipeline.getBindGroupLayout(0),
241
+ entries: [
242
+ { binding: 0, resource: { buffer: uniformBuffer } },
243
+ { binding: 1, resource: { buffer: input.buffer } },
244
+ { binding: 2, resource: { buffer: weightBuffer } },
245
+ { binding: 3, resource: { buffer: normWeightBuffer } },
246
+ { binding: 4, resource: { buffer: output } },
247
+ { binding: 5, resource: { buffer: residualBuffer } },
248
+ ],
249
+ });
250
+
251
+ const workgroups = 1;
252
+ const dispatchLabel = label ? `matmul_rmsnorm_fused:${label}` : 'matmul_rmsnorm_fused';
253
+ recordDispatch(recorder, pipeline, bindGroup, workgroups, dispatchLabel);
254
+ } catch (error) {
255
+ if (ownsResidualBuffer) {
256
+ residualBuffer.destroy();
257
+ }
258
+ if (ownedOutput) {
259
+ releaseBuffer(ownedOutput);
260
+ }
261
+ throw error;
262
+ }
246
263
 
247
264
  // Track placeholder for cleanup
248
- if (!residual) {
265
+ if (ownsResidualBuffer) {
249
266
  recorder.trackTemporaryBuffer(residualBuffer);
250
267
  }
251
268
 
@@ -1,5 +1,5 @@
1
1
  import { getKernelCapabilities } from '../device.js';
2
- import { acquireBuffer } from '../../memory/buffer-pool.js';
2
+ import { acquireBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
3
3
  import { WORKGROUP_SIZES, VEC4_ELEMENTS_PER_WG } from './constants.js';
4
4
  import { unifiedKernelWrapper } from './utils.js';
5
5
  import { trace } from '../../debug/index.js';
@@ -26,7 +26,6 @@ async function _gather(
26
26
  options = {}
27
27
  ) {
28
28
  const {
29
- useVec4 = true,
30
29
  outputBuffer = null,
31
30
  embeddingDtype,
32
31
  outputDtype,
@@ -43,9 +42,22 @@ async function _gather(
43
42
  if (outputDtype == null) {
44
43
  throw new Error('[Gather] outputDtype is required.');
45
44
  }
45
+ if (embeddingDtype === 'f16' && !caps.hasF16) {
46
+ throw new Error('[Gather] embeddingDtype=f16 requires shader-f16 support.');
47
+ }
48
+ if (outputDtype === 'f16' && !caps.hasF16) {
49
+ throw new Error('[Gather] outputDtype=f16 requires shader-f16 support.');
50
+ }
46
51
 
47
- const useF16Input = embeddingDtype === 'f16' && caps.hasF16;
48
- const useF16Output = outputDtype === 'f16' && caps.hasF16;
52
+ const requestedVec4 = options.useVec4;
53
+ const wantsVec4 = requestedVec4 ?? true;
54
+ if (requestedVec4 === true && hiddenSize % 4 !== 0) {
55
+ throw new Error('[Gather] useVec4=true requires hiddenSize to be divisible by 4.');
56
+ }
57
+
58
+ const useF16Input = embeddingDtype === 'f16';
59
+ const useF16Output = outputDtype === 'f16';
60
+ const useVec4 = wantsVec4 && hiddenSize % 4 === 0;
49
61
 
50
62
  trace.embed(
51
63
  `Gather: numTokens=${numTokens}, hiddenSize=${hiddenSize}, vocabSize=${vocabSize}, ` +
@@ -64,6 +76,7 @@ async function _gather(
64
76
  const paddedHiddenSize = padToQ4KBlock(hiddenSize);
65
77
  const outputSize = numTokens * paddedHiddenSize * bytesPerElement;
66
78
  const output = outputBuffer || acquireBuffer(outputSize, undefined, 'gather_output');
79
+ const ownedOutput = outputBuffer ? null : output;
67
80
 
68
81
  const uniforms = {
69
82
  num_tokens: numTokens,
@@ -82,16 +95,22 @@ async function _gather(
82
95
  ? Math.ceil((numTokens * hiddenSize) / VEC4_ELEMENTS_PER_WG)
83
96
  : Math.ceil((numTokens * hiddenSize) / WORKGROUP_SIZES.DEFAULT));
84
97
 
85
- await unifiedKernelWrapper(
86
- 'gather',
87
- target,
88
- variant,
89
- [indices, embeddings, output],
90
- uniforms,
91
- workgroups
92
- );
93
-
94
- return createTensor(output, actualDtype, [numTokens, hiddenSize], 'gather_output');
98
+ try {
99
+ await unifiedKernelWrapper(
100
+ 'gather',
101
+ target,
102
+ variant,
103
+ [indices, embeddings, output],
104
+ uniforms,
105
+ workgroups
106
+ );
107
+ return createTensor(output, actualDtype, [numTokens, hiddenSize], 'gather_output');
108
+ } catch (error) {
109
+ if (ownedOutput) {
110
+ releaseBuffer(ownedOutput);
111
+ }
112
+ throw error;
113
+ }
95
114
  }
96
115
 
97
116
  export async function runGather(
@@ -116,4 +135,3 @@ export async function recordGather(
116
135
  ) {
117
136
  return _gather(recorder, indices, embeddings, numTokens, hiddenSize, vocabSize, options);
118
137
  }
119
-
@@ -1,5 +1,5 @@
1
1
 
2
- import { acquireBuffer } from '../../memory/buffer-pool.js';
2
+ import { acquireBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
3
3
  import { createTensor, dtypeBytes } from '../tensor.js';
4
4
  import { WORKGROUP_SIZES } from './constants.js';
5
5
  import { unifiedKernelWrapper } from './utils.js';
@@ -26,16 +26,24 @@ async function _gelu(target, input, options = {}) {
26
26
  const outputSize = inferredSize * bytesPerElement;
27
27
  const output = outputBuffer || acquireBuffer(outputSize, undefined, 'gelu_output');
28
28
  const gateBuffer = gate ?? input;
29
-
30
- await unifiedKernelWrapper(
31
- 'gelu', target, variant,
32
- [input, output, gateBuffer],
33
- { size: inferredSize, rowsplit_dim: 0 },
34
- Math.ceil(inferredSize / WORKGROUP_SIZES.DEFAULT),
35
- overrides
36
- );
37
-
38
- return createTensor(output, input.dtype, [inferredSize], 'gelu_output');
29
+ const ownedOutput = outputBuffer ? null : output;
30
+
31
+ try {
32
+ await unifiedKernelWrapper(
33
+ 'gelu', target, variant,
34
+ [input, output, gateBuffer],
35
+ { size: inferredSize, rowsplit_dim: 0 },
36
+ Math.ceil(inferredSize / WORKGROUP_SIZES.DEFAULT),
37
+ overrides
38
+ );
39
+
40
+ return createTensor(output, input.dtype, [inferredSize], 'gelu_output');
41
+ } catch (error) {
42
+ if (ownedOutput) {
43
+ releaseBuffer(ownedOutput);
44
+ }
45
+ throw error;
46
+ }
39
47
  }
40
48
 
41
49
  export async function runGeLU(input, options = {}) {
@@ -42,6 +42,7 @@ async function _groupedPointwiseConv2D(target, input, weight, bias, options = {}
42
42
  const bytesPerElement = dtypeBytes(input.dtype);
43
43
  const outputSize = outChannels * height * width * bytesPerElement;
44
44
  const output = outputBuffer || acquireBuffer(outputSize, undefined, 'grouped_pointwise_conv2d_output');
45
+ const spatial = height * width;
45
46
 
46
47
  const weightBuffer = getBuffer(weight);
47
48
  let biasBuffer = getBuffer(bias);
@@ -54,33 +55,43 @@ async function _groupedPointwiseConv2D(target, input, weight, bias, options = {}
54
55
  device.queue.writeBuffer(biasBuffer, 0, new Uint8Array(paddedSize));
55
56
  }
56
57
 
57
- await unifiedKernelWrapper(
58
- 'grouped_pointwise_conv2d',
59
- target,
60
- variant,
61
- [input, weightBuffer, biasBuffer, output],
62
- {
63
- in_channels: inChannels,
64
- out_channels: outChannels,
65
- height,
66
- width,
67
- groups,
68
- _pad0: 0,
69
- _pad1: 0,
70
- _pad2: 0,
71
- },
72
- Math.ceil((outChannels * height * width) / WORKGROUP_SIZES.DEFAULT)
73
- );
58
+ try {
59
+ await unifiedKernelWrapper(
60
+ 'grouped_pointwise_conv2d',
61
+ target,
62
+ variant,
63
+ [input, weightBuffer, biasBuffer, output],
64
+ {
65
+ in_channels: inChannels,
66
+ out_channels: outChannels,
67
+ height,
68
+ width,
69
+ groups,
70
+ _pad0: 0,
71
+ _pad1: 0,
72
+ _pad2: 0,
73
+ },
74
+ [Math.ceil(spatial / WORKGROUP_SIZES.DEFAULT), outChannels, 1]
75
+ );
76
+
77
+ if (tempBias) {
78
+ if (recorder) {
79
+ recorder.trackTemporaryBuffer(tempBias);
80
+ } else {
81
+ releaseBuffer(tempBias);
82
+ }
83
+ }
74
84
 
75
- if (tempBias) {
76
- if (recorder) {
77
- recorder.trackTemporaryBuffer(tempBias);
78
- } else {
85
+ return createTensor(output, input.dtype, [outChannels, height, width], 'grouped_pointwise_conv2d_output');
86
+ } catch (error) {
87
+ if (tempBias) {
79
88
  releaseBuffer(tempBias);
80
89
  }
90
+ if (!outputBuffer) {
91
+ releaseBuffer(output);
92
+ }
93
+ throw error;
81
94
  }
82
-
83
- return createTensor(output, input.dtype, [outChannels, height, width], 'grouped_pointwise_conv2d_output');
84
95
  }
85
96
 
86
97
  export async function runGroupedPointwiseConv2D(input, weight, bias, options = {}) {
@@ -19,17 +19,14 @@ struct Uniforms {
19
19
 
20
20
  @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
21
21
  fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
22
- let idx = gid.x;
23
22
  let spatial = u.height * u.width;
24
- let out_size = u.out_channels * spatial;
25
- if (idx >= out_size) {
23
+ let spatial_idx = gid.x;
24
+ let out_channel = gid.y;
25
+ if (spatial_idx >= spatial || out_channel >= u.out_channels) {
26
26
  return;
27
27
  }
28
-
29
- let out_channel = idx / spatial;
30
- let rem = idx - out_channel * spatial;
31
- let y = rem / u.width;
32
- let x = rem - y * u.width;
28
+ let y = spatial_idx / u.width;
29
+ let x = spatial_idx - y * u.width;
33
30
 
34
31
  let in_per_group = u.in_channels / u.groups;
35
32
  let out_per_group = u.out_channels / u.groups;
@@ -43,5 +40,5 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
43
40
  sum = sum + input[input_idx] * weight[weight_idx];
44
41
  }
45
42
 
46
- output[idx] = sum;
43
+ output[out_channel * spatial + spatial_idx] = sum;
47
44
  }
@@ -23,17 +23,14 @@ struct Uniforms {
23
23
 
24
24
  @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
25
25
  fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
26
- let idx = gid.x;
27
26
  let spatial = u.height * u.width;
28
- let out_size = u.out_channels * spatial;
29
- if (idx >= out_size) {
27
+ let spatial_idx = gid.x;
28
+ let out_channel = gid.y;
29
+ if (spatial_idx >= spatial || out_channel >= u.out_channels) {
30
30
  return;
31
31
  }
32
-
33
- let out_channel = idx / spatial;
34
- let rem = idx - out_channel * spatial;
35
- let y = rem / u.width;
36
- let x = rem - y * u.width;
32
+ let y = spatial_idx / u.width;
33
+ let x = spatial_idx - y * u.width;
37
34
 
38
35
  let in_per_group = u.in_channels / u.groups;
39
36
  let out_per_group = u.out_channels / u.groups;
@@ -47,5 +44,5 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
47
44
  sum = sum + f32(input[input_idx]) * f32(weight[weight_idx]);
48
45
  }
49
46
 
50
- output[idx] = f16(sum);
47
+ output[out_channel * spatial + spatial_idx] = f16(sum);
51
48
  }
@@ -17,6 +17,9 @@ function validateOptions(options) {
17
17
  if (!Number.isFinite(numGroups) || numGroups <= 0) {
18
18
  throw new Error('GroupNorm requires numGroups > 0.');
19
19
  }
20
+ if (channels % numGroups !== 0) {
21
+ throw new Error('GroupNorm requires channels to be divisible by numGroups.');
22
+ }
20
23
  if (!Number.isFinite(eps)) {
21
24
  throw new Error('GroupNorm requires eps.');
22
25
  }
@@ -44,34 +47,42 @@ async function _groupNorm(target, input, weight, bias, options = {}) {
44
47
 
45
48
  const statsSize = numGroups * 2 * 4;
46
49
  const statsBuffer = acquireBuffer(statsSize, undefined, 'groupnorm_stats');
47
-
48
- await unifiedKernelWrapper(
49
- 'groupnorm_stats',
50
- target,
51
- statsVariant,
52
- [input, statsBuffer],
53
- uniforms,
54
- numGroups
55
- );
56
-
57
50
  const bytesPerElement = dtypeBytes(input.dtype);
58
51
  const outputSize = channels * height * width * bytesPerElement;
59
- const output = outputBuffer || acquireBuffer(outputSize, undefined, 'groupnorm_output');
52
+ const ownedOutput = outputBuffer ? null : acquireBuffer(outputSize, undefined, 'groupnorm_output');
53
+ const output = outputBuffer || ownedOutput;
60
54
 
61
- const weightBuffer = getBuffer(weight);
62
- const biasBuffer = getBuffer(bias);
55
+ try {
56
+ await unifiedKernelWrapper(
57
+ 'groupnorm_stats',
58
+ target,
59
+ statsVariant,
60
+ [input, statsBuffer],
61
+ uniforms,
62
+ numGroups
63
+ );
63
64
 
64
- const total = channels * height * width;
65
- const workgroups = Math.ceil(total / WORKGROUP_SIZES.DEFAULT);
65
+ const weightBuffer = getBuffer(weight);
66
+ const biasBuffer = getBuffer(bias);
66
67
 
67
- await unifiedKernelWrapper(
68
- 'groupnorm_apply',
69
- target,
70
- applyVariant,
71
- [input, statsBuffer, weightBuffer, biasBuffer, output],
72
- uniforms,
73
- workgroups
74
- );
68
+ const total = channels * height * width;
69
+ const workgroups = Math.ceil(total / WORKGROUP_SIZES.DEFAULT);
70
+
71
+ await unifiedKernelWrapper(
72
+ 'groupnorm_apply',
73
+ target,
74
+ applyVariant,
75
+ [input, statsBuffer, weightBuffer, biasBuffer, output],
76
+ uniforms,
77
+ workgroups
78
+ );
79
+ } catch (error) {
80
+ releaseBuffer(statsBuffer);
81
+ if (ownedOutput) {
82
+ releaseBuffer(ownedOutput);
83
+ }
84
+ throw error;
85
+ }
75
86
 
76
87
  if (recorder) {
77
88
  recorder.trackTemporaryBuffer(statsBuffer);
@@ -78,8 +78,11 @@ export async function runKVQuantize(
78
78
  });
79
79
 
80
80
  const workgroups = [numKVHeads, numTokens, 1];
81
- dispatch(device, pipeline, bindGroup, workgroups, 'kv_quantize');
82
- uniformBuffer.destroy();
81
+ try {
82
+ dispatch(device, pipeline, bindGroup, workgroups, 'kv_quantize');
83
+ } finally {
84
+ uniformBuffer.destroy();
85
+ }
83
86
  }
84
87
 
85
88