@simulatte/doppler 0.1.6 → 0.1.8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (355) hide show
  1. package/CHANGELOG.md +145 -0
  2. package/README.md +16 -23
  3. package/package.json +30 -32
  4. package/src/adapters/adapter-registry.js +12 -1
  5. package/src/adapters/lora-loader.js +23 -6
  6. package/src/bridge/extension-client.d.ts +5 -0
  7. package/src/bridge/extension-client.js +40 -0
  8. package/src/bridge/index.d.ts +2 -1
  9. package/src/bridge/index.js +6 -4
  10. package/src/browser/browser-converter.js +31 -1
  11. package/src/browser/file-picker.js +6 -0
  12. package/src/browser/safetensors-parser-browser.js +84 -1
  13. package/src/browser/shard-io-browser.js +2 -2
  14. package/src/browser/tensor-source-download.js +8 -2
  15. package/src/browser/tensor-source-http.d.ts +1 -0
  16. package/src/browser/tensor-source-http.js +5 -1
  17. package/src/client/doppler-api.browser.js +20 -4
  18. package/src/client/doppler-api.js +19 -3
  19. package/src/client/doppler-provider/generation.js +12 -0
  20. package/src/client/doppler-provider/model-manager.d.ts +10 -0
  21. package/src/client/doppler-provider/model-manager.js +91 -19
  22. package/src/client/doppler-provider/source-runtime.d.ts +2 -1
  23. package/src/client/doppler-provider/source-runtime.js +132 -13
  24. package/src/client/doppler-registry.json +5 -20
  25. package/src/config/backward-registry-loader.js +17 -2
  26. package/src/config/execution-v0-contract-check.js +113 -15
  27. package/src/config/kernel-path-contract-check.js +57 -29
  28. package/src/config/kernel-path-loader.d.ts +5 -0
  29. package/src/config/kernel-path-loader.js +18 -36
  30. package/src/config/kernels/kernel-ref-digests.js +1 -1
  31. package/src/config/kernels/registry.js +14 -1
  32. package/src/config/kernels/registry.json +81 -5
  33. package/src/config/loader.d.ts +1 -1
  34. package/src/config/loader.js +15 -2
  35. package/src/config/merge-contract-check.js +66 -4
  36. package/src/config/merge-helpers.js +128 -7
  37. package/src/config/merge.d.ts +1 -0
  38. package/src/config/merge.js +10 -0
  39. package/src/config/param-validator.js +47 -2
  40. package/src/config/presets/kernel-paths/{gemma2-q4k-dequant-f32a.json → gemma2-q4k-dequant-f32a-nosubgroups.json} +3 -3
  41. package/src/config/presets/kernel-paths/gemma3-f16-fused-f32a-online-streamingprefill.json +223 -0
  42. package/src/config/presets/kernel-paths/{gemma3-q4k-dequant-f32a.json → gemma3-q4k-dequant-f32a-nosubgroups.json} +3 -3
  43. package/src/config/presets/kernel-paths/gemma3-q4k-dequant-f32w-f32a-online.json +56 -0
  44. package/src/config/presets/kernel-paths/lfm2-q4k-dequant-f32a-nosubgroups.json +61 -0
  45. package/src/config/presets/kernel-paths/registry.json +43 -8
  46. package/src/config/presets/models/gemma2.json +3 -2
  47. package/src/config/presets/models/gemma3.json +2 -0
  48. package/src/config/presets/models/qwen3.json +4 -3
  49. package/src/config/presets/models/qwen3_5.json +16 -0
  50. package/src/config/presets/runtime/experiments/bench/gemma3-bench-q4k.json +1 -1
  51. package/src/config/presets/runtime/experiments/debug/gemma3-debug-q4k.json +1 -1
  52. package/src/config/presets/runtime/experiments/verify/gemma3-verify.json +1 -1
  53. package/src/config/presets/runtime/kernels/dequant-f16-q4k.json +6 -13
  54. package/src/config/presets/runtime/kernels/dequant-f32-q4k.json +6 -13
  55. package/src/config/presets/runtime/kernels/embeddinggemma-q4k-dequant-f32a.json +37 -0
  56. package/src/config/presets/runtime/kernels/fused-q4k.json +6 -13
  57. package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f16a.json +33 -0
  58. package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f32a-nosubgroups.json +33 -0
  59. package/src/config/presets/runtime/kernels/gemma2-q4k-fused-f32a.json +33 -0
  60. package/src/config/presets/runtime/kernels/safe-q4k.json +6 -13
  61. package/src/config/presets/runtime/model/qwen3-5-layer-probe.json +52 -0
  62. package/src/config/presets/runtime/model/qwen3-5-linear-attn-debug.json +90 -0
  63. package/src/config/presets/runtime/platform/metal-apple-q4k.json +1 -1
  64. package/src/config/runtime.js +6 -1
  65. package/src/config/schema/conversion.schema.d.ts +1 -0
  66. package/src/config/schema/debug.schema.d.ts +5 -0
  67. package/src/config/schema/doppler.schema.js +16 -21
  68. package/src/config/schema/inference-defaults.schema.js +3 -3
  69. package/src/config/schema/kernel-path.schema.d.ts +5 -1
  70. package/src/config/schema/kernel-thresholds.schema.js +12 -4
  71. package/src/config/schema/manifest.schema.d.ts +3 -2
  72. package/src/config/schema/manifest.schema.js +17 -4
  73. package/src/config/schema/storage.schema.js +1 -1
  74. package/src/config/training-defaults.js +30 -22
  75. package/src/converter/conversion-plan.js +104 -11
  76. package/src/converter/core.d.ts +7 -0
  77. package/src/converter/core.js +16 -9
  78. package/src/converter/execution-v0-manifest.js +4 -1
  79. package/src/converter/index.d.ts +1 -0
  80. package/src/converter/index.js +1 -0
  81. package/src/converter/manifest-inference.js +50 -29
  82. package/src/converter/parsers/diffusion.js +0 -3
  83. package/src/converter/parsers/transformer.js +4 -0
  84. package/src/converter/quantization-info.js +40 -16
  85. package/src/converter/quantizer.js +19 -12
  86. package/src/converter/rope-config.js +8 -6
  87. package/src/converter/shard-packer.d.ts +1 -1
  88. package/src/converter/shard-packer.js +4 -1
  89. package/src/converter/tokenizer-utils.d.ts +1 -0
  90. package/src/converter/tokenizer-utils.js +4 -1
  91. package/src/debug/config.js +123 -11
  92. package/src/debug/reference/hf_qwen35_linear_attn_debug.py +268 -0
  93. package/src/debug/signals.js +7 -1
  94. package/src/debug/tensor.d.ts +2 -0
  95. package/src/debug/tensor.js +13 -2
  96. package/src/distribution/p2p-control-plane.js +52 -12
  97. package/src/distribution/p2p-observability.js +43 -7
  98. package/src/distribution/p2p-webrtc-browser.js +20 -0
  99. package/src/distribution/shard-delivery.js +83 -27
  100. package/src/formats/gguf/types.js +33 -16
  101. package/src/formats/rdrr/groups.d.ts +12 -4
  102. package/src/formats/rdrr/groups.js +3 -6
  103. package/src/formats/rdrr/parsing.d.ts +4 -0
  104. package/src/formats/rdrr/parsing.js +53 -3
  105. package/src/formats/rdrr/types.d.ts +2 -1
  106. package/src/gpu/command-recorder.js +86 -61
  107. package/src/gpu/device.d.ts +1 -0
  108. package/src/gpu/device.js +73 -19
  109. package/src/gpu/kernel-tuner/benchmarks.js +326 -316
  110. package/src/gpu/kernel-tuner/cache.js +71 -4
  111. package/src/gpu/kernel-tuner/tuner.js +22 -4
  112. package/src/gpu/kernels/attention.js +15 -34
  113. package/src/gpu/kernels/backward/adam.js +62 -58
  114. package/src/gpu/kernels/backward/attention_backward.js +257 -169
  115. package/src/gpu/kernels/backward/conv2d_backward.js +14 -1
  116. package/src/gpu/kernels/cast.js +191 -149
  117. package/src/gpu/kernels/check-stop.js +33 -44
  118. package/src/gpu/kernels/conv2d.js +27 -17
  119. package/src/gpu/kernels/cross_entropy_loss.js +21 -15
  120. package/src/gpu/kernels/depthwise_conv2d.js +36 -26
  121. package/src/gpu/kernels/dequant.js +178 -126
  122. package/src/gpu/kernels/energy.d.ts +3 -21
  123. package/src/gpu/kernels/energy.js +111 -88
  124. package/src/gpu/kernels/feature-check.js +1 -1
  125. package/src/gpu/kernels/fused_ffn.js +84 -65
  126. package/src/gpu/kernels/fused_matmul_residual.js +56 -33
  127. package/src/gpu/kernels/fused_matmul_rmsnorm.js +62 -45
  128. package/src/gpu/kernels/gather.js +33 -15
  129. package/src/gpu/kernels/gelu.js +19 -11
  130. package/src/gpu/kernels/grouped_pointwise_conv2d.js +33 -23
  131. package/src/gpu/kernels/groupnorm.js +34 -23
  132. package/src/gpu/kernels/index.d.ts +8 -0
  133. package/src/gpu/kernels/index.js +6 -0
  134. package/src/gpu/kernels/kv-quantize.js +5 -2
  135. package/src/gpu/kernels/layernorm.js +35 -19
  136. package/src/gpu/kernels/logit-merge.js +5 -3
  137. package/src/gpu/kernels/matmul-selection.js +47 -4
  138. package/src/gpu/kernels/matmul.d.ts +2 -0
  139. package/src/gpu/kernels/matmul.js +59 -40
  140. package/src/gpu/kernels/modulate.js +23 -15
  141. package/src/gpu/kernels/moe.js +221 -175
  142. package/src/gpu/kernels/pixel_shuffle.js +22 -14
  143. package/src/gpu/kernels/relu.js +18 -10
  144. package/src/gpu/kernels/repeat_channels.js +25 -17
  145. package/src/gpu/kernels/residual.js +37 -27
  146. package/src/gpu/kernels/rmsnorm.js +66 -43
  147. package/src/gpu/kernels/rope.js +3 -0
  148. package/src/gpu/kernels/sample.js +27 -38
  149. package/src/gpu/kernels/sana_linear_attention.js +18 -10
  150. package/src/gpu/kernels/scale.js +18 -11
  151. package/src/gpu/kernels/shader-cache.js +4 -2
  152. package/src/gpu/kernels/silu.js +120 -72
  153. package/src/gpu/kernels/softmax.js +44 -25
  154. package/src/gpu/kernels/split_qg.d.ts +50 -0
  155. package/src/gpu/kernels/split_qg.js +46 -0
  156. package/src/gpu/kernels/split_qg.wgsl +58 -0
  157. package/src/gpu/kernels/split_qg_f16.wgsl +62 -0
  158. package/src/gpu/kernels/split_qkv.js +23 -13
  159. package/src/gpu/kernels/transpose.js +18 -10
  160. package/src/gpu/kernels/transpose.wgsl +5 -3
  161. package/src/gpu/kernels/upsample2d.js +21 -13
  162. package/src/gpu/kernels/utils.js +20 -13
  163. package/src/gpu/partitioned-buffer-pool.js +10 -2
  164. package/src/gpu/perf-guards.js +2 -9
  165. package/src/gpu/profiler.js +27 -22
  166. package/src/gpu/readback-utils.d.ts +16 -0
  167. package/src/gpu/readback-utils.js +41 -0
  168. package/src/gpu/submit-tracker.js +13 -0
  169. package/src/gpu/uniform-cache.d.ts +1 -0
  170. package/src/gpu/uniform-cache.js +30 -9
  171. package/src/gpu/weight-buffer.d.ts +1 -1
  172. package/src/gpu/weight-buffer.js +1 -1
  173. package/src/hotswap/intent-bundle.js +6 -0
  174. package/src/hotswap/manifest.d.ts +10 -1
  175. package/src/hotswap/manifest.js +12 -2
  176. package/src/hotswap/runtime.js +30 -8
  177. package/src/index-browser.d.ts +44 -0
  178. package/src/index-browser.js +14 -0
  179. package/src/inference/browser-harness-contract-helpers.d.ts +5 -0
  180. package/src/inference/browser-harness-contract-helpers.js +28 -0
  181. package/src/inference/browser-harness-diffusion-energy-suites.d.ts +2 -0
  182. package/src/inference/browser-harness-diffusion-energy-suites.js +269 -0
  183. package/src/inference/browser-harness-model-helpers.d.ts +16 -0
  184. package/src/inference/browser-harness-model-helpers.js +217 -0
  185. package/src/inference/browser-harness-report-helpers.d.ts +7 -0
  186. package/src/inference/browser-harness-report-helpers.js +42 -0
  187. package/src/inference/browser-harness-runtime-helpers.d.ts +61 -0
  188. package/src/inference/browser-harness-runtime-helpers.js +415 -0
  189. package/src/inference/browser-harness-suite-helpers.d.ts +28 -0
  190. package/src/inference/browser-harness-suite-helpers.js +268 -0
  191. package/src/inference/browser-harness-text-helpers.d.ts +27 -0
  192. package/src/inference/browser-harness-text-helpers.js +788 -0
  193. package/src/inference/browser-harness.d.ts +8 -0
  194. package/src/inference/browser-harness.js +149 -1996
  195. package/src/inference/kv-cache/base.js +140 -94
  196. package/src/inference/kv-cache/tiered.js +5 -3
  197. package/src/inference/moe-router.js +88 -56
  198. package/src/inference/multi-model-network.js +5 -3
  199. package/src/inference/network-evolution.d.ts +11 -2
  200. package/src/inference/network-evolution.js +20 -21
  201. package/src/inference/pipelines/context.d.ts +3 -0
  202. package/src/inference/pipelines/context.js +142 -2
  203. package/src/inference/pipelines/diffusion/helpers.js +10 -2
  204. package/src/inference/pipelines/diffusion/pipeline.js +2 -1
  205. package/src/inference/pipelines/diffusion/sd3-transformer.js +10 -10
  206. package/src/inference/pipelines/diffusion/text-encoder-gpu.js +8 -2
  207. package/src/inference/pipelines/diffusion/vae.js +3 -7
  208. package/src/inference/pipelines/energy/pipeline.js +27 -21
  209. package/src/inference/pipelines/energy/quintel.d.ts +5 -0
  210. package/src/inference/pipelines/energy/quintel.js +11 -0
  211. package/src/inference/pipelines/energy-head/row-head-pipeline.js +17 -13
  212. package/src/inference/pipelines/structured/json-head-pipeline.js +26 -11
  213. package/src/inference/pipelines/text/attention/output-projection.d.ts +12 -0
  214. package/src/inference/pipelines/text/attention/output-projection.js +8 -0
  215. package/src/inference/pipelines/text/attention/projections.d.ts +10 -1
  216. package/src/inference/pipelines/text/attention/projections.js +192 -112
  217. package/src/inference/pipelines/text/attention/record.js +77 -14
  218. package/src/inference/pipelines/text/attention/run.js +112 -14
  219. package/src/inference/pipelines/text/config.js +17 -4
  220. package/src/inference/pipelines/text/embed.js +2 -8
  221. package/src/inference/pipelines/text/execution-plan.js +46 -23
  222. package/src/inference/pipelines/text/execution-v0-contract-helpers.d.ts +59 -0
  223. package/src/inference/pipelines/text/execution-v0-contract-helpers.js +937 -0
  224. package/src/inference/pipelines/text/execution-v0-runtime-builders.d.ts +15 -0
  225. package/src/inference/pipelines/text/execution-v0-runtime-builders.js +279 -0
  226. package/src/inference/pipelines/text/execution-v0.js +62 -1013
  227. package/src/inference/pipelines/text/generator-runtime.js +5 -0
  228. package/src/inference/pipelines/text/generator-steps.d.ts +52 -0
  229. package/src/inference/pipelines/text/generator-steps.js +340 -221
  230. package/src/inference/pipelines/text/generator.js +56 -40
  231. package/src/inference/pipelines/text/init.d.ts +13 -0
  232. package/src/inference/pipelines/text/init.js +94 -25
  233. package/src/inference/pipelines/text/kernel-path-auto-select.js +2 -0
  234. package/src/inference/pipelines/text/kernel-trace.d.ts +2 -0
  235. package/src/inference/pipelines/text/kernel-trace.js +6 -0
  236. package/src/inference/pipelines/text/layer.js +4 -9
  237. package/src/inference/pipelines/text/linear-attention.d.ts +15 -0
  238. package/src/inference/pipelines/text/linear-attention.js +113 -9
  239. package/src/inference/pipelines/text/logits/gpu.js +12 -7
  240. package/src/inference/pipelines/text/logits/index.d.ts +6 -1
  241. package/src/inference/pipelines/text/logits/index.js +13 -12
  242. package/src/inference/pipelines/text/logits/utils.d.ts +7 -0
  243. package/src/inference/pipelines/text/logits/utils.js +9 -0
  244. package/src/inference/pipelines/text/lora-apply.js +50 -32
  245. package/src/inference/pipelines/text/model-load.js +282 -104
  246. package/src/inference/pipelines/text/moe-cache.js +5 -4
  247. package/src/inference/pipelines/text/moe-cpu-gptoss.js +74 -69
  248. package/src/inference/pipelines/text/moe-cpu.js +42 -38
  249. package/src/inference/pipelines/text/moe-gpu.js +110 -86
  250. package/src/inference/pipelines/text/ops.js +90 -90
  251. package/src/inference/pipelines/text/probes.js +9 -9
  252. package/src/inference/pipelines/text/sampling.js +52 -6
  253. package/src/inference/pipelines/text/weights.js +17 -7
  254. package/src/inference/pipelines/text.js +13 -1
  255. package/src/inference/speculative.d.ts +2 -2
  256. package/src/inference/speculative.js +4 -18
  257. package/src/inference/test-harness.d.ts +1 -1
  258. package/src/inference/test-harness.js +17 -7
  259. package/src/inference/tokenizer.d.ts +0 -5
  260. package/src/inference/tokenizer.js +4 -23
  261. package/src/inference/tokenizers/bpe.js +9 -0
  262. package/src/inference/tokenizers/bundled.js +20 -0
  263. package/src/inference/tokenizers/sentencepiece.js +12 -0
  264. package/src/loader/doppler-loader.js +38 -22
  265. package/src/loader/dtype-utils.js +3 -44
  266. package/src/loader/embedding-loader.js +7 -3
  267. package/src/loader/experts/expert-cache.js +13 -6
  268. package/src/loader/experts/expert-loader.js +10 -6
  269. package/src/loader/final-weights-loader.js +10 -4
  270. package/src/loader/layer-loader.js +2 -1
  271. package/src/loader/loader-state.js +2 -2
  272. package/src/loader/memory-monitor.js +8 -0
  273. package/src/loader/multi-model-loader.d.ts +14 -0
  274. package/src/loader/multi-model-loader.js +70 -24
  275. package/src/loader/shard-cache.js +84 -14
  276. package/src/loader/shard-resolver.js +25 -3
  277. package/src/loader/tensors/tensor-loader.js +214 -144
  278. package/src/loader/tensors/tensor-reader.js +76 -19
  279. package/src/loader/weight-downcast.js +1 -1
  280. package/src/memory/buffer-pool.d.ts +9 -1
  281. package/src/memory/buffer-pool.js +109 -44
  282. package/src/memory/unified-detect.js +1 -1
  283. package/src/rules/inference/dtype.rules.json +5 -0
  284. package/src/rules/inference/kernel-path.rules.json +24 -8
  285. package/src/rules/kernels/split-qg.rules.json +6 -0
  286. package/src/rules/rule-registry.js +27 -1
  287. package/src/storage/backends/opfs-store.js +68 -24
  288. package/src/storage/downloader.js +365 -83
  289. package/src/storage/index.d.ts +3 -0
  290. package/src/storage/index.js +3 -0
  291. package/src/storage/preflight.d.ts +2 -2
  292. package/src/storage/preflight.js +24 -2
  293. package/src/storage/quickstart-downloader.js +11 -5
  294. package/src/storage/registry.js +10 -4
  295. package/src/storage/reports.js +1 -1
  296. package/src/storage/shard-manager.d.ts +15 -1
  297. package/src/storage/shard-manager.js +55 -6
  298. package/src/storage/source-artifact-store.d.ts +52 -0
  299. package/src/storage/source-artifact-store.js +234 -0
  300. package/src/tooling/command-api-constants.d.ts +9 -0
  301. package/src/tooling/command-api-constants.js +9 -0
  302. package/src/tooling/command-api-family-normalizers.d.ts +9 -0
  303. package/src/tooling/command-api-family-normalizers.js +343 -0
  304. package/src/tooling/command-api-helpers.d.ts +25 -0
  305. package/src/tooling/command-api-helpers.js +262 -0
  306. package/src/tooling/command-api.js +16 -602
  307. package/src/tooling/command-envelope.js +4 -1
  308. package/src/tooling/command-runner-shared.js +52 -18
  309. package/src/tooling/conversion-config-materializer.js +3 -5
  310. package/src/tooling/lean-execution-contract.js +150 -3
  311. package/src/tooling/node-browser-command-runner.js +161 -271
  312. package/src/tooling/node-command-runner.js +29 -3
  313. package/src/tooling/node-converter.js +30 -1
  314. package/src/tooling/node-source-runtime.d.ts +1 -1
  315. package/src/tooling/node-source-runtime.js +120 -3
  316. package/src/tooling/node-webgpu.js +24 -21
  317. package/src/tooling/opfs-cache.js +21 -4
  318. package/src/tooling/runtime-input-composition.d.ts +38 -0
  319. package/src/tooling/runtime-input-composition.js +86 -0
  320. package/src/tooling/source-runtime-bundle.d.ts +40 -5
  321. package/src/tooling/source-runtime-bundle.js +261 -34
  322. package/src/tooling/source-runtime-materializer.d.ts +6 -0
  323. package/src/tooling/source-runtime-materializer.js +93 -0
  324. package/src/training/attention-backward.js +32 -17
  325. package/src/training/autograd.js +80 -52
  326. package/src/training/checkpoint-watch.d.ts +2 -1
  327. package/src/training/checkpoint-watch.js +39 -6
  328. package/src/training/checkpoint.js +40 -11
  329. package/src/training/clip.js +2 -1
  330. package/src/training/datasets/token-batch.js +20 -8
  331. package/src/training/distillation/checkpoint-watch.js +1 -0
  332. package/src/training/distillation/student-fixture.d.ts +22 -0
  333. package/src/training/distillation/student-fixture.js +846 -0
  334. package/src/training/distillation/suite-data.d.ts +45 -0
  335. package/src/training/distillation/suite-data.js +189 -0
  336. package/src/training/lora-pipeline.js +4 -7
  337. package/src/training/lora.js +26 -12
  338. package/src/training/loss.js +5 -6
  339. package/src/training/objectives/cross_entropy.js +2 -5
  340. package/src/training/objectives/distill_kd.js +4 -8
  341. package/src/training/objectives/distill_triplet.js +4 -8
  342. package/src/training/objectives/ul_stage2_base.js +4 -8
  343. package/src/training/operator-command.js +2 -0
  344. package/src/training/optimizer.js +19 -7
  345. package/src/training/runner.js +2 -1
  346. package/src/training/suite.js +18 -978
  347. package/src/training/tensor-factory.d.ts +9 -0
  348. package/src/training/tensor-factory.js +13 -0
  349. package/src/training/trainer.js +3 -5
  350. package/src/training/ul_dataset.js +3 -5
  351. package/src/training/workloads.js +70 -79
  352. package/src/types/model.d.ts +5 -0
  353. package/src/version.js +1 -1
  354. package/tools/convert-safetensors-node.js +22 -16
  355. package/tools/doppler-cli.js +50 -26
