@simulatte/doppler 0.1.6 → 0.1.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (316) hide show
  1. package/CHANGELOG.md +126 -0
  2. package/README.md +16 -23
  3. package/package.json +14 -1
  4. package/src/adapters/adapter-registry.js +12 -1
  5. package/src/adapters/lora-loader.js +23 -6
  6. package/src/bridge/extension-client.d.ts +5 -0
  7. package/src/bridge/extension-client.js +40 -0
  8. package/src/bridge/index.d.ts +2 -1
  9. package/src/bridge/index.js +6 -4
  10. package/src/browser/browser-converter.js +26 -1
  11. package/src/browser/file-picker.js +6 -0
  12. package/src/browser/safetensors-parser-browser.js +84 -1
  13. package/src/browser/shard-io-browser.js +2 -2
  14. package/src/browser/tensor-source-download.js +8 -2
  15. package/src/browser/tensor-source-http.d.ts +1 -0
  16. package/src/browser/tensor-source-http.js +5 -1
  17. package/src/client/doppler-api.browser.js +20 -4
  18. package/src/client/doppler-api.js +19 -3
  19. package/src/client/doppler-provider/generation.js +12 -0
  20. package/src/client/doppler-provider/model-manager.d.ts +10 -0
  21. package/src/client/doppler-provider/model-manager.js +91 -19
  22. package/src/client/doppler-provider/source-runtime.d.ts +2 -1
  23. package/src/client/doppler-provider/source-runtime.js +132 -13
  24. package/src/client/doppler-registry.json +8 -7
  25. package/src/config/backward-registry-loader.js +17 -2
  26. package/src/config/execution-v0-contract-check.js +113 -15
  27. package/src/config/kernel-path-contract-check.js +57 -29
  28. package/src/config/kernel-path-loader.js +5 -36
  29. package/src/config/kernels/kernel-ref-digests.js +1 -1
  30. package/src/config/kernels/registry.js +14 -1
  31. package/src/config/kernels/registry.json +7 -5
  32. package/src/config/loader.d.ts +1 -1
  33. package/src/config/loader.js +12 -2
  34. package/src/config/merge-contract-check.js +59 -4
  35. package/src/config/merge-helpers.js +128 -7
  36. package/src/config/merge.d.ts +1 -0
  37. package/src/config/merge.js +10 -0
  38. package/src/config/param-validator.js +47 -2
  39. package/src/config/presets/kernel-paths/{gemma2-q4k-dequant-f32a.json → gemma2-q4k-dequant-f32a-nosubgroups.json} +3 -3
  40. package/src/config/presets/kernel-paths/gemma3-f16-fused-f32a-online-streamingprefill.json +223 -0
  41. package/src/config/presets/kernel-paths/{gemma3-q4k-dequant-f32a.json → gemma3-q4k-dequant-f32a-nosubgroups.json} +3 -3
  42. package/src/config/presets/kernel-paths/registry.json +29 -8
  43. package/src/config/presets/models/gemma2.json +2 -2
  44. package/src/config/presets/runtime/experiments/bench/gemma3-bench-q4k.json +1 -1
  45. package/src/config/presets/runtime/experiments/debug/gemma3-debug-q4k.json +1 -1
  46. package/src/config/presets/runtime/experiments/verify/gemma3-verify.json +1 -1
  47. package/src/config/presets/runtime/kernels/dequant-f16-q4k.json +6 -13
  48. package/src/config/presets/runtime/kernels/dequant-f32-q4k.json +6 -13
  49. package/src/config/presets/runtime/kernels/embeddinggemma-q4k-dequant-f32a.json +37 -0
  50. package/src/config/presets/runtime/kernels/fused-q4k.json +6 -13
  51. package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f16a.json +33 -0
  52. package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f32a-nosubgroups.json +33 -0
  53. package/src/config/presets/runtime/kernels/gemma2-q4k-fused-f32a.json +33 -0
  54. package/src/config/presets/runtime/kernels/safe-q4k.json +6 -13
  55. package/src/config/presets/runtime/platform/metal-apple-q4k.json +1 -1
  56. package/src/config/runtime.js +6 -1
  57. package/src/config/schema/debug.schema.d.ts +5 -0
  58. package/src/config/schema/doppler.schema.js +16 -21
  59. package/src/config/schema/inference-defaults.schema.js +3 -3
  60. package/src/config/schema/kernel-path.schema.d.ts +5 -1
  61. package/src/config/schema/kernel-thresholds.schema.js +12 -4
  62. package/src/config/schema/manifest.schema.d.ts +2 -1
  63. package/src/config/schema/manifest.schema.js +16 -3
  64. package/src/config/training-defaults.js +30 -22
  65. package/src/converter/conversion-plan.js +94 -9
  66. package/src/converter/core.d.ts +7 -0
  67. package/src/converter/core.js +14 -9
  68. package/src/converter/execution-v0-manifest.js +4 -1
  69. package/src/converter/index.d.ts +1 -0
  70. package/src/converter/index.js +1 -0
  71. package/src/converter/manifest-inference.js +43 -12
  72. package/src/converter/parsers/diffusion.js +0 -3
  73. package/src/converter/quantization-info.js +35 -15
  74. package/src/converter/shard-packer.d.ts +1 -1
  75. package/src/converter/shard-packer.js +4 -1
  76. package/src/debug/config.js +123 -11
  77. package/src/debug/signals.js +7 -1
  78. package/src/debug/tensor.d.ts +2 -0
  79. package/src/debug/tensor.js +13 -2
  80. package/src/distribution/p2p-control-plane.js +52 -12
  81. package/src/distribution/p2p-observability.js +43 -7
  82. package/src/distribution/p2p-webrtc-browser.js +20 -0
  83. package/src/distribution/shard-delivery.js +77 -26
  84. package/src/formats/gguf/types.js +33 -16
  85. package/src/formats/rdrr/groups.d.ts +12 -4
  86. package/src/formats/rdrr/groups.js +3 -6
  87. package/src/formats/rdrr/parsing.js +39 -2
  88. package/src/formats/rdrr/types.d.ts +2 -1
  89. package/src/gpu/command-recorder.js +86 -61
  90. package/src/gpu/device.d.ts +1 -0
  91. package/src/gpu/device.js +73 -19
  92. package/src/gpu/kernel-tuner/benchmarks.js +326 -316
  93. package/src/gpu/kernel-tuner/cache.js +71 -4
  94. package/src/gpu/kernel-tuner/tuner.js +22 -4
  95. package/src/gpu/kernels/attention.js +15 -34
  96. package/src/gpu/kernels/backward/adam.js +62 -58
  97. package/src/gpu/kernels/backward/attention_backward.js +257 -169
  98. package/src/gpu/kernels/backward/conv2d_backward.js +14 -1
  99. package/src/gpu/kernels/cast.js +191 -149
  100. package/src/gpu/kernels/check-stop.js +33 -44
  101. package/src/gpu/kernels/conv2d.js +27 -17
  102. package/src/gpu/kernels/cross_entropy_loss.js +21 -15
  103. package/src/gpu/kernels/depthwise_conv2d.js +36 -26
  104. package/src/gpu/kernels/dequant.js +178 -126
  105. package/src/gpu/kernels/energy.d.ts +3 -21
  106. package/src/gpu/kernels/energy.js +111 -88
  107. package/src/gpu/kernels/feature-check.js +1 -1
  108. package/src/gpu/kernels/fused_ffn.js +84 -65
  109. package/src/gpu/kernels/fused_matmul_residual.js +56 -33
  110. package/src/gpu/kernels/fused_matmul_rmsnorm.js +62 -45
  111. package/src/gpu/kernels/gather.js +33 -15
  112. package/src/gpu/kernels/gelu.js +19 -11
  113. package/src/gpu/kernels/grouped_pointwise_conv2d.js +33 -23
  114. package/src/gpu/kernels/groupnorm.js +34 -23
  115. package/src/gpu/kernels/kv-quantize.js +5 -2
  116. package/src/gpu/kernels/layernorm.js +35 -19
  117. package/src/gpu/kernels/logit-merge.js +5 -3
  118. package/src/gpu/kernels/matmul.js +58 -39
  119. package/src/gpu/kernels/modulate.js +23 -15
  120. package/src/gpu/kernels/moe.js +221 -175
  121. package/src/gpu/kernels/pixel_shuffle.js +22 -14
  122. package/src/gpu/kernels/relu.js +18 -10
  123. package/src/gpu/kernels/repeat_channels.js +25 -17
  124. package/src/gpu/kernels/residual.js +37 -27
  125. package/src/gpu/kernels/rmsnorm.js +57 -41
  126. package/src/gpu/kernels/rope.js +3 -0
  127. package/src/gpu/kernels/sample.js +27 -38
  128. package/src/gpu/kernels/sana_linear_attention.js +18 -10
  129. package/src/gpu/kernels/scale.js +18 -11
  130. package/src/gpu/kernels/shader-cache.js +4 -2
  131. package/src/gpu/kernels/silu.js +120 -72
  132. package/src/gpu/kernels/softmax.js +44 -25
  133. package/src/gpu/kernels/split_qkv.js +23 -13
  134. package/src/gpu/kernels/transpose.js +18 -10
  135. package/src/gpu/kernels/transpose.wgsl +5 -3
  136. package/src/gpu/kernels/upsample2d.js +21 -13
  137. package/src/gpu/kernels/utils.js +20 -13
  138. package/src/gpu/partitioned-buffer-pool.js +10 -2
  139. package/src/gpu/perf-guards.js +2 -9
  140. package/src/gpu/profiler.js +27 -22
  141. package/src/gpu/readback-utils.d.ts +16 -0
  142. package/src/gpu/readback-utils.js +41 -0
  143. package/src/gpu/submit-tracker.js +13 -0
  144. package/src/gpu/uniform-cache.d.ts +1 -0
  145. package/src/gpu/uniform-cache.js +30 -9
  146. package/src/hotswap/intent-bundle.js +6 -0
  147. package/src/hotswap/manifest.d.ts +10 -1
  148. package/src/hotswap/manifest.js +12 -2
  149. package/src/hotswap/runtime.js +30 -8
  150. package/src/index-browser.d.ts +44 -0
  151. package/src/index-browser.js +14 -0
  152. package/src/inference/browser-harness-contract-helpers.d.ts +5 -0
  153. package/src/inference/browser-harness-contract-helpers.js +28 -0
  154. package/src/inference/browser-harness-diffusion-energy-suites.d.ts +2 -0
  155. package/src/inference/browser-harness-diffusion-energy-suites.js +269 -0
  156. package/src/inference/browser-harness-model-helpers.d.ts +16 -0
  157. package/src/inference/browser-harness-model-helpers.js +217 -0
  158. package/src/inference/browser-harness-report-helpers.d.ts +7 -0
  159. package/src/inference/browser-harness-report-helpers.js +42 -0
  160. package/src/inference/browser-harness-runtime-helpers.d.ts +61 -0
  161. package/src/inference/browser-harness-runtime-helpers.js +415 -0
  162. package/src/inference/browser-harness-suite-helpers.d.ts +28 -0
  163. package/src/inference/browser-harness-suite-helpers.js +268 -0
  164. package/src/inference/browser-harness-text-helpers.d.ts +27 -0
  165. package/src/inference/browser-harness-text-helpers.js +788 -0
  166. package/src/inference/browser-harness.d.ts +6 -0
  167. package/src/inference/browser-harness.js +130 -1996
  168. package/src/inference/kv-cache/base.js +140 -94
  169. package/src/inference/kv-cache/tiered.js +5 -3
  170. package/src/inference/moe-router.js +88 -56
  171. package/src/inference/multi-model-network.js +5 -3
  172. package/src/inference/network-evolution.d.ts +11 -2
  173. package/src/inference/network-evolution.js +20 -21
  174. package/src/inference/pipelines/context.d.ts +3 -0
  175. package/src/inference/pipelines/context.js +142 -2
  176. package/src/inference/pipelines/diffusion/helpers.js +7 -2
  177. package/src/inference/pipelines/diffusion/pipeline.js +2 -1
  178. package/src/inference/pipelines/diffusion/sd3-transformer.js +10 -10
  179. package/src/inference/pipelines/diffusion/vae.js +3 -7
  180. package/src/inference/pipelines/energy/pipeline.js +27 -21
  181. package/src/inference/pipelines/energy/quintel.d.ts +5 -0
  182. package/src/inference/pipelines/energy/quintel.js +11 -0
  183. package/src/inference/pipelines/energy-head/row-head-pipeline.js +17 -13
  184. package/src/inference/pipelines/structured/json-head-pipeline.js +26 -11
  185. package/src/inference/pipelines/text/attention/projections.js +151 -101
  186. package/src/inference/pipelines/text/attention/record.js +62 -8
  187. package/src/inference/pipelines/text/attention/run.js +62 -8
  188. package/src/inference/pipelines/text/config.js +3 -4
  189. package/src/inference/pipelines/text/embed.js +2 -8
  190. package/src/inference/pipelines/text/execution-plan.js +41 -19
  191. package/src/inference/pipelines/text/execution-v0-contract-helpers.d.ts +59 -0
  192. package/src/inference/pipelines/text/execution-v0-contract-helpers.js +937 -0
  193. package/src/inference/pipelines/text/execution-v0-runtime-builders.d.ts +15 -0
  194. package/src/inference/pipelines/text/execution-v0-runtime-builders.js +279 -0
  195. package/src/inference/pipelines/text/execution-v0.js +62 -1013
  196. package/src/inference/pipelines/text/generator-steps.d.ts +46 -0
  197. package/src/inference/pipelines/text/generator-steps.js +298 -207
  198. package/src/inference/pipelines/text/generator.js +6 -23
  199. package/src/inference/pipelines/text/init.js +78 -20
  200. package/src/inference/pipelines/text/kernel-path-auto-select.js +2 -0
  201. package/src/inference/pipelines/text/kernel-trace.d.ts +2 -0
  202. package/src/inference/pipelines/text/kernel-trace.js +6 -0
  203. package/src/inference/pipelines/text/layer.js +3 -9
  204. package/src/inference/pipelines/text/linear-attention.d.ts +10 -0
  205. package/src/inference/pipelines/text/linear-attention.js +80 -6
  206. package/src/inference/pipelines/text/logits/gpu.js +10 -5
  207. package/src/inference/pipelines/text/logits/index.js +10 -11
  208. package/src/inference/pipelines/text/logits/utils.d.ts +7 -0
  209. package/src/inference/pipelines/text/logits/utils.js +9 -0
  210. package/src/inference/pipelines/text/lora-apply.js +50 -32
  211. package/src/inference/pipelines/text/model-load.js +279 -104
  212. package/src/inference/pipelines/text/moe-cache.js +5 -4
  213. package/src/inference/pipelines/text/moe-cpu-gptoss.js +74 -69
  214. package/src/inference/pipelines/text/moe-cpu.js +42 -38
  215. package/src/inference/pipelines/text/moe-gpu.js +110 -86
  216. package/src/inference/pipelines/text/ops.js +90 -90
  217. package/src/inference/pipelines/text/probes.js +9 -9
  218. package/src/inference/pipelines/text/weights.js +17 -7
  219. package/src/inference/pipelines/text.js +13 -1
  220. package/src/inference/speculative.d.ts +2 -2
  221. package/src/inference/speculative.js +4 -18
  222. package/src/inference/test-harness.d.ts +1 -1
  223. package/src/inference/test-harness.js +15 -5
  224. package/src/inference/tokenizer.d.ts +0 -5
  225. package/src/inference/tokenizer.js +4 -23
  226. package/src/inference/tokenizers/bpe.js +9 -0
  227. package/src/inference/tokenizers/bundled.js +20 -0
  228. package/src/inference/tokenizers/sentencepiece.js +12 -0
  229. package/src/loader/doppler-loader.js +38 -22
  230. package/src/loader/dtype-utils.js +3 -44
  231. package/src/loader/embedding-loader.js +7 -3
  232. package/src/loader/experts/expert-cache.js +13 -6
  233. package/src/loader/experts/expert-loader.js +10 -6
  234. package/src/loader/final-weights-loader.js +8 -4
  235. package/src/loader/layer-loader.js +2 -1
  236. package/src/loader/loader-state.js +2 -2
  237. package/src/loader/memory-monitor.js +8 -0
  238. package/src/loader/multi-model-loader.d.ts +14 -0
  239. package/src/loader/multi-model-loader.js +70 -24
  240. package/src/loader/shard-cache.js +81 -12
  241. package/src/loader/shard-resolver.js +25 -3
  242. package/src/loader/tensors/tensor-loader.js +209 -144
  243. package/src/loader/tensors/tensor-reader.js +76 -19
  244. package/src/loader/weight-downcast.js +1 -1
  245. package/src/memory/buffer-pool.d.ts +9 -1
  246. package/src/memory/buffer-pool.js +109 -44
  247. package/src/memory/unified-detect.js +1 -1
  248. package/src/rules/inference/kernel-path.rules.json +24 -8
  249. package/src/rules/rule-registry.js +25 -1
  250. package/src/storage/backends/opfs-store.js +68 -24
  251. package/src/storage/downloader.js +364 -83
  252. package/src/storage/index.d.ts +3 -0
  253. package/src/storage/index.js +3 -0
  254. package/src/storage/preflight.d.ts +2 -2
  255. package/src/storage/preflight.js +24 -2
  256. package/src/storage/quickstart-downloader.js +11 -5
  257. package/src/storage/registry.js +10 -4
  258. package/src/storage/reports.js +1 -1
  259. package/src/storage/shard-manager.d.ts +15 -1
  260. package/src/storage/shard-manager.js +51 -3
  261. package/src/storage/source-artifact-store.d.ts +52 -0
  262. package/src/storage/source-artifact-store.js +234 -0
  263. package/src/tooling/command-api-constants.d.ts +9 -0
  264. package/src/tooling/command-api-constants.js +9 -0
  265. package/src/tooling/command-api-family-normalizers.d.ts +9 -0
  266. package/src/tooling/command-api-family-normalizers.js +343 -0
  267. package/src/tooling/command-api-helpers.d.ts +25 -0
  268. package/src/tooling/command-api-helpers.js +262 -0
  269. package/src/tooling/command-api.js +16 -602
  270. package/src/tooling/command-envelope.js +4 -1
  271. package/src/tooling/command-runner-shared.js +52 -18
  272. package/src/tooling/lean-execution-contract.js +150 -3
  273. package/src/tooling/node-browser-command-runner.js +161 -271
  274. package/src/tooling/node-command-runner.js +29 -3
  275. package/src/tooling/node-converter.js +27 -1
  276. package/src/tooling/node-source-runtime.d.ts +1 -1
  277. package/src/tooling/node-source-runtime.js +84 -3
  278. package/src/tooling/node-webgpu.js +24 -21
  279. package/src/tooling/opfs-cache.js +21 -4
  280. package/src/tooling/runtime-input-composition.d.ts +38 -0
  281. package/src/tooling/runtime-input-composition.js +86 -0
  282. package/src/tooling/source-runtime-bundle.d.ts +40 -5
  283. package/src/tooling/source-runtime-bundle.js +261 -34
  284. package/src/tooling/source-runtime-materializer.d.ts +6 -0
  285. package/src/tooling/source-runtime-materializer.js +93 -0
  286. package/src/training/attention-backward.js +32 -17
  287. package/src/training/autograd.js +80 -52
  288. package/src/training/checkpoint-watch.d.ts +2 -1
  289. package/src/training/checkpoint-watch.js +39 -6
  290. package/src/training/checkpoint.js +40 -11
  291. package/src/training/clip.js +2 -1
  292. package/src/training/datasets/token-batch.js +20 -8
  293. package/src/training/distillation/checkpoint-watch.js +1 -0
  294. package/src/training/distillation/student-fixture.d.ts +22 -0
  295. package/src/training/distillation/student-fixture.js +846 -0
  296. package/src/training/distillation/suite-data.d.ts +45 -0
  297. package/src/training/distillation/suite-data.js +189 -0
  298. package/src/training/lora-pipeline.js +4 -7
  299. package/src/training/lora.js +26 -12
  300. package/src/training/loss.js +5 -6
  301. package/src/training/objectives/cross_entropy.js +2 -5
  302. package/src/training/objectives/distill_kd.js +4 -8
  303. package/src/training/objectives/distill_triplet.js +4 -8
  304. package/src/training/objectives/ul_stage2_base.js +4 -8
  305. package/src/training/operator-command.js +2 -0
  306. package/src/training/optimizer.js +19 -7
  307. package/src/training/runner.js +2 -1
  308. package/src/training/suite.js +18 -978
  309. package/src/training/tensor-factory.d.ts +9 -0
  310. package/src/training/tensor-factory.js +13 -0
  311. package/src/training/trainer.js +3 -5
  312. package/src/training/ul_dataset.js +3 -5
  313. package/src/training/workloads.js +70 -79
  314. package/src/version.js +1 -1
  315. package/tools/convert-safetensors-node.js +22 -16
  316. package/tools/doppler-cli.js +44 -25
