@simulatte/doppler 0.1.5 → 0.1.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (392) hide show
  1. package/CHANGELOG.md +126 -0
  2. package/README.md +25 -17
  3. package/package.json +20 -4
  4. package/src/adapters/adapter-registry.js +12 -1
  5. package/src/adapters/lora-loader.js +23 -6
  6. package/src/bridge/extension-client.d.ts +5 -0
  7. package/src/bridge/extension-client.js +40 -0
  8. package/src/bridge/index.d.ts +2 -1
  9. package/src/bridge/index.js +6 -4
  10. package/src/browser/browser-converter.js +26 -1
  11. package/src/browser/file-picker.js +6 -0
  12. package/src/browser/safetensors-parser-browser.js +84 -1
  13. package/src/browser/shard-io-browser.js +2 -2
  14. package/src/browser/tensor-source-download.js +8 -2
  15. package/src/browser/tensor-source-http.d.ts +1 -0
  16. package/src/browser/tensor-source-http.js +5 -1
  17. package/src/client/doppler-api.browser.js +20 -4
  18. package/src/client/doppler-api.js +19 -3
  19. package/src/client/doppler-provider/generation.js +12 -0
  20. package/src/client/doppler-provider/model-manager.d.ts +10 -0
  21. package/src/client/doppler-provider/model-manager.js +91 -19
  22. package/src/client/doppler-provider/source-runtime.d.ts +2 -1
  23. package/src/client/doppler-provider/source-runtime.js +132 -13
  24. package/src/client/doppler-registry.json +8 -7
  25. package/src/config/backward-registry-loader.js +17 -2
  26. package/src/config/execution-v0-contract-check.js +113 -15
  27. package/src/config/kernel-path-contract-check.js +57 -29
  28. package/src/config/kernel-path-loader.js +5 -36
  29. package/src/config/kernels/kernel-ref-digests.js +39 -39
  30. package/src/config/kernels/registry.js +14 -1
  31. package/src/config/kernels/registry.json +49 -7
  32. package/src/config/loader.d.ts +1 -1
  33. package/src/config/loader.js +43 -4
  34. package/src/config/merge-contract-check.js +59 -4
  35. package/src/config/merge-helpers.js +128 -7
  36. package/src/config/merge.d.ts +1 -0
  37. package/src/config/merge.js +28 -0
  38. package/src/config/param-validator.js +47 -2
  39. package/src/config/presets/kernel-paths/{gemma2-q4k-dequant-f32a.json → gemma2-q4k-dequant-f32a-nosubgroups.json} +3 -3
  40. package/src/config/presets/kernel-paths/gemma3-f16-fused-f32a-online-streamingprefill.json +223 -0
  41. package/src/config/presets/kernel-paths/{gemma3-q4k-dequant-f32a.json → gemma3-q4k-dequant-f32a-nosubgroups.json} +3 -3
  42. package/src/config/presets/kernel-paths/registry.json +29 -8
  43. package/src/config/presets/models/gemma2.json +2 -2
  44. package/src/config/presets/models/qwen3.json +9 -2
  45. package/src/config/presets/models/transformer.json +5 -0
  46. package/src/config/presets/runtime/experiments/bench/gemma3-bench-q4k.json +1 -1
  47. package/src/config/presets/runtime/experiments/debug/gemma3-debug-q4k.json +1 -1
  48. package/src/config/presets/runtime/experiments/verify/gemma3-verify.json +1 -1
  49. package/src/config/presets/runtime/kernels/dequant-f16-q4k.json +6 -13
  50. package/src/config/presets/runtime/kernels/dequant-f32-q4k.json +6 -13
  51. package/src/config/presets/runtime/kernels/embeddinggemma-q4k-dequant-f32a.json +37 -0
  52. package/src/config/presets/runtime/kernels/fused-q4k.json +6 -13
  53. package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f16a.json +33 -0
  54. package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f32a-nosubgroups.json +33 -0
  55. package/src/config/presets/runtime/kernels/gemma2-q4k-fused-f32a.json +33 -0
  56. package/src/config/presets/runtime/kernels/safe-q4k.json +6 -13
  57. package/src/config/presets/runtime/platform/metal-apple-q4k.json +1 -1
  58. package/src/config/required-inference-fields-contract-check.js +6 -0
  59. package/src/config/runtime.js +6 -1
  60. package/src/config/schema/debug.schema.d.ts +5 -0
  61. package/src/config/schema/doppler.schema.js +16 -21
  62. package/src/config/schema/inference-defaults.schema.js +6 -3
  63. package/src/config/schema/inference.schema.d.ts +9 -0
  64. package/src/config/schema/kernel-path.schema.d.ts +11 -1
  65. package/src/config/schema/kernel-thresholds.schema.js +12 -4
  66. package/src/config/schema/manifest.schema.d.ts +8 -1
  67. package/src/config/schema/manifest.schema.js +19 -3
  68. package/src/config/training-defaults.js +30 -22
  69. package/src/converter/conversion-plan.js +94 -9
  70. package/src/converter/core.d.ts +7 -0
  71. package/src/converter/core.js +14 -9
  72. package/src/converter/execution-v0-manifest.js +4 -1
  73. package/src/converter/index.d.ts +1 -0
  74. package/src/converter/index.js +1 -0
  75. package/src/converter/manifest-inference.js +43 -12
  76. package/src/converter/parsers/diffusion.js +0 -3
  77. package/src/converter/quantization-info.js +35 -15
  78. package/src/converter/rope-config.js +42 -0
  79. package/src/converter/shard-packer.d.ts +1 -1
  80. package/src/converter/shard-packer.js +4 -1
  81. package/src/debug/config.js +123 -11
  82. package/src/debug/signals.js +7 -1
  83. package/src/debug/tensor.d.ts +2 -0
  84. package/src/debug/tensor.js +13 -2
  85. package/src/distribution/p2p-control-plane.js +52 -12
  86. package/src/distribution/p2p-observability.js +43 -7
  87. package/src/distribution/p2p-webrtc-browser.js +20 -0
  88. package/src/distribution/shard-delivery.js +77 -26
  89. package/src/formats/gguf/types.js +33 -16
  90. package/src/formats/rdrr/groups.d.ts +12 -4
  91. package/src/formats/rdrr/groups.js +3 -6
  92. package/src/formats/rdrr/parsing.js +39 -2
  93. package/src/formats/rdrr/types.d.ts +2 -1
  94. package/src/gpu/command-recorder.js +86 -61
  95. package/src/gpu/device.d.ts +1 -0
  96. package/src/gpu/device.js +131 -19
  97. package/src/gpu/kernel-tuner/benchmarks.js +326 -316
  98. package/src/gpu/kernel-tuner/cache.js +71 -4
  99. package/src/gpu/kernel-tuner/tuner.js +22 -4
  100. package/src/gpu/kernels/attention.js +113 -34
  101. package/src/gpu/kernels/backward/adam.js +62 -58
  102. package/src/gpu/kernels/backward/attention_backward.js +257 -169
  103. package/src/gpu/kernels/backward/conv2d_backward.js +14 -1
  104. package/src/gpu/kernels/bias_add.wgsl +8 -6
  105. package/src/gpu/kernels/bias_add_f16.wgsl +8 -5
  106. package/src/gpu/kernels/cast.js +191 -149
  107. package/src/gpu/kernels/check-stop.js +33 -44
  108. package/src/gpu/kernels/conv2d.js +27 -17
  109. package/src/gpu/kernels/conv2d.wgsl +7 -8
  110. package/src/gpu/kernels/conv2d_f16.wgsl +7 -8
  111. package/src/gpu/kernels/cross_entropy_loss.js +21 -15
  112. package/src/gpu/kernels/depthwise_conv2d.js +37 -26
  113. package/src/gpu/kernels/depthwise_conv2d.wgsl +6 -9
  114. package/src/gpu/kernels/depthwise_conv2d_f16.wgsl +6 -9
  115. package/src/gpu/kernels/dequant.js +178 -126
  116. package/src/gpu/kernels/energy.d.ts +3 -21
  117. package/src/gpu/kernels/energy.js +111 -88
  118. package/src/gpu/kernels/feature-check.js +1 -1
  119. package/src/gpu/kernels/fused_ffn.js +84 -65
  120. package/src/gpu/kernels/fused_matmul_residual.js +56 -33
  121. package/src/gpu/kernels/fused_matmul_rmsnorm.js +62 -45
  122. package/src/gpu/kernels/gather.js +33 -15
  123. package/src/gpu/kernels/gelu.js +19 -11
  124. package/src/gpu/kernels/grouped_pointwise_conv2d.js +34 -23
  125. package/src/gpu/kernels/grouped_pointwise_conv2d.wgsl +6 -9
  126. package/src/gpu/kernels/grouped_pointwise_conv2d_f16.wgsl +6 -9
  127. package/src/gpu/kernels/groupnorm.js +34 -23
  128. package/src/gpu/kernels/kv-quantize.js +5 -2
  129. package/src/gpu/kernels/layernorm.js +35 -19
  130. package/src/gpu/kernels/logit-merge.js +5 -3
  131. package/src/gpu/kernels/matmul.js +83 -39
  132. package/src/gpu/kernels/modulate.js +23 -15
  133. package/src/gpu/kernels/moe.js +221 -175
  134. package/src/gpu/kernels/pixel_shuffle.js +22 -14
  135. package/src/gpu/kernels/pixel_shuffle.wgsl +4 -5
  136. package/src/gpu/kernels/pixel_shuffle_f16.wgsl +4 -5
  137. package/src/gpu/kernels/relu.js +31 -10
  138. package/src/gpu/kernels/relu.wgsl +2 -1
  139. package/src/gpu/kernels/relu_f16.wgsl +2 -1
  140. package/src/gpu/kernels/repeat_channels.js +25 -17
  141. package/src/gpu/kernels/repeat_channels.wgsl +4 -5
  142. package/src/gpu/kernels/repeat_channels_f16.wgsl +4 -5
  143. package/src/gpu/kernels/residual.js +69 -23
  144. package/src/gpu/kernels/residual.wgsl +6 -3
  145. package/src/gpu/kernels/residual_f16.wgsl +2 -1
  146. package/src/gpu/kernels/residual_f16_vec4.wgsl +2 -1
  147. package/src/gpu/kernels/residual_vec4.wgsl +2 -1
  148. package/src/gpu/kernels/rmsnorm.js +96 -28
  149. package/src/gpu/kernels/rmsnorm.wgsl +14 -6
  150. package/src/gpu/kernels/rmsnorm_f16.wgsl +10 -2
  151. package/src/gpu/kernels/rope.d.ts +2 -0
  152. package/src/gpu/kernels/rope.js +14 -1
  153. package/src/gpu/kernels/rope.wgsl +56 -40
  154. package/src/gpu/kernels/sample.js +27 -38
  155. package/src/gpu/kernels/sana_linear_attention.js +19 -12
  156. package/src/gpu/kernels/sana_linear_attention_apply.wgsl +4 -5
  157. package/src/gpu/kernels/sana_linear_attention_apply_f16.wgsl +4 -5
  158. package/src/gpu/kernels/sana_linear_attention_summary.wgsl +4 -0
  159. package/src/gpu/kernels/sana_linear_attention_summary_f16.wgsl +4 -0
  160. package/src/gpu/kernels/scale.js +18 -11
  161. package/src/gpu/kernels/shader-cache.js +4 -2
  162. package/src/gpu/kernels/silu.d.ts +1 -0
  163. package/src/gpu/kernels/silu.js +148 -82
  164. package/src/gpu/kernels/silu.wgsl +19 -9
  165. package/src/gpu/kernels/silu_f16.wgsl +19 -9
  166. package/src/gpu/kernels/softmax.js +44 -25
  167. package/src/gpu/kernels/split_qkv.js +23 -13
  168. package/src/gpu/kernels/transpose.js +31 -10
  169. package/src/gpu/kernels/transpose.wgsl +6 -5
  170. package/src/gpu/kernels/upsample2d.js +22 -13
  171. package/src/gpu/kernels/upsample2d.wgsl +6 -9
  172. package/src/gpu/kernels/upsample2d_f16.wgsl +6 -9
  173. package/src/gpu/kernels/utils.js +35 -13
  174. package/src/gpu/partitioned-buffer-pool.js +10 -2
  175. package/src/gpu/perf-guards.js +2 -9
  176. package/src/gpu/profiler.js +27 -22
  177. package/src/gpu/readback-utils.d.ts +16 -0
  178. package/src/gpu/readback-utils.js +41 -0
  179. package/src/gpu/submit-tracker.js +13 -0
  180. package/src/gpu/uniform-cache.d.ts +1 -0
  181. package/src/gpu/uniform-cache.js +30 -9
  182. package/src/hotswap/intent-bundle.js +6 -0
  183. package/src/hotswap/manifest.d.ts +10 -1
  184. package/src/hotswap/manifest.js +12 -2
  185. package/src/hotswap/runtime.js +30 -8
  186. package/src/index-browser.d.ts +44 -0
  187. package/src/index-browser.js +14 -0
  188. package/src/inference/browser-harness-contract-helpers.d.ts +5 -0
  189. package/src/inference/browser-harness-contract-helpers.js +28 -0
  190. package/src/inference/browser-harness-diffusion-energy-suites.d.ts +2 -0
  191. package/src/inference/browser-harness-diffusion-energy-suites.js +269 -0
  192. package/src/inference/browser-harness-model-helpers.d.ts +16 -0
  193. package/src/inference/browser-harness-model-helpers.js +217 -0
  194. package/src/inference/browser-harness-report-helpers.d.ts +7 -0
  195. package/src/inference/browser-harness-report-helpers.js +42 -0
  196. package/src/inference/browser-harness-runtime-helpers.d.ts +61 -0
  197. package/src/inference/browser-harness-runtime-helpers.js +415 -0
  198. package/src/inference/browser-harness-suite-helpers.d.ts +28 -0
  199. package/src/inference/browser-harness-suite-helpers.js +268 -0
  200. package/src/inference/browser-harness-text-helpers.d.ts +27 -0
  201. package/src/inference/browser-harness-text-helpers.js +788 -0
  202. package/src/inference/browser-harness.d.ts +6 -0
  203. package/src/inference/browser-harness.js +130 -1950
  204. package/src/inference/kv-cache/base.js +140 -94
  205. package/src/inference/kv-cache/tiered.js +5 -3
  206. package/src/inference/moe-router.js +88 -56
  207. package/src/inference/multi-model-network.js +5 -3
  208. package/src/inference/network-evolution.d.ts +11 -2
  209. package/src/inference/network-evolution.js +20 -21
  210. package/src/inference/pipelines/context.d.ts +3 -0
  211. package/src/inference/pipelines/context.js +142 -2
  212. package/src/inference/pipelines/diffusion/helpers.js +7 -2
  213. package/src/inference/pipelines/diffusion/pipeline.js +17 -7
  214. package/src/inference/pipelines/diffusion/sd3-transformer.js +10 -10
  215. package/src/inference/pipelines/diffusion/text-encoder-gpu.d.ts +5 -0
  216. package/src/inference/pipelines/diffusion/text-encoder-gpu.js +27 -15
  217. package/src/inference/pipelines/diffusion/vae.js +3 -7
  218. package/src/inference/pipelines/energy/pipeline.js +27 -21
  219. package/src/inference/pipelines/energy/quintel.d.ts +5 -0
  220. package/src/inference/pipelines/energy/quintel.js +11 -0
  221. package/src/inference/pipelines/energy-head/row-head-pipeline.js +17 -13
  222. package/src/inference/pipelines/structured/json-head-pipeline.js +26 -11
  223. package/src/inference/pipelines/text/attention/projections.js +151 -101
  224. package/src/inference/pipelines/text/attention/record.js +73 -10
  225. package/src/inference/pipelines/text/attention/run.js +73 -10
  226. package/src/inference/pipelines/text/chat-format.js +25 -1
  227. package/src/inference/pipelines/text/config.d.ts +4 -0
  228. package/src/inference/pipelines/text/config.js +71 -5
  229. package/src/inference/pipelines/text/embed.js +2 -8
  230. package/src/inference/pipelines/text/execution-plan.js +64 -50
  231. package/src/inference/pipelines/text/execution-v0-contract-helpers.d.ts +59 -0
  232. package/src/inference/pipelines/text/execution-v0-contract-helpers.js +937 -0
  233. package/src/inference/pipelines/text/execution-v0-runtime-builders.d.ts +15 -0
  234. package/src/inference/pipelines/text/execution-v0-runtime-builders.js +279 -0
  235. package/src/inference/pipelines/text/execution-v0.js +78 -1002
  236. package/src/inference/pipelines/text/ffn/standard.js +3 -0
  237. package/src/inference/pipelines/text/generator-steps.d.ts +46 -0
  238. package/src/inference/pipelines/text/generator-steps.js +298 -207
  239. package/src/inference/pipelines/text/generator.js +6 -23
  240. package/src/inference/pipelines/text/init.d.ts +4 -0
  241. package/src/inference/pipelines/text/init.js +134 -29
  242. package/src/inference/pipelines/text/kernel-path-auto-select.js +2 -0
  243. package/src/inference/pipelines/text/kernel-trace.d.ts +2 -0
  244. package/src/inference/pipelines/text/kernel-trace.js +6 -0
  245. package/src/inference/pipelines/text/layer.js +14 -9
  246. package/src/inference/pipelines/text/linear-attention.d.ts +10 -0
  247. package/src/inference/pipelines/text/linear-attention.js +80 -6
  248. package/src/inference/pipelines/text/logits/gpu.js +10 -5
  249. package/src/inference/pipelines/text/logits/index.js +10 -11
  250. package/src/inference/pipelines/text/logits/utils.d.ts +7 -0
  251. package/src/inference/pipelines/text/logits/utils.js +9 -0
  252. package/src/inference/pipelines/text/lora-apply.js +50 -32
  253. package/src/inference/pipelines/text/model-load.js +279 -104
  254. package/src/inference/pipelines/text/moe-cache.js +5 -4
  255. package/src/inference/pipelines/text/moe-cpu-gptoss.js +74 -69
  256. package/src/inference/pipelines/text/moe-cpu.js +42 -38
  257. package/src/inference/pipelines/text/moe-gpu.js +110 -86
  258. package/src/inference/pipelines/text/ops.js +90 -90
  259. package/src/inference/pipelines/text/probes.js +9 -9
  260. package/src/inference/pipelines/text/weights.js +17 -7
  261. package/src/inference/pipelines/text.js +17 -1
  262. package/src/inference/speculative.d.ts +2 -2
  263. package/src/inference/speculative.js +4 -18
  264. package/src/inference/test-harness.d.ts +1 -1
  265. package/src/inference/test-harness.js +15 -5
  266. package/src/inference/tokenizer.d.ts +0 -5
  267. package/src/inference/tokenizer.js +4 -23
  268. package/src/inference/tokenizers/bpe.js +9 -0
  269. package/src/inference/tokenizers/bundled.js +176 -33
  270. package/src/inference/tokenizers/sentencepiece.js +12 -0
  271. package/src/loader/doppler-loader.js +38 -22
  272. package/src/loader/dtype-utils.js +3 -44
  273. package/src/loader/embedding-loader.js +7 -3
  274. package/src/loader/experts/expert-cache.js +13 -6
  275. package/src/loader/experts/expert-loader.js +10 -6
  276. package/src/loader/final-weights-loader.js +8 -4
  277. package/src/loader/layer-loader.js +2 -1
  278. package/src/loader/loader-state.js +2 -2
  279. package/src/loader/memory-monitor.js +8 -0
  280. package/src/loader/multi-model-loader.d.ts +14 -0
  281. package/src/loader/multi-model-loader.js +70 -24
  282. package/src/loader/shard-cache.js +81 -12
  283. package/src/loader/shard-resolver.js +25 -3
  284. package/src/loader/tensors/tensor-loader.js +209 -144
  285. package/src/loader/tensors/tensor-reader.js +76 -19
  286. package/src/loader/weight-downcast.js +1 -1
  287. package/src/memory/buffer-pool.d.ts +9 -1
  288. package/src/memory/buffer-pool.js +109 -44
  289. package/src/memory/unified-detect.js +1 -1
  290. package/src/rules/inference/kernel-path.rules.json +24 -8
  291. package/src/rules/rule-registry.js +25 -1
  292. package/src/rules/tooling/command-runtime.rules.json +18 -0
  293. package/src/storage/backends/opfs-store.js +68 -24
  294. package/src/storage/downloader.js +364 -83
  295. package/src/storage/index.d.ts +3 -0
  296. package/src/storage/index.js +3 -0
  297. package/src/storage/preflight.d.ts +2 -2
  298. package/src/storage/preflight.js +24 -2
  299. package/src/storage/quickstart-downloader.js +11 -5
  300. package/src/storage/registry.js +10 -4
  301. package/src/storage/reports.js +1 -1
  302. package/src/storage/shard-manager.d.ts +15 -1
  303. package/src/storage/shard-manager.js +51 -3
  304. package/src/storage/source-artifact-store.d.ts +52 -0
  305. package/src/storage/source-artifact-store.js +234 -0
  306. package/src/tooling/command-api-constants.d.ts +9 -0
  307. package/src/tooling/command-api-constants.js +9 -0
  308. package/src/tooling/command-api-family-normalizers.d.ts +9 -0
  309. package/src/tooling/command-api-family-normalizers.js +343 -0
  310. package/src/tooling/command-api-helpers.d.ts +25 -0
  311. package/src/tooling/command-api-helpers.js +262 -0
  312. package/src/tooling/command-api.d.ts +27 -1
  313. package/src/tooling/command-api.js +26 -473
  314. package/src/tooling/command-envelope.js +4 -1
  315. package/src/tooling/command-runner-shared.js +52 -18
  316. package/src/tooling/lean-execution-contract.js +150 -3
  317. package/src/tooling/node-browser-command-runner.d.ts +4 -0
  318. package/src/tooling/node-browser-command-runner.js +218 -273
  319. package/src/tooling/node-command-runner.js +44 -3
  320. package/src/tooling/node-converter.js +27 -1
  321. package/src/tooling/node-source-runtime.d.ts +1 -1
  322. package/src/tooling/node-source-runtime.js +84 -3
  323. package/src/tooling/node-webgpu.js +30 -105
  324. package/src/tooling/opfs-cache.js +21 -4
  325. package/src/tooling/runtime-input-composition.d.ts +38 -0
  326. package/src/tooling/runtime-input-composition.js +86 -0
  327. package/src/tooling/source-runtime-bundle.d.ts +40 -5
  328. package/src/tooling/source-runtime-bundle.js +261 -34
  329. package/src/tooling/source-runtime-materializer.d.ts +6 -0
  330. package/src/tooling/source-runtime-materializer.js +93 -0
  331. package/src/training/attention-backward.js +32 -17
  332. package/src/training/autograd.js +80 -52
  333. package/src/training/checkpoint-watch.d.ts +8 -0
  334. package/src/training/checkpoint-watch.js +139 -0
  335. package/src/training/checkpoint.d.ts +6 -1
  336. package/src/training/checkpoint.js +46 -7
  337. package/src/training/clip.js +2 -1
  338. package/src/training/datasets/token-batch.js +20 -8
  339. package/src/training/distillation/artifacts.d.ts +71 -0
  340. package/src/training/distillation/artifacts.js +132 -0
  341. package/src/training/distillation/checkpoint-watch.d.ts +10 -0
  342. package/src/training/distillation/checkpoint-watch.js +58 -0
  343. package/src/training/distillation/dataset.d.ts +59 -0
  344. package/src/training/distillation/dataset.js +337 -0
  345. package/src/training/distillation/eval.d.ts +34 -0
  346. package/src/training/distillation/eval.js +310 -0
  347. package/src/training/distillation/index.d.ts +29 -0
  348. package/src/training/distillation/index.js +29 -0
  349. package/src/training/distillation/runtime.d.ts +20 -0
  350. package/src/training/distillation/runtime.js +121 -0
  351. package/src/training/distillation/scoreboard.d.ts +6 -0
  352. package/src/training/distillation/scoreboard.js +8 -0
  353. package/src/training/distillation/stage-a.d.ts +45 -0
  354. package/src/training/distillation/stage-a.js +338 -0
  355. package/src/training/distillation/stage-b.d.ts +24 -0
  356. package/src/training/distillation/stage-b.js +20 -0
  357. package/src/training/distillation/student-fixture.d.ts +22 -0
  358. package/src/training/distillation/student-fixture.js +846 -0
  359. package/src/training/distillation/suite-data.d.ts +45 -0
  360. package/src/training/distillation/suite-data.js +189 -0
  361. package/src/training/index.d.ts +10 -0
  362. package/src/training/index.js +10 -0
  363. package/src/training/lora-pipeline.d.ts +40 -0
  364. package/src/training/lora-pipeline.js +793 -0
  365. package/src/training/lora.js +26 -12
  366. package/src/training/loss.js +5 -6
  367. package/src/training/objectives/cross_entropy.js +2 -5
  368. package/src/training/objectives/distill_kd.js +4 -8
  369. package/src/training/objectives/distill_triplet.js +4 -8
  370. package/src/training/objectives/ul_stage2_base.js +4 -8
  371. package/src/training/operator-artifacts.d.ts +62 -0
  372. package/src/training/operator-artifacts.js +140 -0
  373. package/src/training/operator-command.d.ts +5 -0
  374. package/src/training/operator-command.js +455 -0
  375. package/src/training/operator-eval.d.ts +48 -0
  376. package/src/training/operator-eval.js +230 -0
  377. package/src/training/operator-scoreboard.d.ts +5 -0
  378. package/src/training/operator-scoreboard.js +44 -0
  379. package/src/training/optimizer.js +19 -7
  380. package/src/training/runner.d.ts +52 -0
  381. package/src/training/runner.js +31 -5
  382. package/src/training/suite.d.ts +112 -0
  383. package/src/training/suite.js +24 -984
  384. package/src/training/tensor-factory.d.ts +9 -0
  385. package/src/training/tensor-factory.js +13 -0
  386. package/src/training/trainer.js +3 -5
  387. package/src/training/ul_dataset.js +3 -5
  388. package/src/training/workloads.d.ts +164 -0
  389. package/src/training/workloads.js +530 -0
  390. package/src/version.js +1 -1
  391. package/tools/convert-safetensors-node.js +22 -16
  392. package/tools/doppler-cli.js +179 -63
