@simulatte/doppler 0.1.6 → 0.1.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (316) hide show
  1. package/CHANGELOG.md +126 -0
  2. package/README.md +16 -23
  3. package/package.json +14 -1
  4. package/src/adapters/adapter-registry.js +12 -1
  5. package/src/adapters/lora-loader.js +23 -6
  6. package/src/bridge/extension-client.d.ts +5 -0
  7. package/src/bridge/extension-client.js +40 -0
  8. package/src/bridge/index.d.ts +2 -1
  9. package/src/bridge/index.js +6 -4
  10. package/src/browser/browser-converter.js +26 -1
  11. package/src/browser/file-picker.js +6 -0
  12. package/src/browser/safetensors-parser-browser.js +84 -1
  13. package/src/browser/shard-io-browser.js +2 -2
  14. package/src/browser/tensor-source-download.js +8 -2
  15. package/src/browser/tensor-source-http.d.ts +1 -0
  16. package/src/browser/tensor-source-http.js +5 -1
  17. package/src/client/doppler-api.browser.js +20 -4
  18. package/src/client/doppler-api.js +19 -3
  19. package/src/client/doppler-provider/generation.js +12 -0
  20. package/src/client/doppler-provider/model-manager.d.ts +10 -0
  21. package/src/client/doppler-provider/model-manager.js +91 -19
  22. package/src/client/doppler-provider/source-runtime.d.ts +2 -1
  23. package/src/client/doppler-provider/source-runtime.js +132 -13
  24. package/src/client/doppler-registry.json +8 -7
  25. package/src/config/backward-registry-loader.js +17 -2
  26. package/src/config/execution-v0-contract-check.js +113 -15
  27. package/src/config/kernel-path-contract-check.js +57 -29
  28. package/src/config/kernel-path-loader.js +5 -36
  29. package/src/config/kernels/kernel-ref-digests.js +1 -1
  30. package/src/config/kernels/registry.js +14 -1
  31. package/src/config/kernels/registry.json +7 -5
  32. package/src/config/loader.d.ts +1 -1
  33. package/src/config/loader.js +12 -2
  34. package/src/config/merge-contract-check.js +59 -4
  35. package/src/config/merge-helpers.js +128 -7
  36. package/src/config/merge.d.ts +1 -0
  37. package/src/config/merge.js +10 -0
  38. package/src/config/param-validator.js +47 -2
  39. package/src/config/presets/kernel-paths/{gemma2-q4k-dequant-f32a.json → gemma2-q4k-dequant-f32a-nosubgroups.json} +3 -3
  40. package/src/config/presets/kernel-paths/gemma3-f16-fused-f32a-online-streamingprefill.json +223 -0
  41. package/src/config/presets/kernel-paths/{gemma3-q4k-dequant-f32a.json → gemma3-q4k-dequant-f32a-nosubgroups.json} +3 -3
  42. package/src/config/presets/kernel-paths/registry.json +29 -8
  43. package/src/config/presets/models/gemma2.json +2 -2
  44. package/src/config/presets/runtime/experiments/bench/gemma3-bench-q4k.json +1 -1
  45. package/src/config/presets/runtime/experiments/debug/gemma3-debug-q4k.json +1 -1
  46. package/src/config/presets/runtime/experiments/verify/gemma3-verify.json +1 -1
  47. package/src/config/presets/runtime/kernels/dequant-f16-q4k.json +6 -13
  48. package/src/config/presets/runtime/kernels/dequant-f32-q4k.json +6 -13
  49. package/src/config/presets/runtime/kernels/embeddinggemma-q4k-dequant-f32a.json +37 -0
  50. package/src/config/presets/runtime/kernels/fused-q4k.json +6 -13
  51. package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f16a.json +33 -0
  52. package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f32a-nosubgroups.json +33 -0
  53. package/src/config/presets/runtime/kernels/gemma2-q4k-fused-f32a.json +33 -0
  54. package/src/config/presets/runtime/kernels/safe-q4k.json +6 -13
  55. package/src/config/presets/runtime/platform/metal-apple-q4k.json +1 -1
  56. package/src/config/runtime.js +6 -1
  57. package/src/config/schema/debug.schema.d.ts +5 -0
  58. package/src/config/schema/doppler.schema.js +16 -21
  59. package/src/config/schema/inference-defaults.schema.js +3 -3
  60. package/src/config/schema/kernel-path.schema.d.ts +5 -1
  61. package/src/config/schema/kernel-thresholds.schema.js +12 -4
  62. package/src/config/schema/manifest.schema.d.ts +2 -1
  63. package/src/config/schema/manifest.schema.js +16 -3
  64. package/src/config/training-defaults.js +30 -22
  65. package/src/converter/conversion-plan.js +94 -9
  66. package/src/converter/core.d.ts +7 -0
  67. package/src/converter/core.js +14 -9
  68. package/src/converter/execution-v0-manifest.js +4 -1
  69. package/src/converter/index.d.ts +1 -0
  70. package/src/converter/index.js +1 -0
  71. package/src/converter/manifest-inference.js +43 -12
  72. package/src/converter/parsers/diffusion.js +0 -3
  73. package/src/converter/quantization-info.js +35 -15
  74. package/src/converter/shard-packer.d.ts +1 -1
  75. package/src/converter/shard-packer.js +4 -1
  76. package/src/debug/config.js +123 -11
  77. package/src/debug/signals.js +7 -1
  78. package/src/debug/tensor.d.ts +2 -0
  79. package/src/debug/tensor.js +13 -2
  80. package/src/distribution/p2p-control-plane.js +52 -12
  81. package/src/distribution/p2p-observability.js +43 -7
  82. package/src/distribution/p2p-webrtc-browser.js +20 -0
  83. package/src/distribution/shard-delivery.js +77 -26
  84. package/src/formats/gguf/types.js +33 -16
  85. package/src/formats/rdrr/groups.d.ts +12 -4
  86. package/src/formats/rdrr/groups.js +3 -6
  87. package/src/formats/rdrr/parsing.js +39 -2
  88. package/src/formats/rdrr/types.d.ts +2 -1
  89. package/src/gpu/command-recorder.js +86 -61
  90. package/src/gpu/device.d.ts +1 -0
  91. package/src/gpu/device.js +73 -19
  92. package/src/gpu/kernel-tuner/benchmarks.js +326 -316
  93. package/src/gpu/kernel-tuner/cache.js +71 -4
  94. package/src/gpu/kernel-tuner/tuner.js +22 -4
  95. package/src/gpu/kernels/attention.js +15 -34
  96. package/src/gpu/kernels/backward/adam.js +62 -58
  97. package/src/gpu/kernels/backward/attention_backward.js +257 -169
  98. package/src/gpu/kernels/backward/conv2d_backward.js +14 -1
  99. package/src/gpu/kernels/cast.js +191 -149
  100. package/src/gpu/kernels/check-stop.js +33 -44
  101. package/src/gpu/kernels/conv2d.js +27 -17
  102. package/src/gpu/kernels/cross_entropy_loss.js +21 -15
  103. package/src/gpu/kernels/depthwise_conv2d.js +36 -26
  104. package/src/gpu/kernels/dequant.js +178 -126
  105. package/src/gpu/kernels/energy.d.ts +3 -21
  106. package/src/gpu/kernels/energy.js +111 -88
  107. package/src/gpu/kernels/feature-check.js +1 -1
  108. package/src/gpu/kernels/fused_ffn.js +84 -65
  109. package/src/gpu/kernels/fused_matmul_residual.js +56 -33
  110. package/src/gpu/kernels/fused_matmul_rmsnorm.js +62 -45
  111. package/src/gpu/kernels/gather.js +33 -15
  112. package/src/gpu/kernels/gelu.js +19 -11
  113. package/src/gpu/kernels/grouped_pointwise_conv2d.js +33 -23
  114. package/src/gpu/kernels/groupnorm.js +34 -23
  115. package/src/gpu/kernels/kv-quantize.js +5 -2
  116. package/src/gpu/kernels/layernorm.js +35 -19
  117. package/src/gpu/kernels/logit-merge.js +5 -3
  118. package/src/gpu/kernels/matmul.js +58 -39
  119. package/src/gpu/kernels/modulate.js +23 -15
  120. package/src/gpu/kernels/moe.js +221 -175
  121. package/src/gpu/kernels/pixel_shuffle.js +22 -14
  122. package/src/gpu/kernels/relu.js +18 -10
  123. package/src/gpu/kernels/repeat_channels.js +25 -17
  124. package/src/gpu/kernels/residual.js +37 -27
  125. package/src/gpu/kernels/rmsnorm.js +57 -41
  126. package/src/gpu/kernels/rope.js +3 -0
  127. package/src/gpu/kernels/sample.js +27 -38
  128. package/src/gpu/kernels/sana_linear_attention.js +18 -10
  129. package/src/gpu/kernels/scale.js +18 -11
  130. package/src/gpu/kernels/shader-cache.js +4 -2
  131. package/src/gpu/kernels/silu.js +120 -72
  132. package/src/gpu/kernels/softmax.js +44 -25
  133. package/src/gpu/kernels/split_qkv.js +23 -13
  134. package/src/gpu/kernels/transpose.js +18 -10
  135. package/src/gpu/kernels/transpose.wgsl +5 -3
  136. package/src/gpu/kernels/upsample2d.js +21 -13
  137. package/src/gpu/kernels/utils.js +20 -13
  138. package/src/gpu/partitioned-buffer-pool.js +10 -2
  139. package/src/gpu/perf-guards.js +2 -9
  140. package/src/gpu/profiler.js +27 -22
  141. package/src/gpu/readback-utils.d.ts +16 -0
  142. package/src/gpu/readback-utils.js +41 -0
  143. package/src/gpu/submit-tracker.js +13 -0
  144. package/src/gpu/uniform-cache.d.ts +1 -0
  145. package/src/gpu/uniform-cache.js +30 -9
  146. package/src/hotswap/intent-bundle.js +6 -0
  147. package/src/hotswap/manifest.d.ts +10 -1
  148. package/src/hotswap/manifest.js +12 -2
  149. package/src/hotswap/runtime.js +30 -8
  150. package/src/index-browser.d.ts +44 -0
  151. package/src/index-browser.js +14 -0
  152. package/src/inference/browser-harness-contract-helpers.d.ts +5 -0
  153. package/src/inference/browser-harness-contract-helpers.js +28 -0
  154. package/src/inference/browser-harness-diffusion-energy-suites.d.ts +2 -0
  155. package/src/inference/browser-harness-diffusion-energy-suites.js +269 -0
  156. package/src/inference/browser-harness-model-helpers.d.ts +16 -0
  157. package/src/inference/browser-harness-model-helpers.js +217 -0
  158. package/src/inference/browser-harness-report-helpers.d.ts +7 -0
  159. package/src/inference/browser-harness-report-helpers.js +42 -0
  160. package/src/inference/browser-harness-runtime-helpers.d.ts +61 -0
  161. package/src/inference/browser-harness-runtime-helpers.js +415 -0
  162. package/src/inference/browser-harness-suite-helpers.d.ts +28 -0
  163. package/src/inference/browser-harness-suite-helpers.js +268 -0
  164. package/src/inference/browser-harness-text-helpers.d.ts +27 -0
  165. package/src/inference/browser-harness-text-helpers.js +788 -0
  166. package/src/inference/browser-harness.d.ts +6 -0
  167. package/src/inference/browser-harness.js +130 -1996
  168. package/src/inference/kv-cache/base.js +140 -94
  169. package/src/inference/kv-cache/tiered.js +5 -3
  170. package/src/inference/moe-router.js +88 -56
  171. package/src/inference/multi-model-network.js +5 -3
  172. package/src/inference/network-evolution.d.ts +11 -2
  173. package/src/inference/network-evolution.js +20 -21
  174. package/src/inference/pipelines/context.d.ts +3 -0
  175. package/src/inference/pipelines/context.js +142 -2
  176. package/src/inference/pipelines/diffusion/helpers.js +7 -2
  177. package/src/inference/pipelines/diffusion/pipeline.js +2 -1
  178. package/src/inference/pipelines/diffusion/sd3-transformer.js +10 -10
  179. package/src/inference/pipelines/diffusion/vae.js +3 -7
  180. package/src/inference/pipelines/energy/pipeline.js +27 -21
  181. package/src/inference/pipelines/energy/quintel.d.ts +5 -0
  182. package/src/inference/pipelines/energy/quintel.js +11 -0
  183. package/src/inference/pipelines/energy-head/row-head-pipeline.js +17 -13
  184. package/src/inference/pipelines/structured/json-head-pipeline.js +26 -11
  185. package/src/inference/pipelines/text/attention/projections.js +151 -101
  186. package/src/inference/pipelines/text/attention/record.js +62 -8
  187. package/src/inference/pipelines/text/attention/run.js +62 -8
  188. package/src/inference/pipelines/text/config.js +3 -4
  189. package/src/inference/pipelines/text/embed.js +2 -8
  190. package/src/inference/pipelines/text/execution-plan.js +41 -19
  191. package/src/inference/pipelines/text/execution-v0-contract-helpers.d.ts +59 -0
  192. package/src/inference/pipelines/text/execution-v0-contract-helpers.js +937 -0
  193. package/src/inference/pipelines/text/execution-v0-runtime-builders.d.ts +15 -0
  194. package/src/inference/pipelines/text/execution-v0-runtime-builders.js +279 -0
  195. package/src/inference/pipelines/text/execution-v0.js +62 -1013
  196. package/src/inference/pipelines/text/generator-steps.d.ts +46 -0
  197. package/src/inference/pipelines/text/generator-steps.js +298 -207
  198. package/src/inference/pipelines/text/generator.js +6 -23
  199. package/src/inference/pipelines/text/init.js +78 -20
  200. package/src/inference/pipelines/text/kernel-path-auto-select.js +2 -0
  201. package/src/inference/pipelines/text/kernel-trace.d.ts +2 -0
  202. package/src/inference/pipelines/text/kernel-trace.js +6 -0
  203. package/src/inference/pipelines/text/layer.js +3 -9
  204. package/src/inference/pipelines/text/linear-attention.d.ts +10 -0
  205. package/src/inference/pipelines/text/linear-attention.js +80 -6
  206. package/src/inference/pipelines/text/logits/gpu.js +10 -5
  207. package/src/inference/pipelines/text/logits/index.js +10 -11
  208. package/src/inference/pipelines/text/logits/utils.d.ts +7 -0
  209. package/src/inference/pipelines/text/logits/utils.js +9 -0
  210. package/src/inference/pipelines/text/lora-apply.js +50 -32
  211. package/src/inference/pipelines/text/model-load.js +279 -104
  212. package/src/inference/pipelines/text/moe-cache.js +5 -4
  213. package/src/inference/pipelines/text/moe-cpu-gptoss.js +74 -69
  214. package/src/inference/pipelines/text/moe-cpu.js +42 -38
  215. package/src/inference/pipelines/text/moe-gpu.js +110 -86
  216. package/src/inference/pipelines/text/ops.js +90 -90
  217. package/src/inference/pipelines/text/probes.js +9 -9
  218. package/src/inference/pipelines/text/weights.js +17 -7
  219. package/src/inference/pipelines/text.js +13 -1
  220. package/src/inference/speculative.d.ts +2 -2
  221. package/src/inference/speculative.js +4 -18
  222. package/src/inference/test-harness.d.ts +1 -1
  223. package/src/inference/test-harness.js +15 -5
  224. package/src/inference/tokenizer.d.ts +0 -5
  225. package/src/inference/tokenizer.js +4 -23
  226. package/src/inference/tokenizers/bpe.js +9 -0
  227. package/src/inference/tokenizers/bundled.js +20 -0
  228. package/src/inference/tokenizers/sentencepiece.js +12 -0
  229. package/src/loader/doppler-loader.js +38 -22
  230. package/src/loader/dtype-utils.js +3 -44
  231. package/src/loader/embedding-loader.js +7 -3
  232. package/src/loader/experts/expert-cache.js +13 -6
  233. package/src/loader/experts/expert-loader.js +10 -6
  234. package/src/loader/final-weights-loader.js +8 -4
  235. package/src/loader/layer-loader.js +2 -1
  236. package/src/loader/loader-state.js +2 -2
  237. package/src/loader/memory-monitor.js +8 -0
  238. package/src/loader/multi-model-loader.d.ts +14 -0
  239. package/src/loader/multi-model-loader.js +70 -24
  240. package/src/loader/shard-cache.js +81 -12
  241. package/src/loader/shard-resolver.js +25 -3
  242. package/src/loader/tensors/tensor-loader.js +209 -144
  243. package/src/loader/tensors/tensor-reader.js +76 -19
  244. package/src/loader/weight-downcast.js +1 -1
  245. package/src/memory/buffer-pool.d.ts +9 -1
  246. package/src/memory/buffer-pool.js +109 -44
  247. package/src/memory/unified-detect.js +1 -1
  248. package/src/rules/inference/kernel-path.rules.json +24 -8
  249. package/src/rules/rule-registry.js +25 -1
  250. package/src/storage/backends/opfs-store.js +68 -24
  251. package/src/storage/downloader.js +364 -83
  252. package/src/storage/index.d.ts +3 -0
  253. package/src/storage/index.js +3 -0
  254. package/src/storage/preflight.d.ts +2 -2
  255. package/src/storage/preflight.js +24 -2
  256. package/src/storage/quickstart-downloader.js +11 -5
  257. package/src/storage/registry.js +10 -4
  258. package/src/storage/reports.js +1 -1
  259. package/src/storage/shard-manager.d.ts +15 -1
  260. package/src/storage/shard-manager.js +51 -3
  261. package/src/storage/source-artifact-store.d.ts +52 -0
  262. package/src/storage/source-artifact-store.js +234 -0
  263. package/src/tooling/command-api-constants.d.ts +9 -0
  264. package/src/tooling/command-api-constants.js +9 -0
  265. package/src/tooling/command-api-family-normalizers.d.ts +9 -0
  266. package/src/tooling/command-api-family-normalizers.js +343 -0
  267. package/src/tooling/command-api-helpers.d.ts +25 -0
  268. package/src/tooling/command-api-helpers.js +262 -0
  269. package/src/tooling/command-api.js +16 -602
  270. package/src/tooling/command-envelope.js +4 -1
  271. package/src/tooling/command-runner-shared.js +52 -18
  272. package/src/tooling/lean-execution-contract.js +150 -3
  273. package/src/tooling/node-browser-command-runner.js +161 -271
  274. package/src/tooling/node-command-runner.js +29 -3
  275. package/src/tooling/node-converter.js +27 -1
  276. package/src/tooling/node-source-runtime.d.ts +1 -1
  277. package/src/tooling/node-source-runtime.js +84 -3
  278. package/src/tooling/node-webgpu.js +24 -21
  279. package/src/tooling/opfs-cache.js +21 -4
  280. package/src/tooling/runtime-input-composition.d.ts +38 -0
  281. package/src/tooling/runtime-input-composition.js +86 -0
  282. package/src/tooling/source-runtime-bundle.d.ts +40 -5
  283. package/src/tooling/source-runtime-bundle.js +261 -34
  284. package/src/tooling/source-runtime-materializer.d.ts +6 -0
  285. package/src/tooling/source-runtime-materializer.js +93 -0
  286. package/src/training/attention-backward.js +32 -17
  287. package/src/training/autograd.js +80 -52
  288. package/src/training/checkpoint-watch.d.ts +2 -1
  289. package/src/training/checkpoint-watch.js +39 -6
  290. package/src/training/checkpoint.js +40 -11
  291. package/src/training/clip.js +2 -1
  292. package/src/training/datasets/token-batch.js +20 -8
  293. package/src/training/distillation/checkpoint-watch.js +1 -0
  294. package/src/training/distillation/student-fixture.d.ts +22 -0
  295. package/src/training/distillation/student-fixture.js +846 -0
  296. package/src/training/distillation/suite-data.d.ts +45 -0
  297. package/src/training/distillation/suite-data.js +189 -0
  298. package/src/training/lora-pipeline.js +4 -7
  299. package/src/training/lora.js +26 -12
  300. package/src/training/loss.js +5 -6
  301. package/src/training/objectives/cross_entropy.js +2 -5
  302. package/src/training/objectives/distill_kd.js +4 -8
  303. package/src/training/objectives/distill_triplet.js +4 -8
  304. package/src/training/objectives/ul_stage2_base.js +4 -8
  305. package/src/training/operator-command.js +2 -0
  306. package/src/training/optimizer.js +19 -7
  307. package/src/training/runner.js +2 -1
  308. package/src/training/suite.js +18 -978
  309. package/src/training/tensor-factory.d.ts +9 -0
  310. package/src/training/tensor-factory.js +13 -0
  311. package/src/training/trainer.js +3 -5
  312. package/src/training/ul_dataset.js +3 -5
  313. package/src/training/workloads.js +70 -79
  314. package/src/version.js +1 -1
  315. package/tools/convert-safetensors-node.js +22 -16
  316. package/tools/doppler-cli.js +44 -25
