@simulatte/doppler 0.1.5 → 0.1.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (392) hide show
  1. package/CHANGELOG.md +126 -0
  2. package/README.md +25 -17
  3. package/package.json +20 -4
  4. package/src/adapters/adapter-registry.js +12 -1
  5. package/src/adapters/lora-loader.js +23 -6
  6. package/src/bridge/extension-client.d.ts +5 -0
  7. package/src/bridge/extension-client.js +40 -0
  8. package/src/bridge/index.d.ts +2 -1
  9. package/src/bridge/index.js +6 -4
  10. package/src/browser/browser-converter.js +26 -1
  11. package/src/browser/file-picker.js +6 -0
  12. package/src/browser/safetensors-parser-browser.js +84 -1
  13. package/src/browser/shard-io-browser.js +2 -2
  14. package/src/browser/tensor-source-download.js +8 -2
  15. package/src/browser/tensor-source-http.d.ts +1 -0
  16. package/src/browser/tensor-source-http.js +5 -1
  17. package/src/client/doppler-api.browser.js +20 -4
  18. package/src/client/doppler-api.js +19 -3
  19. package/src/client/doppler-provider/generation.js +12 -0
  20. package/src/client/doppler-provider/model-manager.d.ts +10 -0
  21. package/src/client/doppler-provider/model-manager.js +91 -19
  22. package/src/client/doppler-provider/source-runtime.d.ts +2 -1
  23. package/src/client/doppler-provider/source-runtime.js +132 -13
  24. package/src/client/doppler-registry.json +8 -7
  25. package/src/config/backward-registry-loader.js +17 -2
  26. package/src/config/execution-v0-contract-check.js +113 -15
  27. package/src/config/kernel-path-contract-check.js +57 -29
  28. package/src/config/kernel-path-loader.js +5 -36
  29. package/src/config/kernels/kernel-ref-digests.js +39 -39
  30. package/src/config/kernels/registry.js +14 -1
  31. package/src/config/kernels/registry.json +49 -7
  32. package/src/config/loader.d.ts +1 -1
  33. package/src/config/loader.js +43 -4
  34. package/src/config/merge-contract-check.js +59 -4
  35. package/src/config/merge-helpers.js +128 -7
  36. package/src/config/merge.d.ts +1 -0
  37. package/src/config/merge.js +28 -0
  38. package/src/config/param-validator.js +47 -2
  39. package/src/config/presets/kernel-paths/{gemma2-q4k-dequant-f32a.json → gemma2-q4k-dequant-f32a-nosubgroups.json} +3 -3
  40. package/src/config/presets/kernel-paths/gemma3-f16-fused-f32a-online-streamingprefill.json +223 -0
  41. package/src/config/presets/kernel-paths/{gemma3-q4k-dequant-f32a.json → gemma3-q4k-dequant-f32a-nosubgroups.json} +3 -3
  42. package/src/config/presets/kernel-paths/registry.json +29 -8
  43. package/src/config/presets/models/gemma2.json +2 -2
  44. package/src/config/presets/models/qwen3.json +9 -2
  45. package/src/config/presets/models/transformer.json +5 -0
  46. package/src/config/presets/runtime/experiments/bench/gemma3-bench-q4k.json +1 -1
  47. package/src/config/presets/runtime/experiments/debug/gemma3-debug-q4k.json +1 -1
  48. package/src/config/presets/runtime/experiments/verify/gemma3-verify.json +1 -1
  49. package/src/config/presets/runtime/kernels/dequant-f16-q4k.json +6 -13
  50. package/src/config/presets/runtime/kernels/dequant-f32-q4k.json +6 -13
  51. package/src/config/presets/runtime/kernels/embeddinggemma-q4k-dequant-f32a.json +37 -0
  52. package/src/config/presets/runtime/kernels/fused-q4k.json +6 -13
  53. package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f16a.json +33 -0
  54. package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f32a-nosubgroups.json +33 -0
  55. package/src/config/presets/runtime/kernels/gemma2-q4k-fused-f32a.json +33 -0
  56. package/src/config/presets/runtime/kernels/safe-q4k.json +6 -13
  57. package/src/config/presets/runtime/platform/metal-apple-q4k.json +1 -1
  58. package/src/config/required-inference-fields-contract-check.js +6 -0
  59. package/src/config/runtime.js +6 -1
  60. package/src/config/schema/debug.schema.d.ts +5 -0
  61. package/src/config/schema/doppler.schema.js +16 -21
  62. package/src/config/schema/inference-defaults.schema.js +6 -3
  63. package/src/config/schema/inference.schema.d.ts +9 -0
  64. package/src/config/schema/kernel-path.schema.d.ts +11 -1
  65. package/src/config/schema/kernel-thresholds.schema.js +12 -4
  66. package/src/config/schema/manifest.schema.d.ts +8 -1
  67. package/src/config/schema/manifest.schema.js +19 -3
  68. package/src/config/training-defaults.js +30 -22
  69. package/src/converter/conversion-plan.js +94 -9
  70. package/src/converter/core.d.ts +7 -0
  71. package/src/converter/core.js +14 -9
  72. package/src/converter/execution-v0-manifest.js +4 -1
  73. package/src/converter/index.d.ts +1 -0
  74. package/src/converter/index.js +1 -0
  75. package/src/converter/manifest-inference.js +43 -12
  76. package/src/converter/parsers/diffusion.js +0 -3
  77. package/src/converter/quantization-info.js +35 -15
  78. package/src/converter/rope-config.js +42 -0
  79. package/src/converter/shard-packer.d.ts +1 -1
  80. package/src/converter/shard-packer.js +4 -1
  81. package/src/debug/config.js +123 -11
  82. package/src/debug/signals.js +7 -1
  83. package/src/debug/tensor.d.ts +2 -0
  84. package/src/debug/tensor.js +13 -2
  85. package/src/distribution/p2p-control-plane.js +52 -12
  86. package/src/distribution/p2p-observability.js +43 -7
  87. package/src/distribution/p2p-webrtc-browser.js +20 -0
  88. package/src/distribution/shard-delivery.js +77 -26
  89. package/src/formats/gguf/types.js +33 -16
  90. package/src/formats/rdrr/groups.d.ts +12 -4
  91. package/src/formats/rdrr/groups.js +3 -6
  92. package/src/formats/rdrr/parsing.js +39 -2
  93. package/src/formats/rdrr/types.d.ts +2 -1
  94. package/src/gpu/command-recorder.js +86 -61
  95. package/src/gpu/device.d.ts +1 -0
  96. package/src/gpu/device.js +131 -19
  97. package/src/gpu/kernel-tuner/benchmarks.js +326 -316
  98. package/src/gpu/kernel-tuner/cache.js +71 -4
  99. package/src/gpu/kernel-tuner/tuner.js +22 -4
  100. package/src/gpu/kernels/attention.js +113 -34
  101. package/src/gpu/kernels/backward/adam.js +62 -58
  102. package/src/gpu/kernels/backward/attention_backward.js +257 -169
  103. package/src/gpu/kernels/backward/conv2d_backward.js +14 -1
  104. package/src/gpu/kernels/bias_add.wgsl +8 -6
  105. package/src/gpu/kernels/bias_add_f16.wgsl +8 -5
  106. package/src/gpu/kernels/cast.js +191 -149
  107. package/src/gpu/kernels/check-stop.js +33 -44
  108. package/src/gpu/kernels/conv2d.js +27 -17
  109. package/src/gpu/kernels/conv2d.wgsl +7 -8
  110. package/src/gpu/kernels/conv2d_f16.wgsl +7 -8
  111. package/src/gpu/kernels/cross_entropy_loss.js +21 -15
  112. package/src/gpu/kernels/depthwise_conv2d.js +37 -26
  113. package/src/gpu/kernels/depthwise_conv2d.wgsl +6 -9
  114. package/src/gpu/kernels/depthwise_conv2d_f16.wgsl +6 -9
  115. package/src/gpu/kernels/dequant.js +178 -126
  116. package/src/gpu/kernels/energy.d.ts +3 -21
  117. package/src/gpu/kernels/energy.js +111 -88
  118. package/src/gpu/kernels/feature-check.js +1 -1
  119. package/src/gpu/kernels/fused_ffn.js +84 -65
  120. package/src/gpu/kernels/fused_matmul_residual.js +56 -33
  121. package/src/gpu/kernels/fused_matmul_rmsnorm.js +62 -45
  122. package/src/gpu/kernels/gather.js +33 -15
  123. package/src/gpu/kernels/gelu.js +19 -11
  124. package/src/gpu/kernels/grouped_pointwise_conv2d.js +34 -23
  125. package/src/gpu/kernels/grouped_pointwise_conv2d.wgsl +6 -9
  126. package/src/gpu/kernels/grouped_pointwise_conv2d_f16.wgsl +6 -9
  127. package/src/gpu/kernels/groupnorm.js +34 -23
  128. package/src/gpu/kernels/kv-quantize.js +5 -2
  129. package/src/gpu/kernels/layernorm.js +35 -19
  130. package/src/gpu/kernels/logit-merge.js +5 -3
  131. package/src/gpu/kernels/matmul.js +83 -39
  132. package/src/gpu/kernels/modulate.js +23 -15
  133. package/src/gpu/kernels/moe.js +221 -175
  134. package/src/gpu/kernels/pixel_shuffle.js +22 -14
  135. package/src/gpu/kernels/pixel_shuffle.wgsl +4 -5
  136. package/src/gpu/kernels/pixel_shuffle_f16.wgsl +4 -5
  137. package/src/gpu/kernels/relu.js +31 -10
  138. package/src/gpu/kernels/relu.wgsl +2 -1
  139. package/src/gpu/kernels/relu_f16.wgsl +2 -1
  140. package/src/gpu/kernels/repeat_channels.js +25 -17
  141. package/src/gpu/kernels/repeat_channels.wgsl +4 -5
  142. package/src/gpu/kernels/repeat_channels_f16.wgsl +4 -5
  143. package/src/gpu/kernels/residual.js +69 -23
  144. package/src/gpu/kernels/residual.wgsl +6 -3
  145. package/src/gpu/kernels/residual_f16.wgsl +2 -1
  146. package/src/gpu/kernels/residual_f16_vec4.wgsl +2 -1
  147. package/src/gpu/kernels/residual_vec4.wgsl +2 -1
  148. package/src/gpu/kernels/rmsnorm.js +96 -28
  149. package/src/gpu/kernels/rmsnorm.wgsl +14 -6
  150. package/src/gpu/kernels/rmsnorm_f16.wgsl +10 -2
  151. package/src/gpu/kernels/rope.d.ts +2 -0
  152. package/src/gpu/kernels/rope.js +14 -1
  153. package/src/gpu/kernels/rope.wgsl +56 -40
  154. package/src/gpu/kernels/sample.js +27 -38
  155. package/src/gpu/kernels/sana_linear_attention.js +19 -12
  156. package/src/gpu/kernels/sana_linear_attention_apply.wgsl +4 -5
  157. package/src/gpu/kernels/sana_linear_attention_apply_f16.wgsl +4 -5
  158. package/src/gpu/kernels/sana_linear_attention_summary.wgsl +4 -0
  159. package/src/gpu/kernels/sana_linear_attention_summary_f16.wgsl +4 -0
  160. package/src/gpu/kernels/scale.js +18 -11
  161. package/src/gpu/kernels/shader-cache.js +4 -2
  162. package/src/gpu/kernels/silu.d.ts +1 -0
  163. package/src/gpu/kernels/silu.js +148 -82
  164. package/src/gpu/kernels/silu.wgsl +19 -9
  165. package/src/gpu/kernels/silu_f16.wgsl +19 -9
  166. package/src/gpu/kernels/softmax.js +44 -25
  167. package/src/gpu/kernels/split_qkv.js +23 -13
  168. package/src/gpu/kernels/transpose.js +31 -10
  169. package/src/gpu/kernels/transpose.wgsl +6 -5
  170. package/src/gpu/kernels/upsample2d.js +22 -13
  171. package/src/gpu/kernels/upsample2d.wgsl +6 -9
  172. package/src/gpu/kernels/upsample2d_f16.wgsl +6 -9
  173. package/src/gpu/kernels/utils.js +35 -13
  174. package/src/gpu/partitioned-buffer-pool.js +10 -2
  175. package/src/gpu/perf-guards.js +2 -9
  176. package/src/gpu/profiler.js +27 -22
  177. package/src/gpu/readback-utils.d.ts +16 -0
  178. package/src/gpu/readback-utils.js +41 -0
  179. package/src/gpu/submit-tracker.js +13 -0
  180. package/src/gpu/uniform-cache.d.ts +1 -0
  181. package/src/gpu/uniform-cache.js +30 -9
  182. package/src/hotswap/intent-bundle.js +6 -0
  183. package/src/hotswap/manifest.d.ts +10 -1
  184. package/src/hotswap/manifest.js +12 -2
  185. package/src/hotswap/runtime.js +30 -8
  186. package/src/index-browser.d.ts +44 -0
  187. package/src/index-browser.js +14 -0
  188. package/src/inference/browser-harness-contract-helpers.d.ts +5 -0
  189. package/src/inference/browser-harness-contract-helpers.js +28 -0
  190. package/src/inference/browser-harness-diffusion-energy-suites.d.ts +2 -0
  191. package/src/inference/browser-harness-diffusion-energy-suites.js +269 -0
  192. package/src/inference/browser-harness-model-helpers.d.ts +16 -0
  193. package/src/inference/browser-harness-model-helpers.js +217 -0
  194. package/src/inference/browser-harness-report-helpers.d.ts +7 -0
  195. package/src/inference/browser-harness-report-helpers.js +42 -0
  196. package/src/inference/browser-harness-runtime-helpers.d.ts +61 -0
  197. package/src/inference/browser-harness-runtime-helpers.js +415 -0
  198. package/src/inference/browser-harness-suite-helpers.d.ts +28 -0
  199. package/src/inference/browser-harness-suite-helpers.js +268 -0
  200. package/src/inference/browser-harness-text-helpers.d.ts +27 -0
  201. package/src/inference/browser-harness-text-helpers.js +788 -0
  202. package/src/inference/browser-harness.d.ts +6 -0
  203. package/src/inference/browser-harness.js +130 -1950
  204. package/src/inference/kv-cache/base.js +140 -94
  205. package/src/inference/kv-cache/tiered.js +5 -3
  206. package/src/inference/moe-router.js +88 -56
  207. package/src/inference/multi-model-network.js +5 -3
  208. package/src/inference/network-evolution.d.ts +11 -2
  209. package/src/inference/network-evolution.js +20 -21
  210. package/src/inference/pipelines/context.d.ts +3 -0
  211. package/src/inference/pipelines/context.js +142 -2
  212. package/src/inference/pipelines/diffusion/helpers.js +7 -2
  213. package/src/inference/pipelines/diffusion/pipeline.js +17 -7
  214. package/src/inference/pipelines/diffusion/sd3-transformer.js +10 -10
  215. package/src/inference/pipelines/diffusion/text-encoder-gpu.d.ts +5 -0
  216. package/src/inference/pipelines/diffusion/text-encoder-gpu.js +27 -15
  217. package/src/inference/pipelines/diffusion/vae.js +3 -7
  218. package/src/inference/pipelines/energy/pipeline.js +27 -21
  219. package/src/inference/pipelines/energy/quintel.d.ts +5 -0
  220. package/src/inference/pipelines/energy/quintel.js +11 -0
  221. package/src/inference/pipelines/energy-head/row-head-pipeline.js +17 -13
  222. package/src/inference/pipelines/structured/json-head-pipeline.js +26 -11
  223. package/src/inference/pipelines/text/attention/projections.js +151 -101
  224. package/src/inference/pipelines/text/attention/record.js +73 -10
  225. package/src/inference/pipelines/text/attention/run.js +73 -10
  226. package/src/inference/pipelines/text/chat-format.js +25 -1
  227. package/src/inference/pipelines/text/config.d.ts +4 -0
  228. package/src/inference/pipelines/text/config.js +71 -5
  229. package/src/inference/pipelines/text/embed.js +2 -8
  230. package/src/inference/pipelines/text/execution-plan.js +64 -50
  231. package/src/inference/pipelines/text/execution-v0-contract-helpers.d.ts +59 -0
  232. package/src/inference/pipelines/text/execution-v0-contract-helpers.js +937 -0
  233. package/src/inference/pipelines/text/execution-v0-runtime-builders.d.ts +15 -0
  234. package/src/inference/pipelines/text/execution-v0-runtime-builders.js +279 -0
  235. package/src/inference/pipelines/text/execution-v0.js +78 -1002
  236. package/src/inference/pipelines/text/ffn/standard.js +3 -0
  237. package/src/inference/pipelines/text/generator-steps.d.ts +46 -0
  238. package/src/inference/pipelines/text/generator-steps.js +298 -207
  239. package/src/inference/pipelines/text/generator.js +6 -23
  240. package/src/inference/pipelines/text/init.d.ts +4 -0
  241. package/src/inference/pipelines/text/init.js +134 -29
  242. package/src/inference/pipelines/text/kernel-path-auto-select.js +2 -0
  243. package/src/inference/pipelines/text/kernel-trace.d.ts +2 -0
  244. package/src/inference/pipelines/text/kernel-trace.js +6 -0
  245. package/src/inference/pipelines/text/layer.js +14 -9
  246. package/src/inference/pipelines/text/linear-attention.d.ts +10 -0
  247. package/src/inference/pipelines/text/linear-attention.js +80 -6
  248. package/src/inference/pipelines/text/logits/gpu.js +10 -5
  249. package/src/inference/pipelines/text/logits/index.js +10 -11
  250. package/src/inference/pipelines/text/logits/utils.d.ts +7 -0
  251. package/src/inference/pipelines/text/logits/utils.js +9 -0
  252. package/src/inference/pipelines/text/lora-apply.js +50 -32
  253. package/src/inference/pipelines/text/model-load.js +279 -104
  254. package/src/inference/pipelines/text/moe-cache.js +5 -4
  255. package/src/inference/pipelines/text/moe-cpu-gptoss.js +74 -69
  256. package/src/inference/pipelines/text/moe-cpu.js +42 -38
  257. package/src/inference/pipelines/text/moe-gpu.js +110 -86
  258. package/src/inference/pipelines/text/ops.js +90 -90
  259. package/src/inference/pipelines/text/probes.js +9 -9
  260. package/src/inference/pipelines/text/weights.js +17 -7
  261. package/src/inference/pipelines/text.js +17 -1
  262. package/src/inference/speculative.d.ts +2 -2
  263. package/src/inference/speculative.js +4 -18
  264. package/src/inference/test-harness.d.ts +1 -1
  265. package/src/inference/test-harness.js +15 -5
  266. package/src/inference/tokenizer.d.ts +0 -5
  267. package/src/inference/tokenizer.js +4 -23
  268. package/src/inference/tokenizers/bpe.js +9 -0
  269. package/src/inference/tokenizers/bundled.js +176 -33
  270. package/src/inference/tokenizers/sentencepiece.js +12 -0
  271. package/src/loader/doppler-loader.js +38 -22
  272. package/src/loader/dtype-utils.js +3 -44
  273. package/src/loader/embedding-loader.js +7 -3
  274. package/src/loader/experts/expert-cache.js +13 -6
  275. package/src/loader/experts/expert-loader.js +10 -6
  276. package/src/loader/final-weights-loader.js +8 -4
  277. package/src/loader/layer-loader.js +2 -1
  278. package/src/loader/loader-state.js +2 -2
  279. package/src/loader/memory-monitor.js +8 -0
  280. package/src/loader/multi-model-loader.d.ts +14 -0
  281. package/src/loader/multi-model-loader.js +70 -24
  282. package/src/loader/shard-cache.js +81 -12
  283. package/src/loader/shard-resolver.js +25 -3
  284. package/src/loader/tensors/tensor-loader.js +209 -144
  285. package/src/loader/tensors/tensor-reader.js +76 -19
  286. package/src/loader/weight-downcast.js +1 -1
  287. package/src/memory/buffer-pool.d.ts +9 -1
  288. package/src/memory/buffer-pool.js +109 -44
  289. package/src/memory/unified-detect.js +1 -1
  290. package/src/rules/inference/kernel-path.rules.json +24 -8
  291. package/src/rules/rule-registry.js +25 -1
  292. package/src/rules/tooling/command-runtime.rules.json +18 -0
  293. package/src/storage/backends/opfs-store.js +68 -24
  294. package/src/storage/downloader.js +364 -83
  295. package/src/storage/index.d.ts +3 -0
  296. package/src/storage/index.js +3 -0
  297. package/src/storage/preflight.d.ts +2 -2
  298. package/src/storage/preflight.js +24 -2
  299. package/src/storage/quickstart-downloader.js +11 -5
  300. package/src/storage/registry.js +10 -4
  301. package/src/storage/reports.js +1 -1
  302. package/src/storage/shard-manager.d.ts +15 -1
  303. package/src/storage/shard-manager.js +51 -3
  304. package/src/storage/source-artifact-store.d.ts +52 -0
  305. package/src/storage/source-artifact-store.js +234 -0
  306. package/src/tooling/command-api-constants.d.ts +9 -0
  307. package/src/tooling/command-api-constants.js +9 -0
  308. package/src/tooling/command-api-family-normalizers.d.ts +9 -0
  309. package/src/tooling/command-api-family-normalizers.js +343 -0
  310. package/src/tooling/command-api-helpers.d.ts +25 -0
  311. package/src/tooling/command-api-helpers.js +262 -0
  312. package/src/tooling/command-api.d.ts +27 -1
  313. package/src/tooling/command-api.js +26 -473
  314. package/src/tooling/command-envelope.js +4 -1
  315. package/src/tooling/command-runner-shared.js +52 -18
  316. package/src/tooling/lean-execution-contract.js +150 -3
  317. package/src/tooling/node-browser-command-runner.d.ts +4 -0
  318. package/src/tooling/node-browser-command-runner.js +218 -273
  319. package/src/tooling/node-command-runner.js +44 -3
  320. package/src/tooling/node-converter.js +27 -1
  321. package/src/tooling/node-source-runtime.d.ts +1 -1
  322. package/src/tooling/node-source-runtime.js +84 -3
  323. package/src/tooling/node-webgpu.js +30 -105
  324. package/src/tooling/opfs-cache.js +21 -4
  325. package/src/tooling/runtime-input-composition.d.ts +38 -0
  326. package/src/tooling/runtime-input-composition.js +86 -0
  327. package/src/tooling/source-runtime-bundle.d.ts +40 -5
  328. package/src/tooling/source-runtime-bundle.js +261 -34
  329. package/src/tooling/source-runtime-materializer.d.ts +6 -0
  330. package/src/tooling/source-runtime-materializer.js +93 -0
  331. package/src/training/attention-backward.js +32 -17
  332. package/src/training/autograd.js +80 -52
  333. package/src/training/checkpoint-watch.d.ts +8 -0
  334. package/src/training/checkpoint-watch.js +139 -0
  335. package/src/training/checkpoint.d.ts +6 -1
  336. package/src/training/checkpoint.js +46 -7
  337. package/src/training/clip.js +2 -1
  338. package/src/training/datasets/token-batch.js +20 -8
  339. package/src/training/distillation/artifacts.d.ts +71 -0
  340. package/src/training/distillation/artifacts.js +132 -0
  341. package/src/training/distillation/checkpoint-watch.d.ts +10 -0
  342. package/src/training/distillation/checkpoint-watch.js +58 -0
  343. package/src/training/distillation/dataset.d.ts +59 -0
  344. package/src/training/distillation/dataset.js +337 -0
  345. package/src/training/distillation/eval.d.ts +34 -0
  346. package/src/training/distillation/eval.js +310 -0
  347. package/src/training/distillation/index.d.ts +29 -0
  348. package/src/training/distillation/index.js +29 -0
  349. package/src/training/distillation/runtime.d.ts +20 -0
  350. package/src/training/distillation/runtime.js +121 -0
  351. package/src/training/distillation/scoreboard.d.ts +6 -0
  352. package/src/training/distillation/scoreboard.js +8 -0
  353. package/src/training/distillation/stage-a.d.ts +45 -0
  354. package/src/training/distillation/stage-a.js +338 -0
  355. package/src/training/distillation/stage-b.d.ts +24 -0
  356. package/src/training/distillation/stage-b.js +20 -0
  357. package/src/training/distillation/student-fixture.d.ts +22 -0
  358. package/src/training/distillation/student-fixture.js +846 -0
  359. package/src/training/distillation/suite-data.d.ts +45 -0
  360. package/src/training/distillation/suite-data.js +189 -0
  361. package/src/training/index.d.ts +10 -0
  362. package/src/training/index.js +10 -0
  363. package/src/training/lora-pipeline.d.ts +40 -0
  364. package/src/training/lora-pipeline.js +793 -0
  365. package/src/training/lora.js +26 -12
  366. package/src/training/loss.js +5 -6
  367. package/src/training/objectives/cross_entropy.js +2 -5
  368. package/src/training/objectives/distill_kd.js +4 -8
  369. package/src/training/objectives/distill_triplet.js +4 -8
  370. package/src/training/objectives/ul_stage2_base.js +4 -8
  371. package/src/training/operator-artifacts.d.ts +62 -0
  372. package/src/training/operator-artifacts.js +140 -0
  373. package/src/training/operator-command.d.ts +5 -0
  374. package/src/training/operator-command.js +455 -0
  375. package/src/training/operator-eval.d.ts +48 -0
  376. package/src/training/operator-eval.js +230 -0
  377. package/src/training/operator-scoreboard.d.ts +5 -0
  378. package/src/training/operator-scoreboard.js +44 -0
  379. package/src/training/optimizer.js +19 -7
  380. package/src/training/runner.d.ts +52 -0
  381. package/src/training/runner.js +31 -5
  382. package/src/training/suite.d.ts +112 -0
  383. package/src/training/suite.js +24 -984
  384. package/src/training/tensor-factory.d.ts +9 -0
  385. package/src/training/tensor-factory.js +13 -0
  386. package/src/training/trainer.js +3 -5
  387. package/src/training/ul_dataset.js +3 -5
  388. package/src/training/workloads.d.ts +164 -0
  389. package/src/training/workloads.js +530 -0
  390. package/src/version.js +1 -1
  391. package/tools/convert-safetensors-node.js +22 -16
  392. package/tools/doppler-cli.js +179 -63
