@simulatte/doppler 0.1.6 → 0.1.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (316) hide show
  1. package/CHANGELOG.md +126 -0
  2. package/README.md +16 -23
  3. package/package.json +14 -1
  4. package/src/adapters/adapter-registry.js +12 -1
  5. package/src/adapters/lora-loader.js +23 -6
  6. package/src/bridge/extension-client.d.ts +5 -0
  7. package/src/bridge/extension-client.js +40 -0
  8. package/src/bridge/index.d.ts +2 -1
  9. package/src/bridge/index.js +6 -4
  10. package/src/browser/browser-converter.js +26 -1
  11. package/src/browser/file-picker.js +6 -0
  12. package/src/browser/safetensors-parser-browser.js +84 -1
  13. package/src/browser/shard-io-browser.js +2 -2
  14. package/src/browser/tensor-source-download.js +8 -2
  15. package/src/browser/tensor-source-http.d.ts +1 -0
  16. package/src/browser/tensor-source-http.js +5 -1
  17. package/src/client/doppler-api.browser.js +20 -4
  18. package/src/client/doppler-api.js +19 -3
  19. package/src/client/doppler-provider/generation.js +12 -0
  20. package/src/client/doppler-provider/model-manager.d.ts +10 -0
  21. package/src/client/doppler-provider/model-manager.js +91 -19
  22. package/src/client/doppler-provider/source-runtime.d.ts +2 -1
  23. package/src/client/doppler-provider/source-runtime.js +132 -13
  24. package/src/client/doppler-registry.json +8 -7
  25. package/src/config/backward-registry-loader.js +17 -2
  26. package/src/config/execution-v0-contract-check.js +113 -15
  27. package/src/config/kernel-path-contract-check.js +57 -29
  28. package/src/config/kernel-path-loader.js +5 -36
  29. package/src/config/kernels/kernel-ref-digests.js +1 -1
  30. package/src/config/kernels/registry.js +14 -1
  31. package/src/config/kernels/registry.json +7 -5
  32. package/src/config/loader.d.ts +1 -1
  33. package/src/config/loader.js +12 -2
  34. package/src/config/merge-contract-check.js +59 -4
  35. package/src/config/merge-helpers.js +128 -7
  36. package/src/config/merge.d.ts +1 -0
  37. package/src/config/merge.js +10 -0
  38. package/src/config/param-validator.js +47 -2
  39. package/src/config/presets/kernel-paths/{gemma2-q4k-dequant-f32a.json → gemma2-q4k-dequant-f32a-nosubgroups.json} +3 -3
  40. package/src/config/presets/kernel-paths/gemma3-f16-fused-f32a-online-streamingprefill.json +223 -0
  41. package/src/config/presets/kernel-paths/{gemma3-q4k-dequant-f32a.json → gemma3-q4k-dequant-f32a-nosubgroups.json} +3 -3
  42. package/src/config/presets/kernel-paths/registry.json +29 -8
  43. package/src/config/presets/models/gemma2.json +2 -2
  44. package/src/config/presets/runtime/experiments/bench/gemma3-bench-q4k.json +1 -1
  45. package/src/config/presets/runtime/experiments/debug/gemma3-debug-q4k.json +1 -1
  46. package/src/config/presets/runtime/experiments/verify/gemma3-verify.json +1 -1
  47. package/src/config/presets/runtime/kernels/dequant-f16-q4k.json +6 -13
  48. package/src/config/presets/runtime/kernels/dequant-f32-q4k.json +6 -13
  49. package/src/config/presets/runtime/kernels/embeddinggemma-q4k-dequant-f32a.json +37 -0
  50. package/src/config/presets/runtime/kernels/fused-q4k.json +6 -13
  51. package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f16a.json +33 -0
  52. package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f32a-nosubgroups.json +33 -0
  53. package/src/config/presets/runtime/kernels/gemma2-q4k-fused-f32a.json +33 -0
  54. package/src/config/presets/runtime/kernels/safe-q4k.json +6 -13
  55. package/src/config/presets/runtime/platform/metal-apple-q4k.json +1 -1
  56. package/src/config/runtime.js +6 -1
  57. package/src/config/schema/debug.schema.d.ts +5 -0
  58. package/src/config/schema/doppler.schema.js +16 -21
  59. package/src/config/schema/inference-defaults.schema.js +3 -3
  60. package/src/config/schema/kernel-path.schema.d.ts +5 -1
  61. package/src/config/schema/kernel-thresholds.schema.js +12 -4
  62. package/src/config/schema/manifest.schema.d.ts +2 -1
  63. package/src/config/schema/manifest.schema.js +16 -3
  64. package/src/config/training-defaults.js +30 -22
  65. package/src/converter/conversion-plan.js +94 -9
  66. package/src/converter/core.d.ts +7 -0
  67. package/src/converter/core.js +14 -9
  68. package/src/converter/execution-v0-manifest.js +4 -1
  69. package/src/converter/index.d.ts +1 -0
  70. package/src/converter/index.js +1 -0
  71. package/src/converter/manifest-inference.js +43 -12
  72. package/src/converter/parsers/diffusion.js +0 -3
  73. package/src/converter/quantization-info.js +35 -15
  74. package/src/converter/shard-packer.d.ts +1 -1
  75. package/src/converter/shard-packer.js +4 -1
  76. package/src/debug/config.js +123 -11
  77. package/src/debug/signals.js +7 -1
  78. package/src/debug/tensor.d.ts +2 -0
  79. package/src/debug/tensor.js +13 -2
  80. package/src/distribution/p2p-control-plane.js +52 -12
  81. package/src/distribution/p2p-observability.js +43 -7
  82. package/src/distribution/p2p-webrtc-browser.js +20 -0
  83. package/src/distribution/shard-delivery.js +77 -26
  84. package/src/formats/gguf/types.js +33 -16
  85. package/src/formats/rdrr/groups.d.ts +12 -4
  86. package/src/formats/rdrr/groups.js +3 -6
  87. package/src/formats/rdrr/parsing.js +39 -2
  88. package/src/formats/rdrr/types.d.ts +2 -1
  89. package/src/gpu/command-recorder.js +86 -61
  90. package/src/gpu/device.d.ts +1 -0
  91. package/src/gpu/device.js +73 -19
  92. package/src/gpu/kernel-tuner/benchmarks.js +326 -316
  93. package/src/gpu/kernel-tuner/cache.js +71 -4
  94. package/src/gpu/kernel-tuner/tuner.js +22 -4
  95. package/src/gpu/kernels/attention.js +15 -34
  96. package/src/gpu/kernels/backward/adam.js +62 -58
  97. package/src/gpu/kernels/backward/attention_backward.js +257 -169
  98. package/src/gpu/kernels/backward/conv2d_backward.js +14 -1
  99. package/src/gpu/kernels/cast.js +191 -149
  100. package/src/gpu/kernels/check-stop.js +33 -44
  101. package/src/gpu/kernels/conv2d.js +27 -17
  102. package/src/gpu/kernels/cross_entropy_loss.js +21 -15
  103. package/src/gpu/kernels/depthwise_conv2d.js +36 -26
  104. package/src/gpu/kernels/dequant.js +178 -126
  105. package/src/gpu/kernels/energy.d.ts +3 -21
  106. package/src/gpu/kernels/energy.js +111 -88
  107. package/src/gpu/kernels/feature-check.js +1 -1
  108. package/src/gpu/kernels/fused_ffn.js +84 -65
  109. package/src/gpu/kernels/fused_matmul_residual.js +56 -33
  110. package/src/gpu/kernels/fused_matmul_rmsnorm.js +62 -45
  111. package/src/gpu/kernels/gather.js +33 -15
  112. package/src/gpu/kernels/gelu.js +19 -11
  113. package/src/gpu/kernels/grouped_pointwise_conv2d.js +33 -23
  114. package/src/gpu/kernels/groupnorm.js +34 -23
  115. package/src/gpu/kernels/kv-quantize.js +5 -2
  116. package/src/gpu/kernels/layernorm.js +35 -19
  117. package/src/gpu/kernels/logit-merge.js +5 -3
  118. package/src/gpu/kernels/matmul.js +58 -39
  119. package/src/gpu/kernels/modulate.js +23 -15
  120. package/src/gpu/kernels/moe.js +221 -175
  121. package/src/gpu/kernels/pixel_shuffle.js +22 -14
  122. package/src/gpu/kernels/relu.js +18 -10
  123. package/src/gpu/kernels/repeat_channels.js +25 -17
  124. package/src/gpu/kernels/residual.js +37 -27
  125. package/src/gpu/kernels/rmsnorm.js +57 -41
  126. package/src/gpu/kernels/rope.js +3 -0
  127. package/src/gpu/kernels/sample.js +27 -38
  128. package/src/gpu/kernels/sana_linear_attention.js +18 -10
  129. package/src/gpu/kernels/scale.js +18 -11
  130. package/src/gpu/kernels/shader-cache.js +4 -2
  131. package/src/gpu/kernels/silu.js +120 -72
  132. package/src/gpu/kernels/softmax.js +44 -25
  133. package/src/gpu/kernels/split_qkv.js +23 -13
  134. package/src/gpu/kernels/transpose.js +18 -10
  135. package/src/gpu/kernels/transpose.wgsl +5 -3
  136. package/src/gpu/kernels/upsample2d.js +21 -13
  137. package/src/gpu/kernels/utils.js +20 -13
  138. package/src/gpu/partitioned-buffer-pool.js +10 -2
  139. package/src/gpu/perf-guards.js +2 -9
  140. package/src/gpu/profiler.js +27 -22
  141. package/src/gpu/readback-utils.d.ts +16 -0
  142. package/src/gpu/readback-utils.js +41 -0
  143. package/src/gpu/submit-tracker.js +13 -0
  144. package/src/gpu/uniform-cache.d.ts +1 -0
  145. package/src/gpu/uniform-cache.js +30 -9
  146. package/src/hotswap/intent-bundle.js +6 -0
  147. package/src/hotswap/manifest.d.ts +10 -1
  148. package/src/hotswap/manifest.js +12 -2
  149. package/src/hotswap/runtime.js +30 -8
  150. package/src/index-browser.d.ts +44 -0
  151. package/src/index-browser.js +14 -0
  152. package/src/inference/browser-harness-contract-helpers.d.ts +5 -0
  153. package/src/inference/browser-harness-contract-helpers.js +28 -0
  154. package/src/inference/browser-harness-diffusion-energy-suites.d.ts +2 -0
  155. package/src/inference/browser-harness-diffusion-energy-suites.js +269 -0
  156. package/src/inference/browser-harness-model-helpers.d.ts +16 -0
  157. package/src/inference/browser-harness-model-helpers.js +217 -0
  158. package/src/inference/browser-harness-report-helpers.d.ts +7 -0
  159. package/src/inference/browser-harness-report-helpers.js +42 -0
  160. package/src/inference/browser-harness-runtime-helpers.d.ts +61 -0
  161. package/src/inference/browser-harness-runtime-helpers.js +415 -0
  162. package/src/inference/browser-harness-suite-helpers.d.ts +28 -0
  163. package/src/inference/browser-harness-suite-helpers.js +268 -0
  164. package/src/inference/browser-harness-text-helpers.d.ts +27 -0
  165. package/src/inference/browser-harness-text-helpers.js +788 -0
  166. package/src/inference/browser-harness.d.ts +6 -0
  167. package/src/inference/browser-harness.js +130 -1996
  168. package/src/inference/kv-cache/base.js +140 -94
  169. package/src/inference/kv-cache/tiered.js +5 -3
  170. package/src/inference/moe-router.js +88 -56
  171. package/src/inference/multi-model-network.js +5 -3
  172. package/src/inference/network-evolution.d.ts +11 -2
  173. package/src/inference/network-evolution.js +20 -21
  174. package/src/inference/pipelines/context.d.ts +3 -0
  175. package/src/inference/pipelines/context.js +142 -2
  176. package/src/inference/pipelines/diffusion/helpers.js +7 -2
  177. package/src/inference/pipelines/diffusion/pipeline.js +2 -1
  178. package/src/inference/pipelines/diffusion/sd3-transformer.js +10 -10
  179. package/src/inference/pipelines/diffusion/vae.js +3 -7
  180. package/src/inference/pipelines/energy/pipeline.js +27 -21
  181. package/src/inference/pipelines/energy/quintel.d.ts +5 -0
  182. package/src/inference/pipelines/energy/quintel.js +11 -0
  183. package/src/inference/pipelines/energy-head/row-head-pipeline.js +17 -13
  184. package/src/inference/pipelines/structured/json-head-pipeline.js +26 -11
  185. package/src/inference/pipelines/text/attention/projections.js +151 -101
  186. package/src/inference/pipelines/text/attention/record.js +62 -8
  187. package/src/inference/pipelines/text/attention/run.js +62 -8
  188. package/src/inference/pipelines/text/config.js +3 -4
  189. package/src/inference/pipelines/text/embed.js +2 -8
  190. package/src/inference/pipelines/text/execution-plan.js +41 -19
  191. package/src/inference/pipelines/text/execution-v0-contract-helpers.d.ts +59 -0
  192. package/src/inference/pipelines/text/execution-v0-contract-helpers.js +937 -0
  193. package/src/inference/pipelines/text/execution-v0-runtime-builders.d.ts +15 -0
  194. package/src/inference/pipelines/text/execution-v0-runtime-builders.js +279 -0
  195. package/src/inference/pipelines/text/execution-v0.js +62 -1013
  196. package/src/inference/pipelines/text/generator-steps.d.ts +46 -0
  197. package/src/inference/pipelines/text/generator-steps.js +298 -207
  198. package/src/inference/pipelines/text/generator.js +6 -23
  199. package/src/inference/pipelines/text/init.js +78 -20
  200. package/src/inference/pipelines/text/kernel-path-auto-select.js +2 -0
  201. package/src/inference/pipelines/text/kernel-trace.d.ts +2 -0
  202. package/src/inference/pipelines/text/kernel-trace.js +6 -0
  203. package/src/inference/pipelines/text/layer.js +3 -9
  204. package/src/inference/pipelines/text/linear-attention.d.ts +10 -0
  205. package/src/inference/pipelines/text/linear-attention.js +80 -6
  206. package/src/inference/pipelines/text/logits/gpu.js +10 -5
  207. package/src/inference/pipelines/text/logits/index.js +10 -11
  208. package/src/inference/pipelines/text/logits/utils.d.ts +7 -0
  209. package/src/inference/pipelines/text/logits/utils.js +9 -0
  210. package/src/inference/pipelines/text/lora-apply.js +50 -32
  211. package/src/inference/pipelines/text/model-load.js +279 -104
  212. package/src/inference/pipelines/text/moe-cache.js +5 -4
  213. package/src/inference/pipelines/text/moe-cpu-gptoss.js +74 -69
  214. package/src/inference/pipelines/text/moe-cpu.js +42 -38
  215. package/src/inference/pipelines/text/moe-gpu.js +110 -86
  216. package/src/inference/pipelines/text/ops.js +90 -90
  217. package/src/inference/pipelines/text/probes.js +9 -9
  218. package/src/inference/pipelines/text/weights.js +17 -7
  219. package/src/inference/pipelines/text.js +13 -1
  220. package/src/inference/speculative.d.ts +2 -2
  221. package/src/inference/speculative.js +4 -18
  222. package/src/inference/test-harness.d.ts +1 -1
  223. package/src/inference/test-harness.js +15 -5
  224. package/src/inference/tokenizer.d.ts +0 -5
  225. package/src/inference/tokenizer.js +4 -23
  226. package/src/inference/tokenizers/bpe.js +9 -0
  227. package/src/inference/tokenizers/bundled.js +20 -0
  228. package/src/inference/tokenizers/sentencepiece.js +12 -0
  229. package/src/loader/doppler-loader.js +38 -22
  230. package/src/loader/dtype-utils.js +3 -44
  231. package/src/loader/embedding-loader.js +7 -3
  232. package/src/loader/experts/expert-cache.js +13 -6
  233. package/src/loader/experts/expert-loader.js +10 -6
  234. package/src/loader/final-weights-loader.js +8 -4
  235. package/src/loader/layer-loader.js +2 -1
  236. package/src/loader/loader-state.js +2 -2
  237. package/src/loader/memory-monitor.js +8 -0
  238. package/src/loader/multi-model-loader.d.ts +14 -0
  239. package/src/loader/multi-model-loader.js +70 -24
  240. package/src/loader/shard-cache.js +81 -12
  241. package/src/loader/shard-resolver.js +25 -3
  242. package/src/loader/tensors/tensor-loader.js +209 -144
  243. package/src/loader/tensors/tensor-reader.js +76 -19
  244. package/src/loader/weight-downcast.js +1 -1
  245. package/src/memory/buffer-pool.d.ts +9 -1
  246. package/src/memory/buffer-pool.js +109 -44
  247. package/src/memory/unified-detect.js +1 -1
  248. package/src/rules/inference/kernel-path.rules.json +24 -8
  249. package/src/rules/rule-registry.js +25 -1
  250. package/src/storage/backends/opfs-store.js +68 -24
  251. package/src/storage/downloader.js +364 -83
  252. package/src/storage/index.d.ts +3 -0
  253. package/src/storage/index.js +3 -0
  254. package/src/storage/preflight.d.ts +2 -2
  255. package/src/storage/preflight.js +24 -2
  256. package/src/storage/quickstart-downloader.js +11 -5
  257. package/src/storage/registry.js +10 -4
  258. package/src/storage/reports.js +1 -1
  259. package/src/storage/shard-manager.d.ts +15 -1
  260. package/src/storage/shard-manager.js +51 -3
  261. package/src/storage/source-artifact-store.d.ts +52 -0
  262. package/src/storage/source-artifact-store.js +234 -0
  263. package/src/tooling/command-api-constants.d.ts +9 -0
  264. package/src/tooling/command-api-constants.js +9 -0
  265. package/src/tooling/command-api-family-normalizers.d.ts +9 -0
  266. package/src/tooling/command-api-family-normalizers.js +343 -0
  267. package/src/tooling/command-api-helpers.d.ts +25 -0
  268. package/src/tooling/command-api-helpers.js +262 -0
  269. package/src/tooling/command-api.js +16 -602
  270. package/src/tooling/command-envelope.js +4 -1
  271. package/src/tooling/command-runner-shared.js +52 -18
  272. package/src/tooling/lean-execution-contract.js +150 -3
  273. package/src/tooling/node-browser-command-runner.js +161 -271
  274. package/src/tooling/node-command-runner.js +29 -3
  275. package/src/tooling/node-converter.js +27 -1
  276. package/src/tooling/node-source-runtime.d.ts +1 -1
  277. package/src/tooling/node-source-runtime.js +84 -3
  278. package/src/tooling/node-webgpu.js +24 -21
  279. package/src/tooling/opfs-cache.js +21 -4
  280. package/src/tooling/runtime-input-composition.d.ts +38 -0
  281. package/src/tooling/runtime-input-composition.js +86 -0
  282. package/src/tooling/source-runtime-bundle.d.ts +40 -5
  283. package/src/tooling/source-runtime-bundle.js +261 -34
  284. package/src/tooling/source-runtime-materializer.d.ts +6 -0
  285. package/src/tooling/source-runtime-materializer.js +93 -0
  286. package/src/training/attention-backward.js +32 -17
  287. package/src/training/autograd.js +80 -52
  288. package/src/training/checkpoint-watch.d.ts +2 -1
  289. package/src/training/checkpoint-watch.js +39 -6
  290. package/src/training/checkpoint.js +40 -11
  291. package/src/training/clip.js +2 -1
  292. package/src/training/datasets/token-batch.js +20 -8
  293. package/src/training/distillation/checkpoint-watch.js +1 -0
  294. package/src/training/distillation/student-fixture.d.ts +22 -0
  295. package/src/training/distillation/student-fixture.js +846 -0
  296. package/src/training/distillation/suite-data.d.ts +45 -0
  297. package/src/training/distillation/suite-data.js +189 -0
  298. package/src/training/lora-pipeline.js +4 -7
  299. package/src/training/lora.js +26 -12
  300. package/src/training/loss.js +5 -6
  301. package/src/training/objectives/cross_entropy.js +2 -5
  302. package/src/training/objectives/distill_kd.js +4 -8
  303. package/src/training/objectives/distill_triplet.js +4 -8
  304. package/src/training/objectives/ul_stage2_base.js +4 -8
  305. package/src/training/operator-command.js +2 -0
  306. package/src/training/optimizer.js +19 -7
  307. package/src/training/runner.js +2 -1
  308. package/src/training/suite.js +18 -978
  309. package/src/training/tensor-factory.d.ts +9 -0
  310. package/src/training/tensor-factory.js +13 -0
  311. package/src/training/trainer.js +3 -5
  312. package/src/training/ul_dataset.js +3 -5
  313. package/src/training/workloads.js +70 -79
  314. package/src/version.js +1 -1
  315. package/tools/convert-safetensors-node.js +22 -16
  316. package/tools/doppler-cli.js +44 -25
