@simulatte/doppler 0.1.5 → 0.1.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (392) hide show
  1. package/CHANGELOG.md +126 -0
  2. package/README.md +25 -17
  3. package/package.json +20 -4
  4. package/src/adapters/adapter-registry.js +12 -1
  5. package/src/adapters/lora-loader.js +23 -6
  6. package/src/bridge/extension-client.d.ts +5 -0
  7. package/src/bridge/extension-client.js +40 -0
  8. package/src/bridge/index.d.ts +2 -1
  9. package/src/bridge/index.js +6 -4
  10. package/src/browser/browser-converter.js +26 -1
  11. package/src/browser/file-picker.js +6 -0
  12. package/src/browser/safetensors-parser-browser.js +84 -1
  13. package/src/browser/shard-io-browser.js +2 -2
  14. package/src/browser/tensor-source-download.js +8 -2
  15. package/src/browser/tensor-source-http.d.ts +1 -0
  16. package/src/browser/tensor-source-http.js +5 -1
  17. package/src/client/doppler-api.browser.js +20 -4
  18. package/src/client/doppler-api.js +19 -3
  19. package/src/client/doppler-provider/generation.js +12 -0
  20. package/src/client/doppler-provider/model-manager.d.ts +10 -0
  21. package/src/client/doppler-provider/model-manager.js +91 -19
  22. package/src/client/doppler-provider/source-runtime.d.ts +2 -1
  23. package/src/client/doppler-provider/source-runtime.js +132 -13
  24. package/src/client/doppler-registry.json +8 -7
  25. package/src/config/backward-registry-loader.js +17 -2
  26. package/src/config/execution-v0-contract-check.js +113 -15
  27. package/src/config/kernel-path-contract-check.js +57 -29
  28. package/src/config/kernel-path-loader.js +5 -36
  29. package/src/config/kernels/kernel-ref-digests.js +39 -39
  30. package/src/config/kernels/registry.js +14 -1
  31. package/src/config/kernels/registry.json +49 -7
  32. package/src/config/loader.d.ts +1 -1
  33. package/src/config/loader.js +43 -4
  34. package/src/config/merge-contract-check.js +59 -4
  35. package/src/config/merge-helpers.js +128 -7
  36. package/src/config/merge.d.ts +1 -0
  37. package/src/config/merge.js +28 -0
  38. package/src/config/param-validator.js +47 -2
  39. package/src/config/presets/kernel-paths/{gemma2-q4k-dequant-f32a.json → gemma2-q4k-dequant-f32a-nosubgroups.json} +3 -3
  40. package/src/config/presets/kernel-paths/gemma3-f16-fused-f32a-online-streamingprefill.json +223 -0
  41. package/src/config/presets/kernel-paths/{gemma3-q4k-dequant-f32a.json → gemma3-q4k-dequant-f32a-nosubgroups.json} +3 -3
  42. package/src/config/presets/kernel-paths/registry.json +29 -8
  43. package/src/config/presets/models/gemma2.json +2 -2
  44. package/src/config/presets/models/qwen3.json +9 -2
  45. package/src/config/presets/models/transformer.json +5 -0
  46. package/src/config/presets/runtime/experiments/bench/gemma3-bench-q4k.json +1 -1
  47. package/src/config/presets/runtime/experiments/debug/gemma3-debug-q4k.json +1 -1
  48. package/src/config/presets/runtime/experiments/verify/gemma3-verify.json +1 -1
  49. package/src/config/presets/runtime/kernels/dequant-f16-q4k.json +6 -13
  50. package/src/config/presets/runtime/kernels/dequant-f32-q4k.json +6 -13
  51. package/src/config/presets/runtime/kernels/embeddinggemma-q4k-dequant-f32a.json +37 -0
  52. package/src/config/presets/runtime/kernels/fused-q4k.json +6 -13
  53. package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f16a.json +33 -0
  54. package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f32a-nosubgroups.json +33 -0
  55. package/src/config/presets/runtime/kernels/gemma2-q4k-fused-f32a.json +33 -0
  56. package/src/config/presets/runtime/kernels/safe-q4k.json +6 -13
  57. package/src/config/presets/runtime/platform/metal-apple-q4k.json +1 -1
  58. package/src/config/required-inference-fields-contract-check.js +6 -0
  59. package/src/config/runtime.js +6 -1
  60. package/src/config/schema/debug.schema.d.ts +5 -0
  61. package/src/config/schema/doppler.schema.js +16 -21
  62. package/src/config/schema/inference-defaults.schema.js +6 -3
  63. package/src/config/schema/inference.schema.d.ts +9 -0
  64. package/src/config/schema/kernel-path.schema.d.ts +11 -1
  65. package/src/config/schema/kernel-thresholds.schema.js +12 -4
  66. package/src/config/schema/manifest.schema.d.ts +8 -1
  67. package/src/config/schema/manifest.schema.js +19 -3
  68. package/src/config/training-defaults.js +30 -22
  69. package/src/converter/conversion-plan.js +94 -9
  70. package/src/converter/core.d.ts +7 -0
  71. package/src/converter/core.js +14 -9
  72. package/src/converter/execution-v0-manifest.js +4 -1
  73. package/src/converter/index.d.ts +1 -0
  74. package/src/converter/index.js +1 -0
  75. package/src/converter/manifest-inference.js +43 -12
  76. package/src/converter/parsers/diffusion.js +0 -3
  77. package/src/converter/quantization-info.js +35 -15
  78. package/src/converter/rope-config.js +42 -0
  79. package/src/converter/shard-packer.d.ts +1 -1
  80. package/src/converter/shard-packer.js +4 -1
  81. package/src/debug/config.js +123 -11
  82. package/src/debug/signals.js +7 -1
  83. package/src/debug/tensor.d.ts +2 -0
  84. package/src/debug/tensor.js +13 -2
  85. package/src/distribution/p2p-control-plane.js +52 -12
  86. package/src/distribution/p2p-observability.js +43 -7
  87. package/src/distribution/p2p-webrtc-browser.js +20 -0
  88. package/src/distribution/shard-delivery.js +77 -26
  89. package/src/formats/gguf/types.js +33 -16
  90. package/src/formats/rdrr/groups.d.ts +12 -4
  91. package/src/formats/rdrr/groups.js +3 -6
  92. package/src/formats/rdrr/parsing.js +39 -2
  93. package/src/formats/rdrr/types.d.ts +2 -1
  94. package/src/gpu/command-recorder.js +86 -61
  95. package/src/gpu/device.d.ts +1 -0
  96. package/src/gpu/device.js +131 -19
  97. package/src/gpu/kernel-tuner/benchmarks.js +326 -316
  98. package/src/gpu/kernel-tuner/cache.js +71 -4
  99. package/src/gpu/kernel-tuner/tuner.js +22 -4
  100. package/src/gpu/kernels/attention.js +113 -34
  101. package/src/gpu/kernels/backward/adam.js +62 -58
  102. package/src/gpu/kernels/backward/attention_backward.js +257 -169
  103. package/src/gpu/kernels/backward/conv2d_backward.js +14 -1
  104. package/src/gpu/kernels/bias_add.wgsl +8 -6
  105. package/src/gpu/kernels/bias_add_f16.wgsl +8 -5
  106. package/src/gpu/kernels/cast.js +191 -149
  107. package/src/gpu/kernels/check-stop.js +33 -44
  108. package/src/gpu/kernels/conv2d.js +27 -17
  109. package/src/gpu/kernels/conv2d.wgsl +7 -8
  110. package/src/gpu/kernels/conv2d_f16.wgsl +7 -8
  111. package/src/gpu/kernels/cross_entropy_loss.js +21 -15
  112. package/src/gpu/kernels/depthwise_conv2d.js +37 -26
  113. package/src/gpu/kernels/depthwise_conv2d.wgsl +6 -9
  114. package/src/gpu/kernels/depthwise_conv2d_f16.wgsl +6 -9
  115. package/src/gpu/kernels/dequant.js +178 -126
  116. package/src/gpu/kernels/energy.d.ts +3 -21
  117. package/src/gpu/kernels/energy.js +111 -88
  118. package/src/gpu/kernels/feature-check.js +1 -1
  119. package/src/gpu/kernels/fused_ffn.js +84 -65
  120. package/src/gpu/kernels/fused_matmul_residual.js +56 -33
  121. package/src/gpu/kernels/fused_matmul_rmsnorm.js +62 -45
  122. package/src/gpu/kernels/gather.js +33 -15
  123. package/src/gpu/kernels/gelu.js +19 -11
  124. package/src/gpu/kernels/grouped_pointwise_conv2d.js +34 -23
  125. package/src/gpu/kernels/grouped_pointwise_conv2d.wgsl +6 -9
  126. package/src/gpu/kernels/grouped_pointwise_conv2d_f16.wgsl +6 -9
  127. package/src/gpu/kernels/groupnorm.js +34 -23
  128. package/src/gpu/kernels/kv-quantize.js +5 -2
  129. package/src/gpu/kernels/layernorm.js +35 -19
  130. package/src/gpu/kernels/logit-merge.js +5 -3
  131. package/src/gpu/kernels/matmul.js +83 -39
  132. package/src/gpu/kernels/modulate.js +23 -15
  133. package/src/gpu/kernels/moe.js +221 -175
  134. package/src/gpu/kernels/pixel_shuffle.js +22 -14
  135. package/src/gpu/kernels/pixel_shuffle.wgsl +4 -5
  136. package/src/gpu/kernels/pixel_shuffle_f16.wgsl +4 -5
  137. package/src/gpu/kernels/relu.js +31 -10
  138. package/src/gpu/kernels/relu.wgsl +2 -1
  139. package/src/gpu/kernels/relu_f16.wgsl +2 -1
  140. package/src/gpu/kernels/repeat_channels.js +25 -17
  141. package/src/gpu/kernels/repeat_channels.wgsl +4 -5
  142. package/src/gpu/kernels/repeat_channels_f16.wgsl +4 -5
  143. package/src/gpu/kernels/residual.js +69 -23
  144. package/src/gpu/kernels/residual.wgsl +6 -3
  145. package/src/gpu/kernels/residual_f16.wgsl +2 -1
  146. package/src/gpu/kernels/residual_f16_vec4.wgsl +2 -1
  147. package/src/gpu/kernels/residual_vec4.wgsl +2 -1
  148. package/src/gpu/kernels/rmsnorm.js +96 -28
  149. package/src/gpu/kernels/rmsnorm.wgsl +14 -6
  150. package/src/gpu/kernels/rmsnorm_f16.wgsl +10 -2
  151. package/src/gpu/kernels/rope.d.ts +2 -0
  152. package/src/gpu/kernels/rope.js +14 -1
  153. package/src/gpu/kernels/rope.wgsl +56 -40
  154. package/src/gpu/kernels/sample.js +27 -38
  155. package/src/gpu/kernels/sana_linear_attention.js +19 -12
  156. package/src/gpu/kernels/sana_linear_attention_apply.wgsl +4 -5
  157. package/src/gpu/kernels/sana_linear_attention_apply_f16.wgsl +4 -5
  158. package/src/gpu/kernels/sana_linear_attention_summary.wgsl +4 -0
  159. package/src/gpu/kernels/sana_linear_attention_summary_f16.wgsl +4 -0
  160. package/src/gpu/kernels/scale.js +18 -11
  161. package/src/gpu/kernels/shader-cache.js +4 -2
  162. package/src/gpu/kernels/silu.d.ts +1 -0
  163. package/src/gpu/kernels/silu.js +148 -82
  164. package/src/gpu/kernels/silu.wgsl +19 -9
  165. package/src/gpu/kernels/silu_f16.wgsl +19 -9
  166. package/src/gpu/kernels/softmax.js +44 -25
  167. package/src/gpu/kernels/split_qkv.js +23 -13
  168. package/src/gpu/kernels/transpose.js +31 -10
  169. package/src/gpu/kernels/transpose.wgsl +6 -5
  170. package/src/gpu/kernels/upsample2d.js +22 -13
  171. package/src/gpu/kernels/upsample2d.wgsl +6 -9
  172. package/src/gpu/kernels/upsample2d_f16.wgsl +6 -9
  173. package/src/gpu/kernels/utils.js +35 -13
  174. package/src/gpu/partitioned-buffer-pool.js +10 -2
  175. package/src/gpu/perf-guards.js +2 -9
  176. package/src/gpu/profiler.js +27 -22
  177. package/src/gpu/readback-utils.d.ts +16 -0
  178. package/src/gpu/readback-utils.js +41 -0
  179. package/src/gpu/submit-tracker.js +13 -0
  180. package/src/gpu/uniform-cache.d.ts +1 -0
  181. package/src/gpu/uniform-cache.js +30 -9
  182. package/src/hotswap/intent-bundle.js +6 -0
  183. package/src/hotswap/manifest.d.ts +10 -1
  184. package/src/hotswap/manifest.js +12 -2
  185. package/src/hotswap/runtime.js +30 -8
  186. package/src/index-browser.d.ts +44 -0
  187. package/src/index-browser.js +14 -0
  188. package/src/inference/browser-harness-contract-helpers.d.ts +5 -0
  189. package/src/inference/browser-harness-contract-helpers.js +28 -0
  190. package/src/inference/browser-harness-diffusion-energy-suites.d.ts +2 -0
  191. package/src/inference/browser-harness-diffusion-energy-suites.js +269 -0
  192. package/src/inference/browser-harness-model-helpers.d.ts +16 -0
  193. package/src/inference/browser-harness-model-helpers.js +217 -0
  194. package/src/inference/browser-harness-report-helpers.d.ts +7 -0
  195. package/src/inference/browser-harness-report-helpers.js +42 -0
  196. package/src/inference/browser-harness-runtime-helpers.d.ts +61 -0
  197. package/src/inference/browser-harness-runtime-helpers.js +415 -0
  198. package/src/inference/browser-harness-suite-helpers.d.ts +28 -0
  199. package/src/inference/browser-harness-suite-helpers.js +268 -0
  200. package/src/inference/browser-harness-text-helpers.d.ts +27 -0
  201. package/src/inference/browser-harness-text-helpers.js +788 -0
  202. package/src/inference/browser-harness.d.ts +6 -0
  203. package/src/inference/browser-harness.js +130 -1950
  204. package/src/inference/kv-cache/base.js +140 -94
  205. package/src/inference/kv-cache/tiered.js +5 -3
  206. package/src/inference/moe-router.js +88 -56
  207. package/src/inference/multi-model-network.js +5 -3
  208. package/src/inference/network-evolution.d.ts +11 -2
  209. package/src/inference/network-evolution.js +20 -21
  210. package/src/inference/pipelines/context.d.ts +3 -0
  211. package/src/inference/pipelines/context.js +142 -2
  212. package/src/inference/pipelines/diffusion/helpers.js +7 -2
  213. package/src/inference/pipelines/diffusion/pipeline.js +17 -7
  214. package/src/inference/pipelines/diffusion/sd3-transformer.js +10 -10
  215. package/src/inference/pipelines/diffusion/text-encoder-gpu.d.ts +5 -0
  216. package/src/inference/pipelines/diffusion/text-encoder-gpu.js +27 -15
  217. package/src/inference/pipelines/diffusion/vae.js +3 -7
  218. package/src/inference/pipelines/energy/pipeline.js +27 -21
  219. package/src/inference/pipelines/energy/quintel.d.ts +5 -0
  220. package/src/inference/pipelines/energy/quintel.js +11 -0
  221. package/src/inference/pipelines/energy-head/row-head-pipeline.js +17 -13
  222. package/src/inference/pipelines/structured/json-head-pipeline.js +26 -11
  223. package/src/inference/pipelines/text/attention/projections.js +151 -101
  224. package/src/inference/pipelines/text/attention/record.js +73 -10
  225. package/src/inference/pipelines/text/attention/run.js +73 -10
  226. package/src/inference/pipelines/text/chat-format.js +25 -1
  227. package/src/inference/pipelines/text/config.d.ts +4 -0
  228. package/src/inference/pipelines/text/config.js +71 -5
  229. package/src/inference/pipelines/text/embed.js +2 -8
  230. package/src/inference/pipelines/text/execution-plan.js +64 -50
  231. package/src/inference/pipelines/text/execution-v0-contract-helpers.d.ts +59 -0
  232. package/src/inference/pipelines/text/execution-v0-contract-helpers.js +937 -0
  233. package/src/inference/pipelines/text/execution-v0-runtime-builders.d.ts +15 -0
  234. package/src/inference/pipelines/text/execution-v0-runtime-builders.js +279 -0
  235. package/src/inference/pipelines/text/execution-v0.js +78 -1002
  236. package/src/inference/pipelines/text/ffn/standard.js +3 -0
  237. package/src/inference/pipelines/text/generator-steps.d.ts +46 -0
  238. package/src/inference/pipelines/text/generator-steps.js +298 -207
  239. package/src/inference/pipelines/text/generator.js +6 -23
  240. package/src/inference/pipelines/text/init.d.ts +4 -0
  241. package/src/inference/pipelines/text/init.js +134 -29
  242. package/src/inference/pipelines/text/kernel-path-auto-select.js +2 -0
  243. package/src/inference/pipelines/text/kernel-trace.d.ts +2 -0
  244. package/src/inference/pipelines/text/kernel-trace.js +6 -0
  245. package/src/inference/pipelines/text/layer.js +14 -9
  246. package/src/inference/pipelines/text/linear-attention.d.ts +10 -0
  247. package/src/inference/pipelines/text/linear-attention.js +80 -6
  248. package/src/inference/pipelines/text/logits/gpu.js +10 -5
  249. package/src/inference/pipelines/text/logits/index.js +10 -11
  250. package/src/inference/pipelines/text/logits/utils.d.ts +7 -0
  251. package/src/inference/pipelines/text/logits/utils.js +9 -0
  252. package/src/inference/pipelines/text/lora-apply.js +50 -32
  253. package/src/inference/pipelines/text/model-load.js +279 -104
  254. package/src/inference/pipelines/text/moe-cache.js +5 -4
  255. package/src/inference/pipelines/text/moe-cpu-gptoss.js +74 -69
  256. package/src/inference/pipelines/text/moe-cpu.js +42 -38
  257. package/src/inference/pipelines/text/moe-gpu.js +110 -86
  258. package/src/inference/pipelines/text/ops.js +90 -90
  259. package/src/inference/pipelines/text/probes.js +9 -9
  260. package/src/inference/pipelines/text/weights.js +17 -7
  261. package/src/inference/pipelines/text.js +17 -1
  262. package/src/inference/speculative.d.ts +2 -2
  263. package/src/inference/speculative.js +4 -18
  264. package/src/inference/test-harness.d.ts +1 -1
  265. package/src/inference/test-harness.js +15 -5
  266. package/src/inference/tokenizer.d.ts +0 -5
  267. package/src/inference/tokenizer.js +4 -23
  268. package/src/inference/tokenizers/bpe.js +9 -0
  269. package/src/inference/tokenizers/bundled.js +176 -33
  270. package/src/inference/tokenizers/sentencepiece.js +12 -0
  271. package/src/loader/doppler-loader.js +38 -22
  272. package/src/loader/dtype-utils.js +3 -44
  273. package/src/loader/embedding-loader.js +7 -3
  274. package/src/loader/experts/expert-cache.js +13 -6
  275. package/src/loader/experts/expert-loader.js +10 -6
  276. package/src/loader/final-weights-loader.js +8 -4
  277. package/src/loader/layer-loader.js +2 -1
  278. package/src/loader/loader-state.js +2 -2
  279. package/src/loader/memory-monitor.js +8 -0
  280. package/src/loader/multi-model-loader.d.ts +14 -0
  281. package/src/loader/multi-model-loader.js +70 -24
  282. package/src/loader/shard-cache.js +81 -12
  283. package/src/loader/shard-resolver.js +25 -3
  284. package/src/loader/tensors/tensor-loader.js +209 -144
  285. package/src/loader/tensors/tensor-reader.js +76 -19
  286. package/src/loader/weight-downcast.js +1 -1
  287. package/src/memory/buffer-pool.d.ts +9 -1
  288. package/src/memory/buffer-pool.js +109 -44
  289. package/src/memory/unified-detect.js +1 -1
  290. package/src/rules/inference/kernel-path.rules.json +24 -8
  291. package/src/rules/rule-registry.js +25 -1
  292. package/src/rules/tooling/command-runtime.rules.json +18 -0
  293. package/src/storage/backends/opfs-store.js +68 -24
  294. package/src/storage/downloader.js +364 -83
  295. package/src/storage/index.d.ts +3 -0
  296. package/src/storage/index.js +3 -0
  297. package/src/storage/preflight.d.ts +2 -2
  298. package/src/storage/preflight.js +24 -2
  299. package/src/storage/quickstart-downloader.js +11 -5
  300. package/src/storage/registry.js +10 -4
  301. package/src/storage/reports.js +1 -1
  302. package/src/storage/shard-manager.d.ts +15 -1
  303. package/src/storage/shard-manager.js +51 -3
  304. package/src/storage/source-artifact-store.d.ts +52 -0
  305. package/src/storage/source-artifact-store.js +234 -0
  306. package/src/tooling/command-api-constants.d.ts +9 -0
  307. package/src/tooling/command-api-constants.js +9 -0
  308. package/src/tooling/command-api-family-normalizers.d.ts +9 -0
  309. package/src/tooling/command-api-family-normalizers.js +343 -0
  310. package/src/tooling/command-api-helpers.d.ts +25 -0
  311. package/src/tooling/command-api-helpers.js +262 -0
  312. package/src/tooling/command-api.d.ts +27 -1
  313. package/src/tooling/command-api.js +26 -473
  314. package/src/tooling/command-envelope.js +4 -1
  315. package/src/tooling/command-runner-shared.js +52 -18
  316. package/src/tooling/lean-execution-contract.js +150 -3
  317. package/src/tooling/node-browser-command-runner.d.ts +4 -0
  318. package/src/tooling/node-browser-command-runner.js +218 -273
  319. package/src/tooling/node-command-runner.js +44 -3
  320. package/src/tooling/node-converter.js +27 -1
  321. package/src/tooling/node-source-runtime.d.ts +1 -1
  322. package/src/tooling/node-source-runtime.js +84 -3
  323. package/src/tooling/node-webgpu.js +30 -105
  324. package/src/tooling/opfs-cache.js +21 -4
  325. package/src/tooling/runtime-input-composition.d.ts +38 -0
  326. package/src/tooling/runtime-input-composition.js +86 -0
  327. package/src/tooling/source-runtime-bundle.d.ts +40 -5
  328. package/src/tooling/source-runtime-bundle.js +261 -34
  329. package/src/tooling/source-runtime-materializer.d.ts +6 -0
  330. package/src/tooling/source-runtime-materializer.js +93 -0
  331. package/src/training/attention-backward.js +32 -17
  332. package/src/training/autograd.js +80 -52
  333. package/src/training/checkpoint-watch.d.ts +8 -0
  334. package/src/training/checkpoint-watch.js +139 -0
  335. package/src/training/checkpoint.d.ts +6 -1
  336. package/src/training/checkpoint.js +46 -7
  337. package/src/training/clip.js +2 -1
  338. package/src/training/datasets/token-batch.js +20 -8
  339. package/src/training/distillation/artifacts.d.ts +71 -0
  340. package/src/training/distillation/artifacts.js +132 -0
  341. package/src/training/distillation/checkpoint-watch.d.ts +10 -0
  342. package/src/training/distillation/checkpoint-watch.js +58 -0
  343. package/src/training/distillation/dataset.d.ts +59 -0
  344. package/src/training/distillation/dataset.js +337 -0
  345. package/src/training/distillation/eval.d.ts +34 -0
  346. package/src/training/distillation/eval.js +310 -0
  347. package/src/training/distillation/index.d.ts +29 -0
  348. package/src/training/distillation/index.js +29 -0
  349. package/src/training/distillation/runtime.d.ts +20 -0
  350. package/src/training/distillation/runtime.js +121 -0
  351. package/src/training/distillation/scoreboard.d.ts +6 -0
  352. package/src/training/distillation/scoreboard.js +8 -0
  353. package/src/training/distillation/stage-a.d.ts +45 -0
  354. package/src/training/distillation/stage-a.js +338 -0
  355. package/src/training/distillation/stage-b.d.ts +24 -0
  356. package/src/training/distillation/stage-b.js +20 -0
  357. package/src/training/distillation/student-fixture.d.ts +22 -0
  358. package/src/training/distillation/student-fixture.js +846 -0
  359. package/src/training/distillation/suite-data.d.ts +45 -0
  360. package/src/training/distillation/suite-data.js +189 -0
  361. package/src/training/index.d.ts +10 -0
  362. package/src/training/index.js +10 -0
  363. package/src/training/lora-pipeline.d.ts +40 -0
  364. package/src/training/lora-pipeline.js +793 -0
  365. package/src/training/lora.js +26 -12
  366. package/src/training/loss.js +5 -6
  367. package/src/training/objectives/cross_entropy.js +2 -5
  368. package/src/training/objectives/distill_kd.js +4 -8
  369. package/src/training/objectives/distill_triplet.js +4 -8
  370. package/src/training/objectives/ul_stage2_base.js +4 -8
  371. package/src/training/operator-artifacts.d.ts +62 -0
  372. package/src/training/operator-artifacts.js +140 -0
  373. package/src/training/operator-command.d.ts +5 -0
  374. package/src/training/operator-command.js +455 -0
  375. package/src/training/operator-eval.d.ts +48 -0
  376. package/src/training/operator-eval.js +230 -0
  377. package/src/training/operator-scoreboard.d.ts +5 -0
  378. package/src/training/operator-scoreboard.js +44 -0
  379. package/src/training/optimizer.js +19 -7
  380. package/src/training/runner.d.ts +52 -0
  381. package/src/training/runner.js +31 -5
  382. package/src/training/suite.d.ts +112 -0
  383. package/src/training/suite.js +24 -984
  384. package/src/training/tensor-factory.d.ts +9 -0
  385. package/src/training/tensor-factory.js +13 -0
  386. package/src/training/trainer.js +3 -5
  387. package/src/training/ul_dataset.js +3 -5
  388. package/src/training/workloads.d.ts +164 -0
  389. package/src/training/workloads.js +530 -0
  390. package/src/version.js +1 -1
  391. package/tools/convert-safetensors-node.js +22 -16
  392. package/tools/doppler-cli.js +179 -63
