@simulatte/doppler 0.1.6 → 0.1.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (316) hide show
  1. package/CHANGELOG.md +126 -0
  2. package/README.md +16 -23
  3. package/package.json +14 -1
  4. package/src/adapters/adapter-registry.js +12 -1
  5. package/src/adapters/lora-loader.js +23 -6
  6. package/src/bridge/extension-client.d.ts +5 -0
  7. package/src/bridge/extension-client.js +40 -0
  8. package/src/bridge/index.d.ts +2 -1
  9. package/src/bridge/index.js +6 -4
  10. package/src/browser/browser-converter.js +26 -1
  11. package/src/browser/file-picker.js +6 -0
  12. package/src/browser/safetensors-parser-browser.js +84 -1
  13. package/src/browser/shard-io-browser.js +2 -2
  14. package/src/browser/tensor-source-download.js +8 -2
  15. package/src/browser/tensor-source-http.d.ts +1 -0
  16. package/src/browser/tensor-source-http.js +5 -1
  17. package/src/client/doppler-api.browser.js +20 -4
  18. package/src/client/doppler-api.js +19 -3
  19. package/src/client/doppler-provider/generation.js +12 -0
  20. package/src/client/doppler-provider/model-manager.d.ts +10 -0
  21. package/src/client/doppler-provider/model-manager.js +91 -19
  22. package/src/client/doppler-provider/source-runtime.d.ts +2 -1
  23. package/src/client/doppler-provider/source-runtime.js +132 -13
  24. package/src/client/doppler-registry.json +8 -7
  25. package/src/config/backward-registry-loader.js +17 -2
  26. package/src/config/execution-v0-contract-check.js +113 -15
  27. package/src/config/kernel-path-contract-check.js +57 -29
  28. package/src/config/kernel-path-loader.js +5 -36
  29. package/src/config/kernels/kernel-ref-digests.js +1 -1
  30. package/src/config/kernels/registry.js +14 -1
  31. package/src/config/kernels/registry.json +7 -5
  32. package/src/config/loader.d.ts +1 -1
  33. package/src/config/loader.js +12 -2
  34. package/src/config/merge-contract-check.js +59 -4
  35. package/src/config/merge-helpers.js +128 -7
  36. package/src/config/merge.d.ts +1 -0
  37. package/src/config/merge.js +10 -0
  38. package/src/config/param-validator.js +47 -2
  39. package/src/config/presets/kernel-paths/{gemma2-q4k-dequant-f32a.json → gemma2-q4k-dequant-f32a-nosubgroups.json} +3 -3
  40. package/src/config/presets/kernel-paths/gemma3-f16-fused-f32a-online-streamingprefill.json +223 -0
  41. package/src/config/presets/kernel-paths/{gemma3-q4k-dequant-f32a.json → gemma3-q4k-dequant-f32a-nosubgroups.json} +3 -3
  42. package/src/config/presets/kernel-paths/registry.json +29 -8
  43. package/src/config/presets/models/gemma2.json +2 -2
  44. package/src/config/presets/runtime/experiments/bench/gemma3-bench-q4k.json +1 -1
  45. package/src/config/presets/runtime/experiments/debug/gemma3-debug-q4k.json +1 -1
  46. package/src/config/presets/runtime/experiments/verify/gemma3-verify.json +1 -1
  47. package/src/config/presets/runtime/kernels/dequant-f16-q4k.json +6 -13
  48. package/src/config/presets/runtime/kernels/dequant-f32-q4k.json +6 -13
  49. package/src/config/presets/runtime/kernels/embeddinggemma-q4k-dequant-f32a.json +37 -0
  50. package/src/config/presets/runtime/kernels/fused-q4k.json +6 -13
  51. package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f16a.json +33 -0
  52. package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f32a-nosubgroups.json +33 -0
  53. package/src/config/presets/runtime/kernels/gemma2-q4k-fused-f32a.json +33 -0
  54. package/src/config/presets/runtime/kernels/safe-q4k.json +6 -13
  55. package/src/config/presets/runtime/platform/metal-apple-q4k.json +1 -1
  56. package/src/config/runtime.js +6 -1
  57. package/src/config/schema/debug.schema.d.ts +5 -0
  58. package/src/config/schema/doppler.schema.js +16 -21
  59. package/src/config/schema/inference-defaults.schema.js +3 -3
  60. package/src/config/schema/kernel-path.schema.d.ts +5 -1
  61. package/src/config/schema/kernel-thresholds.schema.js +12 -4
  62. package/src/config/schema/manifest.schema.d.ts +2 -1
  63. package/src/config/schema/manifest.schema.js +16 -3
  64. package/src/config/training-defaults.js +30 -22
  65. package/src/converter/conversion-plan.js +94 -9
  66. package/src/converter/core.d.ts +7 -0
  67. package/src/converter/core.js +14 -9
  68. package/src/converter/execution-v0-manifest.js +4 -1
  69. package/src/converter/index.d.ts +1 -0
  70. package/src/converter/index.js +1 -0
  71. package/src/converter/manifest-inference.js +43 -12
  72. package/src/converter/parsers/diffusion.js +0 -3
  73. package/src/converter/quantization-info.js +35 -15
  74. package/src/converter/shard-packer.d.ts +1 -1
  75. package/src/converter/shard-packer.js +4 -1
  76. package/src/debug/config.js +123 -11
  77. package/src/debug/signals.js +7 -1
  78. package/src/debug/tensor.d.ts +2 -0
  79. package/src/debug/tensor.js +13 -2
  80. package/src/distribution/p2p-control-plane.js +52 -12
  81. package/src/distribution/p2p-observability.js +43 -7
  82. package/src/distribution/p2p-webrtc-browser.js +20 -0
  83. package/src/distribution/shard-delivery.js +77 -26
  84. package/src/formats/gguf/types.js +33 -16
  85. package/src/formats/rdrr/groups.d.ts +12 -4
  86. package/src/formats/rdrr/groups.js +3 -6
  87. package/src/formats/rdrr/parsing.js +39 -2
  88. package/src/formats/rdrr/types.d.ts +2 -1
  89. package/src/gpu/command-recorder.js +86 -61
  90. package/src/gpu/device.d.ts +1 -0
  91. package/src/gpu/device.js +73 -19
  92. package/src/gpu/kernel-tuner/benchmarks.js +326 -316
  93. package/src/gpu/kernel-tuner/cache.js +71 -4
  94. package/src/gpu/kernel-tuner/tuner.js +22 -4
  95. package/src/gpu/kernels/attention.js +15 -34
  96. package/src/gpu/kernels/backward/adam.js +62 -58
  97. package/src/gpu/kernels/backward/attention_backward.js +257 -169
  98. package/src/gpu/kernels/backward/conv2d_backward.js +14 -1
  99. package/src/gpu/kernels/cast.js +191 -149
  100. package/src/gpu/kernels/check-stop.js +33 -44
  101. package/src/gpu/kernels/conv2d.js +27 -17
  102. package/src/gpu/kernels/cross_entropy_loss.js +21 -15
  103. package/src/gpu/kernels/depthwise_conv2d.js +36 -26
  104. package/src/gpu/kernels/dequant.js +178 -126
  105. package/src/gpu/kernels/energy.d.ts +3 -21
  106. package/src/gpu/kernels/energy.js +111 -88
  107. package/src/gpu/kernels/feature-check.js +1 -1
  108. package/src/gpu/kernels/fused_ffn.js +84 -65
  109. package/src/gpu/kernels/fused_matmul_residual.js +56 -33
  110. package/src/gpu/kernels/fused_matmul_rmsnorm.js +62 -45
  111. package/src/gpu/kernels/gather.js +33 -15
  112. package/src/gpu/kernels/gelu.js +19 -11
  113. package/src/gpu/kernels/grouped_pointwise_conv2d.js +33 -23
  114. package/src/gpu/kernels/groupnorm.js +34 -23
  115. package/src/gpu/kernels/kv-quantize.js +5 -2
  116. package/src/gpu/kernels/layernorm.js +35 -19
  117. package/src/gpu/kernels/logit-merge.js +5 -3
  118. package/src/gpu/kernels/matmul.js +58 -39
  119. package/src/gpu/kernels/modulate.js +23 -15
  120. package/src/gpu/kernels/moe.js +221 -175
  121. package/src/gpu/kernels/pixel_shuffle.js +22 -14
  122. package/src/gpu/kernels/relu.js +18 -10
  123. package/src/gpu/kernels/repeat_channels.js +25 -17
  124. package/src/gpu/kernels/residual.js +37 -27
  125. package/src/gpu/kernels/rmsnorm.js +57 -41
  126. package/src/gpu/kernels/rope.js +3 -0
  127. package/src/gpu/kernels/sample.js +27 -38
  128. package/src/gpu/kernels/sana_linear_attention.js +18 -10
  129. package/src/gpu/kernels/scale.js +18 -11
  130. package/src/gpu/kernels/shader-cache.js +4 -2
  131. package/src/gpu/kernels/silu.js +120 -72
  132. package/src/gpu/kernels/softmax.js +44 -25
  133. package/src/gpu/kernels/split_qkv.js +23 -13
  134. package/src/gpu/kernels/transpose.js +18 -10
  135. package/src/gpu/kernels/transpose.wgsl +5 -3
  136. package/src/gpu/kernels/upsample2d.js +21 -13
  137. package/src/gpu/kernels/utils.js +20 -13
  138. package/src/gpu/partitioned-buffer-pool.js +10 -2
  139. package/src/gpu/perf-guards.js +2 -9
  140. package/src/gpu/profiler.js +27 -22
  141. package/src/gpu/readback-utils.d.ts +16 -0
  142. package/src/gpu/readback-utils.js +41 -0
  143. package/src/gpu/submit-tracker.js +13 -0
  144. package/src/gpu/uniform-cache.d.ts +1 -0
  145. package/src/gpu/uniform-cache.js +30 -9
  146. package/src/hotswap/intent-bundle.js +6 -0
  147. package/src/hotswap/manifest.d.ts +10 -1
  148. package/src/hotswap/manifest.js +12 -2
  149. package/src/hotswap/runtime.js +30 -8
  150. package/src/index-browser.d.ts +44 -0
  151. package/src/index-browser.js +14 -0
  152. package/src/inference/browser-harness-contract-helpers.d.ts +5 -0
  153. package/src/inference/browser-harness-contract-helpers.js +28 -0
  154. package/src/inference/browser-harness-diffusion-energy-suites.d.ts +2 -0
  155. package/src/inference/browser-harness-diffusion-energy-suites.js +269 -0
  156. package/src/inference/browser-harness-model-helpers.d.ts +16 -0
  157. package/src/inference/browser-harness-model-helpers.js +217 -0
  158. package/src/inference/browser-harness-report-helpers.d.ts +7 -0
  159. package/src/inference/browser-harness-report-helpers.js +42 -0
  160. package/src/inference/browser-harness-runtime-helpers.d.ts +61 -0
  161. package/src/inference/browser-harness-runtime-helpers.js +415 -0
  162. package/src/inference/browser-harness-suite-helpers.d.ts +28 -0
  163. package/src/inference/browser-harness-suite-helpers.js +268 -0
  164. package/src/inference/browser-harness-text-helpers.d.ts +27 -0
  165. package/src/inference/browser-harness-text-helpers.js +788 -0
  166. package/src/inference/browser-harness.d.ts +6 -0
  167. package/src/inference/browser-harness.js +130 -1996
  168. package/src/inference/kv-cache/base.js +140 -94
  169. package/src/inference/kv-cache/tiered.js +5 -3
  170. package/src/inference/moe-router.js +88 -56
  171. package/src/inference/multi-model-network.js +5 -3
  172. package/src/inference/network-evolution.d.ts +11 -2
  173. package/src/inference/network-evolution.js +20 -21
  174. package/src/inference/pipelines/context.d.ts +3 -0
  175. package/src/inference/pipelines/context.js +142 -2
  176. package/src/inference/pipelines/diffusion/helpers.js +7 -2
  177. package/src/inference/pipelines/diffusion/pipeline.js +2 -1
  178. package/src/inference/pipelines/diffusion/sd3-transformer.js +10 -10
  179. package/src/inference/pipelines/diffusion/vae.js +3 -7
  180. package/src/inference/pipelines/energy/pipeline.js +27 -21
  181. package/src/inference/pipelines/energy/quintel.d.ts +5 -0
  182. package/src/inference/pipelines/energy/quintel.js +11 -0
  183. package/src/inference/pipelines/energy-head/row-head-pipeline.js +17 -13
  184. package/src/inference/pipelines/structured/json-head-pipeline.js +26 -11
  185. package/src/inference/pipelines/text/attention/projections.js +151 -101
  186. package/src/inference/pipelines/text/attention/record.js +62 -8
  187. package/src/inference/pipelines/text/attention/run.js +62 -8
  188. package/src/inference/pipelines/text/config.js +3 -4
  189. package/src/inference/pipelines/text/embed.js +2 -8
  190. package/src/inference/pipelines/text/execution-plan.js +41 -19
  191. package/src/inference/pipelines/text/execution-v0-contract-helpers.d.ts +59 -0
  192. package/src/inference/pipelines/text/execution-v0-contract-helpers.js +937 -0
  193. package/src/inference/pipelines/text/execution-v0-runtime-builders.d.ts +15 -0
  194. package/src/inference/pipelines/text/execution-v0-runtime-builders.js +279 -0
  195. package/src/inference/pipelines/text/execution-v0.js +62 -1013
  196. package/src/inference/pipelines/text/generator-steps.d.ts +46 -0
  197. package/src/inference/pipelines/text/generator-steps.js +298 -207
  198. package/src/inference/pipelines/text/generator.js +6 -23
  199. package/src/inference/pipelines/text/init.js +78 -20
  200. package/src/inference/pipelines/text/kernel-path-auto-select.js +2 -0
  201. package/src/inference/pipelines/text/kernel-trace.d.ts +2 -0
  202. package/src/inference/pipelines/text/kernel-trace.js +6 -0
  203. package/src/inference/pipelines/text/layer.js +3 -9
  204. package/src/inference/pipelines/text/linear-attention.d.ts +10 -0
  205. package/src/inference/pipelines/text/linear-attention.js +80 -6
  206. package/src/inference/pipelines/text/logits/gpu.js +10 -5
  207. package/src/inference/pipelines/text/logits/index.js +10 -11
  208. package/src/inference/pipelines/text/logits/utils.d.ts +7 -0
  209. package/src/inference/pipelines/text/logits/utils.js +9 -0
  210. package/src/inference/pipelines/text/lora-apply.js +50 -32
  211. package/src/inference/pipelines/text/model-load.js +279 -104
  212. package/src/inference/pipelines/text/moe-cache.js +5 -4
  213. package/src/inference/pipelines/text/moe-cpu-gptoss.js +74 -69
  214. package/src/inference/pipelines/text/moe-cpu.js +42 -38
  215. package/src/inference/pipelines/text/moe-gpu.js +110 -86
  216. package/src/inference/pipelines/text/ops.js +90 -90
  217. package/src/inference/pipelines/text/probes.js +9 -9
  218. package/src/inference/pipelines/text/weights.js +17 -7
  219. package/src/inference/pipelines/text.js +13 -1
  220. package/src/inference/speculative.d.ts +2 -2
  221. package/src/inference/speculative.js +4 -18
  222. package/src/inference/test-harness.d.ts +1 -1
  223. package/src/inference/test-harness.js +15 -5
  224. package/src/inference/tokenizer.d.ts +0 -5
  225. package/src/inference/tokenizer.js +4 -23
  226. package/src/inference/tokenizers/bpe.js +9 -0
  227. package/src/inference/tokenizers/bundled.js +20 -0
  228. package/src/inference/tokenizers/sentencepiece.js +12 -0
  229. package/src/loader/doppler-loader.js +38 -22
  230. package/src/loader/dtype-utils.js +3 -44
  231. package/src/loader/embedding-loader.js +7 -3
  232. package/src/loader/experts/expert-cache.js +13 -6
  233. package/src/loader/experts/expert-loader.js +10 -6
  234. package/src/loader/final-weights-loader.js +8 -4
  235. package/src/loader/layer-loader.js +2 -1
  236. package/src/loader/loader-state.js +2 -2
  237. package/src/loader/memory-monitor.js +8 -0
  238. package/src/loader/multi-model-loader.d.ts +14 -0
  239. package/src/loader/multi-model-loader.js +70 -24
  240. package/src/loader/shard-cache.js +81 -12
  241. package/src/loader/shard-resolver.js +25 -3
  242. package/src/loader/tensors/tensor-loader.js +209 -144
  243. package/src/loader/tensors/tensor-reader.js +76 -19
  244. package/src/loader/weight-downcast.js +1 -1
  245. package/src/memory/buffer-pool.d.ts +9 -1
  246. package/src/memory/buffer-pool.js +109 -44
  247. package/src/memory/unified-detect.js +1 -1
  248. package/src/rules/inference/kernel-path.rules.json +24 -8
  249. package/src/rules/rule-registry.js +25 -1
  250. package/src/storage/backends/opfs-store.js +68 -24
  251. package/src/storage/downloader.js +364 -83
  252. package/src/storage/index.d.ts +3 -0
  253. package/src/storage/index.js +3 -0
  254. package/src/storage/preflight.d.ts +2 -2
  255. package/src/storage/preflight.js +24 -2
  256. package/src/storage/quickstart-downloader.js +11 -5
  257. package/src/storage/registry.js +10 -4
  258. package/src/storage/reports.js +1 -1
  259. package/src/storage/shard-manager.d.ts +15 -1
  260. package/src/storage/shard-manager.js +51 -3
  261. package/src/storage/source-artifact-store.d.ts +52 -0
  262. package/src/storage/source-artifact-store.js +234 -0
  263. package/src/tooling/command-api-constants.d.ts +9 -0
  264. package/src/tooling/command-api-constants.js +9 -0
  265. package/src/tooling/command-api-family-normalizers.d.ts +9 -0
  266. package/src/tooling/command-api-family-normalizers.js +343 -0
  267. package/src/tooling/command-api-helpers.d.ts +25 -0
  268. package/src/tooling/command-api-helpers.js +262 -0
  269. package/src/tooling/command-api.js +16 -602
  270. package/src/tooling/command-envelope.js +4 -1
  271. package/src/tooling/command-runner-shared.js +52 -18
  272. package/src/tooling/lean-execution-contract.js +150 -3
  273. package/src/tooling/node-browser-command-runner.js +161 -271
  274. package/src/tooling/node-command-runner.js +29 -3
  275. package/src/tooling/node-converter.js +27 -1
  276. package/src/tooling/node-source-runtime.d.ts +1 -1
  277. package/src/tooling/node-source-runtime.js +84 -3
  278. package/src/tooling/node-webgpu.js +24 -21
  279. package/src/tooling/opfs-cache.js +21 -4
  280. package/src/tooling/runtime-input-composition.d.ts +38 -0
  281. package/src/tooling/runtime-input-composition.js +86 -0
  282. package/src/tooling/source-runtime-bundle.d.ts +40 -5
  283. package/src/tooling/source-runtime-bundle.js +261 -34
  284. package/src/tooling/source-runtime-materializer.d.ts +6 -0
  285. package/src/tooling/source-runtime-materializer.js +93 -0
  286. package/src/training/attention-backward.js +32 -17
  287. package/src/training/autograd.js +80 -52
  288. package/src/training/checkpoint-watch.d.ts +2 -1
  289. package/src/training/checkpoint-watch.js +39 -6
  290. package/src/training/checkpoint.js +40 -11
  291. package/src/training/clip.js +2 -1
  292. package/src/training/datasets/token-batch.js +20 -8
  293. package/src/training/distillation/checkpoint-watch.js +1 -0
  294. package/src/training/distillation/student-fixture.d.ts +22 -0
  295. package/src/training/distillation/student-fixture.js +846 -0
  296. package/src/training/distillation/suite-data.d.ts +45 -0
  297. package/src/training/distillation/suite-data.js +189 -0
  298. package/src/training/lora-pipeline.js +4 -7
  299. package/src/training/lora.js +26 -12
  300. package/src/training/loss.js +5 -6
  301. package/src/training/objectives/cross_entropy.js +2 -5
  302. package/src/training/objectives/distill_kd.js +4 -8
  303. package/src/training/objectives/distill_triplet.js +4 -8
  304. package/src/training/objectives/ul_stage2_base.js +4 -8
  305. package/src/training/operator-command.js +2 -0
  306. package/src/training/optimizer.js +19 -7
  307. package/src/training/runner.js +2 -1
  308. package/src/training/suite.js +18 -978
  309. package/src/training/tensor-factory.d.ts +9 -0
  310. package/src/training/tensor-factory.js +13 -0
  311. package/src/training/trainer.js +3 -5
  312. package/src/training/ul_dataset.js +3 -5
  313. package/src/training/workloads.js +70 -79
  314. package/src/version.js +1 -1
  315. package/tools/convert-safetensors-node.js +22 -16
  316. package/tools/doppler-cli.js +44 -25