@@ -1,6 +1,5 @@
1
- import { CommandRecorder } from '../../command-recorder.js';
2
1
  import { getDevice } from '../../device.js';
3
- import { acquireBuffer } from '../../../memory/buffer-pool.js';
2
+ import { acquireBuffer, releaseBuffer } from '../../../memory/buffer-pool.js';
4
3
  import { createTensor, dtypeBytes } from '../../tensor.js';
5
4
  import { castF16ToF32, recordCastF16ToF32 } from '../cast.js';
6
5
  import { runMatmul, recordMatmul } from '../matmul.js';
@@ -15,24 +14,16 @@ async function ensureF32(tensor, recorder = null) {
15
14
  if (!recorder) {
16
15
  return castF16ToF32(tensor);
17
16
  }
18
- const casted = await recordCastF16ToF32(recorder, tensor);
19
- recorder.trackTemporaryBuffer(casted.buffer);
20
- return casted;
17
+ return recordCastF16ToF32(recorder, tensor);
21
18
  }
22
19
 
23
- function createHeadSliceBuffers(recorder, headBytes, softmaxBytes) {
20
+ function createHeadSliceBuffers(headBytes, softmaxBytes) {
24
21
  const qHeadBuf = acquireBuffer(headBytes, undefined, 'attn_q_head');
25
22
  const kHeadBuf = acquireBuffer(headBytes, undefined, 'attn_k_head');
26
23
  const vHeadBuf = acquireBuffer(headBytes, undefined, 'attn_v_head');
27
24
  const sHeadBuf = acquireBuffer(softmaxBytes, undefined, 'attn_s_head');
28
25
  const dHeadBuf = acquireBuffer(headBytes, undefined, 'attn_d_head');
29
26
 
30
- recorder.trackTemporaryBuffer(qHeadBuf);
31
- recorder.trackTemporaryBuffer(kHeadBuf);
32
- recorder.trackTemporaryBuffer(vHeadBuf);
33
- recorder.trackTemporaryBuffer(sHeadBuf);
34
- recorder.trackTemporaryBuffer(dHeadBuf);
35
-
36
27
  return { qHeadBuf, kHeadBuf, vHeadBuf, sHeadBuf, dHeadBuf };
37
28
  }
38
29
 
@@ -49,6 +40,19 @@ function trackTensorBuffer(recorder, tensor) {
49
40
  recorder.trackTemporaryBuffer(tensor.buffer);
50
41
  }
51
42
 
43
+ function releaseTensorBuffer(tensor) {
44
+ if (tensor?.buffer) {
45
+ releaseBuffer(tensor.buffer);
46
+ }
47
+ }
48
+
49
+ function maybeTrackOwnedTensor(ownedTensors, originalTensor, resolvedTensor) {
50
+ if (resolvedTensor !== originalTensor) {
51
+ ownedTensors.push(resolvedTensor);
52
+ }
53
+ return resolvedTensor;
54
+ }
55
+
52
56
  async function runAttentionBackwardCore(
53
57
  q,
54
58
  k,
@@ -63,11 +67,23 @@ async function runAttentionBackwardCore(
63
67
  throw new Error('attention backward requires seqLen, numHeads, and headDim');
64
68
  }
65
69
 
66
- const qTensor = await ensureF32(q, recorder);
67
- const kTensor = await ensureF32(k, recorder);
68
- const vTensor = await ensureF32(v, recorder);
69
- const sTensor = await ensureF32(softmax, recorder);
70
- const dTensor = await ensureF32(gradOutput, recorder);
70
+ const ownedInputTensors = [];
71
+ const ownedRecorderInputTensors = [];
72
+ const qTensor = !recorder
73
+ ? maybeTrackOwnedTensor(ownedInputTensors, q, await ensureF32(q))
74
+ : maybeTrackOwnedTensor(ownedRecorderInputTensors, q, await ensureF32(q, recorder));
75
+ const kTensor = !recorder
76
+ ? maybeTrackOwnedTensor(ownedInputTensors, k, await ensureF32(k))
77
+ : maybeTrackOwnedTensor(ownedRecorderInputTensors, k, await ensureF32(k, recorder));
78
+ const vTensor = !recorder
79
+ ? maybeTrackOwnedTensor(ownedInputTensors, v, await ensureF32(v))
80
+ : maybeTrackOwnedTensor(ownedRecorderInputTensors, v, await ensureF32(v, recorder));
81
+ const sTensor = !recorder
82
+ ? maybeTrackOwnedTensor(ownedInputTensors, softmax, await ensureF32(softmax))
83
+ : maybeTrackOwnedTensor(ownedRecorderInputTensors, softmax, await ensureF32(softmax, recorder));
84
+ const dTensor = !recorder
85
+ ? maybeTrackOwnedTensor(ownedInputTensors, gradOutput, await ensureF32(gradOutput))
86
+ : maybeTrackOwnedTensor(ownedRecorderInputTensors, gradOutput, await ensureF32(gradOutput, recorder));
71
87
 
72
88
  const headElements = seqLen * headDim;
73
89
  const headBytes = headElements * dtypeBytes(qTensor.dtype);
@@ -77,171 +93,247 @@ async function runAttentionBackwardCore(
77
93
  const gradQBuf = acquireBuffer(totalBytes, undefined, 'attn_grad_q');
78
94
  const gradKBuf = acquireBuffer(totalBytes, undefined, 'attn_grad_k');
79
95
  const gradVBuf = acquireBuffer(totalBytes, undefined, 'attn_grad_v');
96
+ let completed = false;
80
97
 
81
- if (!recorder) {
82
- for (let h = 0; h < numHeads; h += 1) {
83
- const qOffset = h * headBytes;
84
- const kOffset = h * headBytes;
85
- const vOffset = h * headBytes;
86
- const dOffset = h * headBytes;
87
- const sOffset = h * softmaxBytes;
98
+ try {
99
+ if (!recorder) {
100
+ for (let h = 0; h < numHeads; h += 1) {
101
+ const qOffset = h * headBytes;
102
+ const kOffset = h * headBytes;
103
+ const vOffset = h * headBytes;
104
+ const dOffset = h * headBytes;
105
+ const sOffset = h * softmaxBytes;
88
106
 
89
- const qHeadBuf = acquireBuffer(headBytes, undefined, 'attn_q_head');
90
- const kHeadBuf = acquireBuffer(headBytes, undefined, 'attn_k_head');
91
- const vHeadBuf = acquireBuffer(headBytes, undefined, 'attn_v_head');
92
- const sHeadBuf = acquireBuffer(softmaxBytes, undefined, 'attn_s_head');
93
- const dHeadBuf = acquireBuffer(headBytes, undefined, 'attn_d_head');
107
+ const qHeadBuf = acquireBuffer(headBytes, undefined, 'attn_q_head');
108
+ const kHeadBuf = acquireBuffer(headBytes, undefined, 'attn_k_head');
109
+ const vHeadBuf = acquireBuffer(headBytes, undefined, 'attn_v_head');
110
+ const sHeadBuf = acquireBuffer(softmaxBytes, undefined, 'attn_s_head');
111
+ const dHeadBuf = acquireBuffer(headBytes, undefined, 'attn_d_head');
112
+ let sTransposed = null;
113
+ let dV = null;
114
+ let vTransposed = null;
115
+ let dS = null;
116
+ let dQK = null;
117
+ let dQ = null;
118
+ let dQKTransposed = null;
119
+ let dK = null;
94
120
 
95
- const sliceEncoder = getDevice().createCommandEncoder();
96
- sliceEncoder.copyBufferToBuffer(qTensor.buffer, qOffset, qHeadBuf, 0, headBytes);
97
- sliceEncoder.copyBufferToBuffer(kTensor.buffer, kOffset, kHeadBuf, 0, headBytes);
98
- sliceEncoder.copyBufferToBuffer(vTensor.buffer, vOffset, vHeadBuf, 0, headBytes);
99
- sliceEncoder.copyBufferToBuffer(sTensor.buffer, sOffset, sHeadBuf, 0, softmaxBytes);
100
- sliceEncoder.copyBufferToBuffer(dTensor.buffer, dOffset, dHeadBuf, 0, headBytes);
101
- getDevice().queue.submit([sliceEncoder.finish()]);
121
+ try {
122
+ const sliceEncoder = getDevice().createCommandEncoder();
123
+ sliceEncoder.copyBufferToBuffer(qTensor.buffer, qOffset, qHeadBuf, 0, headBytes);
124
+ sliceEncoder.copyBufferToBuffer(kTensor.buffer, kOffset, kHeadBuf, 0, headBytes);
125
+ sliceEncoder.copyBufferToBuffer(vTensor.buffer, vOffset, vHeadBuf, 0, headBytes);
126
+ sliceEncoder.copyBufferToBuffer(sTensor.buffer, sOffset, sHeadBuf, 0, softmaxBytes);
127
+ sliceEncoder.copyBufferToBuffer(dTensor.buffer, dOffset, dHeadBuf, 0, headBytes);
128
+ getDevice().queue.submit([sliceEncoder.finish()]);
102
129
 
103
- const { qHead, kHead, vHead, sHead, dHead } = createHeadTensors(
104
- qHeadBuf,
105
- kHeadBuf,
106
- vHeadBuf,
107
- sHeadBuf,
108
- dHeadBuf,
109
- seqLen,
110
- headDim
111
- );
130
+ const { qHead, kHead, vHead, sHead, dHead } = createHeadTensors(
131
+ qHeadBuf,
132
+ kHeadBuf,
133
+ vHeadBuf,
134
+ sHeadBuf,
135
+ dHeadBuf,
136
+ seqLen,
137
+ headDim
138
+ );
112
139
 
113
- const sTransposed = await runTranspose(sHead, seqLen, seqLen);
114
- const dV = await runMatmul(sTransposed, dHead.buffer, seqLen, headDim, seqLen, {
115
- transposeB: false,
116
- bDtype: 'f32',
117
- });
140
+ sTransposed = await runTranspose(sHead, seqLen, seqLen);
141
+ dV = await runMatmul(sTransposed, dHead.buffer, seqLen, headDim, seqLen, {
142
+ transposeB: false,
143
+ bDtype: 'f32',
144
+ });
118
145
 
119
- const vTransposed = await runTranspose(vHead, seqLen, headDim);
120
- const dS = await runMatmul(dHead, vTransposed.buffer, seqLen, seqLen, headDim, {
121
- transposeB: false,
122
- bDtype: 'f32',
123
- });
124
- const dQK = causal
125
- ? await runBackwardKernel(
126
- 'attention_backward',
127
- sHead,
128
- dS,
129
- 16,
130
- (view) => {
131
- view.setUint32(0, seqLen, true);
132
- view.setUint32(4, seqLen, true);
133
- view.setUint32(8, 1, true);
134
- }
135
- )
136
- : await runSoftmaxBackward(sHead, dS, { rows: seqLen, cols: seqLen });
146
+ vTransposed = await runTranspose(vHead, seqLen, headDim);
147
+ dS = await runMatmul(dHead, vTransposed.buffer, seqLen, seqLen, headDim, {
148
+ transposeB: false,
149
+ bDtype: 'f32',
150
+ });
151
+ dQK = causal
152
+ ? await runBackwardKernel(
153
+ 'attention_backward',
154
+ sHead,
155
+ dS,
156
+ 16,
157
+ (view) => {
158
+ view.setUint32(0, seqLen, true);
159
+ view.setUint32(4, seqLen, true);
160
+ view.setUint32(8, 1, true);
161
+ }
162
+ )
163
+ : await runSoftmaxBackward(sHead, dS, { rows: seqLen, cols: seqLen });
137
164
 
138
- const dQ = await runMatmul(dQK, kHead.buffer, seqLen, headDim, seqLen, {
139
- transposeB: false,
140
- alpha: scale,
141
- bDtype: 'f32',
142
- });
143
- const dQKTransposed = await runTranspose(dQK, seqLen, seqLen);
144
- const dK = await runMatmul(dQKTransposed, qHead.buffer, seqLen, headDim, seqLen, {
145
- transposeB: false,
146
- alpha: scale,
147
- bDtype: 'f32',
148
- });
165
+ dQ = await runMatmul(dQK, kHead.buffer, seqLen, headDim, seqLen, {
166
+ transposeB: false,
167
+ alpha: scale,
168
+ bDtype: 'f32',
169
+ });
170
+ dQKTransposed = await runTranspose(dQK, seqLen, seqLen);
171
+ dK = await runMatmul(dQKTransposed, qHead.buffer, seqLen, headDim, seqLen, {
172
+ transposeB: false,
173
+ alpha: scale,
174
+ bDtype: 'f32',
175
+ });
149
176
 
150
- const copyEncoder = getDevice().createCommandEncoder();
151
- copyEncoder.copyBufferToBuffer(dQ.buffer, 0, gradQBuf, qOffset, headBytes);
152
- copyEncoder.copyBufferToBuffer(dK.buffer, 0, gradKBuf, kOffset, headBytes);
153
- copyEncoder.copyBufferToBuffer(dV.buffer, 0, gradVBuf, vOffset, headBytes);
154
- getDevice().queue.submit([copyEncoder.finish()]);
155
- }
156
- } else {
157
- const encoder = recorder.getEncoder();
158
- for (let h = 0; h < numHeads; h += 1) {
159
- const qOffset = h * headBytes;
160
- const kOffset = h * headBytes;
161
- const vOffset = h * headBytes;
162
- const dOffset = h * headBytes;
163
- const sOffset = h * softmaxBytes;
177
+ const copyEncoder = getDevice().createCommandEncoder();
178
+ copyEncoder.copyBufferToBuffer(dQ.buffer, 0, gradQBuf, qOffset, headBytes);
179
+ copyEncoder.copyBufferToBuffer(dK.buffer, 0, gradKBuf, kOffset, headBytes);
180
+ copyEncoder.copyBufferToBuffer(dV.buffer, 0, gradVBuf, vOffset, headBytes);
181
+ getDevice().queue.submit([copyEncoder.finish()]);
182
+ await getDevice().queue.onSubmittedWorkDone();
183
+ } finally {
184
+ releaseTensorBuffer(sTransposed);
185
+ releaseTensorBuffer(dV);
186
+ releaseTensorBuffer(vTransposed);
187
+ releaseTensorBuffer(dS);
188
+ releaseTensorBuffer(dQK);
189
+ releaseTensorBuffer(dQ);
190
+ releaseTensorBuffer(dQKTransposed);
191
+ releaseTensorBuffer(dK);
192
+ releaseBuffer(qHeadBuf);
193
+ releaseBuffer(kHeadBuf);
194
+ releaseBuffer(vHeadBuf);
195
+ releaseBuffer(sHeadBuf);
196
+ releaseBuffer(dHeadBuf);
197
+ }
198
+ }
199
+ } else {
200
+ const encoder = recorder.getEncoder();
201
+ for (let h = 0; h < numHeads; h += 1) {
202
+ const qOffset = h * headBytes;
203
+ const kOffset = h * headBytes;
204
+ const vOffset = h * headBytes;
205
+ const dOffset = h * headBytes;
206
+ const sOffset = h * softmaxBytes;
164
207
 
165
- const { qHeadBuf, kHeadBuf, vHeadBuf, sHeadBuf, dHeadBuf } = createHeadSliceBuffers(
166
- recorder,
167
- headBytes,
168
- softmaxBytes
169
- );
208
+ const { qHeadBuf, kHeadBuf, vHeadBuf, sHeadBuf, dHeadBuf } = createHeadSliceBuffers(
209
+ headBytes,
210
+ softmaxBytes
211
+ );
212
+ const headBuffers = [qHeadBuf, kHeadBuf, vHeadBuf, sHeadBuf, dHeadBuf];
213
+ let sTransposed = null;
214
+ let dV = null;
215
+ let vTransposed = null;
216
+ let dS = null;
217
+ let dQK = null;
218
+ let dQ = null;
219
+ let dQKTransposed = null;
220
+ let dK = null;
170
221
 
171
- encoder.copyBufferToBuffer(qTensor.buffer, qOffset, qHeadBuf, 0, headBytes);
172
- encoder.copyBufferToBuffer(kTensor.buffer, kOffset, kHeadBuf, 0, headBytes);
173
- encoder.copyBufferToBuffer(vTensor.buffer, vOffset, vHeadBuf, 0, headBytes);
174
- encoder.copyBufferToBuffer(sTensor.buffer, sOffset, sHeadBuf, 0, softmaxBytes);
175
- encoder.copyBufferToBuffer(dTensor.buffer, dOffset, dHeadBuf, 0, headBytes);
222
+ try {
223
+ encoder.copyBufferToBuffer(qTensor.buffer, qOffset, qHeadBuf, 0, headBytes);
224
+ encoder.copyBufferToBuffer(kTensor.buffer, kOffset, kHeadBuf, 0, headBytes);
225
+ encoder.copyBufferToBuffer(vTensor.buffer, vOffset, vHeadBuf, 0, headBytes);
226
+ encoder.copyBufferToBuffer(sTensor.buffer, sOffset, sHeadBuf, 0, softmaxBytes);
227
+ encoder.copyBufferToBuffer(dTensor.buffer, dOffset, dHeadBuf, 0, headBytes);
176
228
 
177
- const { qHead, kHead, vHead, sHead, dHead } = createHeadTensors(
178
- qHeadBuf,
179
- kHeadBuf,
180
- vHeadBuf,
181
- sHeadBuf,
182
- dHeadBuf,
183
- seqLen,
184
- headDim
185
- );
229
+ const { qHead, kHead, vHead, sHead, dHead } = createHeadTensors(
230
+ qHeadBuf,
231
+ kHeadBuf,
232
+ vHeadBuf,
233
+ sHeadBuf,
234
+ dHeadBuf,
235
+ seqLen,
236
+ headDim
237
+ );
186
238
 
187
- const sTransposed = await recordTranspose(recorder, sHead, seqLen, seqLen);
188
- const dV = await recordMatmul(recorder, sTransposed, dHead.buffer, seqLen, headDim, seqLen, {
189
- transposeB: false,
190
- bDtype: 'f32',
191
- });
239
+ sTransposed = await recordTranspose(recorder, sHead, seqLen, seqLen);
240
+ dV = await recordMatmul(recorder, sTransposed, dHead.buffer, seqLen, headDim, seqLen, {
241
+ transposeB: false,
242
+ bDtype: 'f32',
243
+ });
192
244
 
193
- const vTransposed = await recordTranspose(recorder, vHead, seqLen, headDim);
194
- const dS = await recordMatmul(recorder, dHead, vTransposed.buffer, seqLen, seqLen, headDim, {
195
- transposeB: false,
196
- bDtype: 'f32',
197
- });
198
- const dQK = causal
199
- ? await recordBackwardKernel(
200
- recorder,
201
- 'attention_backward',
202
- sHead,
203
- dS,
204
- 16,
205
- (view) => {
206
- view.setUint32(0, seqLen, true);
207
- view.setUint32(4, seqLen, true);
208
- view.setUint32(8, 1, true);
209
- }
210
- )
211
- : await recordSoftmaxBackward(recorder, sHead, dS, { rows: seqLen, cols: seqLen });
245
+ vTransposed = await recordTranspose(recorder, vHead, seqLen, headDim);
246
+ dS = await recordMatmul(recorder, dHead, vTransposed.buffer, seqLen, seqLen, headDim, {
247
+ transposeB: false,
248
+ bDtype: 'f32',
249
+ });
250
+ dQK = causal
251
+ ? await recordBackwardKernel(
252
+ recorder,
253
+ 'attention_backward',
254
+ sHead,
255
+ dS,
256
+ 16,
257
+ (view) => {
258
+ view.setUint32(0, seqLen, true);
259
+ view.setUint32(4, seqLen, true);
260
+ view.setUint32(8, 1, true);
261
+ }
262
+ )
263
+ : await recordSoftmaxBackward(recorder, sHead, dS, { rows: seqLen, cols: seqLen });
212
264
 
213
- const dQ = await recordMatmul(recorder, dQK, kHead.buffer, seqLen, headDim, seqLen, {
214
- transposeB: false,
215
- alpha: scale,
216
- bDtype: 'f32',
217
- });
218
- const dQKTransposed = await recordTranspose(recorder, dQK, seqLen, seqLen);
219
- const dK = await recordMatmul(recorder, dQKTransposed, qHead.buffer, seqLen, headDim, seqLen, {
220
- transposeB: false,
221
- alpha: scale,
222
- bDtype: 'f32',
223
- });
265
+ dQ = await recordMatmul(recorder, dQK, kHead.buffer, seqLen, headDim, seqLen, {
266
+ transposeB: false,
267
+ alpha: scale,
268
+ bDtype: 'f32',
269
+ });
270
+ dQKTransposed = await recordTranspose(recorder, dQK, seqLen, seqLen);
271
+ dK = await recordMatmul(recorder, dQKTransposed, qHead.buffer, seqLen, headDim, seqLen, {
272
+ transposeB: false,
273
+ alpha: scale,
274
+ bDtype: 'f32',
275
+ });
224
276
 
225
- encoder.copyBufferToBuffer(dQ.buffer, 0, gradQBuf, qOffset, headBytes);
226
- encoder.copyBufferToBuffer(dK.buffer, 0, gradKBuf, kOffset, headBytes);
227
- encoder.copyBufferToBuffer(dV.buffer, 0, gradVBuf, vOffset, headBytes);
277
+ encoder.copyBufferToBuffer(dQ.buffer, 0, gradQBuf, qOffset, headBytes);
278
+ encoder.copyBufferToBuffer(dK.buffer, 0, gradKBuf, kOffset, headBytes);
279
+ encoder.copyBufferToBuffer(dV.buffer, 0, gradVBuf, vOffset, headBytes);
280
+ } catch (error) {
281
+ releaseTensorBuffer(sTransposed);
282
+ releaseTensorBuffer(dV);
283
+ releaseTensorBuffer(vTransposed);
284
+ releaseTensorBuffer(dS);
285
+ releaseTensorBuffer(dQK);
286
+ releaseTensorBuffer(dQ);
287
+ releaseTensorBuffer(dQKTransposed);
288
+ releaseTensorBuffer(dK);
289
+ releaseBuffer(qHeadBuf);
290
+ releaseBuffer(kHeadBuf);
291
+ releaseBuffer(vHeadBuf);
292
+ releaseBuffer(sHeadBuf);
293
+ releaseBuffer(dHeadBuf);
294
+ throw error;
295
+ }
228
296
 
229
- trackTensorBuffer(recorder, sTransposed);
230
- trackTensorBuffer(recorder, dV);
231
- trackTensorBuffer(recorder, vTransposed);
232
- trackTensorBuffer(recorder, dS);
233
- trackTensorBuffer(recorder, dQK);
234
- trackTensorBuffer(recorder, dQ);
235
- trackTensorBuffer(recorder, dQKTransposed);
236
- trackTensorBuffer(recorder, dK);
297
+ for (const buffer of headBuffers) {
298
+ recorder.trackTemporaryBuffer(buffer);
299
+ }
300
+ trackTensorBuffer(recorder, sTransposed);
301
+ trackTensorBuffer(recorder, dV);
302
+ trackTensorBuffer(recorder, vTransposed);
303
+ trackTensorBuffer(recorder, dS);
304
+ trackTensorBuffer(recorder, dQK);
305
+ trackTensorBuffer(recorder, dQ);
306
+ trackTensorBuffer(recorder, dQKTransposed);
307
+ trackTensorBuffer(recorder, dK);
308
+ }
309
+ }
310
+ if (recorder) {
311
+ for (const tensor of ownedRecorderInputTensors) {
312
+ trackTensorBuffer(recorder, tensor);
313
+ }
314
+ }
315
+ completed = true;
316
+ return {
317
+ gradQ: createTensor(gradQBuf, 'f32', [...q.shape], 'attn_grad_q'),
318
+ gradK: createTensor(gradKBuf, 'f32', [...k.shape], 'attn_grad_k'),
319
+ gradV: createTensor(gradVBuf, 'f32', [...v.shape], 'attn_grad_v'),
320
+ };
321
+ } finally {
322
+ if (!completed) {
323
+ releaseBuffer(gradQBuf);
324
+ releaseBuffer(gradKBuf);
325
+ releaseBuffer(gradVBuf);
326
+ }
327
+ if (!recorder) {
328
+ for (const tensor of ownedInputTensors) {
329
+ releaseTensorBuffer(tensor);
330
+ }
331
+ } else {
332
+ for (const tensor of ownedRecorderInputTensors) {
333
+ releaseTensorBuffer(tensor);
334
+ }
237
335
  }
238
336
  }
239
-
240
- return {
241
- gradQ: createTensor(gradQBuf, 'f32', [...q.shape], 'attn_grad_q'),
242
- gradK: createTensor(gradKBuf, 'f32', [...k.shape], 'attn_grad_k'),
243
- gradV: createTensor(gradVBuf, 'f32', [...v.shape], 'attn_grad_v'),
244
- };
245
337
  }
246
338
 
247
339
  export async function runAttentionBackward(
@@ -256,11 +348,7 @@ export async function runAttentionBackward(
256
348
  if (!device) {
257
349
  throw new Error('runAttentionBackward requires a GPU device');
258
350
  }
259
-
260
- const recorder = new CommandRecorder(device, 'attention_backward');
261
- const result = await runAttentionBackwardCore(q, k, v, softmax, gradOutput, options, recorder);
262
- recorder.submit();
263
- return result;
351
+ return runAttentionBackwardCore(q, k, v, softmax, gradOutput, options);
264
352
  }
265
353
 
266
354
  export async function recordAttentionBackward(
@@ -4,6 +4,19 @@ import { createPipeline, createUniformBufferWithView } from '../utils.js';
4
4
  import { dispatch, recordDispatch } from '../dispatch.js';
5
5
  import { getDevice } from '../../device.js';
6
6
 
7
+ function destroyAfterSubmit(device, buffer) {
8
+ if (!buffer) {
9
+ return;
10
+ }
11
+ device.queue.onSubmittedWorkDone()
12
+ .then(() => {
13
+ buffer.destroy();
14
+ })
15
+ .catch(() => {
16
+ buffer.destroy();
17
+ });
18
+ }
19
+
7
20
  export async function runConv2DBackward(input, weight, gradOutput, options = {}) {
8
21
  const { inChannels, outChannels, height, width, outHeight, outWidth, kernelH, kernelW, stride, pad, computeGradInput = true, computeGradWeight = true } = options;
9
22
 
@@ -67,7 +80,7 @@ export async function runConv2DBackward(input, weight, gradOutput, options = {})
67
80
  gradWeight = createTensor(outputBuf, 'f32', [outChannels, inChannels, kernelH, kernelW], 'conv2d_grad_weight');
68
81
  }
69
82
 
70
- uniformBuffer.destroy();
83
+ destroyAfterSubmit(device, uniformBuffer);
71
84
  return { gradInput, gradWeight };
72
85
  }
73
86
 
@@ -14,6 +14,10 @@ struct Uniforms {
14
14
  dim: u32,
15
15
  data_offset: u32, // byte offset into data buffer (divide by 4 for F32)
16
16
  bias_offset: u32, // byte offset into bias buffer (divide by 4 for F32)
17
+ token_stride: u32,
18
+ _pad0: u32,
19
+ _pad1: u32,
20
+ _pad2: u32,
17
21
  }
18
22
 
19
23
  override WORKGROUP_SIZE: u32 = 256u;
@@ -24,17 +28,15 @@ override WORKGROUP_SIZE: u32 = 256u;
24
28
 
25
29
  @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
26
30
  fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
27
- let idx = gid.x;
28
- let total = u.num_tokens * u.dim;
29
- if (idx >= total) {
31
+ let d = gid.x;
32
+ let token = gid.z * max(u.token_stride, 1u) + gid.y;
33
+ if (token >= u.num_tokens || d >= u.dim) {
30
34
  return;
31
35
  }
32
36
 
33
37
  // Convert byte offsets to F32 indices
34
38
  let data_base = u.data_offset / 4u;
35
39
  let bias_base = u.bias_offset / 4u;
36
-
37
- let d = idx % u.dim;
40
+ let idx = token * u.dim + d;
38
41
  data[data_base + idx] = data[data_base + idx] + bias[bias_base + d];
39
42
  }
40
-
@@ -18,6 +18,10 @@ struct Uniforms {
18
18
  dim: u32,
19
19
  data_offset: u32, // byte offset into data buffer (divide by 2 for F16)
20
20
  bias_offset: u32, // byte offset into bias buffer (divide by 2 for F16)
21
+ token_stride: u32,
22
+ _pad0: u32,
23
+ _pad1: u32,
24
+ _pad2: u32,
21
25
  }
22
26
 
23
27
  override WORKGROUP_SIZE: u32 = 256u;
@@ -28,17 +32,16 @@ override WORKGROUP_SIZE: u32 = 256u;
28
32
 
29
33
  @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
30
34
  fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
31
- let idx = gid.x;
32
- let total = u.num_tokens * u.dim;
33
- if (idx >= total) {
35
+ let d = gid.x;
36
+ let token = gid.z * max(u.token_stride, 1u) + gid.y;
37
+ if (token >= u.num_tokens || d >= u.dim) {
34
38
  return;
35
39
  }
36
40
 
37
41
  // Convert byte offsets to F16 indices
38
42
  let data_base = u.data_offset / 2u;
39
43
  let bias_base = u.bias_offset / 2u;
40
-
41
- let d = idx % u.dim;
44
+ let idx = token * u.dim + d;
42
45
  let out = f32(data[data_base + idx]) + f32(bias[bias_base + d]);
43
46
  data[data_base + idx] = f16(out);
44
47
  }