@@ -1,6 +1,5 @@
1
- import { CommandRecorder } from '../../command-recorder.js';
2
1
  import { getDevice } from '../../device.js';
3
- import { acquireBuffer } from '../../../memory/buffer-pool.js';
2
+ import { acquireBuffer, releaseBuffer } from '../../../memory/buffer-pool.js';
4
3
  import { createTensor, dtypeBytes } from '../../tensor.js';
5
4
  import { castF16ToF32, recordCastF16ToF32 } from '../cast.js';
6
5
  import { runMatmul, recordMatmul } from '../matmul.js';
@@ -15,24 +14,16 @@ async function ensureF32(tensor, recorder = null) {
15
14
  if (!recorder) {
16
15
  return castF16ToF32(tensor);
17
16
  }
18
- const casted = await recordCastF16ToF32(recorder, tensor);
19
- recorder.trackTemporaryBuffer(casted.buffer);
20
- return casted;
17
+ return recordCastF16ToF32(recorder, tensor);
21
18
  }
22
19
 
23
- function createHeadSliceBuffers(recorder, headBytes, softmaxBytes) {
20
+ function createHeadSliceBuffers(headBytes, softmaxBytes) {
24
21
  const qHeadBuf = acquireBuffer(headBytes, undefined, 'attn_q_head');
25
22
  const kHeadBuf = acquireBuffer(headBytes, undefined, 'attn_k_head');
26
23
  const vHeadBuf = acquireBuffer(headBytes, undefined, 'attn_v_head');
27
24
  const sHeadBuf = acquireBuffer(softmaxBytes, undefined, 'attn_s_head');
28
25
  const dHeadBuf = acquireBuffer(headBytes, undefined, 'attn_d_head');
29
26
 
30
- recorder.trackTemporaryBuffer(qHeadBuf);
31
- recorder.trackTemporaryBuffer(kHeadBuf);
32
- recorder.trackTemporaryBuffer(vHeadBuf);
33
- recorder.trackTemporaryBuffer(sHeadBuf);
34
- recorder.trackTemporaryBuffer(dHeadBuf);
35
-
36
27
  return { qHeadBuf, kHeadBuf, vHeadBuf, sHeadBuf, dHeadBuf };
37
28
  }
38
29
 
@@ -49,6 +40,19 @@ function trackTensorBuffer(recorder, tensor) {
49
40
  recorder.trackTemporaryBuffer(tensor.buffer);
50
41
  }
51
42
 
43
+ function releaseTensorBuffer(tensor) {
44
+ if (tensor?.buffer) {
45
+ releaseBuffer(tensor.buffer);
46
+ }
47
+ }
48
+
49
+ function maybeTrackOwnedTensor(ownedTensors, originalTensor, resolvedTensor) {
50
+ if (resolvedTensor !== originalTensor) {
51
+ ownedTensors.push(resolvedTensor);
52
+ }
53
+ return resolvedTensor;
54
+ }
55
+
52
56
  async function runAttentionBackwardCore(
53
57
  q,
54
58
  k,
@@ -63,11 +67,23 @@ async function runAttentionBackwardCore(
63
67
  throw new Error('attention backward requires seqLen, numHeads, and headDim');
64
68
  }
65
69
 
66
- const qTensor = await ensureF32(q, recorder);
67
- const kTensor = await ensureF32(k, recorder);
68
- const vTensor = await ensureF32(v, recorder);
69
- const sTensor = await ensureF32(softmax, recorder);
70
- const dTensor = await ensureF32(gradOutput, recorder);
70
+ const ownedInputTensors = [];
71
+ const ownedRecorderInputTensors = [];
72
+ const qTensor = !recorder
73
+ ? maybeTrackOwnedTensor(ownedInputTensors, q, await ensureF32(q))
74
+ : maybeTrackOwnedTensor(ownedRecorderInputTensors, q, await ensureF32(q, recorder));
75
+ const kTensor = !recorder
76
+ ? maybeTrackOwnedTensor(ownedInputTensors, k, await ensureF32(k))
77
+ : maybeTrackOwnedTensor(ownedRecorderInputTensors, k, await ensureF32(k, recorder));
78
+ const vTensor = !recorder
79
+ ? maybeTrackOwnedTensor(ownedInputTensors, v, await ensureF32(v))
80
+ : maybeTrackOwnedTensor(ownedRecorderInputTensors, v, await ensureF32(v, recorder));
81
+ const sTensor = !recorder
82
+ ? maybeTrackOwnedTensor(ownedInputTensors, softmax, await ensureF32(softmax))
83
+ : maybeTrackOwnedTensor(ownedRecorderInputTensors, softmax, await ensureF32(softmax, recorder));
84
+ const dTensor = !recorder
85
+ ? maybeTrackOwnedTensor(ownedInputTensors, gradOutput, await ensureF32(gradOutput))
86
+ : maybeTrackOwnedTensor(ownedRecorderInputTensors, gradOutput, await ensureF32(gradOutput, recorder));
71
87
 
72
88
  const headElements = seqLen * headDim;
73
89
  const headBytes = headElements * dtypeBytes(qTensor.dtype);
@@ -77,171 +93,247 @@ async function runAttentionBackwardCore(
77
93
  const gradQBuf = acquireBuffer(totalBytes, undefined, 'attn_grad_q');
78
94
  const gradKBuf = acquireBuffer(totalBytes, undefined, 'attn_grad_k');
79
95
  const gradVBuf = acquireBuffer(totalBytes, undefined, 'attn_grad_v');
96
+ let completed = false;
80
97
 
81
- if (!recorder) {
82
- for (let h = 0; h < numHeads; h += 1) {
83
- const qOffset = h * headBytes;
84
- const kOffset = h * headBytes;
85
- const vOffset = h * headBytes;
86
- const dOffset = h * headBytes;
87
- const sOffset = h * softmaxBytes;
98
+ try {
99
+ if (!recorder) {
100
+ for (let h = 0; h < numHeads; h += 1) {
101
+ const qOffset = h * headBytes;
102
+ const kOffset = h * headBytes;
103
+ const vOffset = h * headBytes;
104
+ const dOffset = h * headBytes;
105
+ const sOffset = h * softmaxBytes;
88
106
 
89
- const qHeadBuf = acquireBuffer(headBytes, undefined, 'attn_q_head');
90
- const kHeadBuf = acquireBuffer(headBytes, undefined, 'attn_k_head');
91
- const vHeadBuf = acquireBuffer(headBytes, undefined, 'attn_v_head');
92
- const sHeadBuf = acquireBuffer(softmaxBytes, undefined, 'attn_s_head');
93
- const dHeadBuf = acquireBuffer(headBytes, undefined, 'attn_d_head');
107
+ const qHeadBuf = acquireBuffer(headBytes, undefined, 'attn_q_head');
108
+ const kHeadBuf = acquireBuffer(headBytes, undefined, 'attn_k_head');
109
+ const vHeadBuf = acquireBuffer(headBytes, undefined, 'attn_v_head');
110
+ const sHeadBuf = acquireBuffer(softmaxBytes, undefined, 'attn_s_head');
111
+ const dHeadBuf = acquireBuffer(headBytes, undefined, 'attn_d_head');
112
+ let sTransposed = null;
113
+ let dV = null;
114
+ let vTransposed = null;
115
+ let dS = null;
116
+ let dQK = null;
117
+ let dQ = null;
118
+ let dQKTransposed = null;
119
+ let dK = null;
94
120
 
95
- const sliceEncoder = getDevice().createCommandEncoder();
96
- sliceEncoder.copyBufferToBuffer(qTensor.buffer, qOffset, qHeadBuf, 0, headBytes);
97
- sliceEncoder.copyBufferToBuffer(kTensor.buffer, kOffset, kHeadBuf, 0, headBytes);
98
- sliceEncoder.copyBufferToBuffer(vTensor.buffer, vOffset, vHeadBuf, 0, headBytes);
99
- sliceEncoder.copyBufferToBuffer(sTensor.buffer, sOffset, sHeadBuf, 0, softmaxBytes);
100
- sliceEncoder.copyBufferToBuffer(dTensor.buffer, dOffset, dHeadBuf, 0, headBytes);
101
- getDevice().queue.submit([sliceEncoder.finish()]);
121
+ try {
122
+ const sliceEncoder = getDevice().createCommandEncoder();
123
+ sliceEncoder.copyBufferToBuffer(qTensor.buffer, qOffset, qHeadBuf, 0, headBytes);
124
+ sliceEncoder.copyBufferToBuffer(kTensor.buffer, kOffset, kHeadBuf, 0, headBytes);
125
+ sliceEncoder.copyBufferToBuffer(vTensor.buffer, vOffset, vHeadBuf, 0, headBytes);
126
+ sliceEncoder.copyBufferToBuffer(sTensor.buffer, sOffset, sHeadBuf, 0, softmaxBytes);
127
+ sliceEncoder.copyBufferToBuffer(dTensor.buffer, dOffset, dHeadBuf, 0, headBytes);
128
+ getDevice().queue.submit([sliceEncoder.finish()]);
102
129
 
103
- const { qHead, kHead, vHead, sHead, dHead } = createHeadTensors(
104
- qHeadBuf,
105
- kHeadBuf,
106
- vHeadBuf,
107
- sHeadBuf,
108
- dHeadBuf,
109
- seqLen,
110
- headDim
111
- );
130
+ const { qHead, kHead, vHead, sHead, dHead } = createHeadTensors(
131
+ qHeadBuf,
132
+ kHeadBuf,
133
+ vHeadBuf,
134
+ sHeadBuf,
135
+ dHeadBuf,
136
+ seqLen,
137
+ headDim
138
+ );
112
139
 
113
- const sTransposed = await runTranspose(sHead, seqLen, seqLen);
114
- const dV = await runMatmul(sTransposed, dHead.buffer, seqLen, headDim, seqLen, {
115
- transposeB: false,
116
- bDtype: 'f32',
117
- });
140
+ sTransposed = await runTranspose(sHead, seqLen, seqLen);
141
+ dV = await runMatmul(sTransposed, dHead.buffer, seqLen, headDim, seqLen, {
142
+ transposeB: false,
143
+ bDtype: 'f32',
144
+ });
118
145
 
119
- const vTransposed = await runTranspose(vHead, seqLen, headDim);
120
- const dS = await runMatmul(dHead, vTransposed.buffer, seqLen, seqLen, headDim, {
121
- transposeB: false,
122
- bDtype: 'f32',
123
- });
124
- const dQK = causal
125
- ? await runBackwardKernel(
126
- 'attention_backward',
127
- sHead,
128
- dS,
129
- 16,
130
- (view) => {
131
- view.setUint32(0, seqLen, true);
132
- view.setUint32(4, seqLen, true);
133
- view.setUint32(8, 1, true);
134
- }
135
- )
136
- : await runSoftmaxBackward(sHead, dS, { rows: seqLen, cols: seqLen });
146
+ vTransposed = await runTranspose(vHead, seqLen, headDim);
147
+ dS = await runMatmul(dHead, vTransposed.buffer, seqLen, seqLen, headDim, {
148
+ transposeB: false,
149
+ bDtype: 'f32',
150
+ });
151
+ dQK = causal
152
+ ? await runBackwardKernel(
153
+ 'attention_backward',
154
+ sHead,
155
+ dS,
156
+ 16,
157
+ (view) => {
158
+ view.setUint32(0, seqLen, true);
159
+ view.setUint32(4, seqLen, true);
160
+ view.setUint32(8, 1, true);
161
+ }
162
+ )
163
+ : await runSoftmaxBackward(sHead, dS, { rows: seqLen, cols: seqLen });
137
164
 
138
- const dQ = await runMatmul(dQK, kHead.buffer, seqLen, headDim, seqLen, {
139
- transposeB: false,
140
- alpha: scale,
141
- bDtype: 'f32',
142
- });
143
- const dQKTransposed = await runTranspose(dQK, seqLen, seqLen);
144
- const dK = await runMatmul(dQKTransposed, qHead.buffer, seqLen, headDim, seqLen, {
145
- transposeB: false,
146
- alpha: scale,
147
- bDtype: 'f32',
148
- });
165
+ dQ = await runMatmul(dQK, kHead.buffer, seqLen, headDim, seqLen, {
166
+ transposeB: false,
167
+ alpha: scale,
168
+ bDtype: 'f32',
169
+ });
170
+ dQKTransposed = await runTranspose(dQK, seqLen, seqLen);
171
+ dK = await runMatmul(dQKTransposed, qHead.buffer, seqLen, headDim, seqLen, {
172
+ transposeB: false,
173
+ alpha: scale,
174
+ bDtype: 'f32',
175
+ });
149
176
 
150
- const copyEncoder = getDevice().createCommandEncoder();
151
- copyEncoder.copyBufferToBuffer(dQ.buffer, 0, gradQBuf, qOffset, headBytes);
152
- copyEncoder.copyBufferToBuffer(dK.buffer, 0, gradKBuf, kOffset, headBytes);
153
- copyEncoder.copyBufferToBuffer(dV.buffer, 0, gradVBuf, vOffset, headBytes);
154
- getDevice().queue.submit([copyEncoder.finish()]);
155
- }
156
- } else {
157
- const encoder = recorder.getEncoder();
158
- for (let h = 0; h < numHeads; h += 1) {
159
- const qOffset = h * headBytes;
160
- const kOffset = h * headBytes;
161
- const vOffset = h * headBytes;
162
- const dOffset = h * headBytes;
163
- const sOffset = h * softmaxBytes;
177
+ const copyEncoder = getDevice().createCommandEncoder();
178
+ copyEncoder.copyBufferToBuffer(dQ.buffer, 0, gradQBuf, qOffset, headBytes);
179
+ copyEncoder.copyBufferToBuffer(dK.buffer, 0, gradKBuf, kOffset, headBytes);
180
+ copyEncoder.copyBufferToBuffer(dV.buffer, 0, gradVBuf, vOffset, headBytes);
181
+ getDevice().queue.submit([copyEncoder.finish()]);
182
+ await getDevice().queue.onSubmittedWorkDone();
183
+ } finally {
184
+ releaseTensorBuffer(sTransposed);
185
+ releaseTensorBuffer(dV);
186
+ releaseTensorBuffer(vTransposed);
187
+ releaseTensorBuffer(dS);
188
+ releaseTensorBuffer(dQK);
189
+ releaseTensorBuffer(dQ);
190
+ releaseTensorBuffer(dQKTransposed);
191
+ releaseTensorBuffer(dK);
192
+ releaseBuffer(qHeadBuf);
193
+ releaseBuffer(kHeadBuf);
194
+ releaseBuffer(vHeadBuf);
195
+ releaseBuffer(sHeadBuf);
196
+ releaseBuffer(dHeadBuf);
197
+ }
198
+ }
199
+ } else {
200
+ const encoder = recorder.getEncoder();
201
+ for (let h = 0; h < numHeads; h += 1) {
202
+ const qOffset = h * headBytes;
203
+ const kOffset = h * headBytes;
204
+ const vOffset = h * headBytes;
205
+ const dOffset = h * headBytes;
206
+ const sOffset = h * softmaxBytes;
164
207
 
165
- const { qHeadBuf, kHeadBuf, vHeadBuf, sHeadBuf, dHeadBuf } = createHeadSliceBuffers(
166
- recorder,
167
- headBytes,
168
- softmaxBytes
169
- );
208
+ const { qHeadBuf, kHeadBuf, vHeadBuf, sHeadBuf, dHeadBuf } = createHeadSliceBuffers(
209
+ headBytes,
210
+ softmaxBytes
211
+ );
212
+ const headBuffers = [qHeadBuf, kHeadBuf, vHeadBuf, sHeadBuf, dHeadBuf];
213
+ let sTransposed = null;
214
+ let dV = null;
215
+ let vTransposed = null;
216
+ let dS = null;
217
+ let dQK = null;
218
+ let dQ = null;
219
+ let dQKTransposed = null;
220
+ let dK = null;
170
221
 
171
- encoder.copyBufferToBuffer(qTensor.buffer, qOffset, qHeadBuf, 0, headBytes);
172
- encoder.copyBufferToBuffer(kTensor.buffer, kOffset, kHeadBuf, 0, headBytes);
173
- encoder.copyBufferToBuffer(vTensor.buffer, vOffset, vHeadBuf, 0, headBytes);
174
- encoder.copyBufferToBuffer(sTensor.buffer, sOffset, sHeadBuf, 0, softmaxBytes);
175
- encoder.copyBufferToBuffer(dTensor.buffer, dOffset, dHeadBuf, 0, headBytes);
222
+ try {
223
+ encoder.copyBufferToBuffer(qTensor.buffer, qOffset, qHeadBuf, 0, headBytes);
224
+ encoder.copyBufferToBuffer(kTensor.buffer, kOffset, kHeadBuf, 0, headBytes);
225
+ encoder.copyBufferToBuffer(vTensor.buffer, vOffset, vHeadBuf, 0, headBytes);
226
+ encoder.copyBufferToBuffer(sTensor.buffer, sOffset, sHeadBuf, 0, softmaxBytes);
227
+ encoder.copyBufferToBuffer(dTensor.buffer, dOffset, dHeadBuf, 0, headBytes);
176
228
 
177
- const { qHead, kHead, vHead, sHead, dHead } = createHeadTensors(
178
- qHeadBuf,
179
- kHeadBuf,
180
- vHeadBuf,
181
- sHeadBuf,
182
- dHeadBuf,
183
- seqLen,
184
- headDim
185
- );
229
+ const { qHead, kHead, vHead, sHead, dHead } = createHeadTensors(
230
+ qHeadBuf,
231
+ kHeadBuf,
232
+ vHeadBuf,
233
+ sHeadBuf,
234
+ dHeadBuf,
235
+ seqLen,
236
+ headDim
237
+ );
186
238
 
187
- const sTransposed = await recordTranspose(recorder, sHead, seqLen, seqLen);
188
- const dV = await recordMatmul(recorder, sTransposed, dHead.buffer, seqLen, headDim, seqLen, {
189
- transposeB: false,
190
- bDtype: 'f32',
191
- });
239
+ sTransposed = await recordTranspose(recorder, sHead, seqLen, seqLen);
240
+ dV = await recordMatmul(recorder, sTransposed, dHead.buffer, seqLen, headDim, seqLen, {
241
+ transposeB: false,
242
+ bDtype: 'f32',
243
+ });
192
244
 
193
- const vTransposed = await recordTranspose(recorder, vHead, seqLen, headDim);
194
- const dS = await recordMatmul(recorder, dHead, vTransposed.buffer, seqLen, seqLen, headDim, {
195
- transposeB: false,
196
- bDtype: 'f32',
197
- });
198
- const dQK = causal
199
- ? await recordBackwardKernel(
200
- recorder,
201
- 'attention_backward',
202
- sHead,
203
- dS,
204
- 16,
205
- (view) => {
206
- view.setUint32(0, seqLen, true);
207
- view.setUint32(4, seqLen, true);
208
- view.setUint32(8, 1, true);
209
- }
210
- )
211
- : await recordSoftmaxBackward(recorder, sHead, dS, { rows: seqLen, cols: seqLen });
245
+ vTransposed = await recordTranspose(recorder, vHead, seqLen, headDim);
246
+ dS = await recordMatmul(recorder, dHead, vTransposed.buffer, seqLen, seqLen, headDim, {
247
+ transposeB: false,
248
+ bDtype: 'f32',
249
+ });
250
+ dQK = causal
251
+ ? await recordBackwardKernel(
252
+ recorder,
253
+ 'attention_backward',
254
+ sHead,
255
+ dS,
256
+ 16,
257
+ (view) => {
258
+ view.setUint32(0, seqLen, true);
259
+ view.setUint32(4, seqLen, true);
260
+ view.setUint32(8, 1, true);
261
+ }
262
+ )
263
+ : await recordSoftmaxBackward(recorder, sHead, dS, { rows: seqLen, cols: seqLen });
212
264
 
213
- const dQ = await recordMatmul(recorder, dQK, kHead.buffer, seqLen, headDim, seqLen, {
214
- transposeB: false,
215
- alpha: scale,
216
- bDtype: 'f32',
217
- });
218
- const dQKTransposed = await recordTranspose(recorder, dQK, seqLen, seqLen);
219
- const dK = await recordMatmul(recorder, dQKTransposed, qHead.buffer, seqLen, headDim, seqLen, {
220
- transposeB: false,
221
- alpha: scale,
222
- bDtype: 'f32',
223
- });
265
+ dQ = await recordMatmul(recorder, dQK, kHead.buffer, seqLen, headDim, seqLen, {
266
+ transposeB: false,
267
+ alpha: scale,
268
+ bDtype: 'f32',
269
+ });
270
+ dQKTransposed = await recordTranspose(recorder, dQK, seqLen, seqLen);
271
+ dK = await recordMatmul(recorder, dQKTransposed, qHead.buffer, seqLen, headDim, seqLen, {
272
+ transposeB: false,
273
+ alpha: scale,
274
+ bDtype: 'f32',
275
+ });
224
276
 
225
- encoder.copyBufferToBuffer(dQ.buffer, 0, gradQBuf, qOffset, headBytes);
226
- encoder.copyBufferToBuffer(dK.buffer, 0, gradKBuf, kOffset, headBytes);
227
- encoder.copyBufferToBuffer(dV.buffer, 0, gradVBuf, vOffset, headBytes);
277
+ encoder.copyBufferToBuffer(dQ.buffer, 0, gradQBuf, qOffset, headBytes);
278
+ encoder.copyBufferToBuffer(dK.buffer, 0, gradKBuf, kOffset, headBytes);
279
+ encoder.copyBufferToBuffer(dV.buffer, 0, gradVBuf, vOffset, headBytes);
280
+ } catch (error) {
281
+ releaseTensorBuffer(sTransposed);
282
+ releaseTensorBuffer(dV);
283
+ releaseTensorBuffer(vTransposed);
284
+ releaseTensorBuffer(dS);
285
+ releaseTensorBuffer(dQK);
286
+ releaseTensorBuffer(dQ);
287
+ releaseTensorBuffer(dQKTransposed);
288
+ releaseTensorBuffer(dK);
289
+ releaseBuffer(qHeadBuf);
290
+ releaseBuffer(kHeadBuf);
291
+ releaseBuffer(vHeadBuf);
292
+ releaseBuffer(sHeadBuf);
293
+ releaseBuffer(dHeadBuf);
294
+ throw error;
295
+ }
228
296
 
229
- trackTensorBuffer(recorder, sTransposed);
230
- trackTensorBuffer(recorder, dV);
231
- trackTensorBuffer(recorder, vTransposed);
232
- trackTensorBuffer(recorder, dS);
233
- trackTensorBuffer(recorder, dQK);
234
- trackTensorBuffer(recorder, dQ);
235
- trackTensorBuffer(recorder, dQKTransposed);
236
- trackTensorBuffer(recorder, dK);
297
+ for (const buffer of headBuffers) {
298
+ recorder.trackTemporaryBuffer(buffer);
299
+ }
300
+ trackTensorBuffer(recorder, sTransposed);
301
+ trackTensorBuffer(recorder, dV);
302
+ trackTensorBuffer(recorder, vTransposed);
303
+ trackTensorBuffer(recorder, dS);
304
+ trackTensorBuffer(recorder, dQK);
305
+ trackTensorBuffer(recorder, dQ);
306
+ trackTensorBuffer(recorder, dQKTransposed);
307
+ trackTensorBuffer(recorder, dK);
308
+ }
309
+ }
310
+ if (recorder) {
311
+ for (const tensor of ownedRecorderInputTensors) {
312
+ trackTensorBuffer(recorder, tensor);
313
+ }
314
+ }
315
+ completed = true;
316
+ return {
317
+ gradQ: createTensor(gradQBuf, 'f32', [...q.shape], 'attn_grad_q'),
318
+ gradK: createTensor(gradKBuf, 'f32', [...k.shape], 'attn_grad_k'),
319
+ gradV: createTensor(gradVBuf, 'f32', [...v.shape], 'attn_grad_v'),
320
+ };
321
+ } finally {
322
+ if (!completed) {
323
+ releaseBuffer(gradQBuf);
324
+ releaseBuffer(gradKBuf);
325
+ releaseBuffer(gradVBuf);
326
+ }
327
+ if (!recorder) {
328
+ for (const tensor of ownedInputTensors) {
329
+ releaseTensorBuffer(tensor);
330
+ }
331
+ } else {
332
+ for (const tensor of ownedRecorderInputTensors) {
333
+ releaseTensorBuffer(tensor);
334
+ }
237
335
  }
238
336
  }
239
-
240
- return {
241
- gradQ: createTensor(gradQBuf, 'f32', [...q.shape], 'attn_grad_q'),
242
- gradK: createTensor(gradKBuf, 'f32', [...k.shape], 'attn_grad_k'),
243
- gradV: createTensor(gradVBuf, 'f32', [...v.shape], 'attn_grad_v'),
244
- };
245
337
  }
246
338
 
247
339
  export async function runAttentionBackward(
@@ -256,11 +348,7 @@ export async function runAttentionBackward(
256
348
  if (!device) {
257
349
  throw new Error('runAttentionBackward requires a GPU device');
258
350
  }
259
-
260
- const recorder = new CommandRecorder(device, 'attention_backward');
261
- const result = await runAttentionBackwardCore(q, k, v, softmax, gradOutput, options, recorder);
262
- recorder.submit();
263
- return result;
351
+ return runAttentionBackwardCore(q, k, v, softmax, gradOutput, options);
264
352
  }
265
353
 
266
354
  export async function recordAttentionBackward(
@@ -4,6 +4,19 @@ import { createPipeline, createUniformBufferWithView } from '../utils.js';
4
4
  import { dispatch, recordDispatch } from '../dispatch.js';
5
5
  import { getDevice } from '../../device.js';
6
6
 
7
+ function destroyAfterSubmit(device, buffer) {
8
+ if (!buffer) {
9
+ return;
10
+ }
11
+ device.queue.onSubmittedWorkDone()
12
+ .then(() => {
13
+ buffer.destroy();
14
+ })
15
+ .catch(() => {
16
+ buffer.destroy();
17
+ });
18
+ }
19
+
7
20
  export async function runConv2DBackward(input, weight, gradOutput, options = {}) {
8
21
  const { inChannels, outChannels, height, width, outHeight, outWidth, kernelH, kernelW, stride, pad, computeGradInput = true, computeGradWeight = true } = options;
9
22
 
@@ -67,7 +80,7 @@ export async function runConv2DBackward(input, weight, gradOutput, options = {})
67
80
  gradWeight = createTensor(outputBuf, 'f32', [outChannels, inChannels, kernelH, kernelW], 'conv2d_grad_weight');
68
81
  }
69
82
 
70
- uniformBuffer.destroy();
83
+ destroyAfterSubmit(device, uniformBuffer);
71
84
  return { gradInput, gradWeight };
72
85
  }
73
86