@simulatte/doppler 0.1.6 → 0.1.8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (355) hide show
  1. package/CHANGELOG.md +145 -0
  2. package/README.md +16 -23
  3. package/package.json +30 -32
  4. package/src/adapters/adapter-registry.js +12 -1
  5. package/src/adapters/lora-loader.js +23 -6
  6. package/src/bridge/extension-client.d.ts +5 -0
  7. package/src/bridge/extension-client.js +40 -0
  8. package/src/bridge/index.d.ts +2 -1
  9. package/src/bridge/index.js +6 -4
  10. package/src/browser/browser-converter.js +31 -1
  11. package/src/browser/file-picker.js +6 -0
  12. package/src/browser/safetensors-parser-browser.js +84 -1
  13. package/src/browser/shard-io-browser.js +2 -2
  14. package/src/browser/tensor-source-download.js +8 -2
  15. package/src/browser/tensor-source-http.d.ts +1 -0
  16. package/src/browser/tensor-source-http.js +5 -1
  17. package/src/client/doppler-api.browser.js +20 -4
  18. package/src/client/doppler-api.js +19 -3
  19. package/src/client/doppler-provider/generation.js +12 -0
  20. package/src/client/doppler-provider/model-manager.d.ts +10 -0
  21. package/src/client/doppler-provider/model-manager.js +91 -19
  22. package/src/client/doppler-provider/source-runtime.d.ts +2 -1
  23. package/src/client/doppler-provider/source-runtime.js +132 -13
  24. package/src/client/doppler-registry.json +5 -20
  25. package/src/config/backward-registry-loader.js +17 -2
  26. package/src/config/execution-v0-contract-check.js +113 -15
  27. package/src/config/kernel-path-contract-check.js +57 -29
  28. package/src/config/kernel-path-loader.d.ts +5 -0
  29. package/src/config/kernel-path-loader.js +18 -36
  30. package/src/config/kernels/kernel-ref-digests.js +1 -1
  31. package/src/config/kernels/registry.js +14 -1
  32. package/src/config/kernels/registry.json +81 -5
  33. package/src/config/loader.d.ts +1 -1
  34. package/src/config/loader.js +15 -2
  35. package/src/config/merge-contract-check.js +66 -4
  36. package/src/config/merge-helpers.js +128 -7
  37. package/src/config/merge.d.ts +1 -0
  38. package/src/config/merge.js +10 -0
  39. package/src/config/param-validator.js +47 -2
  40. package/src/config/presets/kernel-paths/{gemma2-q4k-dequant-f32a.json → gemma2-q4k-dequant-f32a-nosubgroups.json} +3 -3
  41. package/src/config/presets/kernel-paths/gemma3-f16-fused-f32a-online-streamingprefill.json +223 -0
  42. package/src/config/presets/kernel-paths/{gemma3-q4k-dequant-f32a.json → gemma3-q4k-dequant-f32a-nosubgroups.json} +3 -3
  43. package/src/config/presets/kernel-paths/gemma3-q4k-dequant-f32w-f32a-online.json +56 -0
  44. package/src/config/presets/kernel-paths/lfm2-q4k-dequant-f32a-nosubgroups.json +61 -0
  45. package/src/config/presets/kernel-paths/registry.json +43 -8
  46. package/src/config/presets/models/gemma2.json +3 -2
  47. package/src/config/presets/models/gemma3.json +2 -0
  48. package/src/config/presets/models/qwen3.json +4 -3
  49. package/src/config/presets/models/qwen3_5.json +16 -0
  50. package/src/config/presets/runtime/experiments/bench/gemma3-bench-q4k.json +1 -1
  51. package/src/config/presets/runtime/experiments/debug/gemma3-debug-q4k.json +1 -1
  52. package/src/config/presets/runtime/experiments/verify/gemma3-verify.json +1 -1
  53. package/src/config/presets/runtime/kernels/dequant-f16-q4k.json +6 -13
  54. package/src/config/presets/runtime/kernels/dequant-f32-q4k.json +6 -13
  55. package/src/config/presets/runtime/kernels/embeddinggemma-q4k-dequant-f32a.json +37 -0
  56. package/src/config/presets/runtime/kernels/fused-q4k.json +6 -13
  57. package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f16a.json +33 -0
  58. package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f32a-nosubgroups.json +33 -0
  59. package/src/config/presets/runtime/kernels/gemma2-q4k-fused-f32a.json +33 -0
  60. package/src/config/presets/runtime/kernels/safe-q4k.json +6 -13
  61. package/src/config/presets/runtime/model/qwen3-5-layer-probe.json +52 -0
  62. package/src/config/presets/runtime/model/qwen3-5-linear-attn-debug.json +90 -0
  63. package/src/config/presets/runtime/platform/metal-apple-q4k.json +1 -1
  64. package/src/config/runtime.js +6 -1
  65. package/src/config/schema/conversion.schema.d.ts +1 -0
  66. package/src/config/schema/debug.schema.d.ts +5 -0
  67. package/src/config/schema/doppler.schema.js +16 -21
  68. package/src/config/schema/inference-defaults.schema.js +3 -3
  69. package/src/config/schema/kernel-path.schema.d.ts +5 -1
  70. package/src/config/schema/kernel-thresholds.schema.js +12 -4
  71. package/src/config/schema/manifest.schema.d.ts +3 -2
  72. package/src/config/schema/manifest.schema.js +17 -4
  73. package/src/config/schema/storage.schema.js +1 -1
  74. package/src/config/training-defaults.js +30 -22
  75. package/src/converter/conversion-plan.js +104 -11
  76. package/src/converter/core.d.ts +7 -0
  77. package/src/converter/core.js +16 -9
  78. package/src/converter/execution-v0-manifest.js +4 -1
  79. package/src/converter/index.d.ts +1 -0
  80. package/src/converter/index.js +1 -0
  81. package/src/converter/manifest-inference.js +50 -29
  82. package/src/converter/parsers/diffusion.js +0 -3
  83. package/src/converter/parsers/transformer.js +4 -0
  84. package/src/converter/quantization-info.js +40 -16
  85. package/src/converter/quantizer.js +19 -12
  86. package/src/converter/rope-config.js +8 -6
  87. package/src/converter/shard-packer.d.ts +1 -1
  88. package/src/converter/shard-packer.js +4 -1
  89. package/src/converter/tokenizer-utils.d.ts +1 -0
  90. package/src/converter/tokenizer-utils.js +4 -1
  91. package/src/debug/config.js +123 -11
  92. package/src/debug/reference/hf_qwen35_linear_attn_debug.py +268 -0
  93. package/src/debug/signals.js +7 -1
  94. package/src/debug/tensor.d.ts +2 -0
  95. package/src/debug/tensor.js +13 -2
  96. package/src/distribution/p2p-control-plane.js +52 -12
  97. package/src/distribution/p2p-observability.js +43 -7
  98. package/src/distribution/p2p-webrtc-browser.js +20 -0
  99. package/src/distribution/shard-delivery.js +83 -27
  100. package/src/formats/gguf/types.js +33 -16
  101. package/src/formats/rdrr/groups.d.ts +12 -4
  102. package/src/formats/rdrr/groups.js +3 -6
  103. package/src/formats/rdrr/parsing.d.ts +4 -0
  104. package/src/formats/rdrr/parsing.js +53 -3
  105. package/src/formats/rdrr/types.d.ts +2 -1
  106. package/src/gpu/command-recorder.js +86 -61
  107. package/src/gpu/device.d.ts +1 -0
  108. package/src/gpu/device.js +73 -19
  109. package/src/gpu/kernel-tuner/benchmarks.js +326 -316
  110. package/src/gpu/kernel-tuner/cache.js +71 -4
  111. package/src/gpu/kernel-tuner/tuner.js +22 -4
  112. package/src/gpu/kernels/attention.js +15 -34
  113. package/src/gpu/kernels/backward/adam.js +62 -58
  114. package/src/gpu/kernels/backward/attention_backward.js +257 -169
  115. package/src/gpu/kernels/backward/conv2d_backward.js +14 -1
  116. package/src/gpu/kernels/cast.js +191 -149
  117. package/src/gpu/kernels/check-stop.js +33 -44
  118. package/src/gpu/kernels/conv2d.js +27 -17
  119. package/src/gpu/kernels/cross_entropy_loss.js +21 -15
  120. package/src/gpu/kernels/depthwise_conv2d.js +36 -26
  121. package/src/gpu/kernels/dequant.js +178 -126
  122. package/src/gpu/kernels/energy.d.ts +3 -21
  123. package/src/gpu/kernels/energy.js +111 -88
  124. package/src/gpu/kernels/feature-check.js +1 -1
  125. package/src/gpu/kernels/fused_ffn.js +84 -65
  126. package/src/gpu/kernels/fused_matmul_residual.js +56 -33
  127. package/src/gpu/kernels/fused_matmul_rmsnorm.js +62 -45
  128. package/src/gpu/kernels/gather.js +33 -15
  129. package/src/gpu/kernels/gelu.js +19 -11
  130. package/src/gpu/kernels/grouped_pointwise_conv2d.js +33 -23
  131. package/src/gpu/kernels/groupnorm.js +34 -23
  132. package/src/gpu/kernels/index.d.ts +8 -0
  133. package/src/gpu/kernels/index.js +6 -0
  134. package/src/gpu/kernels/kv-quantize.js +5 -2
  135. package/src/gpu/kernels/layernorm.js +35 -19
  136. package/src/gpu/kernels/logit-merge.js +5 -3
  137. package/src/gpu/kernels/matmul-selection.js +47 -4
  138. package/src/gpu/kernels/matmul.d.ts +2 -0
  139. package/src/gpu/kernels/matmul.js +59 -40
  140. package/src/gpu/kernels/modulate.js +23 -15
  141. package/src/gpu/kernels/moe.js +221 -175
  142. package/src/gpu/kernels/pixel_shuffle.js +22 -14
  143. package/src/gpu/kernels/relu.js +18 -10
  144. package/src/gpu/kernels/repeat_channels.js +25 -17
  145. package/src/gpu/kernels/residual.js +37 -27
  146. package/src/gpu/kernels/rmsnorm.js +66 -43
  147. package/src/gpu/kernels/rope.js +3 -0
  148. package/src/gpu/kernels/sample.js +27 -38
  149. package/src/gpu/kernels/sana_linear_attention.js +18 -10
  150. package/src/gpu/kernels/scale.js +18 -11
  151. package/src/gpu/kernels/shader-cache.js +4 -2
  152. package/src/gpu/kernels/silu.js +120 -72
  153. package/src/gpu/kernels/softmax.js +44 -25
  154. package/src/gpu/kernels/split_qg.d.ts +50 -0
  155. package/src/gpu/kernels/split_qg.js +46 -0
  156. package/src/gpu/kernels/split_qg.wgsl +58 -0
  157. package/src/gpu/kernels/split_qg_f16.wgsl +62 -0
  158. package/src/gpu/kernels/split_qkv.js +23 -13
  159. package/src/gpu/kernels/transpose.js +18 -10
  160. package/src/gpu/kernels/transpose.wgsl +5 -3
  161. package/src/gpu/kernels/upsample2d.js +21 -13
  162. package/src/gpu/kernels/utils.js +20 -13
  163. package/src/gpu/partitioned-buffer-pool.js +10 -2
  164. package/src/gpu/perf-guards.js +2 -9
  165. package/src/gpu/profiler.js +27 -22
  166. package/src/gpu/readback-utils.d.ts +16 -0
  167. package/src/gpu/readback-utils.js +41 -0
  168. package/src/gpu/submit-tracker.js +13 -0
  169. package/src/gpu/uniform-cache.d.ts +1 -0
  170. package/src/gpu/uniform-cache.js +30 -9
  171. package/src/gpu/weight-buffer.d.ts +1 -1
  172. package/src/gpu/weight-buffer.js +1 -1
  173. package/src/hotswap/intent-bundle.js +6 -0
  174. package/src/hotswap/manifest.d.ts +10 -1
  175. package/src/hotswap/manifest.js +12 -2
  176. package/src/hotswap/runtime.js +30 -8
  177. package/src/index-browser.d.ts +44 -0
  178. package/src/index-browser.js +14 -0
  179. package/src/inference/browser-harness-contract-helpers.d.ts +5 -0
  180. package/src/inference/browser-harness-contract-helpers.js +28 -0
  181. package/src/inference/browser-harness-diffusion-energy-suites.d.ts +2 -0
  182. package/src/inference/browser-harness-diffusion-energy-suites.js +269 -0
  183. package/src/inference/browser-harness-model-helpers.d.ts +16 -0
  184. package/src/inference/browser-harness-model-helpers.js +217 -0
  185. package/src/inference/browser-harness-report-helpers.d.ts +7 -0
  186. package/src/inference/browser-harness-report-helpers.js +42 -0
  187. package/src/inference/browser-harness-runtime-helpers.d.ts +61 -0
  188. package/src/inference/browser-harness-runtime-helpers.js +415 -0
  189. package/src/inference/browser-harness-suite-helpers.d.ts +28 -0
  190. package/src/inference/browser-harness-suite-helpers.js +268 -0
  191. package/src/inference/browser-harness-text-helpers.d.ts +27 -0
  192. package/src/inference/browser-harness-text-helpers.js +788 -0
  193. package/src/inference/browser-harness.d.ts +8 -0
  194. package/src/inference/browser-harness.js +149 -1996
  195. package/src/inference/kv-cache/base.js +140 -94
  196. package/src/inference/kv-cache/tiered.js +5 -3
  197. package/src/inference/moe-router.js +88 -56
  198. package/src/inference/multi-model-network.js +5 -3
  199. package/src/inference/network-evolution.d.ts +11 -2
  200. package/src/inference/network-evolution.js +20 -21
  201. package/src/inference/pipelines/context.d.ts +3 -0
  202. package/src/inference/pipelines/context.js +142 -2
  203. package/src/inference/pipelines/diffusion/helpers.js +10 -2
  204. package/src/inference/pipelines/diffusion/pipeline.js +2 -1
  205. package/src/inference/pipelines/diffusion/sd3-transformer.js +10 -10
  206. package/src/inference/pipelines/diffusion/text-encoder-gpu.js +8 -2
  207. package/src/inference/pipelines/diffusion/vae.js +3 -7
  208. package/src/inference/pipelines/energy/pipeline.js +27 -21
  209. package/src/inference/pipelines/energy/quintel.d.ts +5 -0
  210. package/src/inference/pipelines/energy/quintel.js +11 -0
  211. package/src/inference/pipelines/energy-head/row-head-pipeline.js +17 -13
  212. package/src/inference/pipelines/structured/json-head-pipeline.js +26 -11
  213. package/src/inference/pipelines/text/attention/output-projection.d.ts +12 -0
  214. package/src/inference/pipelines/text/attention/output-projection.js +8 -0
  215. package/src/inference/pipelines/text/attention/projections.d.ts +10 -1
  216. package/src/inference/pipelines/text/attention/projections.js +192 -112
  217. package/src/inference/pipelines/text/attention/record.js +77 -14
  218. package/src/inference/pipelines/text/attention/run.js +112 -14
  219. package/src/inference/pipelines/text/config.js +17 -4
  220. package/src/inference/pipelines/text/embed.js +2 -8
  221. package/src/inference/pipelines/text/execution-plan.js +46 -23
  222. package/src/inference/pipelines/text/execution-v0-contract-helpers.d.ts +59 -0
  223. package/src/inference/pipelines/text/execution-v0-contract-helpers.js +937 -0
  224. package/src/inference/pipelines/text/execution-v0-runtime-builders.d.ts +15 -0
  225. package/src/inference/pipelines/text/execution-v0-runtime-builders.js +279 -0
  226. package/src/inference/pipelines/text/execution-v0.js +62 -1013
  227. package/src/inference/pipelines/text/generator-runtime.js +5 -0
  228. package/src/inference/pipelines/text/generator-steps.d.ts +52 -0
  229. package/src/inference/pipelines/text/generator-steps.js +340 -221
  230. package/src/inference/pipelines/text/generator.js +56 -40
  231. package/src/inference/pipelines/text/init.d.ts +13 -0
  232. package/src/inference/pipelines/text/init.js +94 -25
  233. package/src/inference/pipelines/text/kernel-path-auto-select.js +2 -0
  234. package/src/inference/pipelines/text/kernel-trace.d.ts +2 -0
  235. package/src/inference/pipelines/text/kernel-trace.js +6 -0
  236. package/src/inference/pipelines/text/layer.js +4 -9
  237. package/src/inference/pipelines/text/linear-attention.d.ts +15 -0
  238. package/src/inference/pipelines/text/linear-attention.js +113 -9
  239. package/src/inference/pipelines/text/logits/gpu.js +12 -7
  240. package/src/inference/pipelines/text/logits/index.d.ts +6 -1
  241. package/src/inference/pipelines/text/logits/index.js +13 -12
  242. package/src/inference/pipelines/text/logits/utils.d.ts +7 -0
  243. package/src/inference/pipelines/text/logits/utils.js +9 -0
  244. package/src/inference/pipelines/text/lora-apply.js +50 -32
  245. package/src/inference/pipelines/text/model-load.js +282 -104
  246. package/src/inference/pipelines/text/moe-cache.js +5 -4
  247. package/src/inference/pipelines/text/moe-cpu-gptoss.js +74 -69
  248. package/src/inference/pipelines/text/moe-cpu.js +42 -38
  249. package/src/inference/pipelines/text/moe-gpu.js +110 -86
  250. package/src/inference/pipelines/text/ops.js +90 -90
  251. package/src/inference/pipelines/text/probes.js +9 -9
  252. package/src/inference/pipelines/text/sampling.js +52 -6
  253. package/src/inference/pipelines/text/weights.js +17 -7
  254. package/src/inference/pipelines/text.js +13 -1
  255. package/src/inference/speculative.d.ts +2 -2
  256. package/src/inference/speculative.js +4 -18
  257. package/src/inference/test-harness.d.ts +1 -1
  258. package/src/inference/test-harness.js +17 -7
  259. package/src/inference/tokenizer.d.ts +0 -5
  260. package/src/inference/tokenizer.js +4 -23
  261. package/src/inference/tokenizers/bpe.js +9 -0
  262. package/src/inference/tokenizers/bundled.js +20 -0
  263. package/src/inference/tokenizers/sentencepiece.js +12 -0
  264. package/src/loader/doppler-loader.js +38 -22
  265. package/src/loader/dtype-utils.js +3 -44
  266. package/src/loader/embedding-loader.js +7 -3
  267. package/src/loader/experts/expert-cache.js +13 -6
  268. package/src/loader/experts/expert-loader.js +10 -6
  269. package/src/loader/final-weights-loader.js +10 -4
  270. package/src/loader/layer-loader.js +2 -1
  271. package/src/loader/loader-state.js +2 -2
  272. package/src/loader/memory-monitor.js +8 -0
  273. package/src/loader/multi-model-loader.d.ts +14 -0
  274. package/src/loader/multi-model-loader.js +70 -24
  275. package/src/loader/shard-cache.js +84 -14
  276. package/src/loader/shard-resolver.js +25 -3
  277. package/src/loader/tensors/tensor-loader.js +214 -144
  278. package/src/loader/tensors/tensor-reader.js +76 -19
  279. package/src/loader/weight-downcast.js +1 -1
  280. package/src/memory/buffer-pool.d.ts +9 -1
  281. package/src/memory/buffer-pool.js +109 -44
  282. package/src/memory/unified-detect.js +1 -1
  283. package/src/rules/inference/dtype.rules.json +5 -0
  284. package/src/rules/inference/kernel-path.rules.json +24 -8
  285. package/src/rules/kernels/split-qg.rules.json +6 -0
  286. package/src/rules/rule-registry.js +27 -1
  287. package/src/storage/backends/opfs-store.js +68 -24
  288. package/src/storage/downloader.js +365 -83
  289. package/src/storage/index.d.ts +3 -0
  290. package/src/storage/index.js +3 -0
  291. package/src/storage/preflight.d.ts +2 -2
  292. package/src/storage/preflight.js +24 -2
  293. package/src/storage/quickstart-downloader.js +11 -5
  294. package/src/storage/registry.js +10 -4
  295. package/src/storage/reports.js +1 -1
  296. package/src/storage/shard-manager.d.ts +15 -1
  297. package/src/storage/shard-manager.js +55 -6
  298. package/src/storage/source-artifact-store.d.ts +52 -0
  299. package/src/storage/source-artifact-store.js +234 -0
  300. package/src/tooling/command-api-constants.d.ts +9 -0
  301. package/src/tooling/command-api-constants.js +9 -0
  302. package/src/tooling/command-api-family-normalizers.d.ts +9 -0
  303. package/src/tooling/command-api-family-normalizers.js +343 -0
  304. package/src/tooling/command-api-helpers.d.ts +25 -0
  305. package/src/tooling/command-api-helpers.js +262 -0
  306. package/src/tooling/command-api.js +16 -602
  307. package/src/tooling/command-envelope.js +4 -1
  308. package/src/tooling/command-runner-shared.js +52 -18
  309. package/src/tooling/conversion-config-materializer.js +3 -5
  310. package/src/tooling/lean-execution-contract.js +150 -3
  311. package/src/tooling/node-browser-command-runner.js +161 -271
  312. package/src/tooling/node-command-runner.js +29 -3
  313. package/src/tooling/node-converter.js +30 -1
  314. package/src/tooling/node-source-runtime.d.ts +1 -1
  315. package/src/tooling/node-source-runtime.js +120 -3
  316. package/src/tooling/node-webgpu.js +24 -21
  317. package/src/tooling/opfs-cache.js +21 -4
  318. package/src/tooling/runtime-input-composition.d.ts +38 -0
  319. package/src/tooling/runtime-input-composition.js +86 -0
  320. package/src/tooling/source-runtime-bundle.d.ts +40 -5
  321. package/src/tooling/source-runtime-bundle.js +261 -34
  322. package/src/tooling/source-runtime-materializer.d.ts +6 -0
  323. package/src/tooling/source-runtime-materializer.js +93 -0
  324. package/src/training/attention-backward.js +32 -17
  325. package/src/training/autograd.js +80 -52
  326. package/src/training/checkpoint-watch.d.ts +2 -1
  327. package/src/training/checkpoint-watch.js +39 -6
  328. package/src/training/checkpoint.js +40 -11
  329. package/src/training/clip.js +2 -1
  330. package/src/training/datasets/token-batch.js +20 -8
  331. package/src/training/distillation/checkpoint-watch.js +1 -0
  332. package/src/training/distillation/student-fixture.d.ts +22 -0
  333. package/src/training/distillation/student-fixture.js +846 -0
  334. package/src/training/distillation/suite-data.d.ts +45 -0
  335. package/src/training/distillation/suite-data.js +189 -0
  336. package/src/training/lora-pipeline.js +4 -7
  337. package/src/training/lora.js +26 -12
  338. package/src/training/loss.js +5 -6
  339. package/src/training/objectives/cross_entropy.js +2 -5
  340. package/src/training/objectives/distill_kd.js +4 -8
  341. package/src/training/objectives/distill_triplet.js +4 -8
  342. package/src/training/objectives/ul_stage2_base.js +4 -8
  343. package/src/training/operator-command.js +2 -0
  344. package/src/training/optimizer.js +19 -7
  345. package/src/training/runner.js +2 -1
  346. package/src/training/suite.js +18 -978
  347. package/src/training/tensor-factory.d.ts +9 -0
  348. package/src/training/tensor-factory.js +13 -0
  349. package/src/training/trainer.js +3 -5
  350. package/src/training/ul_dataset.js +3 -5
  351. package/src/training/workloads.js +70 -79
  352. package/src/types/model.d.ts +5 -0
  353. package/src/version.js +1 -1
  354. package/tools/convert-safetensors-node.js +22 -16
  355. package/tools/doppler-cli.js +50 -26
