@simulatte/doppler 0.1.6 → 0.1.8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (355) hide show
  1. package/CHANGELOG.md +145 -0
  2. package/README.md +16 -23
  3. package/package.json +30 -32
  4. package/src/adapters/adapter-registry.js +12 -1
  5. package/src/adapters/lora-loader.js +23 -6
  6. package/src/bridge/extension-client.d.ts +5 -0
  7. package/src/bridge/extension-client.js +40 -0
  8. package/src/bridge/index.d.ts +2 -1
  9. package/src/bridge/index.js +6 -4
  10. package/src/browser/browser-converter.js +31 -1
  11. package/src/browser/file-picker.js +6 -0
  12. package/src/browser/safetensors-parser-browser.js +84 -1
  13. package/src/browser/shard-io-browser.js +2 -2
  14. package/src/browser/tensor-source-download.js +8 -2
  15. package/src/browser/tensor-source-http.d.ts +1 -0
  16. package/src/browser/tensor-source-http.js +5 -1
  17. package/src/client/doppler-api.browser.js +20 -4
  18. package/src/client/doppler-api.js +19 -3
  19. package/src/client/doppler-provider/generation.js +12 -0
  20. package/src/client/doppler-provider/model-manager.d.ts +10 -0
  21. package/src/client/doppler-provider/model-manager.js +91 -19
  22. package/src/client/doppler-provider/source-runtime.d.ts +2 -1
  23. package/src/client/doppler-provider/source-runtime.js +132 -13
  24. package/src/client/doppler-registry.json +5 -20
  25. package/src/config/backward-registry-loader.js +17 -2
  26. package/src/config/execution-v0-contract-check.js +113 -15
  27. package/src/config/kernel-path-contract-check.js +57 -29
  28. package/src/config/kernel-path-loader.d.ts +5 -0
  29. package/src/config/kernel-path-loader.js +18 -36
  30. package/src/config/kernels/kernel-ref-digests.js +1 -1
  31. package/src/config/kernels/registry.js +14 -1
  32. package/src/config/kernels/registry.json +81 -5
  33. package/src/config/loader.d.ts +1 -1
  34. package/src/config/loader.js +15 -2
  35. package/src/config/merge-contract-check.js +66 -4
  36. package/src/config/merge-helpers.js +128 -7
  37. package/src/config/merge.d.ts +1 -0
  38. package/src/config/merge.js +10 -0
  39. package/src/config/param-validator.js +47 -2
  40. package/src/config/presets/kernel-paths/{gemma2-q4k-dequant-f32a.json → gemma2-q4k-dequant-f32a-nosubgroups.json} +3 -3
  41. package/src/config/presets/kernel-paths/gemma3-f16-fused-f32a-online-streamingprefill.json +223 -0
  42. package/src/config/presets/kernel-paths/{gemma3-q4k-dequant-f32a.json → gemma3-q4k-dequant-f32a-nosubgroups.json} +3 -3
  43. package/src/config/presets/kernel-paths/gemma3-q4k-dequant-f32w-f32a-online.json +56 -0
  44. package/src/config/presets/kernel-paths/lfm2-q4k-dequant-f32a-nosubgroups.json +61 -0
  45. package/src/config/presets/kernel-paths/registry.json +43 -8
  46. package/src/config/presets/models/gemma2.json +3 -2
  47. package/src/config/presets/models/gemma3.json +2 -0
  48. package/src/config/presets/models/qwen3.json +4 -3
  49. package/src/config/presets/models/qwen3_5.json +16 -0
  50. package/src/config/presets/runtime/experiments/bench/gemma3-bench-q4k.json +1 -1
  51. package/src/config/presets/runtime/experiments/debug/gemma3-debug-q4k.json +1 -1
  52. package/src/config/presets/runtime/experiments/verify/gemma3-verify.json +1 -1
  53. package/src/config/presets/runtime/kernels/dequant-f16-q4k.json +6 -13
  54. package/src/config/presets/runtime/kernels/dequant-f32-q4k.json +6 -13
  55. package/src/config/presets/runtime/kernels/embeddinggemma-q4k-dequant-f32a.json +37 -0
  56. package/src/config/presets/runtime/kernels/fused-q4k.json +6 -13
  57. package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f16a.json +33 -0
  58. package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f32a-nosubgroups.json +33 -0
  59. package/src/config/presets/runtime/kernels/gemma2-q4k-fused-f32a.json +33 -0
  60. package/src/config/presets/runtime/kernels/safe-q4k.json +6 -13
  61. package/src/config/presets/runtime/model/qwen3-5-layer-probe.json +52 -0
  62. package/src/config/presets/runtime/model/qwen3-5-linear-attn-debug.json +90 -0
  63. package/src/config/presets/runtime/platform/metal-apple-q4k.json +1 -1
  64. package/src/config/runtime.js +6 -1
  65. package/src/config/schema/conversion.schema.d.ts +1 -0
  66. package/src/config/schema/debug.schema.d.ts +5 -0
  67. package/src/config/schema/doppler.schema.js +16 -21
  68. package/src/config/schema/inference-defaults.schema.js +3 -3
  69. package/src/config/schema/kernel-path.schema.d.ts +5 -1
  70. package/src/config/schema/kernel-thresholds.schema.js +12 -4
  71. package/src/config/schema/manifest.schema.d.ts +3 -2
  72. package/src/config/schema/manifest.schema.js +17 -4
  73. package/src/config/schema/storage.schema.js +1 -1
  74. package/src/config/training-defaults.js +30 -22
  75. package/src/converter/conversion-plan.js +104 -11
  76. package/src/converter/core.d.ts +7 -0
  77. package/src/converter/core.js +16 -9
  78. package/src/converter/execution-v0-manifest.js +4 -1
  79. package/src/converter/index.d.ts +1 -0
  80. package/src/converter/index.js +1 -0
  81. package/src/converter/manifest-inference.js +50 -29
  82. package/src/converter/parsers/diffusion.js +0 -3
  83. package/src/converter/parsers/transformer.js +4 -0
  84. package/src/converter/quantization-info.js +40 -16
  85. package/src/converter/quantizer.js +19 -12
  86. package/src/converter/rope-config.js +8 -6
  87. package/src/converter/shard-packer.d.ts +1 -1
  88. package/src/converter/shard-packer.js +4 -1
  89. package/src/converter/tokenizer-utils.d.ts +1 -0
  90. package/src/converter/tokenizer-utils.js +4 -1
  91. package/src/debug/config.js +123 -11
  92. package/src/debug/reference/hf_qwen35_linear_attn_debug.py +268 -0
  93. package/src/debug/signals.js +7 -1
  94. package/src/debug/tensor.d.ts +2 -0
  95. package/src/debug/tensor.js +13 -2
  96. package/src/distribution/p2p-control-plane.js +52 -12
  97. package/src/distribution/p2p-observability.js +43 -7
  98. package/src/distribution/p2p-webrtc-browser.js +20 -0
  99. package/src/distribution/shard-delivery.js +83 -27
  100. package/src/formats/gguf/types.js +33 -16
  101. package/src/formats/rdrr/groups.d.ts +12 -4
  102. package/src/formats/rdrr/groups.js +3 -6
  103. package/src/formats/rdrr/parsing.d.ts +4 -0
  104. package/src/formats/rdrr/parsing.js +53 -3
  105. package/src/formats/rdrr/types.d.ts +2 -1
  106. package/src/gpu/command-recorder.js +86 -61
  107. package/src/gpu/device.d.ts +1 -0
  108. package/src/gpu/device.js +73 -19
  109. package/src/gpu/kernel-tuner/benchmarks.js +326 -316
  110. package/src/gpu/kernel-tuner/cache.js +71 -4
  111. package/src/gpu/kernel-tuner/tuner.js +22 -4
  112. package/src/gpu/kernels/attention.js +15 -34
  113. package/src/gpu/kernels/backward/adam.js +62 -58
  114. package/src/gpu/kernels/backward/attention_backward.js +257 -169
  115. package/src/gpu/kernels/backward/conv2d_backward.js +14 -1
  116. package/src/gpu/kernels/cast.js +191 -149
  117. package/src/gpu/kernels/check-stop.js +33 -44
  118. package/src/gpu/kernels/conv2d.js +27 -17
  119. package/src/gpu/kernels/cross_entropy_loss.js +21 -15
  120. package/src/gpu/kernels/depthwise_conv2d.js +36 -26
  121. package/src/gpu/kernels/dequant.js +178 -126
  122. package/src/gpu/kernels/energy.d.ts +3 -21
  123. package/src/gpu/kernels/energy.js +111 -88
  124. package/src/gpu/kernels/feature-check.js +1 -1
  125. package/src/gpu/kernels/fused_ffn.js +84 -65
  126. package/src/gpu/kernels/fused_matmul_residual.js +56 -33
  127. package/src/gpu/kernels/fused_matmul_rmsnorm.js +62 -45
  128. package/src/gpu/kernels/gather.js +33 -15
  129. package/src/gpu/kernels/gelu.js +19 -11
  130. package/src/gpu/kernels/grouped_pointwise_conv2d.js +33 -23
  131. package/src/gpu/kernels/groupnorm.js +34 -23
  132. package/src/gpu/kernels/index.d.ts +8 -0
  133. package/src/gpu/kernels/index.js +6 -0
  134. package/src/gpu/kernels/kv-quantize.js +5 -2
  135. package/src/gpu/kernels/layernorm.js +35 -19
  136. package/src/gpu/kernels/logit-merge.js +5 -3
  137. package/src/gpu/kernels/matmul-selection.js +47 -4
  138. package/src/gpu/kernels/matmul.d.ts +2 -0
  139. package/src/gpu/kernels/matmul.js +59 -40
  140. package/src/gpu/kernels/modulate.js +23 -15
  141. package/src/gpu/kernels/moe.js +221 -175
  142. package/src/gpu/kernels/pixel_shuffle.js +22 -14
  143. package/src/gpu/kernels/relu.js +18 -10
  144. package/src/gpu/kernels/repeat_channels.js +25 -17
  145. package/src/gpu/kernels/residual.js +37 -27
  146. package/src/gpu/kernels/rmsnorm.js +66 -43
  147. package/src/gpu/kernels/rope.js +3 -0
  148. package/src/gpu/kernels/sample.js +27 -38
  149. package/src/gpu/kernels/sana_linear_attention.js +18 -10
  150. package/src/gpu/kernels/scale.js +18 -11
  151. package/src/gpu/kernels/shader-cache.js +4 -2
  152. package/src/gpu/kernels/silu.js +120 -72
  153. package/src/gpu/kernels/softmax.js +44 -25
  154. package/src/gpu/kernels/split_qg.d.ts +50 -0
  155. package/src/gpu/kernels/split_qg.js +46 -0
  156. package/src/gpu/kernels/split_qg.wgsl +58 -0
  157. package/src/gpu/kernels/split_qg_f16.wgsl +62 -0
  158. package/src/gpu/kernels/split_qkv.js +23 -13
  159. package/src/gpu/kernels/transpose.js +18 -10
  160. package/src/gpu/kernels/transpose.wgsl +5 -3
  161. package/src/gpu/kernels/upsample2d.js +21 -13
  162. package/src/gpu/kernels/utils.js +20 -13
  163. package/src/gpu/partitioned-buffer-pool.js +10 -2
  164. package/src/gpu/perf-guards.js +2 -9
  165. package/src/gpu/profiler.js +27 -22
  166. package/src/gpu/readback-utils.d.ts +16 -0
  167. package/src/gpu/readback-utils.js +41 -0
  168. package/src/gpu/submit-tracker.js +13 -0
  169. package/src/gpu/uniform-cache.d.ts +1 -0
  170. package/src/gpu/uniform-cache.js +30 -9
  171. package/src/gpu/weight-buffer.d.ts +1 -1
  172. package/src/gpu/weight-buffer.js +1 -1
  173. package/src/hotswap/intent-bundle.js +6 -0
  174. package/src/hotswap/manifest.d.ts +10 -1
  175. package/src/hotswap/manifest.js +12 -2
  176. package/src/hotswap/runtime.js +30 -8
  177. package/src/index-browser.d.ts +44 -0
  178. package/src/index-browser.js +14 -0
  179. package/src/inference/browser-harness-contract-helpers.d.ts +5 -0
  180. package/src/inference/browser-harness-contract-helpers.js +28 -0
  181. package/src/inference/browser-harness-diffusion-energy-suites.d.ts +2 -0
  182. package/src/inference/browser-harness-diffusion-energy-suites.js +269 -0
  183. package/src/inference/browser-harness-model-helpers.d.ts +16 -0
  184. package/src/inference/browser-harness-model-helpers.js +217 -0
  185. package/src/inference/browser-harness-report-helpers.d.ts +7 -0
  186. package/src/inference/browser-harness-report-helpers.js +42 -0
  187. package/src/inference/browser-harness-runtime-helpers.d.ts +61 -0
  188. package/src/inference/browser-harness-runtime-helpers.js +415 -0
  189. package/src/inference/browser-harness-suite-helpers.d.ts +28 -0
  190. package/src/inference/browser-harness-suite-helpers.js +268 -0
  191. package/src/inference/browser-harness-text-helpers.d.ts +27 -0
  192. package/src/inference/browser-harness-text-helpers.js +788 -0
  193. package/src/inference/browser-harness.d.ts +8 -0
  194. package/src/inference/browser-harness.js +149 -1996
  195. package/src/inference/kv-cache/base.js +140 -94
  196. package/src/inference/kv-cache/tiered.js +5 -3
  197. package/src/inference/moe-router.js +88 -56
  198. package/src/inference/multi-model-network.js +5 -3
  199. package/src/inference/network-evolution.d.ts +11 -2
  200. package/src/inference/network-evolution.js +20 -21
  201. package/src/inference/pipelines/context.d.ts +3 -0
  202. package/src/inference/pipelines/context.js +142 -2
  203. package/src/inference/pipelines/diffusion/helpers.js +10 -2
  204. package/src/inference/pipelines/diffusion/pipeline.js +2 -1
  205. package/src/inference/pipelines/diffusion/sd3-transformer.js +10 -10
  206. package/src/inference/pipelines/diffusion/text-encoder-gpu.js +8 -2
  207. package/src/inference/pipelines/diffusion/vae.js +3 -7
  208. package/src/inference/pipelines/energy/pipeline.js +27 -21
  209. package/src/inference/pipelines/energy/quintel.d.ts +5 -0
  210. package/src/inference/pipelines/energy/quintel.js +11 -0
  211. package/src/inference/pipelines/energy-head/row-head-pipeline.js +17 -13
  212. package/src/inference/pipelines/structured/json-head-pipeline.js +26 -11
  213. package/src/inference/pipelines/text/attention/output-projection.d.ts +12 -0
  214. package/src/inference/pipelines/text/attention/output-projection.js +8 -0
  215. package/src/inference/pipelines/text/attention/projections.d.ts +10 -1
  216. package/src/inference/pipelines/text/attention/projections.js +192 -112
  217. package/src/inference/pipelines/text/attention/record.js +77 -14
  218. package/src/inference/pipelines/text/attention/run.js +112 -14
  219. package/src/inference/pipelines/text/config.js +17 -4
  220. package/src/inference/pipelines/text/embed.js +2 -8
  221. package/src/inference/pipelines/text/execution-plan.js +46 -23
  222. package/src/inference/pipelines/text/execution-v0-contract-helpers.d.ts +59 -0
  223. package/src/inference/pipelines/text/execution-v0-contract-helpers.js +937 -0
  224. package/src/inference/pipelines/text/execution-v0-runtime-builders.d.ts +15 -0
  225. package/src/inference/pipelines/text/execution-v0-runtime-builders.js +279 -0
  226. package/src/inference/pipelines/text/execution-v0.js +62 -1013
  227. package/src/inference/pipelines/text/generator-runtime.js +5 -0
  228. package/src/inference/pipelines/text/generator-steps.d.ts +52 -0
  229. package/src/inference/pipelines/text/generator-steps.js +340 -221
  230. package/src/inference/pipelines/text/generator.js +56 -40
  231. package/src/inference/pipelines/text/init.d.ts +13 -0
  232. package/src/inference/pipelines/text/init.js +94 -25
  233. package/src/inference/pipelines/text/kernel-path-auto-select.js +2 -0
  234. package/src/inference/pipelines/text/kernel-trace.d.ts +2 -0
  235. package/src/inference/pipelines/text/kernel-trace.js +6 -0
  236. package/src/inference/pipelines/text/layer.js +4 -9
  237. package/src/inference/pipelines/text/linear-attention.d.ts +15 -0
  238. package/src/inference/pipelines/text/linear-attention.js +113 -9
  239. package/src/inference/pipelines/text/logits/gpu.js +12 -7
  240. package/src/inference/pipelines/text/logits/index.d.ts +6 -1
  241. package/src/inference/pipelines/text/logits/index.js +13 -12
  242. package/src/inference/pipelines/text/logits/utils.d.ts +7 -0
  243. package/src/inference/pipelines/text/logits/utils.js +9 -0
  244. package/src/inference/pipelines/text/lora-apply.js +50 -32
  245. package/src/inference/pipelines/text/model-load.js +282 -104
  246. package/src/inference/pipelines/text/moe-cache.js +5 -4
  247. package/src/inference/pipelines/text/moe-cpu-gptoss.js +74 -69
  248. package/src/inference/pipelines/text/moe-cpu.js +42 -38
  249. package/src/inference/pipelines/text/moe-gpu.js +110 -86
  250. package/src/inference/pipelines/text/ops.js +90 -90
  251. package/src/inference/pipelines/text/probes.js +9 -9
  252. package/src/inference/pipelines/text/sampling.js +52 -6
  253. package/src/inference/pipelines/text/weights.js +17 -7
  254. package/src/inference/pipelines/text.js +13 -1
  255. package/src/inference/speculative.d.ts +2 -2
  256. package/src/inference/speculative.js +4 -18
  257. package/src/inference/test-harness.d.ts +1 -1
  258. package/src/inference/test-harness.js +17 -7
  259. package/src/inference/tokenizer.d.ts +0 -5
  260. package/src/inference/tokenizer.js +4 -23
  261. package/src/inference/tokenizers/bpe.js +9 -0
  262. package/src/inference/tokenizers/bundled.js +20 -0
  263. package/src/inference/tokenizers/sentencepiece.js +12 -0
  264. package/src/loader/doppler-loader.js +38 -22
  265. package/src/loader/dtype-utils.js +3 -44
  266. package/src/loader/embedding-loader.js +7 -3
  267. package/src/loader/experts/expert-cache.js +13 -6
  268. package/src/loader/experts/expert-loader.js +10 -6
  269. package/src/loader/final-weights-loader.js +10 -4
  270. package/src/loader/layer-loader.js +2 -1
  271. package/src/loader/loader-state.js +2 -2
  272. package/src/loader/memory-monitor.js +8 -0
  273. package/src/loader/multi-model-loader.d.ts +14 -0
  274. package/src/loader/multi-model-loader.js +70 -24
  275. package/src/loader/shard-cache.js +84 -14
  276. package/src/loader/shard-resolver.js +25 -3
  277. package/src/loader/tensors/tensor-loader.js +214 -144
  278. package/src/loader/tensors/tensor-reader.js +76 -19
  279. package/src/loader/weight-downcast.js +1 -1
  280. package/src/memory/buffer-pool.d.ts +9 -1
  281. package/src/memory/buffer-pool.js +109 -44
  282. package/src/memory/unified-detect.js +1 -1
  283. package/src/rules/inference/dtype.rules.json +5 -0
  284. package/src/rules/inference/kernel-path.rules.json +24 -8
  285. package/src/rules/kernels/split-qg.rules.json +6 -0
  286. package/src/rules/rule-registry.js +27 -1
  287. package/src/storage/backends/opfs-store.js +68 -24
  288. package/src/storage/downloader.js +365 -83
  289. package/src/storage/index.d.ts +3 -0
  290. package/src/storage/index.js +3 -0
  291. package/src/storage/preflight.d.ts +2 -2
  292. package/src/storage/preflight.js +24 -2
  293. package/src/storage/quickstart-downloader.js +11 -5
  294. package/src/storage/registry.js +10 -4
  295. package/src/storage/reports.js +1 -1
  296. package/src/storage/shard-manager.d.ts +15 -1
  297. package/src/storage/shard-manager.js +55 -6
  298. package/src/storage/source-artifact-store.d.ts +52 -0
  299. package/src/storage/source-artifact-store.js +234 -0
  300. package/src/tooling/command-api-constants.d.ts +9 -0
  301. package/src/tooling/command-api-constants.js +9 -0
  302. package/src/tooling/command-api-family-normalizers.d.ts +9 -0
  303. package/src/tooling/command-api-family-normalizers.js +343 -0
  304. package/src/tooling/command-api-helpers.d.ts +25 -0
  305. package/src/tooling/command-api-helpers.js +262 -0
  306. package/src/tooling/command-api.js +16 -602
  307. package/src/tooling/command-envelope.js +4 -1
  308. package/src/tooling/command-runner-shared.js +52 -18
  309. package/src/tooling/conversion-config-materializer.js +3 -5
  310. package/src/tooling/lean-execution-contract.js +150 -3
  311. package/src/tooling/node-browser-command-runner.js +161 -271
  312. package/src/tooling/node-command-runner.js +29 -3
  313. package/src/tooling/node-converter.js +30 -1
  314. package/src/tooling/node-source-runtime.d.ts +1 -1
  315. package/src/tooling/node-source-runtime.js +120 -3
  316. package/src/tooling/node-webgpu.js +24 -21
  317. package/src/tooling/opfs-cache.js +21 -4
  318. package/src/tooling/runtime-input-composition.d.ts +38 -0
  319. package/src/tooling/runtime-input-composition.js +86 -0
  320. package/src/tooling/source-runtime-bundle.d.ts +40 -5
  321. package/src/tooling/source-runtime-bundle.js +261 -34
  322. package/src/tooling/source-runtime-materializer.d.ts +6 -0
  323. package/src/tooling/source-runtime-materializer.js +93 -0
  324. package/src/training/attention-backward.js +32 -17
  325. package/src/training/autograd.js +80 -52
  326. package/src/training/checkpoint-watch.d.ts +2 -1
  327. package/src/training/checkpoint-watch.js +39 -6
  328. package/src/training/checkpoint.js +40 -11
  329. package/src/training/clip.js +2 -1
  330. package/src/training/datasets/token-batch.js +20 -8
  331. package/src/training/distillation/checkpoint-watch.js +1 -0
  332. package/src/training/distillation/student-fixture.d.ts +22 -0
  333. package/src/training/distillation/student-fixture.js +846 -0
  334. package/src/training/distillation/suite-data.d.ts +45 -0
  335. package/src/training/distillation/suite-data.js +189 -0
  336. package/src/training/lora-pipeline.js +4 -7
  337. package/src/training/lora.js +26 -12
  338. package/src/training/loss.js +5 -6
  339. package/src/training/objectives/cross_entropy.js +2 -5
  340. package/src/training/objectives/distill_kd.js +4 -8
  341. package/src/training/objectives/distill_triplet.js +4 -8
  342. package/src/training/objectives/ul_stage2_base.js +4 -8
  343. package/src/training/operator-command.js +2 -0
  344. package/src/training/optimizer.js +19 -7
  345. package/src/training/runner.js +2 -1
  346. package/src/training/suite.js +18 -978
  347. package/src/training/tensor-factory.d.ts +9 -0
  348. package/src/training/tensor-factory.js +13 -0
  349. package/src/training/trainer.js +3 -5
  350. package/src/training/ul_dataset.js +3 -5
  351. package/src/training/workloads.js +70 -79
  352. package/src/types/model.d.ts +5 -0
  353. package/src/version.js +1 -1
  354. package/tools/convert-safetensors-node.js +22 -16
  355. package/tools/doppler-cli.js +50 -26
