@simulatte/doppler 0.1.5 → 0.1.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (392) hide show
  1. package/CHANGELOG.md +126 -0
  2. package/README.md +25 -17
  3. package/package.json +20 -4
  4. package/src/adapters/adapter-registry.js +12 -1
  5. package/src/adapters/lora-loader.js +23 -6
  6. package/src/bridge/extension-client.d.ts +5 -0
  7. package/src/bridge/extension-client.js +40 -0
  8. package/src/bridge/index.d.ts +2 -1
  9. package/src/bridge/index.js +6 -4
  10. package/src/browser/browser-converter.js +26 -1
  11. package/src/browser/file-picker.js +6 -0
  12. package/src/browser/safetensors-parser-browser.js +84 -1
  13. package/src/browser/shard-io-browser.js +2 -2
  14. package/src/browser/tensor-source-download.js +8 -2
  15. package/src/browser/tensor-source-http.d.ts +1 -0
  16. package/src/browser/tensor-source-http.js +5 -1
  17. package/src/client/doppler-api.browser.js +20 -4
  18. package/src/client/doppler-api.js +19 -3
  19. package/src/client/doppler-provider/generation.js +12 -0
  20. package/src/client/doppler-provider/model-manager.d.ts +10 -0
  21. package/src/client/doppler-provider/model-manager.js +91 -19
  22. package/src/client/doppler-provider/source-runtime.d.ts +2 -1
  23. package/src/client/doppler-provider/source-runtime.js +132 -13
  24. package/src/client/doppler-registry.json +8 -7
  25. package/src/config/backward-registry-loader.js +17 -2
  26. package/src/config/execution-v0-contract-check.js +113 -15
  27. package/src/config/kernel-path-contract-check.js +57 -29
  28. package/src/config/kernel-path-loader.js +5 -36
  29. package/src/config/kernels/kernel-ref-digests.js +39 -39
  30. package/src/config/kernels/registry.js +14 -1
  31. package/src/config/kernels/registry.json +49 -7
  32. package/src/config/loader.d.ts +1 -1
  33. package/src/config/loader.js +43 -4
  34. package/src/config/merge-contract-check.js +59 -4
  35. package/src/config/merge-helpers.js +128 -7
  36. package/src/config/merge.d.ts +1 -0
  37. package/src/config/merge.js +28 -0
  38. package/src/config/param-validator.js +47 -2
  39. package/src/config/presets/kernel-paths/{gemma2-q4k-dequant-f32a.json → gemma2-q4k-dequant-f32a-nosubgroups.json} +3 -3
  40. package/src/config/presets/kernel-paths/gemma3-f16-fused-f32a-online-streamingprefill.json +223 -0
  41. package/src/config/presets/kernel-paths/{gemma3-q4k-dequant-f32a.json → gemma3-q4k-dequant-f32a-nosubgroups.json} +3 -3
  42. package/src/config/presets/kernel-paths/registry.json +29 -8
  43. package/src/config/presets/models/gemma2.json +2 -2
  44. package/src/config/presets/models/qwen3.json +9 -2
  45. package/src/config/presets/models/transformer.json +5 -0
  46. package/src/config/presets/runtime/experiments/bench/gemma3-bench-q4k.json +1 -1
  47. package/src/config/presets/runtime/experiments/debug/gemma3-debug-q4k.json +1 -1
  48. package/src/config/presets/runtime/experiments/verify/gemma3-verify.json +1 -1
  49. package/src/config/presets/runtime/kernels/dequant-f16-q4k.json +6 -13
  50. package/src/config/presets/runtime/kernels/dequant-f32-q4k.json +6 -13
  51. package/src/config/presets/runtime/kernels/embeddinggemma-q4k-dequant-f32a.json +37 -0
  52. package/src/config/presets/runtime/kernels/fused-q4k.json +6 -13
  53. package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f16a.json +33 -0
  54. package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f32a-nosubgroups.json +33 -0
  55. package/src/config/presets/runtime/kernels/gemma2-q4k-fused-f32a.json +33 -0
  56. package/src/config/presets/runtime/kernels/safe-q4k.json +6 -13
  57. package/src/config/presets/runtime/platform/metal-apple-q4k.json +1 -1
  58. package/src/config/required-inference-fields-contract-check.js +6 -0
  59. package/src/config/runtime.js +6 -1
  60. package/src/config/schema/debug.schema.d.ts +5 -0
  61. package/src/config/schema/doppler.schema.js +16 -21
  62. package/src/config/schema/inference-defaults.schema.js +6 -3
  63. package/src/config/schema/inference.schema.d.ts +9 -0
  64. package/src/config/schema/kernel-path.schema.d.ts +11 -1
  65. package/src/config/schema/kernel-thresholds.schema.js +12 -4
  66. package/src/config/schema/manifest.schema.d.ts +8 -1
  67. package/src/config/schema/manifest.schema.js +19 -3
  68. package/src/config/training-defaults.js +30 -22
  69. package/src/converter/conversion-plan.js +94 -9
  70. package/src/converter/core.d.ts +7 -0
  71. package/src/converter/core.js +14 -9
  72. package/src/converter/execution-v0-manifest.js +4 -1
  73. package/src/converter/index.d.ts +1 -0
  74. package/src/converter/index.js +1 -0
  75. package/src/converter/manifest-inference.js +43 -12
  76. package/src/converter/parsers/diffusion.js +0 -3
  77. package/src/converter/quantization-info.js +35 -15
  78. package/src/converter/rope-config.js +42 -0
  79. package/src/converter/shard-packer.d.ts +1 -1
  80. package/src/converter/shard-packer.js +4 -1
  81. package/src/debug/config.js +123 -11
  82. package/src/debug/signals.js +7 -1
  83. package/src/debug/tensor.d.ts +2 -0
  84. package/src/debug/tensor.js +13 -2
  85. package/src/distribution/p2p-control-plane.js +52 -12
  86. package/src/distribution/p2p-observability.js +43 -7
  87. package/src/distribution/p2p-webrtc-browser.js +20 -0
  88. package/src/distribution/shard-delivery.js +77 -26
  89. package/src/formats/gguf/types.js +33 -16
  90. package/src/formats/rdrr/groups.d.ts +12 -4
  91. package/src/formats/rdrr/groups.js +3 -6
  92. package/src/formats/rdrr/parsing.js +39 -2
  93. package/src/formats/rdrr/types.d.ts +2 -1
  94. package/src/gpu/command-recorder.js +86 -61
  95. package/src/gpu/device.d.ts +1 -0
  96. package/src/gpu/device.js +131 -19
  97. package/src/gpu/kernel-tuner/benchmarks.js +326 -316
  98. package/src/gpu/kernel-tuner/cache.js +71 -4
  99. package/src/gpu/kernel-tuner/tuner.js +22 -4
  100. package/src/gpu/kernels/attention.js +113 -34
  101. package/src/gpu/kernels/backward/adam.js +62 -58
  102. package/src/gpu/kernels/backward/attention_backward.js +257 -169
  103. package/src/gpu/kernels/backward/conv2d_backward.js +14 -1
  104. package/src/gpu/kernels/bias_add.wgsl +8 -6
  105. package/src/gpu/kernels/bias_add_f16.wgsl +8 -5
  106. package/src/gpu/kernels/cast.js +191 -149
  107. package/src/gpu/kernels/check-stop.js +33 -44
  108. package/src/gpu/kernels/conv2d.js +27 -17
  109. package/src/gpu/kernels/conv2d.wgsl +7 -8
  110. package/src/gpu/kernels/conv2d_f16.wgsl +7 -8
  111. package/src/gpu/kernels/cross_entropy_loss.js +21 -15
  112. package/src/gpu/kernels/depthwise_conv2d.js +37 -26
  113. package/src/gpu/kernels/depthwise_conv2d.wgsl +6 -9
  114. package/src/gpu/kernels/depthwise_conv2d_f16.wgsl +6 -9
  115. package/src/gpu/kernels/dequant.js +178 -126
  116. package/src/gpu/kernels/energy.d.ts +3 -21
  117. package/src/gpu/kernels/energy.js +111 -88
  118. package/src/gpu/kernels/feature-check.js +1 -1
  119. package/src/gpu/kernels/fused_ffn.js +84 -65
  120. package/src/gpu/kernels/fused_matmul_residual.js +56 -33
  121. package/src/gpu/kernels/fused_matmul_rmsnorm.js +62 -45
  122. package/src/gpu/kernels/gather.js +33 -15
  123. package/src/gpu/kernels/gelu.js +19 -11
  124. package/src/gpu/kernels/grouped_pointwise_conv2d.js +34 -23
  125. package/src/gpu/kernels/grouped_pointwise_conv2d.wgsl +6 -9
  126. package/src/gpu/kernels/grouped_pointwise_conv2d_f16.wgsl +6 -9
  127. package/src/gpu/kernels/groupnorm.js +34 -23
  128. package/src/gpu/kernels/kv-quantize.js +5 -2
  129. package/src/gpu/kernels/layernorm.js +35 -19
  130. package/src/gpu/kernels/logit-merge.js +5 -3
  131. package/src/gpu/kernels/matmul.js +83 -39
  132. package/src/gpu/kernels/modulate.js +23 -15
  133. package/src/gpu/kernels/moe.js +221 -175
  134. package/src/gpu/kernels/pixel_shuffle.js +22 -14
  135. package/src/gpu/kernels/pixel_shuffle.wgsl +4 -5
  136. package/src/gpu/kernels/pixel_shuffle_f16.wgsl +4 -5
  137. package/src/gpu/kernels/relu.js +31 -10
  138. package/src/gpu/kernels/relu.wgsl +2 -1
  139. package/src/gpu/kernels/relu_f16.wgsl +2 -1
  140. package/src/gpu/kernels/repeat_channels.js +25 -17
  141. package/src/gpu/kernels/repeat_channels.wgsl +4 -5
  142. package/src/gpu/kernels/repeat_channels_f16.wgsl +4 -5
  143. package/src/gpu/kernels/residual.js +69 -23
  144. package/src/gpu/kernels/residual.wgsl +6 -3
  145. package/src/gpu/kernels/residual_f16.wgsl +2 -1
  146. package/src/gpu/kernels/residual_f16_vec4.wgsl +2 -1
  147. package/src/gpu/kernels/residual_vec4.wgsl +2 -1
  148. package/src/gpu/kernels/rmsnorm.js +96 -28
  149. package/src/gpu/kernels/rmsnorm.wgsl +14 -6
  150. package/src/gpu/kernels/rmsnorm_f16.wgsl +10 -2
  151. package/src/gpu/kernels/rope.d.ts +2 -0
  152. package/src/gpu/kernels/rope.js +14 -1
  153. package/src/gpu/kernels/rope.wgsl +56 -40
  154. package/src/gpu/kernels/sample.js +27 -38
  155. package/src/gpu/kernels/sana_linear_attention.js +19 -12
  156. package/src/gpu/kernels/sana_linear_attention_apply.wgsl +4 -5
  157. package/src/gpu/kernels/sana_linear_attention_apply_f16.wgsl +4 -5
  158. package/src/gpu/kernels/sana_linear_attention_summary.wgsl +4 -0
  159. package/src/gpu/kernels/sana_linear_attention_summary_f16.wgsl +4 -0
  160. package/src/gpu/kernels/scale.js +18 -11
  161. package/src/gpu/kernels/shader-cache.js +4 -2
  162. package/src/gpu/kernels/silu.d.ts +1 -0
  163. package/src/gpu/kernels/silu.js +148 -82
  164. package/src/gpu/kernels/silu.wgsl +19 -9
  165. package/src/gpu/kernels/silu_f16.wgsl +19 -9
  166. package/src/gpu/kernels/softmax.js +44 -25
  167. package/src/gpu/kernels/split_qkv.js +23 -13
  168. package/src/gpu/kernels/transpose.js +31 -10
  169. package/src/gpu/kernels/transpose.wgsl +6 -5
  170. package/src/gpu/kernels/upsample2d.js +22 -13
  171. package/src/gpu/kernels/upsample2d.wgsl +6 -9
  172. package/src/gpu/kernels/upsample2d_f16.wgsl +6 -9
  173. package/src/gpu/kernels/utils.js +35 -13
  174. package/src/gpu/partitioned-buffer-pool.js +10 -2
  175. package/src/gpu/perf-guards.js +2 -9
  176. package/src/gpu/profiler.js +27 -22
  177. package/src/gpu/readback-utils.d.ts +16 -0
  178. package/src/gpu/readback-utils.js +41 -0
  179. package/src/gpu/submit-tracker.js +13 -0
  180. package/src/gpu/uniform-cache.d.ts +1 -0
  181. package/src/gpu/uniform-cache.js +30 -9
  182. package/src/hotswap/intent-bundle.js +6 -0
  183. package/src/hotswap/manifest.d.ts +10 -1
  184. package/src/hotswap/manifest.js +12 -2
  185. package/src/hotswap/runtime.js +30 -8
  186. package/src/index-browser.d.ts +44 -0
  187. package/src/index-browser.js +14 -0
  188. package/src/inference/browser-harness-contract-helpers.d.ts +5 -0
  189. package/src/inference/browser-harness-contract-helpers.js +28 -0
  190. package/src/inference/browser-harness-diffusion-energy-suites.d.ts +2 -0
  191. package/src/inference/browser-harness-diffusion-energy-suites.js +269 -0
  192. package/src/inference/browser-harness-model-helpers.d.ts +16 -0
  193. package/src/inference/browser-harness-model-helpers.js +217 -0
  194. package/src/inference/browser-harness-report-helpers.d.ts +7 -0
  195. package/src/inference/browser-harness-report-helpers.js +42 -0
  196. package/src/inference/browser-harness-runtime-helpers.d.ts +61 -0
  197. package/src/inference/browser-harness-runtime-helpers.js +415 -0
  198. package/src/inference/browser-harness-suite-helpers.d.ts +28 -0
  199. package/src/inference/browser-harness-suite-helpers.js +268 -0
  200. package/src/inference/browser-harness-text-helpers.d.ts +27 -0
  201. package/src/inference/browser-harness-text-helpers.js +788 -0
  202. package/src/inference/browser-harness.d.ts +6 -0
  203. package/src/inference/browser-harness.js +130 -1950
  204. package/src/inference/kv-cache/base.js +140 -94
  205. package/src/inference/kv-cache/tiered.js +5 -3
  206. package/src/inference/moe-router.js +88 -56
  207. package/src/inference/multi-model-network.js +5 -3
  208. package/src/inference/network-evolution.d.ts +11 -2
  209. package/src/inference/network-evolution.js +20 -21
  210. package/src/inference/pipelines/context.d.ts +3 -0
  211. package/src/inference/pipelines/context.js +142 -2
  212. package/src/inference/pipelines/diffusion/helpers.js +7 -2
  213. package/src/inference/pipelines/diffusion/pipeline.js +17 -7
  214. package/src/inference/pipelines/diffusion/sd3-transformer.js +10 -10
  215. package/src/inference/pipelines/diffusion/text-encoder-gpu.d.ts +5 -0
  216. package/src/inference/pipelines/diffusion/text-encoder-gpu.js +27 -15
  217. package/src/inference/pipelines/diffusion/vae.js +3 -7
  218. package/src/inference/pipelines/energy/pipeline.js +27 -21
  219. package/src/inference/pipelines/energy/quintel.d.ts +5 -0
  220. package/src/inference/pipelines/energy/quintel.js +11 -0
  221. package/src/inference/pipelines/energy-head/row-head-pipeline.js +17 -13
  222. package/src/inference/pipelines/structured/json-head-pipeline.js +26 -11
  223. package/src/inference/pipelines/text/attention/projections.js +151 -101
  224. package/src/inference/pipelines/text/attention/record.js +73 -10
  225. package/src/inference/pipelines/text/attention/run.js +73 -10
  226. package/src/inference/pipelines/text/chat-format.js +25 -1
  227. package/src/inference/pipelines/text/config.d.ts +4 -0
  228. package/src/inference/pipelines/text/config.js +71 -5
  229. package/src/inference/pipelines/text/embed.js +2 -8
  230. package/src/inference/pipelines/text/execution-plan.js +64 -50
  231. package/src/inference/pipelines/text/execution-v0-contract-helpers.d.ts +59 -0
  232. package/src/inference/pipelines/text/execution-v0-contract-helpers.js +937 -0
  233. package/src/inference/pipelines/text/execution-v0-runtime-builders.d.ts +15 -0
  234. package/src/inference/pipelines/text/execution-v0-runtime-builders.js +279 -0
  235. package/src/inference/pipelines/text/execution-v0.js +78 -1002
  236. package/src/inference/pipelines/text/ffn/standard.js +3 -0
  237. package/src/inference/pipelines/text/generator-steps.d.ts +46 -0
  238. package/src/inference/pipelines/text/generator-steps.js +298 -207
  239. package/src/inference/pipelines/text/generator.js +6 -23
  240. package/src/inference/pipelines/text/init.d.ts +4 -0
  241. package/src/inference/pipelines/text/init.js +134 -29
  242. package/src/inference/pipelines/text/kernel-path-auto-select.js +2 -0
  243. package/src/inference/pipelines/text/kernel-trace.d.ts +2 -0
  244. package/src/inference/pipelines/text/kernel-trace.js +6 -0
  245. package/src/inference/pipelines/text/layer.js +14 -9
  246. package/src/inference/pipelines/text/linear-attention.d.ts +10 -0
  247. package/src/inference/pipelines/text/linear-attention.js +80 -6
  248. package/src/inference/pipelines/text/logits/gpu.js +10 -5
  249. package/src/inference/pipelines/text/logits/index.js +10 -11
  250. package/src/inference/pipelines/text/logits/utils.d.ts +7 -0
  251. package/src/inference/pipelines/text/logits/utils.js +9 -0
  252. package/src/inference/pipelines/text/lora-apply.js +50 -32
  253. package/src/inference/pipelines/text/model-load.js +279 -104
  254. package/src/inference/pipelines/text/moe-cache.js +5 -4
  255. package/src/inference/pipelines/text/moe-cpu-gptoss.js +74 -69
  256. package/src/inference/pipelines/text/moe-cpu.js +42 -38
  257. package/src/inference/pipelines/text/moe-gpu.js +110 -86
  258. package/src/inference/pipelines/text/ops.js +90 -90
  259. package/src/inference/pipelines/text/probes.js +9 -9
  260. package/src/inference/pipelines/text/weights.js +17 -7
  261. package/src/inference/pipelines/text.js +17 -1
  262. package/src/inference/speculative.d.ts +2 -2
  263. package/src/inference/speculative.js +4 -18
  264. package/src/inference/test-harness.d.ts +1 -1
  265. package/src/inference/test-harness.js +15 -5
  266. package/src/inference/tokenizer.d.ts +0 -5
  267. package/src/inference/tokenizer.js +4 -23
  268. package/src/inference/tokenizers/bpe.js +9 -0
  269. package/src/inference/tokenizers/bundled.js +176 -33
  270. package/src/inference/tokenizers/sentencepiece.js +12 -0
  271. package/src/loader/doppler-loader.js +38 -22
  272. package/src/loader/dtype-utils.js +3 -44
  273. package/src/loader/embedding-loader.js +7 -3
  274. package/src/loader/experts/expert-cache.js +13 -6
  275. package/src/loader/experts/expert-loader.js +10 -6
  276. package/src/loader/final-weights-loader.js +8 -4
  277. package/src/loader/layer-loader.js +2 -1
  278. package/src/loader/loader-state.js +2 -2
  279. package/src/loader/memory-monitor.js +8 -0
  280. package/src/loader/multi-model-loader.d.ts +14 -0
  281. package/src/loader/multi-model-loader.js +70 -24
  282. package/src/loader/shard-cache.js +81 -12
  283. package/src/loader/shard-resolver.js +25 -3
  284. package/src/loader/tensors/tensor-loader.js +209 -144
  285. package/src/loader/tensors/tensor-reader.js +76 -19
  286. package/src/loader/weight-downcast.js +1 -1
  287. package/src/memory/buffer-pool.d.ts +9 -1
  288. package/src/memory/buffer-pool.js +109 -44
  289. package/src/memory/unified-detect.js +1 -1
  290. package/src/rules/inference/kernel-path.rules.json +24 -8
  291. package/src/rules/rule-registry.js +25 -1
  292. package/src/rules/tooling/command-runtime.rules.json +18 -0
  293. package/src/storage/backends/opfs-store.js +68 -24
  294. package/src/storage/downloader.js +364 -83
  295. package/src/storage/index.d.ts +3 -0
  296. package/src/storage/index.js +3 -0
  297. package/src/storage/preflight.d.ts +2 -2
  298. package/src/storage/preflight.js +24 -2
  299. package/src/storage/quickstart-downloader.js +11 -5
  300. package/src/storage/registry.js +10 -4
  301. package/src/storage/reports.js +1 -1
  302. package/src/storage/shard-manager.d.ts +15 -1
  303. package/src/storage/shard-manager.js +51 -3
  304. package/src/storage/source-artifact-store.d.ts +52 -0
  305. package/src/storage/source-artifact-store.js +234 -0
  306. package/src/tooling/command-api-constants.d.ts +9 -0
  307. package/src/tooling/command-api-constants.js +9 -0
  308. package/src/tooling/command-api-family-normalizers.d.ts +9 -0
  309. package/src/tooling/command-api-family-normalizers.js +343 -0
  310. package/src/tooling/command-api-helpers.d.ts +25 -0
  311. package/src/tooling/command-api-helpers.js +262 -0
  312. package/src/tooling/command-api.d.ts +27 -1
  313. package/src/tooling/command-api.js +26 -473
  314. package/src/tooling/command-envelope.js +4 -1
  315. package/src/tooling/command-runner-shared.js +52 -18
  316. package/src/tooling/lean-execution-contract.js +150 -3
  317. package/src/tooling/node-browser-command-runner.d.ts +4 -0
  318. package/src/tooling/node-browser-command-runner.js +218 -273
  319. package/src/tooling/node-command-runner.js +44 -3
  320. package/src/tooling/node-converter.js +27 -1
  321. package/src/tooling/node-source-runtime.d.ts +1 -1
  322. package/src/tooling/node-source-runtime.js +84 -3
  323. package/src/tooling/node-webgpu.js +30 -105
  324. package/src/tooling/opfs-cache.js +21 -4
  325. package/src/tooling/runtime-input-composition.d.ts +38 -0
  326. package/src/tooling/runtime-input-composition.js +86 -0
  327. package/src/tooling/source-runtime-bundle.d.ts +40 -5
  328. package/src/tooling/source-runtime-bundle.js +261 -34
  329. package/src/tooling/source-runtime-materializer.d.ts +6 -0
  330. package/src/tooling/source-runtime-materializer.js +93 -0
  331. package/src/training/attention-backward.js +32 -17
  332. package/src/training/autograd.js +80 -52
  333. package/src/training/checkpoint-watch.d.ts +8 -0
  334. package/src/training/checkpoint-watch.js +139 -0
  335. package/src/training/checkpoint.d.ts +6 -1
  336. package/src/training/checkpoint.js +46 -7
  337. package/src/training/clip.js +2 -1
  338. package/src/training/datasets/token-batch.js +20 -8
  339. package/src/training/distillation/artifacts.d.ts +71 -0
  340. package/src/training/distillation/artifacts.js +132 -0
  341. package/src/training/distillation/checkpoint-watch.d.ts +10 -0
  342. package/src/training/distillation/checkpoint-watch.js +58 -0
  343. package/src/training/distillation/dataset.d.ts +59 -0
  344. package/src/training/distillation/dataset.js +337 -0
  345. package/src/training/distillation/eval.d.ts +34 -0
  346. package/src/training/distillation/eval.js +310 -0
  347. package/src/training/distillation/index.d.ts +29 -0
  348. package/src/training/distillation/index.js +29 -0
  349. package/src/training/distillation/runtime.d.ts +20 -0
  350. package/src/training/distillation/runtime.js +121 -0
  351. package/src/training/distillation/scoreboard.d.ts +6 -0
  352. package/src/training/distillation/scoreboard.js +8 -0
  353. package/src/training/distillation/stage-a.d.ts +45 -0
  354. package/src/training/distillation/stage-a.js +338 -0
  355. package/src/training/distillation/stage-b.d.ts +24 -0
  356. package/src/training/distillation/stage-b.js +20 -0
  357. package/src/training/distillation/student-fixture.d.ts +22 -0
  358. package/src/training/distillation/student-fixture.js +846 -0
  359. package/src/training/distillation/suite-data.d.ts +45 -0
  360. package/src/training/distillation/suite-data.js +189 -0
  361. package/src/training/index.d.ts +10 -0
  362. package/src/training/index.js +10 -0
  363. package/src/training/lora-pipeline.d.ts +40 -0
  364. package/src/training/lora-pipeline.js +793 -0
  365. package/src/training/lora.js +26 -12
  366. package/src/training/loss.js +5 -6
  367. package/src/training/objectives/cross_entropy.js +2 -5
  368. package/src/training/objectives/distill_kd.js +4 -8
  369. package/src/training/objectives/distill_triplet.js +4 -8
  370. package/src/training/objectives/ul_stage2_base.js +4 -8
  371. package/src/training/operator-artifacts.d.ts +62 -0
  372. package/src/training/operator-artifacts.js +140 -0
  373. package/src/training/operator-command.d.ts +5 -0
  374. package/src/training/operator-command.js +455 -0
  375. package/src/training/operator-eval.d.ts +48 -0
  376. package/src/training/operator-eval.js +230 -0
  377. package/src/training/operator-scoreboard.d.ts +5 -0
  378. package/src/training/operator-scoreboard.js +44 -0
  379. package/src/training/optimizer.js +19 -7
  380. package/src/training/runner.d.ts +52 -0
  381. package/src/training/runner.js +31 -5
  382. package/src/training/suite.d.ts +112 -0
  383. package/src/training/suite.js +24 -984
  384. package/src/training/tensor-factory.d.ts +9 -0
  385. package/src/training/tensor-factory.js +13 -0
  386. package/src/training/trainer.js +3 -5
  387. package/src/training/ul_dataset.js +3 -5
  388. package/src/training/workloads.d.ts +164 -0
  389. package/src/training/workloads.js +530 -0
  390. package/src/version.js +1 -1
  391. package/tools/convert-safetensors-node.js +22 -16
  392. package/tools/doppler-cli.js +179 -63
