@simulatte/doppler 0.1.5 → 0.1.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (392) hide show
  1. package/CHANGELOG.md +126 -0
  2. package/README.md +25 -17
  3. package/package.json +20 -4
  4. package/src/adapters/adapter-registry.js +12 -1
  5. package/src/adapters/lora-loader.js +23 -6
  6. package/src/bridge/extension-client.d.ts +5 -0
  7. package/src/bridge/extension-client.js +40 -0
  8. package/src/bridge/index.d.ts +2 -1
  9. package/src/bridge/index.js +6 -4
  10. package/src/browser/browser-converter.js +26 -1
  11. package/src/browser/file-picker.js +6 -0
  12. package/src/browser/safetensors-parser-browser.js +84 -1
  13. package/src/browser/shard-io-browser.js +2 -2
  14. package/src/browser/tensor-source-download.js +8 -2
  15. package/src/browser/tensor-source-http.d.ts +1 -0
  16. package/src/browser/tensor-source-http.js +5 -1
  17. package/src/client/doppler-api.browser.js +20 -4
  18. package/src/client/doppler-api.js +19 -3
  19. package/src/client/doppler-provider/generation.js +12 -0
  20. package/src/client/doppler-provider/model-manager.d.ts +10 -0
  21. package/src/client/doppler-provider/model-manager.js +91 -19
  22. package/src/client/doppler-provider/source-runtime.d.ts +2 -1
  23. package/src/client/doppler-provider/source-runtime.js +132 -13
  24. package/src/client/doppler-registry.json +8 -7
  25. package/src/config/backward-registry-loader.js +17 -2
  26. package/src/config/execution-v0-contract-check.js +113 -15
  27. package/src/config/kernel-path-contract-check.js +57 -29
  28. package/src/config/kernel-path-loader.js +5 -36
  29. package/src/config/kernels/kernel-ref-digests.js +39 -39
  30. package/src/config/kernels/registry.js +14 -1
  31. package/src/config/kernels/registry.json +49 -7
  32. package/src/config/loader.d.ts +1 -1
  33. package/src/config/loader.js +43 -4
  34. package/src/config/merge-contract-check.js +59 -4
  35. package/src/config/merge-helpers.js +128 -7
  36. package/src/config/merge.d.ts +1 -0
  37. package/src/config/merge.js +28 -0
  38. package/src/config/param-validator.js +47 -2
  39. package/src/config/presets/kernel-paths/{gemma2-q4k-dequant-f32a.json → gemma2-q4k-dequant-f32a-nosubgroups.json} +3 -3
  40. package/src/config/presets/kernel-paths/gemma3-f16-fused-f32a-online-streamingprefill.json +223 -0
  41. package/src/config/presets/kernel-paths/{gemma3-q4k-dequant-f32a.json → gemma3-q4k-dequant-f32a-nosubgroups.json} +3 -3
  42. package/src/config/presets/kernel-paths/registry.json +29 -8
  43. package/src/config/presets/models/gemma2.json +2 -2
  44. package/src/config/presets/models/qwen3.json +9 -2
  45. package/src/config/presets/models/transformer.json +5 -0
  46. package/src/config/presets/runtime/experiments/bench/gemma3-bench-q4k.json +1 -1
  47. package/src/config/presets/runtime/experiments/debug/gemma3-debug-q4k.json +1 -1
  48. package/src/config/presets/runtime/experiments/verify/gemma3-verify.json +1 -1
  49. package/src/config/presets/runtime/kernels/dequant-f16-q4k.json +6 -13
  50. package/src/config/presets/runtime/kernels/dequant-f32-q4k.json +6 -13
  51. package/src/config/presets/runtime/kernels/embeddinggemma-q4k-dequant-f32a.json +37 -0
  52. package/src/config/presets/runtime/kernels/fused-q4k.json +6 -13
  53. package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f16a.json +33 -0
  54. package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f32a-nosubgroups.json +33 -0
  55. package/src/config/presets/runtime/kernels/gemma2-q4k-fused-f32a.json +33 -0
  56. package/src/config/presets/runtime/kernels/safe-q4k.json +6 -13
  57. package/src/config/presets/runtime/platform/metal-apple-q4k.json +1 -1
  58. package/src/config/required-inference-fields-contract-check.js +6 -0
  59. package/src/config/runtime.js +6 -1
  60. package/src/config/schema/debug.schema.d.ts +5 -0
  61. package/src/config/schema/doppler.schema.js +16 -21
  62. package/src/config/schema/inference-defaults.schema.js +6 -3
  63. package/src/config/schema/inference.schema.d.ts +9 -0
  64. package/src/config/schema/kernel-path.schema.d.ts +11 -1
  65. package/src/config/schema/kernel-thresholds.schema.js +12 -4
  66. package/src/config/schema/manifest.schema.d.ts +8 -1
  67. package/src/config/schema/manifest.schema.js +19 -3
  68. package/src/config/training-defaults.js +30 -22
  69. package/src/converter/conversion-plan.js +94 -9
  70. package/src/converter/core.d.ts +7 -0
  71. package/src/converter/core.js +14 -9
  72. package/src/converter/execution-v0-manifest.js +4 -1
  73. package/src/converter/index.d.ts +1 -0
  74. package/src/converter/index.js +1 -0
  75. package/src/converter/manifest-inference.js +43 -12
  76. package/src/converter/parsers/diffusion.js +0 -3
  77. package/src/converter/quantization-info.js +35 -15
  78. package/src/converter/rope-config.js +42 -0
  79. package/src/converter/shard-packer.d.ts +1 -1
  80. package/src/converter/shard-packer.js +4 -1
  81. package/src/debug/config.js +123 -11
  82. package/src/debug/signals.js +7 -1
  83. package/src/debug/tensor.d.ts +2 -0
  84. package/src/debug/tensor.js +13 -2
  85. package/src/distribution/p2p-control-plane.js +52 -12
  86. package/src/distribution/p2p-observability.js +43 -7
  87. package/src/distribution/p2p-webrtc-browser.js +20 -0
  88. package/src/distribution/shard-delivery.js +77 -26
  89. package/src/formats/gguf/types.js +33 -16
  90. package/src/formats/rdrr/groups.d.ts +12 -4
  91. package/src/formats/rdrr/groups.js +3 -6
  92. package/src/formats/rdrr/parsing.js +39 -2
  93. package/src/formats/rdrr/types.d.ts +2 -1
  94. package/src/gpu/command-recorder.js +86 -61
  95. package/src/gpu/device.d.ts +1 -0
  96. package/src/gpu/device.js +131 -19
  97. package/src/gpu/kernel-tuner/benchmarks.js +326 -316
  98. package/src/gpu/kernel-tuner/cache.js +71 -4
  99. package/src/gpu/kernel-tuner/tuner.js +22 -4
  100. package/src/gpu/kernels/attention.js +113 -34
  101. package/src/gpu/kernels/backward/adam.js +62 -58
  102. package/src/gpu/kernels/backward/attention_backward.js +257 -169
  103. package/src/gpu/kernels/backward/conv2d_backward.js +14 -1
  104. package/src/gpu/kernels/bias_add.wgsl +8 -6
  105. package/src/gpu/kernels/bias_add_f16.wgsl +8 -5
  106. package/src/gpu/kernels/cast.js +191 -149
  107. package/src/gpu/kernels/check-stop.js +33 -44
  108. package/src/gpu/kernels/conv2d.js +27 -17
  109. package/src/gpu/kernels/conv2d.wgsl +7 -8
  110. package/src/gpu/kernels/conv2d_f16.wgsl +7 -8
  111. package/src/gpu/kernels/cross_entropy_loss.js +21 -15
  112. package/src/gpu/kernels/depthwise_conv2d.js +37 -26
  113. package/src/gpu/kernels/depthwise_conv2d.wgsl +6 -9
  114. package/src/gpu/kernels/depthwise_conv2d_f16.wgsl +6 -9
  115. package/src/gpu/kernels/dequant.js +178 -126
  116. package/src/gpu/kernels/energy.d.ts +3 -21
  117. package/src/gpu/kernels/energy.js +111 -88
  118. package/src/gpu/kernels/feature-check.js +1 -1
  119. package/src/gpu/kernels/fused_ffn.js +84 -65
  120. package/src/gpu/kernels/fused_matmul_residual.js +56 -33
  121. package/src/gpu/kernels/fused_matmul_rmsnorm.js +62 -45
  122. package/src/gpu/kernels/gather.js +33 -15
  123. package/src/gpu/kernels/gelu.js +19 -11
  124. package/src/gpu/kernels/grouped_pointwise_conv2d.js +34 -23
  125. package/src/gpu/kernels/grouped_pointwise_conv2d.wgsl +6 -9
  126. package/src/gpu/kernels/grouped_pointwise_conv2d_f16.wgsl +6 -9
  127. package/src/gpu/kernels/groupnorm.js +34 -23
  128. package/src/gpu/kernels/kv-quantize.js +5 -2
  129. package/src/gpu/kernels/layernorm.js +35 -19
  130. package/src/gpu/kernels/logit-merge.js +5 -3
  131. package/src/gpu/kernels/matmul.js +83 -39
  132. package/src/gpu/kernels/modulate.js +23 -15
  133. package/src/gpu/kernels/moe.js +221 -175
  134. package/src/gpu/kernels/pixel_shuffle.js +22 -14
  135. package/src/gpu/kernels/pixel_shuffle.wgsl +4 -5
  136. package/src/gpu/kernels/pixel_shuffle_f16.wgsl +4 -5
  137. package/src/gpu/kernels/relu.js +31 -10
  138. package/src/gpu/kernels/relu.wgsl +2 -1
  139. package/src/gpu/kernels/relu_f16.wgsl +2 -1
  140. package/src/gpu/kernels/repeat_channels.js +25 -17
  141. package/src/gpu/kernels/repeat_channels.wgsl +4 -5
  142. package/src/gpu/kernels/repeat_channels_f16.wgsl +4 -5
  143. package/src/gpu/kernels/residual.js +69 -23
  144. package/src/gpu/kernels/residual.wgsl +6 -3
  145. package/src/gpu/kernels/residual_f16.wgsl +2 -1
  146. package/src/gpu/kernels/residual_f16_vec4.wgsl +2 -1
  147. package/src/gpu/kernels/residual_vec4.wgsl +2 -1
  148. package/src/gpu/kernels/rmsnorm.js +96 -28
  149. package/src/gpu/kernels/rmsnorm.wgsl +14 -6
  150. package/src/gpu/kernels/rmsnorm_f16.wgsl +10 -2
  151. package/src/gpu/kernels/rope.d.ts +2 -0
  152. package/src/gpu/kernels/rope.js +14 -1
  153. package/src/gpu/kernels/rope.wgsl +56 -40
  154. package/src/gpu/kernels/sample.js +27 -38
  155. package/src/gpu/kernels/sana_linear_attention.js +19 -12
  156. package/src/gpu/kernels/sana_linear_attention_apply.wgsl +4 -5
  157. package/src/gpu/kernels/sana_linear_attention_apply_f16.wgsl +4 -5
  158. package/src/gpu/kernels/sana_linear_attention_summary.wgsl +4 -0
  159. package/src/gpu/kernels/sana_linear_attention_summary_f16.wgsl +4 -0
  160. package/src/gpu/kernels/scale.js +18 -11
  161. package/src/gpu/kernels/shader-cache.js +4 -2
  162. package/src/gpu/kernels/silu.d.ts +1 -0
  163. package/src/gpu/kernels/silu.js +148 -82
  164. package/src/gpu/kernels/silu.wgsl +19 -9
  165. package/src/gpu/kernels/silu_f16.wgsl +19 -9
  166. package/src/gpu/kernels/softmax.js +44 -25
  167. package/src/gpu/kernels/split_qkv.js +23 -13
  168. package/src/gpu/kernels/transpose.js +31 -10
  169. package/src/gpu/kernels/transpose.wgsl +6 -5
  170. package/src/gpu/kernels/upsample2d.js +22 -13
  171. package/src/gpu/kernels/upsample2d.wgsl +6 -9
  172. package/src/gpu/kernels/upsample2d_f16.wgsl +6 -9
  173. package/src/gpu/kernels/utils.js +35 -13
  174. package/src/gpu/partitioned-buffer-pool.js +10 -2
  175. package/src/gpu/perf-guards.js +2 -9
  176. package/src/gpu/profiler.js +27 -22
  177. package/src/gpu/readback-utils.d.ts +16 -0
  178. package/src/gpu/readback-utils.js +41 -0
  179. package/src/gpu/submit-tracker.js +13 -0
  180. package/src/gpu/uniform-cache.d.ts +1 -0
  181. package/src/gpu/uniform-cache.js +30 -9
  182. package/src/hotswap/intent-bundle.js +6 -0
  183. package/src/hotswap/manifest.d.ts +10 -1
  184. package/src/hotswap/manifest.js +12 -2
  185. package/src/hotswap/runtime.js +30 -8
  186. package/src/index-browser.d.ts +44 -0
  187. package/src/index-browser.js +14 -0
  188. package/src/inference/browser-harness-contract-helpers.d.ts +5 -0
  189. package/src/inference/browser-harness-contract-helpers.js +28 -0
  190. package/src/inference/browser-harness-diffusion-energy-suites.d.ts +2 -0
  191. package/src/inference/browser-harness-diffusion-energy-suites.js +269 -0
  192. package/src/inference/browser-harness-model-helpers.d.ts +16 -0
  193. package/src/inference/browser-harness-model-helpers.js +217 -0
  194. package/src/inference/browser-harness-report-helpers.d.ts +7 -0
  195. package/src/inference/browser-harness-report-helpers.js +42 -0
  196. package/src/inference/browser-harness-runtime-helpers.d.ts +61 -0
  197. package/src/inference/browser-harness-runtime-helpers.js +415 -0
  198. package/src/inference/browser-harness-suite-helpers.d.ts +28 -0
  199. package/src/inference/browser-harness-suite-helpers.js +268 -0
  200. package/src/inference/browser-harness-text-helpers.d.ts +27 -0
  201. package/src/inference/browser-harness-text-helpers.js +788 -0
  202. package/src/inference/browser-harness.d.ts +6 -0
  203. package/src/inference/browser-harness.js +130 -1950
  204. package/src/inference/kv-cache/base.js +140 -94
  205. package/src/inference/kv-cache/tiered.js +5 -3
  206. package/src/inference/moe-router.js +88 -56
  207. package/src/inference/multi-model-network.js +5 -3
  208. package/src/inference/network-evolution.d.ts +11 -2
  209. package/src/inference/network-evolution.js +20 -21
  210. package/src/inference/pipelines/context.d.ts +3 -0
  211. package/src/inference/pipelines/context.js +142 -2
  212. package/src/inference/pipelines/diffusion/helpers.js +7 -2
  213. package/src/inference/pipelines/diffusion/pipeline.js +17 -7
  214. package/src/inference/pipelines/diffusion/sd3-transformer.js +10 -10
  215. package/src/inference/pipelines/diffusion/text-encoder-gpu.d.ts +5 -0
  216. package/src/inference/pipelines/diffusion/text-encoder-gpu.js +27 -15
  217. package/src/inference/pipelines/diffusion/vae.js +3 -7
  218. package/src/inference/pipelines/energy/pipeline.js +27 -21
  219. package/src/inference/pipelines/energy/quintel.d.ts +5 -0
  220. package/src/inference/pipelines/energy/quintel.js +11 -0
  221. package/src/inference/pipelines/energy-head/row-head-pipeline.js +17 -13
  222. package/src/inference/pipelines/structured/json-head-pipeline.js +26 -11
  223. package/src/inference/pipelines/text/attention/projections.js +151 -101
  224. package/src/inference/pipelines/text/attention/record.js +73 -10
  225. package/src/inference/pipelines/text/attention/run.js +73 -10
  226. package/src/inference/pipelines/text/chat-format.js +25 -1
  227. package/src/inference/pipelines/text/config.d.ts +4 -0
  228. package/src/inference/pipelines/text/config.js +71 -5
  229. package/src/inference/pipelines/text/embed.js +2 -8
  230. package/src/inference/pipelines/text/execution-plan.js +64 -50
  231. package/src/inference/pipelines/text/execution-v0-contract-helpers.d.ts +59 -0
  232. package/src/inference/pipelines/text/execution-v0-contract-helpers.js +937 -0
  233. package/src/inference/pipelines/text/execution-v0-runtime-builders.d.ts +15 -0
  234. package/src/inference/pipelines/text/execution-v0-runtime-builders.js +279 -0
  235. package/src/inference/pipelines/text/execution-v0.js +78 -1002
  236. package/src/inference/pipelines/text/ffn/standard.js +3 -0
  237. package/src/inference/pipelines/text/generator-steps.d.ts +46 -0
  238. package/src/inference/pipelines/text/generator-steps.js +298 -207
  239. package/src/inference/pipelines/text/generator.js +6 -23
  240. package/src/inference/pipelines/text/init.d.ts +4 -0
  241. package/src/inference/pipelines/text/init.js +134 -29
  242. package/src/inference/pipelines/text/kernel-path-auto-select.js +2 -0
  243. package/src/inference/pipelines/text/kernel-trace.d.ts +2 -0
  244. package/src/inference/pipelines/text/kernel-trace.js +6 -0
  245. package/src/inference/pipelines/text/layer.js +14 -9
  246. package/src/inference/pipelines/text/linear-attention.d.ts +10 -0
  247. package/src/inference/pipelines/text/linear-attention.js +80 -6
  248. package/src/inference/pipelines/text/logits/gpu.js +10 -5
  249. package/src/inference/pipelines/text/logits/index.js +10 -11
  250. package/src/inference/pipelines/text/logits/utils.d.ts +7 -0
  251. package/src/inference/pipelines/text/logits/utils.js +9 -0
  252. package/src/inference/pipelines/text/lora-apply.js +50 -32
  253. package/src/inference/pipelines/text/model-load.js +279 -104
  254. package/src/inference/pipelines/text/moe-cache.js +5 -4
  255. package/src/inference/pipelines/text/moe-cpu-gptoss.js +74 -69
  256. package/src/inference/pipelines/text/moe-cpu.js +42 -38
  257. package/src/inference/pipelines/text/moe-gpu.js +110 -86
  258. package/src/inference/pipelines/text/ops.js +90 -90
  259. package/src/inference/pipelines/text/probes.js +9 -9
  260. package/src/inference/pipelines/text/weights.js +17 -7
  261. package/src/inference/pipelines/text.js +17 -1
  262. package/src/inference/speculative.d.ts +2 -2
  263. package/src/inference/speculative.js +4 -18
  264. package/src/inference/test-harness.d.ts +1 -1
  265. package/src/inference/test-harness.js +15 -5
  266. package/src/inference/tokenizer.d.ts +0 -5
  267. package/src/inference/tokenizer.js +4 -23
  268. package/src/inference/tokenizers/bpe.js +9 -0
  269. package/src/inference/tokenizers/bundled.js +176 -33
  270. package/src/inference/tokenizers/sentencepiece.js +12 -0
  271. package/src/loader/doppler-loader.js +38 -22
  272. package/src/loader/dtype-utils.js +3 -44
  273. package/src/loader/embedding-loader.js +7 -3
  274. package/src/loader/experts/expert-cache.js +13 -6
  275. package/src/loader/experts/expert-loader.js +10 -6
  276. package/src/loader/final-weights-loader.js +8 -4
  277. package/src/loader/layer-loader.js +2 -1
  278. package/src/loader/loader-state.js +2 -2
  279. package/src/loader/memory-monitor.js +8 -0
  280. package/src/loader/multi-model-loader.d.ts +14 -0
  281. package/src/loader/multi-model-loader.js +70 -24
  282. package/src/loader/shard-cache.js +81 -12
  283. package/src/loader/shard-resolver.js +25 -3
  284. package/src/loader/tensors/tensor-loader.js +209 -144
  285. package/src/loader/tensors/tensor-reader.js +76 -19
  286. package/src/loader/weight-downcast.js +1 -1
  287. package/src/memory/buffer-pool.d.ts +9 -1
  288. package/src/memory/buffer-pool.js +109 -44
  289. package/src/memory/unified-detect.js +1 -1
  290. package/src/rules/inference/kernel-path.rules.json +24 -8
  291. package/src/rules/rule-registry.js +25 -1
  292. package/src/rules/tooling/command-runtime.rules.json +18 -0
  293. package/src/storage/backends/opfs-store.js +68 -24
  294. package/src/storage/downloader.js +364 -83
  295. package/src/storage/index.d.ts +3 -0
  296. package/src/storage/index.js +3 -0
  297. package/src/storage/preflight.d.ts +2 -2
  298. package/src/storage/preflight.js +24 -2
  299. package/src/storage/quickstart-downloader.js +11 -5
  300. package/src/storage/registry.js +10 -4
  301. package/src/storage/reports.js +1 -1
  302. package/src/storage/shard-manager.d.ts +15 -1
  303. package/src/storage/shard-manager.js +51 -3
  304. package/src/storage/source-artifact-store.d.ts +52 -0
  305. package/src/storage/source-artifact-store.js +234 -0
  306. package/src/tooling/command-api-constants.d.ts +9 -0
  307. package/src/tooling/command-api-constants.js +9 -0
  308. package/src/tooling/command-api-family-normalizers.d.ts +9 -0
  309. package/src/tooling/command-api-family-normalizers.js +343 -0
  310. package/src/tooling/command-api-helpers.d.ts +25 -0
  311. package/src/tooling/command-api-helpers.js +262 -0
  312. package/src/tooling/command-api.d.ts +27 -1
  313. package/src/tooling/command-api.js +26 -473
  314. package/src/tooling/command-envelope.js +4 -1
  315. package/src/tooling/command-runner-shared.js +52 -18
  316. package/src/tooling/lean-execution-contract.js +150 -3
  317. package/src/tooling/node-browser-command-runner.d.ts +4 -0
  318. package/src/tooling/node-browser-command-runner.js +218 -273
  319. package/src/tooling/node-command-runner.js +44 -3
  320. package/src/tooling/node-converter.js +27 -1
  321. package/src/tooling/node-source-runtime.d.ts +1 -1
  322. package/src/tooling/node-source-runtime.js +84 -3
  323. package/src/tooling/node-webgpu.js +30 -105
  324. package/src/tooling/opfs-cache.js +21 -4
  325. package/src/tooling/runtime-input-composition.d.ts +38 -0
  326. package/src/tooling/runtime-input-composition.js +86 -0
  327. package/src/tooling/source-runtime-bundle.d.ts +40 -5
  328. package/src/tooling/source-runtime-bundle.js +261 -34
  329. package/src/tooling/source-runtime-materializer.d.ts +6 -0
  330. package/src/tooling/source-runtime-materializer.js +93 -0
  331. package/src/training/attention-backward.js +32 -17
  332. package/src/training/autograd.js +80 -52
  333. package/src/training/checkpoint-watch.d.ts +8 -0
  334. package/src/training/checkpoint-watch.js +139 -0
  335. package/src/training/checkpoint.d.ts +6 -1
  336. package/src/training/checkpoint.js +46 -7
  337. package/src/training/clip.js +2 -1
  338. package/src/training/datasets/token-batch.js +20 -8
  339. package/src/training/distillation/artifacts.d.ts +71 -0
  340. package/src/training/distillation/artifacts.js +132 -0
  341. package/src/training/distillation/checkpoint-watch.d.ts +10 -0
  342. package/src/training/distillation/checkpoint-watch.js +58 -0
  343. package/src/training/distillation/dataset.d.ts +59 -0
  344. package/src/training/distillation/dataset.js +337 -0
  345. package/src/training/distillation/eval.d.ts +34 -0
  346. package/src/training/distillation/eval.js +310 -0
  347. package/src/training/distillation/index.d.ts +29 -0
  348. package/src/training/distillation/index.js +29 -0
  349. package/src/training/distillation/runtime.d.ts +20 -0
  350. package/src/training/distillation/runtime.js +121 -0
  351. package/src/training/distillation/scoreboard.d.ts +6 -0
  352. package/src/training/distillation/scoreboard.js +8 -0
  353. package/src/training/distillation/stage-a.d.ts +45 -0
  354. package/src/training/distillation/stage-a.js +338 -0
  355. package/src/training/distillation/stage-b.d.ts +24 -0
  356. package/src/training/distillation/stage-b.js +20 -0
  357. package/src/training/distillation/student-fixture.d.ts +22 -0
  358. package/src/training/distillation/student-fixture.js +846 -0
  359. package/src/training/distillation/suite-data.d.ts +45 -0
  360. package/src/training/distillation/suite-data.js +189 -0
  361. package/src/training/index.d.ts +10 -0
  362. package/src/training/index.js +10 -0
  363. package/src/training/lora-pipeline.d.ts +40 -0
  364. package/src/training/lora-pipeline.js +793 -0
  365. package/src/training/lora.js +26 -12
  366. package/src/training/loss.js +5 -6
  367. package/src/training/objectives/cross_entropy.js +2 -5
  368. package/src/training/objectives/distill_kd.js +4 -8
  369. package/src/training/objectives/distill_triplet.js +4 -8
  370. package/src/training/objectives/ul_stage2_base.js +4 -8
  371. package/src/training/operator-artifacts.d.ts +62 -0
  372. package/src/training/operator-artifacts.js +140 -0
  373. package/src/training/operator-command.d.ts +5 -0
  374. package/src/training/operator-command.js +455 -0
  375. package/src/training/operator-eval.d.ts +48 -0
  376. package/src/training/operator-eval.js +230 -0
  377. package/src/training/operator-scoreboard.d.ts +5 -0
  378. package/src/training/operator-scoreboard.js +44 -0
  379. package/src/training/optimizer.js +19 -7
  380. package/src/training/runner.d.ts +52 -0
  381. package/src/training/runner.js +31 -5
  382. package/src/training/suite.d.ts +112 -0
  383. package/src/training/suite.js +24 -984
  384. package/src/training/tensor-factory.d.ts +9 -0
  385. package/src/training/tensor-factory.js +13 -0
  386. package/src/training/trainer.js +3 -5
  387. package/src/training/ul_dataset.js +3 -5
  388. package/src/training/workloads.d.ts +164 -0
  389. package/src/training/workloads.js +530 -0
  390. package/src/version.js +1 -1
  391. package/tools/convert-safetensors-node.js +22 -16
  392. package/tools/doppler-cli.js +179 -63
