@simulatte/doppler 0.1.6 → 0.1.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (316) hide show
  1. package/CHANGELOG.md +126 -0
  2. package/README.md +16 -23
  3. package/package.json +14 -1
  4. package/src/adapters/adapter-registry.js +12 -1
  5. package/src/adapters/lora-loader.js +23 -6
  6. package/src/bridge/extension-client.d.ts +5 -0
  7. package/src/bridge/extension-client.js +40 -0
  8. package/src/bridge/index.d.ts +2 -1
  9. package/src/bridge/index.js +6 -4
  10. package/src/browser/browser-converter.js +26 -1
  11. package/src/browser/file-picker.js +6 -0
  12. package/src/browser/safetensors-parser-browser.js +84 -1
  13. package/src/browser/shard-io-browser.js +2 -2
  14. package/src/browser/tensor-source-download.js +8 -2
  15. package/src/browser/tensor-source-http.d.ts +1 -0
  16. package/src/browser/tensor-source-http.js +5 -1
  17. package/src/client/doppler-api.browser.js +20 -4
  18. package/src/client/doppler-api.js +19 -3
  19. package/src/client/doppler-provider/generation.js +12 -0
  20. package/src/client/doppler-provider/model-manager.d.ts +10 -0
  21. package/src/client/doppler-provider/model-manager.js +91 -19
  22. package/src/client/doppler-provider/source-runtime.d.ts +2 -1
  23. package/src/client/doppler-provider/source-runtime.js +132 -13
  24. package/src/client/doppler-registry.json +8 -7
  25. package/src/config/backward-registry-loader.js +17 -2
  26. package/src/config/execution-v0-contract-check.js +113 -15
  27. package/src/config/kernel-path-contract-check.js +57 -29
  28. package/src/config/kernel-path-loader.js +5 -36
  29. package/src/config/kernels/kernel-ref-digests.js +1 -1
  30. package/src/config/kernels/registry.js +14 -1
  31. package/src/config/kernels/registry.json +7 -5
  32. package/src/config/loader.d.ts +1 -1
  33. package/src/config/loader.js +12 -2
  34. package/src/config/merge-contract-check.js +59 -4
  35. package/src/config/merge-helpers.js +128 -7
  36. package/src/config/merge.d.ts +1 -0
  37. package/src/config/merge.js +10 -0
  38. package/src/config/param-validator.js +47 -2
  39. package/src/config/presets/kernel-paths/{gemma2-q4k-dequant-f32a.json → gemma2-q4k-dequant-f32a-nosubgroups.json} +3 -3
  40. package/src/config/presets/kernel-paths/gemma3-f16-fused-f32a-online-streamingprefill.json +223 -0
  41. package/src/config/presets/kernel-paths/{gemma3-q4k-dequant-f32a.json → gemma3-q4k-dequant-f32a-nosubgroups.json} +3 -3
  42. package/src/config/presets/kernel-paths/registry.json +29 -8
  43. package/src/config/presets/models/gemma2.json +2 -2
  44. package/src/config/presets/runtime/experiments/bench/gemma3-bench-q4k.json +1 -1
  45. package/src/config/presets/runtime/experiments/debug/gemma3-debug-q4k.json +1 -1
  46. package/src/config/presets/runtime/experiments/verify/gemma3-verify.json +1 -1
  47. package/src/config/presets/runtime/kernels/dequant-f16-q4k.json +6 -13
  48. package/src/config/presets/runtime/kernels/dequant-f32-q4k.json +6 -13
  49. package/src/config/presets/runtime/kernels/embeddinggemma-q4k-dequant-f32a.json +37 -0
  50. package/src/config/presets/runtime/kernels/fused-q4k.json +6 -13
  51. package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f16a.json +33 -0
  52. package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f32a-nosubgroups.json +33 -0
  53. package/src/config/presets/runtime/kernels/gemma2-q4k-fused-f32a.json +33 -0
  54. package/src/config/presets/runtime/kernels/safe-q4k.json +6 -13
  55. package/src/config/presets/runtime/platform/metal-apple-q4k.json +1 -1
  56. package/src/config/runtime.js +6 -1
  57. package/src/config/schema/debug.schema.d.ts +5 -0
  58. package/src/config/schema/doppler.schema.js +16 -21
  59. package/src/config/schema/inference-defaults.schema.js +3 -3
  60. package/src/config/schema/kernel-path.schema.d.ts +5 -1
  61. package/src/config/schema/kernel-thresholds.schema.js +12 -4
  62. package/src/config/schema/manifest.schema.d.ts +2 -1
  63. package/src/config/schema/manifest.schema.js +16 -3
  64. package/src/config/training-defaults.js +30 -22
  65. package/src/converter/conversion-plan.js +94 -9
  66. package/src/converter/core.d.ts +7 -0
  67. package/src/converter/core.js +14 -9
  68. package/src/converter/execution-v0-manifest.js +4 -1
  69. package/src/converter/index.d.ts +1 -0
  70. package/src/converter/index.js +1 -0
  71. package/src/converter/manifest-inference.js +43 -12
  72. package/src/converter/parsers/diffusion.js +0 -3
  73. package/src/converter/quantization-info.js +35 -15
  74. package/src/converter/shard-packer.d.ts +1 -1
  75. package/src/converter/shard-packer.js +4 -1
  76. package/src/debug/config.js +123 -11
  77. package/src/debug/signals.js +7 -1
  78. package/src/debug/tensor.d.ts +2 -0
  79. package/src/debug/tensor.js +13 -2
  80. package/src/distribution/p2p-control-plane.js +52 -12
  81. package/src/distribution/p2p-observability.js +43 -7
  82. package/src/distribution/p2p-webrtc-browser.js +20 -0
  83. package/src/distribution/shard-delivery.js +77 -26
  84. package/src/formats/gguf/types.js +33 -16
  85. package/src/formats/rdrr/groups.d.ts +12 -4
  86. package/src/formats/rdrr/groups.js +3 -6
  87. package/src/formats/rdrr/parsing.js +39 -2
  88. package/src/formats/rdrr/types.d.ts +2 -1
  89. package/src/gpu/command-recorder.js +86 -61
  90. package/src/gpu/device.d.ts +1 -0
  91. package/src/gpu/device.js +73 -19
  92. package/src/gpu/kernel-tuner/benchmarks.js +326 -316
  93. package/src/gpu/kernel-tuner/cache.js +71 -4
  94. package/src/gpu/kernel-tuner/tuner.js +22 -4
  95. package/src/gpu/kernels/attention.js +15 -34
  96. package/src/gpu/kernels/backward/adam.js +62 -58
  97. package/src/gpu/kernels/backward/attention_backward.js +257 -169
  98. package/src/gpu/kernels/backward/conv2d_backward.js +14 -1
  99. package/src/gpu/kernels/cast.js +191 -149
  100. package/src/gpu/kernels/check-stop.js +33 -44
  101. package/src/gpu/kernels/conv2d.js +27 -17
  102. package/src/gpu/kernels/cross_entropy_loss.js +21 -15
  103. package/src/gpu/kernels/depthwise_conv2d.js +36 -26
  104. package/src/gpu/kernels/dequant.js +178 -126
  105. package/src/gpu/kernels/energy.d.ts +3 -21
  106. package/src/gpu/kernels/energy.js +111 -88
  107. package/src/gpu/kernels/feature-check.js +1 -1
  108. package/src/gpu/kernels/fused_ffn.js +84 -65
  109. package/src/gpu/kernels/fused_matmul_residual.js +56 -33
  110. package/src/gpu/kernels/fused_matmul_rmsnorm.js +62 -45
  111. package/src/gpu/kernels/gather.js +33 -15
  112. package/src/gpu/kernels/gelu.js +19 -11
  113. package/src/gpu/kernels/grouped_pointwise_conv2d.js +33 -23
  114. package/src/gpu/kernels/groupnorm.js +34 -23
  115. package/src/gpu/kernels/kv-quantize.js +5 -2
  116. package/src/gpu/kernels/layernorm.js +35 -19
  117. package/src/gpu/kernels/logit-merge.js +5 -3
  118. package/src/gpu/kernels/matmul.js +58 -39
  119. package/src/gpu/kernels/modulate.js +23 -15
  120. package/src/gpu/kernels/moe.js +221 -175
  121. package/src/gpu/kernels/pixel_shuffle.js +22 -14
  122. package/src/gpu/kernels/relu.js +18 -10
  123. package/src/gpu/kernels/repeat_channels.js +25 -17
  124. package/src/gpu/kernels/residual.js +37 -27
  125. package/src/gpu/kernels/rmsnorm.js +57 -41
  126. package/src/gpu/kernels/rope.js +3 -0
  127. package/src/gpu/kernels/sample.js +27 -38
  128. package/src/gpu/kernels/sana_linear_attention.js +18 -10
  129. package/src/gpu/kernels/scale.js +18 -11
  130. package/src/gpu/kernels/shader-cache.js +4 -2
  131. package/src/gpu/kernels/silu.js +120 -72
  132. package/src/gpu/kernels/softmax.js +44 -25
  133. package/src/gpu/kernels/split_qkv.js +23 -13
  134. package/src/gpu/kernels/transpose.js +18 -10
  135. package/src/gpu/kernels/transpose.wgsl +5 -3
  136. package/src/gpu/kernels/upsample2d.js +21 -13
  137. package/src/gpu/kernels/utils.js +20 -13
  138. package/src/gpu/partitioned-buffer-pool.js +10 -2
  139. package/src/gpu/perf-guards.js +2 -9
  140. package/src/gpu/profiler.js +27 -22
  141. package/src/gpu/readback-utils.d.ts +16 -0
  142. package/src/gpu/readback-utils.js +41 -0
  143. package/src/gpu/submit-tracker.js +13 -0
  144. package/src/gpu/uniform-cache.d.ts +1 -0
  145. package/src/gpu/uniform-cache.js +30 -9
  146. package/src/hotswap/intent-bundle.js +6 -0
  147. package/src/hotswap/manifest.d.ts +10 -1
  148. package/src/hotswap/manifest.js +12 -2
  149. package/src/hotswap/runtime.js +30 -8
  150. package/src/index-browser.d.ts +44 -0
  151. package/src/index-browser.js +14 -0
  152. package/src/inference/browser-harness-contract-helpers.d.ts +5 -0
  153. package/src/inference/browser-harness-contract-helpers.js +28 -0
  154. package/src/inference/browser-harness-diffusion-energy-suites.d.ts +2 -0
  155. package/src/inference/browser-harness-diffusion-energy-suites.js +269 -0
  156. package/src/inference/browser-harness-model-helpers.d.ts +16 -0
  157. package/src/inference/browser-harness-model-helpers.js +217 -0
  158. package/src/inference/browser-harness-report-helpers.d.ts +7 -0
  159. package/src/inference/browser-harness-report-helpers.js +42 -0
  160. package/src/inference/browser-harness-runtime-helpers.d.ts +61 -0
  161. package/src/inference/browser-harness-runtime-helpers.js +415 -0
  162. package/src/inference/browser-harness-suite-helpers.d.ts +28 -0
  163. package/src/inference/browser-harness-suite-helpers.js +268 -0
  164. package/src/inference/browser-harness-text-helpers.d.ts +27 -0
  165. package/src/inference/browser-harness-text-helpers.js +788 -0
  166. package/src/inference/browser-harness.d.ts +6 -0
  167. package/src/inference/browser-harness.js +130 -1996
  168. package/src/inference/kv-cache/base.js +140 -94
  169. package/src/inference/kv-cache/tiered.js +5 -3
  170. package/src/inference/moe-router.js +88 -56
  171. package/src/inference/multi-model-network.js +5 -3
  172. package/src/inference/network-evolution.d.ts +11 -2
  173. package/src/inference/network-evolution.js +20 -21
  174. package/src/inference/pipelines/context.d.ts +3 -0
  175. package/src/inference/pipelines/context.js +142 -2
  176. package/src/inference/pipelines/diffusion/helpers.js +7 -2
  177. package/src/inference/pipelines/diffusion/pipeline.js +2 -1
  178. package/src/inference/pipelines/diffusion/sd3-transformer.js +10 -10
  179. package/src/inference/pipelines/diffusion/vae.js +3 -7
  180. package/src/inference/pipelines/energy/pipeline.js +27 -21
  181. package/src/inference/pipelines/energy/quintel.d.ts +5 -0
  182. package/src/inference/pipelines/energy/quintel.js +11 -0
  183. package/src/inference/pipelines/energy-head/row-head-pipeline.js +17 -13
  184. package/src/inference/pipelines/structured/json-head-pipeline.js +26 -11
  185. package/src/inference/pipelines/text/attention/projections.js +151 -101
  186. package/src/inference/pipelines/text/attention/record.js +62 -8
  187. package/src/inference/pipelines/text/attention/run.js +62 -8
  188. package/src/inference/pipelines/text/config.js +3 -4
  189. package/src/inference/pipelines/text/embed.js +2 -8
  190. package/src/inference/pipelines/text/execution-plan.js +41 -19
  191. package/src/inference/pipelines/text/execution-v0-contract-helpers.d.ts +59 -0
  192. package/src/inference/pipelines/text/execution-v0-contract-helpers.js +937 -0
  193. package/src/inference/pipelines/text/execution-v0-runtime-builders.d.ts +15 -0
  194. package/src/inference/pipelines/text/execution-v0-runtime-builders.js +279 -0
  195. package/src/inference/pipelines/text/execution-v0.js +62 -1013
  196. package/src/inference/pipelines/text/generator-steps.d.ts +46 -0
  197. package/src/inference/pipelines/text/generator-steps.js +298 -207
  198. package/src/inference/pipelines/text/generator.js +6 -23
  199. package/src/inference/pipelines/text/init.js +78 -20
  200. package/src/inference/pipelines/text/kernel-path-auto-select.js +2 -0
  201. package/src/inference/pipelines/text/kernel-trace.d.ts +2 -0
  202. package/src/inference/pipelines/text/kernel-trace.js +6 -0
  203. package/src/inference/pipelines/text/layer.js +3 -9
  204. package/src/inference/pipelines/text/linear-attention.d.ts +10 -0
  205. package/src/inference/pipelines/text/linear-attention.js +80 -6
  206. package/src/inference/pipelines/text/logits/gpu.js +10 -5
  207. package/src/inference/pipelines/text/logits/index.js +10 -11
  208. package/src/inference/pipelines/text/logits/utils.d.ts +7 -0
  209. package/src/inference/pipelines/text/logits/utils.js +9 -0
  210. package/src/inference/pipelines/text/lora-apply.js +50 -32
  211. package/src/inference/pipelines/text/model-load.js +279 -104
  212. package/src/inference/pipelines/text/moe-cache.js +5 -4
  213. package/src/inference/pipelines/text/moe-cpu-gptoss.js +74 -69
  214. package/src/inference/pipelines/text/moe-cpu.js +42 -38
  215. package/src/inference/pipelines/text/moe-gpu.js +110 -86
  216. package/src/inference/pipelines/text/ops.js +90 -90
  217. package/src/inference/pipelines/text/probes.js +9 -9
  218. package/src/inference/pipelines/text/weights.js +17 -7
  219. package/src/inference/pipelines/text.js +13 -1
  220. package/src/inference/speculative.d.ts +2 -2
  221. package/src/inference/speculative.js +4 -18
  222. package/src/inference/test-harness.d.ts +1 -1
  223. package/src/inference/test-harness.js +15 -5
  224. package/src/inference/tokenizer.d.ts +0 -5
  225. package/src/inference/tokenizer.js +4 -23
  226. package/src/inference/tokenizers/bpe.js +9 -0
  227. package/src/inference/tokenizers/bundled.js +20 -0
  228. package/src/inference/tokenizers/sentencepiece.js +12 -0
  229. package/src/loader/doppler-loader.js +38 -22
  230. package/src/loader/dtype-utils.js +3 -44
  231. package/src/loader/embedding-loader.js +7 -3
  232. package/src/loader/experts/expert-cache.js +13 -6
  233. package/src/loader/experts/expert-loader.js +10 -6
  234. package/src/loader/final-weights-loader.js +8 -4
  235. package/src/loader/layer-loader.js +2 -1
  236. package/src/loader/loader-state.js +2 -2
  237. package/src/loader/memory-monitor.js +8 -0
  238. package/src/loader/multi-model-loader.d.ts +14 -0
  239. package/src/loader/multi-model-loader.js +70 -24
  240. package/src/loader/shard-cache.js +81 -12
  241. package/src/loader/shard-resolver.js +25 -3
  242. package/src/loader/tensors/tensor-loader.js +209 -144
  243. package/src/loader/tensors/tensor-reader.js +76 -19
  244. package/src/loader/weight-downcast.js +1 -1
  245. package/src/memory/buffer-pool.d.ts +9 -1
  246. package/src/memory/buffer-pool.js +109 -44
  247. package/src/memory/unified-detect.js +1 -1
  248. package/src/rules/inference/kernel-path.rules.json +24 -8
  249. package/src/rules/rule-registry.js +25 -1
  250. package/src/storage/backends/opfs-store.js +68 -24
  251. package/src/storage/downloader.js +364 -83
  252. package/src/storage/index.d.ts +3 -0
  253. package/src/storage/index.js +3 -0
  254. package/src/storage/preflight.d.ts +2 -2
  255. package/src/storage/preflight.js +24 -2
  256. package/src/storage/quickstart-downloader.js +11 -5
  257. package/src/storage/registry.js +10 -4
  258. package/src/storage/reports.js +1 -1
  259. package/src/storage/shard-manager.d.ts +15 -1
  260. package/src/storage/shard-manager.js +51 -3
  261. package/src/storage/source-artifact-store.d.ts +52 -0
  262. package/src/storage/source-artifact-store.js +234 -0
  263. package/src/tooling/command-api-constants.d.ts +9 -0
  264. package/src/tooling/command-api-constants.js +9 -0
  265. package/src/tooling/command-api-family-normalizers.d.ts +9 -0
  266. package/src/tooling/command-api-family-normalizers.js +343 -0
  267. package/src/tooling/command-api-helpers.d.ts +25 -0
  268. package/src/tooling/command-api-helpers.js +262 -0
  269. package/src/tooling/command-api.js +16 -602
  270. package/src/tooling/command-envelope.js +4 -1
  271. package/src/tooling/command-runner-shared.js +52 -18
  272. package/src/tooling/lean-execution-contract.js +150 -3
  273. package/src/tooling/node-browser-command-runner.js +161 -271
  274. package/src/tooling/node-command-runner.js +29 -3
  275. package/src/tooling/node-converter.js +27 -1
  276. package/src/tooling/node-source-runtime.d.ts +1 -1
  277. package/src/tooling/node-source-runtime.js +84 -3
  278. package/src/tooling/node-webgpu.js +24 -21
  279. package/src/tooling/opfs-cache.js +21 -4
  280. package/src/tooling/runtime-input-composition.d.ts +38 -0
  281. package/src/tooling/runtime-input-composition.js +86 -0
  282. package/src/tooling/source-runtime-bundle.d.ts +40 -5
  283. package/src/tooling/source-runtime-bundle.js +261 -34
  284. package/src/tooling/source-runtime-materializer.d.ts +6 -0
  285. package/src/tooling/source-runtime-materializer.js +93 -0
  286. package/src/training/attention-backward.js +32 -17
  287. package/src/training/autograd.js +80 -52
  288. package/src/training/checkpoint-watch.d.ts +2 -1
  289. package/src/training/checkpoint-watch.js +39 -6
  290. package/src/training/checkpoint.js +40 -11
  291. package/src/training/clip.js +2 -1
  292. package/src/training/datasets/token-batch.js +20 -8
  293. package/src/training/distillation/checkpoint-watch.js +1 -0
  294. package/src/training/distillation/student-fixture.d.ts +22 -0
  295. package/src/training/distillation/student-fixture.js +846 -0
  296. package/src/training/distillation/suite-data.d.ts +45 -0
  297. package/src/training/distillation/suite-data.js +189 -0
  298. package/src/training/lora-pipeline.js +4 -7
  299. package/src/training/lora.js +26 -12
  300. package/src/training/loss.js +5 -6
  301. package/src/training/objectives/cross_entropy.js +2 -5
  302. package/src/training/objectives/distill_kd.js +4 -8
  303. package/src/training/objectives/distill_triplet.js +4 -8
  304. package/src/training/objectives/ul_stage2_base.js +4 -8
  305. package/src/training/operator-command.js +2 -0
  306. package/src/training/optimizer.js +19 -7
  307. package/src/training/runner.js +2 -1
  308. package/src/training/suite.js +18 -978
  309. package/src/training/tensor-factory.d.ts +9 -0
  310. package/src/training/tensor-factory.js +13 -0
  311. package/src/training/trainer.js +3 -5
  312. package/src/training/ul_dataset.js +3 -5
  313. package/src/training/workloads.js +70 -79
  314. package/src/version.js +1 -1
  315. package/tools/convert-safetensors-node.js +22 -16
  316. package/tools/doppler-cli.js +44 -25