@@ -1,4 +1,4 @@
1
- import { acquireBuffer } from '../../memory/buffer-pool.js';
1
+ import { acquireBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
2
2
  import { createTensor, dtypeBytes } from '../tensor.js';
3
3
  import { unifiedKernelWrapper } from './utils.js';
4
4
  import { selectRuleValue } from './rule-registry.js';
@@ -32,23 +32,31 @@ async function _repeatChannels(target, input, options = {}) {
32
32
  const bytesPerElement = dtypeBytes(input.dtype);
33
33
  const outputSize = outChannels * height * width * bytesPerElement;
34
34
  const output = outputBuffer || acquireBuffer(outputSize, undefined, 'repeat_channels_output');
35
+ const ownedOutput = outputBuffer ? null : output;
35
36
 
36
- await unifiedKernelWrapper(
37
- 'repeat_channels',
38
- target,
39
- variant,
40
- [input, output],
41
- {
42
- in_channels: inChannels,
43
- height,
44
- width,
45
- repeats,
46
- _pad0: 0,
47
- },
48
- Math.ceil((outChannels * height * width) / WORKGROUP_SIZES.DEFAULT)
49
- );
50
-
51
- return createTensor(output, input.dtype, [outChannels, height, width], 'repeat_channels_output');
37
+ try {
38
+ await unifiedKernelWrapper(
39
+ 'repeat_channels',
40
+ target,
41
+ variant,
42
+ [input, output],
43
+ {
44
+ in_channels: inChannels,
45
+ height,
46
+ width,
47
+ repeats,
48
+ _pad0: 0,
49
+ },
50
+ [Math.ceil((height * width) / WORKGROUP_SIZES.DEFAULT), outChannels, 1]
51
+ );
52
+
53
+ return createTensor(output, input.dtype, [outChannels, height, width], 'repeat_channels_output');
54
+ } catch (error) {
55
+ if (ownedOutput) {
56
+ releaseBuffer(ownedOutput);
57
+ }
58
+ throw error;
59
+ }
52
60
  }
53
61
 
54
62
  export async function runRepeatChannels(input, options = {}) {
@@ -14,16 +14,15 @@ struct Uniforms {
14
14
 
15
15
  @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
16
16
  fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
17
- let idx = gid.x;
18
17
  let spatial = u.height * u.width;
19
18
  let out_channels = u.in_channels * u.repeats;
20
- let total = out_channels * spatial;
21
- if (idx >= total) {
19
+ let spatial_idx = gid.x;
20
+ let out_channel = gid.y;
21
+ if (out_channel >= out_channels || spatial_idx >= spatial) {
22
22
  return;
23
23
  }
24
24
 
25
- let out_channel = idx / spatial;
26
25
  let channel = out_channel / u.repeats;
27
- let spatial_idx = idx - out_channel * spatial;
26
+ let idx = out_channel * spatial + spatial_idx;
28
27
  output[idx] = input[channel * spatial + spatial_idx];
29
28
  }
@@ -16,16 +16,15 @@ struct Uniforms {
16
16
 
17
17
  @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
18
18
  fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
19
- let idx = gid.x;
20
19
  let spatial = u.height * u.width;
21
20
  let out_channels = u.in_channels * u.repeats;
22
- let total = out_channels * spatial;
23
- if (idx >= total) {
21
+ let spatial_idx = gid.x;
22
+ let out_channel = gid.y;
23
+ if (out_channel >= out_channels || spatial_idx >= spatial) {
24
24
  return;
25
25
  }
26
26
 
27
- let out_channel = idx / spatial;
28
27
  let channel = out_channel / u.repeats;
29
- let spatial_idx = idx - out_channel * spatial;
28
+ let idx = out_channel * spatial + spatial_idx;
30
29
  output[idx] = input[channel * spatial + spatial_idx];
31
30
  }
@@ -63,9 +63,26 @@ function cleanupTemps(temps, recorder) {
63
63
  }
64
64
  }
65
65
 
66
+ function planResidualDispatch(target, size, elementsPerWorkgroup) {
67
+ const device = target?.device;
68
+ const maxPerDim = Number.isFinite(device?.limits?.maxComputeWorkgroupsPerDimension)
69
+ ? device.limits.maxComputeWorkgroupsPerDimension
70
+ : 65535;
71
+ const dispatchStride = Math.min(size, maxPerDim * elementsPerWorkgroup);
72
+ return {
73
+ dispatchStride,
74
+ workgroups: [
75
+ Math.ceil(dispatchStride / elementsPerWorkgroup),
76
+ Math.ceil(size / dispatchStride),
77
+ 1,
78
+ ],
79
+ };
80
+ }
81
+
66
82
  async function _residualAdd(target, a, b, size, options = {}) {
67
83
  const recorder = target && typeof target.beginComputePass === 'function' ? target : null;
68
84
  const { useVec4 = true, outputBuffer = null } = options;
85
+ const ownsOutput = outputBuffer == null;
69
86
 
70
87
  const { a: aAligned, b: bAligned, temps } = await alignResidualInputs(a, b, recorder);
71
88
  const outputDtype = inferOutputDtype(aAligned, bAligned);
@@ -75,19 +92,28 @@ async function _residualAdd(target, a, b, size, options = {}) {
75
92
  const outputSize = size * bytesPerElement;
76
93
  const output = outputBuffer || acquireBuffer(outputSize, undefined, 'residual_output');
77
94
 
78
- const workgroups = useVec4
79
- ? Math.ceil(size / VEC4_ELEMENTS_PER_WG)
80
- : Math.ceil(size / WORKGROUP_SIZES.DEFAULT);
81
-
82
- await unifiedKernelWrapper(
83
- 'residual', target, variant,
84
- [aAligned, bAligned, output],
85
- { size },
86
- workgroups
95
+ const dispatchPlan = planResidualDispatch(
96
+ target,
97
+ size,
98
+ useVec4 ? VEC4_ELEMENTS_PER_WG : WORKGROUP_SIZES.DEFAULT
87
99
  );
88
100
 
89
- cleanupTemps(temps, recorder);
90
- return createTensor(output, outputDtype, [size], 'residual_output');
101
+ try {
102
+ await unifiedKernelWrapper(
103
+ 'residual', target, variant,
104
+ [aAligned, bAligned, output],
105
+ { size, scale: 1, _pad1: dispatchPlan.dispatchStride, _pad2: 0 },
106
+ dispatchPlan.workgroups
107
+ );
108
+ return createTensor(output, outputDtype, [size], 'residual_output');
109
+ } catch (error) {
110
+ if (ownsOutput) {
111
+ releaseBuffer(output);
112
+ }
113
+ throw error;
114
+ } finally {
115
+ cleanupTemps(temps, recorder);
116
+ }
91
117
  }
92
118
 
93
119
  async function _biasAdd(target, data, bias, numTokens, dim, options = {}) {
@@ -96,18 +122,38 @@ async function _biasAdd(target, data, bias, numTokens, dim, options = {}) {
96
122
 
97
123
  const { bias: biasAligned, temps } = await alignBiasTensor(data, bias, recorder);
98
124
  const variant = selectBiasAddVariant(data.dtype, biasAligned.dtype);
99
-
100
- const workgroups = Math.ceil((numTokens * dim) / WORKGROUP_SIZES.DEFAULT);
101
-
102
- await unifiedKernelWrapper(
103
- 'bias_add', target, variant,
104
- [data, biasAligned],
105
- { num_tokens: numTokens, dim, data_offset: dataOffset, bias_offset: biasOffset },
106
- workgroups
107
- );
108
-
109
- cleanupTemps(temps, recorder);
110
- return createTensor(data.buffer, data.dtype, [numTokens, dim], 'bias_add_output');
125
+ const device = target?.device;
126
+ const maxPerDim = Number.isFinite(device?.limits?.maxComputeWorkgroupsPerDimension)
127
+ ? device.limits.maxComputeWorkgroupsPerDimension
128
+ : 65535;
129
+ const tokenStride = Math.min(numTokens, maxPerDim);
130
+
131
+ const workgroups = [
132
+ Math.ceil(dim / WORKGROUP_SIZES.DEFAULT),
133
+ tokenStride,
134
+ Math.ceil(numTokens / tokenStride),
135
+ ];
136
+
137
+ try {
138
+ await unifiedKernelWrapper(
139
+ 'bias_add', target, variant,
140
+ [data, biasAligned],
141
+ {
142
+ num_tokens: numTokens,
143
+ dim,
144
+ data_offset: dataOffset,
145
+ bias_offset: biasOffset,
146
+ token_stride: tokenStride,
147
+ _pad0: 0,
148
+ _pad1: 0,
149
+ _pad2: 0,
150
+ },
151
+ workgroups
152
+ );
153
+ return createTensor(data.buffer, data.dtype, [numTokens, dim], 'bias_add_output');
154
+ } finally {
155
+ cleanupTemps(temps, recorder);
156
+ }
111
157
  }
112
158
 
113
159
  export async function runResidualAdd(a, b, size, options = {}) {
@@ -23,7 +23,8 @@ override WORKGROUP_SIZE: u32 = 256u;
23
23
 
24
24
  @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
25
25
  fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
26
- let idx = gid.x;
26
+ let dispatch_stride = max(u._pad1, 1u);
27
+ let idx = gid.y * dispatch_stride + gid.x;
27
28
  if (idx >= u.size) {
28
29
  return;
29
30
  }
@@ -35,7 +36,8 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
35
36
  // This avoids requiring a different bind group layout with read_write on 'a'
36
37
  @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
37
38
  fn add_inplace(@builtin(global_invocation_id) gid: vec3<u32>) {
38
- let idx = gid.x;
39
+ let dispatch_stride = max(u._pad1, 1u);
40
+ let idx = gid.y * dispatch_stride + gid.x;
39
41
  if (idx >= u.size) {
40
42
  return;
41
43
  }
@@ -45,7 +47,8 @@ fn add_inplace(@builtin(global_invocation_id) gid: vec3<u32>) {
45
47
  // Fused residual + scale: output = a + scale * b
46
48
  @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
47
49
  fn add_scaled(@builtin(global_invocation_id) gid: vec3<u32>) {
48
- let idx = gid.x;
50
+ let dispatch_stride = max(u._pad1, 1u);
51
+ let idx = gid.y * dispatch_stride + gid.x;
49
52
  if (idx >= u.size) {
50
53
  return;
51
54
  }
@@ -27,7 +27,8 @@ override WORKGROUP_SIZE: u32 = 256u;
27
27
 
28
28
  @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
29
29
  fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
30
- let idx = gid.x;
30
+ let dispatch_stride = max(u._pad1, 1u);
31
+ let idx = gid.y * dispatch_stride + gid.x;
31
32
  if (idx >= u.size) {
32
33
  return;
33
34
  }
@@ -25,7 +25,8 @@ override WORKGROUP_SIZE_VEC4: u32 = 64u;
25
25
  // Vectorized version for better throughput
26
26
  @compute @workgroup_size(WORKGROUP_SIZE_VEC4, 1, 1)
27
27
  fn add_vec4(@builtin(global_invocation_id) gid: vec3<u32>) {
28
- let idx = gid.x * 4u;
28
+ let dispatch_stride = max(u._pad1, 4u);
29
+ let idx = gid.y * dispatch_stride + gid.x * 4u;
29
30
  let size = u.size;
30
31
 
31
32
  if (idx >= size) {
@@ -23,7 +23,8 @@ override WORKGROUP_SIZE_VEC4: u32 = 64u;
23
23
  // Vectorized version for better throughput
24
24
  @compute @workgroup_size(WORKGROUP_SIZE_VEC4, 1, 1)
25
25
  fn add_vec4(@builtin(global_invocation_id) gid: vec3<u32>) {
26
- let idx = gid.x * 4u;
26
+ let dispatch_stride = max(u._pad1, 4u);
27
+ let idx = gid.y * dispatch_stride + gid.x * 4u;
27
28
  let size = u.size;
28
29
 
29
30
  if (idx >= size) {
@@ -1,7 +1,7 @@
1
1
 
2
2
 
3
3
  import { getKernelCapabilities } from '../device.js';
4
- import { acquireBuffer, getBufferRequestedSize } from '../../memory/buffer-pool.js';
4
+ import { acquireBuffer, getBufferRequestedSize, releaseBuffer } from '../../memory/buffer-pool.js';
5
5
  import { createTensor } from '../tensor.js';
6
6
  import { getKernelThresholds, padToQ4KBlock } from '../../config/schema/index.js';
7
7
  import { selectRuleValue } from './rule-registry.js';
@@ -58,6 +58,36 @@ function resolveNormWeightDtype(weight, hiddenSize) {
58
58
  return 'f32';
59
59
  }
60
60
 
61
+ function assertRMSNormWeightBuffer(weight, weightBuffer, hiddenSize) {
62
+ const isGpuBuffer = weightBuffer && (
63
+ typeof GPUBuffer === 'undefined'
64
+ ? true
65
+ : weightBuffer instanceof GPUBuffer
66
+ );
67
+ if (isGpuBuffer) {
68
+ return;
69
+ }
70
+ const weightLabel = weight?.label ?? 'unknown';
71
+ const weightType = weight === null ? 'null' : weight === undefined ? 'undefined' : weight.constructor?.name || typeof weight;
72
+ const bufferType = weightBuffer === null ? 'null' : weightBuffer === undefined ? 'undefined' : weightBuffer.constructor?.name || typeof weightBuffer;
73
+ throw new Error(
74
+ `[rmsnorm] weight "${weightLabel}" requires a GPUBuffer ` +
75
+ `(weightType=${weightType}, bufferType=${bufferType}, hiddenSize=${hiddenSize ?? 'unknown'}).`
76
+ );
77
+ }
78
+
79
+ function planRMSNormDispatch(target, numTokens) {
80
+ const device = target?.device;
81
+ const maxPerDim = Number.isFinite(device?.limits?.maxComputeWorkgroupsPerDimension)
82
+ ? device.limits.maxComputeWorkgroupsPerDimension
83
+ : 65535;
84
+ const tokenStride = Math.min(numTokens, maxPerDim);
85
+ return {
86
+ tokenStride,
87
+ workgroups: [tokenStride, Math.ceil(numTokens / tokenStride), 1],
88
+ };
89
+ }
90
+
61
91
  export function selectRMSNormKernel(options = {}, isF16 = false) {
62
92
  const { residual = null, hiddenSize = null } = options;
63
93
  const { smallThreshold } = getKernelThresholds().rmsnorm;
@@ -82,27 +112,46 @@ export async function runRMSNorm(
82
112
  const variant = selectRMSNormKernel(options, isF16);
83
113
  const inferredHiddenSize = inferHiddenSize(input, hiddenSize);
84
114
  const normWeightBuffer = getBuffer(weight);
115
+ assertRMSNormWeightBuffer(weight, normWeightBuffer, inferredHiddenSize);
85
116
  const normWeightDtype = resolveNormWeightDtype(weight, inferredHiddenSize);
86
117
 
87
118
  const bytesPerElement = isF16 ? 2 : 4;
88
119
  const paddedHiddenSize = padToQ4KBlock(inferredHiddenSize);
89
120
  const outputSize = batchSize * paddedHiddenSize * bytesPerElement;
90
121
  const outputBuf = outputBuffer || acquireBuffer(outputSize, undefined, 'rmsnorm_output');
122
+ const ownedOutput = outputBuffer ? null : outputBuf;
123
+ const dispatchPlan = planRMSNormDispatch(null, batchSize);
91
124
 
92
125
  // Shader layout always includes the residual binding; when unused, bind a harmless placeholder.
93
- const residualBuf = residual?.buffer || input.buffer;
94
-
95
- await unifiedKernelWrapper(
96
- 'rmsnorm',
97
- null,
98
- variant,
99
- [input, normWeightBuffer, outputBuf, residualBuf],
100
- { hidden_size: inferredHiddenSize, num_tokens: batchSize, eps, has_residual: residual ? 1 : 0 },
101
- batchSize,
102
- { RMS_NORM_OFFSET: rmsNormWeightOffset, WEIGHT_IS_F16: normWeightDtype === 'f16' }
103
- );
104
-
105
- return createTensor(outputBuf, input.dtype, [batchSize, inferredHiddenSize], 'rmsnorm_output');
126
+ const residualBuf = residual?.buffer || residual || input?.buffer || input || outputBuf;
127
+
128
+ try {
129
+ await unifiedKernelWrapper(
130
+ 'rmsnorm',
131
+ null,
132
+ variant,
133
+ [input, normWeightBuffer, outputBuf, residualBuf],
134
+ {
135
+ hidden_size: inferredHiddenSize,
136
+ num_tokens: batchSize,
137
+ eps,
138
+ has_residual: residual ? 1 : 0,
139
+ token_stride: dispatchPlan.tokenStride,
140
+ _pad0: 0,
141
+ _pad1: 0,
142
+ _pad2: 0,
143
+ },
144
+ dispatchPlan.workgroups,
145
+ { RMS_NORM_OFFSET: rmsNormWeightOffset, WEIGHT_IS_F16: normWeightDtype === 'f16' }
146
+ );
147
+
148
+ return createTensor(outputBuf, input.dtype, [batchSize, inferredHiddenSize], 'rmsnorm_output');
149
+ } catch (error) {
150
+ if (ownedOutput) {
151
+ releaseBuffer(ownedOutput);
152
+ }
153
+ throw error;
154
+ }
106
155
  }
107
156
 
108
157
  export async function recordRMSNorm(
@@ -117,24 +166,43 @@ export async function recordRMSNorm(
117
166
  const variant = selectRMSNormKernel(options, isF16);
118
167
  const inferredHiddenSize = inferHiddenSize(input, hiddenSize);
119
168
  const normWeightBuffer = getBuffer(weight);
169
+ assertRMSNormWeightBuffer(weight, normWeightBuffer, inferredHiddenSize);
120
170
  const normWeightDtype = resolveNormWeightDtype(weight, inferredHiddenSize);
121
171
 
122
172
  const bytesPerElement = isF16 ? 2 : 4;
123
173
  const paddedHiddenSize = padToQ4KBlock(inferredHiddenSize);
124
174
  const outputSize = batchSize * paddedHiddenSize * bytesPerElement;
125
175
  const outputBuf = outputBuffer || acquireBuffer(outputSize, undefined, 'rmsnorm_output');
126
-
127
- const residualBuf = residual?.buffer || input.buffer;
128
-
129
- await unifiedKernelWrapper(
130
- 'rmsnorm',
131
- recorder,
132
- variant,
133
- [input, normWeightBuffer, outputBuf, residualBuf],
134
- { hidden_size: inferredHiddenSize, num_tokens: batchSize, eps, has_residual: residual ? 1 : 0 },
135
- batchSize,
136
- { RMS_NORM_OFFSET: rmsNormWeightOffset, WEIGHT_IS_F16: normWeightDtype === 'f16' }
137
- );
138
-
139
- return createTensor(outputBuf, input.dtype, [batchSize, inferredHiddenSize], 'rmsnorm_output');
176
+ const ownedOutput = outputBuffer ? null : outputBuf;
177
+ const dispatchPlan = planRMSNormDispatch(recorder, batchSize);
178
+
179
+ const residualBuf = residual?.buffer || residual || input?.buffer || input || outputBuf;
180
+
181
+ try {
182
+ await unifiedKernelWrapper(
183
+ 'rmsnorm',
184
+ recorder,
185
+ variant,
186
+ [input, normWeightBuffer, outputBuf, residualBuf],
187
+ {
188
+ hidden_size: inferredHiddenSize,
189
+ num_tokens: batchSize,
190
+ eps,
191
+ has_residual: residual ? 1 : 0,
192
+ token_stride: dispatchPlan.tokenStride,
193
+ _pad0: 0,
194
+ _pad1: 0,
195
+ _pad2: 0,
196
+ },
197
+ dispatchPlan.workgroups,
198
+ { RMS_NORM_OFFSET: rmsNormWeightOffset, WEIGHT_IS_F16: normWeightDtype === 'f16' }
199
+ );
200
+
201
+ return createTensor(outputBuf, input.dtype, [batchSize, inferredHiddenSize], 'rmsnorm_output');
202
+ } catch (error) {
203
+ if (ownedOutput) {
204
+ releaseBuffer(ownedOutput);
205
+ }
206
+ throw error;
207
+ }
140
208
  }
@@ -39,6 +39,10 @@ struct Uniforms {
39
39
  num_tokens: u32, // Number of tokens to process
40
40
  eps: f32, // Epsilon for numerical stability (typically 1e-5 or 1e-6)
41
41
  has_residual: u32, // Runtime flag: 1 = add residual after norm
42
+ token_stride: u32, // Workgroup rows per dispatch row
43
+ _pad0: u32,
44
+ _pad1: u32,
45
+ _pad2: u32,
42
46
  }
43
47
 
44
48
  @group(0) @binding(0) var<uniform> u: Uniforms;
@@ -82,6 +86,10 @@ fn should_add_residual() -> bool {
82
86
  return HAS_RESIDUAL || (u.has_residual != 0u);
83
87
  }
84
88
 
89
+ fn token_index(wg_id: vec3<u32>) -> u32 {
90
+ return wg_id.y * max(u.token_stride, 1u) + wg_id.x;
91
+ }
92
+
85
93
  // =============================================================================
86
94
  // Main Entry Point
87
95
  // =============================================================================
@@ -93,7 +101,7 @@ fn main(
93
101
  @builtin(local_invocation_id) local_id: vec3<u32>,
94
102
  @builtin(workgroup_id) wg_id: vec3<u32>
95
103
  ) {
96
- let token_idx = wg_id.x;
104
+ let token_idx = token_index(wg_id);
97
105
  let thread_idx = local_id.x;
98
106
  let size = u.size;
99
107
 
@@ -163,7 +171,7 @@ fn main_small(
163
171
  @builtin(local_invocation_id) local_id: vec3<u32>,
164
172
  @builtin(workgroup_id) wg_id: vec3<u32>
165
173
  ) {
166
- let token_idx = wg_id.x;
174
+ let token_idx = token_index(wg_id);
167
175
  let thread_idx = local_id.x;
168
176
  let size = u.size;
169
177
 
@@ -219,7 +227,7 @@ fn main_cached(
219
227
  @builtin(local_invocation_id) local_id: vec3<u32>,
220
228
  @builtin(workgroup_id) wg_id: vec3<u32>
221
229
  ) {
222
- let token_idx = wg_id.x;
230
+ let token_idx = token_index(wg_id);
223
231
  let thread_idx = local_id.x;
224
232
  let size = u.size;
225
233
 
@@ -288,7 +296,7 @@ fn main_subgroup(
288
296
  @builtin(subgroup_invocation_id) sg_lane: u32,
289
297
  @builtin(subgroup_size) sg_size: u32,
290
298
  ) {
291
- let token_idx = wg_id.x;
299
+ let token_idx = token_index(wg_id);
292
300
  let thread_idx = local_id.x;
293
301
  let size = u.size;
294
302
 
@@ -362,7 +370,7 @@ fn main_small_subgroup(
362
370
  @builtin(subgroup_invocation_id) sg_lane: u32,
363
371
  @builtin(subgroup_size) sg_size: u32,
364
372
  ) {
365
- let token_idx = wg_id.x;
373
+ let token_idx = token_index(wg_id);
366
374
  let thread_idx = local_id.x;
367
375
  let size = u.size;
368
376
 
@@ -414,4 +422,4 @@ fn main_small_subgroup(
414
422
  }
415
423
  output[base_offset + thread_idx] = result;
416
424
  }
417
- }
425
+ }
@@ -20,6 +20,10 @@ struct Uniforms {
20
20
  num_tokens: u32, // Number of tokens to process
21
21
  eps: f32, // Epsilon for numerical stability
22
22
  has_residual: u32, // 1 if residual input provided, 0 otherwise
23
+ token_stride: u32, // Workgroup rows per dispatch row
24
+ _pad0: u32,
25
+ _pad1: u32,
26
+ _pad2: u32,
23
27
  }
24
28
 
25
29
  @group(0) @binding(0) var<uniform> u: Uniforms;
@@ -47,6 +51,10 @@ fn load_weight(idx: u32) -> f32 {
47
51
  return bitcast<f32>(weight[idx]);
48
52
  }
49
53
 
54
+ fn token_index(wg_id: vec3<u32>) -> u32 {
55
+ return wg_id.y * max(u.token_stride, 1u) + wg_id.x;
56
+ }
57
+
50
58
  // Main RMSNorm kernel - one workgroup per token
51
59
  @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
52
60
  fn main(
@@ -54,7 +62,7 @@ fn main(
54
62
  @builtin(local_invocation_id) local_id: vec3<u32>,
55
63
  @builtin(workgroup_id) wg_id: vec3<u32>
56
64
  ) {
57
- let token_idx = wg_id.x;
65
+ let token_idx = token_index(wg_id);
58
66
  let thread_idx = local_id.x;
59
67
  let size = u.size;
60
68
 
@@ -121,7 +129,7 @@ fn rmsnorm_small_f16(
121
129
  @builtin(local_invocation_id) local_id: vec3<u32>,
122
130
  @builtin(workgroup_id) wg_id: vec3<u32>
123
131
  ) {
124
- let token_idx = wg_id.x;
132
+ let token_idx = token_index(wg_id);
125
133
  let thread_idx = local_id.x;
126
134
  let size = u.size;
127
135
 
@@ -15,6 +15,8 @@ import type { OutputBufferOptions } from './types.js';
15
15
  export interface RoPEOptions extends OutputBufferOptions {
16
16
  numHeads?: number;
17
17
  headDim?: number;
18
+ rotaryDim?: number;
19
+ interleaved?: boolean;
18
20
  ropeTheta?: number;
19
21
  startPos?: number;
20
22
  }
@@ -13,18 +13,29 @@ async function _rope(target, input, freqsCos, freqsSin, seqLen, options = {}) {
13
13
  const {
14
14
  numHeads = 1,
15
15
  headDim = 64,
16
+ rotaryDim = headDim,
17
+ interleaved = false,
16
18
  ropeTheta = ropeDefaults.defaultTheta,
17
19
  } = options;
18
20
 
19
21
  if (headDim % 2 !== 0) {
20
22
  throw new Error(`RoPE headDim must be even, got ${headDim}`);
21
23
  }
24
+ if (rotaryDim % 2 !== 0) {
25
+ throw new Error(`RoPE rotaryDim must be even, got ${rotaryDim}`);
26
+ }
27
+ if (rotaryDim <= 0 || rotaryDim > headDim) {
28
+ throw new Error(`RoPE rotaryDim must be in (0, headDim]; got ${rotaryDim} for headDim ${headDim}`);
29
+ }
30
+ if (input.dtype === 'f16' && (rotaryDim !== headDim || interleaved)) {
31
+ throw new Error('RoPE f16 kernel requires rotaryDim === headDim and interleaved === false.');
32
+ }
22
33
 
23
34
  const caps = getKernelCapabilities();
24
35
  const useF16 = input.dtype === 'f16' && caps.hasF16;
25
36
  const variant = selectRuleValue('rope', 'variant', { useF16 });
26
37
 
27
- const halfDim = headDim / 2;
38
+ const halfDim = rotaryDim / 2;
28
39
  const workgroups = Math.ceil((seqLen * numHeads * halfDim) / WORKGROUP_SIZES.DEFAULT);
29
40
 
30
41
  await unifiedKernelWrapper(
@@ -34,9 +45,11 @@ async function _rope(target, input, freqsCos, freqsSin, seqLen, options = {}) {
34
45
  seq_len: seqLen,
35
46
  num_heads: numHeads,
36
47
  head_dim: headDim,
48
+ rotary_dim: rotaryDim,
37
49
  start_pos: options.startPos ?? ropeDefaults.defaultStartPos,
38
50
  rope_base: ropeTheta,
39
51
  rope_scale: 1.0,
52
+ interleaved: interleaved ? 1 : 0,
40
53
  },
41
54
  workgroups
42
55
  );