@simulatte/doppler 0.1.6 → 0.1.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (316) hide show
  1. package/CHANGELOG.md +126 -0
  2. package/README.md +16 -23
  3. package/package.json +14 -1
  4. package/src/adapters/adapter-registry.js +12 -1
  5. package/src/adapters/lora-loader.js +23 -6
  6. package/src/bridge/extension-client.d.ts +5 -0
  7. package/src/bridge/extension-client.js +40 -0
  8. package/src/bridge/index.d.ts +2 -1
  9. package/src/bridge/index.js +6 -4
  10. package/src/browser/browser-converter.js +26 -1
  11. package/src/browser/file-picker.js +6 -0
  12. package/src/browser/safetensors-parser-browser.js +84 -1
  13. package/src/browser/shard-io-browser.js +2 -2
  14. package/src/browser/tensor-source-download.js +8 -2
  15. package/src/browser/tensor-source-http.d.ts +1 -0
  16. package/src/browser/tensor-source-http.js +5 -1
  17. package/src/client/doppler-api.browser.js +20 -4
  18. package/src/client/doppler-api.js +19 -3
  19. package/src/client/doppler-provider/generation.js +12 -0
  20. package/src/client/doppler-provider/model-manager.d.ts +10 -0
  21. package/src/client/doppler-provider/model-manager.js +91 -19
  22. package/src/client/doppler-provider/source-runtime.d.ts +2 -1
  23. package/src/client/doppler-provider/source-runtime.js +132 -13
  24. package/src/client/doppler-registry.json +8 -7
  25. package/src/config/backward-registry-loader.js +17 -2
  26. package/src/config/execution-v0-contract-check.js +113 -15
  27. package/src/config/kernel-path-contract-check.js +57 -29
  28. package/src/config/kernel-path-loader.js +5 -36
  29. package/src/config/kernels/kernel-ref-digests.js +1 -1
  30. package/src/config/kernels/registry.js +14 -1
  31. package/src/config/kernels/registry.json +7 -5
  32. package/src/config/loader.d.ts +1 -1
  33. package/src/config/loader.js +12 -2
  34. package/src/config/merge-contract-check.js +59 -4
  35. package/src/config/merge-helpers.js +128 -7
  36. package/src/config/merge.d.ts +1 -0
  37. package/src/config/merge.js +10 -0
  38. package/src/config/param-validator.js +47 -2
  39. package/src/config/presets/kernel-paths/{gemma2-q4k-dequant-f32a.json → gemma2-q4k-dequant-f32a-nosubgroups.json} +3 -3
  40. package/src/config/presets/kernel-paths/gemma3-f16-fused-f32a-online-streamingprefill.json +223 -0
  41. package/src/config/presets/kernel-paths/{gemma3-q4k-dequant-f32a.json → gemma3-q4k-dequant-f32a-nosubgroups.json} +3 -3
  42. package/src/config/presets/kernel-paths/registry.json +29 -8
  43. package/src/config/presets/models/gemma2.json +2 -2
  44. package/src/config/presets/runtime/experiments/bench/gemma3-bench-q4k.json +1 -1
  45. package/src/config/presets/runtime/experiments/debug/gemma3-debug-q4k.json +1 -1
  46. package/src/config/presets/runtime/experiments/verify/gemma3-verify.json +1 -1
  47. package/src/config/presets/runtime/kernels/dequant-f16-q4k.json +6 -13
  48. package/src/config/presets/runtime/kernels/dequant-f32-q4k.json +6 -13
  49. package/src/config/presets/runtime/kernels/embeddinggemma-q4k-dequant-f32a.json +37 -0
  50. package/src/config/presets/runtime/kernels/fused-q4k.json +6 -13
  51. package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f16a.json +33 -0
  52. package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f32a-nosubgroups.json +33 -0
  53. package/src/config/presets/runtime/kernels/gemma2-q4k-fused-f32a.json +33 -0
  54. package/src/config/presets/runtime/kernels/safe-q4k.json +6 -13
  55. package/src/config/presets/runtime/platform/metal-apple-q4k.json +1 -1
  56. package/src/config/runtime.js +6 -1
  57. package/src/config/schema/debug.schema.d.ts +5 -0
  58. package/src/config/schema/doppler.schema.js +16 -21
  59. package/src/config/schema/inference-defaults.schema.js +3 -3
  60. package/src/config/schema/kernel-path.schema.d.ts +5 -1
  61. package/src/config/schema/kernel-thresholds.schema.js +12 -4
  62. package/src/config/schema/manifest.schema.d.ts +2 -1
  63. package/src/config/schema/manifest.schema.js +16 -3
  64. package/src/config/training-defaults.js +30 -22
  65. package/src/converter/conversion-plan.js +94 -9
  66. package/src/converter/core.d.ts +7 -0
  67. package/src/converter/core.js +14 -9
  68. package/src/converter/execution-v0-manifest.js +4 -1
  69. package/src/converter/index.d.ts +1 -0
  70. package/src/converter/index.js +1 -0
  71. package/src/converter/manifest-inference.js +43 -12
  72. package/src/converter/parsers/diffusion.js +0 -3
  73. package/src/converter/quantization-info.js +35 -15
  74. package/src/converter/shard-packer.d.ts +1 -1
  75. package/src/converter/shard-packer.js +4 -1
  76. package/src/debug/config.js +123 -11
  77. package/src/debug/signals.js +7 -1
  78. package/src/debug/tensor.d.ts +2 -0
  79. package/src/debug/tensor.js +13 -2
  80. package/src/distribution/p2p-control-plane.js +52 -12
  81. package/src/distribution/p2p-observability.js +43 -7
  82. package/src/distribution/p2p-webrtc-browser.js +20 -0
  83. package/src/distribution/shard-delivery.js +77 -26
  84. package/src/formats/gguf/types.js +33 -16
  85. package/src/formats/rdrr/groups.d.ts +12 -4
  86. package/src/formats/rdrr/groups.js +3 -6
  87. package/src/formats/rdrr/parsing.js +39 -2
  88. package/src/formats/rdrr/types.d.ts +2 -1
  89. package/src/gpu/command-recorder.js +86 -61
  90. package/src/gpu/device.d.ts +1 -0
  91. package/src/gpu/device.js +73 -19
  92. package/src/gpu/kernel-tuner/benchmarks.js +326 -316
  93. package/src/gpu/kernel-tuner/cache.js +71 -4
  94. package/src/gpu/kernel-tuner/tuner.js +22 -4
  95. package/src/gpu/kernels/attention.js +15 -34
  96. package/src/gpu/kernels/backward/adam.js +62 -58
  97. package/src/gpu/kernels/backward/attention_backward.js +257 -169
  98. package/src/gpu/kernels/backward/conv2d_backward.js +14 -1
  99. package/src/gpu/kernels/cast.js +191 -149
  100. package/src/gpu/kernels/check-stop.js +33 -44
  101. package/src/gpu/kernels/conv2d.js +27 -17
  102. package/src/gpu/kernels/cross_entropy_loss.js +21 -15
  103. package/src/gpu/kernels/depthwise_conv2d.js +36 -26
  104. package/src/gpu/kernels/dequant.js +178 -126
  105. package/src/gpu/kernels/energy.d.ts +3 -21
  106. package/src/gpu/kernels/energy.js +111 -88
  107. package/src/gpu/kernels/feature-check.js +1 -1
  108. package/src/gpu/kernels/fused_ffn.js +84 -65
  109. package/src/gpu/kernels/fused_matmul_residual.js +56 -33
  110. package/src/gpu/kernels/fused_matmul_rmsnorm.js +62 -45
  111. package/src/gpu/kernels/gather.js +33 -15
  112. package/src/gpu/kernels/gelu.js +19 -11
  113. package/src/gpu/kernels/grouped_pointwise_conv2d.js +33 -23
  114. package/src/gpu/kernels/groupnorm.js +34 -23
  115. package/src/gpu/kernels/kv-quantize.js +5 -2
  116. package/src/gpu/kernels/layernorm.js +35 -19
  117. package/src/gpu/kernels/logit-merge.js +5 -3
  118. package/src/gpu/kernels/matmul.js +58 -39
  119. package/src/gpu/kernels/modulate.js +23 -15
  120. package/src/gpu/kernels/moe.js +221 -175
  121. package/src/gpu/kernels/pixel_shuffle.js +22 -14
  122. package/src/gpu/kernels/relu.js +18 -10
  123. package/src/gpu/kernels/repeat_channels.js +25 -17
  124. package/src/gpu/kernels/residual.js +37 -27
  125. package/src/gpu/kernels/rmsnorm.js +57 -41
  126. package/src/gpu/kernels/rope.js +3 -0
  127. package/src/gpu/kernels/sample.js +27 -38
  128. package/src/gpu/kernels/sana_linear_attention.js +18 -10
  129. package/src/gpu/kernels/scale.js +18 -11
  130. package/src/gpu/kernels/shader-cache.js +4 -2
  131. package/src/gpu/kernels/silu.js +120 -72
  132. package/src/gpu/kernels/softmax.js +44 -25
  133. package/src/gpu/kernels/split_qkv.js +23 -13
  134. package/src/gpu/kernels/transpose.js +18 -10
  135. package/src/gpu/kernels/transpose.wgsl +5 -3
  136. package/src/gpu/kernels/upsample2d.js +21 -13
  137. package/src/gpu/kernels/utils.js +20 -13
  138. package/src/gpu/partitioned-buffer-pool.js +10 -2
  139. package/src/gpu/perf-guards.js +2 -9
  140. package/src/gpu/profiler.js +27 -22
  141. package/src/gpu/readback-utils.d.ts +16 -0
  142. package/src/gpu/readback-utils.js +41 -0
  143. package/src/gpu/submit-tracker.js +13 -0
  144. package/src/gpu/uniform-cache.d.ts +1 -0
  145. package/src/gpu/uniform-cache.js +30 -9
  146. package/src/hotswap/intent-bundle.js +6 -0
  147. package/src/hotswap/manifest.d.ts +10 -1
  148. package/src/hotswap/manifest.js +12 -2
  149. package/src/hotswap/runtime.js +30 -8
  150. package/src/index-browser.d.ts +44 -0
  151. package/src/index-browser.js +14 -0
  152. package/src/inference/browser-harness-contract-helpers.d.ts +5 -0
  153. package/src/inference/browser-harness-contract-helpers.js +28 -0
  154. package/src/inference/browser-harness-diffusion-energy-suites.d.ts +2 -0
  155. package/src/inference/browser-harness-diffusion-energy-suites.js +269 -0
  156. package/src/inference/browser-harness-model-helpers.d.ts +16 -0
  157. package/src/inference/browser-harness-model-helpers.js +217 -0
  158. package/src/inference/browser-harness-report-helpers.d.ts +7 -0
  159. package/src/inference/browser-harness-report-helpers.js +42 -0
  160. package/src/inference/browser-harness-runtime-helpers.d.ts +61 -0
  161. package/src/inference/browser-harness-runtime-helpers.js +415 -0
  162. package/src/inference/browser-harness-suite-helpers.d.ts +28 -0
  163. package/src/inference/browser-harness-suite-helpers.js +268 -0
  164. package/src/inference/browser-harness-text-helpers.d.ts +27 -0
  165. package/src/inference/browser-harness-text-helpers.js +788 -0
  166. package/src/inference/browser-harness.d.ts +6 -0
  167. package/src/inference/browser-harness.js +130 -1996
  168. package/src/inference/kv-cache/base.js +140 -94
  169. package/src/inference/kv-cache/tiered.js +5 -3
  170. package/src/inference/moe-router.js +88 -56
  171. package/src/inference/multi-model-network.js +5 -3
  172. package/src/inference/network-evolution.d.ts +11 -2
  173. package/src/inference/network-evolution.js +20 -21
  174. package/src/inference/pipelines/context.d.ts +3 -0
  175. package/src/inference/pipelines/context.js +142 -2
  176. package/src/inference/pipelines/diffusion/helpers.js +7 -2
  177. package/src/inference/pipelines/diffusion/pipeline.js +2 -1
  178. package/src/inference/pipelines/diffusion/sd3-transformer.js +10 -10
  179. package/src/inference/pipelines/diffusion/vae.js +3 -7
  180. package/src/inference/pipelines/energy/pipeline.js +27 -21
  181. package/src/inference/pipelines/energy/quintel.d.ts +5 -0
  182. package/src/inference/pipelines/energy/quintel.js +11 -0
  183. package/src/inference/pipelines/energy-head/row-head-pipeline.js +17 -13
  184. package/src/inference/pipelines/structured/json-head-pipeline.js +26 -11
  185. package/src/inference/pipelines/text/attention/projections.js +151 -101
  186. package/src/inference/pipelines/text/attention/record.js +62 -8
  187. package/src/inference/pipelines/text/attention/run.js +62 -8
  188. package/src/inference/pipelines/text/config.js +3 -4
  189. package/src/inference/pipelines/text/embed.js +2 -8
  190. package/src/inference/pipelines/text/execution-plan.js +41 -19
  191. package/src/inference/pipelines/text/execution-v0-contract-helpers.d.ts +59 -0
  192. package/src/inference/pipelines/text/execution-v0-contract-helpers.js +937 -0
  193. package/src/inference/pipelines/text/execution-v0-runtime-builders.d.ts +15 -0
  194. package/src/inference/pipelines/text/execution-v0-runtime-builders.js +279 -0
  195. package/src/inference/pipelines/text/execution-v0.js +62 -1013
  196. package/src/inference/pipelines/text/generator-steps.d.ts +46 -0
  197. package/src/inference/pipelines/text/generator-steps.js +298 -207
  198. package/src/inference/pipelines/text/generator.js +6 -23
  199. package/src/inference/pipelines/text/init.js +78 -20
  200. package/src/inference/pipelines/text/kernel-path-auto-select.js +2 -0
  201. package/src/inference/pipelines/text/kernel-trace.d.ts +2 -0
  202. package/src/inference/pipelines/text/kernel-trace.js +6 -0
  203. package/src/inference/pipelines/text/layer.js +3 -9
  204. package/src/inference/pipelines/text/linear-attention.d.ts +10 -0
  205. package/src/inference/pipelines/text/linear-attention.js +80 -6
  206. package/src/inference/pipelines/text/logits/gpu.js +10 -5
  207. package/src/inference/pipelines/text/logits/index.js +10 -11
  208. package/src/inference/pipelines/text/logits/utils.d.ts +7 -0
  209. package/src/inference/pipelines/text/logits/utils.js +9 -0
  210. package/src/inference/pipelines/text/lora-apply.js +50 -32
  211. package/src/inference/pipelines/text/model-load.js +279 -104
  212. package/src/inference/pipelines/text/moe-cache.js +5 -4
  213. package/src/inference/pipelines/text/moe-cpu-gptoss.js +74 -69
  214. package/src/inference/pipelines/text/moe-cpu.js +42 -38
  215. package/src/inference/pipelines/text/moe-gpu.js +110 -86
  216. package/src/inference/pipelines/text/ops.js +90 -90
  217. package/src/inference/pipelines/text/probes.js +9 -9
  218. package/src/inference/pipelines/text/weights.js +17 -7
  219. package/src/inference/pipelines/text.js +13 -1
  220. package/src/inference/speculative.d.ts +2 -2
  221. package/src/inference/speculative.js +4 -18
  222. package/src/inference/test-harness.d.ts +1 -1
  223. package/src/inference/test-harness.js +15 -5
  224. package/src/inference/tokenizer.d.ts +0 -5
  225. package/src/inference/tokenizer.js +4 -23
  226. package/src/inference/tokenizers/bpe.js +9 -0
  227. package/src/inference/tokenizers/bundled.js +20 -0
  228. package/src/inference/tokenizers/sentencepiece.js +12 -0
  229. package/src/loader/doppler-loader.js +38 -22
  230. package/src/loader/dtype-utils.js +3 -44
  231. package/src/loader/embedding-loader.js +7 -3
  232. package/src/loader/experts/expert-cache.js +13 -6
  233. package/src/loader/experts/expert-loader.js +10 -6
  234. package/src/loader/final-weights-loader.js +8 -4
  235. package/src/loader/layer-loader.js +2 -1
  236. package/src/loader/loader-state.js +2 -2
  237. package/src/loader/memory-monitor.js +8 -0
  238. package/src/loader/multi-model-loader.d.ts +14 -0
  239. package/src/loader/multi-model-loader.js +70 -24
  240. package/src/loader/shard-cache.js +81 -12
  241. package/src/loader/shard-resolver.js +25 -3
  242. package/src/loader/tensors/tensor-loader.js +209 -144
  243. package/src/loader/tensors/tensor-reader.js +76 -19
  244. package/src/loader/weight-downcast.js +1 -1
  245. package/src/memory/buffer-pool.d.ts +9 -1
  246. package/src/memory/buffer-pool.js +109 -44
  247. package/src/memory/unified-detect.js +1 -1
  248. package/src/rules/inference/kernel-path.rules.json +24 -8
  249. package/src/rules/rule-registry.js +25 -1
  250. package/src/storage/backends/opfs-store.js +68 -24
  251. package/src/storage/downloader.js +364 -83
  252. package/src/storage/index.d.ts +3 -0
  253. package/src/storage/index.js +3 -0
  254. package/src/storage/preflight.d.ts +2 -2
  255. package/src/storage/preflight.js +24 -2
  256. package/src/storage/quickstart-downloader.js +11 -5
  257. package/src/storage/registry.js +10 -4
  258. package/src/storage/reports.js +1 -1
  259. package/src/storage/shard-manager.d.ts +15 -1
  260. package/src/storage/shard-manager.js +51 -3
  261. package/src/storage/source-artifact-store.d.ts +52 -0
  262. package/src/storage/source-artifact-store.js +234 -0
  263. package/src/tooling/command-api-constants.d.ts +9 -0
  264. package/src/tooling/command-api-constants.js +9 -0
  265. package/src/tooling/command-api-family-normalizers.d.ts +9 -0
  266. package/src/tooling/command-api-family-normalizers.js +343 -0
  267. package/src/tooling/command-api-helpers.d.ts +25 -0
  268. package/src/tooling/command-api-helpers.js +262 -0
  269. package/src/tooling/command-api.js +16 -602
  270. package/src/tooling/command-envelope.js +4 -1
  271. package/src/tooling/command-runner-shared.js +52 -18
  272. package/src/tooling/lean-execution-contract.js +150 -3
  273. package/src/tooling/node-browser-command-runner.js +161 -271
  274. package/src/tooling/node-command-runner.js +29 -3
  275. package/src/tooling/node-converter.js +27 -1
  276. package/src/tooling/node-source-runtime.d.ts +1 -1
  277. package/src/tooling/node-source-runtime.js +84 -3
  278. package/src/tooling/node-webgpu.js +24 -21
  279. package/src/tooling/opfs-cache.js +21 -4
  280. package/src/tooling/runtime-input-composition.d.ts +38 -0
  281. package/src/tooling/runtime-input-composition.js +86 -0
  282. package/src/tooling/source-runtime-bundle.d.ts +40 -5
  283. package/src/tooling/source-runtime-bundle.js +261 -34
  284. package/src/tooling/source-runtime-materializer.d.ts +6 -0
  285. package/src/tooling/source-runtime-materializer.js +93 -0
  286. package/src/training/attention-backward.js +32 -17
  287. package/src/training/autograd.js +80 -52
  288. package/src/training/checkpoint-watch.d.ts +2 -1
  289. package/src/training/checkpoint-watch.js +39 -6
  290. package/src/training/checkpoint.js +40 -11
  291. package/src/training/clip.js +2 -1
  292. package/src/training/datasets/token-batch.js +20 -8
  293. package/src/training/distillation/checkpoint-watch.js +1 -0
  294. package/src/training/distillation/student-fixture.d.ts +22 -0
  295. package/src/training/distillation/student-fixture.js +846 -0
  296. package/src/training/distillation/suite-data.d.ts +45 -0
  297. package/src/training/distillation/suite-data.js +189 -0
  298. package/src/training/lora-pipeline.js +4 -7
  299. package/src/training/lora.js +26 -12
  300. package/src/training/loss.js +5 -6
  301. package/src/training/objectives/cross_entropy.js +2 -5
  302. package/src/training/objectives/distill_kd.js +4 -8
  303. package/src/training/objectives/distill_triplet.js +4 -8
  304. package/src/training/objectives/ul_stage2_base.js +4 -8
  305. package/src/training/operator-command.js +2 -0
  306. package/src/training/optimizer.js +19 -7
  307. package/src/training/runner.js +2 -1
  308. package/src/training/suite.js +18 -978
  309. package/src/training/tensor-factory.d.ts +9 -0
  310. package/src/training/tensor-factory.js +13 -0
  311. package/src/training/trainer.js +3 -5
  312. package/src/training/ul_dataset.js +3 -5
  313. package/src/training/workloads.js +70 -79
  314. package/src/version.js +1 -1
  315. package/tools/convert-safetensors-node.js +22 -16
  316. package/tools/doppler-cli.js +44 -25
