@simulatte/doppler 0.1.6 → 0.1.8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (355) hide show
  1. package/CHANGELOG.md +145 -0
  2. package/README.md +16 -23
  3. package/package.json +30 -32
  4. package/src/adapters/adapter-registry.js +12 -1
  5. package/src/adapters/lora-loader.js +23 -6
  6. package/src/bridge/extension-client.d.ts +5 -0
  7. package/src/bridge/extension-client.js +40 -0
  8. package/src/bridge/index.d.ts +2 -1
  9. package/src/bridge/index.js +6 -4
  10. package/src/browser/browser-converter.js +31 -1
  11. package/src/browser/file-picker.js +6 -0
  12. package/src/browser/safetensors-parser-browser.js +84 -1
  13. package/src/browser/shard-io-browser.js +2 -2
  14. package/src/browser/tensor-source-download.js +8 -2
  15. package/src/browser/tensor-source-http.d.ts +1 -0
  16. package/src/browser/tensor-source-http.js +5 -1
  17. package/src/client/doppler-api.browser.js +20 -4
  18. package/src/client/doppler-api.js +19 -3
  19. package/src/client/doppler-provider/generation.js +12 -0
  20. package/src/client/doppler-provider/model-manager.d.ts +10 -0
  21. package/src/client/doppler-provider/model-manager.js +91 -19
  22. package/src/client/doppler-provider/source-runtime.d.ts +2 -1
  23. package/src/client/doppler-provider/source-runtime.js +132 -13
  24. package/src/client/doppler-registry.json +5 -20
  25. package/src/config/backward-registry-loader.js +17 -2
  26. package/src/config/execution-v0-contract-check.js +113 -15
  27. package/src/config/kernel-path-contract-check.js +57 -29
  28. package/src/config/kernel-path-loader.d.ts +5 -0
  29. package/src/config/kernel-path-loader.js +18 -36
  30. package/src/config/kernels/kernel-ref-digests.js +1 -1
  31. package/src/config/kernels/registry.js +14 -1
  32. package/src/config/kernels/registry.json +81 -5
  33. package/src/config/loader.d.ts +1 -1
  34. package/src/config/loader.js +15 -2
  35. package/src/config/merge-contract-check.js +66 -4
  36. package/src/config/merge-helpers.js +128 -7
  37. package/src/config/merge.d.ts +1 -0
  38. package/src/config/merge.js +10 -0
  39. package/src/config/param-validator.js +47 -2
  40. package/src/config/presets/kernel-paths/{gemma2-q4k-dequant-f32a.json → gemma2-q4k-dequant-f32a-nosubgroups.json} +3 -3
  41. package/src/config/presets/kernel-paths/gemma3-f16-fused-f32a-online-streamingprefill.json +223 -0
  42. package/src/config/presets/kernel-paths/{gemma3-q4k-dequant-f32a.json → gemma3-q4k-dequant-f32a-nosubgroups.json} +3 -3
  43. package/src/config/presets/kernel-paths/gemma3-q4k-dequant-f32w-f32a-online.json +56 -0
  44. package/src/config/presets/kernel-paths/lfm2-q4k-dequant-f32a-nosubgroups.json +61 -0
  45. package/src/config/presets/kernel-paths/registry.json +43 -8
  46. package/src/config/presets/models/gemma2.json +3 -2
  47. package/src/config/presets/models/gemma3.json +2 -0
  48. package/src/config/presets/models/qwen3.json +4 -3
  49. package/src/config/presets/models/qwen3_5.json +16 -0
  50. package/src/config/presets/runtime/experiments/bench/gemma3-bench-q4k.json +1 -1
  51. package/src/config/presets/runtime/experiments/debug/gemma3-debug-q4k.json +1 -1
  52. package/src/config/presets/runtime/experiments/verify/gemma3-verify.json +1 -1
  53. package/src/config/presets/runtime/kernels/dequant-f16-q4k.json +6 -13
  54. package/src/config/presets/runtime/kernels/dequant-f32-q4k.json +6 -13
  55. package/src/config/presets/runtime/kernels/embeddinggemma-q4k-dequant-f32a.json +37 -0
  56. package/src/config/presets/runtime/kernels/fused-q4k.json +6 -13
  57. package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f16a.json +33 -0
  58. package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f32a-nosubgroups.json +33 -0
  59. package/src/config/presets/runtime/kernels/gemma2-q4k-fused-f32a.json +33 -0
  60. package/src/config/presets/runtime/kernels/safe-q4k.json +6 -13
  61. package/src/config/presets/runtime/model/qwen3-5-layer-probe.json +52 -0
  62. package/src/config/presets/runtime/model/qwen3-5-linear-attn-debug.json +90 -0
  63. package/src/config/presets/runtime/platform/metal-apple-q4k.json +1 -1
  64. package/src/config/runtime.js +6 -1
  65. package/src/config/schema/conversion.schema.d.ts +1 -0
  66. package/src/config/schema/debug.schema.d.ts +5 -0
  67. package/src/config/schema/doppler.schema.js +16 -21
  68. package/src/config/schema/inference-defaults.schema.js +3 -3
  69. package/src/config/schema/kernel-path.schema.d.ts +5 -1
  70. package/src/config/schema/kernel-thresholds.schema.js +12 -4
  71. package/src/config/schema/manifest.schema.d.ts +3 -2
  72. package/src/config/schema/manifest.schema.js +17 -4
  73. package/src/config/schema/storage.schema.js +1 -1
  74. package/src/config/training-defaults.js +30 -22
  75. package/src/converter/conversion-plan.js +104 -11
  76. package/src/converter/core.d.ts +7 -0
  77. package/src/converter/core.js +16 -9
  78. package/src/converter/execution-v0-manifest.js +4 -1
  79. package/src/converter/index.d.ts +1 -0
  80. package/src/converter/index.js +1 -0
  81. package/src/converter/manifest-inference.js +50 -29
  82. package/src/converter/parsers/diffusion.js +0 -3
  83. package/src/converter/parsers/transformer.js +4 -0
  84. package/src/converter/quantization-info.js +40 -16
  85. package/src/converter/quantizer.js +19 -12
  86. package/src/converter/rope-config.js +8 -6
  87. package/src/converter/shard-packer.d.ts +1 -1
  88. package/src/converter/shard-packer.js +4 -1
  89. package/src/converter/tokenizer-utils.d.ts +1 -0
  90. package/src/converter/tokenizer-utils.js +4 -1
  91. package/src/debug/config.js +123 -11
  92. package/src/debug/reference/hf_qwen35_linear_attn_debug.py +268 -0
  93. package/src/debug/signals.js +7 -1
  94. package/src/debug/tensor.d.ts +2 -0
  95. package/src/debug/tensor.js +13 -2
  96. package/src/distribution/p2p-control-plane.js +52 -12
  97. package/src/distribution/p2p-observability.js +43 -7
  98. package/src/distribution/p2p-webrtc-browser.js +20 -0
  99. package/src/distribution/shard-delivery.js +83 -27
  100. package/src/formats/gguf/types.js +33 -16
  101. package/src/formats/rdrr/groups.d.ts +12 -4
  102. package/src/formats/rdrr/groups.js +3 -6
  103. package/src/formats/rdrr/parsing.d.ts +4 -0
  104. package/src/formats/rdrr/parsing.js +53 -3
  105. package/src/formats/rdrr/types.d.ts +2 -1
  106. package/src/gpu/command-recorder.js +86 -61
  107. package/src/gpu/device.d.ts +1 -0
  108. package/src/gpu/device.js +73 -19
  109. package/src/gpu/kernel-tuner/benchmarks.js +326 -316
  110. package/src/gpu/kernel-tuner/cache.js +71 -4
  111. package/src/gpu/kernel-tuner/tuner.js +22 -4
  112. package/src/gpu/kernels/attention.js +15 -34
  113. package/src/gpu/kernels/backward/adam.js +62 -58
  114. package/src/gpu/kernels/backward/attention_backward.js +257 -169
  115. package/src/gpu/kernels/backward/conv2d_backward.js +14 -1
  116. package/src/gpu/kernels/cast.js +191 -149
  117. package/src/gpu/kernels/check-stop.js +33 -44
  118. package/src/gpu/kernels/conv2d.js +27 -17
  119. package/src/gpu/kernels/cross_entropy_loss.js +21 -15
  120. package/src/gpu/kernels/depthwise_conv2d.js +36 -26
  121. package/src/gpu/kernels/dequant.js +178 -126
  122. package/src/gpu/kernels/energy.d.ts +3 -21
  123. package/src/gpu/kernels/energy.js +111 -88
  124. package/src/gpu/kernels/feature-check.js +1 -1
  125. package/src/gpu/kernels/fused_ffn.js +84 -65
  126. package/src/gpu/kernels/fused_matmul_residual.js +56 -33
  127. package/src/gpu/kernels/fused_matmul_rmsnorm.js +62 -45
  128. package/src/gpu/kernels/gather.js +33 -15
  129. package/src/gpu/kernels/gelu.js +19 -11
  130. package/src/gpu/kernels/grouped_pointwise_conv2d.js +33 -23
  131. package/src/gpu/kernels/groupnorm.js +34 -23
  132. package/src/gpu/kernels/index.d.ts +8 -0
  133. package/src/gpu/kernels/index.js +6 -0
  134. package/src/gpu/kernels/kv-quantize.js +5 -2
  135. package/src/gpu/kernels/layernorm.js +35 -19
  136. package/src/gpu/kernels/logit-merge.js +5 -3
  137. package/src/gpu/kernels/matmul-selection.js +47 -4
  138. package/src/gpu/kernels/matmul.d.ts +2 -0
  139. package/src/gpu/kernels/matmul.js +59 -40
  140. package/src/gpu/kernels/modulate.js +23 -15
  141. package/src/gpu/kernels/moe.js +221 -175
  142. package/src/gpu/kernels/pixel_shuffle.js +22 -14
  143. package/src/gpu/kernels/relu.js +18 -10
  144. package/src/gpu/kernels/repeat_channels.js +25 -17
  145. package/src/gpu/kernels/residual.js +37 -27
  146. package/src/gpu/kernels/rmsnorm.js +66 -43
  147. package/src/gpu/kernels/rope.js +3 -0
  148. package/src/gpu/kernels/sample.js +27 -38
  149. package/src/gpu/kernels/sana_linear_attention.js +18 -10
  150. package/src/gpu/kernels/scale.js +18 -11
  151. package/src/gpu/kernels/shader-cache.js +4 -2
  152. package/src/gpu/kernels/silu.js +120 -72
  153. package/src/gpu/kernels/softmax.js +44 -25
  154. package/src/gpu/kernels/split_qg.d.ts +50 -0
  155. package/src/gpu/kernels/split_qg.js +46 -0
  156. package/src/gpu/kernels/split_qg.wgsl +58 -0
  157. package/src/gpu/kernels/split_qg_f16.wgsl +62 -0
  158. package/src/gpu/kernels/split_qkv.js +23 -13
  159. package/src/gpu/kernels/transpose.js +18 -10
  160. package/src/gpu/kernels/transpose.wgsl +5 -3
  161. package/src/gpu/kernels/upsample2d.js +21 -13
  162. package/src/gpu/kernels/utils.js +20 -13
  163. package/src/gpu/partitioned-buffer-pool.js +10 -2
  164. package/src/gpu/perf-guards.js +2 -9
  165. package/src/gpu/profiler.js +27 -22
  166. package/src/gpu/readback-utils.d.ts +16 -0
  167. package/src/gpu/readback-utils.js +41 -0
  168. package/src/gpu/submit-tracker.js +13 -0
  169. package/src/gpu/uniform-cache.d.ts +1 -0
  170. package/src/gpu/uniform-cache.js +30 -9
  171. package/src/gpu/weight-buffer.d.ts +1 -1
  172. package/src/gpu/weight-buffer.js +1 -1
  173. package/src/hotswap/intent-bundle.js +6 -0
  174. package/src/hotswap/manifest.d.ts +10 -1
  175. package/src/hotswap/manifest.js +12 -2
  176. package/src/hotswap/runtime.js +30 -8
  177. package/src/index-browser.d.ts +44 -0
  178. package/src/index-browser.js +14 -0
  179. package/src/inference/browser-harness-contract-helpers.d.ts +5 -0
  180. package/src/inference/browser-harness-contract-helpers.js +28 -0
  181. package/src/inference/browser-harness-diffusion-energy-suites.d.ts +2 -0
  182. package/src/inference/browser-harness-diffusion-energy-suites.js +269 -0
  183. package/src/inference/browser-harness-model-helpers.d.ts +16 -0
  184. package/src/inference/browser-harness-model-helpers.js +217 -0
  185. package/src/inference/browser-harness-report-helpers.d.ts +7 -0
  186. package/src/inference/browser-harness-report-helpers.js +42 -0
  187. package/src/inference/browser-harness-runtime-helpers.d.ts +61 -0
  188. package/src/inference/browser-harness-runtime-helpers.js +415 -0
  189. package/src/inference/browser-harness-suite-helpers.d.ts +28 -0
  190. package/src/inference/browser-harness-suite-helpers.js +268 -0
  191. package/src/inference/browser-harness-text-helpers.d.ts +27 -0
  192. package/src/inference/browser-harness-text-helpers.js +788 -0
  193. package/src/inference/browser-harness.d.ts +8 -0
  194. package/src/inference/browser-harness.js +149 -1996
  195. package/src/inference/kv-cache/base.js +140 -94
  196. package/src/inference/kv-cache/tiered.js +5 -3
  197. package/src/inference/moe-router.js +88 -56
  198. package/src/inference/multi-model-network.js +5 -3
  199. package/src/inference/network-evolution.d.ts +11 -2
  200. package/src/inference/network-evolution.js +20 -21
  201. package/src/inference/pipelines/context.d.ts +3 -0
  202. package/src/inference/pipelines/context.js +142 -2
  203. package/src/inference/pipelines/diffusion/helpers.js +10 -2
  204. package/src/inference/pipelines/diffusion/pipeline.js +2 -1
  205. package/src/inference/pipelines/diffusion/sd3-transformer.js +10 -10
  206. package/src/inference/pipelines/diffusion/text-encoder-gpu.js +8 -2
  207. package/src/inference/pipelines/diffusion/vae.js +3 -7
  208. package/src/inference/pipelines/energy/pipeline.js +27 -21
  209. package/src/inference/pipelines/energy/quintel.d.ts +5 -0
  210. package/src/inference/pipelines/energy/quintel.js +11 -0
  211. package/src/inference/pipelines/energy-head/row-head-pipeline.js +17 -13
  212. package/src/inference/pipelines/structured/json-head-pipeline.js +26 -11
  213. package/src/inference/pipelines/text/attention/output-projection.d.ts +12 -0
  214. package/src/inference/pipelines/text/attention/output-projection.js +8 -0
  215. package/src/inference/pipelines/text/attention/projections.d.ts +10 -1
  216. package/src/inference/pipelines/text/attention/projections.js +192 -112
  217. package/src/inference/pipelines/text/attention/record.js +77 -14
  218. package/src/inference/pipelines/text/attention/run.js +112 -14
  219. package/src/inference/pipelines/text/config.js +17 -4
  220. package/src/inference/pipelines/text/embed.js +2 -8
  221. package/src/inference/pipelines/text/execution-plan.js +46 -23
  222. package/src/inference/pipelines/text/execution-v0-contract-helpers.d.ts +59 -0
  223. package/src/inference/pipelines/text/execution-v0-contract-helpers.js +937 -0
  224. package/src/inference/pipelines/text/execution-v0-runtime-builders.d.ts +15 -0
  225. package/src/inference/pipelines/text/execution-v0-runtime-builders.js +279 -0
  226. package/src/inference/pipelines/text/execution-v0.js +62 -1013
  227. package/src/inference/pipelines/text/generator-runtime.js +5 -0
  228. package/src/inference/pipelines/text/generator-steps.d.ts +52 -0
  229. package/src/inference/pipelines/text/generator-steps.js +340 -221
  230. package/src/inference/pipelines/text/generator.js +56 -40
  231. package/src/inference/pipelines/text/init.d.ts +13 -0
  232. package/src/inference/pipelines/text/init.js +94 -25
  233. package/src/inference/pipelines/text/kernel-path-auto-select.js +2 -0
  234. package/src/inference/pipelines/text/kernel-trace.d.ts +2 -0
  235. package/src/inference/pipelines/text/kernel-trace.js +6 -0
  236. package/src/inference/pipelines/text/layer.js +4 -9
  237. package/src/inference/pipelines/text/linear-attention.d.ts +15 -0
  238. package/src/inference/pipelines/text/linear-attention.js +113 -9
  239. package/src/inference/pipelines/text/logits/gpu.js +12 -7
  240. package/src/inference/pipelines/text/logits/index.d.ts +6 -1
  241. package/src/inference/pipelines/text/logits/index.js +13 -12
  242. package/src/inference/pipelines/text/logits/utils.d.ts +7 -0
  243. package/src/inference/pipelines/text/logits/utils.js +9 -0
  244. package/src/inference/pipelines/text/lora-apply.js +50 -32
  245. package/src/inference/pipelines/text/model-load.js +282 -104
  246. package/src/inference/pipelines/text/moe-cache.js +5 -4
  247. package/src/inference/pipelines/text/moe-cpu-gptoss.js +74 -69
  248. package/src/inference/pipelines/text/moe-cpu.js +42 -38
  249. package/src/inference/pipelines/text/moe-gpu.js +110 -86
  250. package/src/inference/pipelines/text/ops.js +90 -90
  251. package/src/inference/pipelines/text/probes.js +9 -9
  252. package/src/inference/pipelines/text/sampling.js +52 -6
  253. package/src/inference/pipelines/text/weights.js +17 -7
  254. package/src/inference/pipelines/text.js +13 -1
  255. package/src/inference/speculative.d.ts +2 -2
  256. package/src/inference/speculative.js +4 -18
  257. package/src/inference/test-harness.d.ts +1 -1
  258. package/src/inference/test-harness.js +17 -7
  259. package/src/inference/tokenizer.d.ts +0 -5
  260. package/src/inference/tokenizer.js +4 -23
  261. package/src/inference/tokenizers/bpe.js +9 -0
  262. package/src/inference/tokenizers/bundled.js +20 -0
  263. package/src/inference/tokenizers/sentencepiece.js +12 -0
  264. package/src/loader/doppler-loader.js +38 -22
  265. package/src/loader/dtype-utils.js +3 -44
  266. package/src/loader/embedding-loader.js +7 -3
  267. package/src/loader/experts/expert-cache.js +13 -6
  268. package/src/loader/experts/expert-loader.js +10 -6
  269. package/src/loader/final-weights-loader.js +10 -4
  270. package/src/loader/layer-loader.js +2 -1
  271. package/src/loader/loader-state.js +2 -2
  272. package/src/loader/memory-monitor.js +8 -0
  273. package/src/loader/multi-model-loader.d.ts +14 -0
  274. package/src/loader/multi-model-loader.js +70 -24
  275. package/src/loader/shard-cache.js +84 -14
  276. package/src/loader/shard-resolver.js +25 -3
  277. package/src/loader/tensors/tensor-loader.js +214 -144
  278. package/src/loader/tensors/tensor-reader.js +76 -19
  279. package/src/loader/weight-downcast.js +1 -1
  280. package/src/memory/buffer-pool.d.ts +9 -1
  281. package/src/memory/buffer-pool.js +109 -44
  282. package/src/memory/unified-detect.js +1 -1
  283. package/src/rules/inference/dtype.rules.json +5 -0
  284. package/src/rules/inference/kernel-path.rules.json +24 -8
  285. package/src/rules/kernels/split-qg.rules.json +6 -0
  286. package/src/rules/rule-registry.js +27 -1
  287. package/src/storage/backends/opfs-store.js +68 -24
  288. package/src/storage/downloader.js +365 -83
  289. package/src/storage/index.d.ts +3 -0
  290. package/src/storage/index.js +3 -0
  291. package/src/storage/preflight.d.ts +2 -2
  292. package/src/storage/preflight.js +24 -2
  293. package/src/storage/quickstart-downloader.js +11 -5
  294. package/src/storage/registry.js +10 -4
  295. package/src/storage/reports.js +1 -1
  296. package/src/storage/shard-manager.d.ts +15 -1
  297. package/src/storage/shard-manager.js +55 -6
  298. package/src/storage/source-artifact-store.d.ts +52 -0
  299. package/src/storage/source-artifact-store.js +234 -0
  300. package/src/tooling/command-api-constants.d.ts +9 -0
  301. package/src/tooling/command-api-constants.js +9 -0
  302. package/src/tooling/command-api-family-normalizers.d.ts +9 -0
  303. package/src/tooling/command-api-family-normalizers.js +343 -0
  304. package/src/tooling/command-api-helpers.d.ts +25 -0
  305. package/src/tooling/command-api-helpers.js +262 -0
  306. package/src/tooling/command-api.js +16 -602
  307. package/src/tooling/command-envelope.js +4 -1
  308. package/src/tooling/command-runner-shared.js +52 -18
  309. package/src/tooling/conversion-config-materializer.js +3 -5
  310. package/src/tooling/lean-execution-contract.js +150 -3
  311. package/src/tooling/node-browser-command-runner.js +161 -271
  312. package/src/tooling/node-command-runner.js +29 -3
  313. package/src/tooling/node-converter.js +30 -1
  314. package/src/tooling/node-source-runtime.d.ts +1 -1
  315. package/src/tooling/node-source-runtime.js +120 -3
  316. package/src/tooling/node-webgpu.js +24 -21
  317. package/src/tooling/opfs-cache.js +21 -4
  318. package/src/tooling/runtime-input-composition.d.ts +38 -0
  319. package/src/tooling/runtime-input-composition.js +86 -0
  320. package/src/tooling/source-runtime-bundle.d.ts +40 -5
  321. package/src/tooling/source-runtime-bundle.js +261 -34
  322. package/src/tooling/source-runtime-materializer.d.ts +6 -0
  323. package/src/tooling/source-runtime-materializer.js +93 -0
  324. package/src/training/attention-backward.js +32 -17
  325. package/src/training/autograd.js +80 -52
  326. package/src/training/checkpoint-watch.d.ts +2 -1
  327. package/src/training/checkpoint-watch.js +39 -6
  328. package/src/training/checkpoint.js +40 -11
  329. package/src/training/clip.js +2 -1
  330. package/src/training/datasets/token-batch.js +20 -8
  331. package/src/training/distillation/checkpoint-watch.js +1 -0
  332. package/src/training/distillation/student-fixture.d.ts +22 -0
  333. package/src/training/distillation/student-fixture.js +846 -0
  334. package/src/training/distillation/suite-data.d.ts +45 -0
  335. package/src/training/distillation/suite-data.js +189 -0
  336. package/src/training/lora-pipeline.js +4 -7
  337. package/src/training/lora.js +26 -12
  338. package/src/training/loss.js +5 -6
  339. package/src/training/objectives/cross_entropy.js +2 -5
  340. package/src/training/objectives/distill_kd.js +4 -8
  341. package/src/training/objectives/distill_triplet.js +4 -8
  342. package/src/training/objectives/ul_stage2_base.js +4 -8
  343. package/src/training/operator-command.js +2 -0
  344. package/src/training/optimizer.js +19 -7
  345. package/src/training/runner.js +2 -1
  346. package/src/training/suite.js +18 -978
  347. package/src/training/tensor-factory.d.ts +9 -0
  348. package/src/training/tensor-factory.js +13 -0
  349. package/src/training/trainer.js +3 -5
  350. package/src/training/ul_dataset.js +3 -5
  351. package/src/training/workloads.js +70 -79
  352. package/src/types/model.d.ts +5 -0
  353. package/src/version.js +1 -1
  354. package/tools/convert-safetensors-node.js +22 -16
  355. package/tools/doppler-cli.js +50 -26
