@simulatte/doppler 0.1.5 → 0.1.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (392) hide show
  1. package/CHANGELOG.md +126 -0
  2. package/README.md +25 -17
  3. package/package.json +20 -4
  4. package/src/adapters/adapter-registry.js +12 -1
  5. package/src/adapters/lora-loader.js +23 -6
  6. package/src/bridge/extension-client.d.ts +5 -0
  7. package/src/bridge/extension-client.js +40 -0
  8. package/src/bridge/index.d.ts +2 -1
  9. package/src/bridge/index.js +6 -4
  10. package/src/browser/browser-converter.js +26 -1
  11. package/src/browser/file-picker.js +6 -0
  12. package/src/browser/safetensors-parser-browser.js +84 -1
  13. package/src/browser/shard-io-browser.js +2 -2
  14. package/src/browser/tensor-source-download.js +8 -2
  15. package/src/browser/tensor-source-http.d.ts +1 -0
  16. package/src/browser/tensor-source-http.js +5 -1
  17. package/src/client/doppler-api.browser.js +20 -4
  18. package/src/client/doppler-api.js +19 -3
  19. package/src/client/doppler-provider/generation.js +12 -0
  20. package/src/client/doppler-provider/model-manager.d.ts +10 -0
  21. package/src/client/doppler-provider/model-manager.js +91 -19
  22. package/src/client/doppler-provider/source-runtime.d.ts +2 -1
  23. package/src/client/doppler-provider/source-runtime.js +132 -13
  24. package/src/client/doppler-registry.json +8 -7
  25. package/src/config/backward-registry-loader.js +17 -2
  26. package/src/config/execution-v0-contract-check.js +113 -15
  27. package/src/config/kernel-path-contract-check.js +57 -29
  28. package/src/config/kernel-path-loader.js +5 -36
  29. package/src/config/kernels/kernel-ref-digests.js +39 -39
  30. package/src/config/kernels/registry.js +14 -1
  31. package/src/config/kernels/registry.json +49 -7
  32. package/src/config/loader.d.ts +1 -1
  33. package/src/config/loader.js +43 -4
  34. package/src/config/merge-contract-check.js +59 -4
  35. package/src/config/merge-helpers.js +128 -7
  36. package/src/config/merge.d.ts +1 -0
  37. package/src/config/merge.js +28 -0
  38. package/src/config/param-validator.js +47 -2
  39. package/src/config/presets/kernel-paths/{gemma2-q4k-dequant-f32a.json → gemma2-q4k-dequant-f32a-nosubgroups.json} +3 -3
  40. package/src/config/presets/kernel-paths/gemma3-f16-fused-f32a-online-streamingprefill.json +223 -0
  41. package/src/config/presets/kernel-paths/{gemma3-q4k-dequant-f32a.json → gemma3-q4k-dequant-f32a-nosubgroups.json} +3 -3
  42. package/src/config/presets/kernel-paths/registry.json +29 -8
  43. package/src/config/presets/models/gemma2.json +2 -2
  44. package/src/config/presets/models/qwen3.json +9 -2
  45. package/src/config/presets/models/transformer.json +5 -0
  46. package/src/config/presets/runtime/experiments/bench/gemma3-bench-q4k.json +1 -1
  47. package/src/config/presets/runtime/experiments/debug/gemma3-debug-q4k.json +1 -1
  48. package/src/config/presets/runtime/experiments/verify/gemma3-verify.json +1 -1
  49. package/src/config/presets/runtime/kernels/dequant-f16-q4k.json +6 -13
  50. package/src/config/presets/runtime/kernels/dequant-f32-q4k.json +6 -13
  51. package/src/config/presets/runtime/kernels/embeddinggemma-q4k-dequant-f32a.json +37 -0
  52. package/src/config/presets/runtime/kernels/fused-q4k.json +6 -13
  53. package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f16a.json +33 -0
  54. package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f32a-nosubgroups.json +33 -0
  55. package/src/config/presets/runtime/kernels/gemma2-q4k-fused-f32a.json +33 -0
  56. package/src/config/presets/runtime/kernels/safe-q4k.json +6 -13
  57. package/src/config/presets/runtime/platform/metal-apple-q4k.json +1 -1
  58. package/src/config/required-inference-fields-contract-check.js +6 -0
  59. package/src/config/runtime.js +6 -1
  60. package/src/config/schema/debug.schema.d.ts +5 -0
  61. package/src/config/schema/doppler.schema.js +16 -21
  62. package/src/config/schema/inference-defaults.schema.js +6 -3
  63. package/src/config/schema/inference.schema.d.ts +9 -0
  64. package/src/config/schema/kernel-path.schema.d.ts +11 -1
  65. package/src/config/schema/kernel-thresholds.schema.js +12 -4
  66. package/src/config/schema/manifest.schema.d.ts +8 -1
  67. package/src/config/schema/manifest.schema.js +19 -3
  68. package/src/config/training-defaults.js +30 -22
  69. package/src/converter/conversion-plan.js +94 -9
  70. package/src/converter/core.d.ts +7 -0
  71. package/src/converter/core.js +14 -9
  72. package/src/converter/execution-v0-manifest.js +4 -1
  73. package/src/converter/index.d.ts +1 -0
  74. package/src/converter/index.js +1 -0
  75. package/src/converter/manifest-inference.js +43 -12
  76. package/src/converter/parsers/diffusion.js +0 -3
  77. package/src/converter/quantization-info.js +35 -15
  78. package/src/converter/rope-config.js +42 -0
  79. package/src/converter/shard-packer.d.ts +1 -1
  80. package/src/converter/shard-packer.js +4 -1
  81. package/src/debug/config.js +123 -11
  82. package/src/debug/signals.js +7 -1
  83. package/src/debug/tensor.d.ts +2 -0
  84. package/src/debug/tensor.js +13 -2
  85. package/src/distribution/p2p-control-plane.js +52 -12
  86. package/src/distribution/p2p-observability.js +43 -7
  87. package/src/distribution/p2p-webrtc-browser.js +20 -0
  88. package/src/distribution/shard-delivery.js +77 -26
  89. package/src/formats/gguf/types.js +33 -16
  90. package/src/formats/rdrr/groups.d.ts +12 -4
  91. package/src/formats/rdrr/groups.js +3 -6
  92. package/src/formats/rdrr/parsing.js +39 -2
  93. package/src/formats/rdrr/types.d.ts +2 -1
  94. package/src/gpu/command-recorder.js +86 -61
  95. package/src/gpu/device.d.ts +1 -0
  96. package/src/gpu/device.js +131 -19
  97. package/src/gpu/kernel-tuner/benchmarks.js +326 -316
  98. package/src/gpu/kernel-tuner/cache.js +71 -4
  99. package/src/gpu/kernel-tuner/tuner.js +22 -4
  100. package/src/gpu/kernels/attention.js +113 -34
  101. package/src/gpu/kernels/backward/adam.js +62 -58
  102. package/src/gpu/kernels/backward/attention_backward.js +257 -169
  103. package/src/gpu/kernels/backward/conv2d_backward.js +14 -1
  104. package/src/gpu/kernels/bias_add.wgsl +8 -6
  105. package/src/gpu/kernels/bias_add_f16.wgsl +8 -5
  106. package/src/gpu/kernels/cast.js +191 -149
  107. package/src/gpu/kernels/check-stop.js +33 -44
  108. package/src/gpu/kernels/conv2d.js +27 -17
  109. package/src/gpu/kernels/conv2d.wgsl +7 -8
  110. package/src/gpu/kernels/conv2d_f16.wgsl +7 -8
  111. package/src/gpu/kernels/cross_entropy_loss.js +21 -15
  112. package/src/gpu/kernels/depthwise_conv2d.js +37 -26
  113. package/src/gpu/kernels/depthwise_conv2d.wgsl +6 -9
  114. package/src/gpu/kernels/depthwise_conv2d_f16.wgsl +6 -9
  115. package/src/gpu/kernels/dequant.js +178 -126
  116. package/src/gpu/kernels/energy.d.ts +3 -21
  117. package/src/gpu/kernels/energy.js +111 -88
  118. package/src/gpu/kernels/feature-check.js +1 -1
  119. package/src/gpu/kernels/fused_ffn.js +84 -65
  120. package/src/gpu/kernels/fused_matmul_residual.js +56 -33
  121. package/src/gpu/kernels/fused_matmul_rmsnorm.js +62 -45
  122. package/src/gpu/kernels/gather.js +33 -15
  123. package/src/gpu/kernels/gelu.js +19 -11
  124. package/src/gpu/kernels/grouped_pointwise_conv2d.js +34 -23
  125. package/src/gpu/kernels/grouped_pointwise_conv2d.wgsl +6 -9
  126. package/src/gpu/kernels/grouped_pointwise_conv2d_f16.wgsl +6 -9
  127. package/src/gpu/kernels/groupnorm.js +34 -23
  128. package/src/gpu/kernels/kv-quantize.js +5 -2
  129. package/src/gpu/kernels/layernorm.js +35 -19
  130. package/src/gpu/kernels/logit-merge.js +5 -3
  131. package/src/gpu/kernels/matmul.js +83 -39
  132. package/src/gpu/kernels/modulate.js +23 -15
  133. package/src/gpu/kernels/moe.js +221 -175
  134. package/src/gpu/kernels/pixel_shuffle.js +22 -14
  135. package/src/gpu/kernels/pixel_shuffle.wgsl +4 -5
  136. package/src/gpu/kernels/pixel_shuffle_f16.wgsl +4 -5
  137. package/src/gpu/kernels/relu.js +31 -10
  138. package/src/gpu/kernels/relu.wgsl +2 -1
  139. package/src/gpu/kernels/relu_f16.wgsl +2 -1
  140. package/src/gpu/kernels/repeat_channels.js +25 -17
  141. package/src/gpu/kernels/repeat_channels.wgsl +4 -5
  142. package/src/gpu/kernels/repeat_channels_f16.wgsl +4 -5
  143. package/src/gpu/kernels/residual.js +69 -23
  144. package/src/gpu/kernels/residual.wgsl +6 -3
  145. package/src/gpu/kernels/residual_f16.wgsl +2 -1
  146. package/src/gpu/kernels/residual_f16_vec4.wgsl +2 -1
  147. package/src/gpu/kernels/residual_vec4.wgsl +2 -1
  148. package/src/gpu/kernels/rmsnorm.js +96 -28
  149. package/src/gpu/kernels/rmsnorm.wgsl +14 -6
  150. package/src/gpu/kernels/rmsnorm_f16.wgsl +10 -2
  151. package/src/gpu/kernels/rope.d.ts +2 -0
  152. package/src/gpu/kernels/rope.js +14 -1
  153. package/src/gpu/kernels/rope.wgsl +56 -40
  154. package/src/gpu/kernels/sample.js +27 -38
  155. package/src/gpu/kernels/sana_linear_attention.js +19 -12
  156. package/src/gpu/kernels/sana_linear_attention_apply.wgsl +4 -5
  157. package/src/gpu/kernels/sana_linear_attention_apply_f16.wgsl +4 -5
  158. package/src/gpu/kernels/sana_linear_attention_summary.wgsl +4 -0
  159. package/src/gpu/kernels/sana_linear_attention_summary_f16.wgsl +4 -0
  160. package/src/gpu/kernels/scale.js +18 -11
  161. package/src/gpu/kernels/shader-cache.js +4 -2
  162. package/src/gpu/kernels/silu.d.ts +1 -0
  163. package/src/gpu/kernels/silu.js +148 -82
  164. package/src/gpu/kernels/silu.wgsl +19 -9
  165. package/src/gpu/kernels/silu_f16.wgsl +19 -9
  166. package/src/gpu/kernels/softmax.js +44 -25
  167. package/src/gpu/kernels/split_qkv.js +23 -13
  168. package/src/gpu/kernels/transpose.js +31 -10
  169. package/src/gpu/kernels/transpose.wgsl +6 -5
  170. package/src/gpu/kernels/upsample2d.js +22 -13
  171. package/src/gpu/kernels/upsample2d.wgsl +6 -9
  172. package/src/gpu/kernels/upsample2d_f16.wgsl +6 -9
  173. package/src/gpu/kernels/utils.js +35 -13
  174. package/src/gpu/partitioned-buffer-pool.js +10 -2
  175. package/src/gpu/perf-guards.js +2 -9
  176. package/src/gpu/profiler.js +27 -22
  177. package/src/gpu/readback-utils.d.ts +16 -0
  178. package/src/gpu/readback-utils.js +41 -0
  179. package/src/gpu/submit-tracker.js +13 -0
  180. package/src/gpu/uniform-cache.d.ts +1 -0
  181. package/src/gpu/uniform-cache.js +30 -9
  182. package/src/hotswap/intent-bundle.js +6 -0
  183. package/src/hotswap/manifest.d.ts +10 -1
  184. package/src/hotswap/manifest.js +12 -2
  185. package/src/hotswap/runtime.js +30 -8
  186. package/src/index-browser.d.ts +44 -0
  187. package/src/index-browser.js +14 -0
  188. package/src/inference/browser-harness-contract-helpers.d.ts +5 -0
  189. package/src/inference/browser-harness-contract-helpers.js +28 -0
  190. package/src/inference/browser-harness-diffusion-energy-suites.d.ts +2 -0
  191. package/src/inference/browser-harness-diffusion-energy-suites.js +269 -0
  192. package/src/inference/browser-harness-model-helpers.d.ts +16 -0
  193. package/src/inference/browser-harness-model-helpers.js +217 -0
  194. package/src/inference/browser-harness-report-helpers.d.ts +7 -0
  195. package/src/inference/browser-harness-report-helpers.js +42 -0
  196. package/src/inference/browser-harness-runtime-helpers.d.ts +61 -0
  197. package/src/inference/browser-harness-runtime-helpers.js +415 -0
  198. package/src/inference/browser-harness-suite-helpers.d.ts +28 -0
  199. package/src/inference/browser-harness-suite-helpers.js +268 -0
  200. package/src/inference/browser-harness-text-helpers.d.ts +27 -0
  201. package/src/inference/browser-harness-text-helpers.js +788 -0
  202. package/src/inference/browser-harness.d.ts +6 -0
  203. package/src/inference/browser-harness.js +130 -1950
  204. package/src/inference/kv-cache/base.js +140 -94
  205. package/src/inference/kv-cache/tiered.js +5 -3
  206. package/src/inference/moe-router.js +88 -56
  207. package/src/inference/multi-model-network.js +5 -3
  208. package/src/inference/network-evolution.d.ts +11 -2
  209. package/src/inference/network-evolution.js +20 -21
  210. package/src/inference/pipelines/context.d.ts +3 -0
  211. package/src/inference/pipelines/context.js +142 -2
  212. package/src/inference/pipelines/diffusion/helpers.js +7 -2
  213. package/src/inference/pipelines/diffusion/pipeline.js +17 -7
  214. package/src/inference/pipelines/diffusion/sd3-transformer.js +10 -10
  215. package/src/inference/pipelines/diffusion/text-encoder-gpu.d.ts +5 -0
  216. package/src/inference/pipelines/diffusion/text-encoder-gpu.js +27 -15
  217. package/src/inference/pipelines/diffusion/vae.js +3 -7
  218. package/src/inference/pipelines/energy/pipeline.js +27 -21
  219. package/src/inference/pipelines/energy/quintel.d.ts +5 -0
  220. package/src/inference/pipelines/energy/quintel.js +11 -0
  221. package/src/inference/pipelines/energy-head/row-head-pipeline.js +17 -13
  222. package/src/inference/pipelines/structured/json-head-pipeline.js +26 -11
  223. package/src/inference/pipelines/text/attention/projections.js +151 -101
  224. package/src/inference/pipelines/text/attention/record.js +73 -10
  225. package/src/inference/pipelines/text/attention/run.js +73 -10
  226. package/src/inference/pipelines/text/chat-format.js +25 -1
  227. package/src/inference/pipelines/text/config.d.ts +4 -0
  228. package/src/inference/pipelines/text/config.js +71 -5
  229. package/src/inference/pipelines/text/embed.js +2 -8
  230. package/src/inference/pipelines/text/execution-plan.js +64 -50
  231. package/src/inference/pipelines/text/execution-v0-contract-helpers.d.ts +59 -0
  232. package/src/inference/pipelines/text/execution-v0-contract-helpers.js +937 -0
  233. package/src/inference/pipelines/text/execution-v0-runtime-builders.d.ts +15 -0
  234. package/src/inference/pipelines/text/execution-v0-runtime-builders.js +279 -0
  235. package/src/inference/pipelines/text/execution-v0.js +78 -1002
  236. package/src/inference/pipelines/text/ffn/standard.js +3 -0
  237. package/src/inference/pipelines/text/generator-steps.d.ts +46 -0
  238. package/src/inference/pipelines/text/generator-steps.js +298 -207
  239. package/src/inference/pipelines/text/generator.js +6 -23
  240. package/src/inference/pipelines/text/init.d.ts +4 -0
  241. package/src/inference/pipelines/text/init.js +134 -29
  242. package/src/inference/pipelines/text/kernel-path-auto-select.js +2 -0
  243. package/src/inference/pipelines/text/kernel-trace.d.ts +2 -0
  244. package/src/inference/pipelines/text/kernel-trace.js +6 -0
  245. package/src/inference/pipelines/text/layer.js +14 -9
  246. package/src/inference/pipelines/text/linear-attention.d.ts +10 -0
  247. package/src/inference/pipelines/text/linear-attention.js +80 -6
  248. package/src/inference/pipelines/text/logits/gpu.js +10 -5
  249. package/src/inference/pipelines/text/logits/index.js +10 -11
  250. package/src/inference/pipelines/text/logits/utils.d.ts +7 -0
  251. package/src/inference/pipelines/text/logits/utils.js +9 -0
  252. package/src/inference/pipelines/text/lora-apply.js +50 -32
  253. package/src/inference/pipelines/text/model-load.js +279 -104
  254. package/src/inference/pipelines/text/moe-cache.js +5 -4
  255. package/src/inference/pipelines/text/moe-cpu-gptoss.js +74 -69
  256. package/src/inference/pipelines/text/moe-cpu.js +42 -38
  257. package/src/inference/pipelines/text/moe-gpu.js +110 -86
  258. package/src/inference/pipelines/text/ops.js +90 -90
  259. package/src/inference/pipelines/text/probes.js +9 -9
  260. package/src/inference/pipelines/text/weights.js +17 -7
  261. package/src/inference/pipelines/text.js +17 -1
  262. package/src/inference/speculative.d.ts +2 -2
  263. package/src/inference/speculative.js +4 -18
  264. package/src/inference/test-harness.d.ts +1 -1
  265. package/src/inference/test-harness.js +15 -5
  266. package/src/inference/tokenizer.d.ts +0 -5
  267. package/src/inference/tokenizer.js +4 -23
  268. package/src/inference/tokenizers/bpe.js +9 -0
  269. package/src/inference/tokenizers/bundled.js +176 -33
  270. package/src/inference/tokenizers/sentencepiece.js +12 -0
  271. package/src/loader/doppler-loader.js +38 -22
  272. package/src/loader/dtype-utils.js +3 -44
  273. package/src/loader/embedding-loader.js +7 -3
  274. package/src/loader/experts/expert-cache.js +13 -6
  275. package/src/loader/experts/expert-loader.js +10 -6
  276. package/src/loader/final-weights-loader.js +8 -4
  277. package/src/loader/layer-loader.js +2 -1
  278. package/src/loader/loader-state.js +2 -2
  279. package/src/loader/memory-monitor.js +8 -0
  280. package/src/loader/multi-model-loader.d.ts +14 -0
  281. package/src/loader/multi-model-loader.js +70 -24
  282. package/src/loader/shard-cache.js +81 -12
  283. package/src/loader/shard-resolver.js +25 -3
  284. package/src/loader/tensors/tensor-loader.js +209 -144
  285. package/src/loader/tensors/tensor-reader.js +76 -19
  286. package/src/loader/weight-downcast.js +1 -1
  287. package/src/memory/buffer-pool.d.ts +9 -1
  288. package/src/memory/buffer-pool.js +109 -44
  289. package/src/memory/unified-detect.js +1 -1
  290. package/src/rules/inference/kernel-path.rules.json +24 -8
  291. package/src/rules/rule-registry.js +25 -1
  292. package/src/rules/tooling/command-runtime.rules.json +18 -0
  293. package/src/storage/backends/opfs-store.js +68 -24
  294. package/src/storage/downloader.js +364 -83
  295. package/src/storage/index.d.ts +3 -0
  296. package/src/storage/index.js +3 -0
  297. package/src/storage/preflight.d.ts +2 -2
  298. package/src/storage/preflight.js +24 -2
  299. package/src/storage/quickstart-downloader.js +11 -5
  300. package/src/storage/registry.js +10 -4
  301. package/src/storage/reports.js +1 -1
  302. package/src/storage/shard-manager.d.ts +15 -1
  303. package/src/storage/shard-manager.js +51 -3
  304. package/src/storage/source-artifact-store.d.ts +52 -0
  305. package/src/storage/source-artifact-store.js +234 -0
  306. package/src/tooling/command-api-constants.d.ts +9 -0
  307. package/src/tooling/command-api-constants.js +9 -0
  308. package/src/tooling/command-api-family-normalizers.d.ts +9 -0
  309. package/src/tooling/command-api-family-normalizers.js +343 -0
  310. package/src/tooling/command-api-helpers.d.ts +25 -0
  311. package/src/tooling/command-api-helpers.js +262 -0
  312. package/src/tooling/command-api.d.ts +27 -1
  313. package/src/tooling/command-api.js +26 -473
  314. package/src/tooling/command-envelope.js +4 -1
  315. package/src/tooling/command-runner-shared.js +52 -18
  316. package/src/tooling/lean-execution-contract.js +150 -3
  317. package/src/tooling/node-browser-command-runner.d.ts +4 -0
  318. package/src/tooling/node-browser-command-runner.js +218 -273
  319. package/src/tooling/node-command-runner.js +44 -3
  320. package/src/tooling/node-converter.js +27 -1
  321. package/src/tooling/node-source-runtime.d.ts +1 -1
  322. package/src/tooling/node-source-runtime.js +84 -3
  323. package/src/tooling/node-webgpu.js +30 -105
  324. package/src/tooling/opfs-cache.js +21 -4
  325. package/src/tooling/runtime-input-composition.d.ts +38 -0
  326. package/src/tooling/runtime-input-composition.js +86 -0
  327. package/src/tooling/source-runtime-bundle.d.ts +40 -5
  328. package/src/tooling/source-runtime-bundle.js +261 -34
  329. package/src/tooling/source-runtime-materializer.d.ts +6 -0
  330. package/src/tooling/source-runtime-materializer.js +93 -0
  331. package/src/training/attention-backward.js +32 -17
  332. package/src/training/autograd.js +80 -52
  333. package/src/training/checkpoint-watch.d.ts +8 -0
  334. package/src/training/checkpoint-watch.js +139 -0
  335. package/src/training/checkpoint.d.ts +6 -1
  336. package/src/training/checkpoint.js +46 -7
  337. package/src/training/clip.js +2 -1
  338. package/src/training/datasets/token-batch.js +20 -8
  339. package/src/training/distillation/artifacts.d.ts +71 -0
  340. package/src/training/distillation/artifacts.js +132 -0
  341. package/src/training/distillation/checkpoint-watch.d.ts +10 -0
  342. package/src/training/distillation/checkpoint-watch.js +58 -0
  343. package/src/training/distillation/dataset.d.ts +59 -0
  344. package/src/training/distillation/dataset.js +337 -0
  345. package/src/training/distillation/eval.d.ts +34 -0
  346. package/src/training/distillation/eval.js +310 -0
  347. package/src/training/distillation/index.d.ts +29 -0
  348. package/src/training/distillation/index.js +29 -0
  349. package/src/training/distillation/runtime.d.ts +20 -0
  350. package/src/training/distillation/runtime.js +121 -0
  351. package/src/training/distillation/scoreboard.d.ts +6 -0
  352. package/src/training/distillation/scoreboard.js +8 -0
  353. package/src/training/distillation/stage-a.d.ts +45 -0
  354. package/src/training/distillation/stage-a.js +338 -0
  355. package/src/training/distillation/stage-b.d.ts +24 -0
  356. package/src/training/distillation/stage-b.js +20 -0
  357. package/src/training/distillation/student-fixture.d.ts +22 -0
  358. package/src/training/distillation/student-fixture.js +846 -0
  359. package/src/training/distillation/suite-data.d.ts +45 -0
  360. package/src/training/distillation/suite-data.js +189 -0
  361. package/src/training/index.d.ts +10 -0
  362. package/src/training/index.js +10 -0
  363. package/src/training/lora-pipeline.d.ts +40 -0
  364. package/src/training/lora-pipeline.js +793 -0
  365. package/src/training/lora.js +26 -12
  366. package/src/training/loss.js +5 -6
  367. package/src/training/objectives/cross_entropy.js +2 -5
  368. package/src/training/objectives/distill_kd.js +4 -8
  369. package/src/training/objectives/distill_triplet.js +4 -8
  370. package/src/training/objectives/ul_stage2_base.js +4 -8
  371. package/src/training/operator-artifacts.d.ts +62 -0
  372. package/src/training/operator-artifacts.js +140 -0
  373. package/src/training/operator-command.d.ts +5 -0
  374. package/src/training/operator-command.js +455 -0
  375. package/src/training/operator-eval.d.ts +48 -0
  376. package/src/training/operator-eval.js +230 -0
  377. package/src/training/operator-scoreboard.d.ts +5 -0
  378. package/src/training/operator-scoreboard.js +44 -0
  379. package/src/training/optimizer.js +19 -7
  380. package/src/training/runner.d.ts +52 -0
  381. package/src/training/runner.js +31 -5
  382. package/src/training/suite.d.ts +112 -0
  383. package/src/training/suite.js +24 -984
  384. package/src/training/tensor-factory.d.ts +9 -0
  385. package/src/training/tensor-factory.js +13 -0
  386. package/src/training/trainer.js +3 -5
  387. package/src/training/ul_dataset.js +3 -5
  388. package/src/training/workloads.d.ts +164 -0
  389. package/src/training/workloads.js +530 -0
  390. package/src/version.js +1 -1
  391. package/tools/convert-safetensors-node.js +22 -16
  392. package/tools/doppler-cli.js +179 -63