@@ -1,4 +1,4 @@
1
- import { acquireBuffer } from '../../memory/buffer-pool.js';
1
+ import { acquireBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
2
2
  import { createTensor, dtypeBytes } from '../tensor.js';
3
3
  import { unifiedKernelWrapper } from './utils.js';
4
4
  import { selectRuleValue } from './rule-registry.js';
@@ -25,19 +25,27 @@ async function _pixelShuffle(target, input, options = {}) {
25
25
  const bytesPerElement = dtypeBytes(input.dtype);
26
26
  const outputSize = outChannels * outHeight * outWidth * bytesPerElement;
27
27
  const output = outputBuffer || acquireBuffer(outputSize, undefined, 'pixel_shuffle_output');
28
-
29
- await unifiedKernelWrapper(
30
- 'pixel_shuffle', target, variant,
31
- [input, output],
32
- {
33
- out_channels: outChannels, out_height: outHeight, out_width: outWidth,
34
- grid_width: gridWidth, grid_height: gridHeight, patch_size: patchSize,
35
- patch_channels: inferredPatchChannels, _pad0: 0,
36
- },
37
- [Math.ceil((outHeight * outWidth) / WORKGROUP_SIZES.DEFAULT), outChannels, 1]
38
- );
39
-
40
- return createTensor(output, input.dtype, [outChannels, outHeight, outWidth], 'pixel_shuffle_output');
28
+ const ownedOutput = outputBuffer ? null : output;
29
+
30
+ try {
31
+ await unifiedKernelWrapper(
32
+ 'pixel_shuffle', target, variant,
33
+ [input, output],
34
+ {
35
+ out_channels: outChannels, out_height: outHeight, out_width: outWidth,
36
+ grid_width: gridWidth, grid_height: gridHeight, patch_size: patchSize,
37
+ patch_channels: inferredPatchChannels, _pad0: 0,
38
+ },
39
+ [Math.ceil((outHeight * outWidth) / WORKGROUP_SIZES.DEFAULT), outChannels, 1]
40
+ );
41
+
42
+ return createTensor(output, input.dtype, [outChannels, outHeight, outWidth], 'pixel_shuffle_output');
43
+ } catch (error) {
44
+ if (ownedOutput) {
45
+ releaseBuffer(ownedOutput);
46
+ }
47
+ throw error;
48
+ }
41
49
  }
42
50
 
43
51
  export async function runPixelShuffle(input, options = {}) {
@@ -1,4 +1,4 @@
1
- import { acquireBuffer } from '../../memory/buffer-pool.js';
1
+ import { acquireBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
2
2
  import { createTensor, dtypeBytes } from '../tensor.js';
3
3
  import { unifiedKernelWrapper } from './utils.js';
4
4
  import { selectRuleValue } from './rule-registry.js';
@@ -35,18 +35,26 @@ async function _relu(target, input, options = {}) {
35
35
  const size = resolveCount(input, count);
36
36
  const variant = selectReluVariant(input.dtype);
37
37
  const output = outputBuffer || acquireBuffer(size * dtypeBytes(input.dtype), undefined, 'relu_output');
38
+ const ownedOutput = outputBuffer ? null : output;
38
39
  const dispatchPlan = planReluDispatch(target, size);
39
40
 
40
- await unifiedKernelWrapper(
41
- 'relu',
42
- target,
43
- variant,
44
- [input, output],
45
- { size, _pad0: dispatchPlan.dispatchStride, _pad1: 0, _pad2: 0 },
46
- dispatchPlan.workgroups
47
- );
41
+ try {
42
+ await unifiedKernelWrapper(
43
+ 'relu',
44
+ target,
45
+ variant,
46
+ [input, output],
47
+ { size, _pad0: dispatchPlan.dispatchStride, _pad1: 0, _pad2: 0 },
48
+ dispatchPlan.workgroups
49
+ );
48
50
 
49
- return createTensor(output, input.dtype, [...input.shape], 'relu_output');
51
+ return createTensor(output, input.dtype, [...input.shape], 'relu_output');
52
+ } catch (error) {
53
+ if (ownedOutput) {
54
+ releaseBuffer(ownedOutput);
55
+ }
56
+ throw error;
57
+ }
50
58
  }
51
59
 
52
60
  export async function runReLU(input, options = {}) {
@@ -1,4 +1,4 @@
1
- import { acquireBuffer } from '../../memory/buffer-pool.js';
1
+ import { acquireBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
2
2
  import { createTensor, dtypeBytes } from '../tensor.js';
3
3
  import { unifiedKernelWrapper } from './utils.js';
4
4
  import { selectRuleValue } from './rule-registry.js';
@@ -32,23 +32,31 @@ async function _repeatChannels(target, input, options = {}) {
32
32
  const bytesPerElement = dtypeBytes(input.dtype);
33
33
  const outputSize = outChannels * height * width * bytesPerElement;
34
34
  const output = outputBuffer || acquireBuffer(outputSize, undefined, 'repeat_channels_output');
35
+ const ownedOutput = outputBuffer ? null : output;
35
36
 
36
- await unifiedKernelWrapper(
37
- 'repeat_channels',
38
- target,
39
- variant,
40
- [input, output],
41
- {
42
- in_channels: inChannels,
43
- height,
44
- width,
45
- repeats,
46
- _pad0: 0,
47
- },
48
- [Math.ceil((height * width) / WORKGROUP_SIZES.DEFAULT), outChannels, 1]
49
- );
50
-
51
- return createTensor(output, input.dtype, [outChannels, height, width], 'repeat_channels_output');
37
+ try {
38
+ await unifiedKernelWrapper(
39
+ 'repeat_channels',
40
+ target,
41
+ variant,
42
+ [input, output],
43
+ {
44
+ in_channels: inChannels,
45
+ height,
46
+ width,
47
+ repeats,
48
+ _pad0: 0,
49
+ },
50
+ [Math.ceil((height * width) / WORKGROUP_SIZES.DEFAULT), outChannels, 1]
51
+ );
52
+
53
+ return createTensor(output, input.dtype, [outChannels, height, width], 'repeat_channels_output');
54
+ } catch (error) {
55
+ if (ownedOutput) {
56
+ releaseBuffer(ownedOutput);
57
+ }
58
+ throw error;
59
+ }
52
60
  }
53
61
 
54
62
  export async function runRepeatChannels(input, options = {}) {
@@ -82,6 +82,7 @@ function planResidualDispatch(target, size, elementsPerWorkgroup) {
82
82
  async function _residualAdd(target, a, b, size, options = {}) {
83
83
  const recorder = target && typeof target.beginComputePass === 'function' ? target : null;
84
84
  const { useVec4 = true, outputBuffer = null } = options;
85
+ const ownsOutput = outputBuffer == null;
85
86
 
86
87
  const { a: aAligned, b: bAligned, temps } = await alignResidualInputs(a, b, recorder);
87
88
  const outputDtype = inferOutputDtype(aAligned, bAligned);
@@ -97,15 +98,22 @@ async function _residualAdd(target, a, b, size, options = {}) {
97
98
  useVec4 ? VEC4_ELEMENTS_PER_WG : WORKGROUP_SIZES.DEFAULT
98
99
  );
99
100
 
100
- await unifiedKernelWrapper(
101
- 'residual', target, variant,
102
- [aAligned, bAligned, output],
103
- { size, scale: 1, _pad1: dispatchPlan.dispatchStride, _pad2: 0 },
104
- dispatchPlan.workgroups
105
- );
106
-
107
- cleanupTemps(temps, recorder);
108
- return createTensor(output, outputDtype, [size], 'residual_output');
101
+ try {
102
+ await unifiedKernelWrapper(
103
+ 'residual', target, variant,
104
+ [aAligned, bAligned, output],
105
+ { size, scale: 1, _pad1: dispatchPlan.dispatchStride, _pad2: 0 },
106
+ dispatchPlan.workgroups
107
+ );
108
+ return createTensor(output, outputDtype, [size], 'residual_output');
109
+ } catch (error) {
110
+ if (ownsOutput) {
111
+ releaseBuffer(output);
112
+ }
113
+ throw error;
114
+ } finally {
115
+ cleanupTemps(temps, recorder);
116
+ }
109
117
  }
110
118
 
111
119
  async function _biasAdd(target, data, bias, numTokens, dim, options = {}) {
@@ -126,24 +134,26 @@ async function _biasAdd(target, data, bias, numTokens, dim, options = {}) {
126
134
  Math.ceil(numTokens / tokenStride),
127
135
  ];
128
136
 
129
- await unifiedKernelWrapper(
130
- 'bias_add', target, variant,
131
- [data, biasAligned],
132
- {
133
- num_tokens: numTokens,
134
- dim,
135
- data_offset: dataOffset,
136
- bias_offset: biasOffset,
137
- token_stride: tokenStride,
138
- _pad0: 0,
139
- _pad1: 0,
140
- _pad2: 0,
141
- },
142
- workgroups
143
- );
144
-
145
- cleanupTemps(temps, recorder);
146
- return createTensor(data.buffer, data.dtype, [numTokens, dim], 'bias_add_output');
137
+ try {
138
+ await unifiedKernelWrapper(
139
+ 'bias_add', target, variant,
140
+ [data, biasAligned],
141
+ {
142
+ num_tokens: numTokens,
143
+ dim,
144
+ data_offset: dataOffset,
145
+ bias_offset: biasOffset,
146
+ token_stride: tokenStride,
147
+ _pad0: 0,
148
+ _pad1: 0,
149
+ _pad2: 0,
150
+ },
151
+ workgroups
152
+ );
153
+ return createTensor(data.buffer, data.dtype, [numTokens, dim], 'bias_add_output');
154
+ } finally {
155
+ cleanupTemps(temps, recorder);
156
+ }
147
157
  }
148
158
 
149
159
  export async function runResidualAdd(a, b, size, options = {}) {
@@ -1,7 +1,7 @@
1
1
 
2
2
 
3
3
  import { getKernelCapabilities } from '../device.js';
4
- import { acquireBuffer, getBufferRequestedSize } from '../../memory/buffer-pool.js';
4
+ import { acquireBuffer, getBufferRequestedSize, releaseBuffer } from '../../memory/buffer-pool.js';
5
5
  import { createTensor } from '../tensor.js';
6
6
  import { getKernelThresholds, padToQ4KBlock } from '../../config/schema/index.js';
7
7
  import { selectRuleValue } from './rule-registry.js';
@@ -9,6 +9,9 @@ import { selectRuleValue as selectLoaderRule } from '../../rules/rule-registry.j
9
9
  import { getBuffer, getWeightDtype, getBufferDtype } from '../weight-buffer.js';
10
10
  import { unifiedKernelWrapper } from './utils.js';
11
11
 
12
+ // Conservative fallback dtype for norm weight inference when metadata is unavailable.
13
+ const DEFAULT_DTYPE = 'f32';
14
+
12
15
  function inferHiddenSize(input, hiddenSize) {
13
16
  if (hiddenSize != null) return hiddenSize;
14
17
  const shape = input?.shape;
@@ -39,9 +42,12 @@ function resolveNormWeightDtype(weight, hiddenSize) {
39
42
  return taggedDtype;
40
43
  }
41
44
 
45
+ // Conservative fallback: f32 avoids precision loss when dtype cannot be determined.
46
+ // This path fires for non-GPU buffers or missing hiddenSize, both of which prevent
47
+ // size-based dtype inference below.
42
48
  const hasGPUBufferType = typeof GPUBuffer !== 'undefined';
43
49
  if (!hasGPUBufferType || !(weightBuffer instanceof GPUBuffer) || hiddenSize == null || hiddenSize <= 0) {
44
- return 'f32';
50
+ return DEFAULT_DTYPE;
45
51
  }
46
52
 
47
53
  const byteSize = getBufferRequestedSize(weightBuffer);
@@ -55,7 +61,8 @@ function resolveNormWeightDtype(weight, hiddenSize) {
55
61
  sizeMatchesF32,
56
62
  });
57
63
  }
58
- return 'f32';
64
+ // Buffer size matches neither f16 nor f32 for given hiddenSize; fall back to f32.
65
+ return DEFAULT_DTYPE;
59
66
  }
60
67
 
61
68
  function assertRMSNormWeightBuffer(weight, weightBuffer, hiddenSize) {
@@ -119,31 +126,39 @@ export async function runRMSNorm(
119
126
  const paddedHiddenSize = padToQ4KBlock(inferredHiddenSize);
120
127
  const outputSize = batchSize * paddedHiddenSize * bytesPerElement;
121
128
  const outputBuf = outputBuffer || acquireBuffer(outputSize, undefined, 'rmsnorm_output');
129
+ const ownedOutput = outputBuffer ? null : outputBuf;
122
130
  const dispatchPlan = planRMSNormDispatch(null, batchSize);
123
131
 
124
132
  // Shader layout always includes the residual binding; when unused, bind a harmless placeholder.
125
133
  const residualBuf = residual?.buffer || residual || input?.buffer || input || outputBuf;
126
134
 
127
- await unifiedKernelWrapper(
128
- 'rmsnorm',
129
- null,
130
- variant,
131
- [input, normWeightBuffer, outputBuf, residualBuf],
132
- {
133
- hidden_size: inferredHiddenSize,
134
- num_tokens: batchSize,
135
- eps,
136
- has_residual: residual ? 1 : 0,
137
- token_stride: dispatchPlan.tokenStride,
138
- _pad0: 0,
139
- _pad1: 0,
140
- _pad2: 0,
141
- },
142
- dispatchPlan.workgroups,
143
- { RMS_NORM_OFFSET: rmsNormWeightOffset, WEIGHT_IS_F16: normWeightDtype === 'f16' }
144
- );
145
-
146
- return createTensor(outputBuf, input.dtype, [batchSize, inferredHiddenSize], 'rmsnorm_output');
135
+ try {
136
+ await unifiedKernelWrapper(
137
+ 'rmsnorm',
138
+ null,
139
+ variant,
140
+ [input, normWeightBuffer, outputBuf, residualBuf],
141
+ {
142
+ hidden_size: inferredHiddenSize,
143
+ num_tokens: batchSize,
144
+ eps,
145
+ has_residual: residual ? 1 : 0,
146
+ token_stride: dispatchPlan.tokenStride,
147
+ _pad0: 0,
148
+ _pad1: 0,
149
+ _pad2: 0,
150
+ },
151
+ dispatchPlan.workgroups,
152
+ { RMS_NORM_OFFSET: rmsNormWeightOffset, WEIGHT_IS_F16: normWeightDtype === 'f16' }
153
+ );
154
+
155
+ return createTensor(outputBuf, input.dtype, [batchSize, inferredHiddenSize], 'rmsnorm_output');
156
+ } catch (error) {
157
+ if (ownedOutput) {
158
+ releaseBuffer(ownedOutput);
159
+ }
160
+ throw error;
161
+ }
147
162
  }
148
163
 
149
164
  export async function recordRMSNorm(
@@ -165,28 +180,36 @@ export async function recordRMSNorm(
165
180
  const paddedHiddenSize = padToQ4KBlock(inferredHiddenSize);
166
181
  const outputSize = batchSize * paddedHiddenSize * bytesPerElement;
167
182
  const outputBuf = outputBuffer || acquireBuffer(outputSize, undefined, 'rmsnorm_output');
183
+ const ownedOutput = outputBuffer ? null : outputBuf;
168
184
  const dispatchPlan = planRMSNormDispatch(recorder, batchSize);
169
185
 
170
186
  const residualBuf = residual?.buffer || residual || input?.buffer || input || outputBuf;
171
187
 
172
- await unifiedKernelWrapper(
173
- 'rmsnorm',
174
- recorder,
175
- variant,
176
- [input, normWeightBuffer, outputBuf, residualBuf],
177
- {
178
- hidden_size: inferredHiddenSize,
179
- num_tokens: batchSize,
180
- eps,
181
- has_residual: residual ? 1 : 0,
182
- token_stride: dispatchPlan.tokenStride,
183
- _pad0: 0,
184
- _pad1: 0,
185
- _pad2: 0,
186
- },
187
- dispatchPlan.workgroups,
188
- { RMS_NORM_OFFSET: rmsNormWeightOffset, WEIGHT_IS_F16: normWeightDtype === 'f16' }
189
- );
190
-
191
- return createTensor(outputBuf, input.dtype, [batchSize, inferredHiddenSize], 'rmsnorm_output');
188
+ try {
189
+ await unifiedKernelWrapper(
190
+ 'rmsnorm',
191
+ recorder,
192
+ variant,
193
+ [input, normWeightBuffer, outputBuf, residualBuf],
194
+ {
195
+ hidden_size: inferredHiddenSize,
196
+ num_tokens: batchSize,
197
+ eps,
198
+ has_residual: residual ? 1 : 0,
199
+ token_stride: dispatchPlan.tokenStride,
200
+ _pad0: 0,
201
+ _pad1: 0,
202
+ _pad2: 0,
203
+ },
204
+ dispatchPlan.workgroups,
205
+ { RMS_NORM_OFFSET: rmsNormWeightOffset, WEIGHT_IS_F16: normWeightDtype === 'f16' }
206
+ );
207
+
208
+ return createTensor(outputBuf, input.dtype, [batchSize, inferredHiddenSize], 'rmsnorm_output');
209
+ } catch (error) {
210
+ if (ownedOutput) {
211
+ releaseBuffer(ownedOutput);
212
+ }
213
+ throw error;
214
+ }
192
215
  }
@@ -27,6 +27,9 @@ async function _rope(target, input, freqsCos, freqsSin, seqLen, options = {}) {
27
27
  if (rotaryDim <= 0 || rotaryDim > headDim) {
28
28
  throw new Error(`RoPE rotaryDim must be in (0, headDim]; got ${rotaryDim} for headDim ${headDim}`);
29
29
  }
30
+ if (input.dtype === 'f16' && (rotaryDim !== headDim || interleaved)) {
31
+ throw new Error('RoPE f16 kernel requires rotaryDim === headDim and interleaved === false.');
32
+ }
30
33
 
31
34
  const caps = getKernelCapabilities();
32
35
  const useF16 = input.dtype === 'f16' && caps.hasF16;
@@ -1,7 +1,7 @@
1
1
 
2
2
 
3
3
  import { getDevice, getKernelCapabilities } from '../device.js';
4
- import { acquireBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
4
+ import { acquireBuffer, readBufferSlice, releaseBuffer } from '../../memory/buffer-pool.js';
5
5
  import { WORKGROUP_SIZES } from './constants.js';
6
6
  import { createPipeline, createUniformBufferWithView, getOrCreateBindGroupLayout } from './utils.js';
7
7
  import { allowReadback } from '../perf-guards.js';
@@ -156,18 +156,19 @@ function ensureOutputBufferSize(outputBuffer, minBytes, label) {
156
156
  }
157
157
  }
158
158
 
159
- function readTokenFromOutput(device, outputBuffer, outputIndex, label) {
160
- const stagingBuffer = device.createBuffer({
161
- label,
162
- size: 4,
163
- usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST,
164
- });
165
-
166
- const copyEncoder = device.createCommandEncoder({ label: `${label}_copy` });
167
- copyEncoder.copyBufferToBuffer(outputBuffer, outputIndex * 4, stagingBuffer, 0, 4);
168
- device.queue.submit([copyEncoder.finish()]);
159
+ async function readTokenFromOutput(outputBuffer, outputIndex) {
160
+ return new Uint32Array(await readBufferSlice(outputBuffer, outputIndex * 4, 4))[0];
161
+ }
169
162
 
170
- return stagingBuffer;
163
+ function cleanupRunResources(uniformBuffer, ownedBuffers) {
164
+ if (uniformBuffer) {
165
+ uniformBuffer.destroy();
166
+ }
167
+ for (const buffer of ownedBuffers) {
168
+ if (buffer) {
169
+ releaseBuffer(buffer);
170
+ }
171
+ }
171
172
  }
172
173
 
173
174
  async function executeArgmaxRun(logits, vocabSize, options) {
@@ -238,20 +239,14 @@ async function executeArgmaxRun(logits, vocabSize, options) {
238
239
 
239
240
  device.queue.submit([encoder.finish()]);
240
241
 
241
- const stagingBuffer = readTokenFromOutput(device, outputBuffer, outputIndex, 'argmax_staging');
242
- await stagingBuffer.mapAsync(GPUMapMode.READ);
243
- const tokenId = new Uint32Array(stagingBuffer.getMappedRange())[0];
244
- stagingBuffer.unmap();
245
-
246
- stagingBuffer.destroy();
247
- uniformBuffer.destroy();
248
- releaseBuffer(tempLogits);
249
- releaseBuffer(tempIndices);
250
- if (ownsOutputBuffer) {
251
- releaseBuffer(outputBuffer);
242
+ try {
243
+ return await readTokenFromOutput(outputBuffer, outputIndex);
244
+ } finally {
245
+ cleanupRunResources(
246
+ uniformBuffer,
247
+ [tempLogits, tempIndices, ownsOutputBuffer ? outputBuffer : null]
248
+ );
252
249
  }
253
-
254
- return tokenId;
255
250
  }
256
251
 
257
252
  async function executeArgmaxRecord(recorder, logits, vocabSize, options) {
@@ -428,20 +423,14 @@ export async function runGPUSample(
428
423
 
429
424
  device.queue.submit([encoder.finish()]);
430
425
 
431
- const stagingBuffer = readTokenFromOutput(device, outputBuffer, outputIndex, 'sample_staging');
432
- await stagingBuffer.mapAsync(GPUMapMode.READ);
433
- const tokenId = new Uint32Array(stagingBuffer.getMappedRange())[0];
434
- stagingBuffer.unmap();
435
-
436
- stagingBuffer.destroy();
437
- uniformBuffer.destroy();
438
- releaseBuffer(topkLogits);
439
- releaseBuffer(topkIndices);
440
- if (ownsOutputBuffer) {
441
- releaseBuffer(outputBuffer);
426
+ try {
427
+ return await readTokenFromOutput(outputBuffer, outputIndex);
428
+ } finally {
429
+ cleanupRunResources(
430
+ uniformBuffer,
431
+ [topkLogits, topkIndices, ownsOutputBuffer ? outputBuffer : null]
432
+ );
442
433
  }
443
-
444
- return tokenId;
445
434
  }
446
435
 
447
436
 
@@ -64,6 +64,8 @@ async function _sanaLinearAttention(target, query, key, value, options = {}) {
64
64
  outputBuffer = null,
65
65
  summaryBuffer = null,
66
66
  } = options;
67
+ const ownsSummary = summaryBuffer == null;
68
+ const ownsOutput = outputBuffer == null;
67
69
 
68
70
  if (
69
71
  !Number.isFinite(numHeads) ||
@@ -98,18 +100,24 @@ async function _sanaLinearAttention(target, query, key, value, options = {}) {
98
100
  eps,
99
101
  };
100
102
 
101
- await runSummary(target, query, key, value, temporarySummary, uniforms, variant);
102
- await runApply(target, query, temporarySummary, output, uniforms, variant);
103
-
104
- if (!summaryBuffer) {
105
- if (recorder) {
106
- recorder.trackTemporaryBuffer(temporarySummary);
107
- } else {
108
- releaseBuffer(temporarySummary);
103
+ try {
104
+ await runSummary(target, query, key, value, temporarySummary, uniforms, variant);
105
+ await runApply(target, query, temporarySummary, output, uniforms, variant);
106
+ return createTensor(output, query.dtype, [numTokens, hiddenSize], 'sana_linear_attention_output');
107
+ } catch (error) {
108
+ if (ownsOutput) {
109
+ releaseBuffer(output);
110
+ }
111
+ throw error;
112
+ } finally {
113
+ if (ownsSummary) {
114
+ if (recorder) {
115
+ recorder.trackTemporaryBuffer(temporarySummary);
116
+ } else {
117
+ releaseBuffer(temporarySummary);
118
+ }
109
119
  }
110
120
  }
111
-
112
- return createTensor(output, query.dtype, [numTokens, hiddenSize], 'sana_linear_attention_output');
113
121
  }
114
122
 
115
123
  export async function runSanaLinearAttention(query, key, value, options = {}) {
@@ -1,4 +1,4 @@
1
- import { acquireBuffer } from '../../memory/buffer-pool.js';
1
+ import { acquireBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
2
2
  import { createTensor, dtypeBytes } from '../tensor.js';
3
3
  import { WORKGROUP_SIZES } from './constants.js';
4
4
  import { unifiedKernelWrapper } from './utils.js';
@@ -6,6 +6,7 @@ import { selectRuleValue } from './rule-registry.js';
6
6
 
7
7
  async function _scale(target, input, scale, options = {}) {
8
8
  const { count, outputBuffer = null, inplace = false } = options;
9
+ const ownsOutput = !inplace && outputBuffer == null;
9
10
 
10
11
  const bytesPerElement = dtypeBytes(input.dtype);
11
12
  const inferredCount = count ?? Math.floor(input.buffer.size / bytesPerElement);
@@ -16,16 +17,22 @@ async function _scale(target, input, scale, options = {}) {
16
17
 
17
18
  const bindings = inplace ? [outputBuf, outputBuf] : [input, outputBuf];
18
19
 
19
- await unifiedKernelWrapper(
20
- 'scale',
21
- target,
22
- variant,
23
- bindings,
24
- { size: inferredCount, scale },
25
- Math.ceil(inferredCount / WORKGROUP_SIZES.DEFAULT)
26
- );
27
-
28
- return createTensor(outputBuf, input.dtype, [...input.shape], 'scale_output');
20
+ try {
21
+ await unifiedKernelWrapper(
22
+ 'scale',
23
+ target,
24
+ variant,
25
+ bindings,
26
+ { size: inferredCount, scale },
27
+ Math.ceil(inferredCount / WORKGROUP_SIZES.DEFAULT)
28
+ );
29
+ return createTensor(outputBuf, input.dtype, [...input.shape], 'scale_output');
30
+ } catch (error) {
31
+ if (ownsOutput) {
32
+ releaseBuffer(outputBuf);
33
+ }
34
+ throw error;
35
+ }
29
36
  }
30
37
 
31
38
  export async function runScale(input, scale, options = {}) {
@@ -138,8 +138,10 @@ export async function compileShader(
138
138
  code: source,
139
139
  });
140
140
 
141
- // Check for compilation errors
142
- const compilationInfo = await module.getCompilationInfo();
141
+ // Check for compilation errors (getCompilationInfo not available in all WebGPU providers)
142
+ const compilationInfo = typeof module.getCompilationInfo === 'function'
143
+ ? await module.getCompilationInfo()
144
+ : { messages: [] };
143
145
  if (compilationInfo.messages.length > 0) {
144
146
  for (const msg of compilationInfo.messages) {
145
147
  if (msg.type === 'error') {