@@ -6,6 +6,7 @@ import { acquireBuffer, readBuffer, releaseBuffer, uploadData } from '../memory/
6
6
  import { createTensor } from '../gpu/tensor.js';
7
7
  import { attentionBackwardCpu } from './attention-backward.js';
8
8
  import { f16ToF32Array, f32ToF16Array } from '../inference/kv-cache/types.js';
9
+ import { createUploadedTensor } from './tensor-factory.js';
9
10
 
10
11
  export const OpType = {
11
12
  EMBED: 'embed',
@@ -35,6 +36,7 @@ export class AutogradTape {
35
36
  constructor(registry) {
36
37
  this.registry = registry;
37
38
  this.records = [];
39
+ this.retainedBuffers = new Set();
38
40
  }
39
41
 
40
42
  watch(tensor) {
@@ -43,6 +45,13 @@ export class AutogradTape {
43
45
 
44
46
  async record(op, fn, inputs, options = {}) {
45
47
  const output = await fn(...inputs);
48
+ if (Array.isArray(options.retainBuffers)) {
49
+ for (const buffer of options.retainBuffers) {
50
+ if (buffer) {
51
+ this.retainedBuffers.add(buffer);
52
+ }
53
+ }
54
+ }
46
55
  this.records.push({ op, inputs, output, options });
47
56
  return output;
48
57
  }
@@ -50,31 +59,40 @@ export class AutogradTape {
50
59
  async backward(gradOutput) {
51
60
  const grads = new Map();
52
61
  const seeds = this.normalizeBackwardSeeds(gradOutput);
53
- for (const seed of seeds) {
54
- await this.accumulateGrad(grads, seed.tensor, seed.grad);
55
- }
56
-
57
- for (let i = this.records.length - 1; i >= 0; i -= 1) {
58
- const record = this.records[i];
59
- const entry = this.registry.ops[record.op];
60
- if (!entry) {
61
- continue;
62
+ try {
63
+ for (const seed of seeds) {
64
+ await this.accumulateGrad(grads, seed.tensor, seed.grad);
62
65
  }
63
66
 
64
- const gradOut = grads.get(record.output);
65
- if (!gradOut) {
66
- continue;
67
- }
67
+ for (let i = this.records.length - 1; i >= 0; i -= 1) {
68
+ const record = this.records[i];
69
+ const entry = this.registry.ops[record.op];
70
+ if (!entry) {
71
+ continue;
72
+ }
73
+
74
+ const gradOut = grads.get(record.output);
75
+ if (!gradOut) {
76
+ continue;
77
+ }
68
78
 
69
- const gradsOut = await this.runBackward(entry.backward, record, gradOut);
70
- for (const { input, grad } of gradsOut) {
71
- if (input && grad) {
72
- await this.accumulateGrad(grads, input, grad);
79
+ const gradsOut = await this.runBackward(entry.backward, record, gradOut);
80
+ for (const { input, grad } of gradsOut) {
81
+ if (input && grad) {
82
+ await this.accumulateGrad(grads, input, grad);
83
+ }
73
84
  }
74
85
  }
75
- }
76
86
 
77
- return grads;
87
+ return grads;
88
+ } finally {
89
+ for (const buffer of this.retainedBuffers) {
90
+ try {
91
+ releaseBuffer(buffer);
92
+ } catch {}
93
+ }
94
+ this.retainedBuffers.clear();
95
+ }
78
96
  }
79
97
 
80
98
  isTensorLike(value) {
@@ -245,9 +263,7 @@ export class AutogradTape {
245
263
  expanded.set(gradRow.subarray(0, copyCount), rowOffset);
246
264
  const dtype = gradOut.dtype === 'f16' ? 'f16' : 'f32';
247
265
  const payload = dtype === 'f16' ? f32ToF16Array(expanded) : expanded;
248
- const outBuffer = acquireBuffer(payload.byteLength, undefined, 'row_slice_backward_output');
249
- uploadData(outBuffer, payload);
250
- return createTensor(outBuffer, dtype, [rows, cols], 'row_slice_backward_output');
266
+ return createUploadedTensor(payload, dtype, [rows, cols], 'row_slice_backward_output');
251
267
  }
252
268
 
253
269
  resolveSiluRowsplitGate(gateValue, activation) {
@@ -305,9 +321,7 @@ export class AutogradTape {
305
321
 
306
322
  const dtype = gradOut.dtype === 'f16' ? 'f16' : 'f32';
307
323
  const payload = dtype === 'f16' ? f32ToF16Array(output) : output;
308
- const outBuffer = acquireBuffer(payload.byteLength, undefined, 'silu_rowsplit_backward_output');
309
- uploadData(outBuffer, payload);
310
- return createTensor(outBuffer, dtype, [numTokens, dim * 2], 'silu_rowsplit_backward_output');
324
+ return createUploadedTensor(payload, dtype, [numTokens, dim * 2], 'silu_rowsplit_backward_output');
311
325
  }
312
326
 
313
327
  async accumulateLargeGradF32(existing, grad, size, shape) {
@@ -317,35 +331,49 @@ export class AutogradTape {
317
331
  }
318
332
  const bytesPerElement = 4;
319
333
  const outputBuffer = acquireBuffer(size * bytesPerElement, undefined, 'grad_accum_large_output');
320
-
321
- for (let offset = 0; offset < size; offset += MAX_RESIDUAL_ELEMENTS_PER_DISPATCH) {
322
- const chunkElements = Math.min(MAX_RESIDUAL_ELEMENTS_PER_DISPATCH, size - offset);
323
- const chunkBytes = chunkElements * bytesPerElement;
324
- const chunkOffsetBytes = offset * bytesPerElement;
325
-
326
- const aChunkBuffer = acquireBuffer(chunkBytes, undefined, 'grad_accum_large_a_chunk');
327
- const bChunkBuffer = acquireBuffer(chunkBytes, undefined, 'grad_accum_large_b_chunk');
328
- const copyIn = device.createCommandEncoder();
329
- copyIn.copyBufferToBuffer(existing.buffer, chunkOffsetBytes, aChunkBuffer, 0, chunkBytes);
330
- copyIn.copyBufferToBuffer(grad.buffer, chunkOffsetBytes, bChunkBuffer, 0, chunkBytes);
331
- device.queue.submit([copyIn.finish()]);
332
-
333
- const aChunk = createTensor(aChunkBuffer, 'f32', [chunkElements], 'grad_accum_large_a_tensor');
334
- const bChunk = createTensor(bChunkBuffer, 'f32', [chunkElements], 'grad_accum_large_b_tensor');
335
- const summedChunk = await runResidualAdd(aChunk, bChunk, chunkElements);
336
-
337
- const copyOut = device.createCommandEncoder();
338
- copyOut.copyBufferToBuffer(summedChunk.buffer, 0, outputBuffer, chunkOffsetBytes, chunkBytes);
339
- device.queue.submit([copyOut.finish()]);
340
-
341
- releaseBuffer(aChunkBuffer);
342
- releaseBuffer(bChunkBuffer);
343
- if (summedChunk?.buffer && summedChunk.buffer !== outputBuffer) {
344
- releaseBuffer(summedChunk.buffer);
334
+ try {
335
+ for (let offset = 0; offset < size; offset += MAX_RESIDUAL_ELEMENTS_PER_DISPATCH) {
336
+ const chunkElements = Math.min(MAX_RESIDUAL_ELEMENTS_PER_DISPATCH, size - offset);
337
+ const chunkBytes = chunkElements * bytesPerElement;
338
+ const chunkOffsetBytes = offset * bytesPerElement;
339
+
340
+ let aChunkBuffer = null;
341
+ let bChunkBuffer = null;
342
+ let summedChunkBuffer = null;
343
+ try {
344
+ aChunkBuffer = acquireBuffer(chunkBytes, undefined, 'grad_accum_large_a_chunk');
345
+ bChunkBuffer = acquireBuffer(chunkBytes, undefined, 'grad_accum_large_b_chunk');
346
+ const copyIn = device.createCommandEncoder();
347
+ copyIn.copyBufferToBuffer(existing.buffer, chunkOffsetBytes, aChunkBuffer, 0, chunkBytes);
348
+ copyIn.copyBufferToBuffer(grad.buffer, chunkOffsetBytes, bChunkBuffer, 0, chunkBytes);
349
+ device.queue.submit([copyIn.finish()]);
350
+
351
+ const aChunk = createTensor(aChunkBuffer, 'f32', [chunkElements], 'grad_accum_large_a_tensor');
352
+ const bChunk = createTensor(bChunkBuffer, 'f32', [chunkElements], 'grad_accum_large_b_tensor');
353
+ const summedChunk = await runResidualAdd(aChunk, bChunk, chunkElements);
354
+ summedChunkBuffer = summedChunk?.buffer ?? null;
355
+
356
+ const copyOut = device.createCommandEncoder();
357
+ copyOut.copyBufferToBuffer(summedChunk.buffer, 0, outputBuffer, chunkOffsetBytes, chunkBytes);
358
+ device.queue.submit([copyOut.finish()]);
359
+ } finally {
360
+ if (aChunkBuffer) {
361
+ releaseBuffer(aChunkBuffer);
362
+ }
363
+ if (bChunkBuffer) {
364
+ releaseBuffer(bChunkBuffer);
365
+ }
366
+ if (summedChunkBuffer && summedChunkBuffer !== outputBuffer) {
367
+ releaseBuffer(summedChunkBuffer);
368
+ }
369
+ }
345
370
  }
346
- }
347
371
 
348
- return createTensor(outputBuffer, 'f32', [...shape], 'grad_accum_large_output');
372
+ return createTensor(outputBuffer, 'f32', [...shape], 'grad_accum_large_output');
373
+ } catch (error) {
374
+ releaseBuffer(outputBuffer);
375
+ throw error;
376
+ }
349
377
  }
350
378
 
351
379
 
@@ -3,5 +3,6 @@ export declare function watchFinalizedCheckpoints(options: {
3
3
  manifestPath: string;
4
4
  pollIntervalMs?: number | null;
5
5
  stopWhenIdle?: boolean;
6
+ signal?: AbortSignal | null;
6
7
  onCheckpoint: (markerPath: string) => Promise<void> | void;
7
- }): Promise<{ ok: true; processedCount: number; manifestPath: string }>;
8
+ }): Promise<{ ok: true; processedCount: number; manifestPath: string; aborted?: boolean }>;
@@ -55,6 +55,36 @@ async function readProcessedManifest(manifestPath) {
55
55
  }
56
56
  }
57
57
 
58
+ function createWatchResult(processed, manifestPath, aborted = false) {
59
+ return {
60
+ ok: true,
61
+ processedCount: processed.size,
62
+ manifestPath,
63
+ aborted,
64
+ };
65
+ }
66
+
67
+ async function waitForPollInterval(pollIntervalMs, signal) {
68
+ if (!signal) {
69
+ await new Promise((resolvePromise) => setTimeout(resolvePromise, pollIntervalMs));
70
+ return true;
71
+ }
72
+ if (signal.aborted) {
73
+ return false;
74
+ }
75
+ return new Promise((resolvePromise) => {
76
+ const onAbort = () => {
77
+ clearTimeout(timer);
78
+ resolvePromise(false);
79
+ };
80
+ const timer = setTimeout(() => {
81
+ signal.removeEventListener('abort', onAbort);
82
+ resolvePromise(true);
83
+ }, pollIntervalMs);
84
+ signal.addEventListener('abort', onAbort, { once: true });
85
+ });
86
+ }
87
+
58
88
  export async function watchFinalizedCheckpoints(options) {
59
89
  const checkpointsDir = resolve(String(options.checkpointsDir));
60
90
  const manifestPath = resolve(String(options.manifestPath));
@@ -65,6 +95,7 @@ export async function watchFinalizedCheckpoints(options) {
65
95
  const onCheckpoint = typeof options.onCheckpoint === 'function'
66
96
  ? options.onCheckpoint
67
97
  : null;
98
+ const signal = options.signal ?? null;
68
99
  if (!onCheckpoint) {
69
100
  throw new Error('watchFinalizedCheckpoints requires onCheckpoint(markerPath).');
70
101
  }
@@ -72,6 +103,9 @@ export async function watchFinalizedCheckpoints(options) {
72
103
  const processed = await readProcessedManifest(manifestPath);
73
104
  let idlePolls = 0;
74
105
  for (;;) {
106
+ if (signal?.aborted) {
107
+ return createWatchResult(processed, manifestPath, true);
108
+ }
75
109
  const checkpointsExist = await ensureDirectoryExists(checkpointsDir);
76
110
  const markers = checkpointsExist
77
111
  ? await listCheckpointMarkers(checkpointsDir)
@@ -92,15 +126,14 @@ export async function watchFinalizedCheckpoints(options) {
92
126
  if (!sawNewMarker) {
93
127
  idlePolls += 1;
94
128
  if (stopWhenIdle && idlePolls > 0) {
95
- return {
96
- ok: true,
97
- processedCount: processed.size,
98
- manifestPath,
99
- };
129
+ return createWatchResult(processed, manifestPath);
100
130
  }
101
131
  } else {
102
132
  idlePolls = 0;
103
133
  }
104
- await new Promise((resolvePromise) => setTimeout(resolvePromise, pollIntervalMs));
134
+ const shouldContinue = await waitForPollInterval(pollIntervalMs, signal);
135
+ if (!shouldContinue) {
136
+ return createWatchResult(processed, manifestPath, true);
137
+ }
105
138
  }
106
139
  }
@@ -31,6 +31,13 @@ function openCheckpointDB(options = {}) {
31
31
  });
32
32
  }
33
33
 
34
+ function closeCheckpointDB(db) {
35
+ if (!db || typeof db.close !== 'function') {
36
+ return;
37
+ }
38
+ db.close();
39
+ }
40
+
34
41
  async function resolveNodeCheckpointPath(key, options = {}) {
35
42
  const [{ resolve, join, dirname }, { mkdir }] = await Promise.all([
36
43
  import('node:path'),
@@ -140,9 +147,15 @@ export async function saveCheckpoint(key, payload, options = {}) {
140
147
  const useNodeStore = isNodeRuntime() && typeof indexedDB === 'undefined';
141
148
  const nodePath = useNodeStore ? await resolveNodeCheckpointPath(key, options) : null;
142
149
  const browserStore = useNodeStore ? null : await openCheckpointDB(options);
143
- const previousData = useNodeStore
144
- ? await readNodeCheckpointRecord(nodePath)
145
- : await readCheckpointRecord(browserStore.db, browserStore.storeName, key);
150
+ let previousData;
151
+ try {
152
+ previousData = useNodeStore
153
+ ? await readNodeCheckpointRecord(nodePath)
154
+ : await readCheckpointRecord(browserStore.db, browserStore.storeName, key);
155
+ } catch (error) {
156
+ closeCheckpointDB(browserStore?.db);
157
+ throw error;
158
+ }
146
159
  const previousMetadata = previousData?.metadata || {};
147
160
  const previousLineage = previousMetadata.lineage || {};
148
161
  const previousCheckpointHash = options.priorCheckpointHash
@@ -194,13 +207,25 @@ export async function saveCheckpoint(key, payload, options = {}) {
194
207
 
195
208
  return new Promise((resolve, reject) => {
196
209
  const tx = browserStore.db.transaction(browserStore.storeName, 'readwrite');
197
- tx.oncomplete = () => resolve({
198
- key,
199
- path: null,
200
- metadata: data.metadata,
201
- data,
202
- });
203
- tx.onerror = () => reject(tx.error);
210
+ tx.oncomplete = () => {
211
+ closeCheckpointDB(browserStore.db);
212
+ resolve({
213
+ key,
214
+ path: null,
215
+ metadata: data.metadata,
216
+ data,
217
+ });
218
+ };
219
+ tx.onerror = () => {
220
+ const error = tx.error;
221
+ closeCheckpointDB(browserStore.db);
222
+ reject(error);
223
+ };
224
+ tx.onabort = () => {
225
+ const error = tx.error ?? new Error('Checkpoint transaction aborted');
226
+ closeCheckpointDB(browserStore.db);
227
+ reject(error);
228
+ };
204
229
  const store = tx.objectStore(browserStore.storeName);
205
230
  store.put(data, key);
206
231
  });
@@ -213,7 +238,11 @@ export async function loadCheckpoint(key, options = {}) {
213
238
  ? await readNodeCheckpointRecord(nodePath)
214
239
  : await (async () => {
215
240
  const { db, storeName } = await openCheckpointDB(options);
216
- return readCheckpointRecord(db, storeName, key);
241
+ try {
242
+ return await readCheckpointRecord(db, storeName, key);
243
+ } finally {
244
+ closeCheckpointDB(db);
245
+ }
217
246
  })();
218
247
 
219
248
  if (!data || !data.metadata || !options.expectedMetadata) {
@@ -12,7 +12,8 @@ async function readGradData(grad) {
12
12
  }
13
13
 
14
14
  export async function clipGradients(grads, config) {
15
- const maxNorm = config?.training?.gradient?.maxNorm;
15
+ const maxNorm = config?.training?.gradientClipping?.maxNorm
16
+ ?? config?.training?.gradient?.maxNorm;
16
17
  let sumSq = 0;
17
18
  let totalParamCount = 0;
18
19
 
@@ -1,5 +1,5 @@
1
1
 
2
- import { acquireBuffer, uploadData } from '../../memory/buffer-pool.js';
2
+ import { acquireBuffer, uploadData, releaseBuffer } from '../../memory/buffer-pool.js';
3
3
  import { createTensor } from '../../gpu/tensor.js';
4
4
 
5
5
  function flattenTokenBatch(samples, key) {
@@ -27,14 +27,26 @@ export function buildTokenBatch(samples) {
27
27
  }
28
28
 
29
29
  export function createTokenBatchTensors(batch) {
30
- const inputBuf = acquireBuffer(batch.inputFlat.byteLength, undefined, 'train_input_tokens');
31
- uploadData(inputBuf, batch.inputFlat);
30
+ let inputBuf = null;
31
+ let targetBuf = null;
32
+ try {
33
+ inputBuf = acquireBuffer(batch.inputFlat.byteLength, undefined, 'train_input_tokens');
34
+ uploadData(inputBuf, batch.inputFlat);
32
35
 
33
- const targetBuf = acquireBuffer(batch.targetFlat.byteLength, undefined, 'train_target_tokens');
34
- uploadData(targetBuf, batch.targetFlat);
36
+ targetBuf = acquireBuffer(batch.targetFlat.byteLength, undefined, 'train_target_tokens');
37
+ uploadData(targetBuf, batch.targetFlat);
35
38
 
36
- const input = createTensor(inputBuf, 'f32', [batch.inputFlat.length], 'train_input_tokens');
37
- const targets = createTensor(targetBuf, 'f32', [batch.targetFlat.length], 'train_target_tokens');
39
+ const input = createTensor(inputBuf, 'f32', [batch.inputFlat.length], 'train_input_tokens');
40
+ const targets = createTensor(targetBuf, 'f32', [batch.targetFlat.length], 'train_target_tokens');
38
41
 
39
- return { input, targets, offsets: batch.offsets };
42
+ return { input, targets, offsets: batch.offsets };
43
+ } catch (error) {
44
+ if (inputBuf) {
45
+ releaseBuffer(inputBuf);
46
+ }
47
+ if (targetBuf) {
48
+ releaseBuffer(targetBuf);
49
+ }
50
+ throw error;
51
+ }
40
52
  }
@@ -14,6 +14,7 @@ export async function watchDistillationCheckpoints(options) {
14
14
  manifestPath,
15
15
  pollIntervalMs: options.pollIntervalMs || 2000,
16
16
  stopWhenIdle: options.stopWhenIdle === true,
17
+ signal: options.signal ?? null,
17
18
  onCheckpoint: async (markerPath) => {
18
19
  const { marker } = await readDistillCheckpointMarker(markerPath);
19
20
  const reports = await evaluateDistillationCheckpoint({
@@ -0,0 +1,22 @@
1
+ export interface DistillStudentFixture {
2
+ config: Record<string, unknown>;
3
+ model: {
4
+ forward: (input: unknown, tape: unknown) => Promise<unknown>;
5
+ forwardDistill?: (
6
+ batch: unknown,
7
+ tape: unknown,
8
+ options?: Record<string, unknown>
9
+ ) => Promise<{ logits: unknown }>;
10
+ cleanupDistillStep?: () => void;
11
+ loraParams?: () => unknown[];
12
+ paramGroups?: () => Record<string, unknown[]>;
13
+ };
14
+ outputDim?: number;
15
+ embeddingDim?: number;
16
+ cleanup(): void;
17
+ }
18
+
19
+ export declare function createDistillStudentRuntimeModelFixture(
20
+ overrides?: Record<string, unknown>,
21
+ options?: Record<string, unknown>
22
+ ): Promise<DistillStudentFixture>;