@@ -8,11 +8,67 @@ export function getTunerConfig() {
8
8
  return getRuntimeConfig().shared.tuner;
9
9
  }
10
10
 
11
+ function normalizeSignaturePart(value) {
12
+ return String(value ?? '')
13
+ .trim()
14
+ .replace(/[^a-zA-Z0-9]/g, '_');
15
+ }
16
+
17
+ function hasAdapterIdentity(info) {
18
+ if (!info || typeof info !== 'object') {
19
+ return false;
20
+ }
21
+ return ['vendor', 'architecture', 'device'].some((key) => {
22
+ const value = String(info[key] ?? '').trim().toLowerCase();
23
+ return value.length > 0 && value !== 'unknown';
24
+ });
25
+ }
26
+
27
+ function getFallbackSignature(capabilities) {
28
+ return [
29
+ capabilities?.hasF16 ? 'f16' : 'nof16',
30
+ capabilities?.hasSubgroups ? 'subgroups' : 'nosubgroups',
31
+ capabilities?.hasSubgroupsF16 ? 'subgroupsf16' : 'nosubgroupsf16',
32
+ capabilities?.hasTimestampQuery ? 'timestamp' : 'notimestamp',
33
+ `buf${Number.isFinite(capabilities?.maxBufferSize) ? capabilities.maxBufferSize : 'na'}`,
34
+ `wg${Number.isFinite(capabilities?.maxWorkgroupSize) ? capabilities.maxWorkgroupSize : 'na'}`,
35
+ `wgs${Number.isFinite(capabilities?.maxWorkgroupStorageSize) ? capabilities.maxWorkgroupStorageSize : 'na'}`,
36
+ ].join('_');
37
+ }
38
+
39
+ function isValidDeviceInfo(value) {
40
+ if (value == null) {
41
+ return true;
42
+ }
43
+ return typeof value === 'object';
44
+ }
45
+
46
+ function isValidTuneRecord(value) {
47
+ return !!value
48
+ && typeof value === 'object'
49
+ && Array.isArray(value.optimalWorkgroupSize)
50
+ && value.optimalWorkgroupSize.length === 3
51
+ && value.optimalWorkgroupSize.every((entry) => Number.isFinite(entry) && entry >= 0)
52
+ && Number.isFinite(value.optimalTileSize)
53
+ && value.optimalTileSize >= 0
54
+ && Number.isFinite(value.throughput)
55
+ && value.throughput >= 0
56
+ && Number.isFinite(value.timeMs)
57
+ && value.timeMs >= 0
58
+ && isValidDeviceInfo(value.deviceInfo);
59
+ }
60
+
11
61
 