@@ -122,6 +122,20 @@ function resolveTokenText(tokenizer, tokenIds, fallbackText = '?', renderTokenTe
122
122
  return fallbackText;
123
123
  }
124
124
 
125
+ export function shouldRetryWithFinitenessFallback(error) {
126
+ if (error?.name === 'FinitenessError') {
127
+ return true;
128
+ }
129
+ const message = typeof error?.message === 'string'
130
+ ? error.message
131
+ : (typeof error === 'string' ? error : '');
132
+ if (!message.startsWith('[Sampling]')) {
133
+ return false;
134
+ }
135
+ return message.includes('no finite candidate logits after masking the pad token')
136
+ || message.includes('Softmax produced no finite candidate probabilities');
137
+ }
138
+
125
139
  export class PipelineGenerator {
126
140
 
127
141
  #state;
@@ -351,7 +365,7 @@ export class PipelineGenerator {
351
365
  try {
352
366
  prefillLogits = await this._prefill(inputIds, opts);
353
367
  } catch (error) {
354
- if (error.name === 'FinitenessError') {
368
+ if (shouldRetryWithFinitenessFallback(error)) {
355
369
  log.warn('Pipeline', `FinitenessGuard caught NaN/Inf during prefill. Retrying with F32 precision.`);
356
370
  prefillLogits = await this._retryWithFinitenessFallback(
357
371
  opts,
@@ -395,13 +409,34 @@ export class PipelineGenerator {
395
409
  log.debug('Pipeline', `After rep penalty top-5: ${topAfterPenalty.map(t => `"${t.text}"(${(t.prob * 100).toFixed(1)}%)`).join(', ')}`);
396
410
  }
397
411
 
398
- const firstToken = sample(prefillLogits, {
399
- temperature: opts.temperature,
400
- topP: opts.topP,
401
- topK: opts.topK,
402
- padTokenId,
403
- seed: opts.seed,
404
- });
412
+ let firstToken;
413
+ try {
414
+ firstToken = sample(prefillLogits, {
415
+ temperature: opts.temperature,
416
+ topP: opts.topP,
417
+ topK: opts.topK,
418
+ padTokenId,
419
+ seed: opts.seed,
420
+ });
421
+ } catch (error) {
422
+ if (!shouldRetryWithFinitenessFallback(error)) {
423
+ throw error;
424
+ }
425
+ log.warn('Pipeline', 'FinitenessGuard caught non-finite prefill logits at sampling. Retrying with F32 precision.');
426
+ prefillLogits = await this._retryWithFinitenessFallback(
427
+ opts,
428
+ 'prefill-sample',
429
+ () => this._prefill(inputIds, opts)
430
+ );
431
+ applyRepetitionPenalty(prefillLogits, generatedIds, opts.repetitionPenalty);
432
+ firstToken = sample(prefillLogits, {
433
+ temperature: opts.temperature,
434
+ topP: opts.topP,
435
+ topK: opts.topK,
436
+ padTokenId,
437
+ seed: opts.seed,
438
+ });
439
+ }
405
440
 
406
441
  if (opts.debug) {
407
442
  const firstTokenText = resolveTokenText(this.#state.tokenizer, [firstToken], `[${firstToken}]`, (tokens) => this.#state.tokenizer?.decode?.(tokens, true, false));
@@ -479,7 +514,7 @@ export class PipelineGenerator {
479
514
  try {
480
515
  prefillResult = await this._prefillToHidden(inputIds, opts);
481
516
  } catch (error) {
482
- if (error.name === 'FinitenessError') {
517
+ if (shouldRetryWithFinitenessFallback(error)) {
483
518
  log.warn('Pipeline', `FinitenessGuard caught NaN/Inf during prefillKVOnly. Retrying with F32 precision.`);
484
519
  prefillResult = await this._retryWithFinitenessFallback(
485
520
  opts,
@@ -544,7 +579,7 @@ export class PipelineGenerator {
544
579
  try {
545
580
  prefillResult = await this._prefillToHidden(inputIds, opts);
546
581
  } catch (error) {
547
- if (error.name === 'FinitenessError') {
582
+ if (shouldRetryWithFinitenessFallback(error)) {
548
583
  log.warn('Pipeline', `FinitenessGuard caught NaN/Inf during prefillWithEmbedding. Retrying with F32 precision.`);
549
584
  prefillResult = await this._retryWithFinitenessFallback(
550
585
  opts,
@@ -833,7 +868,7 @@ export class PipelineGenerator {
833
868
  try {
834
869
  nextToken = await this._decodeStep(generatedIds, opts);
835
870
  } catch (singleTokenError) {
836
- if (singleTokenError.name === 'FinitenessError') {
871
+ if (shouldRetryWithFinitenessFallback(singleTokenError)) {
837
872
  log.warn('Pipeline', `FinitenessGuard caught NaN/Inf at batch step ${tokensGenerated}. Truncating KV cache and retrying token with F32 precision.`);
838
873
  nextToken = await this._retryDecodeStepWithFinitenessWindow(
839
874
  generatedIds,
@@ -858,7 +893,7 @@ export class PipelineGenerator {
858
893
  try {
859
894
  nextToken = await this._decodeStep(generatedIds, opts);
860
895
  } catch (error) {
861
- if (error.name === 'FinitenessError') {
896
+ if (shouldRetryWithFinitenessFallback(error)) {
862
897
  log.warn('Pipeline', `FinitenessGuard caught NaN/Inf at step ${tokensGenerated}. Truncating KV cache and retrying token with F32 precision.`);
863
898
  nextToken = await this._retryDecodeStepWithFinitenessWindow(
864
899
  generatedIds,
@@ -918,11 +953,9 @@ export class PipelineGenerator {
918
953
  throw new Error('Embed buffer not found or not a supported buffer type');
919
954
  }
920
955
  const embedBuffer = isWeightBuffer(embedBufferRaw) ? embedBufferRaw.buffer : embedBufferRaw;
921
- const embedDtype = isWeightBuffer(embedBufferRaw)
922
- ? getWeightDtype(embedBufferRaw)
923
- : isCpuWeightBuffer(embedBufferRaw)
924
- ? embedBufferRaw.dtype
925
- : null;
956
+ const embedDtype = isCpuWeightBuffer(embedBufferRaw)
957
+ ? embedBufferRaw.dtype
958
+ : getWeightDtype(embedBufferRaw);
926
959
  if (opts.debug) {
927
960
  const embedSize = embedBuffer instanceof GPUBuffer ? embedBuffer.size : 'N/A';
928
961
  log.debug('Pipeline', `Embed buffer: type=${embedBuffer?.constructor?.name}, size=${embedSize}, dtype=${embedDtype}`);
@@ -1043,18 +1076,9 @@ export class PipelineGenerator {
1043
1076
  if (allowReadback(`pipeline.prefill.layer-${l}`)) {
1044
1077
  try {
1045
1078
  const sampleSize = config.hiddenSize * activationBytes;
1046
- const staging = device.createBuffer({
1047
- size: sampleSize,
1048
- usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ,
1049
- });
1050
- const enc = device.createCommandEncoder();
1051
1079
  const lastTokenOffset = (numTokens - 1) * config.hiddenSize * activationBytes;
1052
- enc.copyBufferToBuffer(currentHiddenBuffer, lastTokenOffset, staging, 0, sampleSize);
1053
- device.queue.submit([enc.finish()]);
1054
- await staging.mapAsync(GPUMapMode.READ);
1055
- const data = decodeReadback(staging.getMappedRange().slice(0), activationDtype);
1056
- staging.unmap();
1057
- staging.destroy();
1080
+ const readback = await readBufferSlice(currentHiddenBuffer, lastTokenOffset, sampleSize);
1081
+ const data = decodeReadback(readback, activationDtype);
1058
1082
  let min = Infinity;
1059
1083
  let max = -Infinity;
1060
1084
  let maxAbs = 0;
@@ -1112,20 +1136,12 @@ export class PipelineGenerator {
1112
1136
  if (opts.debug) {
1113
1137
  log.debug('Pipeline', `LAYER_LOOP_DONE, currentHiddenBuffer type=${currentHiddenBuffer?.constructor?.name}`);
1114
1138
  if (currentHiddenBuffer && allowReadback('pipeline.prefill.final-hidden')) {
1115
- const device = getDevice();
1116
1139
  const lastTokenOffset = (numTokens - 1) * config.hiddenSize * activationBytes;
1117
1140
  const sampleSize = config.hiddenSize * activationBytes;
1118
- const staging = device.createBuffer({
1119
- size: sampleSize,
1120
- usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ,
1121
- });
1122
- const enc = device.createCommandEncoder();
1123
- enc.copyBufferToBuffer(currentHiddenBuffer, lastTokenOffset, staging, 0, sampleSize);
1124
- device.queue.submit([enc.finish()]);
1125
- await staging.mapAsync(GPUMapMode.READ);
1126
- const data = decodeReadback(staging.getMappedRange().slice(0), activationDtype);
1127
- staging.unmap();
1128
- staging.destroy();
1141
+ const data = decodeReadback(
1142
+ await readBufferSlice(currentHiddenBuffer, lastTokenOffset, sampleSize),
1143
+ activationDtype
1144
+ );
1129
1145
  const nanCount = Array.from(data).filter(x => !Number.isFinite(x)).length;
1130
1146
  const nonZero = Array.from(data).filter(x => Number.isFinite(x) && x !== 0).slice(0, 5);
1131
1147
  log.debug('Pipeline', `FINAL_HIDDEN[pos=${numTokens - 1}]: nan=${nanCount}/${data.length}, sample=[${nonZero.map(x => x.toFixed(4)).join(', ')}]`);
@@ -190,6 +190,12 @@ export interface WeightLoadResult {
190
190
  layerRouterWeights: Map<number, RouterWeights>;
191
191
  }
192
192
 
193
+ export interface ResolvedQ4KConfig {
194
+ useFusedQ4K: boolean;
195
+ q4kLayout: 'row' | 'col' | null;
196
+ keepF32Weights: boolean;
197
+ }
198
+
193
199
  /** Options for loadWeights */
194
200
  export interface LoadWeightsOptions {
195
201
  storageContext?: PipelineStorageContext;
@@ -211,6 +217,13 @@ export function loadWeights(
211
217
  options?: LoadWeightsOptions
212
218
  ): Promise<WeightLoadResult>;
213
219
 
220
+ export function resolveQ4KConfig(
221
+ manifest: Manifest,
222
+ kernelPath?: KernelPathSchema | null,
223
+ kernelPathSource?: KernelPathSource,
224
+ keepF32Weights?: boolean
225
+ ): ResolvedQ4KConfig;
226
+
214
227
  /**
215
228
  * Apply Gemma chat template to a prompt.
216
229
  */
@@ -2,7 +2,7 @@
2
2
 
3
3
  import { parseModelConfig } from './config.js';
4
4
  import { getDevice, getDeviceLimits, getKernelCapabilities } from '../../../gpu/device.js';
5
- import { acquireBuffer } from '../../../memory/buffer-pool.js';
5
+ import { acquireBuffer, releaseBuffer } from '../../../memory/buffer-pool.js';
6
6
  import { KVCache, SlidingWindowKVCache, TieredKVCache, BasisDecomposedPagedCache } from '../../kv-cache.js';
7
7
  import { Tokenizer } from '../../tokenizer.js';
8
8
  import { MoERouter } from '../../moe-router.js';
@@ -11,9 +11,13 @@ import { getDopplerLoader } from '../../../loader/doppler-loader.js';
11
11
  import { log, setGPUDevice, trace as debugTrace } from '../../../debug/index.js';
12
12
  import { getRuntimeConfig } from '../../../config/runtime.js';
13
13
  import { PAGED_LAYOUT_SEQ_LEN_THRESHOLD } from '../../../config/schema/index.js';
14
- import { isKernelPathFusedQ4K } from '../../../config/kernel-path-loader.js';
14
+ import { isKernelPathFusedQ4K, kernelPathRequiresF32MatmulWeights } from '../../../config/kernel-path-loader.js';
15
15
  import { createWeightBuffer, getWeightDtype, isWeightBuffer } from '../../../gpu/weight-buffer.js';
16
16
  import { selectRuleValue } from '../../../rules/rule-registry.js';
17
+ import {
18
+ createSourceStorageContext,
19
+ getSourceRuntimeMetadata,
20
+ } from '../../../tooling/source-runtime-bundle.js';
17
21
 
18
22
  function resolveErrorMessage(error) {
19
23
  if (error && typeof error === 'object' && typeof error.message === 'string') {
@@ -56,12 +60,61 @@ function normalizeBaseUrl(baseUrl) {
56
60
  return baseUrl.replace(/\/$/, '');
57
61
  }
58
62
 
63
+ async function fetchBytes(url, offset = null, length = null) {
64
+ const headers = {};
65
+ if (Number.isFinite(offset) && Number.isFinite(length) && length > 0) {
66
+ const start = Math.max(0, Math.floor(offset));
67
+ const end = start + Math.max(0, Math.floor(length)) - 1;
68
+ headers.Range = `bytes=${start}-${end}`;
69
+ }
70
+ const response = await fetch(url, { headers });
71
+ if (!response.ok) {
72
+ throw new Error(`Failed to fetch ${url}: ${response.status}`);
73
+ }
74
+ return new Uint8Array(await response.arrayBuffer());
75
+ }
76
+
59
77
  function createRemoteStorageContext(baseUrl, manifest) {
60
78
  const root = normalizeBaseUrl(baseUrl);
61
79
  if (!root || !isRDRRManifest(manifest)) {
62
80
  return null;
63
81
  }
64
82
 
83
+ const sourceRuntime = getSourceRuntimeMetadata(manifest);
84
+ if (sourceRuntime) {
85
+ const readRange = async (relativePath, offset, length) => {
86
+ const filename = String(relativePath || '').replace(/^\/+/, '');
87
+ if (!filename) {
88
+ throw new Error('Direct-source artifact path is required.');
89
+ }
90
+ const url = `${root}/${filename}`;
91
+ return fetchBytes(url, offset, length);
92
+ };
93
+ const readText = async (relativePath) => {
94
+ const filename = String(relativePath || '').replace(/^\/+/, '');
95
+ if (!filename) return null;
96
+ const response = await fetch(`${root}/${filename}`);
97
+ if (!response.ok) {
98
+ throw new Error(`Failed to fetch ${filename} from ${root}: ${response.status}`);
99
+ }
100
+ return response.text();
101
+ };
102
+ const readBinary = async (relativePath) => {
103
+ const filename = String(relativePath || '').replace(/^\/+/, '');
104
+ if (!filename) {
105
+ throw new Error('Direct-source binary asset path is required.');
106
+ }
107
+ return fetchBytes(`${root}/${filename}`);
108
+ };
109
+ return createSourceStorageContext({
110
+ manifest,
111
+ readRange,
112
+ readText,
113
+ readBinary,
114
+ verifyHashes: true,
115
+ });
116
+ }
117
+
65
118
  return {
66
119
  async loadShard(index) {
67
120
  const shard = manifest.shards[index];
@@ -69,17 +122,13 @@ function createRemoteStorageContext(baseUrl, manifest) {
69
122
  if (!filename) {
70
123
  throw new Error(`Manifest shard ${index} is missing filename.`);
71
124
  }
72
- const response = await fetch(`${root}/${filename.replace(/^\/+/, '')}`);
73
- if (!response.ok) {
74
- throw new Error(`Failed to fetch shard ${index} from ${root}: ${response.status}`);
75
- }
76
- return new Uint8Array(await response.arrayBuffer());
125
+ return fetchBytes(`${root}/${filename.replace(/^\/+/, '')}`);
77
126
  },
78
127
  };
79
128
  }
80
129
 
81
130
 
82
- function resolveQ4KConfig(
131
+ export function resolveQ4KConfig(
83
132
  manifest,
84
133
  kernelPath,
85
134
  kernelPathSource = 'none',
@@ -101,18 +150,23 @@ function resolveQ4KConfig(
101
150
  );
102
151
  }
103
152
  let useFused = kernelPath ? isKernelPathFusedQ4K(kernelPath) : hasSubgroups;
153
+ const kernelPathKeepsF32Weights = kernelPathRequiresF32MatmulWeights(kernelPath);
104
154
  if (q4kLayout === 'col') {
105
155
  useFused = false;
106
156
  }
157
+ const resolvedKeepF32Weights = keepF32Weights || kernelPathKeepsF32Weights;
107
158
 
108
159
  const pathLabel = kernelPath?.id ?? 'auto';
109
160
  const layoutLabel = q4kLayout ?? 'none';
110
- debugTrace.loader(`Q4K config: fused=${useFused}, kernelPath=${pathLabel}, source=${kernelPathSource}, layout=${layoutLabel}, subgroups=${hasSubgroups}`);
161
+ debugTrace.loader(
162
+ `Q4K config: fused=${useFused}, kernelPath=${pathLabel}, source=${kernelPathSource}, ` +
163
+ `layout=${layoutLabel}, keepF32Weights=${resolvedKeepF32Weights}, subgroups=${hasSubgroups}`
164
+ );
111
165
 
112
166
  return {
113
167
  useFusedQ4K: useFused,
114
168
  q4kLayout,
115
- keepF32Weights,
169
+ keepF32Weights: resolvedKeepF32Weights,
116
170
  };
117
171
  }
118
172
 
@@ -326,20 +380,29 @@ export async function initRoPEFrequencies(config, useGPU) {
326
380
  // Upload to GPU if available
327
381
  const device = getDevice();
328
382
  if (device && useGPU) {
329
- const cosBuffer = acquireBuffer(globalFreqs.cos.byteLength, undefined, 'rope_cos');
330
- const sinBuffer = acquireBuffer(globalFreqs.sin.byteLength, undefined, 'rope_sin');
331
- device.queue.writeBuffer(cosBuffer, 0, globalFreqs.cos.buffer, globalFreqs.cos.byteOffset, globalFreqs.cos.byteLength);
332
- device.queue.writeBuffer(sinBuffer, 0, globalFreqs.sin.buffer, globalFreqs.sin.byteOffset, globalFreqs.sin.byteLength);
333
-
334
-
335
- let localCosBuffer;
336
-
337
- let localSinBuffer;
338
- if (localFreqs) {
339
- localCosBuffer = acquireBuffer(localFreqs.cos.byteLength, undefined, 'rope_local_cos');
340
- localSinBuffer = acquireBuffer(localFreqs.sin.byteLength, undefined, 'rope_local_sin');
341
- device.queue.writeBuffer(localCosBuffer, 0, localFreqs.cos.buffer, localFreqs.cos.byteOffset, localFreqs.cos.byteLength);
342
- device.queue.writeBuffer(localSinBuffer, 0, localFreqs.sin.buffer, localFreqs.sin.byteOffset, localFreqs.sin.byteLength);
383
+ let cosBuffer = null;
384
+ let sinBuffer = null;
385
+ let localCosBuffer = null;
386
+ let localSinBuffer = null;
387
+ try {
388
+ cosBuffer = acquireBuffer(globalFreqs.cos.byteLength, undefined, 'rope_cos');
389
+ sinBuffer = acquireBuffer(globalFreqs.sin.byteLength, undefined, 'rope_sin');
390
+ device.queue.writeBuffer(cosBuffer, 0, globalFreqs.cos.buffer, globalFreqs.cos.byteOffset, globalFreqs.cos.byteLength);
391
+ device.queue.writeBuffer(sinBuffer, 0, globalFreqs.sin.buffer, globalFreqs.sin.byteOffset, globalFreqs.sin.byteLength);
392
+
393
+ if (localFreqs) {
394
+ localCosBuffer = acquireBuffer(localFreqs.cos.byteLength, undefined, 'rope_local_cos');
395
+ localSinBuffer = acquireBuffer(localFreqs.sin.byteLength, undefined, 'rope_local_sin');
396
+ device.queue.writeBuffer(localCosBuffer, 0, localFreqs.cos.buffer, localFreqs.cos.byteOffset, localFreqs.cos.byteLength);
397
+ device.queue.writeBuffer(localSinBuffer, 0, localFreqs.sin.buffer, localFreqs.sin.byteOffset, localFreqs.sin.byteLength);
398
+ }
399
+ } catch (error) {
400
+ for (const buffer of [cosBuffer, sinBuffer, localCosBuffer, localSinBuffer]) {
401
+ if (buffer) {
402
+ releaseBuffer(buffer);
403
+ }
404
+ }
405
+ throw error;
343
406
  }
344
407
 
345
408
  log.debug(
@@ -444,6 +507,12 @@ export function createKVCache(modelConfig, useGPU, debug = false, runtimeConfig)
444
507
  cacheLayout = 'paged';
445
508
  layoutSource = 'threshold';
446
509
  }
510
+ if (forceContiguousKVCache && cacheLayout === 'paged') {
511
+ throw new Error(
512
+ 'Paged KV cache layout is not supported for models with full-attention layers. ' +
513
+ 'Set runtime.inference.kvcache.layout to "contiguous" instead.'
514
+ );
515
+ }
447
516
  if (debug && cacheLayout !== runtimeKV.layout) {
448
517
  log.debug('Pipeline', `KV cache layout override: ${runtimeKV.layout} -> ${cacheLayout} (${layoutSource})`);
449
518
  }
@@ -541,7 +610,7 @@ export function createKVCache(modelConfig, useGPU, debug = false, runtimeConfig)
541
610
 
542
611
  if (debug) {
543
612
  if (forceContiguousKVCache && modelConfig.layerTypes) {
544
- log.debug('Pipeline', 'Layer pattern includes full-attention layers; forcing contiguous KV cache.');
613
+ log.debug('Pipeline', 'Layer pattern includes full-attention layers; paged layout blocked, contiguous enforced.');
545
614
  }
546
615
  const isSliding = kvCache instanceof SlidingWindowKVCache;
547
616
  log.debug('Pipeline', `KV cache: type=${kvCache?.constructor?.name || 'unknown'}, kvDtype=${kvCache.kvDtype}, layout=${kvCache.layout}, maxSeqLen=${kvCache.maxSeqLen}, windowSize=${isSliding ? kvCache.windowSize : null}`);
@@ -78,6 +78,7 @@ export function resolveCapabilityKernelPathRef(configuredKernelPathRef, kernelPa
78
78
 
79
79
  const normalizedPolicy = resolveKernelPathPolicy(kernelPathPolicy);
80
80
  const hasSubgroups = capabilities?.hasSubgroups === true;
81
+ const hasF16 = capabilities?.hasF16 === true;
81
82
  const normalizedSource = normalizeKernelPathSource(kernelPathSource);
82
83
  const allowCapabilityAutoSelection = normalizedPolicy.mode === 'capability-aware'
83
84
  && normalizedPolicy.sourceScope.includes(normalizedSource);
@@ -85,6 +86,7 @@ export function resolveCapabilityKernelPathRef(configuredKernelPathRef, kernelPa
85
86
  return selectRuleValue('inference', 'kernelPath', 'autoSelect', {
86
87
  kernelPathRef: configuredKernelPathRef,
87
88
  hasSubgroups,
89
+ hasF16,
88
90
  allowCapabilityAutoSelection,
89
91
  });
90
92
  }
@@ -12,6 +12,8 @@
12
12
  * Snapshot of a tensor's statistics (no full data, just stats).
13
13
  */
14
14
  export interface TensorSnapshot {
15
+ ok: boolean;
16
+ error: string | null;
15
17
  shape: number[];
16
18
  dtype: string;
17
19
  stats: {
@@ -283,6 +283,9 @@ export async function traceStep(name, label, layer, outputBuffer, outputShape, o
283
283
  if (layer >= 0 && !kernelTrace.shouldTraceLayer(layer)) return;
284
284
 
285
285
  const output = await snapshotTensor(outputBuffer, outputShape);
286
+ if (!output.ok) {
287
+ throw new Error(`[TRACE] Failed to snapshot output for ${label}: ${output.error}`);
288
+ }
286
289
 
287
290
  // Snapshot inputs if provided (expensive - only do if tracing)
288
291
 
@@ -290,6 +293,9 @@ export async function traceStep(name, label, layer, outputBuffer, outputShape, o
290
293
  if (options?.inputs && options?.inputShapes) {
291
294
  for (let i = 0; i < options.inputs.length; i++) {
292
295
  const snap = await snapshotTensor(options.inputs[i], options.inputShapes[i]);
296
+ if (!snap.ok) {
297
+ throw new Error(`[TRACE] Failed to snapshot input ${i} for ${label}: ${snap.error}`);
298
+ }
293
299
  inputs.push(snap);
294
300
  }
295
301
  }
@@ -2,7 +2,7 @@
2
2
 
3
3
  import { log, trace } from '../../../debug/index.js';
4
4
  import { getDevice } from '../../../gpu/device.js';
5
- import { releaseBuffer } from '../../../memory/buffer-pool.js';
5
+ import { releaseBuffer, readBuffer } from '../../../memory/buffer-pool.js';
6
6
  import { allowReadback } from '../../../gpu/perf-guards.js';
7
7
  import { createTensor } from '../../../gpu/tensor.js';
8
8
  import {
@@ -228,6 +228,7 @@ export async function processLayerGPU(layerIdx, inputBuffer, numTokens, isPrefil
228
228
  linearRuntime: context.linearAttentionRuntime ?? null,
229
229
  getWeightBuffer: (weight, label) => getWeightBuffer(weight, label),
230
230
  getNormWeightBuffer: (weight, label) => getNormWeightBuffer(weight, label, weightConfig, debugFlags),
231
+ debugProbes: context.debugProbes,
231
232
  recorder: recorder ?? null,
232
233
  });
233
234
  } else {
@@ -275,6 +276,7 @@ export async function processLayerGPU(layerIdx, inputBuffer, numTokens, isPrefil
275
276
  : (ropeFreqsSin),
276
277
  kvCache: ((kvCache)),
277
278
  stats: context.stats,
279
+ debugProbes: context.debugProbes,
278
280
  linearRuntime: context.linearAttentionRuntime ?? null,
279
281
  };
280
282
 
@@ -314,14 +316,7 @@ export async function processLayerGPU(layerIdx, inputBuffer, numTokens, isPrefil
314
316
  if (allowReadback(`layer.attn-out.${layerIdx}`)) {
315
317
  try {
316
318
  const sampleSize = Math.min(128, attnOutput.buffer.size);
317
- const staging = device.createBuffer({ size: sampleSize, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ });
318
- const enc = device.createCommandEncoder();
319
- enc.copyBufferToBuffer(attnOutput.buffer, 0, staging, 0, sampleSize);
320
- device.queue.submit([enc.finish()]);
321
- await staging.mapAsync(GPUMapMode.READ);
322
- const data = new Float32Array(staging.getMappedRange().slice(0));
323
- staging.unmap();
324
- staging.destroy();
319
+ const data = new Float32Array(await readBuffer(attnOutput.buffer, sampleSize));
325
320
  let maxAbs = 0;
326
321
  for (let i = 0; i < data.length; i++) {
327
322
  const abs = Math.abs(data[i]);
@@ -3,6 +3,7 @@ import type { Tensor } from '../../../gpu/tensor.js';
3
3
  import type { WeightBuffer } from '../../../gpu/weight-buffer.js';
4
4
  import type { CommandRecorder } from '../../../gpu/command-recorder.js';
5
5
  import type { LinearNormMode } from '../../../config/schema/index.js';
6
+ import type { ProbeConfigSchema } from '../../../config/schema/index.js';
6
7
 
7
8
  export interface LinearLayerRuntimeState {
8
9
  layerIdx: number;
@@ -67,6 +68,7 @@ export interface RunLinearAttentionLayerOptions {
67
68
  weight: GPUBuffer | Float32Array | ArrayBuffer,
68
69
  label: string
69
70
  ) => GPUBuffer;
71
+ debugProbes?: ProbeConfigSchema[] | null;
70
72
  recorder?: CommandRecorder | null;
71
73
  }
72
74
 
@@ -74,6 +76,19 @@ export declare function hasLinearAttentionLayers(layerTypes: unknown): boolean;
74
76
 
75
77
  export declare function createLinearAttentionRuntime(): LinearAttentionRuntime;
76
78
 
79
+ export declare function inferLinearNormMode(
80
+ weight: { size?: number; dtype?: string } | GPUBuffer | WeightBuffer | ArrayBufferView | ArrayBuffer | null | undefined,
81
+ projectionLayout: {
82
+ headVDim: number;
83
+ valueDim: number;
84
+ }
85
+ ): LinearNormMode | null;
86
+
87
+ export declare function applyLinearNormWeightOffset(
88
+ values: Float32Array,
89
+ rmsNormWeightOffset: boolean
90
+ ): Float32Array;
91
+
77
92
  export declare function resetLinearAttentionRuntime(
78
93
  runtime: LinearAttentionRuntime | null | undefined
79
94
  ): LinearAttentionRuntime;