@@ -19,6 +19,18 @@ function fillRandom(data, rng) {
19
19
  for (let i = 0; i < data.length; i++) data[i] = rng();
20
20
  }
21
21
 
22
+ function destroyBuffer(buffer) {
23
+ if (buffer && typeof buffer.destroy === 'function') {
24
+ buffer.destroy();
25
+ }
26
+ }
27
+
28
+ function destroyBuffers(...buffers) {
29
+ for (const buffer of buffers) {
30
+ destroyBuffer(buffer);
31
+ }
32
+ }
33
+
22
34
 
23
35
  export async function benchmarkPipeline(
24
36
  device,
@@ -98,7 +110,6 @@ export async function tuneMatmul(
98
110
  deviceInfo: capabilities?.adapterInfo,
99
111
  };
100
112
 
101
- // Create test buffers
102
113
  const bufferA = device.createBuffer({
103
114
  size: M * K * 4,
104
115
  usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST,
@@ -112,93 +123,86 @@ export async function tuneMatmul(
112
123
  usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC,
113
124
  });
114
125
 
115
- // Initialize with random data
116
- const dataA = new Float32Array(M * K);
117
- const dataB = new Float32Array(K * N);
118
- const matmulRng = createRng(0x13579bdf);
119
- fillRandom(dataA, matmulRng);
120
- fillRandom(dataB, matmulRng);
121
- device.queue.writeBuffer(bufferA, 0, dataA);
122
- device.queue.writeBuffer(bufferB, 0, dataB);
123
-
124
- for (const [wgX, wgY] of matmulCandidates) {
125
- try {
126
- // Create shader with this workgroup size
127
- const shader = createMatmulShader(wgX, wgY);
128
- const pipeline = await createComputePipeline(device, shader, 'main');
129
-
130
- // Create bind group
131
- const uniformBuffer = device.createBuffer({
132
- size: 16,
133
- usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST,
134
- });
135
- const uniformData = new Uint32Array([M, N, K, 0]);
136
- device.queue.writeBuffer(uniformBuffer, 0, uniformData);
137
-
138
- const bindGroup = device.createBindGroup({
139
- layout: pipeline.getBindGroupLayout(0),
140
- entries: [
141
- { binding: 0, resource: { buffer: uniformBuffer } },
142
- { binding: 1, resource: { buffer: bufferA } },
143
- { binding: 2, resource: { buffer: bufferB } },
144
- { binding: 3, resource: { buffer: bufferC } },
145
- ],
146
- });
147
-
148
- // Warmup
149
- for (let i = 0; i < warmup; i++) {
150
- const encoder = device.createCommandEncoder();
151
- const pass = encoder.beginComputePass();
152
- pass.setPipeline(pipeline);
153
- pass.setBindGroup(0, bindGroup);
154
- pass.dispatchWorkgroups(Math.ceil(M / wgX), Math.ceil(N / wgY));
155
- pass.end();
156
- device.queue.submit([encoder.finish()]);
157
- }
158
- await device.queue.onSubmittedWorkDone();
159
-
160
- // Benchmark
161
-
162
- const times = [];
163
- for (let i = 0; i < iterations; i++) {
164
- const start = performance.now();
165
- const encoder = device.createCommandEncoder();
166
- const pass = encoder.beginComputePass();
167
- pass.setPipeline(pipeline);
168
- pass.setBindGroup(0, bindGroup);
169
- pass.dispatchWorkgroups(Math.ceil(M / wgX), Math.ceil(N / wgY));
170
- pass.end();
171
- device.queue.submit([encoder.finish()]);
126
+ try {
127
+ const dataA = new Float32Array(M * K);
128
+ const dataB = new Float32Array(K * N);
129
+ const matmulRng = createRng(0x13579bdf);
130
+ fillRandom(dataA, matmulRng);
131
+ fillRandom(dataB, matmulRng);
132
+ device.queue.writeBuffer(bufferA, 0, dataA);
133
+ device.queue.writeBuffer(bufferB, 0, dataB);
134
+
135
+ for (const [wgX, wgY] of matmulCandidates) {
136
+ let uniformBuffer = null;
137
+ try {
138
+ const shader = createMatmulShader(wgX, wgY);
139
+ const pipeline = await createComputePipeline(device, shader, 'main');
140
+
141
+ uniformBuffer = device.createBuffer({
142
+ size: 16,
143
+ usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST,
144
+ });
145
+ const uniformData = new Uint32Array([M, N, K, 0]);
146
+ device.queue.writeBuffer(uniformBuffer, 0, uniformData);
147
+
148
+ const bindGroup = device.createBindGroup({
149
+ layout: pipeline.getBindGroupLayout(0),
150
+ entries: [
151
+ { binding: 0, resource: { buffer: uniformBuffer } },
152
+ { binding: 1, resource: { buffer: bufferA } },
153
+ { binding: 2, resource: { buffer: bufferB } },
154
+ { binding: 3, resource: { buffer: bufferC } },
155
+ ],
156
+ });
157
+
158
+ for (let i = 0; i < warmup; i++) {
159
+ const encoder = device.createCommandEncoder();
160
+ const pass = encoder.beginComputePass();
161
+ pass.setPipeline(pipeline);
162
+ pass.setBindGroup(0, bindGroup);
163
+ pass.dispatchWorkgroups(Math.ceil(M / wgX), Math.ceil(N / wgY));
164
+ pass.end();
165
+ device.queue.submit([encoder.finish()]);
166
+ }
172
167
  await device.queue.onSubmittedWorkDone();
173
- times.push(performance.now() - start);
174
- }
175
168
 
176
- const avgTime = times.reduce((a, b) => a + b, 0) / times.length;
177
- const flops = 2 * M * N * K; // multiply-add = 2 ops
178
- const gflops = (flops / avgTime) / 1e6; // GFLOPS
179
-
180
- if (avgTime < best.timeMs) {
181
- best = {
182
- optimalWorkgroupSize: [wgX, wgY, 1],
183
- optimalTileSize: wgX,
184
- throughput: gflops,
185
- timeMs: avgTime,
186
- deviceInfo: capabilities?.adapterInfo,
187
- };
169
+ const times = [];
170
+ for (let i = 0; i < iterations; i++) {
171
+ const start = performance.now();
172
+ const encoder = device.createCommandEncoder();
173
+ const pass = encoder.beginComputePass();
174
+ pass.setPipeline(pipeline);
175
+ pass.setBindGroup(0, bindGroup);
176
+ pass.dispatchWorkgroups(Math.ceil(M / wgX), Math.ceil(N / wgY));
177
+ pass.end();
178
+ device.queue.submit([encoder.finish()]);
179
+ await device.queue.onSubmittedWorkDone();
180
+ times.push(performance.now() - start);
181
+ }
182
+
183
+ const avgTime = times.reduce((a, b) => a + b, 0) / times.length;
184
+ const flops = 2 * M * N * K;
185
+ const gflops = (flops / avgTime) / 1e6;
186
+
187
+ if (avgTime < best.timeMs) {
188
+ best = {
189
+ optimalWorkgroupSize: [wgX, wgY, 1],
190
+ optimalTileSize: wgX,
191
+ throughput: gflops,
192
+ timeMs: avgTime,
193
+ deviceInfo: capabilities?.adapterInfo,
194
+ };
195
+ }
196
+ } catch (e) {
197
+ continue;
198
+ } finally {
199
+ destroyBuffer(uniformBuffer);
188
200
  }
189
-
190
- uniformBuffer.destroy();
191
- } catch (e) {
192
- // Skip invalid configurations
193
- continue;
194
201
  }
202
+ } finally {
203
+ destroyBuffers(bufferA, bufferB, bufferC);
195
204
  }
196
205
 
197
- // Cleanup
198
- bufferA.destroy();
199
- bufferB.destroy();
200
- bufferC.destroy();
201
-
202
206
  return best;
203
207
  }
204
208
 
@@ -277,68 +281,69 @@ export async function tuneAttention(
277
281
  usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC,
278
282
  });
279
283
 
280
- const dataQ = new Float32Array(totalElements);
281
- const dataK = new Float32Array(totalElements);
282
- const attentionRng = createRng(0x2468ace1);
283
- fillRandom(dataQ, attentionRng);
284
- fillRandom(dataK, attentionRng);
285
- device.queue.writeBuffer(bufferQ, 0, dataQ);
286
- device.queue.writeBuffer(bufferK, 0, dataK);
287
-
288
- for (const [wgX] of attentionCandidates) {
289
- try {
290
- const shader = createAttentionShader(wgX);
291
- const pipeline = await createComputePipeline(device, shader, 'main');
292
-
293
- const uniformBuffer = device.createBuffer({
294
- size: 16,
295
- usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST,
296
- });
297
- const uniformData = new Uint32Array([headDim, numHeads, benchSeqLen, 0]);
298
- device.queue.writeBuffer(uniformBuffer, 0, uniformData);
299
-
300
- const bindGroup = device.createBindGroup({
301
- layout: pipeline.getBindGroupLayout(0),
302
- entries: [
303
- { binding: 0, resource: { buffer: uniformBuffer } },
304
- { binding: 1, resource: { buffer: bufferQ } },
305
- { binding: 2, resource: { buffer: bufferK } },
306
- { binding: 3, resource: { buffer: bufferOut } },
307
- ],
308
- });
309
-
310
- const avgTime = await benchmarkPipeline(
311
- device,
312
- pipeline,
313
- bindGroup,
314
- [totalHeads, 1, 1],
315
- warmup,
316
- iterations
317
- );
318
-
319
- const flops = 2 * totalHeads * headDim;
320
- const gflops = avgTime > 0 ? (flops / avgTime) / 1e6 : 0;
321
-
322
- if (avgTime < best.timeMs) {
323
- best = {
324
- optimalWorkgroupSize: [wgX, 1, 1],
325
- optimalTileSize: wgX,
326
- throughput: gflops,
327
- timeMs: avgTime,
328
- deviceInfo: capabilities?.adapterInfo,
329
- };
284
+ try {
285
+ const dataQ = new Float32Array(totalElements);
286
+ const dataK = new Float32Array(totalElements);
287
+ const attentionRng = createRng(0x2468ace1);
288
+ fillRandom(dataQ, attentionRng);
289
+ fillRandom(dataK, attentionRng);
290
+ device.queue.writeBuffer(bufferQ, 0, dataQ);
291
+ device.queue.writeBuffer(bufferK, 0, dataK);
292
+
293
+ for (const [wgX] of attentionCandidates) {
294
+ let uniformBuffer = null;
295
+ try {
296
+ const shader = createAttentionShader(wgX);
297
+ const pipeline = await createComputePipeline(device, shader, 'main');
298
+
299
+ uniformBuffer = device.createBuffer({
300
+ size: 16,
301
+ usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST,
302
+ });
303
+ const uniformData = new Uint32Array([headDim, numHeads, benchSeqLen, 0]);
304
+ device.queue.writeBuffer(uniformBuffer, 0, uniformData);
305
+
306
+ const bindGroup = device.createBindGroup({
307
+ layout: pipeline.getBindGroupLayout(0),
308
+ entries: [
309
+ { binding: 0, resource: { buffer: uniformBuffer } },
310
+ { binding: 1, resource: { buffer: bufferQ } },
311
+ { binding: 2, resource: { buffer: bufferK } },
312
+ { binding: 3, resource: { buffer: bufferOut } },
313
+ ],
314
+ });
315
+
316
+ const avgTime = await benchmarkPipeline(
317
+ device,
318
+ pipeline,
319
+ bindGroup,
320
+ [totalHeads, 1, 1],
321
+ warmup,
322
+ iterations
323
+ );
324
+
325
+ const flops = 2 * totalHeads * headDim;
326
+ const gflops = avgTime > 0 ? (flops / avgTime) / 1e6 : 0;
327
+
328
+ if (avgTime < best.timeMs) {
329
+ best = {
330
+ optimalWorkgroupSize: [wgX, 1, 1],
331
+ optimalTileSize: wgX,
332
+ throughput: gflops,
333
+ timeMs: avgTime,
334
+ deviceInfo: capabilities?.adapterInfo,
335
+ };
336
+ }
337
+ } catch (e) {
338
+ continue;
339
+ } finally {
340
+ destroyBuffer(uniformBuffer);
330
341
  }
331
-
332
- uniformBuffer.destroy();
333
- } catch (e) {
334
- continue;
335
342
  }
343
+ } finally {
344
+ destroyBuffers(bufferQ, bufferK, bufferOut);
336
345
  }
337
346
 
338
- bufferQ.destroy();
339
- bufferK.destroy();
340
- bufferOut.destroy();
341
-
342
347
  return best;
343
348
  }
344
349
 
@@ -437,63 +442,65 @@ export async function tuneSoftmax(
437
442
  usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC,
438
443
  });
439
444
 
440
- const dataIn = new Float32Array(totalElements);
441
- const softmaxRng = createRng(0x31415926);
442
- fillRandom(dataIn, softmaxRng);
443
- device.queue.writeBuffer(bufferIn, 0, dataIn);
444
-
445
- for (const [wgX] of softmaxCandidates) {
446
- try {
447
- const shader = createSoftmaxShader(wgX);
448
- const pipeline = await createComputePipeline(device, shader, 'main');
449
-
450
- const uniformBuffer = device.createBuffer({
451
- size: 16,
452
- usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST,
453
- });
454
- const uniformData = new Uint32Array([innerSize, outerSize, 0, 0]);
455
- device.queue.writeBuffer(uniformBuffer, 0, uniformData);
456
-
457
- const bindGroup = device.createBindGroup({
458
- layout: pipeline.getBindGroupLayout(0),
459
- entries: [
460
- { binding: 0, resource: { buffer: uniformBuffer } },
461
- { binding: 1, resource: { buffer: bufferIn } },
462
- { binding: 2, resource: { buffer: bufferOut } },
463
- ],
464
- });
465
-
466
- const avgTime = await benchmarkPipeline(
467
- device,
468
- pipeline,
469
- bindGroup,
470
- [outerSize, 1, 1],
471
- warmup,
472
- iterations
473
- );
474
-
475
- const ops = 2 * totalElements;
476
- const gops = avgTime > 0 ? (ops / avgTime) / 1e6 : 0;
477
-
478
- if (avgTime < best.timeMs) {
479
- best = {
480
- optimalWorkgroupSize: [wgX, 1, 1],
481
- optimalTileSize: wgX,
482
- throughput: gops,
483
- timeMs: avgTime,
484
- deviceInfo: capabilities?.adapterInfo,
485
- };
445
+ try {
446
+ const dataIn = new Float32Array(totalElements);
447
+ const softmaxRng = createRng(0x31415926);
448
+ fillRandom(dataIn, softmaxRng);
449
+ device.queue.writeBuffer(bufferIn, 0, dataIn);
450
+
451
+ for (const [wgX] of softmaxCandidates) {
452
+ let uniformBuffer = null;
453
+ try {
454
+ const shader = createSoftmaxShader(wgX);
455
+ const pipeline = await createComputePipeline(device, shader, 'main');
456
+
457
+ uniformBuffer = device.createBuffer({
458
+ size: 16,
459
+ usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST,
460
+ });
461
+ const uniformData = new Uint32Array([innerSize, outerSize, 0, 0]);
462
+ device.queue.writeBuffer(uniformBuffer, 0, uniformData);
463
+
464
+ const bindGroup = device.createBindGroup({
465
+ layout: pipeline.getBindGroupLayout(0),
466
+ entries: [
467
+ { binding: 0, resource: { buffer: uniformBuffer } },
468
+ { binding: 1, resource: { buffer: bufferIn } },
469
+ { binding: 2, resource: { buffer: bufferOut } },
470
+ ],
471
+ });
472
+
473
+ const avgTime = await benchmarkPipeline(
474
+ device,
475
+ pipeline,
476
+ bindGroup,
477
+ [outerSize, 1, 1],
478
+ warmup,
479
+ iterations
480
+ );
481
+
482
+ const ops = 2 * totalElements;
483
+ const gops = avgTime > 0 ? (ops / avgTime) / 1e6 : 0;
484
+
485
+ if (avgTime < best.timeMs) {
486
+ best = {
487
+ optimalWorkgroupSize: [wgX, 1, 1],
488
+ optimalTileSize: wgX,
489
+ throughput: gops,
490
+ timeMs: avgTime,
491
+ deviceInfo: capabilities?.adapterInfo,
492
+ };
493
+ }
494
+ } catch (e) {
495
+ continue;
496
+ } finally {
497
+ destroyBuffer(uniformBuffer);
486
498
  }
487
-
488
- uniformBuffer.destroy();
489
- } catch (e) {
490
- continue;
491
499
  }
500
+ } finally {
501
+ destroyBuffers(bufferIn, bufferOut);
492
502
  }
493
503
 
494
- bufferIn.destroy();
495
- bufferOut.destroy();
496
-
497
504
  return best;
498
505
  }
499
506
 
@@ -620,73 +627,74 @@ export async function tuneRMSNorm(
620
627
  usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC,
621
628
  });
622
629
 
623
- const dataIn = new Float32Array(totalElements);
624
- const dataWeight = new Float32Array(hiddenSize);
625
- const rmsRng = createRng(0x27182818);
626
- fillRandom(dataIn, rmsRng);
627
- fillRandom(dataWeight, rmsRng);
628
- device.queue.writeBuffer(bufferIn, 0, dataIn);
629
- device.queue.writeBuffer(bufferWeight, 0, dataWeight);
630
-
631
- for (const [wgX] of rmsCandidates) {
632
- try {
633
- const shader = createRMSNormShader(wgX);
634
- const pipeline = await createComputePipeline(device, shader, 'main');
635
-
636
- const uniformBuffer = device.createBuffer({
637
- size: 16,
638
- usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST,
639
- });
640
- const uniformData = new ArrayBuffer(16);
641
- const uniformView = new DataView(uniformData);
642
- uniformView.setUint32(0, hiddenSize, true);
643
- uniformView.setUint32(4, numTokens, true);
644
- uniformView.setFloat32(8, DEFAULT_RMS_NORM_EPS, true);
645
- uniformView.setUint32(12, 0, true);
646
- device.queue.writeBuffer(uniformBuffer, 0, uniformData);
647
-
648
- const bindGroup = device.createBindGroup({
649
- layout: pipeline.getBindGroupLayout(0),
650
- entries: [
651
- { binding: 0, resource: { buffer: uniformBuffer } },
652
- { binding: 1, resource: { buffer: bufferIn } },
653
- { binding: 2, resource: { buffer: bufferWeight } },
654
- { binding: 3, resource: { buffer: bufferOut } },
655
- ],
656
- });
657
-
658
- const avgTime = await benchmarkPipeline(
659
- device,
660
- pipeline,
661
- bindGroup,
662
- [numTokens, 1, 1],
663
- warmup,
664
- iterations
665
- );
666
-
667
- const ops = 2 * totalElements;
668
- const gops = avgTime > 0 ? (ops / avgTime) / 1e6 : 0;
669
-
670
- if (avgTime < best.timeMs) {
671
- best = {
672
- optimalWorkgroupSize: [wgX, 1, 1],
673
- optimalTileSize: wgX,
674
- throughput: gops,
675
- timeMs: avgTime,
676
- deviceInfo: capabilities?.adapterInfo,
677
- };
630
+ try {
631
+ const dataIn = new Float32Array(totalElements);
632
+ const dataWeight = new Float32Array(hiddenSize);
633
+ const rmsRng = createRng(0x27182818);
634
+ fillRandom(dataIn, rmsRng);
635
+ fillRandom(dataWeight, rmsRng);
636
+ device.queue.writeBuffer(bufferIn, 0, dataIn);
637
+ device.queue.writeBuffer(bufferWeight, 0, dataWeight);
638
+
639
+ for (const [wgX] of rmsCandidates) {
640
+ let uniformBuffer = null;
641
+ try {
642
+ const shader = createRMSNormShader(wgX);
643
+ const pipeline = await createComputePipeline(device, shader, 'main');
644
+
645
+ uniformBuffer = device.createBuffer({
646
+ size: 16,
647
+ usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST,
648
+ });
649
+ const uniformData = new ArrayBuffer(16);
650
+ const uniformView = new DataView(uniformData);
651
+ uniformView.setUint32(0, hiddenSize, true);
652
+ uniformView.setUint32(4, numTokens, true);
653
+ uniformView.setFloat32(8, DEFAULT_RMS_NORM_EPS, true);
654
+ uniformView.setUint32(12, 0, true);
655
+ device.queue.writeBuffer(uniformBuffer, 0, uniformData);
656
+
657
+ const bindGroup = device.createBindGroup({
658
+ layout: pipeline.getBindGroupLayout(0),
659
+ entries: [
660
+ { binding: 0, resource: { buffer: uniformBuffer } },
661
+ { binding: 1, resource: { buffer: bufferIn } },
662
+ { binding: 2, resource: { buffer: bufferWeight } },
663
+ { binding: 3, resource: { buffer: bufferOut } },
664
+ ],
665
+ });
666
+
667
+ const avgTime = await benchmarkPipeline(
668
+ device,
669
+ pipeline,
670
+ bindGroup,
671
+ [numTokens, 1, 1],
672
+ warmup,
673
+ iterations
674
+ );
675
+
676
+ const ops = 2 * totalElements;
677
+ const gops = avgTime > 0 ? (ops / avgTime) / 1e6 : 0;
678
+
679
+ if (avgTime < best.timeMs) {
680
+ best = {
681
+ optimalWorkgroupSize: [wgX, 1, 1],
682
+ optimalTileSize: wgX,
683
+ throughput: gops,
684
+ timeMs: avgTime,
685
+ deviceInfo: capabilities?.adapterInfo,
686
+ };
687
+ }
688
+ } catch (e) {
689
+ continue;
690
+ } finally {
691
+ destroyBuffer(uniformBuffer);
678
692
  }
679
-
680
- uniformBuffer.destroy();
681
- } catch (e) {
682
- continue;
683
693
  }
694
+ } finally {
695
+ destroyBuffers(bufferIn, bufferWeight, bufferOut);
684
696
  }
685
697
 
686
- bufferIn.destroy();
687
- bufferWeight.destroy();
688
- bufferOut.destroy();
689
-
690
698
  return best;
691
699
  }
692
700
 
@@ -789,70 +797,72 @@ export async function tuneDequant(
789
797
  usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC,
790
798
  });
791
799
 
792
- const dataIn = new Uint32Array(numElements);
793
- for (let i = 0; i < numElements; i++) {
794
- dataIn[i] = i & 0xffff;
795
- }
796
- device.queue.writeBuffer(bufferIn, 0, dataIn);
797
-
798
- for (const [wgX] of dequantCandidates) {
799
- try {
800
- const shader = createDequantShader(wgX);
801
- const pipeline = await createComputePipeline(device, shader, 'main');
802
-
803
- const uniformBuffer = device.createBuffer({
804
- size: 16,
805
- usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST,
806
- });
807
- const uniformData = new ArrayBuffer(16);
808
- const uniformView = new DataView(uniformData);
809
- uniformView.setUint32(0, numElements, true);
810
- uniformView.setFloat32(4, 0.01, true);
811
- uniformView.setUint32(8, 0, true);
812
- uniformView.setUint32(12, 0, true);
813
- device.queue.writeBuffer(uniformBuffer, 0, uniformData);
814
-
815
- const bindGroup = device.createBindGroup({
816
- layout: pipeline.getBindGroupLayout(0),
817
- entries: [
818
- { binding: 0, resource: { buffer: uniformBuffer } },
819
- { binding: 1, resource: { buffer: bufferIn } },
820
- { binding: 2, resource: { buffer: bufferOut } },
821
- ],
822
- });
823
-
824
- const workgroups = Math.ceil(numElements / wgX);
825
- const avgTime = await benchmarkPipeline(
826
- device,
827
- pipeline,
828
- bindGroup,
829
- [workgroups, 1, 1],
830
- warmup,
831
- iterations
832
- );
833
-
834
- const ops = numElements;
835
- const gops = avgTime > 0 ? (ops / avgTime) / 1e6 : 0;
836
-
837
- if (avgTime < best.timeMs) {
838
- best = {
839
- optimalWorkgroupSize: [wgX, 1, 1],
840
- optimalTileSize: wgX,
841
- throughput: gops,
842
- timeMs: avgTime,
843
- deviceInfo: capabilities?.adapterInfo,
844
- };
800
+ try {
801
+ const dataIn = new Uint32Array(numElements);
802
+ for (let i = 0; i < numElements; i++) {
803
+ dataIn[i] = i & 0xffff;
804
+ }
805
+ device.queue.writeBuffer(bufferIn, 0, dataIn);
806
+
807
+ for (const [wgX] of dequantCandidates) {
808
+ let uniformBuffer = null;
809
+ try {
810
+ const shader = createDequantShader(wgX);
811
+ const pipeline = await createComputePipeline(device, shader, 'main');
812
+
813
+ uniformBuffer = device.createBuffer({
814
+ size: 16,
815
+ usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST,
816
+ });
817
+ const uniformData = new ArrayBuffer(16);
818
+ const uniformView = new DataView(uniformData);
819
+ uniformView.setUint32(0, numElements, true);
820
+ uniformView.setFloat32(4, 0.01, true);
821
+ uniformView.setUint32(8, 0, true);
822
+ uniformView.setUint32(12, 0, true);
823
+ device.queue.writeBuffer(uniformBuffer, 0, uniformData);
824
+
825
+ const bindGroup = device.createBindGroup({
826
+ layout: pipeline.getBindGroupLayout(0),
827
+ entries: [
828
+ { binding: 0, resource: { buffer: uniformBuffer } },
829
+ { binding: 1, resource: { buffer: bufferIn } },
830
+ { binding: 2, resource: { buffer: bufferOut } },
831
+ ],
832
+ });
833
+
834
+ const workgroups = Math.ceil(numElements / wgX);
835
+ const avgTime = await benchmarkPipeline(
836
+ device,
837
+ pipeline,
838
+ bindGroup,
839
+ [workgroups, 1, 1],
840
+ warmup,
841
+ iterations
842
+ );
843
+
844
+ const ops = numElements;
845
+ const gops = avgTime > 0 ? (ops / avgTime) / 1e6 : 0;
846
+
847
+ if (avgTime < best.timeMs) {
848
+ best = {
849
+ optimalWorkgroupSize: [wgX, 1, 1],
850
+ optimalTileSize: wgX,
851
+ throughput: gops,
852
+ timeMs: avgTime,
853
+ deviceInfo: capabilities?.adapterInfo,
854
+ };
855
+ }
856
+ } catch (e) {
857
+ continue;
858
+ } finally {
859
+ destroyBuffer(uniformBuffer);
845
860
  }
846
-
847
- uniformBuffer.destroy();
848
- } catch (e) {
849
- continue;
850
861
  }
862
+ } finally {
863
+ destroyBuffers(bufferIn, bufferOut);
851
864
  }
852
865
 
853
- bufferIn.destroy();
854
- bufferOut.destroy();
855
-
856
866
  return best;
857
867
  }
858
868