@@ -1,4 +1,4 @@
1
- import { acquireBuffer } from '../../memory/buffer-pool.js';
1
+ import { acquireBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
2
2
  import { createTensor, dtypeBytes } from '../tensor.js';
3
3
  import { unifiedKernelWrapper } from './utils.js';
4
4
  import { selectRuleValue } from './rule-registry.js';
@@ -32,23 +32,31 @@ async function _repeatChannels(target, input, options = {}) {
32
32
  const bytesPerElement = dtypeBytes(input.dtype);
33
33
  const outputSize = outChannels * height * width * bytesPerElement;
34
34
  const output = outputBuffer || acquireBuffer(outputSize, undefined, 'repeat_channels_output');
35
+ const ownedOutput = outputBuffer ? null : output;
35
36
 
36
- await unifiedKernelWrapper(
37
- 'repeat_channels',
38
- target,
39
- variant,
40
- [input, output],
41
- {
42
- in_channels: inChannels,
43
- height,
44
- width,
45
- repeats,
46
- _pad0: 0,
47
- },
48
- [Math.ceil((height * width) / WORKGROUP_SIZES.DEFAULT), outChannels, 1]
49
- );
50
-
51
- return createTensor(output, input.dtype, [outChannels, height, width], 'repeat_channels_output');
37
+ try {
38
+ await unifiedKernelWrapper(
39
+ 'repeat_channels',
40
+ target,
41
+ variant,
42
+ [input, output],
43
+ {
44
+ in_channels: inChannels,
45
+ height,
46
+ width,
47
+ repeats,
48
+ _pad0: 0,
49
+ },
50
+ [Math.ceil((height * width) / WORKGROUP_SIZES.DEFAULT), outChannels, 1]
51
+ );
52
+
53
+ return createTensor(output, input.dtype, [outChannels, height, width], 'repeat_channels_output');
54
+ } catch (error) {
55
+ if (ownedOutput) {
56
+ releaseBuffer(ownedOutput);
57
+ }
58
+ throw error;
59
+ }
52
60
  }
53
61
 
54
62
  export async function runRepeatChannels(input, options = {}) {
@@ -82,6 +82,7 @@ function planResidualDispatch(target, size, elementsPerWorkgroup) {
82
82
  async function _residualAdd(target, a, b, size, options = {}) {
83
83
  const recorder = target && typeof target.beginComputePass === 'function' ? target : null;
84
84
  const { useVec4 = true, outputBuffer = null } = options;
85
+ const ownsOutput = outputBuffer == null;
85
86
 
86
87
  const { a: aAligned, b: bAligned, temps } = await alignResidualInputs(a, b, recorder);
87
88
  const outputDtype = inferOutputDtype(aAligned, bAligned);
@@ -97,15 +98,22 @@ async function _residualAdd(target, a, b, size, options = {}) {
97
98
  useVec4 ? VEC4_ELEMENTS_PER_WG : WORKGROUP_SIZES.DEFAULT
98
99
  );
99
100
 
100
- await unifiedKernelWrapper(
101
- 'residual', target, variant,
102
- [aAligned, bAligned, output],
103
- { size, scale: 1, _pad1: dispatchPlan.dispatchStride, _pad2: 0 },
104
- dispatchPlan.workgroups
105
- );
106
-
107
- cleanupTemps(temps, recorder);
108
- return createTensor(output, outputDtype, [size], 'residual_output');
101
+ try {
102
+ await unifiedKernelWrapper(
103
+ 'residual', target, variant,
104
+ [aAligned, bAligned, output],
105
+ { size, scale: 1, _pad1: dispatchPlan.dispatchStride, _pad2: 0 },
106
+ dispatchPlan.workgroups
107
+ );
108
+ return createTensor(output, outputDtype, [size], 'residual_output');
109
+ } catch (error) {
110
+ if (ownsOutput) {
111
+ releaseBuffer(output);
112
+ }
113
+ throw error;
114
+ } finally {
115
+ cleanupTemps(temps, recorder);
116
+ }
109
117
  }
110
118
 
111
119
  async function _biasAdd(target, data, bias, numTokens, dim, options = {}) {
@@ -126,24 +134,26 @@ async function _biasAdd(target, data, bias, numTokens, dim, options = {}) {
126
134
  Math.ceil(numTokens / tokenStride),
127
135
  ];
128
136
 
129
- await unifiedKernelWrapper(
130
- 'bias_add', target, variant,
131
- [data, biasAligned],
132
- {
133
- num_tokens: numTokens,
134
- dim,
135
- data_offset: dataOffset,
136
- bias_offset: biasOffset,
137
- token_stride: tokenStride,
138
- _pad0: 0,
139
- _pad1: 0,
140
- _pad2: 0,
141
- },
142
- workgroups
143
- );
144
-
145
- cleanupTemps(temps, recorder);
146
- return createTensor(data.buffer, data.dtype, [numTokens, dim], 'bias_add_output');
137
+ try {
138
+ await unifiedKernelWrapper(
139
+ 'bias_add', target, variant,
140
+ [data, biasAligned],
141
+ {
142
+ num_tokens: numTokens,
143
+ dim,
144
+ data_offset: dataOffset,
145
+ bias_offset: biasOffset,
146
+ token_stride: tokenStride,
147
+ _pad0: 0,
148
+ _pad1: 0,
149
+ _pad2: 0,
150
+ },
151
+ workgroups
152
+ );
153
+ return createTensor(data.buffer, data.dtype, [numTokens, dim], 'bias_add_output');
154
+ } finally {
155
+ cleanupTemps(temps, recorder);
156
+ }
147
157
  }
148
158
 
149
159
  export async function runResidualAdd(a, b, size, options = {}) {
@@ -1,7 +1,7 @@
1
1
 
2
2
 
3
3
  import { getKernelCapabilities } from '../device.js';
4
- import { acquireBuffer, getBufferRequestedSize } from '../../memory/buffer-pool.js';
4
+ import { acquireBuffer, getBufferRequestedSize, releaseBuffer } from '../../memory/buffer-pool.js';
5
5
  import { createTensor } from '../tensor.js';
6
6
  import { getKernelThresholds, padToQ4KBlock } from '../../config/schema/index.js';
7
7
  import { selectRuleValue } from './rule-registry.js';
@@ -119,31 +119,39 @@ export async function runRMSNorm(
119
119
  const paddedHiddenSize = padToQ4KBlock(inferredHiddenSize);
120
120
  const outputSize = batchSize * paddedHiddenSize * bytesPerElement;
121
121
  const outputBuf = outputBuffer || acquireBuffer(outputSize, undefined, 'rmsnorm_output');
122
+ const ownedOutput = outputBuffer ? null : outputBuf;
122
123
  const dispatchPlan = planRMSNormDispatch(null, batchSize);
123
124
 
124
125
  // Shader layout always includes the residual binding; when unused, bind a harmless placeholder.
125
126
  const residualBuf = residual?.buffer || residual || input?.buffer || input || outputBuf;
126
127
 
127
- await unifiedKernelWrapper(
128
- 'rmsnorm',
129
- null,
130
- variant,
131
- [input, normWeightBuffer, outputBuf, residualBuf],
132
- {
133
- hidden_size: inferredHiddenSize,
134
- num_tokens: batchSize,
135
- eps,
136
- has_residual: residual ? 1 : 0,
137
- token_stride: dispatchPlan.tokenStride,
138
- _pad0: 0,
139
- _pad1: 0,
140
- _pad2: 0,
141
- },
142
- dispatchPlan.workgroups,
143
- { RMS_NORM_OFFSET: rmsNormWeightOffset, WEIGHT_IS_F16: normWeightDtype === 'f16' }
144
- );
145
-
146
- return createTensor(outputBuf, input.dtype, [batchSize, inferredHiddenSize], 'rmsnorm_output');
128
+ try {
129
+ await unifiedKernelWrapper(
130
+ 'rmsnorm',
131
+ null,
132
+ variant,
133
+ [input, normWeightBuffer, outputBuf, residualBuf],
134
+ {
135
+ hidden_size: inferredHiddenSize,
136
+ num_tokens: batchSize,
137
+ eps,
138
+ has_residual: residual ? 1 : 0,
139
+ token_stride: dispatchPlan.tokenStride,
140
+ _pad0: 0,
141
+ _pad1: 0,
142
+ _pad2: 0,
143
+ },
144
+ dispatchPlan.workgroups,
145
+ { RMS_NORM_OFFSET: rmsNormWeightOffset, WEIGHT_IS_F16: normWeightDtype === 'f16' }
146
+ );
147
+
148
+ return createTensor(outputBuf, input.dtype, [batchSize, inferredHiddenSize], 'rmsnorm_output');
149
+ } catch (error) {
150
+ if (ownedOutput) {
151
+ releaseBuffer(ownedOutput);
152
+ }
153
+ throw error;
154
+ }
147
155
  }
148
156
 
149
157
  export async function recordRMSNorm(
@@ -165,28 +173,36 @@ export async function recordRMSNorm(
165
173
  const paddedHiddenSize = padToQ4KBlock(inferredHiddenSize);
166
174
  const outputSize = batchSize * paddedHiddenSize * bytesPerElement;
167
175
  const outputBuf = outputBuffer || acquireBuffer(outputSize, undefined, 'rmsnorm_output');
176
+ const ownedOutput = outputBuffer ? null : outputBuf;
168
177
  const dispatchPlan = planRMSNormDispatch(recorder, batchSize);
169
178
 
170
179
  const residualBuf = residual?.buffer || residual || input?.buffer || input || outputBuf;
171
180
 
172
- await unifiedKernelWrapper(
173
- 'rmsnorm',
174
- recorder,
175
- variant,
176
- [input, normWeightBuffer, outputBuf, residualBuf],
177
- {
178
- hidden_size: inferredHiddenSize,
179
- num_tokens: batchSize,
180
- eps,
181
- has_residual: residual ? 1 : 0,
182
- token_stride: dispatchPlan.tokenStride,
183
- _pad0: 0,
184
- _pad1: 0,
185
- _pad2: 0,
186
- },
187
- dispatchPlan.workgroups,
188
- { RMS_NORM_OFFSET: rmsNormWeightOffset, WEIGHT_IS_F16: normWeightDtype === 'f16' }
189
- );
190
-
191
- return createTensor(outputBuf, input.dtype, [batchSize, inferredHiddenSize], 'rmsnorm_output');
181
+ try {
182
+ await unifiedKernelWrapper(
183
+ 'rmsnorm',
184
+ recorder,
185
+ variant,
186
+ [input, normWeightBuffer, outputBuf, residualBuf],
187
+ {
188
+ hidden_size: inferredHiddenSize,
189
+ num_tokens: batchSize,
190
+ eps,
191
+ has_residual: residual ? 1 : 0,
192
+ token_stride: dispatchPlan.tokenStride,
193
+ _pad0: 0,
194
+ _pad1: 0,
195
+ _pad2: 0,
196
+ },
197
+ dispatchPlan.workgroups,
198
+ { RMS_NORM_OFFSET: rmsNormWeightOffset, WEIGHT_IS_F16: normWeightDtype === 'f16' }
199
+ );
200
+
201
+ return createTensor(outputBuf, input.dtype, [batchSize, inferredHiddenSize], 'rmsnorm_output');
202
+ } catch (error) {
203
+ if (ownedOutput) {
204
+ releaseBuffer(ownedOutput);
205
+ }
206
+ throw error;
207
+ }
192
208
  }
@@ -27,6 +27,9 @@ async function _rope(target, input, freqsCos, freqsSin, seqLen, options = {}) {
27
27
  if (rotaryDim <= 0 || rotaryDim > headDim) {
28
28
  throw new Error(`RoPE rotaryDim must be in (0, headDim]; got ${rotaryDim} for headDim ${headDim}`);
29
29
  }
30
+ if (input.dtype === 'f16' && (rotaryDim !== headDim || interleaved)) {
31
+ throw new Error('RoPE f16 kernel requires rotaryDim === headDim and interleaved === false.');
32
+ }
30
33
 
31
34
  const caps = getKernelCapabilities();
32
35
  const useF16 = input.dtype === 'f16' && caps.hasF16;
@@ -1,7 +1,7 @@
1
1
 
2
2
 
3
3
  import { getDevice, getKernelCapabilities } from '../device.js';
4
- import { acquireBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
4
+ import { acquireBuffer, readBufferSlice, releaseBuffer } from '../../memory/buffer-pool.js';
5
5
  import { WORKGROUP_SIZES } from './constants.js';
6
6
  import { createPipeline, createUniformBufferWithView, getOrCreateBindGroupLayout } from './utils.js';
7
7
  import { allowReadback } from '../perf-guards.js';
@@ -156,18 +156,19 @@ function ensureOutputBufferSize(outputBuffer, minBytes, label) {
156
156
  }
157
157
  }
158
158
 
159
- function readTokenFromOutput(device, outputBuffer, outputIndex, label) {
160
- const stagingBuffer = device.createBuffer({
161
- label,
162
- size: 4,
163
- usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST,
164
- });
165
-
166
- const copyEncoder = device.createCommandEncoder({ label: `${label}_copy` });
167
- copyEncoder.copyBufferToBuffer(outputBuffer, outputIndex * 4, stagingBuffer, 0, 4);
168
- device.queue.submit([copyEncoder.finish()]);
159
+ async function readTokenFromOutput(outputBuffer, outputIndex) {
160
+ return new Uint32Array(await readBufferSlice(outputBuffer, outputIndex * 4, 4))[0];
161
+ }
169
162
 
170
- return stagingBuffer;
163
+ function cleanupRunResources(uniformBuffer, ownedBuffers) {
164
+ if (uniformBuffer) {
165
+ uniformBuffer.destroy();
166
+ }
167
+ for (const buffer of ownedBuffers) {
168
+ if (buffer) {
169
+ releaseBuffer(buffer);
170
+ }
171
+ }
171
172
  }
172
173
 
173
174
  async function executeArgmaxRun(logits, vocabSize, options) {
@@ -238,20 +239,14 @@ async function executeArgmaxRun(logits, vocabSize, options) {
238
239
 
239
240
  device.queue.submit([encoder.finish()]);
240
241
 
241
- const stagingBuffer = readTokenFromOutput(device, outputBuffer, outputIndex, 'argmax_staging');
242
- await stagingBuffer.mapAsync(GPUMapMode.READ);
243
- const tokenId = new Uint32Array(stagingBuffer.getMappedRange())[0];
244
- stagingBuffer.unmap();
245
-
246
- stagingBuffer.destroy();
247
- uniformBuffer.destroy();
248
- releaseBuffer(tempLogits);
249
- releaseBuffer(tempIndices);
250
- if (ownsOutputBuffer) {
251
- releaseBuffer(outputBuffer);
242
+ try {
243
+ return await readTokenFromOutput(outputBuffer, outputIndex);
244
+ } finally {
245
+ cleanupRunResources(
246
+ uniformBuffer,
247
+ [tempLogits, tempIndices, ownsOutputBuffer ? outputBuffer : null]
248
+ );
252
249
  }
253
-
254
- return tokenId;
255
250
  }
256
251
 
257
252
  async function executeArgmaxRecord(recorder, logits, vocabSize, options) {
@@ -428,20 +423,14 @@ export async function runGPUSample(
428
423
 
429
424
  device.queue.submit([encoder.finish()]);
430
425
 
431
- const stagingBuffer = readTokenFromOutput(device, outputBuffer, outputIndex, 'sample_staging');
432
- await stagingBuffer.mapAsync(GPUMapMode.READ);
433
- const tokenId = new Uint32Array(stagingBuffer.getMappedRange())[0];
434
- stagingBuffer.unmap();
435
-
436
- stagingBuffer.destroy();
437
- uniformBuffer.destroy();
438
- releaseBuffer(topkLogits);
439
- releaseBuffer(topkIndices);
440
- if (ownsOutputBuffer) {
441
- releaseBuffer(outputBuffer);
426
+ try {
427
+ return await readTokenFromOutput(outputBuffer, outputIndex);
428
+ } finally {
429
+ cleanupRunResources(
430
+ uniformBuffer,
431
+ [topkLogits, topkIndices, ownsOutputBuffer ? outputBuffer : null]
432
+ );
442
433
  }
443
-
444
- return tokenId;
445
434
  }
446
435
 
447
436
 
@@ -64,6 +64,8 @@ async function _sanaLinearAttention(target, query, key, value, options = {}) {
64
64
  outputBuffer = null,
65
65
  summaryBuffer = null,
66
66
  } = options;
67
+ const ownsSummary = summaryBuffer == null;
68
+ const ownsOutput = outputBuffer == null;
67
69
 
68
70
  if (
69
71
  !Number.isFinite(numHeads) ||
@@ -98,18 +100,24 @@ async function _sanaLinearAttention(target, query, key, value, options = {}) {
98
100
  eps,
99
101
  };
100
102
 
101
- await runSummary(target, query, key, value, temporarySummary, uniforms, variant);
102
- await runApply(target, query, temporarySummary, output, uniforms, variant);
103
-
104
- if (!summaryBuffer) {
105
- if (recorder) {
106
- recorder.trackTemporaryBuffer(temporarySummary);
107
- } else {
108
- releaseBuffer(temporarySummary);
103
+ try {
104
+ await runSummary(target, query, key, value, temporarySummary, uniforms, variant);
105
+ await runApply(target, query, temporarySummary, output, uniforms, variant);
106
+ return createTensor(output, query.dtype, [numTokens, hiddenSize], 'sana_linear_attention_output');
107
+ } catch (error) {
108
+ if (ownsOutput) {
109
+ releaseBuffer(output);
110
+ }
111
+ throw error;
112
+ } finally {
113
+ if (ownsSummary) {
114
+ if (recorder) {
115
+ recorder.trackTemporaryBuffer(temporarySummary);
116
+ } else {
117
+ releaseBuffer(temporarySummary);
118
+ }
109
119
  }
110
120
  }
111
-
112
- return createTensor(output, query.dtype, [numTokens, hiddenSize], 'sana_linear_attention_output');
113
121
  }
114
122
 
115
123
  export async function runSanaLinearAttention(query, key, value, options = {}) {
@@ -1,4 +1,4 @@
1
- import { acquireBuffer } from '../../memory/buffer-pool.js';
1
+ import { acquireBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
2
2
  import { createTensor, dtypeBytes } from '../tensor.js';
3
3
  import { WORKGROUP_SIZES } from './constants.js';
4
4
  import { unifiedKernelWrapper } from './utils.js';
@@ -6,6 +6,7 @@ import { selectRuleValue } from './rule-registry.js';
6
6
 
7
7
  async function _scale(target, input, scale, options = {}) {
8
8
  const { count, outputBuffer = null, inplace = false } = options;
9
+ const ownsOutput = !inplace && outputBuffer == null;
9
10
 
10
11
  const bytesPerElement = dtypeBytes(input.dtype);
11
12
  const inferredCount = count ?? Math.floor(input.buffer.size / bytesPerElement);
@@ -16,16 +17,22 @@ async function _scale(target, input, scale, options = {}) {
16
17
 
17
18
  const bindings = inplace ? [outputBuf, outputBuf] : [input, outputBuf];
18
19
 
19
- await unifiedKernelWrapper(
20
- 'scale',
21
- target,
22
- variant,
23
- bindings,
24
- { size: inferredCount, scale },
25
- Math.ceil(inferredCount / WORKGROUP_SIZES.DEFAULT)
26
- );
27
-
28
- return createTensor(outputBuf, input.dtype, [...input.shape], 'scale_output');
20
+ try {
21
+ await unifiedKernelWrapper(
22
+ 'scale',
23
+ target,
24
+ variant,
25
+ bindings,
26
+ { size: inferredCount, scale },
27
+ Math.ceil(inferredCount / WORKGROUP_SIZES.DEFAULT)
28
+ );
29
+ return createTensor(outputBuf, input.dtype, [...input.shape], 'scale_output');
30
+ } catch (error) {
31
+ if (ownsOutput) {
32
+ releaseBuffer(outputBuf);
33
+ }
34
+ throw error;
35
+ }
29
36
  }
30
37
 
31
38
  export async function runScale(input, scale, options = {}) {
@@ -138,8 +138,10 @@ export async function compileShader(
138
138
  code: source,
139
139
  });
140
140
 
141
- // Check for compilation errors
142
- const compilationInfo = await module.getCompilationInfo();
141
+ // Check for compilation errors (getCompilationInfo not available in all WebGPU providers)
142
+ const compilationInfo = typeof module.getCompilationInfo === 'function'
143
+ ? await module.getCompilationInfo()
144
+ : { messages: [] };
143
145
  if (compilationInfo.messages.length > 0) {
144
146
  for (const msg of compilationInfo.messages) {
145
147
  if (msg.type === 'error') {