@@ -94,6 +94,8 @@ export const GGML_TYPE_SIZE = {
94
94
  const GGUF_MAGIC = 0x46554747;
95
95
  const GGUF_VERSION_MIN = 2;
96
96
  const GGUF_VERSION_MAX = 3;
97
+ const MAX_SAFE_BIGINT = BigInt(Number.MAX_SAFE_INTEGER);
98
+ const MIN_SAFE_BIGINT = BigInt(Number.MIN_SAFE_INTEGER);
97
99
 
98
100
  const {
99
101
  contextLength: DEFAULT_GGUF_CONTEXT_LENGTH,
@@ -102,6 +104,13 @@ const {
102
104
  ropeFreqBase: DEFAULT_ROPE_FREQ_BASE,
103
105
  } = DEFAULT_GGUF_PARSER_DEFAULTS;
104
106
 
107
+ function toSafeInteger(value, label) {
108
+ if (value > MAX_SAFE_BIGINT || value < MIN_SAFE_BIGINT) {
109
+ throw new Error(`GGUF ${label} exceeds JavaScript safe integer range: ${value.toString()}`);
110
+ }
111
+ return Number(value);
112
+ }
113
+
105
114
  class GGUFReader {
106
115
  constructor(buffer) {
107
116
  this.view = new DataView(buffer);
@@ -144,18 +153,26 @@ class GGUFReader {
144
153
  return value;
145
154
  }
146
155
 
147
- readUint64() {
148
- const low = this.view.getUint32(this.offset, true);
149
- const high = this.view.getUint32(this.offset + 4, true);
156
+ readUint64BigInt() {
157
+ const low = BigInt(this.view.getUint32(this.offset, true));
158
+ const high = BigInt(this.view.getUint32(this.offset + 4, true));
150
159
  this.offset += 8;
151
- return high * 0x100000000 + low;
160
+ return (high << 32n) | low;
152
161
  }
153
162
 
154
- readInt64() {
155
- const low = this.view.getUint32(this.offset, true);
156
- const high = this.view.getInt32(this.offset + 4, true);
163
+ readUint64(label = 'u64 value') {
164
+ return toSafeInteger(this.readUint64BigInt(), label);
165
+ }
166
+
167
+ readInt64BigInt() {
168
+ const low = BigInt(this.view.getUint32(this.offset, true));
169
+ const high = BigInt(this.view.getInt32(this.offset + 4, true));
157
170
  this.offset += 8;
158
- return high * 0x100000000 + low;
171
+ return (high << 32n) | low;
172
+ }
173
+
174
+ readInt64(label = 'i64 value') {
175
+ return toSafeInteger(this.readInt64BigInt(), label);
159
176
  }
160
177
 
161
178
  readFloat32() {
@@ -175,7 +192,7 @@ class GGUFReader {
175
192
  }
176
193
 
177
194
  readString() {
178
- const length = this.readUint64();
195
+ const length = this.readUint64('string length');
179
196
  const bytes = new Uint8Array(this.view.buffer, this.offset, length);
180
197
  this.offset += length;
181
198
  return new TextDecoder().decode(bytes);
@@ -196,9 +213,9 @@ class GGUFReader {
196
213
  case GGUFValueType.INT32:
197
214
  return this.readInt32();
198
215
  case GGUFValueType.UINT64:
199
- return this.readUint64();
216
+ return this.readUint64('metadata uint64');
200
217
  case GGUFValueType.INT64:
201
- return this.readInt64();
218
+ return this.readInt64('metadata int64');
202
219
  case GGUFValueType.FLOAT32:
203
220
  return this.readFloat32();
204
221
  case GGUFValueType.FLOAT64:
@@ -216,7 +233,7 @@ class GGUFReader {
216
233
 
217
234
  readArray() {
218
235
  const elementType = this.readUint32();
219
- const length = this.readUint64();
236
+ const length = this.readUint64('array length');
220
237
  if (length > 10000000) {
221
238
  throw new Error(`Array too long: ${length}`);
222
239
  }
@@ -331,8 +348,8 @@ export function parseGGUF(buffer) {
331
348
  throw new Error(`Unsupported GGUF version: ${version}`);
332
349
  }
333
350
 
334
- const tensorCount = reader.readUint64();
335
- const metadataKVCount = reader.readUint64();
351
+ const tensorCount = reader.readUint64('tensor count');
352
+ const metadataKVCount = reader.readUint64('metadata count');
336
353
 
337
354
  const metadata = {};
338
355
  for (let i = 0; i < metadataKVCount; i++) {
@@ -351,10 +368,10 @@ export function parseGGUF(buffer) {
351
368
  const nDims = reader.readUint32();
352
369
  const shape = [];
353
370
  for (let d = 0; d < nDims; d++) {
354
- shape.push(reader.readUint64());
371
+ shape.push(reader.readUint64(`tensor "${name}" shape[${d}]`));
355
372
  }
356
373
  const type = reader.readUint32();
357
- const offset = reader.readUint64();
374
+ const offset = reader.readUint64(`tensor "${name}" offset`);
358
375
 
359
376
  tensors.push({
360
377
  name,
@@ -6,7 +6,7 @@
6
6
  * @module formats/rdrr/groups
7
7
  */
8
8
 
9
- import type { ComponentGroup } from './types.js';
9
+ import type { ComponentGroup, RDRRManifest } from './types.js';
10
10
 
11
11
  export declare function getGroup(groupId: string): ComponentGroup | null;
12
12
 
@@ -16,11 +16,19 @@ export declare function getShardsForGroup(groupId: string): number[];
16
16
 
17
17
  export declare function getTensorsForGroup(groupId: string): string[];
18
18
 
19
- export declare function getShardsForExpert(layerIdx: number, expertIdx: number): number[];
19
+ export declare function getShardsForExpert(
20
+ layerIdx: number,
21
+ expertIdx: number,
22
+ manifest?: RDRRManifest | null
23
+ ): number[];
20
24
 
21
- export declare function getTensorsForExpert(layerIdx: number, expertIdx: number): string[];
25
+ export declare function getTensorsForExpert(
26
+ layerIdx: number,
27
+ expertIdx: number,
28
+ manifest?: RDRRManifest | null
29
+ ): string[];
22
30
 
23
- export declare function getExpertBytes(): number;
31
+ export declare function getExpertBytes(manifest?: RDRRManifest | null): number;
24
32
 
25
33
  export declare function getLayerGroupIds(): string[];
26
34
 
@@ -19,8 +19,7 @@ export function getTensorsForGroup(groupId) {
19
19
  return getManifest()?.groups?.[groupId]?.tensors ?? [];
20
20
  }
21
21
 
22
- export function getShardsForExpert(layerIdx, expertIdx) {
23
- const manifest = getManifest();
22
+ export function getShardsForExpert(layerIdx, expertIdx, manifest = getManifest()) {
24
23
  const groupId = `layer.${layerIdx}.expert.${expertIdx}`;
25
24
  const group = manifest?.groups?.[groupId];
26
25
  if (group) {
@@ -29,8 +28,7 @@ export function getShardsForExpert(layerIdx, expertIdx) {
29
28
  throw new Error(`Missing expert group mapping: ${groupId}`);
30
29
  }
31
30
 
32
- export function getTensorsForExpert(layerIdx, expertIdx) {
33
- const manifest = getManifest();
31
+ export function getTensorsForExpert(layerIdx, expertIdx, manifest = getManifest()) {
34
32
  const groupId = `layer.${layerIdx}.expert.${expertIdx}`;
35
33
  const group = manifest?.groups?.[groupId];
36
34
  if (group) {
@@ -39,8 +37,7 @@ export function getTensorsForExpert(layerIdx, expertIdx) {
39
37
  throw new Error(`Missing expert group mapping: ${groupId}`);
40
38
  }
41
39
 
42
- export function getExpertBytes() {
43
- const manifest = getManifest();
40
+ export function getExpertBytes(manifest = getManifest()) {
44
41
  const expertGroups = Object.entries(manifest?.groups || {})
45
42
  .filter(([id]) => id.includes('.expert.'));
46
43
 
@@ -44,9 +44,13 @@ export function parseManifest(jsonString) {
44
44
  export function parseTensorMap(jsonString) {
45
45
  try {
46
46
  const tensorMap = JSON.parse(jsonString);
47
+ const normalizedTensorMap = {};
47
48
 
48
49
  for (const [name, loc] of Object.entries(tensorMap)) {
49
- if (typeof loc.shard !== 'number') {
50
+ const shardIndex = typeof loc.shardIndex === 'number'
51
+ ? loc.shardIndex
52
+ : loc.shard;
53
+ if (typeof shardIndex !== 'number') {
50
54
  throw new Error(`Tensor '${name}' missing shard index`);
51
55
  }
52
56
  if (typeof loc.offset !== 'number') {
@@ -61,9 +65,42 @@ export function parseTensorMap(jsonString) {
61
65
  if (typeof loc.role !== 'string') {
62
66
  throw new Error(`Tensor '${name}' missing role`);
63
67
  }
68
+
69
+ let spans = undefined;
70
+ if (loc.spans !== undefined) {
71
+ if (!Array.isArray(loc.spans)) {
72
+ throw new Error(`Tensor '${name}' has invalid spans array`);
73
+ }
74
+ spans = loc.spans.map((span, spanIndex) => {
75
+ const spanShardIndex = typeof span?.shardIndex === 'number'
76
+ ? span.shardIndex
77
+ : span?.shard;
78
+ if (typeof spanShardIndex !== 'number') {
79
+ throw new Error(`Tensor '${name}' span[${spanIndex}] missing shard index`);
80
+ }
81
+ if (typeof span?.offset !== 'number') {
82
+ throw new Error(`Tensor '${name}' span[${spanIndex}] missing offset`);
83
+ }
84
+ if (typeof span?.size !== 'number') {
85
+ throw new Error(`Tensor '${name}' span[${spanIndex}] missing size`);
86
+ }
87
+ return {
88
+ shardIndex: spanShardIndex,
89
+ offset: span.offset,
90
+ size: span.size,
91
+ };
92
+ });
93
+ }
94
+
95
+ normalizedTensorMap[name] = {
96
+ ...loc,
97
+ shard: shardIndex,
98
+ shardIndex,
99
+ spans,
100
+ };
64
101
  }
65
102
 
66
- return tensorMap;
103
+ return normalizedTensorMap;
67
104
  } catch (e) {
68
105
  if (e instanceof Error && e.message.includes('Tensor')) {
69
106
  throw e;
@@ -75,13 +75,14 @@ export interface ComponentGroup extends ComponentGroupSchema {}
75
75
 
76
76
  export interface TensorLocation {
77
77
  shard: number;
78
+ shardIndex?: number;
78
79
  offset: number;
79
80
  size: number;
80
81
  shape: number[];
81
82
  dtype: string;
82
83
  role: TensorRole;
83
84
  group?: string;
84
- spans?: Array<{ shardIndex: number; offset: number; size: number }>;
85
+ spans?: Array<{ shard?: number; shardIndex?: number; offset: number; size: number }>;
85
86
  layout?: WeightLayout;
86
87
  originalShape?: number[];
87
88
  }
@@ -3,7 +3,7 @@
3
3
  import { getDevice, hasFeature, FEATURES } from './device.js';
4
4
  import { allowReadback, trackAllocation } from './perf-guards.js';
5
5
  import { getUniformCache } from './uniform-cache.js';
6
- import { isBufferActive, releaseBuffer } from '../memory/buffer-pool.js';
6
+ import { isBufferActive, releaseBuffer, discardBuffer } from '../memory/buffer-pool.js';
7
7
  import { log } from '../debug/index.js';
8
8
  import { getRuntimeConfig } from '../config/runtime.js';
9
9
 
@@ -93,6 +93,9 @@ export class CommandRecorder {
93
93
 
94
94
 
95
95
  #initProfiling() {
96
+ let querySet = null;
97
+ let queryBuffer = null;
98
+ let readbackBuffer = null;
96
99
  try {
97
100
  const runtimeProfiler = getRuntimeConfig().shared?.debug?.profiler;
98
101
  if (!runtimeProfiler) {
@@ -119,25 +122,31 @@ export class CommandRecorder {
119
122
  didLogQueryFallback = true;
120
123
  }
121
124
 
122
- this.#querySet = this.device.createQuerySet({
125
+ querySet = this.device.createQuerySet({
123
126
  type: 'timestamp',
124
127
  count: this.#queryCapacity,
125
128
  });
126
129
 
127
130
  // Buffer to hold query results (8 bytes per timestamp = BigUint64)
128
- this.#queryBuffer = this.device.createBuffer({
131
+ queryBuffer = this.device.createBuffer({
129
132
  label: `${this.label}_query_buffer`,
130
133
  size: this.#queryCapacity * 8,
131
134
  usage: GPUBufferUsage.QUERY_RESOLVE | GPUBufferUsage.COPY_SRC,
132
135
  });
133
136
 
134
137
  // Readback buffer
135
- this.#readbackBuffer = this.device.createBuffer({
138
+ readbackBuffer = this.device.createBuffer({
136
139
  label: `${this.label}_readback_buffer`,
137
140
  size: this.#queryCapacity * 8,
138
141
  usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST,
139
142
  });
143
+ this.#querySet = querySet;
144
+ this.#queryBuffer = queryBuffer;
145
+ this.#readbackBuffer = readbackBuffer;
140
146
  } catch (e) {
147
+ readbackBuffer?.destroy();
148
+ queryBuffer?.destroy();
149
+ querySet?.destroy();
141
150
  log.warn('CommandRecorder', `Failed to initialize profiling: ${e}`);
142
151
  this.#profilingEnabled = false;
143
152
  }
@@ -277,39 +286,57 @@ export class CommandRecorder {
277
286
  }
278
287
  }
279
288
 
289
+ #finalizeTrackedBuffers(buffersToDestroy, buffersToRelease, discardPooled) {
290
+ for (const buffer of buffersToDestroy) {
291
+ buffer.destroy();
292
+ }
293
+ for (const buffer of buffersToRelease) {
294
+ if (discardPooled) {
295
+ discardBuffer(buffer);
296
+ } else {
297
+ releaseBuffer(buffer);
298
+ }
299
+ }
300
+ getUniformCache().flushPendingDestruction();
301
+ }
302
+
303
+ #takeTrackedBuffers() {
304
+ const buffersToDestroy = this.#tempBuffers;
305
+ const buffersToRelease = this.#pooledBuffers;
306
+ this.#tempBuffers = [];
307
+ this.#pooledBuffers = [];
308
+ this.#tempBufferSet.clear();
309
+ this.#pooledBufferSet.clear();
310
+ return { buffersToDestroy, buffersToRelease };
311
+ }
312
+
280
313
 
281
314
  submit() {
282
315
  if (this.#submitted) {
283
316
  throw new Error('[CommandRecorder] Already submitted');
284
317
  }
285
318
 
286
- // Submit commands
287
319
  const submitStart = performance.now();
288
- this.device.queue.submit([this.#encoder.finish()]);
320
+ const { buffersToDestroy, buffersToRelease } = this.#takeTrackedBuffers();
321
+ try {
322
+ this.device.queue.submit([this.#encoder.finish()]);
323
+ } catch (error) {
324
+ this.#submitted = true;
325
+ this.#submitStartMs = submitStart;
326
+ this.#finalizeTrackedBuffers(buffersToDestroy, buffersToRelease, false);
327
+ this.#destroyProfilingResources();
328
+ throw error;
329
+ }
330
+
289
331
  this.#submitted = true;
290
332
  this.#submitStartMs = submitStart;
291
333
 
292
- const buffersToDestroy = this.#tempBuffers;
293
- const buffersToRelease = this.#pooledBuffers;
294
- this.#tempBuffers = [];
295
- this.#pooledBuffers = [];
296
- this.#tempBufferSet.clear();
297
- this.#pooledBufferSet.clear();
298
-
299
334
  this.#cleanupPromise = this.device.queue.onSubmittedWorkDone().then(() => {
300
335
  this.#submitLatencyMs = performance.now() - submitStart;
301
- // Destroy buffers created directly by the recorder
302
- for (const buffer of buffersToDestroy) {
303
- buffer.destroy();
304
- }
305
- // Release pooled buffers back to the pool
306
- for (const buffer of buffersToRelease) {
307
- releaseBuffer(buffer);
308
- }
309
- // Safe to destroy evicted uniform buffers now that GPU work is complete
310
- getUniformCache().flushPendingDestruction();
336
+ this.#finalizeTrackedBuffers(buffersToDestroy, buffersToRelease, false);
311
337
  }).catch((err) => {
312
338
  log.warn('CommandRecorder', `Deferred cleanup failed: ${ (err).message}`);
339
+ this.#finalizeTrackedBuffers(buffersToDestroy, buffersToRelease, true);
313
340
  });
314
341
  }
315
342
 
@@ -370,55 +397,53 @@ export class CommandRecorder {
370
397
  }
371
398
 
372
399
  if (this.#profileEntries.length === 0) {
400
+ this.#destroyProfilingResources();
373
401
  return {};
374
402
  }
375
403
 
376
- // Wait for GPU work to complete
377
- await this.device.queue.onSubmittedWorkDone();
404
+ let mapped = false;
378
405
 
379
- // Resolve queries to buffer
380
- const maxIndex = Math.max(...this.#profileEntries.map(e => e.endQueryIndex)) + 1;
381
- const resolveEncoder = this.device.createCommandEncoder({ label: 'profile_resolve' });
382
- resolveEncoder.resolveQuerySet(this.#querySet, 0, maxIndex, this.#queryBuffer, 0);
383
- resolveEncoder.copyBufferToBuffer(this.#queryBuffer, 0, this.#readbackBuffer, 0, maxIndex * 8);
384
- this.device.queue.submit([resolveEncoder.finish()]);
385
-
386
- if (!allowReadback('CommandRecorder.resolveProfileTimings')) {
387
- return null;
388
- }
389
-
390
- // Read back timestamps
391
- await this.#readbackBuffer.mapAsync(GPUMapMode.READ);
392
- const timestamps = new BigUint64Array(this.#readbackBuffer.getMappedRange());
406
+ try {
407
+ await this.device.queue.onSubmittedWorkDone();
393
408
 
394
- // Aggregate timings by label
395
-
396
- const timings = {};
409
+ const maxIndex = Math.max(...this.#profileEntries.map(e => e.endQueryIndex)) + 1;
410
+ const resolveEncoder = this.device.createCommandEncoder({ label: 'profile_resolve' });
411
+ resolveEncoder.resolveQuerySet(this.#querySet, 0, maxIndex, this.#queryBuffer, 0);
412
+ resolveEncoder.copyBufferToBuffer(this.#queryBuffer, 0, this.#readbackBuffer, 0, maxIndex * 8);
413
+ this.device.queue.submit([resolveEncoder.finish()]);
397
414
 
398
- for (const entry of this.#profileEntries) {
399
- const startNs = timestamps[entry.startQueryIndex];
400
- const endNs = timestamps[entry.endQueryIndex];
401
- const durationMs = Number(endNs - startNs) / 1_000_000;
415
+ if (!allowReadback('CommandRecorder.resolveProfileTimings')) {
416
+ return null;
417
+ }
402
418
 
403
- // Skip invalid timings
404
- if (durationMs < 0 || durationMs > 60000) {
405
- continue;
419
+ await this.#readbackBuffer.mapAsync(GPUMapMode.READ);
420
+ mapped = true;
421
+ const timestamps = new BigUint64Array(this.#readbackBuffer.getMappedRange());
422
+ const timings = {};
423
+
424
+ for (const entry of this.#profileEntries) {
425
+ const startNs = timestamps[entry.startQueryIndex];
426
+ const endNs = timestamps[entry.endQueryIndex];
427
+ const durationMs = Number(endNs - startNs) / 1_000_000;
428
+
429
+ if (durationMs < 0 || durationMs > 60000) {
430
+ continue;
431
+ }
432
+
433
+ if (timings[entry.label] !== undefined) {
434
+ timings[entry.label] += durationMs;
435
+ } else {
436
+ timings[entry.label] = durationMs;
437
+ }
406
438
  }
407
439
 
408
- // Aggregate by label
409
- if (timings[entry.label] !== undefined) {
410
- timings[entry.label] += durationMs;
411
- } else {
412
- timings[entry.label] = durationMs;
440
+ return timings;
441
+ } finally {
442
+ if (mapped && this.#readbackBuffer) {
443
+ this.#readbackBuffer.unmap();
413
444
  }
445
+ this.#destroyProfilingResources();
414
446
  }
415
-
416
- this.#readbackBuffer.unmap();
417
-
418
- // Clean up profiling resources after use
419
- this.#destroyProfilingResources();
420
-
421
- return timings;
422
447
  }
423
448
 
424
449
 
@@ -82,6 +82,7 @@ export function initDevice(): Promise<GPUDevice>;
82
82
 
83
83
  /**
84
84
  * Register an externally created GPU device for pipeline use.
85
+ * The active device epoch advances and loss handling is attached to the device.
85
86
  */
86
87
  export function setDevice(
87
88
  device: GPUDevice | null,
package/src/gpu/device.js CHANGED
@@ -28,6 +28,126 @@ function advanceDeviceEpoch() {
28
28
  deviceEpoch += 1;
29
29
  }
30
30
 
31
+ function clearActiveDeviceState() {
32
+ gpuDevice = null;
33
+ kernelCapabilities = null;
34
+ resolvedPlatformConfig = null;
35
+ platformInitialized = false;
36
+ }
37
+
38
+ function isValidGPUBuffer(value) {
39
+ if (!value) {
40
+ return false;
41
+ }
42
+ if (value.__dopplerFakeGPUBuffer === true) {
43
+ return true;
44
+ }
45
+ if (
46
+ typeof value === 'object'
47
+ && value.constructor?.name === 'FakeBuffer'
48
+ && typeof value.size === 'number'
49
+ && typeof value.usage === 'number'
50
+ && typeof value.destroy === 'function'
51
+ ) {
52
+ return true;
53
+ }
54
+ if (typeof GPUBuffer === 'undefined') {
55
+ return true;
56
+ }
57
+ return value instanceof GPUBuffer;
58
+ }
59
+
60
+ function isUsableGPUDevice(device) {
61
+ return !!(
62
+ device
63
+ && typeof device.createBuffer === 'function'
64
+ && typeof device.createBindGroup === 'function'
65
+ && typeof device.createCommandEncoder === 'function'
66
+ && typeof device.createShaderModule === 'function'
67
+ && device.queue
68
+ && typeof device.queue.submit === 'function'
69
+ );
70
+ }
71
+
72
+ function describeBindGroupBufferValue(value) {
73
+ if (value === null) return 'null';
74
+ if (value === undefined) return 'undefined';
75
+ if (typeof GPUBuffer !== 'undefined' && value instanceof GPUBuffer) return 'GPUBuffer';
76
+ if (typeof value === 'object') {
77
+ return value.constructor?.name || 'object';
78
+ }
79
+ return typeof value;
80
+ }
81
+
82
+ function validateBindGroupDescriptor(descriptor) {
83
+ const label = descriptor?.label || 'unlabeled_bind_group';
84
+ const entries = Array.isArray(descriptor?.entries) ? descriptor.entries : [];
85
+ for (const entry of entries) {
86
+ const resource = entry?.resource;
87
+ if (!resource || typeof resource !== 'object' || !('buffer' in resource)) {
88
+ continue;
89
+ }
90
+ if (isValidGPUBuffer(resource.buffer)) {
91
+ continue;
92
+ }
93
+ throw new Error(
94
+ `[${label}] binding ${entry.binding} requires a GPUBuffer; ` +
95
+ `got ${describeBindGroupBufferValue(resource.buffer)}.`
96
+ );
97
+ }
98
+ }
99
+
100
+ function wrapDeviceCreateBindGroup(device) {
101
+ if (!device || device.__dopplerBindGroupValidationWrapped) {
102
+ return device;
103
+ }
104
+ const originalCreateBindGroup = device.createBindGroup.bind(device);
105
+ device.createBindGroup = (descriptor) => {
106
+ validateBindGroupDescriptor(descriptor);
107
+ return originalCreateBindGroup(descriptor);
108
+ };
109
+ Object.defineProperty(device, '__dopplerBindGroupValidationWrapped', {
110
+ value: true,
111
+ configurable: true,
112
+ enumerable: false,
113
+ writable: false,
114
+ });
115
+ return device;
116
+ }
117
+
118
+ function registerDeviceLostHandler(device) {
119
+ if (!device || device.__dopplerLossHandlerRegistered) {
120
+ return device;
121
+ }
122
+
123
+ if (device.lost && typeof device.lost.then === 'function') {
124
+ const trackedDevice = device;
125
+ device.lost.then((info) => {
126
+ if (gpuDevice !== trackedDevice) {
127
+ return;
128
+ }
129
+ log.error('GPU', 'Device lost: ' + info.message + ', Reason: ' + info.reason);
130
+ clearActiveDeviceState();
131
+ advanceDeviceEpoch();
132
+ }).catch((error) => {
133
+ if (gpuDevice !== trackedDevice) {
134
+ return;
135
+ }
136
+ log.warn('GPU', 'Device lost handler failed: ' + (error?.message ?? error));
137
+ clearActiveDeviceState();
138
+ advanceDeviceEpoch();
139
+ });
140
+ }
141
+
142
+ Object.defineProperty(device, '__dopplerLossHandlerRegistered', {
143
+ value: true,
144
+ configurable: true,
145
+ enumerable: false,
146
+ writable: false,
147
+ });
148
+ return device;
149
+ }
150
+
31
151
 
32
152
  export const FEATURES = ({
33
153
  SHADER_F16: 'shader-f16',
@@ -163,7 +283,11 @@ async function initializePlatformAndRegistry(adapter) {
163
283
  export async function initDevice() {
164
284
  // Return cached device if available
165
285
  if (gpuDevice) {
166
- return gpuDevice;
286
+ if (isUsableGPUDevice(gpuDevice)) {
287
+ return gpuDevice;
288
+ }
289
+ clearActiveDeviceState();
290
+ advanceDeviceEpoch();
167
291
  }
168
292
 
169
293
  if (!isWebGPUAvailable()) {
@@ -201,18 +325,10 @@ export async function initDevice() {
201
325
  if (!gpuDevice) {
202
326
  throw createDopplerError(ERROR_CODES.GPU_DEVICE_FAILED, 'Failed to create WebGPU device');
203
327
  }
328
+ wrapDeviceCreateBindGroup(gpuDevice);
329
+ registerDeviceLostHandler(gpuDevice);
204
330
  advanceDeviceEpoch();
205
331
 
206
- // Set up device lost handler
207
- gpuDevice.lost.then((info) => {
208
- log.error('GPU', 'Device lost: ' + info.message + ', Reason: ' + info.reason);
209
- gpuDevice = null;
210
- kernelCapabilities = null;
211
- resolvedPlatformConfig = null;
212
- platformInitialized = false;
213
- advanceDeviceEpoch();
214
- });
215
-
216
332
  // Wrap queue for submit tracking (when enabled)
217
333
  wrapQueueForTracking(gpuDevice.queue);
218
334
 
@@ -244,15 +360,14 @@ export async function initDevice() {
244
360
 
245
361
  export function setDevice(device, options = {}) {
246
362
  if (!device) {
247
- gpuDevice = null;
248
- kernelCapabilities = null;
249
- resolvedPlatformConfig = null;
250
- platformInitialized = false;
363
+ clearActiveDeviceState();
251
364
  advanceDeviceEpoch();
252
365
  return;
253
366
  }
254
367
 
255
368
  gpuDevice = device;
369
+ wrapDeviceCreateBindGroup(gpuDevice);
370
+ registerDeviceLostHandler(gpuDevice);
256
371
  advanceDeviceEpoch();
257
372
  wrapQueueForTracking(gpuDevice.queue);
258
373
 
@@ -314,10 +429,7 @@ export function isPlatformInitialized() {
314
429
  export function destroyDevice() {
315
430
  if (gpuDevice) {
316
431
  gpuDevice.destroy();
317
- gpuDevice = null;
318
- kernelCapabilities = null;
319
- resolvedPlatformConfig = null;
320
- platformInitialized = false;
432
+ clearActiveDeviceState();
321
433
  advanceDeviceEpoch();
322
434
  }
323
435
  }