12
62
  export function getDeviceSignature(capabilities) {
13
-
14
- const info = capabilities?.adapterInfo || { vendor: '', architecture: '', device: '' };
15
- return `${info.vendor}_${info.architecture}_${info.device}`.replace(/[^a-zA-Z0-9]/g, '_');
63
+ const info = capabilities?.adapterInfo;
64
+ if (hasAdapterIdentity(info)) {
65
+ return [
66
+ normalizeSignaturePart(info.vendor),
67
+ normalizeSignaturePart(info.architecture),
68
+ normalizeSignaturePart(info.device),
69
+ ].join('_');
70
+ }
71
+ return getFallbackSignature(capabilities);
16
72
  }
17
73
 
18
74
 
@@ -33,10 +89,21 @@ export function loadCache(capabilities) {
33
89
  const cached = localStorage.getItem(cacheKey);
34
90
  if (cached) {
35
91
  const data = JSON.parse(cached);
36
- return new Map(Object.entries(data));
92
+ if (!data || typeof data !== 'object' || Array.isArray(data)) {
93
+ throw new Error('Kernel tuner cache payload must be an object.');
94
+ }
95
+ const records = [];
96
+ for (const [key, value] of Object.entries(data)) {
97
+ if (!isValidTuneRecord(value)) {
98
+ throw new Error(`Kernel tuner cache record "${key}" is malformed.`);
99
+ }
100
+ records.push([key, value]);
101
+ }
102
+ return new Map(records);
37
103
  }
38
104
  } catch (e) {
39
105
  log.warn('KernelTuner', `Failed to load cache: ${e}`);
106
+ localStorage.removeItem(cacheKey);
40
107
  }