@@ -0,0 +1,45 @@
1
+ export interface DistillDataScope {
2
+ sourceLangs: string[] | null;
3
+ targetLangs: string[] | null;
4
+ pairAllowlist: string[] | null;
5
+ sourceLangSet: Set<string> | null;
6
+ targetLangSet: Set<string> | null;
7
+ pairAllowlistSet: Set<string> | null;
8
+ strictPairContract: boolean;
9
+ }
10
+
11
+ export interface DistillSample {
12
+ index?: number;
13
+ direction?: string | null;
14
+ sourceLang?: string | null;
15
+ targetLang?: string | null;
16
+ source?: string | null;
17
+ targetPos?: string | null;
18
+ targetNeg?: string | null;
19
+ }
20
+
21
+ export declare function normalizeOptionalString(value: unknown): string | null;
22
+
23
+ export declare function normalizeDistillDatasetPath(value: unknown): string | null;
24
+
25
+ export declare function resolveDistillDataScope(
26
+ options?: Record<string, unknown>,
27
+ trainingConfig?: Record<string, unknown> | null
28
+ ): DistillDataScope;
29
+
30
+ export declare function encodeDistillRow(
31
+ record: Record<string, unknown> | null | undefined,
32
+ index: number,
33
+ scope?: DistillDataScope | null
34
+ ): DistillSample | null;
35
+
36
+ export declare function summarizeDirectionCounts(
37
+ samples: Array<Record<string, unknown> | null | undefined>
38
+ ): Record<string, number>;
39
+
40
+ export declare function buildDistillPrompt(sample: Record<string, unknown> | null | undefined): string;
41
+
42
+ export declare function buildDistillCandidatePrompt(
43
+ sample: Record<string, unknown> | null | undefined,
44
+ candidate: unknown
45
+ ): string;
@@ -0,0 +1,189 @@
1
+ export function normalizeOptionalString(value) {
2
+ if (value === undefined || value === null) return null;
3
+ const trimmed = String(value).trim();
4
+ return trimmed || null;
5
+ }
6
+
7
+ export function normalizeDistillDatasetPath(value) {
8
+ return normalizeOptionalString(value);
9
+ }
10
+
11
+ function normalizeLangCode(value) {
12
+ const normalized = normalizeOptionalString(value);
13
+ if (!normalized) return null;
14
+ const compact = normalized.toLowerCase().replace(/_/g, '-');
15
+ if (compact.startsWith('en')) return 'en';
16
+ if (compact.startsWith('es')) return 'es';
17
+ return compact;
18
+ }
19
+
20
+ function normalizePairDirection(value) {
21
+ const pair = normalizeOptionalString(value);
22
+ if (!pair) return null;
23
+ const normalized = pair.toLowerCase().replace(/_/g, '-').replace(/\s+/g, '');
24
+ const parts = normalized.includes('->')
25
+ ? normalized.split('->').filter(Boolean)
26
+ : normalized.split('-').filter(Boolean);
27
+ if (parts.length !== 2) return null;
28
+ return `${normalizeLangCode(parts[0]) || parts[0]}->${normalizeLangCode(parts[1]) || parts[1]}`;
29
+ }
30
+
31
+ function normalizeOptionalStringArray(value) {
32
+ if (value === undefined || value === null) return null;
33
+ const list = Array.isArray(value)
34
+ ? value
35
+ : (typeof value === 'string' ? value.split(',') : null);
36
+ if (!Array.isArray(list)) return null;
37
+ const normalized = list
38
+ .map((entry) => normalizeOptionalString(entry))
39
+ .filter(Boolean);
40
+ return normalized.length > 0 ? normalized : null;
41
+ }
42
+
43
+ function normalizeDistillLanguageAllowlist(value) {
44
+ const list = normalizeOptionalStringArray(value);
45
+ if (!list) return null;
46
+ const normalized = list
47
+ .map((entry) => normalizeLangCode(entry))
48
+ .filter(Boolean);
49
+ if (normalized.length === 0) return null;
50
+ return [...new Set(normalized)];
51
+ }
52
+
53
+ function normalizeDistillPairAllowlist(value) {
54
+ const list = normalizeOptionalStringArray(value);
55
+ if (!list) return null;
56
+ const normalized = list
57
+ .map((entry) => normalizePairDirection(entry))
58
+ .filter(Boolean);
59
+ if (normalized.length === 0) return null;
60
+ return [...new Set(normalized)];
61
+ }
62
+
63
+ export function resolveDistillDataScope(options = {}, trainingConfig = null) {
64
+ const distillConfig = trainingConfig?.distill || {};
65
+ const sourceLangs = normalizeDistillLanguageAllowlist(
66
+ options.distillSourceLangs ?? distillConfig.sourceLangs ?? null
67
+ );
68
+ const targetLangs = normalizeDistillLanguageAllowlist(
69
+ options.distillTargetLangs ?? distillConfig.targetLangs ?? null
70
+ );
71
+ const pairAllowlist = normalizeDistillPairAllowlist(
72
+ options.distillPairAllowlist ?? distillConfig.pairAllowlist ?? null
73
+ );
74
+ const strictPairContract = (
75
+ options.strictPairContract === true
76
+ || distillConfig.strictPairContract === true
77
+ );
78
+ return {
79
+ sourceLangs,
80
+ targetLangs,
81
+ pairAllowlist,
82
+ sourceLangSet: sourceLangs ? new Set(sourceLangs) : null,
83
+ targetLangSet: targetLangs ? new Set(targetLangs) : null,
84
+ pairAllowlistSet: pairAllowlist ? new Set(pairAllowlist) : null,
85
+ strictPairContract,
86
+ };
87
+ }
88
+
89
+ function resolveDistillDirection(record) {
90
+ const pairDirection = normalizePairDirection(record?.pair);
91
+ if (pairDirection) return pairDirection;
92
+ const srcLang = normalizeLangCode(record?.src_lang);
93
+ const tgtLang = normalizeLangCode(record?.tgt_lang || record?.lang);
94
+ if (srcLang && tgtLang) {
95
+ return `${srcLang}->${tgtLang}`;
96
+ }
97
+ return null;
98
+ }
99
+
100
+ function resolveStringCandidate(record, keys) {
101
+ for (const key of keys) {
102
+ const value = normalizeOptionalString(record?.[key]);
103
+ if (value) return value;
104
+ }
105
+ return null;
106
+ }
107
+
108
+ export function encodeDistillRow(record, index, scope = null) {
109
+ if (!record || typeof record !== 'object') return null;
110
+ const source = resolveStringCandidate(record, ['source', 'query']);
111
+ const targetPos = resolveStringCandidate(record, ['target_pos', 'target', 'pos']);
112
+ const targetNeg = resolveStringCandidate(record, ['target_neg', 'neg']);
113
+ if (!source || !targetPos) return null;
114
+ const sourceLangRaw = normalizeLangCode(record?.src_lang);
115
+ const targetLangRaw = normalizeLangCode(record?.tgt_lang || record?.lang);
116
+ const pairDirection = normalizePairDirection(record?.pair);
117
+ const sourceTargetDirection = (
118
+ sourceLangRaw && targetLangRaw
119
+ ? `${sourceLangRaw}->${targetLangRaw}`
120
+ : null
121
+ );
122
+ if (scope?.strictPairContract === true) {
123
+ if (!sourceLangRaw || !targetLangRaw) {
124
+ throw new Error('strictPairContract requires src_lang and tgt_lang/lang on each row.');
125
+ }
126
+ if (!pairDirection) {
127
+ throw new Error('strictPairContract requires pair on each row.');
128
+ }
129
+ if (pairDirection !== sourceTargetDirection) {
130
+ throw new Error(`pair "${record?.pair}" does not match src/tgt "${sourceLangRaw}-${targetLangRaw}".`);
131
+ }
132
+ }
133
+ const direction = pairDirection || sourceTargetDirection || resolveDistillDirection(record) || 'unknown';
134
+ const [directionSourceLang, directionTargetLang] = String(direction).split('->');
135
+ const sourceLang = sourceLangRaw || normalizeLangCode(directionSourceLang);
136
+ const targetLang = targetLangRaw || normalizeLangCode(directionTargetLang);
137
+ if (scope?.sourceLangSet && (!sourceLang || !scope.sourceLangSet.has(sourceLang))) {
138
+ return null;
139
+ }
140
+ if (scope?.targetLangSet && (!targetLang || !scope.targetLangSet.has(targetLang))) {
141
+ return null;
142
+ }
143
+ if (scope?.pairAllowlistSet && !scope.pairAllowlistSet.has(direction)) {
144
+ return null;
145
+ }
146
+
147
+ return {
148
+ index,
149
+ direction,
150
+ sourceLang: sourceLang || null,
151
+ targetLang: targetLang || null,
152
+ source,
153
+ targetPos,
154
+ targetNeg: targetNeg || null,
155
+ };
156
+ }
157
+
158
+ export function summarizeDirectionCounts(samples) {
159
+ const counts = {};
160
+ for (const sample of samples) {
161
+ const key = sample?.direction || 'unknown';
162
+ counts[key] = (counts[key] || 0) + 1;
163
+ }
164
+ return counts;
165
+ }
166
+
167
+ function resolveLanguageName(langCode) {
168
+ const normalized = normalizeLangCode(langCode);
169
+ if (normalized === 'en') return 'English';
170
+ if (normalized === 'es') return 'Spanish';
171
+ return normalized || 'target';
172
+ }
173
+
174
+ export function buildDistillPrompt(sample) {
175
+ const direction = String(sample?.direction || '').trim();
176
+ const [srcCodeRaw, tgtCodeRaw] = direction.split('->');
177
+ const srcCode = normalizeLangCode(srcCodeRaw) || srcCodeRaw || 'source';
178
+ const tgtCode = normalizeLangCode(tgtCodeRaw) || tgtCodeRaw || 'target';
179
+ const srcName = resolveLanguageName(srcCode);
180
+ const tgtName = resolveLanguageName(tgtCode);
181
+ const source = String(sample?.source || '').trim();
182
+ return `Translate from ${srcName} to ${tgtName}:\n${source}\nTranslation:`;
183
+ }
184
+
185
+ export function buildDistillCandidatePrompt(sample, candidate) {
186
+ const base = buildDistillPrompt(sample);
187
+ const text = String(candidate || '').trim();
188
+ return text ? `${base} ${text}` : base;
189
+ }
@@ -3,7 +3,6 @@ import { join, resolve } from 'node:path';
3
3
 