@@ -58,36 +58,46 @@ async function _depthwiseConv2D(target, input, weight, bias, options = {}) {
58
58
  device.queue.writeBuffer(biasBuffer, 0, new Uint8Array(paddedSize));
59
59
  }
60
60
 
61
- await unifiedKernelWrapper(
62
- 'depthwise_conv2d',
63
- target,
64
- variant,
65
- [input, weightBuffer, biasBuffer, output],
66
- {
67
- channels,
68
- height,
69
- width,
70
- out_height: outHeight,
71
- out_width: outWidth,
72
- kernel_h: kernelH,
73
- kernel_w: kernelW,
74
- stride,
75
- pad,
76
- _pad0: 0,
77
- _pad1: 0,
78
- },
79
- [Math.ceil(outSpatial / WORKGROUP_SIZES.DEFAULT), channels, 1]
80
- );
61
+ try {
62
+ await unifiedKernelWrapper(
63
+ 'depthwise_conv2d',
64
+ target,
65
+ variant,
66
+ [input, weightBuffer, biasBuffer, output],
67
+ {
68
+ channels,
69
+ height,
70
+ width,
71
+ out_height: outHeight,
72
+ out_width: outWidth,
73
+ kernel_h: kernelH,
74
+ kernel_w: kernelW,
75
+ stride,
76
+ pad,
77
+ _pad0: 0,
78
+ _pad1: 0,
79
+ },
80
+ [Math.ceil(outSpatial / WORKGROUP_SIZES.DEFAULT), channels, 1]
81
+ );
81
82
 
82
- if (tempBias) {
83
- if (recorder) {
84
- recorder.trackTemporaryBuffer(tempBias);
85
- } else {
83
+ if (tempBias) {
84
+ if (recorder) {
85
+ recorder.trackTemporaryBuffer(tempBias);
86
+ } else {
87
+ releaseBuffer(tempBias);
88
+ }
89
+ }
90
+
91
+ return createTensor(output, input.dtype, [channels, outHeight, outWidth], 'depthwise_conv2d_output');
92
+ } catch (error) {
93
+ if (tempBias) {
86
94
  releaseBuffer(tempBias);
87
95
  }
96
+ if (!outputBuffer) {
97
+ releaseBuffer(output);
98
+ }
99
+ throw error;
88
100
  }
89
-
90
- return createTensor(output, input.dtype, [channels, outHeight, outWidth], 'depthwise_conv2d_output');
91
101
  }
92
102
 
93
103
  export async function runDepthwiseConv2D(input, weight, bias, options = {}) {
@@ -1,7 +1,7 @@
1
1
 
2
2
 
3
3
  import { getDevice, getKernelCapabilities } from '../device.js';
4
- import { acquireBuffer } from '../../memory/buffer-pool.js';
4
+ import { acquireBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
5
5
  import { createTensor } from '../tensor.js';
6
6
  import { GPU_LIMITS, TILE_SIZES, WORKGROUP_SIZES } from './constants.js';
7
7
  import { Q6K_BLOCK_BYTES, Q8_0_BLOCK_BYTES, Q8_0_BLOCK_SIZE } from '../../loader/quantization-constants.js';
@@ -69,6 +69,17 @@ export function createDequantBindGroupLayout() {
69
69
  ]);
70
70
  }
71
71
 
72
+ function cleanupDequantResources(uniformBuffer, ownedBuffers) {
73
+ if (uniformBuffer) {
74
+ releaseUniformBuffer(uniformBuffer);
75
+ }
76
+ for (const buffer of ownedBuffers) {
77
+ if (buffer) {
78
+ releaseBuffer(buffer);
79
+ }
80
+ }
81
+ }
82
+
72
83
 
73
84
  export async function dequantize(
74
85
  quantized,
@@ -76,12 +87,17 @@ export async function dequantize(
76
87
  options = {}
77
88
  ) {
78
89
  const device = getDevice();
90
+ const capabilities = getKernelCapabilities();
79
91
  const {
80
92
  outputOffset = 0,
81
93
  outputBuffer = null,
82
94
  outputDtype = 'f32',
83
95
  } = options;
84
96
 
97
+ if (outputDtype === 'f16' && capabilities?.hasF16 !== true) {
98
+ throw new Error('[dequantize] f16 output requires shader-f16 support.');
99
+ }
100
+
85
101
  // Select kernel
86
102
  const variant = selectDequantKernel({ ...options, outputDtype });
87
103
  const pipeline = await getPipelineFast('dequant', variant);
@@ -92,7 +108,8 @@ export async function dequantize(
92
108
  const outputSize = numBlocks * QK_K * bytesPerElem;
93
109
 
94
110
  // Create output buffer if not provided
95
- const output = outputBuffer || acquireBuffer(outputSize, undefined, 'dequant_output');
111
+ const ownedOutput = outputBuffer ? null : acquireBuffer(outputSize, undefined, 'dequant_output');
112
+ const output = outputBuffer || ownedOutput;
96
113
 
97
114
  // Create uniform buffer
98
115
  const uniformBuffer = createUniformBufferWithView(
@@ -108,21 +125,24 @@ export async function dequantize(
108
125
  device
109
126
  );
110
127
 
111
- // Create bind group
112
- const bindGroup = device.createBindGroup({
113
- label: 'dequant_bind_group',
114
- layout: pipeline.getBindGroupLayout(0),
115
- entries: [
116
- { binding: 0, resource: { buffer: uniformBuffer } },
117
- { binding: 1, resource: { buffer: quantized } },
118
- { binding: 2, resource: { buffer: output } },
119
- ],
120
- });
121
-
122
- const workgroups = calculateDequantWorkgroups(variant, numBlocks);
123
- dispatch(device, pipeline, bindGroup, workgroups, 'dequant');
128
+ try {
129
+ const bindGroup = device.createBindGroup({
130
+ label: 'dequant_bind_group',
131
+ layout: pipeline.getBindGroupLayout(0),
132
+ entries: [
133
+ { binding: 0, resource: { buffer: uniformBuffer } },
134
+ { binding: 1, resource: { buffer: quantized } },
135
+ { binding: 2, resource: { buffer: output } },
136
+ ],
137
+ });
138
+
139
+ const workgroups = calculateDequantWorkgroups(variant, numBlocks);
140
+ dispatch(device, pipeline, bindGroup, workgroups, 'dequant');
141
+ } catch (error) {
142
+ cleanupDequantResources(uniformBuffer, [ownedOutput]);
143
+ throw error;
144
+ }
124
145
 
125
- // Release uniform buffer back to cache (or destroy if not cached)
126
146
  releaseUniformBuffer(uniformBuffer);
127
147
 
128
148
 
@@ -140,7 +160,11 @@ export async function dequantizeRowwise(
140
160
  options = {}
141
161
  ) {
142
162
  const device = getDevice();
163
+ const capabilities = getKernelCapabilities();
143
164
  const { outputBuffer = null, outputDtype = 'f16' } = options;
165
+ if (outputDtype === 'f16' && capabilities?.hasF16 !== true) {
166
+ throw new Error('[dequantizeRowwise] f16 output requires shader-f16 support.');
167
+ }
144
168
  const finalOutputDtype = selectSharedRuleValue('shared', 'dtype', 'f16OrF32FromDtype', { dtype: outputDtype });
145
169
  const pipelineVariant = selectKernelRuleValue(
146
170
  'dequant',
@@ -157,7 +181,8 @@ export async function dequantizeRowwise(
157
181
  const bytesPerElem = finalOutputDtype === 'f16' ? 2 : 4;
158
182
  const outputSize = rows * K * bytesPerElem;
159
183
 
160
- const output = outputBuffer || acquireBuffer(outputSize, undefined, 'dequant_rowwise_output');
184
+ const ownedOutput = outputBuffer ? null : acquireBuffer(outputSize, undefined, 'dequant_rowwise_output');
185
+ const output = outputBuffer || ownedOutput;
161
186
 
162
187
  const uniformBuffer = createUniformBufferWithView(
163
188
  'dequant_rowwise_uniforms',
@@ -172,18 +197,23 @@ export async function dequantizeRowwise(
172
197
  device
173
198
  );
174
199
 
175
- const bindGroup = device.createBindGroup({
176
- label: 'dequant_rowwise_bind_group',
177
- layout: pipeline.getBindGroupLayout(0),
178
- entries: [
179
- { binding: 0, resource: { buffer: uniformBuffer } },
180
- { binding: 1, resource: { buffer: quantized } },
181
- { binding: 2, resource: { buffer: output } },
182
- ],
183
- });
184
-
185
- const workgroups = [numBlocks, 1, 1];
186
- dispatch(device, pipeline, bindGroup, workgroups, 'dequant_rowwise');
200
+ try {
201
+ const bindGroup = device.createBindGroup({
202
+ label: 'dequant_rowwise_bind_group',
203
+ layout: pipeline.getBindGroupLayout(0),
204
+ entries: [
205
+ { binding: 0, resource: { buffer: uniformBuffer } },
206
+ { binding: 1, resource: { buffer: quantized } },
207
+ { binding: 2, resource: { buffer: output } },
208
+ ],
209
+ });
210
+
211
+ const workgroups = [numBlocks, 1, 1];
212
+ dispatch(device, pipeline, bindGroup, workgroups, 'dequant_rowwise');
213
+ } catch (error) {
214
+ cleanupDequantResources(uniformBuffer, [ownedOutput]);
215
+ throw error;
216
+ }
187
217
 
188
218
  releaseUniformBuffer(uniformBuffer);
189
219
 
@@ -208,7 +238,8 @@ export async function dequantizeMXFP4(
208
238
  const pipeline = await getPipelineFast('dequant', 'mxfp4');
209
239
 
210
240
  const outputSize = totalElements * 4; // F32 output
211
- const output = outputBuffer || acquireBuffer(outputSize, undefined, 'mxfp4_dequant_output');
241
+ const ownedOutput = outputBuffer ? null : acquireBuffer(outputSize, undefined, 'mxfp4_dequant_output');
242
+ const output = outputBuffer || ownedOutput;
212
243
 
213
244
  // Create uniform buffer
214
245
  const uniformBuffer = createUniformBufferWithView(
@@ -224,26 +255,29 @@ export async function dequantizeMXFP4(
224
255
  device
225
256
  );
226
257
 
227
- // Create bind group
228
- const bindGroup = device.createBindGroup({
229
- label: 'mxfp4_dequant_bind_group',
230
- layout: pipeline.getBindGroupLayout(0),
231
- entries: [
232
- { binding: 0, resource: { buffer: uniformBuffer } },
233
- { binding: 1, resource: { buffer: blocks } },
234
- { binding: 2, resource: { buffer: scales } },
235
- { binding: 3, resource: { buffer: output } },
236
- ],
237
- });
238
-
239
- const workgroups = Math.ceil(totalElements / WORKGROUP_SIZES.DEFAULT);
240
-
241
- const dispatchSize = [
242
- Math.min(workgroups, GPU_LIMITS.MAX_WORKGROUPS),
243
- Math.max(1, Math.ceil(workgroups / GPU_LIMITS.MAX_WORKGROUPS)),
244
- 1,
245
- ];
246
- dispatch(device, pipeline, bindGroup, dispatchSize, 'mxfp4_dequant');
258
+ try {
259
+ const bindGroup = device.createBindGroup({
260
+ label: 'mxfp4_dequant_bind_group',
261
+ layout: pipeline.getBindGroupLayout(0),
262
+ entries: [
263
+ { binding: 0, resource: { buffer: uniformBuffer } },
264
+ { binding: 1, resource: { buffer: blocks } },
265
+ { binding: 2, resource: { buffer: scales } },
266
+ { binding: 3, resource: { buffer: output } },
267
+ ],
268
+ });
269
+
270
+ const workgroups = Math.ceil(totalElements / WORKGROUP_SIZES.DEFAULT);
271
+ const dispatchSize = [
272
+ Math.min(workgroups, GPU_LIMITS.MAX_WORKGROUPS),
273
+ Math.max(1, Math.ceil(workgroups / GPU_LIMITS.MAX_WORKGROUPS)),
274
+ 1,
275
+ ];
276
+ dispatch(device, pipeline, bindGroup, dispatchSize, 'mxfp4_dequant');
277
+ } catch (error) {
278
+ cleanupDequantResources(uniformBuffer, [ownedOutput]);
279
+ throw error;
280
+ }
247
281
 
248
282
  releaseUniformBuffer(uniformBuffer);
249
283
 
@@ -284,7 +318,8 @@ export async function dequantizeMXFP4Expert(
284
318
  const totalOutput = outDim * numGroups * 32;
285
319
  const bytesPerElement = outputDtype === 'f16' ? 2 : 4;
286
320
  const outputSize = totalOutput * bytesPerElement;
287
- const output = outputBuffer || acquireBuffer(outputSize, undefined, 'mxfp4_expert_output');
321
+ const ownedOutput = outputBuffer ? null : acquireBuffer(outputSize, undefined, 'mxfp4_expert_output');
322
+ const output = outputBuffer || ownedOutput;
288
323
 
289
324
  // Create uniform buffer
290
325
  const uniformBuffer = createUniformBufferWithView(
@@ -301,26 +336,29 @@ export async function dequantizeMXFP4Expert(
301
336
  device
302
337
  );
303
338
 
304
- // Create bind group
305
- const bindGroup = device.createBindGroup({
306
- label: 'mxfp4_expert_bind_group',
307
- layout: pipeline.getBindGroupLayout(0),
308
- entries: [
309
- { binding: 0, resource: { buffer: uniformBuffer } },
310
- { binding: 1, resource: { buffer: blocks } },
311
- { binding: 2, resource: { buffer: scales } },
312
- { binding: 3, resource: { buffer: output } },
313
- ],
314
- });
315
-
316
- const workgroups = Math.ceil(totalOutput / WORKGROUP_SIZES.DEFAULT);
317
-
318
- const dispatchSize = [
319
- Math.min(workgroups, GPU_LIMITS.MAX_WORKGROUPS),
320
- Math.max(1, Math.ceil(workgroups / GPU_LIMITS.MAX_WORKGROUPS)),
321
- 1,
322
- ];
323
- dispatch(device, pipeline, bindGroup, dispatchSize, 'mxfp4_expert');
339
+ try {
340
+ const bindGroup = device.createBindGroup({
341
+ label: 'mxfp4_expert_bind_group',
342
+ layout: pipeline.getBindGroupLayout(0),
343
+ entries: [
344
+ { binding: 0, resource: { buffer: uniformBuffer } },
345
+ { binding: 1, resource: { buffer: blocks } },
346
+ { binding: 2, resource: { buffer: scales } },
347
+ { binding: 3, resource: { buffer: output } },
348
+ ],
349
+ });
350
+
351
+ const workgroups = Math.ceil(totalOutput / WORKGROUP_SIZES.DEFAULT);
352
+ const dispatchSize = [
353
+ Math.min(workgroups, GPU_LIMITS.MAX_WORKGROUPS),
354
+ Math.max(1, Math.ceil(workgroups / GPU_LIMITS.MAX_WORKGROUPS)),
355
+ 1,
356
+ ];
357
+ dispatch(device, pipeline, bindGroup, dispatchSize, 'mxfp4_expert');
358
+ } catch (error) {
359
+ cleanupDequantResources(uniformBuffer, [ownedOutput]);
360
+ throw error;
361
+ }
324
362
 
325
363
  releaseUniformBuffer(uniformBuffer);
326
364
 
@@ -350,7 +388,8 @@ export async function dequantizeQ6K(
350
388
  const outputSize = numBlocks * QK_K * bytesPerElem;
351
389
 
352
390
  // Create output buffer if not provided
353
- const output = outputBuffer || acquireBuffer(outputSize, undefined, 'q6k_dequant_output');
391
+ const ownedOutput = outputBuffer ? null : acquireBuffer(outputSize, undefined, 'q6k_dequant_output');
392
+ const output = outputBuffer || ownedOutput;
354
393
 
355
394
  // Calculate workgroups for 2D dispatch
356
395
  const maxWorkgroups = GPU_LIMITS.MAX_WORKGROUPS;
@@ -370,26 +409,28 @@ export async function dequantizeQ6K(
370
409
  device
371
410
  );
372
411
 
373
- // Create bind group
374
- const bindGroup = device.createBindGroup({
375
- label: 'q6k_dequant_bind_group',
376
- layout: pipeline.getBindGroupLayout(0),
377
- entries: [
378
- { binding: 0, resource: { buffer: uniformBuffer } },
379
- { binding: 1, resource: { buffer: quantized } },
380
- { binding: 2, resource: { buffer: output } },
381
- ],
382
- });
383
-
384
- // One workgroup per block, handle 2D dispatch for large counts
385
-
386
- const workgroups = [
387
- workgroupsX,
388
- numBlocks > maxWorkgroups ? Math.ceil(numBlocks / maxWorkgroups) : 1,
389
- 1
390
- ];
391
-
392
- dispatch(device, pipeline, bindGroup, workgroups, 'q6k_dequant');
412
+ try {
413
+ const bindGroup = device.createBindGroup({
414
+ label: 'q6k_dequant_bind_group',
415
+ layout: pipeline.getBindGroupLayout(0),
416
+ entries: [
417
+ { binding: 0, resource: { buffer: uniformBuffer } },
418
+ { binding: 1, resource: { buffer: quantized } },
419
+ { binding: 2, resource: { buffer: output } },
420
+ ],
421
+ });
422
+
423
+ const workgroups = [
424
+ workgroupsX,
425
+ numBlocks > maxWorkgroups ? Math.ceil(numBlocks / maxWorkgroups) : 1,
426
+ 1
427
+ ];
428
+
429
+ dispatch(device, pipeline, bindGroup, workgroups, 'q6k_dequant');
430
+ } catch (error) {
431
+ cleanupDequantResources(uniformBuffer, [ownedOutput]);
432
+ throw error;
433
+ }
393
434
 
394
435
  releaseUniformBuffer(uniformBuffer);
395
436
 
@@ -419,7 +460,8 @@ export async function dequantizeQ8_0(
419
460
  const outputSize = numBlocks * Q8_0_BLOCK_SIZE * bytesPerElem;
420
461
 
421
462
  // Create output buffer if not provided
422
- const output = outputBuffer || acquireBuffer(outputSize, undefined, 'q8_0_dequant_output');
463
+ const ownedOutput = outputBuffer ? null : acquireBuffer(outputSize, undefined, 'q8_0_dequant_output');
464
+ const output = outputBuffer || ownedOutput;
423
465
 
424
466
  // Calculate workgroups for 2D dispatch
425
467
  const maxWorkgroups = GPU_LIMITS.MAX_WORKGROUPS;
@@ -439,26 +481,28 @@ export async function dequantizeQ8_0(
439
481
  device
440
482
  );
441
483
 
442
- // Create bind group
443
- const bindGroup = device.createBindGroup({
444
- label: 'q8_0_dequant_bind_group',
445
- layout: pipeline.getBindGroupLayout(0),
446
- entries: [
447
- { binding: 0, resource: { buffer: uniformBuffer } },
448
- { binding: 1, resource: { buffer: quantized } },
449
- { binding: 2, resource: { buffer: output } },
450
- ],
451
- });
452
-
453
- // One workgroup per block, handle 2D dispatch for large counts
454
-
455
- const workgroups = [
456
- workgroupsX,
457
- numBlocks > maxWorkgroups ? Math.ceil(numBlocks / maxWorkgroups) : 1,
458
- 1
459
- ];
460
-
461
- dispatch(device, pipeline, bindGroup, workgroups, 'q8_0_dequant');
484
+ try {
485
+ const bindGroup = device.createBindGroup({
486
+ label: 'q8_0_dequant_bind_group',
487
+ layout: pipeline.getBindGroupLayout(0),
488
+ entries: [
489
+ { binding: 0, resource: { buffer: uniformBuffer } },
490
+ { binding: 1, resource: { buffer: quantized } },
491
+ { binding: 2, resource: { buffer: output } },
492
+ ],
493
+ });
494
+
495
+ const workgroups = [
496
+ workgroupsX,
497
+ numBlocks > maxWorkgroups ? Math.ceil(numBlocks / maxWorkgroups) : 1,
498
+ 1
499
+ ];
500
+
501
+ dispatch(device, pipeline, bindGroup, workgroups, 'q8_0_dequant');
502
+ } catch (error) {
503
+ cleanupDequantResources(uniformBuffer, [ownedOutput]);
504
+ throw error;
505
+ }
462
506
 
463
507
  releaseUniformBuffer(uniformBuffer);
464
508
 
@@ -491,7 +535,8 @@ export async function recordDequantize(
491
535
  const outputSize = numBlocks * QK_K * bytesPerElem;
492
536
 
493
537
  // Output buffer
494
- const output = outputBuffer || acquireBuffer(outputSize, undefined, 'dequant_output');
538
+ const ownedOutput = outputBuffer ? null : acquireBuffer(outputSize, undefined, 'dequant_output');
539
+ const output = outputBuffer || ownedOutput;
495
540
 
496
541
  // Uniform buffer
497
542
  const uniformBuffer = createUniformBufferWithView(
@@ -505,18 +550,25 @@ export async function recordDequantize(
505
550
  );
506
551
 
507
552
  // Bind group
508
- const bindGroup = device.createBindGroup({
509
- label: 'dequant_bind_group',
510
- layout: pipeline.getBindGroupLayout(0),
511
- entries: [
512
- { binding: 0, resource: { buffer: uniformBuffer } },
513
- { binding: 1, resource: { buffer: quantized } },
514
- { binding: 2, resource: { buffer: output } },
515
- ],
516
- });
517
-
518
- const workgroups = calculateDequantWorkgroups(variant, numBlocks);
519
- recordDispatch(recorder, pipeline, bindGroup, workgroups, 'dequant');
553
+ try {
554
+ const bindGroup = device.createBindGroup({
555
+ label: 'dequant_bind_group',
556
+ layout: pipeline.getBindGroupLayout(0),
557
+ entries: [
558
+ { binding: 0, resource: { buffer: uniformBuffer } },
559
+ { binding: 1, resource: { buffer: quantized } },
560
+ { binding: 2, resource: { buffer: output } },
561
+ ],
562
+ });
563
+
564
+ const workgroups = calculateDequantWorkgroups(variant, numBlocks);
565
+ recordDispatch(recorder, pipeline, bindGroup, workgroups, 'dequant');
566
+ } catch (error) {
567
+ if (ownedOutput) {
568
+ releaseBuffer(ownedOutput);
569
+ }
570
+ throw error;
571
+ }
520
572
 
521
573
 
522
574
  const dtype = selectSharedRuleValue('shared', 'dtype', 'f16OrF32FromDtype', { dtype: outputDtype });
@@ -16,6 +16,7 @@ export interface EnergyUpdateOptions {
16
16
  export interface EnergyQuintelUpdateOptions {
17
17
  count?: number;
18
18
  size?: number;
19
+ flags?: number;
19
20
  stepSize?: number;
20
21
  gradientScale?: number;
21
22
  countDiff?: number;
@@ -26,48 +27,29 @@ export interface EnergyQuintelUpdateOptions {
26
27
  centerTarget?: number;
27
28
  clampMin?: number;
28
29
  clampMax?: number;
29
- rules?: {
30
- mirrorX?: boolean;
31
- mirrorY?: boolean;
32
- diagonal?: boolean;
33
- count?: boolean;
34
- center?: boolean;
35
- };
36
30
  }
37
31
 
38
32
  export interface EnergyQuintelReduceOptions {
39
33
  count?: number;
40
34
  size?: number;
35
+ flags?: number;
41
36
  symmetryWeight?: number;
42
37
  centerWeight?: number;
43
38
  binarizeWeight?: number;
44
39
  centerTarget?: number;
45
- rules?: {
46
- mirrorX?: boolean;
47
- mirrorY?: boolean;
48
- diagonal?: boolean;
49
- count?: boolean;
50
- center?: boolean;
51
- };
52
40
  outputBuffer?: GPUBuffer | null;
53
41
  }
54
42
 
55
43
  export interface EnergyQuintelGradOptions {
56
44
  count?: number;
57
45
  size?: number;
46
+ flags?: number;
58
47
  countDiff?: number;
59
48
  symmetryWeight?: number;
60
49
  countWeight?: number;
61
50
  centerWeight?: number;
62
51
  binarizeWeight?: number;
63
52
  centerTarget?: number;
64
- rules?: {
65
- mirrorX?: boolean;
66
- mirrorY?: boolean;
67
- diagonal?: boolean;
68
- count?: boolean;
69
- center?: boolean;
70
- };
71
53
  outputBuffer?: GPUBuffer | null;
72
54
  }
73
55