@@ -1043,18 +1043,9 @@ export class PipelineGenerator {
1043
1043
  if (allowReadback(`pipeline.prefill.layer-${l}`)) {
1044
1044
  try {
1045
1045
  const sampleSize = config.hiddenSize * activationBytes;
1046
- const staging = device.createBuffer({
1047
- size: sampleSize,
1048
- usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ,
1049
- });
1050
- const enc = device.createCommandEncoder();
1051
1046
  const lastTokenOffset = (numTokens - 1) * config.hiddenSize * activationBytes;
1052
- enc.copyBufferToBuffer(currentHiddenBuffer, lastTokenOffset, staging, 0, sampleSize);
1053
- device.queue.submit([enc.finish()]);
1054
- await staging.mapAsync(GPUMapMode.READ);
1055
- const data = decodeReadback(staging.getMappedRange().slice(0), activationDtype);
1056
- staging.unmap();
1057
- staging.destroy();
1047
+ const readback = await readBufferSlice(currentHiddenBuffer, lastTokenOffset, sampleSize);
1048
+ const data = decodeReadback(readback, activationDtype);
1058
1049
  let min = Infinity;
1059
1050
  let max = -Infinity;
1060
1051
  let maxAbs = 0;
@@ -1112,20 +1103,12 @@ export class PipelineGenerator {
1112
1103
  if (opts.debug) {
1113
1104
  log.debug('Pipeline', `LAYER_LOOP_DONE, currentHiddenBuffer type=${currentHiddenBuffer?.constructor?.name}`);
1114
1105
  if (currentHiddenBuffer && allowReadback('pipeline.prefill.final-hidden')) {
1115
- const device = getDevice();
1116
1106
  const lastTokenOffset = (numTokens - 1) * config.hiddenSize * activationBytes;
1117
1107
  const sampleSize = config.hiddenSize * activationBytes;
1118
- const staging = device.createBuffer({
1119
- size: sampleSize,
1120
- usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ,
1121
- });
1122
- const enc = device.createCommandEncoder();
1123
- enc.copyBufferToBuffer(currentHiddenBuffer, lastTokenOffset, staging, 0, sampleSize);
1124
- device.queue.submit([enc.finish()]);
1125
- await staging.mapAsync(GPUMapMode.READ);
1126
- const data = decodeReadback(staging.getMappedRange().slice(0), activationDtype);
1127
- staging.unmap();
1128
- staging.destroy();
1108
+ const data = decodeReadback(
1109
+ await readBufferSlice(currentHiddenBuffer, lastTokenOffset, sampleSize),
1110
+ activationDtype
1111
+ );
1129
1112
  const nanCount = Array.from(data).filter(x => !Number.isFinite(x)).length;
1130
1113
  const nonZero = Array.from(data).filter(x => Number.isFinite(x) && x !== 0).slice(0, 5);
1131
1114
  log.debug('Pipeline', `FINAL_HIDDEN[pos=${numTokens - 1}]: nan=${nanCount}/${data.length}, sample=[${nonZero.map(x => x.toFixed(4)).join(', ')}]`);
@@ -71,9 +71,13 @@ export interface PipelineContexts {
71
71
  */
72
72
  export interface RoPEConfig {
73
73
  headDim: number;
74
+ rotaryDim?: number;
74
75
  maxSeqLen: number;
75
76
  ropeTheta: number;
76
77
  ropeLocalTheta?: number | null;
78
+ mropeInterleaved?: boolean;
79
+ mropeSection?: number[] | null;
80
+ partialRotaryFactor?: number | null;
77
81
  ropeScale: number;
78
82
  ropeLocalScale?: number;
79
83
  ropeScalingType?: string | null;
@@ -2,7 +2,7 @@
2
2
 
3
3
  import { parseModelConfig } from './config.js';
4
4
  import { getDevice, getDeviceLimits, getKernelCapabilities } from '../../../gpu/device.js';
5
- import { acquireBuffer } from '../../../memory/buffer-pool.js';
5
+ import { acquireBuffer, releaseBuffer } from '../../../memory/buffer-pool.js';
6
6
  import { KVCache, SlidingWindowKVCache, TieredKVCache, BasisDecomposedPagedCache } from '../../kv-cache.js';
7
7
  import { Tokenizer } from '../../tokenizer.js';
8
8
  import { MoERouter } from '../../moe-router.js';
@@ -14,6 +14,10 @@ import { PAGED_LAYOUT_SEQ_LEN_THRESHOLD } from '../../../config/schema/index.js'
14
14
  import { isKernelPathFusedQ4K } from '../../../config/kernel-path-loader.js';
15
15
  import { createWeightBuffer, getWeightDtype, isWeightBuffer } from '../../../gpu/weight-buffer.js';
16
16
  import { selectRuleValue } from '../../../rules/rule-registry.js';
17
+ import {
18
+ createSourceStorageContext,
19
+ getSourceRuntimeMetadata,
20
+ } from '../../../tooling/source-runtime-bundle.js';
17
21
 
18
22
  function resolveErrorMessage(error) {
19
23
  if (error && typeof error === 'object' && typeof error.message === 'string') {
@@ -56,12 +60,61 @@ function normalizeBaseUrl(baseUrl) {
56
60
  return baseUrl.replace(/\/$/, '');
57
61
  }
58
62
 
63
+ async function fetchBytes(url, offset = null, length = null) {
64
+ const headers = {};
65
+ if (Number.isFinite(offset) && Number.isFinite(length) && length > 0) {
66
+ const start = Math.max(0, Math.floor(offset));
67
+ const end = start + Math.max(0, Math.floor(length)) - 1;
68
+ headers.Range = `bytes=${start}-${end}`;
69
+ }
70
+ const response = await fetch(url, { headers });
71
+ if (!response.ok) {
72
+ throw new Error(`Failed to fetch ${url}: ${response.status}`);
73
+ }
74
+ return new Uint8Array(await response.arrayBuffer());
75
+ }
76
+
59
77
  function createRemoteStorageContext(baseUrl, manifest) {
60
78
  const root = normalizeBaseUrl(baseUrl);
61
79
  if (!root || !isRDRRManifest(manifest)) {
62
80
  return null;
63
81
  }
64
82
 
83
+ const sourceRuntime = getSourceRuntimeMetadata(manifest);
84
+ if (sourceRuntime) {
85
+ const readRange = async (relativePath, offset, length) => {
86
+ const filename = String(relativePath || '').replace(/^\/+/, '');
87
+ if (!filename) {
88
+ throw new Error('Direct-source artifact path is required.');
89
+ }
90
+ const url = `${root}/${filename}`;
91
+ return fetchBytes(url, offset, length);
92
+ };
93
+ const readText = async (relativePath) => {
94
+ const filename = String(relativePath || '').replace(/^\/+/, '');
95
+ if (!filename) return null;
96
+ const response = await fetch(`${root}/${filename}`);
97
+ if (!response.ok) {
98
+ throw new Error(`Failed to fetch ${filename} from ${root}: ${response.status}`);
99
+ }
100
+ return response.text();
101
+ };
102
+ const readBinary = async (relativePath) => {
103
+ const filename = String(relativePath || '').replace(/^\/+/, '');
104
+ if (!filename) {
105
+ throw new Error('Direct-source binary asset path is required.');
106
+ }
107
+ return fetchBytes(`${root}/${filename}`);
108
+ };
109
+ return createSourceStorageContext({
110
+ manifest,
111
+ readRange,
112
+ readText,
113
+ readBinary,
114
+ verifyHashes: true,
115
+ });
116
+ }
117
+
65
118
  return {
66
119
  async loadShard(index) {
67
120
  const shard = manifest.shards[index];
@@ -69,11 +122,7 @@ function createRemoteStorageContext(baseUrl, manifest) {
69
122
  if (!filename) {
70
123
  throw new Error(`Manifest shard ${index} is missing filename.`);
71
124
  }
72
- const response = await fetch(`${root}/${filename.replace(/^\/+/, '')}`);
73
- if (!response.ok) {
74
- throw new Error(`Failed to fetch shard ${index} from ${root}: ${response.status}`);
75
- }
76
- return new Uint8Array(await response.arrayBuffer());
125
+ return fetchBytes(`${root}/${filename.replace(/^\/+/, '')}`);
77
126
  },
78
127
  };
79
128
  }
@@ -206,13 +255,45 @@ function isSameRoPEScalingConfig(
206
255
  === (rightScaling?.original_max_position_embeddings ?? null);
207
256
  }
208
257
 
258
+ function resolveRotaryDim(headDim, rotaryDim, partialRotaryFactor) {
259
+ if (rotaryDim != null) {
260
+ if (!Number.isFinite(rotaryDim) || rotaryDim <= 0 || (rotaryDim % 2) !== 0) {
261
+ throw new Error(`RoPE rotary dim must be a positive even integer; got "${rotaryDim}".`);
262
+ }
263
+ if (rotaryDim > headDim) {
264
+ throw new Error(`RoPE rotary dim ${rotaryDim} cannot exceed headDim ${headDim}.`);
265
+ }
266
+ return rotaryDim;
267
+ }
268
+ if (partialRotaryFactor == null) {
269
+ return headDim;
270
+ }
271
+ if (!Number.isFinite(partialRotaryFactor) || partialRotaryFactor <= 0 || partialRotaryFactor > 1) {
272
+ throw new Error(
273
+ `RoPE partialRotaryFactor must be a number in (0, 1]; got "${partialRotaryFactor}".`
274
+ );
275
+ }
276
+ const resolved = Math.trunc(headDim * partialRotaryFactor);
277
+ if (resolved <= 0 || (resolved % 2) !== 0) {
278
+ throw new Error(
279
+ `RoPE partialRotaryFactor=${partialRotaryFactor} with headDim=${headDim} resolves ` +
280
+ `to rotaryDim=${resolved}, but rotaryDim must be a positive even integer.`
281
+ );
282
+ }
283
+ return resolved;
284
+ }
285
+
209
286
 
210
287
  export async function initRoPEFrequencies(config, useGPU) {
211
288
  const {
212
289
  headDim,
290
+ rotaryDim,
213
291
  maxSeqLen,
214
292
  ropeTheta,
215
293
  ropeLocalTheta,
294
+ mropeInterleaved,
295
+ mropeSection,
296
+ partialRotaryFactor,
216
297
  ropeScale,
217
298
  ropeLocalScale,
218
299
  ropeScalingType,
@@ -230,14 +311,23 @@ export async function initRoPEFrequencies(config, useGPU) {
230
311
  const resolvedLocalTheta = ropeLocalTheta ?? ropeTheta;
231
312
  const resolvedLocalScalingType = ropeLocalScalingType ?? ropeScalingType;
232
313
  const resolvedLocalScaling = ropeLocalScaling ?? ropeScaling;
314
+ const resolvedRotaryDim = resolveRotaryDim(headDim, rotaryDim, partialRotaryFactor);
315
+ const halfDim = resolvedRotaryDim / 2;
316
+ if (mropeInterleaved === true && Array.isArray(mropeSection)) {
317
+ const expandedDim = mropeSection.reduce((sum, entry) => sum + entry, 0) * 2;
318
+ if (expandedDim !== resolvedRotaryDim) {
319
+ throw new Error(
320
+ `RoPE mropeSection expands to ${expandedDim} dims, but rotaryDim is ${resolvedRotaryDim}.`
321
+ );
322
+ }
323
+ }
233
324
 
234
- const halfDim = headDim / 2;
235
325
  const isYarn = ropeScalingType === 'yarn';
236
326
  const isLocalYarn = resolvedLocalScalingType === 'yarn';
237
327
 
238
328
  // Compute global (full_attention) frequencies
239
329
  const globalFreqs = computeRoPEFreqsForTheta(
240
- ropeTheta, headDim, maxSeqLen, ropeScale, ropeScalingType, ropeScaling
330
+ ropeTheta, resolvedRotaryDim, maxSeqLen, ropeScale, ropeScalingType, ropeScaling
241
331
  );
242
332
 
243
333
  // Compute local (sliding_attention) frequencies if different from global.
@@ -256,7 +346,7 @@ export async function initRoPEFrequencies(config, useGPU) {
256
346
  if (hasDistinctLocalTheta || hasDistinctLocalScaling) {
257
347
  localFreqs = computeRoPEFreqsForTheta(
258
348
  resolvedLocalTheta,
259
- headDim,
349
+ resolvedRotaryDim,
260
350
  maxSeqLen,
261
351
  resolvedLocalScale,
262
352
  resolvedLocalScalingType,
@@ -285,27 +375,37 @@ export async function initRoPEFrequencies(config, useGPU) {
285
375
  // Upload to GPU if available
286
376
  const device = getDevice();
287
377
  if (device && useGPU) {
288
- const cosBuffer = acquireBuffer(globalFreqs.cos.byteLength, undefined, 'rope_cos');
289
- const sinBuffer = acquireBuffer(globalFreqs.sin.byteLength, undefined, 'rope_sin');
290
- device.queue.writeBuffer(cosBuffer, 0, globalFreqs.cos.buffer, globalFreqs.cos.byteOffset, globalFreqs.cos.byteLength);
291
- device.queue.writeBuffer(sinBuffer, 0, globalFreqs.sin.buffer, globalFreqs.sin.byteOffset, globalFreqs.sin.byteLength);
292
-
293
-
294
- let localCosBuffer;
295
-
296
- let localSinBuffer;
297
- if (localFreqs) {
298
- localCosBuffer = acquireBuffer(localFreqs.cos.byteLength, undefined, 'rope_local_cos');
299
- localSinBuffer = acquireBuffer(localFreqs.sin.byteLength, undefined, 'rope_local_sin');
300
- device.queue.writeBuffer(localCosBuffer, 0, localFreqs.cos.buffer, localFreqs.cos.byteOffset, localFreqs.cos.byteLength);
301
- device.queue.writeBuffer(localSinBuffer, 0, localFreqs.sin.buffer, localFreqs.sin.byteOffset, localFreqs.sin.byteLength);
378
+ let cosBuffer = null;
379
+ let sinBuffer = null;
380
+ let localCosBuffer = null;
381
+ let localSinBuffer = null;
382
+ try {
383
+ cosBuffer = acquireBuffer(globalFreqs.cos.byteLength, undefined, 'rope_cos');
384
+ sinBuffer = acquireBuffer(globalFreqs.sin.byteLength, undefined, 'rope_sin');
385
+ device.queue.writeBuffer(cosBuffer, 0, globalFreqs.cos.buffer, globalFreqs.cos.byteOffset, globalFreqs.cos.byteLength);
386
+ device.queue.writeBuffer(sinBuffer, 0, globalFreqs.sin.buffer, globalFreqs.sin.byteOffset, globalFreqs.sin.byteLength);
387
+
388
+ if (localFreqs) {
389
+ localCosBuffer = acquireBuffer(localFreqs.cos.byteLength, undefined, 'rope_local_cos');
390
+ localSinBuffer = acquireBuffer(localFreqs.sin.byteLength, undefined, 'rope_local_sin');
391
+ device.queue.writeBuffer(localCosBuffer, 0, localFreqs.cos.buffer, localFreqs.cos.byteOffset, localFreqs.cos.byteLength);
392
+ device.queue.writeBuffer(localSinBuffer, 0, localFreqs.sin.buffer, localFreqs.sin.byteOffset, localFreqs.sin.byteLength);
393
+ }
394
+ } catch (error) {
395
+ for (const buffer of [cosBuffer, sinBuffer, localCosBuffer, localSinBuffer]) {
396
+ if (buffer) {
397
+ releaseBuffer(buffer);
398
+ }
399
+ }
400
+ throw error;
302
401
  }
303
402
 
304
403
  log.debug(
305
404
  'Pipeline',
306
- `RoPE frequencies initialized (GPU): ${maxSeqLen} positions, dim=${halfDim}, headDim=${headDim}, ` +
405
+ `RoPE frequencies initialized (GPU): ${maxSeqLen} positions, dim=${halfDim}, headDim=${headDim}, rotaryDim=${resolvedRotaryDim}, ` +
307
406
  `theta=${ropeTheta}${hasDistinctLocalTheta ? `, localTheta=${resolvedLocalTheta}` : ''}, ` +
308
- `scaling=${ropeScalingType ?? 'none'}:${ropeScale}${hasDistinctLocalScaling ? `, localScaling=${resolvedLocalScalingType ?? 'none'}:${resolvedLocalScale}` : ''}`
407
+ `scaling=${ropeScalingType ?? 'none'}:${ropeScale}${hasDistinctLocalScaling ? `, localScaling=${resolvedLocalScalingType ?? 'none'}:${resolvedLocalScale}` : ''}, ` +
408
+ `interleaved=${mropeInterleaved === true}`
309
409
  );
310
410
 
311
411
  return {
@@ -318,9 +418,10 @@ export async function initRoPEFrequencies(config, useGPU) {
318
418
 
319
419
  log.debug(
320
420
  'Pipeline',
321
- `RoPE frequencies initialized (CPU): ${maxSeqLen} positions, dim=${halfDim}, headDim=${headDim}, ` +
421
+ `RoPE frequencies initialized (CPU): ${maxSeqLen} positions, dim=${halfDim}, headDim=${headDim}, rotaryDim=${resolvedRotaryDim}, ` +
322
422
  `theta=${ropeTheta}${hasDistinctLocalTheta ? `, localTheta=${resolvedLocalTheta}` : ''}, ` +
323
- `scaling=${ropeScalingType ?? 'none'}:${ropeScale}${hasDistinctLocalScaling ? `, localScaling=${resolvedLocalScalingType ?? 'none'}:${resolvedLocalScale}` : ''}`
423
+ `scaling=${ropeScalingType ?? 'none'}:${ropeScale}${hasDistinctLocalScaling ? `, localScaling=${resolvedLocalScalingType ?? 'none'}:${resolvedLocalScale}` : ''}, ` +
424
+ `interleaved=${mropeInterleaved === true}`
324
425
  );
325
426
 
326
427
  return {
@@ -688,6 +789,10 @@ function applyChatMLTemplate(prompt) {
688
789
  return `<|im_start|>user\n${prompt}<|im_end|>\n<|im_start|>assistant\n`;
689
790
  }
690
791
 
792
+ function applyQwenTemplate(prompt) {
793
+ return `<|im_start|>user\n${prompt}<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n`;
794
+ }
795
+
691
796
  function applyTranslateGemmaTemplate() {
692
797
  throw new Error(
693
798
  'TranslateGemma template requires structured messages. ' +
@@ -702,7 +807,7 @@ const PROMPT_TEMPLATES = {
702
807
  'llama3': applyHeaderBasedTemplate,
703
808
  'gpt-oss': applyChannelBasedTemplate,
704
809
  'chatml': applyChatMLTemplate,
705
- 'qwen': applyChatMLTemplate, // Qwen uses ChatML format
810
+ 'qwen': applyQwenTemplate,
706
811
  'translategemma': applyTranslateGemmaTemplate,
707
812
  };
708
813
 
@@ -721,7 +826,7 @@ export function applyChatTemplate(prompt, templateType) {
721
826
  export const applyGemmaChatTemplate = applyTurnBasedTemplate;
722
827
  export const applyLlama3ChatTemplate = applyHeaderBasedTemplate;
723
828
  export const applyGptOssChatTemplate = applyChannelBasedTemplate;
724
- export const applyQwenChatTemplate = applyChatMLTemplate;
829
+ export const applyQwenChatTemplate = applyQwenTemplate;
725
830
 
726
831
 
727
832
  export function isStopToken(token, stopTokenIds, eosTokenId) {
@@ -78,6 +78,7 @@ export function resolveCapabilityKernelPathRef(configuredKernelPathRef, kernelPa
78
78
 
79
79
  const normalizedPolicy = resolveKernelPathPolicy(kernelPathPolicy);
80
80
  const hasSubgroups = capabilities?.hasSubgroups === true;
81
+ const hasF16 = capabilities?.hasF16 === true;
81
82
  const normalizedSource = normalizeKernelPathSource(kernelPathSource);
82
83
  const allowCapabilityAutoSelection = normalizedPolicy.mode === 'capability-aware'
83
84
  && normalizedPolicy.sourceScope.includes(normalizedSource);
@@ -85,6 +86,7 @@ export function resolveCapabilityKernelPathRef(configuredKernelPathRef, kernelPa
85
86
  return selectRuleValue('inference', 'kernelPath', 'autoSelect', {
86
87
  kernelPathRef: configuredKernelPathRef,
87
88
  hasSubgroups,
89
+ hasF16,
88
90
  allowCapabilityAutoSelection,
89
91
  });
90
92
  }
@@ -12,6 +12,8 @@
12
12
  * Snapshot of a tensor's statistics (no full data, just stats).
13
13
  */
14
14
  export interface TensorSnapshot {
15
+ ok: boolean;
16
+ error: string | null;
15
17
  shape: number[];
16
18
  dtype: string;
17
19
  stats: {
@@ -283,6 +283,9 @@ export async function traceStep(name, label, layer, outputBuffer, outputShape, o
283
283
  if (layer >= 0 && !kernelTrace.shouldTraceLayer(layer)) return;
284
284
 
285
285
  const output = await snapshotTensor(outputBuffer, outputShape);
286
+ if (!output.ok) {
287
+ throw new Error(`[TRACE] Failed to snapshot output for ${label}: ${output.error}`);
288
+ }
286
289
 
287
290
  // Snapshot inputs if provided (expensive - only do if tracing)
288
291
 
@@ -290,6 +293,9 @@ export async function traceStep(name, label, layer, outputBuffer, outputShape, o
290
293
  if (options?.inputs && options?.inputShapes) {
291
294
  for (let i = 0; i < options.inputs.length; i++) {
292
295
  const snap = await snapshotTensor(options.inputs[i], options.inputShapes[i]);
296
+ if (!snap.ok) {
297
+ throw new Error(`[TRACE] Failed to snapshot input ${i} for ${label}: ${snap.error}`);
298
+ }
293
299
  inputs.push(snap);
294
300
  }
295
301
  }
@@ -2,7 +2,7 @@
2
2
 
3
3
  import { log, trace } from '../../../debug/index.js';
4
4
  import { getDevice } from '../../../gpu/device.js';
5
- import { releaseBuffer } from '../../../memory/buffer-pool.js';
5
+ import { releaseBuffer, readBuffer } from '../../../memory/buffer-pool.js';
6
6
  import { allowReadback } from '../../../gpu/perf-guards.js';
7
7
  import { createTensor } from '../../../gpu/tensor.js';
8
8
  import {
@@ -228,6 +228,7 @@ export async function processLayerGPU(layerIdx, inputBuffer, numTokens, isPrefil
228
228
  linearRuntime: context.linearAttentionRuntime ?? null,
229
229
  getWeightBuffer: (weight, label) => getWeightBuffer(weight, label),
230
230
  getNormWeightBuffer: (weight, label) => getNormWeightBuffer(weight, label, weightConfig, debugFlags),
231
+ debugProbes: context.debugProbes,
231
232
  recorder: recorder ?? null,
232
233
  });
233
234
  } else {
@@ -259,6 +260,8 @@ export async function processLayerGPU(layerIdx, inputBuffer, numTokens, isPrefil
259
260
  attentionOutputGate: config.attentionOutputGate,
260
261
  causalAttention: config.causalAttention,
261
262
  rmsNormWeightOffset: config.rmsNormWeightOffset,
263
+ ropeRotaryDim: config.ropeRotaryDim,
264
+ ropeInterleaved: config.ropeInterleaved,
262
265
  tokenIds: context.currentTokenIds ?? null,
263
266
  kernelPath: context.kernelPath ?? null,
264
267
  disableRoPE,
@@ -312,14 +315,7 @@ export async function processLayerGPU(layerIdx, inputBuffer, numTokens, isPrefil
312
315
  if (allowReadback(`layer.attn-out.${layerIdx}`)) {
313
316
  try {
314
317
  const sampleSize = Math.min(128, attnOutput.buffer.size);
315
- const staging = device.createBuffer({ size: sampleSize, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ });
316
- const enc = device.createCommandEncoder();
317
- enc.copyBufferToBuffer(attnOutput.buffer, 0, staging, 0, sampleSize);
318
- device.queue.submit([enc.finish()]);
319
- await staging.mapAsync(GPUMapMode.READ);
320
- const data = new Float32Array(staging.getMappedRange().slice(0));
321
- staging.unmap();
322
- staging.destroy();
318
+ const data = new Float32Array(await readBuffer(attnOutput.buffer, sampleSize));
323
319
  let maxAbs = 0;
324
320
  for (let i = 0; i < data.length; i++) {
325
321
  const abs = Math.abs(data[i]);
@@ -661,6 +657,8 @@ async function processLayerPlanGPU(layerIdx, inputBuffer, numTokens, isPrefill,
661
657
  attentionOutputGate: config.attentionOutputGate,
662
658
  causalAttention: config.causalAttention,
663
659
  rmsNormWeightOffset: config.rmsNormWeightOffset,
660
+ ropeRotaryDim: config.ropeRotaryDim,
661
+ ropeInterleaved: config.ropeInterleaved,
664
662
  tokenIds: context.currentTokenIds ?? null,
665
663
  skipInputNorm: step.skipInputNorm === true,
666
664
  activationDtype,
@@ -690,6 +688,7 @@ async function processLayerPlanGPU(layerIdx, inputBuffer, numTokens, isPrefill,
690
688
  hiddenSize,
691
689
  probes: context.debugProbes,
692
690
  recorder,
691
+ dtype: outputDtype,
693
692
  });
694
693
  }
695
694
  break;
@@ -733,6 +732,7 @@ async function processLayerPlanGPU(layerIdx, inputBuffer, numTokens, isPrefill,
733
732
  hiddenSize,
734
733
  probes: context.debugProbes,
735
734
  recorder,
735
+ dtype: outputDtype,
736
736
  });
737
737
  }
738
738
  break;
@@ -767,6 +767,7 @@ async function processLayerPlanGPU(layerIdx, inputBuffer, numTokens, isPrefill,
767
767
  hiddenSize,
768
768
  probes: context.debugProbes,
769
769
  recorder,
770
+ dtype: outputDtype,
770
771
  });
771
772
  }
772
773
  break;
@@ -801,6 +802,7 @@ async function processLayerPlanGPU(layerIdx, inputBuffer, numTokens, isPrefill,
801
802
  hiddenSize,
802
803
  probes: context.debugProbes,
803
804
  recorder,
805
+ dtype: outputDtype,
804
806
  });
805
807
  }
806
808
  break;
@@ -825,6 +827,7 @@ async function processLayerPlanGPU(layerIdx, inputBuffer, numTokens, isPrefill,
825
827
  hiddenSize,
826
828
  probes: context.debugProbes,
827
829
  recorder,
830
+ dtype: outputDtype,
828
831
  });
829
832
  }
830
833
  break;
@@ -851,6 +854,7 @@ async function processLayerPlanGPU(layerIdx, inputBuffer, numTokens, isPrefill,
851
854
  hiddenSize,
852
855
  probes: context.debugProbes,
853
856
  recorder,
857
+ dtype: toDtype,
854
858
  });
855
859
  }
856
860
  break;
@@ -880,6 +884,7 @@ async function processLayerPlanGPU(layerIdx, inputBuffer, numTokens, isPrefill,
880
884
  hiddenSize,
881
885
  probes: context.debugProbes,
882
886
  recorder,
887
+ dtype: getSlotDtype('state') ?? activationDtype,
883
888
  });
884
889
 
885
890
  const computeConfig = context.runtimeComputeConfig ?? null;
@@ -3,6 +3,7 @@ import type { Tensor } from '../../../gpu/tensor.js';
3
3
  import type { WeightBuffer } from '../../../gpu/weight-buffer.js';
4
4
  import type { CommandRecorder } from '../../../gpu/command-recorder.js';
5
5
  import type { LinearNormMode } from '../../../config/schema/index.js';
6
+ import type { ProbeConfigSchema } from '../../../config/schema/index.js';
6
7
 
7
8
  export interface LinearLayerRuntimeState {
8
9
  layerIdx: number;
@@ -67,6 +68,7 @@ export interface RunLinearAttentionLayerOptions {
67
68
  weight: GPUBuffer | Float32Array | ArrayBuffer,
68
69
  label: string
69
70
  ) => GPUBuffer;
71
+ debugProbes?: ProbeConfigSchema[] | null;
70
72
  recorder?: CommandRecorder | null;
71
73
  }
72
74
 
@@ -74,6 +76,14 @@ export declare function hasLinearAttentionLayers(layerTypes: unknown): boolean;
74
76
 
75
77
  export declare function createLinearAttentionRuntime(): LinearAttentionRuntime;
76
78
 
79
+ export declare function inferLinearNormMode(
80
+ weight: { size?: number; dtype?: string } | GPUBuffer | WeightBuffer | ArrayBufferView | ArrayBuffer | null | undefined,
81
+ projectionLayout: {
82
+ headVDim: number;
83
+ valueDim: number;
84
+ }
85
+ ): LinearNormMode | null;
86
+
77
87
  export declare function resetLinearAttentionRuntime(
78
88
  runtime: LinearAttentionRuntime | null | undefined
79
89
  ): LinearAttentionRuntime;
@@ -4,6 +4,7 @@ import { readBuffer, releaseBuffer, uploadData, acquireBuffer } from '../../../m
4
4
  import { log } from '../../../debug/index.js';
5
5
  import { decodeReadback } from './debug-utils/index.js';
6
6
  import { runLinearAttentionCoreGPU } from '../../../gpu/kernels/linear-attention-core.js';
7
+ import { runProbes } from './probes.js';
7
8
 
8
9
  const LINEAR_RUNTIME_SCHEMA_VERSION = 1;
9
10
  const QK_L2NORM_EPS = 1e-6;
@@ -173,9 +174,22 @@ function inferLinearNormModeFromWeight(weight, projectionLayout) {
173
174
  if (weight instanceof ArrayBuffer) {
174
175
  return classify(Math.trunc(weight.byteLength / Float32Array.BYTES_PER_ELEMENT));
175
176
  }
177
+ const explicitDtype = typeof weight?.dtype === 'string' ? weight.dtype.toLowerCase() : null;
178
+ const trackedDtype = isGpuBuffer(weight) ? String(getBufferDtype(weight) ?? '').toLowerCase() : '';
179
+ const bytesPerElement = bytesFromDtype(explicitDtype || trackedDtype || null);
180
+ const sizedElements = Number.isFinite(weight?.size)
181
+ ? Math.trunc(Number(weight.size) / bytesPerElement)
182
+ : null;
183
+ if (sizedElements && Number(weight.size) % bytesPerElement === 0) {
184
+ return classify(sizedElements);
185
+ }
176
186
  return null;
177
187
  }
178
188
 
189
+ export function inferLinearNormMode(weight, projectionLayout) {
190
+ return inferLinearNormModeFromWeight(weight, projectionLayout);
191
+ }
192
+
179
193
  function resolveLinearNormMode(configNormMode, normWeight, projectionLayout, layerIdx) {
180
194
  const configuredMode = normalizeLinearNormMode(configNormMode);
181
195
  const inferredMode = inferLinearNormModeFromWeight(normWeight, projectionLayout);
@@ -185,7 +199,15 @@ function resolveLinearNormMode(configNormMode, normWeight, projectionLayout, lay
185
199
  `but norm.weight shape implies "${inferredMode}".`
186
200
  );
187
201
  }
188
- return configuredMode ?? inferredMode ?? 'shared';
202
+ if (configuredMode) {
203
+ return configuredMode;
204
+ }
205
+ if (inferredMode) {
206
+ return inferredMode;
207
+ }
208
+ throw new Error(
209
+ `linear_attention layer ${layerIdx} requires explicit linearNormMode or a norm.weight shape that resolves it.`
210
+ );
189
211
  }
190
212
 
191
213
  async function readWeightAsF32(weight, expectedElements, label) {
@@ -395,10 +417,17 @@ async function createLayerRuntimeState(
395
417
 
396
418
  let convKernelSize = toPositiveInt(config.linearConvKernelDim) ?? null;
397
419
  if (isWeightBuffer(convKernel) && Array.isArray(convKernel.shape) && convKernel.shape.length >= 3) {
398
- convKernelSize = toPositiveInt(convKernel.shape[2]) ?? convKernelSize;
420
+ const shapeKernelSize = toPositiveInt(convKernel.shape[2]) ?? null;
421
+ if (convKernelSize != null && shapeKernelSize != null && convKernelSize !== shapeKernelSize) {
422
+ throw new Error(
423
+ `linear_attention layer ${layerIdx} declares linearConvKernelDim=${convKernelSize}, ` +
424
+ `but conv1d weight shape implies ${shapeKernelSize}.`
425
+ );
426
+ }
427
+ convKernelSize = shapeKernelSize ?? convKernelSize;
399
428
  }
400
429
  if (!convKernelSize) {
401
- convKernelSize = 4;
430
+ throw new Error(`linear_attention layer ${layerIdx} requires linearConvKernelDim.`);
402
431
  }
403
432
 
404
433
  const convWeight = await readWeightAsF32(
@@ -435,6 +464,11 @@ async function createLayerRuntimeState(
435
464
  const recurrentState = new Float32Array(
436
465
  projectionLayout.numVHeads * projectionLayout.headKDim * projectionLayout.headVDim
437
466
  );
467
+ const rmsNormEps = Number(config.rmsNormEps);
468
+ if (!Number.isFinite(rmsNormEps) || rmsNormEps <= 0) {
469
+ throw new Error(`linear_attention layer ${layerIdx} requires a positive rmsNormEps.`);
470
+ }
471
+
438
472
  const layerState = {
439
473
  layerIdx,
440
474
  seqLen: currentSeqLen,
@@ -452,7 +486,7 @@ async function createLayerRuntimeState(
452
486
  vSize: projectionLayout.vSize,
453
487
  qRep: projectionLayout.qRep,
454
488
  normMode,
455
- rmsNormEps: Number(config.rmsNormEps) || 1e-6,
489
+ rmsNormEps,
456
490
  convWeight,
457
491
  dtBias,
458
492
  aNegExp,
@@ -681,13 +715,13 @@ export async function runLinearAttentionLayer(inputTensor, layerWeights, options
681
715
  const normWeightBuffer = getNormWeightBuffer(layerWeights.inputNorm, `L${layerIdx}.linear_input_norm`);
682
716
  try {
683
717
  if (recorder) {
684
- normedTensor = await recordRMSNorm(recorder, inputTensor, normWeightBuffer, Number(config.rmsNormEps) || 1e-6, {
718
+ normedTensor = await recordRMSNorm(recorder, inputTensor, normWeightBuffer, layerState.rmsNormEps, {
685
719
  batchSize: numTokens,
686
720
  hiddenSize,
687
721
  rmsNormWeightOffset: config.rmsNormWeightOffset,
688
722
  });
689
723
  } else {
690
- normedTensor = await runRMSNorm(inputTensor, normWeightBuffer, Number(config.rmsNormEps) || 1e-6, {
724
+ normedTensor = await runRMSNorm(inputTensor, normWeightBuffer, layerState.rmsNormEps, {
691
725
  batchSize: numTokens,
692
726
  hiddenSize,
693
727
  rmsNormWeightOffset: config.rmsNormWeightOffset,
@@ -755,6 +789,38 @@ export async function runLinearAttentionLayer(inputTensor, layerWeights, options
755
789
  });
756
790
 
757
791
  try {
792
+ await runProbes('linear_qkv_proj', qkvTensor.buffer, {
793
+ layerIdx,
794
+ numTokens,
795
+ hiddenSize: projectionLayout.convDim,
796
+ probes: options.debugProbes,
797
+ recorder,
798
+ dtype: qkvTensor.dtype,
799
+ });
800
+ await runProbes('linear_z_proj', zTensor.buffer, {
801
+ layerIdx,
802
+ numTokens,
803
+ hiddenSize: projectionLayout.valueDim,
804
+ probes: options.debugProbes,
805
+ recorder,
806
+ dtype: zTensor.dtype,
807
+ });
808
+ await runProbes('linear_a_proj', aTensor.buffer, {
809
+ layerIdx,
810
+ numTokens,
811
+ hiddenSize: projectionLayout.numVHeads,
812
+ probes: options.debugProbes,
813
+ recorder,
814
+ dtype: aTensor.dtype,
815
+ });
816
+ await runProbes('linear_b_proj', bTensor.buffer, {
817
+ layerIdx,
818
+ numTokens,
819
+ hiddenSize: projectionLayout.numVHeads,
820
+ probes: options.debugProbes,
821
+ recorder,
822
+ dtype: bTensor.dtype,
823
+ });
758
824
  const coreTensor = await runLinearAttentionCoreGPU(
759
825
  qkvTensor,
760
826
  zTensor,
@@ -768,6 +834,14 @@ export async function runLinearAttentionLayer(inputTensor, layerWeights, options
768
834
  recorder,
769
835
  }
770
836
  );
837
+ await runProbes('linear_core_out', coreTensor.buffer, {
838
+ layerIdx,
839
+ numTokens,
840
+ hiddenSize: projectionLayout.valueDim,
841
+ probes: options.debugProbes,
842
+ recorder,
843
+ dtype: coreTensor.dtype,
844
+ });
771
845
  layerState.seqLen = currentSeqLen + numTokens;
772
846
  const outProjWeight = getWeightBuffer(layerWeights.oProj, `L${layerIdx}.linear_out_proj`);
773
847
  try {