4
4
  import { loadBackwardRegistry } from '../config/backward-registry-loader.js';
5
5
  import { acquireBuffer, readBuffer, releaseBuffer, uploadData } from '../memory/buffer-pool.js';
6
- import { createTensor } from '../gpu/tensor.js';
7
6
  import { runMatmul } from '../gpu/kernels/index.js';
8
7
  import { runResidualAdd } from '../gpu/kernels/residual.js';
9
8
  import { parseJsonl } from './datasets/jsonl.js';
@@ -27,6 +26,7 @@ import {
27
26
  } from './operator-artifacts.js';
28
27
  import { watchFinalizedCheckpoints } from './checkpoint-watch.js';
29
28
  import { loadLoRAFromManifest } from '../adapters/lora-loader.js';
29
+ import { createUploadedTensor } from './tensor-factory.js';
30
30
 
31
31
  function stableSortObject(value) {
32
32
  if (Array.isArray(value)) {
@@ -48,16 +48,12 @@ function stableJson(value) {
48
48
 
49
49
  function makeTensorFromFloat32(values, shape, label) {
50
50
  const data = values instanceof Float32Array ? values : new Float32Array(values);
51
- const buffer = acquireBuffer(data.byteLength, undefined, label);
52
- uploadData(buffer, data);
53
- return createTensor(buffer, 'f32', [...shape], label);
51
+ return createUploadedTensor(data, 'f32', shape, label);
54
52
  }
55
53
 
56
54
  function makeTensorFromUint32(values, shape, label) {
57
55
  const data = values instanceof Uint32Array ? values : new Uint32Array(values);
58
- const buffer = acquireBuffer(data.byteLength, undefined, label);
59
- uploadData(buffer, data);
60
- return createTensor(buffer, 'u32', [...shape], label);
56
+ return createUploadedTensor(data, 'u32', shape, label);
61
57
  }
62
58
 
63
59
  function releaseTensor(tensor) {
@@ -709,6 +705,7 @@ export async function watchLoraCheckpoints(options) {
709
705
  manifestPath: join(options.runRoot, 'scoreboard', 'watch-manifest.json'),
710
706
  pollIntervalMs: options.pollIntervalMs || 2000,
711
707
  stopWhenIdle: options.stopWhenIdle === true,
708
+ signal: options.signal ?? null,
712
709
  onCheckpoint: async (markerPath) => {
713
710
  const raw = await readFile(markerPath, 'utf8');
714
711
  const marker = JSON.parse(raw);
@@ -12,18 +12,32 @@ export class LoraAdapter {
12
12
  const aBytes = tensorBytes([inDim, rank], dtype);
13
13
  const bBytes = tensorBytes([rank, outDim], dtype);
14
14
 
15
- this.A = createTensor(
16
- acquireBuffer(aBytes, BufferUsage.STORAGE, 'lora_A'),
17
- dtype,
18
- [inDim, rank],
19
- 'lora_A'
20
- );
21
- this.B = createTensor(
22
- acquireBuffer(bBytes, BufferUsage.STORAGE, 'lora_B'),
23
- dtype,
24
- [rank, outDim],
25
- 'lora_B'
26
- );
15
+ let aBuffer = null;
16
+ let bBuffer = null;
17
+ try {
18
+ aBuffer = acquireBuffer(aBytes, BufferUsage.STORAGE, 'lora_A');
19
+ bBuffer = acquireBuffer(bBytes, BufferUsage.STORAGE, 'lora_B');
20
+ this.A = createTensor(
21
+ aBuffer,
22
+ dtype,
23
+ [inDim, rank],
24
+ 'lora_A'
25
+ );
26
+ this.B = createTensor(
27
+ bBuffer,
28
+ dtype,
29
+ [rank, outDim],
30
+ 'lora_B'
31
+ );
32
+ } catch (error) {
33
+ if (aBuffer) {
34
+ releaseBuffer(aBuffer);
35
+ }
36
+ if (bBuffer) {
37
+ releaseBuffer(bBuffer);
38
+ }
39
+ throw error;
40
+ }
27
41
  this.alpha = alpha;
28
42
  this.rank = rank;
29
43
  }
@@ -1,6 +1,5 @@
1
1
 
2
2
  import { runSoftmax, runCrossEntropyLoss, castF16ToF32 } from '../gpu/kernels/index.js';
3
- import { releaseBuffer } from '../memory/buffer-pool.js';
4
3
  import { OpType } from './autograd.js';
5
4
 
6
5
  export async function crossEntropyLoss(logits, targets, config, tape) {
@@ -25,13 +24,13 @@ export async function crossEntropyLoss(logits, targets, config, tape) {
25
24
  OpType.SOFTMAX,
26
25
  (input) => runSoftmax(input, -1, { batchSize: numTokens, size: vocabSize }),
27
26
  [logitsF32],
28
- { rows: numTokens, cols: vocabSize }
27
+ {
28
+ rows: numTokens,
29
+ cols: vocabSize,
30
+ retainBuffers: logitsF32 !== logits ? [logitsF32.buffer] : [],
31
+ }
29
32
  );
30
33
 
31
- if (logitsF32 !== logits) {
32
- releaseBuffer(logitsF32.buffer);
33
- }
34
-
35
34
  return tape.record(
36
35
  OpType.CROSS_ENTROPY,
37
36
  (input, target) => runCrossEntropyLoss(input, target, { numTokens, vocabSize }),
@@ -1,15 +1,12 @@
1
1
  import { crossEntropyLoss as defaultCrossEntropyLoss } from '../loss.js';
2
- import { acquireBuffer, uploadData } from '../../memory/buffer-pool.js';
3
- import { createTensor } from '../../gpu/tensor.js';
4
2
  import { createTrainingObjective } from './base.js';
3
+ import { createUploadedTensor } from '../tensor-factory.js';
5
4
 
6
5
  function createLossGradient(loss, lossScale) {
7
6
  const lossElements = loss.shape.reduce((acc, value) => acc * value, 1);
8
7
  const gradData = new Float32Array(lossElements);
9
8
  gradData.fill(lossScale);
10
- const gradBuf = acquireBuffer(gradData.byteLength, undefined, 'loss_grad_output');
11
- uploadData(gradBuf, gradData);
12
- return createTensor(gradBuf, 'f32', [...loss.shape], 'loss_grad_output');
9
+ return createUploadedTensor(gradData, 'f32', loss.shape, 'loss_grad_output');
13
10
  }
14
11
 
15
12
  export function createCrossEntropyObjective(options = {}) {
@@ -1,8 +1,8 @@
1
1
  import { crossEntropyLoss as defaultCrossEntropyLoss } from '../loss.js';
2
2
  import { createTrainingObjective } from './base.js';
3
- import { acquireBuffer, readBuffer, uploadData } from '../../memory/buffer-pool.js';
4
- import { createTensor } from '../../gpu/tensor.js';
3
+ import { readBuffer } from '../../memory/buffer-pool.js';
5
4
  import { f16ToF32Array, f32ToF16Array } from '../../inference/kv-cache/types.js';
5
+ import { createUploadedTensor } from '../tensor-factory.js';
6
6
 
7
7
  const EPS = 1e-8;
8
8
 
@@ -31,9 +31,7 @@ function createLossGradient(loss, lossScale) {
31
31
  const lossElements = loss.shape.reduce((acc, value) => acc * value, 1);
32
32
  const gradData = new Float32Array(lossElements);
33
33
  gradData.fill(lossScale);
34
- const gradBuf = acquireBuffer(gradData.byteLength, undefined, 'distill_kd_loss_grad_output');
35
- uploadData(gradBuf, gradData);
36
- return createTensor(gradBuf, 'f32', [...loss.shape], 'distill_kd_loss_grad_output');
34
+ return createUploadedTensor(gradData, 'f32', loss.shape, 'distill_kd_loss_grad_output');
37
35
  }
38
36
 
39
37
  function createGradientTensor(values, shape, dtype, label) {
@@ -42,9 +40,7 @@ function createGradientTensor(values, shape, dtype, label) {
42
40
  const payload = tensorDtype === 'f16'
43
41
  ? f32ToF16Array(floatValues)
44
42
  : floatValues;
45
- const gradBuf = acquireBuffer(payload.byteLength, undefined, label);
46
- uploadData(gradBuf, payload);
47
- return createTensor(gradBuf, tensorDtype, [...shape], label);
43
+ return createUploadedTensor(payload, tensorDtype, shape, label);
48
44
  }
49
45
 
50
46
  async function readLogitsRows(logitsTensor) {
@@ -1,8 +1,8 @@
1
1
  import { crossEntropyLoss as defaultCrossEntropyLoss } from '../loss.js';
2
2
  import { createTrainingObjective } from './base.js';
3
- import { acquireBuffer, readBuffer, uploadData } from '../../memory/buffer-pool.js';
4
- import { createTensor } from '../../gpu/tensor.js';
3
+ import { readBuffer } from '../../memory/buffer-pool.js';
5
4
  import { f16ToF32Array, f32ToF16Array } from '../../inference/kv-cache/types.js';
5
+ import { createUploadedTensor } from '../tensor-factory.js';
6
6
 
7
7
  function toFinite(value, fallback) {
8
8
  const parsed = Number(value);
@@ -29,9 +29,7 @@ function createLossGradient(loss, lossScale) {
29
29
  const lossElements = loss.shape.reduce((acc, value) => acc * value, 1);
30
30
  const gradData = new Float32Array(lossElements);
31
31
  gradData.fill(lossScale);
32
- const gradBuf = acquireBuffer(gradData.byteLength, undefined, 'distill_triplet_loss_grad_output');
33
- uploadData(gradBuf, gradData);
34
- return createTensor(gradBuf, 'f32', [...loss.shape], 'distill_triplet_loss_grad_output');
32
+ return createUploadedTensor(gradData, 'f32', loss.shape, 'distill_triplet_loss_grad_output');
35
33
  }
36
34
 
37
35
  function createGradientTensor(values, shape, dtype, label) {
@@ -40,9 +38,7 @@ function createGradientTensor(values, shape, dtype, label) {
40
38
  const payload = tensorDtype === 'f16'
41
39
  ? f32ToF16Array(floatValues)
42
40
  : floatValues;
43
- const gradBuf = acquireBuffer(payload.byteLength, undefined, label);
44
- uploadData(gradBuf, payload);
45
- return createTensor(gradBuf, tensorDtype, [...shape], label);
41
+ return createUploadedTensor(payload, tensorDtype, shape, label);
46
42
  }
47
43
 
48
44
  async function readLogitsRows(logitsTensor) {
@@ -1,7 +1,7 @@
1
1
  import { crossEntropyLoss as defaultCrossEntropyLoss } from '../loss.js';
2
2
  import { createTrainingObjective } from './base.js';
3
- import { acquireBuffer, uploadData, releaseBuffer } from '../../memory/buffer-pool.js';
4
- import { createTensor } from '../../gpu/tensor.js';
3
+ import { releaseBuffer } from '../../memory/buffer-pool.js';
4
+ import { createUploadedTensor } from '../tensor-factory.js';
5
5
 
6
6
  function sigmoid(value) {
7
7
  return 1 / (1 + Math.exp(-value));
@@ -9,17 +9,13 @@ function sigmoid(value) {
9
9
 
10
10
  function createF32Tensor(values, shape, label) {
11
11
  const data = values instanceof Float32Array ? values : new Float32Array(values);
12
- const buffer = acquireBuffer(data.byteLength, undefined, label);
13
- uploadData(buffer, data);
14
- return createTensor(buffer, 'f32', [...shape], label);
12
+ return createUploadedTensor(data, 'f32', shape, label);
15
13
  }
16
14
 
17
15
  function createU32TokenTensor(values, shape, label) {
18
16
  const data = values instanceof Uint32Array ? values : new Uint32Array(values);
19
- const buffer = acquireBuffer(data.byteLength, undefined, label);
20
- uploadData(buffer, data);
21
17
  // Token targets are consumed as raw u32 bytes by loss kernels.
22
- return createTensor(buffer, 'f32', [...shape], label);
18
+ return createUploadedTensor(data, 'f32', shape, label);
23
19
  }
24
20
 
25
21
  function releaseTensor(tensor) {
@@ -316,6 +316,7 @@ async function runDistillCommand(request) {
316
316
  layout: runArtifacts.layout,
317
317
  pollIntervalMs: request.pollIntervalMs || null,
318
318
  stopWhenIdle: request.stopWhenIdle === true,
319
+ signal: request.signal ?? null,
319
320
  })),
320
321
  };
321
322
  }
@@ -378,6 +379,7 @@ async function runLoraCommand(request) {
378
379
  runRoot: resolve(String(request.runRoot)),
379
380
  pollIntervalMs: request.pollIntervalMs || null,
380
381
  stopWhenIdle: request.stopWhenIdle === true,
382
+ signal: request.signal ?? null,
381
383
  })),
382
384
  };
383
385
  }
@@ -1,4 +1,4 @@
1
- import { acquireBuffer, BufferUsage } from '../memory/buffer-pool.js';
1
+ import { acquireBuffer, releaseBuffer, BufferUsage } from '../memory/buffer-pool.js';
2
2
  import { createTensor, tensorBytes } from '../gpu/tensor.js';
3
3
  import { runAdam } from '../gpu/kernels/backward/adam.js';
4
4
 
@@ -72,12 +72,24 @@ export class AdamOptimizer {
72
72
  let entry = this.state.get(param);
73
73
  if (!entry) {
74
74
  const bytes = tensorBytes(param.shape, param.dtype);
75
- const mBuf = acquireBuffer(bytes, BufferUsage.STORAGE, 'adam_m');
76
- const vBuf = acquireBuffer(bytes, BufferUsage.STORAGE, 'adam_v');
77
- entry = {
78
- m: createTensor(mBuf, param.dtype, [...param.shape], 'adam_m'),
79
- v: createTensor(vBuf, param.dtype, [...param.shape], 'adam_v'),
80
- };
75
+ let mBuf = null;
76
+ let vBuf = null;
77
+ try {
78
+ mBuf = acquireBuffer(bytes, BufferUsage.STORAGE, 'adam_m');
79
+ vBuf = acquireBuffer(bytes, BufferUsage.STORAGE, 'adam_v');
80
+ entry = {
81
+ m: createTensor(mBuf, param.dtype, [...param.shape], 'adam_m'),
82
+ v: createTensor(vBuf, param.dtype, [...param.shape], 'adam_v'),
83
+ };
84
+ } catch (error) {
85
+ if (mBuf) {
86
+ releaseBuffer(mBuf);
87
+ }
88
+ if (vBuf) {
89
+ releaseBuffer(vBuf);
90
+ }
91
+ throw error;
92
+ }
81
93
  this.state.set(param, entry);
82
94
  }
83
95
  return entry;
@@ -617,7 +617,6 @@ function buildExpectedCheckpointMetadata(metadata) {
617
617
  'configHash',
618
618
  'datasetHash',
619
619
  'tokenizerHash',
620
- 'optimizerHash',
621
620
  'runtimePresetId',
622
621
  'kernelPathId',
623
622
  ]) {
@@ -845,6 +844,8 @@ export class TrainingRunner {
845
844
  }
846
845
 
847
846
  async run(model, dataset, options = {}) {
847
+ this.lastCheckpoint = null;
848
+ this.lastArtifact = null;
848
849
  const {
849
850
  epochs = 1,
850
851
  batchSize = 1,