@@ -113,6 +113,130 @@ export function resolveBatchStop(tokens, stopFlags, stopTokenIds, eosTokenId) {
113
113
  return actualCount;
114
114
  }
115
115
 
116
+ export function findInvalidGeneratedToken(tokens, vocabSize, padTokenId = null) {
117
+ for (let i = 0; i < tokens.length; i++) {
118
+ const tokenId = tokens[i];
119
+ const isInvalid = !Number.isFinite(tokenId)
120
+ || tokenId < 0
121
+ || tokenId >= vocabSize
122
+ || (padTokenId != null ? tokenId === padTokenId : tokenId === 0);
123
+ if (isInvalid) {
124
+ return { index: i, tokenId };
125
+ }
126
+ }
127
+ return null;
128
+ }
129
+
130
+ export async function readSampledTokenFromStagingBuffer(stagingBuffer, options = {}) {
131
+ const ownsStagingBuffer = options.ownsStagingBuffer === true;
132
+ const hasFinitenessBuffer = options.hasFinitenessBuffer === true;
133
+ const ring = options.ring ?? null;
134
+ let mapped = false;
135
+
136
+ try {
137
+ await stagingBuffer.mapAsync(GPUMapMode.READ);
138
+ mapped = true;
139
+ const mappedWords = new Uint32Array(stagingBuffer.getMappedRange());
140
+ return {
141
+ nextToken: mappedWords[0],
142
+ finitenessStatus: hasFinitenessBuffer
143
+ ? parseFinitenessStatusWords(mappedWords, 1)
144
+ : parseFinitenessStatusWords(mappedWords, 0),
145
+ };
146
+ } finally {
147
+ if (mapped) {
148
+ stagingBuffer.unmap();
149
+ }
150
+ if (ownsStagingBuffer) {
151
+ stagingBuffer.destroy();
152
+ }
153
+ ring?.advance();
154
+ }
155
+ }
156
+
157
+ export async function readMappedBufferCopy(stagingBuffer, options = {}) {
158
+ const ownsStagingBuffer = options.ownsStagingBuffer !== false;
159
+ let mapped = false;
160
+
161
+ try {
162
+ await stagingBuffer.mapAsync(GPUMapMode.READ);
163
+ mapped = true;
164
+ return stagingBuffer.getMappedRange().slice(0);
165
+ } finally {
166
+ if (mapped) {
167
+ stagingBuffer.unmap();
168
+ }
169
+ if (ownsStagingBuffer) {
170
+ stagingBuffer.destroy();
171
+ }
172
+ }
173
+ }
174
+
175
+ export async function readBatchTokensFromStagingBuffers(options) {
176
+ const {
177
+ tokensStagingBuffer,
178
+ stopStagingBuffer = null,
179
+ finitenessStagingBuffer = null,
180
+ tokenCount,
181
+ ownsTokensStaging = false,
182
+ ownsStopStaging = false,
183
+ ring = null,
184
+ } = options;
185
+ let tokensMapped = false;
186
+ let stopMapped = false;
187
+ let finitenessMapped = false;
188
+
189
+ try {
190
+ const mapPromises = [tokensStagingBuffer.mapAsync(GPUMapMode.READ)];
191
+ if (stopStagingBuffer) {
192
+ mapPromises.push(stopStagingBuffer.mapAsync(GPUMapMode.READ));
193
+ }
194
+ if (finitenessStagingBuffer) {
195
+ mapPromises.push(finitenessStagingBuffer.mapAsync(GPUMapMode.READ));
196
+ }
197
+ await Promise.all(mapPromises);
198
+ tokensMapped = true;
199
+ stopMapped = Boolean(stopStagingBuffer);
200
+ finitenessMapped = Boolean(finitenessStagingBuffer);
201
+
202
+ const tokens = Array.from(
203
+ new Uint32Array(tokensStagingBuffer.getMappedRange()).subarray(0, tokenCount)
204
+ );
205
+ const stopFlags = stopStagingBuffer
206
+ ? new Uint32Array(stopStagingBuffer.getMappedRange().slice(0, tokenCount * 4))
207
+ : null;
208
+ const finitenessStatus = finitenessStagingBuffer
209
+ ? parseFinitenessStatusWords(new Uint32Array(finitenessStagingBuffer.getMappedRange()), 0)
210
+ : { triggered: false, metadata: '' };
211
+
212
+ return {
213
+ tokens,
214
+ stopFlags,
215
+ finitenessStatus,
216
+ };
217
+ } finally {
218
+ if (finitenessMapped) {
219
+ finitenessStagingBuffer.unmap();
220
+ }
221
+ if (tokensMapped) {
222
+ tokensStagingBuffer.unmap();
223
+ }
224
+ if (stopMapped) {
225
+ stopStagingBuffer.unmap();
226
+ }
227
+ if (finitenessStagingBuffer) {
228
+ finitenessStagingBuffer.destroy();
229
+ }
230
+ if (ownsTokensStaging) {
231
+ tokensStagingBuffer.destroy();
232
+ }
233
+ if (ownsStopStaging) {
234
+ stopStagingBuffer?.destroy();
235
+ }
236
+ ring?.advance();
237
+ }
238
+ }
239
+
116
240
  async function runDecodeLayers(state, tokenId, opts, helpers) {
117
241
  const config = state.modelConfig;
118
242
  const debugCheckBuffer = state.debug ? helpers.debugCheckBuffer : undefined;
@@ -130,11 +254,9 @@ async function runDecodeLayers(state, tokenId, opts, helpers) {
130
254
  throw new Error('Embed buffer not found or not a supported buffer type');
131
255
  }
132
256
  const embedBuffer = isWeightBuffer(embedBufferRaw) ? embedBufferRaw.buffer : embedBufferRaw;
133
- const embedDtype = isWeightBuffer(embedBufferRaw)
134
- ? getWeightDtype(embedBufferRaw)
135
- : isCpuWeightBuffer(embedBufferRaw)
136
- ? embedBufferRaw.dtype
137
- : null;
257
+ const embedDtype = isCpuWeightBuffer(embedBufferRaw)
258
+ ? embedBufferRaw.dtype
259
+ : getWeightDtype(embedBufferRaw);
138
260
  const activationDtype = getEffectiveActivationDtype(state, opts);
139
261
 
140
262
  const embedTensor = await embed([tokenId], embedBuffer, {
@@ -216,11 +338,9 @@ export async function decodeStep(state, currentIds, opts, helpers) {
216
338
  throw new Error('Embed buffer not found or not a supported buffer type');
217
339
  }
218
340
  const embedBuffer = isWeightBuffer(embedBufferRaw) ? embedBufferRaw.buffer : embedBufferRaw;
219
- const embedDtype = isWeightBuffer(embedBufferRaw)
220
- ? getWeightDtype(embedBufferRaw)
221
- : isCpuWeightBuffer(embedBufferRaw)
222
- ? embedBufferRaw.dtype
223
- : null;
341
+ const embedDtype = isCpuWeightBuffer(embedBufferRaw)
342
+ ? embedBufferRaw.dtype
343
+ : getWeightDtype(embedBufferRaw);
224
344
  const activationDtype = getEffectiveActivationDtype(state, opts);
225
345
  const activationBytes = selectRuleValue('shared', 'dtype', 'bytesFromDtype', { dtype: activationDtype });
226
346
 
@@ -352,17 +472,11 @@ export async function decodeStep(state, currentIds, opts, helpers) {
352
472
  throw new Error('[Pipeline] GPU readback disabled for sampling');
353
473
  }
354
474
 
355
- await stagingBuffer.mapAsync(GPUMapMode.READ);
356
- const mapped = new Uint32Array(stagingBuffer.getMappedRange());
357
- const nextToken = mapped[0];
358
- const finitenessStatus = state.finitenessBuffer
359
- ? parseFinitenessStatusWords(mapped, 1)
360
- : parseFinitenessStatusWords(mapped, 0);
361
- stagingBuffer.unmap();
362
- if (ownsStagingBuffer) {
363
- stagingBuffer.destroy();
364
- }
365
- ring?.advance();
475
+ const { nextToken, finitenessStatus } = await readSampledTokenFromStagingBuffer(stagingBuffer, {
476
+ ownsStagingBuffer,
477
+ hasFinitenessBuffer: Boolean(state.finitenessBuffer),
478
+ ring,
479
+ });
366
480
 
367
481
  if (finitenessStatus.triggered) {
368
482
  releaseBuffer(logitsBuffer);
@@ -499,10 +613,7 @@ export async function decodeStep(state, currentIds, opts, helpers) {
499
613
  const enc = debugDevice.createCommandEncoder();
500
614
  enc.copyBufferToBuffer(hiddenStates, 0, staging, 0, sampleSize);
501
615
  debugDevice.queue.submit([enc.finish()]);
502
- await staging.mapAsync(GPUMapMode.READ);
503
- const data = new Float32Array(staging.getMappedRange().slice(0));
504
- staging.unmap();
505
- staging.destroy();
616
+ const data = new Float32Array(await readMappedBufferCopy(staging));
506
617
  const nanCount = Array.from(data).filter(x => !Number.isFinite(x)).length;
507
618
  const nonZero = Array.from(data).filter(x => Number.isFinite(x) && x !== 0).slice(0, 5);
508
619
  log.debug('Decode', `[1] HIDDEN_AFTER_LAYERS: nan=${nanCount}/${data.length}, nonZero=${nonZero.length}, sample=[${nonZero.map(x => x.toFixed(4)).join(', ')}]`);
@@ -535,11 +646,21 @@ export async function decodeStep(state, currentIds, opts, helpers) {
535
646
  });
536
647
 
537
648
  releaseBuffer(logitsBuffer);
538
- if (!context.decodeBuffers?.ownsBuffer(hiddenStates)) {
539
- releaseBuffer(hiddenStates);
649
+ const invalidGpuToken = nextToken >= config.vocabSize
650
+ || (padTokenId != null && nextToken === padTokenId)
651
+ || (padTokenId == null && nextToken === 0);
652
+ if (!invalidGpuToken) {
653
+ if (!context.decodeBuffers?.ownsBuffer(hiddenStates)) {
654
+ releaseBuffer(hiddenStates);
655
+ }
656
+ state.currentSeqLen++;
657
+ return nextToken;
540
658
  }
541
- state.currentSeqLen++;
542
- return nextToken;
659
+ state.disableFusedDecode = true;
660
+ log.warn(
661
+ 'Decode',
662
+ `GPU sampling produced invalid token ${nextToken} (vocabSize=${config.vocabSize}, step=${state.decodeStepCount}); falling back to CPU sampling.`
663
+ );
543
664
  }
544
665
  }
545
666
 
@@ -854,225 +975,223 @@ export async function generateNTokensGPU(state, startToken, N, currentIds, opts,
854
975
  })
855
976
  : null;
856
977
  const ownsStopStaging = useGpuStopFlags && !ringSlot?.stagingStop;
978
+ let finitenessStagingBuffer = null;
979
+ let readbackCleanupDelegated = false;
980
+ try {
981
+ if (state.finitenessBuffer) {
982
+ device.queue.writeBuffer(state.finitenessBuffer, 0, new Uint32Array([0, 0, 0, 0]));
983
+ }
857
984
 
858
- if (state.finitenessBuffer) {
859
- device.queue.writeBuffer(state.finitenessBuffer, 0, new Uint32Array([0, 0, 0, 0]));
860
- }
985
+ device.queue.writeBuffer(tokensBuffer, 0, new Uint32Array([startToken]));
986
+ if (stopBuffer) {
987
+ const stopElements = stopBuffer.size / 4;
988
+ const zeroStopData = ringSlot?.zeroStopData;
989
+ const clearData = zeroStopData && zeroStopData.length <= stopElements
990
+ ? zeroStopData
991
+ : new Uint32Array(stopElements);
992
+ device.queue.writeBuffer(stopBuffer, 0, clearData);
993
+ }
861
994
 
862
- device.queue.writeBuffer(tokensBuffer, 0, new Uint32Array([startToken]));
863
- if (stopBuffer) {
864
- const stopElements = stopBuffer.size / 4;
865
- const zeroStopData = ringSlot?.zeroStopData;
866
- const clearData = zeroStopData && zeroStopData.length <= stopElements
867
- ? zeroStopData
868
- : new Uint32Array(stopElements);
869
- device.queue.writeBuffer(stopBuffer, 0, clearData);
870
- }
995
+ const context = helpers.buildLayerContext(recorder, true, opts.debugLayers, executionPlan);
996
+ const embedBufferRaw = state.weights.get('embed');
997
+ if (isCpuWeightBuffer(embedBufferRaw)) {
998
+ throw new Error('[Pipeline] GPU-only decode not supported with CPU-resident embeddings.');
999
+ }
1000
+ if (!(embedBufferRaw instanceof GPUBuffer) && !isWeightBuffer(embedBufferRaw)) {
1001
+ throw new Error('Embed buffer not found or not a GPUBuffer/WeightBuffer');
1002
+ }
1003
+ const embedBuffer = isWeightBuffer(embedBufferRaw) ? embedBufferRaw.buffer : embedBufferRaw;
1004
+ const embedDtype = getWeightDtype(embedBufferRaw);
1005
+ const activationDtype = getEffectiveActivationDtype(state, opts);
1006
+
1007
+ for (let i = 0; i < N; i++) {
1008
+ const currentPos = state.currentSeqLen + i;
1009
+ context.currentSeqLen = currentPos;
1010
+ context.currentTokenIds = [startToken];
1011
+ context.decodeBuffers?.resetPingPong();
1012
+
1013
+ const hiddenTensor = await embed(tokensBuffer, embedBuffer, {
1014
+ hiddenSize: config.hiddenSize,
1015
+ vocabSize: config.vocabSize,
1016
+ scaleEmbeddings: config.scaleEmbeddings,
1017
+ recorder,
1018
+ transpose: state.embeddingTranspose,
1019
+ debugProbes: state.runtimeConfig.shared.debug.probes,
1020
+ activationDtype,
1021
+ embeddingDtype: selectRuleValue('inference', 'dtype', 'f16OrF32FromDtype', { dtype: embedDtype }),
1022
+ numTokens: 1,
1023
+ indexOffset: i,
1024
+ });
871
1025
 
872
- const context = helpers.buildLayerContext(recorder, true, opts.debugLayers, executionPlan);
873
- const embedBufferRaw = state.weights.get('embed');
874
- if (isCpuWeightBuffer(embedBufferRaw)) {
875
- throw new Error('[Pipeline] GPU-only decode not supported with CPU-resident embeddings.');
876
- }
877
- if (!(embedBufferRaw instanceof GPUBuffer) && !isWeightBuffer(embedBufferRaw)) {
878
- throw new Error('Embed buffer not found or not a GPUBuffer/WeightBuffer');
879
- }
880
- const embedBuffer = isWeightBuffer(embedBufferRaw) ? embedBufferRaw.buffer : embedBufferRaw;
881
- const embedDtype = isWeightBuffer(embedBufferRaw) ? getWeightDtype(embedBufferRaw) : null;
882
- const activationDtype = getEffectiveActivationDtype(state, opts);
1026
+ let hiddenStatesBuffer = hiddenTensor.buffer;
1027
+ for (let l = 0; l < config.numLayers; l++) {
1028
+ const prevStates = hiddenStatesBuffer;
1029
+ hiddenStatesBuffer = (await processLayer(l, hiddenStatesBuffer, 1, false, context));
1030
+ context.decodeBuffers?.swapPingPong();
1031
+ if (prevStates instanceof GPUBuffer && prevStates !== hiddenStatesBuffer) {
1032
+ const ownsBuffer = context.decodeBuffers?.ownsBuffer(prevStates);
1033
+ if (!ownsBuffer) {
1034
+ recorder.trackTemporaryBuffer(prevStates);
1035
+ }
1036
+ }
1037
+ }
883
1038
 
884
- for (let i = 0; i < N; i++) {
885
- const currentPos = state.currentSeqLen + i;
886
- context.currentSeqLen = currentPos;
887
- context.currentTokenIds = [startToken];
888
- context.decodeBuffers?.resetPingPong();
1039
+ const logits = await recordLogitsGPU(
1040
+ recorder,
1041
+ hiddenStatesBuffer,
1042
+ 1,
1043
+ helpers.getLogitsWeights(),
1044
+ helpers.getLogitsConfig()
1045
+ );
1046
+ const { logitsBuffer, vocabSize, logitsDtype } = logits;
889
1047
 
890
- const hiddenTensor = await embed(tokensBuffer, embedBuffer, {
891
- hiddenSize: config.hiddenSize,
892
- vocabSize: config.vocabSize,
893
- scaleEmbeddings: config.scaleEmbeddings,
894
- recorder,
895
- transpose: state.embeddingTranspose,
896
- debugProbes: state.runtimeConfig.shared.debug.probes,
897
- activationDtype,
898
- embeddingDtype: selectRuleValue('inference', 'dtype', 'f16OrF32FromDtype', { dtype: embedDtype }),
899
- numTokens: 1,
900
- indexOffset: i,
901
- });
1048
+ const outputIndex = i + 1;
1049
+ if (opts.temperature < samplingDefaults.greedyThreshold) {
1050
+ await recordArgmax(recorder, logitsBuffer, vocabSize, {
1051
+ padTokenId,
1052
+ logitSoftcap,
1053
+ logitsDtype,
1054
+ outputBuffer: tokensBuffer,
1055
+ outputIndex,
1056
+ });
1057
+ } else {
1058
+ await recordGPUSample(recorder, logitsBuffer, vocabSize, {
1059
+ temperature: opts.temperature,
1060
+ topK: opts.topK,
1061
+ padTokenId,
1062
+ logitSoftcap,
1063
+ logitsDtype,
1064
+ outputBuffer: tokensBuffer,
1065
+ outputIndex,
1066
+ greedyThreshold: samplingDefaults.greedyThreshold,
1067
+ });
1068
+ }
902
1069
 
903
- let hiddenStatesBuffer = hiddenTensor.buffer;
904
- for (let l = 0; l < config.numLayers; l++) {
905
- const prevStates = hiddenStatesBuffer;
906
- hiddenStatesBuffer = (await processLayer(l, hiddenStatesBuffer, 1, false, context));
907
- context.decodeBuffers?.swapPingPong();
908
- if (prevStates instanceof GPUBuffer && prevStates !== hiddenStatesBuffer) {
909
- const ownsBuffer = context.decodeBuffers?.ownsBuffer(prevStates);
910
- if (!ownsBuffer) {
911
- recorder.trackTemporaryBuffer(prevStates);
912
- }
1070
+ const stopCheck = useGpuStopFlags
1071
+ ? recordCheckStop(recorder, {
1072
+ sampledTokenBuffer: tokensBuffer,
1073
+ shouldStopBuffer: stopBuffer,
1074
+ tokenIndex: outputIndex,
1075
+ eosTokenId,
1076
+ maxTokens: maxSeqLen,
1077
+ currentPos,
1078
+ })
1079
+ : null;
1080
+
1081
+ if (hiddenStatesBuffer instanceof GPUBuffer && !context.decodeBuffers?.ownsBuffer(hiddenStatesBuffer)) {
1082
+ recorder.trackTemporaryBuffer(hiddenStatesBuffer);
1083
+ }
1084
+ if (logitsBuffer instanceof GPUBuffer) {
1085
+ recorder.trackTemporaryBuffer(logitsBuffer);
1086
+ }
1087
+ if (stopCheck instanceof GPUBuffer && stopCheck !== stopBuffer) {
1088
+ recorder.trackTemporaryBuffer(stopCheck);
913
1089
  }
914
1090
  }
915
1091
 
916
- const logits = await recordLogitsGPU(
917
- recorder,
918
- hiddenStatesBuffer,
919
- 1,
920
- helpers.getLogitsWeights(),
921
- helpers.getLogitsConfig()
922
- );
923
- const { logitsBuffer, vocabSize, logitsDtype } = logits;
1092
+ const recordMs = performance.now() - recordStart;
1093
+ state.stats.decodeRecordMs = (state.stats.decodeRecordMs ?? 0) + recordMs;
924
1094
 
925
- const outputIndex = i + 1;
926
- if (opts.temperature < samplingDefaults.greedyThreshold) {
927
- await recordArgmax(recorder, logitsBuffer, vocabSize, {
928
- padTokenId,
929
- logitSoftcap,
930
- logitsDtype,
931
- outputBuffer: tokensBuffer,
932
- outputIndex,
933
- });
934
- } else {
935
- await recordGPUSample(recorder, logitsBuffer, vocabSize, {
936
- temperature: opts.temperature,
937
- topK: opts.topK,
938
- padTokenId,
939
- logitSoftcap,
940
- logitsDtype,
941
- outputBuffer: tokensBuffer,
942
- outputIndex,
943
- greedyThreshold: samplingDefaults.greedyThreshold,
944
- });
1095
+ const encoder = recorder.getEncoder();
1096
+ encoder.copyBufferToBuffer(tokensBuffer, 4, tokensStagingBuffer, 0, N * 4);
1097
+ if (useGpuStopFlags && stopBuffer && stopStagingBuffer) {
1098
+ encoder.copyBufferToBuffer(stopBuffer, 4, stopStagingBuffer, 0, N * 4);
945
1099
  }
946
1100
 
947
- const stopCheck = useGpuStopFlags
948
- ? recordCheckStop(recorder, {
949
- sampledTokenBuffer: tokensBuffer,
950
- shouldStopBuffer: stopBuffer,
951
- tokenIndex: outputIndex,
952
- eosTokenId,
953
- maxTokens: maxSeqLen,
954
- currentPos,
955
- })
956
- : null;
957
-
958
- if (hiddenStatesBuffer instanceof GPUBuffer && !context.decodeBuffers?.ownsBuffer(hiddenStatesBuffer)) {
959
- recorder.trackTemporaryBuffer(hiddenStatesBuffer);
960
- }
961
- if (logitsBuffer instanceof GPUBuffer) {
962
- recorder.trackTemporaryBuffer(logitsBuffer);
963
- }
964
- if (stopCheck instanceof GPUBuffer && stopCheck !== stopBuffer) {
965
- recorder.trackTemporaryBuffer(stopCheck);
1101
+ if (state.finitenessBuffer) {
1102
+ finitenessStagingBuffer = device.createBuffer({
1103
+ size: 16,
1104
+ usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ
1105
+ });
1106
+ encoder.copyBufferToBuffer(state.finitenessBuffer, 0, finitenessStagingBuffer, 0, 16);
966
1107
  }
967
- }
968
1108
 
969
- const recordMs = performance.now() - recordStart;
970
- state.stats.decodeRecordMs = (state.stats.decodeRecordMs ?? 0) + recordMs;
1109
+ recorder.submit();
971
1110
 
972
- const encoder = recorder.getEncoder();
973
- encoder.copyBufferToBuffer(tokensBuffer, 4, tokensStagingBuffer, 0, N * 4);
974
- if (useGpuStopFlags && stopBuffer && stopStagingBuffer) {
975
- encoder.copyBufferToBuffer(stopBuffer, 4, stopStagingBuffer, 0, N * 4);
976
- }
1111
+ if (!allowReadback('pipeline.decode.sample')) {
1112
+ throw new Error('[Pipeline] GPU readback disabled for sampling');
1113
+ }
977
1114
 
978
- let finitenessStagingBuffer = null;
979
- if (state.finitenessBuffer) {
980
- finitenessStagingBuffer = device.createBuffer({
981
- size: 16,
982
- usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ
1115
+ const readbackStart = performance.now();
1116
+ readbackCleanupDelegated = true;
1117
+ const readback = await readBatchTokensFromStagingBuffers({
1118
+ tokensStagingBuffer,
1119
+ stopStagingBuffer,
1120
+ finitenessStagingBuffer,
1121
+ tokenCount: N,
1122
+ ownsTokensStaging,
1123
+ ownsStopStaging,
1124
+ ring,
983
1125
  });
984
- encoder.copyBufferToBuffer(state.finitenessBuffer, 0, finitenessStagingBuffer, 0, 16);
985
- }
986
-
987
- recorder.submit();
988
-
989
- if (!allowReadback('pipeline.decode.sample')) {
990
- throw new Error('[Pipeline] GPU readback disabled for sampling');
991
- }
992
-
993
- const readbackStart = performance.now();
994
- const mapPromises = [tokensStagingBuffer.mapAsync(GPUMapMode.READ)];
995
- if (stopStagingBuffer) {
996
- mapPromises.push(stopStagingBuffer.mapAsync(GPUMapMode.READ));
997
- }
998
- if (finitenessStagingBuffer) {
999
- mapPromises.push(finitenessStagingBuffer.mapAsync(GPUMapMode.READ));
1000
- }
1001
- await Promise.all(mapPromises);
1002
- const readbackWaitMs = performance.now() - readbackStart;
1003
- state.stats.decodeReadbackWaitMs = (state.stats.decodeReadbackWaitMs ?? 0) + readbackWaitMs;
1004
-
1005
- let isInfinite = false;
1006
- let metadata = '';
1007
- if (finitenessStagingBuffer) {
1008
- const finitenessData = new Uint32Array(finitenessStagingBuffer.getMappedRange());
1009
- const finitenessStatus = parseFinitenessStatusWords(finitenessData, 0);
1010
- isInfinite = finitenessStatus.triggered;
1011
- metadata = finitenessStatus.metadata;
1012
- finitenessStagingBuffer.unmap();
1013
- finitenessStagingBuffer.destroy();
1014
- }
1126
+ const readbackWaitMs = performance.now() - readbackStart;
1127
+ state.stats.decodeReadbackWaitMs = (state.stats.decodeReadbackWaitMs ?? 0) + readbackWaitMs;
1015
1128
 
1016
- const submitWaitMs = recorder.getSubmitLatencyMs();
1017
- if (submitWaitMs != null) {
1018
- state.stats.decodeSubmitWaitMs = (state.stats.decodeSubmitWaitMs ?? 0) + submitWaitMs;
1019
- }
1020
-
1021
- getUniformCache().flushPendingDestruction();
1022
-
1023
- const tokensView = new Uint32Array(tokensStagingBuffer.getMappedRange());
1024
- const tokens = Array.from(tokensView.subarray(0, N));
1129
+ const isInfinite = readback.finitenessStatus.triggered;
1130
+ const metadata = readback.finitenessStatus.metadata;
1025
1131
 
1026
- const stopFlags = stopStagingBuffer
1027
- ? new Uint32Array(stopStagingBuffer.getMappedRange().slice(0, N * 4))
1028
- : null;
1029
-
1030
- if (stopFlags) {
1031
- log.debug('Pipeline', `[STOP] N=${N} flags=[${Array.from(stopFlags).join(',')}] tokens=[${tokens.join(',')}] eos=${eosTokenId}`);
1032
- }
1132
+ const submitWaitMs = recorder.getSubmitLatencyMs();
1133
+ if (submitWaitMs != null) {
1134
+ state.stats.decodeSubmitWaitMs = (state.stats.decodeSubmitWaitMs ?? 0) + submitWaitMs;
1135
+ }
1033
1136
 
1034
- const actualCount = resolveBatchStop(tokens, stopFlags, stopTokenIds, eosToken);
1137
+ getUniformCache().flushPendingDestruction();
1035
1138
 
1036
- tokensStagingBuffer.unmap();
1037
- if (stopStagingBuffer) {
1038
- stopStagingBuffer.unmap();
1039
- }
1139
+ const tokens = readback.tokens;
1140
+ const stopFlags = readback.stopFlags;
1040
1141
 
1041
- const generatedTokens = tokens.slice(0, actualCount);
1142
+ if (stopFlags) {
1143
+ log.debug('Pipeline', `[STOP] N=${N} flags=[${Array.from(stopFlags).join(',')}] tokens=[${tokens.join(',')}] eos=${eosTokenId}`);
1144
+ }
1042
1145
 
1043
- if (ownsTokensBuffer) tokensBuffer.destroy();
1044
- if (ownsStopBuffer) stopBuffer?.destroy();
1045
- if (ownsTokensStaging) tokensStagingBuffer.destroy();
1046
- if (ownsStopStaging) stopStagingBuffer?.destroy();
1146
+ const actualCount = resolveBatchStop(tokens, stopFlags, stopTokenIds, eosToken);
1147
+ const generatedTokens = tokens.slice(0, actualCount);
1148
+ const invalidToken = findInvalidGeneratedToken(generatedTokens, config.vocabSize, padTokenId);
1047
1149
 
1048
- if (isInfinite) {
1049
- throw new FinitenessError(`F16 bounds exceeded during batch generation${metadata}`);
1050
- }
1150
+ if (isInfinite) {
1151
+ throw new FinitenessError(`F16 bounds exceeded during batch generation${metadata}`);
1152
+ }
1153
+ if (invalidToken) {
1154
+ state.disableFusedDecode = true;
1155
+ throw new Error(
1156
+ `[Pipeline] Batch decode produced invalid token ${invalidToken.tokenId} ` +
1157
+ `at batch index ${invalidToken.index} (vocabSize=${config.vocabSize}, padTokenId=${padTokenId ?? 'none'}).`
1158
+ );
1159
+ }
1051
1160
 
1052
- if (opts.profile && recorder.isProfilingEnabled()) {
1053
- const timings = await recorder.resolveProfileTimings();
1054
- const total = sumProfileTimings(timings);
1055
- if (total !== null) {
1056
- state.stats.gpuTimeDecodeMs = (state.stats.gpuTimeDecodeMs ?? 0) + total;
1161
+ if (opts.profile && recorder.isProfilingEnabled()) {
1162
+ const timings = await recorder.resolveProfileTimings();
1163
+ const total = sumProfileTimings(timings);
1164
+ if (total !== null) {
1165
+ state.stats.gpuTimeDecodeMs = (state.stats.gpuTimeDecodeMs ?? 0) + total;
1166
+ }
1167
+ if (timings) {
1168
+ recordDecodeProfileStep(state, {
1169
+ batch: true,
1170
+ stepStart: state.decodeStepCount + 1,
1171
+ stepCount: actualCount,
1172
+ timings,
1173
+ totalMs: total ?? undefined,
1174
+ });
1175
+ const stepStart = state.decodeStepCount + 1;
1176
+ if (shouldLogProfileStep(state, stepStart)) {
1177
+ log.warn('Profile', `Batch decode (N=${N}):`);
1178
+ log.warn('Profile', CommandRecorder.formatProfileReport(timings));
1179
+ }
1180
+ }
1057
1181
  }
1058
- if (timings) {
1059
- recordDecodeProfileStep(state, {
1060
- batch: true,
1061
- stepStart: state.decodeStepCount + 1,
1062
- stepCount: actualCount,
1063
- timings,
1064
- totalMs: total ?? undefined,
1065
- });
1066
- const stepStart = state.decodeStepCount + 1;
1067
- if (shouldLogProfileStep(state, stepStart)) {
1068
- log.warn('Profile', `Batch decode (N=${N}):`);
1069
- log.warn('Profile', CommandRecorder.formatProfileReport(timings));
1182
+
1183
+ state.currentSeqLen += actualCount;
1184
+ return { tokens: generatedTokens, actualCount };
1185
+ } finally {
1186
+ if (!readbackCleanupDelegated) {
1187
+ if (finitenessStagingBuffer) {
1188
+ finitenessStagingBuffer.destroy();
1070
1189
  }
1190
+ if (ownsTokensStaging) tokensStagingBuffer.destroy();
1191
+ if (ownsStopStaging) stopStagingBuffer?.destroy();
1192
+ ring?.advance();
1071
1193
  }
1194
+ if (ownsTokensBuffer) tokensBuffer.destroy();
1195
+ if (ownsStopBuffer) stopBuffer?.destroy();
1072
1196
  }
1073
-
1074
- state.currentSeqLen += actualCount;
1075
- ring?.advance();
1076
-
1077
- return { tokens: generatedTokens, actualCount };
1078
1197
  }