@@ -1,7 +1,7 @@
1
1
 
2
2
 
3
3
  import { getDevice } from '../device.js';
4
- import { acquireBuffer } from '../../memory/buffer-pool.js';
4
+ import { acquireBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
5
5
  import { createTensor } from '../tensor.js';
6
6
  import { dispatch, recordDispatch } from './dispatch.js';
7
7
  import { createPipeline, createUniformBufferWithView } from './utils.js';
@@ -44,6 +44,7 @@ export async function castF32ToF16(
44
44
  ) {
45
45
  const device = getDevice();
46
46
  const { outputBuffer = null } = options;
47
+ const ownsOutput = outputBuffer == null;
47
48
  const numElements = input.shape.reduce((a, b) => a * b, 1);
48
49
 
49
50
  const pipeline = await createPipeline('cast', 'f32_to_f16');
@@ -51,35 +52,41 @@ export async function castF32ToF16(
51
52
  const outputSize = numElements * DTYPE_SIZES.f16;
52
53
  const output = outputBuffer || acquireBuffer(outputSize, undefined, 'cast_f32_to_f16_output');
53
54
 
54
- const uniformBuffer = createUniformBufferWithView(
55
- 'cast_f32_to_f16_uniforms',
56
- 16,
57
- (view) => {
58
- view.setUint32(0, numElements, true);
59
- },
60
- null,
61
- device
62
- );
63
-
64
- const bindGroup = device.createBindGroup({
65
- label: 'cast_f32_to_f16_bind_group',
66
- layout: pipeline.getBindGroupLayout(0),
67
- entries: [
68
- { binding: 0, resource: { buffer: uniformBuffer } },
69
- { binding: 1, resource: { buffer: input.buffer } },
70
- { binding: 2, resource: { buffer: output } },
71
- ],
72
- });
73
-
74
- // Use 2D dispatch for large tensors (like embeddings with 300M+ elements)
75
- const workgroups = Math.ceil(numElements / WORKGROUP_SIZES.DEFAULT);
76
- const dispatchSize = calculate2DDispatch(workgroups);
55
+ let uniformBuffer = null;
56
+ try {
57
+ uniformBuffer = createUniformBufferWithView(
58
+ 'cast_f32_to_f16_uniforms',
59
+ 16,
60
+ (view) => {
61
+ view.setUint32(0, numElements, true);
62
+ },
63
+ null,
64
+ device
65
+ );
77
66
 
78
- dispatch(device, pipeline, bindGroup, dispatchSize, 'cast_f32_to_f16');
67
+ const bindGroup = device.createBindGroup({
68
+ label: 'cast_f32_to_f16_bind_group',
69
+ layout: pipeline.getBindGroupLayout(0),
70
+ entries: [
71
+ { binding: 0, resource: { buffer: uniformBuffer } },
72
+ { binding: 1, resource: { buffer: input.buffer } },
73
+ { binding: 2, resource: { buffer: output } },
74
+ ],
75
+ });
79
76
 
80
- uniformBuffer.destroy();
77
+ const workgroups = Math.ceil(numElements / WORKGROUP_SIZES.DEFAULT);
78
+ const dispatchSize = calculate2DDispatch(workgroups);
81
79
 
82
- return createTensor(output, 'f16', [...input.shape], input.label ? `${input.label}_f16` : 'cast_f32_to_f16_output');
80
+ dispatch(device, pipeline, bindGroup, dispatchSize, 'cast_f32_to_f16');
81
+ return createTensor(output, 'f16', [...input.shape], input.label ? `${input.label}_f16` : 'cast_f32_to_f16_output');
82
+ } catch (error) {
83
+ if (ownsOutput) {
84
+ releaseBuffer(output);
85
+ }
86
+ throw error;
87
+ } finally {
88
+ uniformBuffer?.destroy();
89
+ }
83
90
  }
84
91
 
85
92
 
@@ -89,6 +96,7 @@ export async function castF16ToF32(
89
96
  ) {
90
97
  const device = getDevice();
91
98
  const { outputBuffer = null } = options;
99
+ const ownsOutput = outputBuffer == null;
92
100
  const numElements = input.shape.reduce((a, b) => a * b, 1);
93
101
 
94
102
  const pipeline = await createPipeline('cast', 'f16_to_f32');
@@ -96,34 +104,41 @@ export async function castF16ToF32(
96
104
  const outputSize = numElements * DTYPE_SIZES.f32;
97
105
  const output = outputBuffer || acquireBuffer(outputSize, undefined, 'cast_f16_to_f32_output');
98
106
 
99
- const uniformBuffer = createUniformBufferWithView(
100
- 'cast_f16_to_f32_uniforms',
101
- 16,
102
- (view) => {
103
- view.setUint32(0, numElements, true);
104
- },
105
- null,
106
- device
107
- );
108
-
109
- const bindGroup = device.createBindGroup({
110
- label: 'cast_f16_to_f32_bind_group',
111
- layout: pipeline.getBindGroupLayout(0),
112
- entries: [
113
- { binding: 0, resource: { buffer: uniformBuffer } },
114
- { binding: 1, resource: { buffer: input.buffer } },
115
- { binding: 2, resource: { buffer: output } },
116
- ],
117
- });
118
-
119
- const workgroups = Math.ceil(numElements / WORKGROUP_SIZES.DEFAULT);
120
- const dispatchSize = calculate2DDispatch(workgroups);
107
+ let uniformBuffer = null;
108
+ try {
109
+ uniformBuffer = createUniformBufferWithView(
110
+ 'cast_f16_to_f32_uniforms',
111
+ 16,
112
+ (view) => {
113
+ view.setUint32(0, numElements, true);
114
+ },
115
+ null,
116
+ device
117
+ );
121
118
 
122
- dispatch(device, pipeline, bindGroup, dispatchSize, 'cast_f16_to_f32');
119
+ const bindGroup = device.createBindGroup({
120
+ label: 'cast_f16_to_f32_bind_group',
121
+ layout: pipeline.getBindGroupLayout(0),
122
+ entries: [
123
+ { binding: 0, resource: { buffer: uniformBuffer } },
124
+ { binding: 1, resource: { buffer: input.buffer } },
125
+ { binding: 2, resource: { buffer: output } },
126
+ ],
127
+ });
123
128
 
124
- uniformBuffer.destroy();
129
+ const workgroups = Math.ceil(numElements / WORKGROUP_SIZES.DEFAULT);
130
+ const dispatchSize = calculate2DDispatch(workgroups);
125
131
 
126
- return createTensor(output, 'f32', [...input.shape], input.label ? `${input.label}_f32` : 'cast_f16_to_f32_output');
132
+ dispatch(device, pipeline, bindGroup, dispatchSize, 'cast_f16_to_f32');
133
+ return createTensor(output, 'f32', [...input.shape], input.label ? `${input.label}_f32` : 'cast_f16_to_f32_output');
134
+ } catch (error) {
135
+ if (ownsOutput) {
136
+ releaseBuffer(output);
137
+ }
138
+ throw error;
139
+ } finally {
140
+ uniformBuffer?.destroy();
141
+ }
127
142
  }
128
143
 
129
144
 
@@ -134,6 +149,7 @@ export async function recordCastF32ToF16(
134
149
  ) {
135
150
  const device = recorder.device;
136
151
  const { outputBuffer = null } = options;
152
+ const ownsOutput = outputBuffer == null;
137
153
  const numElements = input.shape.reduce((a, b) => a * b, 1);
138
154
 
139
155
  const pipeline = await createPipeline('cast', 'f32_to_f16');
@@ -141,32 +157,37 @@ export async function recordCastF32ToF16(
141
157
  const outputSize = numElements * DTYPE_SIZES.f16;
142
158
  const output = outputBuffer || acquireBuffer(outputSize, undefined, 'cast_f32_to_f16_output');
143
159
 
144
- const uniformBuffer = createUniformBufferWithView(
145
- 'cast_f32_to_f16_uniforms',
146
- 16,
147
- (view) => {
148
- view.setUint32(0, numElements, true);
149
- },
150
- recorder
151
- );
152
-
153
- const bindGroup = device.createBindGroup({
154
- label: 'cast_f32_to_f16_bind_group',
155
- layout: pipeline.getBindGroupLayout(0),
156
- entries: [
157
- { binding: 0, resource: { buffer: uniformBuffer } },
158
- { binding: 1, resource: { buffer: input.buffer } },
159
- { binding: 2, resource: { buffer: output } },
160
- ],
161
- });
160
+ try {
161
+ const uniformBuffer = createUniformBufferWithView(
162
+ 'cast_f32_to_f16_uniforms',
163
+ 16,
164
+ (view) => {
165
+ view.setUint32(0, numElements, true);
166
+ },
167
+ recorder
168
+ );
162
169
 
163
- // Use 2D dispatch for large tensors
164
- const workgroups = Math.ceil(numElements / WORKGROUP_SIZES.DEFAULT);
165
- const dispatchSize = calculate2DDispatch(workgroups);
170
+ const bindGroup = device.createBindGroup({
171
+ label: 'cast_f32_to_f16_bind_group',
172
+ layout: pipeline.getBindGroupLayout(0),
173
+ entries: [
174
+ { binding: 0, resource: { buffer: uniformBuffer } },
175
+ { binding: 1, resource: { buffer: input.buffer } },
176
+ { binding: 2, resource: { buffer: output } },
177
+ ],
178
+ });
166
179
 
167
- recordDispatch(recorder, pipeline, bindGroup, dispatchSize, 'cast_f32_to_f16');
180
+ const workgroups = Math.ceil(numElements / WORKGROUP_SIZES.DEFAULT);
181
+ const dispatchSize = calculate2DDispatch(workgroups);
168
182
 
169
- return createTensor(output, 'f16', [...input.shape], input.label ? `${input.label}_f16` : 'cast_f32_to_f16_output');
183
+ recordDispatch(recorder, pipeline, bindGroup, dispatchSize, 'cast_f32_to_f16');
184
+ return createTensor(output, 'f16', [...input.shape], input.label ? `${input.label}_f16` : 'cast_f32_to_f16_output');
185
+ } catch (error) {
186
+ if (ownsOutput) {
187
+ releaseBuffer(output);
188
+ }
189
+ throw error;
190
+ }
170
191
  }
171
192
 
172
193
 
@@ -177,6 +198,7 @@ export async function recordCastF16ToF32(
177
198
  ) {
178
199
  const device = recorder.device;
179
200
  const { outputBuffer = null } = options;
201
+ const ownsOutput = outputBuffer == null;
180
202
  const numElements = input.shape.reduce((a, b) => a * b, 1);
181
203
 
182
204
  const pipeline = await createPipeline('cast', 'f16_to_f32');
@@ -184,31 +206,37 @@ export async function recordCastF16ToF32(
184
206
  const outputSize = numElements * DTYPE_SIZES.f32;
185
207
  const output = outputBuffer || acquireBuffer(outputSize, undefined, 'cast_f16_to_f32_output');
186
208
 
187
- const uniformBuffer = createUniformBufferWithView(
188
- 'cast_f16_to_f32_uniforms',
189
- 16,
190
- (view) => {
191
- view.setUint32(0, numElements, true);
192
- },
193
- recorder
194
- );
195
-
196
- const bindGroup = device.createBindGroup({
197
- label: 'cast_f16_to_f32_bind_group',
198
- layout: pipeline.getBindGroupLayout(0),
199
- entries: [
200
- { binding: 0, resource: { buffer: uniformBuffer } },
201
- { binding: 1, resource: { buffer: input.buffer } },
202
- { binding: 2, resource: { buffer: output } },
203
- ],
204
- });
209
+ try {
210
+ const uniformBuffer = createUniformBufferWithView(
211
+ 'cast_f16_to_f32_uniforms',
212
+ 16,
213
+ (view) => {
214
+ view.setUint32(0, numElements, true);
215
+ },
216
+ recorder
217
+ );
205
218
 
206
- const workgroups = Math.ceil(numElements / WORKGROUP_SIZES.DEFAULT);
207
- const dispatchSize = calculate2DDispatch(workgroups);
219
+ const bindGroup = device.createBindGroup({
220
+ label: 'cast_f16_to_f32_bind_group',
221
+ layout: pipeline.getBindGroupLayout(0),
222
+ entries: [
223
+ { binding: 0, resource: { buffer: uniformBuffer } },
224
+ { binding: 1, resource: { buffer: input.buffer } },
225
+ { binding: 2, resource: { buffer: output } },
226
+ ],
227
+ });
208
228
 
209
- recordDispatch(recorder, pipeline, bindGroup, dispatchSize, 'cast_f16_to_f32');
229
+ const workgroups = Math.ceil(numElements / WORKGROUP_SIZES.DEFAULT);
230
+ const dispatchSize = calculate2DDispatch(workgroups);
210
231
 
211
- return createTensor(output, 'f32', [...input.shape], input.label ? `${input.label}_f32` : 'cast_f16_to_f32_output');
232
+ recordDispatch(recorder, pipeline, bindGroup, dispatchSize, 'cast_f16_to_f32');
233
+ return createTensor(output, 'f32', [...input.shape], input.label ? `${input.label}_f32` : 'cast_f16_to_f32_output');
234
+ } catch (error) {
235
+ if (ownsOutput) {
236
+ releaseBuffer(output);
237
+ }
238
+ throw error;
239
+ }
212
240
  }
213
241
 
214
242
 
@@ -276,11 +304,15 @@ export async function runBF16ToF32(
276
304
  const dispatchSize = calculate2DDispatch(workgroups);
277
305
 
278
306
  trace.kernels(`BF16ToF32: Dispatching ${dispatchSize[0]}x${dispatchSize[1]} workgroups for ${numPairs} pairs (${numElements} elements)`);
279
- dispatch(device, pipeline, bindGroup, dispatchSize, 'bf16_to_f32');
280
-
281
- uniformBuffer.destroy();
282
-
283
- return createTensor(output, 'f32', [...shape], name);
307
+ try {
308
+ dispatch(device, pipeline, bindGroup, dispatchSize, 'bf16_to_f32');
309
+ return createTensor(output, 'f32', [...shape], name);
310
+ } catch (error) {
311
+ releaseBuffer(output);
312
+ throw error;
313
+ } finally {
314
+ uniformBuffer.destroy();
315
+ }
284
316
  }
285
317
 
286
318
 
@@ -337,11 +369,15 @@ export async function runBF16ToF16(
337
369
  const workgroups = Math.ceil(numPairs / WORKGROUP_SIZES.DEFAULT);
338
370
  const dispatchSize = calculate2DDispatch(workgroups);
339
371
 
340
- dispatch(device, pipeline, bindGroup, dispatchSize, 'bf16_to_f16');
341
-
342
- uniformBuffer.destroy();
343
-
344
- return createTensor(output, 'f16', [...shape], name);
372
+ try {
373
+ dispatch(device, pipeline, bindGroup, dispatchSize, 'bf16_to_f16');
374
+ return createTensor(output, 'f16', [...shape], name);
375
+ } catch (error) {
376
+ releaseBuffer(output);
377
+ throw error;
378
+ } finally {
379
+ uniformBuffer.destroy();
380
+ }
345
381
  }
346
382
 
347
383
 
@@ -375,48 +411,54 @@ async function runBF16ToF32Chunked(
375
411
 
376
412
  trace.kernels(`BF16ToF32: Chunking ${numElements} elements in ${numChunks} chunks`);
377
413
 
378
- for (let chunkIdx = 0; chunkIdx < numChunks; chunkIdx++) {
379
- const chunkStart = chunkIdx * maxElementsPerChunk;
380
- const chunkEnd = Math.min((chunkIdx + 1) * maxElementsPerChunk, numElements);
381
- const chunkSize = chunkEnd - chunkStart;
382
-
383
- const uniformBuffer = createUniformBufferWithView(
384
- `bf16_to_f32_chunk${chunkIdx}_uniforms`,
385
- 16,
386
- (view) => {
387
- view.setUint32(0, chunkSize, true);
388
- view.setUint32(4, 0, true);
389
- view.setUint32(8, 0, true);
390
- },
391
- null,
392
- device
393
- );
394
-
395
- const inputOffsetBytes = chunkStart * DTYPE_SIZES.bf16;
396
- const outputOffsetBytes = chunkStart * DTYPE_SIZES.f32;
397
- const inputPairs = Math.ceil(chunkSize / 2);
398
- const inputSizeBytes = inputPairs * DTYPE_SIZES.f32; // Pairs read as u32
399
- const outputSizeBytes = chunkSize * DTYPE_SIZES.f32;
400
-
401
- const bindGroup = device.createBindGroup({
402
- label: `bf16_to_f32_chunk${chunkIdx}_bind_group`,
403
- layout: pipeline.getBindGroupLayout(0),
404
- entries: [
405
- { binding: 0, resource: { buffer: uniformBuffer } },
406
- { binding: 1, resource: { buffer: input, offset: inputOffsetBytes, size: inputSizeBytes } },
407
- { binding: 2, resource: { buffer: output, offset: outputOffsetBytes, size: outputSizeBytes } },
408
- ],
409
- });
410
-
411
- // Each thread processes 2 BF16 values
412
- const numPairs = Math.ceil(chunkSize / 2);
413
- const workgroups = Math.ceil(numPairs / WORKGROUP_SIZES.DEFAULT);
414
- const dispatchSize = calculate2DDispatch(workgroups);
415
-
416
- dispatch(device, pipeline, bindGroup, dispatchSize, `bf16_to_f32_chunk${chunkIdx}`);
414
+ try {
415
+ for (let chunkIdx = 0; chunkIdx < numChunks; chunkIdx++) {
416
+ const chunkStart = chunkIdx * maxElementsPerChunk;
417
+ const chunkEnd = Math.min((chunkIdx + 1) * maxElementsPerChunk, numElements);
418
+ const chunkSize = chunkEnd - chunkStart;
419
+
420
+ const uniformBuffer = createUniformBufferWithView(
421
+ `bf16_to_f32_chunk${chunkIdx}_uniforms`,
422
+ 16,
423
+ (view) => {
424
+ view.setUint32(0, chunkSize, true);
425
+ view.setUint32(4, 0, true);
426
+ view.setUint32(8, 0, true);
427
+ },
428
+ null,
429
+ device
430
+ );
431
+
432
+ try {
433
+ const inputOffsetBytes = chunkStart * DTYPE_SIZES.bf16;
434
+ const outputOffsetBytes = chunkStart * DTYPE_SIZES.f32;
435
+ const inputPairs = Math.ceil(chunkSize / 2);
436
+ const inputSizeBytes = inputPairs * DTYPE_SIZES.f32;
437
+ const outputSizeBytes = chunkSize * DTYPE_SIZES.f32;
438
+
439
+ const bindGroup = device.createBindGroup({
440
+ label: `bf16_to_f32_chunk${chunkIdx}_bind_group`,
441
+ layout: pipeline.getBindGroupLayout(0),
442
+ entries: [
443
+ { binding: 0, resource: { buffer: uniformBuffer } },
444
+ { binding: 1, resource: { buffer: input, offset: inputOffsetBytes, size: inputSizeBytes } },
445
+ { binding: 2, resource: { buffer: output, offset: outputOffsetBytes, size: outputSizeBytes } },
446
+ ],
447
+ });
448
+
449
+ const numPairs = Math.ceil(chunkSize / 2);
450
+ const workgroups = Math.ceil(numPairs / WORKGROUP_SIZES.DEFAULT);
451
+ const dispatchSize = calculate2DDispatch(workgroups);
452
+
453
+ dispatch(device, pipeline, bindGroup, dispatchSize, `bf16_to_f32_chunk${chunkIdx}`);
454
+ } finally {
455
+ uniformBuffer.destroy();
456
+ }
457
+ }
417
458
 
418
- uniformBuffer.destroy();
459
+ return createTensor(output, 'f32', [...shape], name);
460
+ } catch (error) {
461
+ releaseBuffer(output);
462
+ throw error;
419
463
  }
420
-
421
- return createTensor(output, 'f32', [...shape], name);
422
464
  }
@@ -1,7 +1,7 @@
1
1
 
2
2
 
3
3
  import { getDevice, getDeviceEpoch } from '../device.js';
4
- import { acquireBuffer } from '../../memory/buffer-pool.js';
4
+ import { acquireBuffer, readBufferSlice } from '../../memory/buffer-pool.js';
5
5
  import { recordDispatch } from './dispatch.js';
6
6
  import { createUniformBufferFromData, getOrCreateBindGroupLayout, getOrCreatePipelineLayout } from './utils.js';
7
7
  import { allowReadback } from '../perf-guards.js';
@@ -133,49 +133,38 @@ export async function checkStop(params) {
133
133
  usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC,
134
134
  });
135
135
  const ownsStopBuffer = !params.shouldStopBuffer;
136
- if (shouldStopBuffer.size < requiredBytes) {
137
- throw new Error('[CheckStop] shouldStopBuffer too small for tokenIndex.');
138
- }
139
136
 
140
- const bindGroup = device.createBindGroup({
141
- layout: getCheckStopBindGroupLayout(device),
142
- entries: [
143
- { binding: 0, resource: { buffer: uniformBuffer } },
144
- { binding: 1, resource: { buffer: params.sampledTokenBuffer } },
145
- { binding: 2, resource: { buffer: shouldStopBuffer } },
146
- ],
147
- });
148
-
149
- const encoder = device.createCommandEncoder();
150
- const pass = encoder.beginComputePass();
151
- pass.setPipeline(pipeline);
152
- pass.setBindGroup(0, bindGroup);
153
- pass.dispatchWorkgroups(1, 1, 1);
154
- pass.end();
155
-
156
- // Readback result
157
- const stagingBuffer = device.createBuffer({
158
- size: U32_BYTES,
159
- usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST,
160
- });
161
- encoder.copyBufferToBuffer(
162
- shouldStopBuffer,
163
- tokenIndex * U32_BYTES,
164
- stagingBuffer,
165
- 0,
166
- U32_BYTES
167
- );
168
- device.queue.submit([encoder.finish()]);
169
-
170
- await stagingBuffer.mapAsync(GPUMapMode.READ);
171
- const result = new Uint32Array(stagingBuffer.getMappedRange())[0];
172
- stagingBuffer.unmap();
173
-
174
- uniformBuffer.destroy();
175
- if (ownsStopBuffer) {
176
- shouldStopBuffer.destroy();
137
+ try {
138
+ if (shouldStopBuffer.size < requiredBytes) {
139
+ throw new Error('[CheckStop] shouldStopBuffer too small for tokenIndex.');
140
+ }
141
+
142
+ const bindGroup = device.createBindGroup({
143
+ layout: getCheckStopBindGroupLayout(device),
144
+ entries: [
145
+ { binding: 0, resource: { buffer: uniformBuffer } },
146
+ { binding: 1, resource: { buffer: params.sampledTokenBuffer } },
147
+ { binding: 2, resource: { buffer: shouldStopBuffer } },
148
+ ],
149
+ });
150
+
151
+ const encoder = device.createCommandEncoder();
152
+ const pass = encoder.beginComputePass();
153
+ pass.setPipeline(pipeline);
154
+ pass.setBindGroup(0, bindGroup);
155
+ pass.dispatchWorkgroups(1, 1, 1);
156
+ pass.end();
157
+
158
+ device.queue.submit([encoder.finish()]);
159
+
160
+ const result = new Uint32Array(
161
+ await readBufferSlice(shouldStopBuffer, tokenIndex * U32_BYTES, U32_BYTES)
162
+ )[0];
163
+ return result === 1;
164
+ } finally {
165
+ uniformBuffer.destroy();
166
+ if (ownsStopBuffer) {
167
+ shouldStopBuffer.destroy();
168
+ }
177
169
  }
178
- stagingBuffer.destroy();
179
-
180
- return result === 1;
181
170
  }
@@ -49,27 +49,37 @@ async function _conv2d(target, input, weight, bias, options = {}) {
49
49
  device.queue.writeBuffer(biasBuffer, 0, new Uint8Array(paddedSize));
50
50
  }
51
51
 
52
- await unifiedKernelWrapper(
53
- 'conv2d', target, variant,
54
- [input, weightBuffer, biasBuffer, output],
55
- {
56
- in_channels: inChannels, out_channels: outChannels,
57
- height, width, out_height: outHeight, out_width: outWidth,
58
- kernel_h: kernelH, kernel_w: kernelW,
59
- stride, pad, _pad0: 0, _pad1: 0,
60
- },
61
- Math.ceil((outChannels * outHeight * outWidth) / WORKGROUP_SIZES.DEFAULT)
62
- );
52
+ try {
53
+ await unifiedKernelWrapper(
54
+ 'conv2d', target, variant,
55
+ [input, weightBuffer, biasBuffer, output],
56
+ {
57
+ in_channels: inChannels, out_channels: outChannels,
58
+ height, width, out_height: outHeight, out_width: outWidth,
59
+ kernel_h: kernelH, kernel_w: kernelW,
60
+ stride, pad, _pad0: 0, _pad1: 0,
61
+ },
62
+ [Math.ceil((outHeight * outWidth) / WORKGROUP_SIZES.DEFAULT), outChannels, 1]
63
+ );
63
64
 
64
- if (tempBias) {
65
- if (recorder) {
66
- recorder.trackTemporaryBuffer(tempBias);
67
- } else {
65
+ if (tempBias) {
66
+ if (recorder) {
67
+ recorder.trackTemporaryBuffer(tempBias);
68
+ } else {
69
+ releaseBuffer(tempBias);
70
+ }
71
+ }
72
+
73
+ return createTensor(output, input.dtype, [outChannels, outHeight, outWidth], 'conv2d_output');
74
+ } catch (error) {
75
+ if (tempBias) {
68
76
  releaseBuffer(tempBias);
69
77
  }
78
+ if (!outputBuffer) {
79
+ releaseBuffer(output);
80
+ }
81
+ throw error;
70
82
  }
71
-
72
- return createTensor(output, input.dtype, [outChannels, outHeight, outWidth], 'conv2d_output');
73
83
  }
74
84
 
75
85
  export async function runConv2D(input, weight, bias, options = {}) {
@@ -27,19 +27,18 @@ struct Uniforms {
27
27
 
28
28
  @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
29
29
  fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
30
- let idx = gid.x;
31
30
  let out_height = u.out_height;
32
31
  let out_width = u.out_width;
33
- let out_size = u.out_channels * out_height * out_width;
34
- if (idx >= out_size) {
32
+ let out_spatial = out_height * out_width;
33
+ let out_spatial_idx = gid.x;
34
+ let out_c = gid.y;
35
+ if (out_c >= u.out_channels || out_spatial_idx >= out_spatial) {
35
36
  return;
36
37
  }
37
38
 
38
- let out_spatial = out_height * out_width;
39
- let out_c = idx / out_spatial;
40
- let rem = idx - out_c * out_spatial;
41
- let out_y = rem / out_width;
42
- let out_x = rem - out_y * out_width;
39
+ let out_y = out_spatial_idx / out_width;
40
+ let out_x = out_spatial_idx - out_y * out_width;
41
+ let idx = out_c * out_spatial + out_spatial_idx;
43
42
 
44
43
  var sum: f32 = bias[out_c];
45
44
 
@@ -29,19 +29,18 @@ struct Uniforms {
29
29
 
30
30
  @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
31
31
  fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
32
- let idx = gid.x;
33
32
  let out_height = u.out_height;
34
33
  let out_width = u.out_width;
35
- let out_size = u.out_channels * out_height * out_width;
36
- if (idx >= out_size) {
34
+ let out_spatial = out_height * out_width;
35
+ let out_spatial_idx = gid.x;
36
+ let out_c = gid.y;
37
+ if (out_c >= u.out_channels || out_spatial_idx >= out_spatial) {
37
38
  return;
38
39
  }
39
40
 
40
- let out_spatial = out_height * out_width;
41
- let out_c = idx / out_spatial;
42
- let rem = idx - out_c * out_spatial;
43
- let out_y = rem / out_width;
44
- let out_x = rem - out_y * out_width;
41
+ let out_y = out_spatial_idx / out_width;
42
+ let out_x = out_spatial_idx - out_y * out_width;
43
+ let idx = out_c * out_spatial + out_spatial_idx;
45
44
 
46
45
  var sum: f32 = f32(bias[out_c]);
47
46