41
108
 
42
109
  return new Map();
@@ -1,6 +1,6 @@
1
1
 
2
2
 
3
- import { getDevice, getKernelCapabilities, getDeviceLimits } from '../device.js';
3
+ import { getDevice, getKernelCapabilities, getDeviceLimits, getDeviceEpoch } from '../device.js';
4
4
  import { getKernelThresholds } from '../../config/schema/index.js';
5
5
  import { GPUProfiler } from '../profiler.js';
6
6
  import {
@@ -36,21 +36,33 @@ export class KernelTuner {
36
36
 
37
37
  #cache;
38
38
 
39
+ #deviceEpoch;
40
+
39
41
  constructor() {
40
42
  this.#device = null;
41
43
  this.#profiler = null;
42
44
  this.#limits = null;
43
45
  this.#capabilities = null;
44
46
  this.#cache = new Map();
47
+ this.#deviceEpoch = -1;
45
48
  }
46
49
 
47
50
 
48
51
  async init() {
49
- this.#device = getDevice();
50
- if (!this.#device) {
52
+ const device = getDevice();
53
+ if (!device) {
51
54
  throw new Error('GPU device not initialized');
52
55
  }
53
56
 
57
+ const deviceEpoch = getDeviceEpoch();
58
+ if (this.#device === device && this.#deviceEpoch === deviceEpoch && this.#profiler) {
59
+ return;
60
+ }
61
+
62
+ this.destroy();
63
+
64
+ this.#device = device;
65
+ this.#deviceEpoch = deviceEpoch;
54
66
  this.#profiler = new GPUProfiler(this.#device);
55
67
  this.#limits = getDeviceLimits();
56
68
  this.#capabilities = getKernelCapabilities();
@@ -203,6 +215,12 @@ export class KernelTuner {
203
215
  if (this.#profiler) {
204
216
  this.#profiler.destroy();
205
217
  }
218
+ this.#profiler = null;
219
+ this.#device = null;
220
+ this.#limits = null;
221
+ this.#capabilities = null;
222
+ this.#cache = new Map();
223
+ this.#deviceEpoch = -1;
206
224
  }
207
225
  }
208
226
 
@@ -214,8 +232,8 @@ let globalTuner = null;
214
232
  export async function getKernelTuner() {
215
233
  if (!globalTuner) {
216
234
  globalTuner = new KernelTuner();
217
- await globalTuner.init();
218
235
  }
236
+ await globalTuner.init();
219
237
  return globalTuner;
220
238
  }
221
239
 
@@ -18,48 +18,29 @@ import { logKernelSelectionOnce } from '../kernel-selection-log.js';
18
18
  // Track if we've logged the attention tier selection (avoid spam)
19
19
  let loggedAttentionTier = false;
20
20
 
21
-
22
- let _chunkedMaxKVLen = null;
23
-
24
-
25
- function getChunkedMaxKVLen() {
26
- if (_chunkedMaxKVLen === null) {
27
- const config = getKernelConfig('attention', 'decode_chunked_f16kv');
28
- const maxKVLen = config.variantMetadata?.maxKVLen;
29
- if (!Number.isFinite(maxKVLen)) {
30
- throw new Error('Kernel config missing attention.decode_chunked_f16kv maxKVLen');
31
- }
32
- _chunkedMaxKVLen = maxKVLen;
21
+ function getRequiredVariantMaxKVLen(operation, variant, errorLabel) {
22
+ const config = getKernelConfig(operation, variant);
23
+ const maxKVLen = config.variantMetadata?.maxKVLen;
24
+ if (!Number.isFinite(maxKVLen)) {
25
+ throw new Error(`Kernel config missing ${errorLabel} maxKVLen`);
33
26
  }
34
- return _chunkedMaxKVLen;
27
+ return maxKVLen;
35
28
  }
36
29
 
37
- let _tieredMaxKVLen = null;
30
+ function getChunkedMaxKVLen() {
31
+ return getRequiredVariantMaxKVLen('attention', 'decode_chunked_f16kv', 'attention.decode_chunked_f16kv');
32
+ }
38
33
 
39
34
  function getTieredMaxKVLen() {
40
- if (_tieredMaxKVLen === null) {
41
- const config = getKernelConfig('attention_tiered', 'decode_tiered_f16');
42
- const maxKVLen = config.variantMetadata?.maxKVLen;
43
- if (!Number.isFinite(maxKVLen)) {
44
- throw new Error('Kernel config missing attention_tiered.decode_tiered_f16 maxKVLen');
45
- }
46
- _tieredMaxKVLen = maxKVLen;
47
- }
48
- return _tieredMaxKVLen;
35
+ return getRequiredVariantMaxKVLen('attention_tiered', 'decode_tiered_f16', 'attention_tiered.decode_tiered_f16');
49
36
  }
50
37
 
51
- let _tieredQuantMaxKVLen = null;
52
-
53
38
  function getTieredQuantMaxKVLen() {
54
- if (_tieredQuantMaxKVLen === null) {
55
- const config = getKernelConfig('attention_tiered_quant', 'decode_tiered_int8_f16kv');
56
- const maxKVLen = config.variantMetadata?.maxKVLen;
57
- if (!Number.isFinite(maxKVLen)) {
58
- throw new Error('Kernel config missing attention_tiered_quant.decode_tiered_int8_f16kv maxKVLen');
59
- }
60
- _tieredQuantMaxKVLen = maxKVLen;
61
- }
62
- return _tieredQuantMaxKVLen;
39
+ return getRequiredVariantMaxKVLen(
40
+ 'attention_tiered_quant',
41
+ 'decode_tiered_int8_f16kv',
42
+ 'attention_tiered_quant.decode_tiered_int8_f16kv'
43
+ );
63
44
  }
64
45
 
65
46
 
@@ -25,60 +25,62 @@ async function runAdamChunked(device, pipeline, params, grads, moment1, moment2,
25
25
  const gradsChunkBuffer = acquireBuffer(chunkBytes, undefined, 'adam_grads_chunk');
26
26
  const mChunkBuffer = acquireBuffer(chunkBytes, undefined, 'adam_m_chunk');
27
27
  const vChunkBuffer = acquireBuffer(chunkBytes, undefined, 'adam_v_chunk');
28
-
29
- const copyIn = device.createCommandEncoder();
30
- copyIn.copyBufferToBuffer(params.buffer, chunkOffsetBytes, paramsChunkBuffer, 0, chunkBytes);
31
- copyIn.copyBufferToBuffer(grads.buffer, chunkOffsetBytes, gradsChunkBuffer, 0, chunkBytes);
32
- copyIn.copyBufferToBuffer(moment1.buffer, chunkOffsetBytes, mChunkBuffer, 0, chunkBytes);
33
- copyIn.copyBufferToBuffer(moment2.buffer, chunkOffsetBytes, vChunkBuffer, 0, chunkBytes);
34
- device.queue.submit([copyIn.finish()]);
35
-
36
- const paramsChunk = createTensor(paramsChunkBuffer, params.dtype, [chunkCount], 'adam_params_chunk');
37
- const gradsChunk = createTensor(gradsChunkBuffer, grads.dtype, [chunkCount], 'adam_grads_chunk');
38
- const mChunk = createTensor(mChunkBuffer, moment1.dtype, [chunkCount], 'adam_m_chunk');
39
- const vChunk = createTensor(vChunkBuffer, moment2.dtype, [chunkCount], 'adam_v_chunk');
40
-
41
- const uniformBuffer = createUniformBufferWithView(
42
- 'adam_uniforms_chunk',
43
- 32,
44
- (view) => {
45
- view.setUint32(0, chunkCount, true);
46
- view.setUint32(4, options.step, true);
47
- view.setFloat32(8, options.lr, true);
48
- view.setFloat32(12, options.beta1, true);
49
- view.setFloat32(16, options.beta2, true);
50
- view.setFloat32(20, options.eps, true);
51
- },
52
- null,
53
- device
54
- );
55
-
56
- const bindGroup = device.createBindGroup({
57
- label: 'adam_bind_group_chunk',
58
- layout: pipeline.getBindGroupLayout(0),
59
- entries: [
60
- { binding: 0, resource: { buffer: uniformBuffer } },
61
- { binding: 1, resource: { buffer: paramsChunk.buffer } },
62
- { binding: 2, resource: { buffer: gradsChunk.buffer } },
63
- { binding: 3, resource: { buffer: mChunk.buffer } },
64
- { binding: 4, resource: { buffer: vChunk.buffer } },
65
- ],
66
- });
67
-
68
- const workgroups = Math.ceil(chunkCount / WORKGROUP_SIZES.DEFAULT);
69
- dispatch(device, pipeline, bindGroup, workgroups, 'adam_chunk');
70
- uniformBuffer.destroy();
71
-
72
- const copyOut = device.createCommandEncoder();
73
- copyOut.copyBufferToBuffer(paramsChunk.buffer, 0, params.buffer, chunkOffsetBytes, chunkBytes);
74
- copyOut.copyBufferToBuffer(mChunk.buffer, 0, moment1.buffer, chunkOffsetBytes, chunkBytes);
75
- copyOut.copyBufferToBuffer(vChunk.buffer, 0, moment2.buffer, chunkOffsetBytes, chunkBytes);
76
- device.queue.submit([copyOut.finish()]);
77
-
78
- releaseBuffer(paramsChunkBuffer);
79
- releaseBuffer(gradsChunkBuffer);
80
- releaseBuffer(mChunkBuffer);
81
- releaseBuffer(vChunkBuffer);
28
+ let uniformBuffer = null;
29
+ try {
30
+ const copyIn = device.createCommandEncoder();
31
+ copyIn.copyBufferToBuffer(params.buffer, chunkOffsetBytes, paramsChunkBuffer, 0, chunkBytes);
32
+ copyIn.copyBufferToBuffer(grads.buffer, chunkOffsetBytes, gradsChunkBuffer, 0, chunkBytes);
33
+ copyIn.copyBufferToBuffer(moment1.buffer, chunkOffsetBytes, mChunkBuffer, 0, chunkBytes);
34
+ copyIn.copyBufferToBuffer(moment2.buffer, chunkOffsetBytes, vChunkBuffer, 0, chunkBytes);
35
+ device.queue.submit([copyIn.finish()]);
36
+
37
+ const paramsChunk = createTensor(paramsChunkBuffer, params.dtype, [chunkCount], 'adam_params_chunk');
38
+ const gradsChunk = createTensor(gradsChunkBuffer, grads.dtype, [chunkCount], 'adam_grads_chunk');
39
+ const mChunk = createTensor(mChunkBuffer, moment1.dtype, [chunkCount], 'adam_m_chunk');
40
+ const vChunk = createTensor(vChunkBuffer, moment2.dtype, [chunkCount], 'adam_v_chunk');
41
+
42
+ uniformBuffer = createUniformBufferWithView(
43
+ 'adam_uniforms_chunk',
44
+ 32,
45
+ (view) => {
46
+ view.setUint32(0, chunkCount, true);
47
+ view.setUint32(4, options.step, true);
48
+ view.setFloat32(8, options.lr, true);
49
+ view.setFloat32(12, options.beta1, true);
50
+ view.setFloat32(16, options.beta2, true);
51
+ view.setFloat32(20, options.eps, true);
52
+ },
53
+ null,
54
+ device
55
+ );
56
+
57
+ const bindGroup = device.createBindGroup({
58
+ label: 'adam_bind_group_chunk',
59
+ layout: pipeline.getBindGroupLayout(0),
60
+ entries: [
61
+ { binding: 0, resource: { buffer: uniformBuffer } },
62
+ { binding: 1, resource: { buffer: paramsChunk.buffer } },
63
+ { binding: 2, resource: { buffer: gradsChunk.buffer } },
64
+ { binding: 3, resource: { buffer: mChunk.buffer } },
65
+ { binding: 4, resource: { buffer: vChunk.buffer } },
66
+ ],
67
+ });
68
+
69
+ const workgroups = Math.ceil(chunkCount / WORKGROUP_SIZES.DEFAULT);
70
+ dispatch(device, pipeline, bindGroup, workgroups, 'adam_chunk');
71
+
72
+ const copyOut = device.createCommandEncoder();
73
+ copyOut.copyBufferToBuffer(paramsChunk.buffer, 0, params.buffer, chunkOffsetBytes, chunkBytes);
74
+ copyOut.copyBufferToBuffer(mChunk.buffer, 0, moment1.buffer, chunkOffsetBytes, chunkBytes);
75
+ copyOut.copyBufferToBuffer(vChunk.buffer, 0, moment2.buffer, chunkOffsetBytes, chunkBytes);
76
+ device.queue.submit([copyOut.finish()]);
77
+ } finally {
78
+ uniformBuffer?.destroy();
79
+ releaseBuffer(paramsChunkBuffer);
80
+ releaseBuffer(gradsChunkBuffer);
81
+ releaseBuffer(mChunkBuffer);
82
+ releaseBuffer(vChunkBuffer);
83
+ }
82
84
  }
83
85
  }
84
86
 
@@ -143,10 +145,12 @@ export async function runAdam(
143
145
  ],
144
146
  });
145
147
 
146
- const workgroups = Math.ceil(inferredCount / WORKGROUP_SIZES.DEFAULT);
147
- dispatch(device, pipeline, bindGroup, workgroups, 'adam');
148
-
149
- uniformBuffer.destroy();
148
+ try {
149
+ const workgroups = Math.ceil(inferredCount / WORKGROUP_SIZES.DEFAULT);
150
+ dispatch(device, pipeline, bindGroup, workgroups, 'adam');
151
+ } finally {
152
+ uniformBuffer.destroy();
153
+ }
150
154
 
151
155
  return createTensor(params.buffer, params.dtype, [...params.shape], 'adam_params');
152
156
  }