@@ -19,6 +19,52 @@ export declare function resolveBatchStop(
19
19
  eosTokenId: number | undefined | null
20
20
  ): number;
21
21
 
22
+ export interface SampledTokenStagingBuffer {
23
+ mapAsync(mode: number): Promise<void>;
24
+ getMappedRange(): ArrayBufferLike;
25
+ unmap(): void;
26
+ destroy(): void;
27
+ }
28
+
29
+ export declare function readSampledTokenFromStagingBuffer(
30
+ stagingBuffer: SampledTokenStagingBuffer,
31
+ options?: {
32
+ ownsStagingBuffer?: boolean;
33
+ hasFinitenessBuffer?: boolean;
34
+ ring?: { advance(): void } | null;
35
+ }
36
+ ): Promise<{
37
+ nextToken: number;
38
+ finitenessStatus: {
39
+ triggered: boolean;
40
+ metadata: string;
41
+ };
42
+ }>;
43
+
44
+ export declare function readMappedBufferCopy(
45
+ stagingBuffer: SampledTokenStagingBuffer,
46
+ options?: {
47
+ ownsStagingBuffer?: boolean;
48
+ }
49
+ ): Promise<ArrayBuffer>;
50
+
51
+ export declare function readBatchTokensFromStagingBuffers(options: {
52
+ tokensStagingBuffer: SampledTokenStagingBuffer;
53
+ stopStagingBuffer?: SampledTokenStagingBuffer | null;
54
+ finitenessStagingBuffer?: SampledTokenStagingBuffer | null;
55
+ tokenCount: number;
56
+ ownsTokensStaging?: boolean;
57
+ ownsStopStaging?: boolean;
58
+ ring?: { advance(): void } | null;
59
+ }): Promise<{
60
+ tokens: number[];
61
+ stopFlags: Uint32Array | null;
62
+ finitenessStatus: {
63
+ triggered: boolean;
64
+ metadata: string;
65
+ };
66
+ }>;
67
+
22
68
  export declare function decodeStep(
23
69
  state: unknown,
24
70
  currentIds: number[],
@@ -113,6 +113,116 @@ export function resolveBatchStop(tokens, stopFlags, stopTokenIds, eosTokenId) {
113
113
  return actualCount;
114
114
  }
115
115
 
116
+ export async function readSampledTokenFromStagingBuffer(stagingBuffer, options = {}) {
117
+ const ownsStagingBuffer = options.ownsStagingBuffer === true;
118
+ const hasFinitenessBuffer = options.hasFinitenessBuffer === true;
119
+ const ring = options.ring ?? null;
120
+ let mapped = false;
121
+
122
+ try {
123
+ await stagingBuffer.mapAsync(GPUMapMode.READ);
124
+ mapped = true;
125
+ const mappedWords = new Uint32Array(stagingBuffer.getMappedRange());
126
+ return {
127
+ nextToken: mappedWords[0],
128
+ finitenessStatus: hasFinitenessBuffer
129
+ ? parseFinitenessStatusWords(mappedWords, 1)
130
+ : parseFinitenessStatusWords(mappedWords, 0),
131
+ };
132
+ } finally {
133
+ if (mapped) {
134
+ stagingBuffer.unmap();
135
+ }
136
+ if (ownsStagingBuffer) {
137
+ stagingBuffer.destroy();
138
+ }
139
+ ring?.advance();
140
+ }
141
+ }
142
+
143
+ export async function readMappedBufferCopy(stagingBuffer, options = {}) {
144
+ const ownsStagingBuffer = options.ownsStagingBuffer !== false;
145
+ let mapped = false;
146
+
147
+ try {
148
+ await stagingBuffer.mapAsync(GPUMapMode.READ);
149
+ mapped = true;
150
+ return stagingBuffer.getMappedRange().slice(0);
151
+ } finally {
152
+ if (mapped) {
153
+ stagingBuffer.unmap();
154
+ }
155
+ if (ownsStagingBuffer) {
156
+ stagingBuffer.destroy();
157
+ }
158
+ }
159
+ }
160
+
161
+ export async function readBatchTokensFromStagingBuffers(options) {
162
+ const {
163
+ tokensStagingBuffer,
164
+ stopStagingBuffer = null,
165
+ finitenessStagingBuffer = null,
166
+ tokenCount,
167
+ ownsTokensStaging = false,
168
+ ownsStopStaging = false,
169
+ ring = null,
170
+ } = options;
171
+ let tokensMapped = false;
172
+ let stopMapped = false;
173
+ let finitenessMapped = false;
174
+
175
+ try {
176
+ const mapPromises = [tokensStagingBuffer.mapAsync(GPUMapMode.READ)];
177
+ if (stopStagingBuffer) {
178
+ mapPromises.push(stopStagingBuffer.mapAsync(GPUMapMode.READ));
179
+ }
180
+ if (finitenessStagingBuffer) {
181
+ mapPromises.push(finitenessStagingBuffer.mapAsync(GPUMapMode.READ));
182
+ }
183
+ await Promise.all(mapPromises);
184
+ tokensMapped = true;
185
+ stopMapped = Boolean(stopStagingBuffer);
186
+ finitenessMapped = Boolean(finitenessStagingBuffer);
187
+
188
+ const tokens = Array.from(
189
+ new Uint32Array(tokensStagingBuffer.getMappedRange()).subarray(0, tokenCount)
190
+ );
191
+ const stopFlags = stopStagingBuffer
192
+ ? new Uint32Array(stopStagingBuffer.getMappedRange().slice(0, tokenCount * 4))
193
+ : null;
194
+ const finitenessStatus = finitenessStagingBuffer
195
+ ? parseFinitenessStatusWords(new Uint32Array(finitenessStagingBuffer.getMappedRange()), 0)
196
+ : { triggered: false, metadata: '' };
197
+
198
+ return {
199
+ tokens,
200
+ stopFlags,
201
+ finitenessStatus,
202
+ };
203
+ } finally {
204
+ if (finitenessMapped) {
205
+ finitenessStagingBuffer.unmap();
206
+ }
207
+ if (tokensMapped) {
208
+ tokensStagingBuffer.unmap();
209
+ }
210
+ if (stopMapped) {
211
+ stopStagingBuffer.unmap();
212
+ }
213
+ if (finitenessStagingBuffer) {
214
+ finitenessStagingBuffer.destroy();
215
+ }
216
+ if (ownsTokensStaging) {
217
+ tokensStagingBuffer.destroy();
218
+ }
219
+ if (ownsStopStaging) {
220
+ stopStagingBuffer?.destroy();
221
+ }
222
+ ring?.advance();
223
+ }
224
+ }
225
+
116
226
  async function runDecodeLayers(state, tokenId, opts, helpers) {
117
227
  const config = state.modelConfig;
118
228
  const debugCheckBuffer = state.debug ? helpers.debugCheckBuffer : undefined;
@@ -352,17 +462,11 @@ export async function decodeStep(state, currentIds, opts, helpers) {
352
462
  throw new Error('[Pipeline] GPU readback disabled for sampling');
353
463
  }
354
464
 
355
- await stagingBuffer.mapAsync(GPUMapMode.READ);
356
- const mapped = new Uint32Array(stagingBuffer.getMappedRange());
357
- const nextToken = mapped[0];
358
- const finitenessStatus = state.finitenessBuffer
359
- ? parseFinitenessStatusWords(mapped, 1)
360
- : parseFinitenessStatusWords(mapped, 0);
361
- stagingBuffer.unmap();
362
- if (ownsStagingBuffer) {
363
- stagingBuffer.destroy();
364
- }
365
- ring?.advance();
465
+ const { nextToken, finitenessStatus } = await readSampledTokenFromStagingBuffer(stagingBuffer, {
466
+ ownsStagingBuffer,
467
+ hasFinitenessBuffer: Boolean(state.finitenessBuffer),
468
+ ring,
469
+ });
366
470
 
367
471
  if (finitenessStatus.triggered) {
368
472
  releaseBuffer(logitsBuffer);
@@ -499,10 +603,7 @@ export async function decodeStep(state, currentIds, opts, helpers) {
499
603
  const enc = debugDevice.createCommandEncoder();
500
604
  enc.copyBufferToBuffer(hiddenStates, 0, staging, 0, sampleSize);
501
605
  debugDevice.queue.submit([enc.finish()]);
502
- await staging.mapAsync(GPUMapMode.READ);
503
- const data = new Float32Array(staging.getMappedRange().slice(0));
504
- staging.unmap();
505
- staging.destroy();
606
+ const data = new Float32Array(await readMappedBufferCopy(staging));
506
607
  const nanCount = Array.from(data).filter(x => !Number.isFinite(x)).length;
507
608
  const nonZero = Array.from(data).filter(x => Number.isFinite(x) && x !== 0).slice(0, 5);
508
609
  log.debug('Decode', `[1] HIDDEN_AFTER_LAYERS: nan=${nanCount}/${data.length}, nonZero=${nonZero.length}, sample=[${nonZero.map(x => x.toFixed(4)).join(', ')}]`);
@@ -854,225 +955,215 @@ export async function generateNTokensGPU(state, startToken, N, currentIds, opts,
854
955
  })
855
956
  : null;
856
957
  const ownsStopStaging = useGpuStopFlags && !ringSlot?.stagingStop;
958
+ let finitenessStagingBuffer = null;
959
+ let readbackCleanupDelegated = false;
960
+ try {
961
+ if (state.finitenessBuffer) {
962
+ device.queue.writeBuffer(state.finitenessBuffer, 0, new Uint32Array([0, 0, 0, 0]));
963
+ }
857
964
 
858
- if (state.finitenessBuffer) {
859
- device.queue.writeBuffer(state.finitenessBuffer, 0, new Uint32Array([0, 0, 0, 0]));
860
- }
965
+ device.queue.writeBuffer(tokensBuffer, 0, new Uint32Array([startToken]));
966
+ if (stopBuffer) {
967
+ const stopElements = stopBuffer.size / 4;
968
+ const zeroStopData = ringSlot?.zeroStopData;
969
+ const clearData = zeroStopData && zeroStopData.length <= stopElements
970
+ ? zeroStopData
971
+ : new Uint32Array(stopElements);
972
+ device.queue.writeBuffer(stopBuffer, 0, clearData);
973
+ }
861
974
 
862
- device.queue.writeBuffer(tokensBuffer, 0, new Uint32Array([startToken]));
863
- if (stopBuffer) {
864
- const stopElements = stopBuffer.size / 4;
865
- const zeroStopData = ringSlot?.zeroStopData;
866
- const clearData = zeroStopData && zeroStopData.length <= stopElements
867
- ? zeroStopData
868
- : new Uint32Array(stopElements);
869
- device.queue.writeBuffer(stopBuffer, 0, clearData);
870
- }
975
+ const context = helpers.buildLayerContext(recorder, true, opts.debugLayers, executionPlan);
976
+ const embedBufferRaw = state.weights.get('embed');
977
+ if (isCpuWeightBuffer(embedBufferRaw)) {
978
+ throw new Error('[Pipeline] GPU-only decode not supported with CPU-resident embeddings.');
979
+ }
980
+ if (!(embedBufferRaw instanceof GPUBuffer) && !isWeightBuffer(embedBufferRaw)) {
981
+ throw new Error('Embed buffer not found or not a GPUBuffer/WeightBuffer');
982
+ }
983
+ const embedBuffer = isWeightBuffer(embedBufferRaw) ? embedBufferRaw.buffer : embedBufferRaw;
984
+ const embedDtype = isWeightBuffer(embedBufferRaw) ? getWeightDtype(embedBufferRaw) : null;
985
+ const activationDtype = getEffectiveActivationDtype(state, opts);
986
+
987
+ for (let i = 0; i < N; i++) {
988
+ const currentPos = state.currentSeqLen + i;
989
+ context.currentSeqLen = currentPos;
990
+ context.currentTokenIds = [startToken];
991
+ context.decodeBuffers?.resetPingPong();
992
+
993
+ const hiddenTensor = await embed(tokensBuffer, embedBuffer, {
994
+ hiddenSize: config.hiddenSize,
995
+ vocabSize: config.vocabSize,
996
+ scaleEmbeddings: config.scaleEmbeddings,
997
+ recorder,
998
+ transpose: state.embeddingTranspose,
999
+ debugProbes: state.runtimeConfig.shared.debug.probes,
1000
+ activationDtype,
1001
+ embeddingDtype: selectRuleValue('inference', 'dtype', 'f16OrF32FromDtype', { dtype: embedDtype }),
1002
+ numTokens: 1,
1003
+ indexOffset: i,
1004
+ });
871
1005
 
872
- const context = helpers.buildLayerContext(recorder, true, opts.debugLayers, executionPlan);
873
- const embedBufferRaw = state.weights.get('embed');
874
- if (isCpuWeightBuffer(embedBufferRaw)) {
875
- throw new Error('[Pipeline] GPU-only decode not supported with CPU-resident embeddings.');
876
- }
877
- if (!(embedBufferRaw instanceof GPUBuffer) && !isWeightBuffer(embedBufferRaw)) {
878
- throw new Error('Embed buffer not found or not a GPUBuffer/WeightBuffer');
879
- }
880
- const embedBuffer = isWeightBuffer(embedBufferRaw) ? embedBufferRaw.buffer : embedBufferRaw;
881
- const embedDtype = isWeightBuffer(embedBufferRaw) ? getWeightDtype(embedBufferRaw) : null;
882
- const activationDtype = getEffectiveActivationDtype(state, opts);
1006
+ let hiddenStatesBuffer = hiddenTensor.buffer;
1007
+ for (let l = 0; l < config.numLayers; l++) {
1008
+ const prevStates = hiddenStatesBuffer;
1009
+ hiddenStatesBuffer = (await processLayer(l, hiddenStatesBuffer, 1, false, context));
1010
+ context.decodeBuffers?.swapPingPong();
1011
+ if (prevStates instanceof GPUBuffer && prevStates !== hiddenStatesBuffer) {
1012
+ const ownsBuffer = context.decodeBuffers?.ownsBuffer(prevStates);
1013
+ if (!ownsBuffer) {
1014
+ recorder.trackTemporaryBuffer(prevStates);
1015
+ }
1016
+ }
1017
+ }
883
1018
 
884
- for (let i = 0; i < N; i++) {
885
- const currentPos = state.currentSeqLen + i;
886
- context.currentSeqLen = currentPos;
887
- context.currentTokenIds = [startToken];
888
- context.decodeBuffers?.resetPingPong();
1019
+ const logits = await recordLogitsGPU(
1020
+ recorder,
1021
+ hiddenStatesBuffer,
1022
+ 1,
1023
+ helpers.getLogitsWeights(),
1024
+ helpers.getLogitsConfig()
1025
+ );
1026
+ const { logitsBuffer, vocabSize, logitsDtype } = logits;
889
1027
 
890
- const hiddenTensor = await embed(tokensBuffer, embedBuffer, {
891
- hiddenSize: config.hiddenSize,
892
- vocabSize: config.vocabSize,
893
- scaleEmbeddings: config.scaleEmbeddings,
894
- recorder,
895
- transpose: state.embeddingTranspose,
896
- debugProbes: state.runtimeConfig.shared.debug.probes,
897
- activationDtype,
898
- embeddingDtype: selectRuleValue('inference', 'dtype', 'f16OrF32FromDtype', { dtype: embedDtype }),
899
- numTokens: 1,
900
- indexOffset: i,
901
- });
1028
+ const outputIndex = i + 1;
1029
+ if (opts.temperature < samplingDefaults.greedyThreshold) {
1030
+ await recordArgmax(recorder, logitsBuffer, vocabSize, {
1031
+ padTokenId,
1032
+ logitSoftcap,
1033
+ logitsDtype,
1034
+ outputBuffer: tokensBuffer,
1035
+ outputIndex,
1036
+ });
1037
+ } else {
1038
+ await recordGPUSample(recorder, logitsBuffer, vocabSize, {
1039
+ temperature: opts.temperature,
1040
+ topK: opts.topK,
1041
+ padTokenId,
1042
+ logitSoftcap,
1043
+ logitsDtype,
1044
+ outputBuffer: tokensBuffer,
1045
+ outputIndex,
1046
+ greedyThreshold: samplingDefaults.greedyThreshold,
1047
+ });
1048
+ }
902
1049
 
903
- let hiddenStatesBuffer = hiddenTensor.buffer;
904
- for (let l = 0; l < config.numLayers; l++) {
905
- const prevStates = hiddenStatesBuffer;
906
- hiddenStatesBuffer = (await processLayer(l, hiddenStatesBuffer, 1, false, context));
907
- context.decodeBuffers?.swapPingPong();
908
- if (prevStates instanceof GPUBuffer && prevStates !== hiddenStatesBuffer) {
909
- const ownsBuffer = context.decodeBuffers?.ownsBuffer(prevStates);
910
- if (!ownsBuffer) {
911
- recorder.trackTemporaryBuffer(prevStates);
912
- }
1050
+ const stopCheck = useGpuStopFlags
1051
+ ? recordCheckStop(recorder, {
1052
+ sampledTokenBuffer: tokensBuffer,
1053
+ shouldStopBuffer: stopBuffer,
1054
+ tokenIndex: outputIndex,
1055
+ eosTokenId,
1056
+ maxTokens: maxSeqLen,
1057
+ currentPos,
1058
+ })
1059
+ : null;
1060
+
1061
+ if (hiddenStatesBuffer instanceof GPUBuffer && !context.decodeBuffers?.ownsBuffer(hiddenStatesBuffer)) {
1062
+ recorder.trackTemporaryBuffer(hiddenStatesBuffer);
1063
+ }
1064
+ if (logitsBuffer instanceof GPUBuffer) {
1065
+ recorder.trackTemporaryBuffer(logitsBuffer);
1066
+ }
1067
+ if (stopCheck instanceof GPUBuffer && stopCheck !== stopBuffer) {
1068
+ recorder.trackTemporaryBuffer(stopCheck);
913
1069
  }
914
1070
  }
915
1071
 
916
- const logits = await recordLogitsGPU(
917
- recorder,
918
- hiddenStatesBuffer,
919
- 1,
920
- helpers.getLogitsWeights(),
921
- helpers.getLogitsConfig()
922
- );
923
- const { logitsBuffer, vocabSize, logitsDtype } = logits;
1072
+ const recordMs = performance.now() - recordStart;
1073
+ state.stats.decodeRecordMs = (state.stats.decodeRecordMs ?? 0) + recordMs;
924
1074
 
925
- const outputIndex = i + 1;
926
- if (opts.temperature < samplingDefaults.greedyThreshold) {
927
- await recordArgmax(recorder, logitsBuffer, vocabSize, {
928
- padTokenId,
929
- logitSoftcap,
930
- logitsDtype,
931
- outputBuffer: tokensBuffer,
932
- outputIndex,
933
- });
934
- } else {
935
- await recordGPUSample(recorder, logitsBuffer, vocabSize, {
936
- temperature: opts.temperature,
937
- topK: opts.topK,
938
- padTokenId,
939
- logitSoftcap,
940
- logitsDtype,
941
- outputBuffer: tokensBuffer,
942
- outputIndex,
943
- greedyThreshold: samplingDefaults.greedyThreshold,
944
- });
1075
+ const encoder = recorder.getEncoder();
1076
+ encoder.copyBufferToBuffer(tokensBuffer, 4, tokensStagingBuffer, 0, N * 4);
1077
+ if (useGpuStopFlags && stopBuffer && stopStagingBuffer) {
1078
+ encoder.copyBufferToBuffer(stopBuffer, 4, stopStagingBuffer, 0, N * 4);
945
1079
  }
946
1080
 
947
- const stopCheck = useGpuStopFlags
948
- ? recordCheckStop(recorder, {
949
- sampledTokenBuffer: tokensBuffer,
950
- shouldStopBuffer: stopBuffer,
951
- tokenIndex: outputIndex,
952
- eosTokenId,
953
- maxTokens: maxSeqLen,
954
- currentPos,
955
- })
956
- : null;
957
-
958
- if (hiddenStatesBuffer instanceof GPUBuffer && !context.decodeBuffers?.ownsBuffer(hiddenStatesBuffer)) {
959
- recorder.trackTemporaryBuffer(hiddenStatesBuffer);
960
- }
961
- if (logitsBuffer instanceof GPUBuffer) {
962
- recorder.trackTemporaryBuffer(logitsBuffer);
963
- }
964
- if (stopCheck instanceof GPUBuffer && stopCheck !== stopBuffer) {
965
- recorder.trackTemporaryBuffer(stopCheck);
1081
+ if (state.finitenessBuffer) {
1082
+ finitenessStagingBuffer = device.createBuffer({
1083
+ size: 16,
1084
+ usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ
1085
+ });
1086
+ encoder.copyBufferToBuffer(state.finitenessBuffer, 0, finitenessStagingBuffer, 0, 16);
966
1087
  }
967
- }
968
1088
 
969
- const recordMs = performance.now() - recordStart;
970
- state.stats.decodeRecordMs = (state.stats.decodeRecordMs ?? 0) + recordMs;
1089
+ recorder.submit();
971
1090
 
972
- const encoder = recorder.getEncoder();
973
- encoder.copyBufferToBuffer(tokensBuffer, 4, tokensStagingBuffer, 0, N * 4);
974
- if (useGpuStopFlags && stopBuffer && stopStagingBuffer) {
975
- encoder.copyBufferToBuffer(stopBuffer, 4, stopStagingBuffer, 0, N * 4);
976
- }
1091
+ if (!allowReadback('pipeline.decode.sample')) {
1092
+ throw new Error('[Pipeline] GPU readback disabled for sampling');
1093
+ }
977
1094
 
978
- let finitenessStagingBuffer = null;
979
- if (state.finitenessBuffer) {
980
- finitenessStagingBuffer = device.createBuffer({
981
- size: 16,
982
- usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ
1095
+ const readbackStart = performance.now();
1096
+ readbackCleanupDelegated = true;
1097
+ const readback = await readBatchTokensFromStagingBuffers({
1098
+ tokensStagingBuffer,
1099
+ stopStagingBuffer,
1100
+ finitenessStagingBuffer,
1101
+ tokenCount: N,
1102
+ ownsTokensStaging,
1103
+ ownsStopStaging,
1104
+ ring,
983
1105
  });
984
- encoder.copyBufferToBuffer(state.finitenessBuffer, 0, finitenessStagingBuffer, 0, 16);
985
- }
986
-
987
- recorder.submit();
988
-
989
- if (!allowReadback('pipeline.decode.sample')) {
990
- throw new Error('[Pipeline] GPU readback disabled for sampling');
991
- }
992
-
993
- const readbackStart = performance.now();
994
- const mapPromises = [tokensStagingBuffer.mapAsync(GPUMapMode.READ)];
995
- if (stopStagingBuffer) {
996
- mapPromises.push(stopStagingBuffer.mapAsync(GPUMapMode.READ));
997
- }
998
- if (finitenessStagingBuffer) {
999
- mapPromises.push(finitenessStagingBuffer.mapAsync(GPUMapMode.READ));
1000
- }
1001
- await Promise.all(mapPromises);
1002
- const readbackWaitMs = performance.now() - readbackStart;
1003
- state.stats.decodeReadbackWaitMs = (state.stats.decodeReadbackWaitMs ?? 0) + readbackWaitMs;
1004
-
1005
- let isInfinite = false;
1006
- let metadata = '';
1007
- if (finitenessStagingBuffer) {
1008
- const finitenessData = new Uint32Array(finitenessStagingBuffer.getMappedRange());
1009
- const finitenessStatus = parseFinitenessStatusWords(finitenessData, 0);
1010
- isInfinite = finitenessStatus.triggered;
1011
- metadata = finitenessStatus.metadata;
1012
- finitenessStagingBuffer.unmap();
1013
- finitenessStagingBuffer.destroy();
1014
- }
1015
-
1016
- const submitWaitMs = recorder.getSubmitLatencyMs();
1017
- if (submitWaitMs != null) {
1018
- state.stats.decodeSubmitWaitMs = (state.stats.decodeSubmitWaitMs ?? 0) + submitWaitMs;
1019
- }
1020
-
1021
- getUniformCache().flushPendingDestruction();
1022
-
1023
- const tokensView = new Uint32Array(tokensStagingBuffer.getMappedRange());
1024
- const tokens = Array.from(tokensView.subarray(0, N));
1106
+ const readbackWaitMs = performance.now() - readbackStart;
1107
+ state.stats.decodeReadbackWaitMs = (state.stats.decodeReadbackWaitMs ?? 0) + readbackWaitMs;
1025
1108
 
1026
- const stopFlags = stopStagingBuffer
1027
- ? new Uint32Array(stopStagingBuffer.getMappedRange().slice(0, N * 4))
1028
- : null;
1109
+ const isInfinite = readback.finitenessStatus.triggered;
1110
+ const metadata = readback.finitenessStatus.metadata;
1029
1111
 
1030
- if (stopFlags) {
1031
- log.debug('Pipeline', `[STOP] N=${N} flags=[${Array.from(stopFlags).join(',')}] tokens=[${tokens.join(',')}] eos=${eosTokenId}`);
1032
- }
1112
+ const submitWaitMs = recorder.getSubmitLatencyMs();
1113
+ if (submitWaitMs != null) {
1114
+ state.stats.decodeSubmitWaitMs = (state.stats.decodeSubmitWaitMs ?? 0) + submitWaitMs;
1115
+ }
1033
1116
 
1034
- const actualCount = resolveBatchStop(tokens, stopFlags, stopTokenIds, eosToken);
1117
+ getUniformCache().flushPendingDestruction();
1035
1118
 
1036
- tokensStagingBuffer.unmap();
1037
- if (stopStagingBuffer) {
1038
- stopStagingBuffer.unmap();
1039
- }
1119
+ const tokens = readback.tokens;
1120
+ const stopFlags = readback.stopFlags;
1040
1121
 
1041
- const generatedTokens = tokens.slice(0, actualCount);
1122
+ if (stopFlags) {
1123
+ log.debug('Pipeline', `[STOP] N=${N} flags=[${Array.from(stopFlags).join(',')}] tokens=[${tokens.join(',')}] eos=${eosTokenId}`);
1124
+ }
1042
1125
 
1043
- if (ownsTokensBuffer) tokensBuffer.destroy();
1044
- if (ownsStopBuffer) stopBuffer?.destroy();
1045
- if (ownsTokensStaging) tokensStagingBuffer.destroy();
1046
- if (ownsStopStaging) stopStagingBuffer?.destroy();
1126
+ const actualCount = resolveBatchStop(tokens, stopFlags, stopTokenIds, eosToken);
1127
+ const generatedTokens = tokens.slice(0, actualCount);
1047
1128
 
1048
- if (isInfinite) {
1049
- throw new FinitenessError(`F16 bounds exceeded during batch generation${metadata}`);
1050
- }
1129
+ if (isInfinite) {
1130
+ throw new FinitenessError(`F16 bounds exceeded during batch generation${metadata}`);
1131
+ }
1051
1132
 
1052
- if (opts.profile && recorder.isProfilingEnabled()) {
1053
- const timings = await recorder.resolveProfileTimings();
1054
- const total = sumProfileTimings(timings);
1055
- if (total !== null) {
1056
- state.stats.gpuTimeDecodeMs = (state.stats.gpuTimeDecodeMs ?? 0) + total;
1133
+ if (opts.profile && recorder.isProfilingEnabled()) {
1134
+ const timings = await recorder.resolveProfileTimings();
1135
+ const total = sumProfileTimings(timings);
1136
+ if (total !== null) {
1137
+ state.stats.gpuTimeDecodeMs = (state.stats.gpuTimeDecodeMs ?? 0) + total;
1138
+ }
1139
+ if (timings) {
1140
+ recordDecodeProfileStep(state, {
1141
+ batch: true,
1142
+ stepStart: state.decodeStepCount + 1,
1143
+ stepCount: actualCount,
1144
+ timings,
1145
+ totalMs: total ?? undefined,
1146
+ });
1147
+ const stepStart = state.decodeStepCount + 1;
1148
+ if (shouldLogProfileStep(state, stepStart)) {
1149
+ log.warn('Profile', `Batch decode (N=${N}):`);
1150
+ log.warn('Profile', CommandRecorder.formatProfileReport(timings));
1151
+ }
1152
+ }
1057
1153
  }
1058
- if (timings) {
1059
- recordDecodeProfileStep(state, {
1060
- batch: true,
1061
- stepStart: state.decodeStepCount + 1,
1062
- stepCount: actualCount,
1063
- timings,
1064
- totalMs: total ?? undefined,
1065
- });
1066
- const stepStart = state.decodeStepCount + 1;
1067
- if (shouldLogProfileStep(state, stepStart)) {
1068
- log.warn('Profile', `Batch decode (N=${N}):`);
1069
- log.warn('Profile', CommandRecorder.formatProfileReport(timings));
1154
+
1155
+ state.currentSeqLen += actualCount;
1156
+ return { tokens: generatedTokens, actualCount };
1157
+ } finally {
1158
+ if (!readbackCleanupDelegated) {
1159
+ if (finitenessStagingBuffer) {
1160
+ finitenessStagingBuffer.destroy();
1070
1161
  }
1162
+ if (ownsTokensStaging) tokensStagingBuffer.destroy();
1163
+ if (ownsStopStaging) stopStagingBuffer?.destroy();
1164
+ ring?.advance();
1071
1165
  }
1166
+ if (ownsTokensBuffer) tokensBuffer.destroy();
1167
+ if (ownsStopBuffer) stopBuffer?.destroy();
1072
1168
  }
1073
-
1074
- state.currentSeqLen += actualCount;
1075
- ring?.advance();
1076
-
1077
- return { tokens: generatedTokens, actualCount };
1078
1169
  }
@@ -1043,18 +1043,9 @@ export class PipelineGenerator {
1043
1043
  if (allowReadback(`pipeline.prefill.layer-${l}`)) {
1044
1044
  try {
1045
1045
  const sampleSize = config.hiddenSize * activationBytes;
1046
- const staging = device.createBuffer({
1047
- size: sampleSize,
1048
- usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ,
1049
- });
1050
- const enc = device.createCommandEncoder();
1051
1046
  const lastTokenOffset = (numTokens - 1) * config.hiddenSize * activationBytes;
1052
- enc.copyBufferToBuffer(currentHiddenBuffer, lastTokenOffset, staging, 0, sampleSize);
1053
- device.queue.submit([enc.finish()]);
1054
- await staging.mapAsync(GPUMapMode.READ);
1055
- const data = decodeReadback(staging.getMappedRange().slice(0), activationDtype);
1056
- staging.unmap();
1057
- staging.destroy();
1047
+ const readback = await readBufferSlice(currentHiddenBuffer, lastTokenOffset, sampleSize);
1048
+ const data = decodeReadback(readback, activationDtype);
1058
1049
  let min = Infinity;
1059
1050
  let max = -Infinity;
1060
1051
  let maxAbs = 0;
@@ -1112,20 +1103,12 @@ export class PipelineGenerator {
1112
1103
  if (opts.debug) {
1113
1104
  log.debug('Pipeline', `LAYER_LOOP_DONE, currentHiddenBuffer type=${currentHiddenBuffer?.constructor?.name}`);
1114
1105
  if (currentHiddenBuffer && allowReadback('pipeline.prefill.final-hidden')) {
1115
- const device = getDevice();
1116
1106
  const lastTokenOffset = (numTokens - 1) * config.hiddenSize * activationBytes;
1117
1107
  const sampleSize = config.hiddenSize * activationBytes;
1118
- const staging = device.createBuffer({
1119
- size: sampleSize,
1120
- usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ,
1121
- });
1122
- const enc = device.createCommandEncoder();
1123
- enc.copyBufferToBuffer(currentHiddenBuffer, lastTokenOffset, staging, 0, sampleSize);
1124
- device.queue.submit([enc.finish()]);
1125
- await staging.mapAsync(GPUMapMode.READ);
1126
- const data = decodeReadback(staging.getMappedRange().slice(0), activationDtype);
1127
- staging.unmap();
1128
- staging.destroy();
1108
+ const data = decodeReadback(
1109
+ await readBufferSlice(currentHiddenBuffer, lastTokenOffset, sampleSize),
1110
+ activationDtype
1111
+ );
1129
1112
  const nanCount = Array.from(data).filter(x => !Number.isFinite(x)).length;
1130
1113
  const nonZero = Array.from(data).filter(x => Number.isFinite(x) && x !== 0).slice(0, 5);
1131
1114
  log.debug('Pipeline', `FINAL_HIDDEN[pos=${numTokens - 1}]: nan=${nanCount}/${data.length}, sample=[${nonZero.map(x => x.toFixed(4)).join(', ')}]`);