@simulatte/doppler 0.1.6 → 0.1.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (316) hide show
  1. package/CHANGELOG.md +126 -0
  2. package/README.md +16 -23
  3. package/package.json +14 -1
  4. package/src/adapters/adapter-registry.js +12 -1
  5. package/src/adapters/lora-loader.js +23 -6
  6. package/src/bridge/extension-client.d.ts +5 -0
  7. package/src/bridge/extension-client.js +40 -0
  8. package/src/bridge/index.d.ts +2 -1
  9. package/src/bridge/index.js +6 -4
  10. package/src/browser/browser-converter.js +26 -1
  11. package/src/browser/file-picker.js +6 -0
  12. package/src/browser/safetensors-parser-browser.js +84 -1
  13. package/src/browser/shard-io-browser.js +2 -2
  14. package/src/browser/tensor-source-download.js +8 -2
  15. package/src/browser/tensor-source-http.d.ts +1 -0
  16. package/src/browser/tensor-source-http.js +5 -1
  17. package/src/client/doppler-api.browser.js +20 -4
  18. package/src/client/doppler-api.js +19 -3
  19. package/src/client/doppler-provider/generation.js +12 -0
  20. package/src/client/doppler-provider/model-manager.d.ts +10 -0
  21. package/src/client/doppler-provider/model-manager.js +91 -19
  22. package/src/client/doppler-provider/source-runtime.d.ts +2 -1
  23. package/src/client/doppler-provider/source-runtime.js +132 -13
  24. package/src/client/doppler-registry.json +8 -7
  25. package/src/config/backward-registry-loader.js +17 -2
  26. package/src/config/execution-v0-contract-check.js +113 -15
  27. package/src/config/kernel-path-contract-check.js +57 -29
  28. package/src/config/kernel-path-loader.js +5 -36
  29. package/src/config/kernels/kernel-ref-digests.js +1 -1
  30. package/src/config/kernels/registry.js +14 -1
  31. package/src/config/kernels/registry.json +7 -5
  32. package/src/config/loader.d.ts +1 -1
  33. package/src/config/loader.js +12 -2
  34. package/src/config/merge-contract-check.js +59 -4
  35. package/src/config/merge-helpers.js +128 -7
  36. package/src/config/merge.d.ts +1 -0
  37. package/src/config/merge.js +10 -0
  38. package/src/config/param-validator.js +47 -2
  39. package/src/config/presets/kernel-paths/{gemma2-q4k-dequant-f32a.json → gemma2-q4k-dequant-f32a-nosubgroups.json} +3 -3
  40. package/src/config/presets/kernel-paths/gemma3-f16-fused-f32a-online-streamingprefill.json +223 -0
  41. package/src/config/presets/kernel-paths/{gemma3-q4k-dequant-f32a.json → gemma3-q4k-dequant-f32a-nosubgroups.json} +3 -3
  42. package/src/config/presets/kernel-paths/registry.json +29 -8
  43. package/src/config/presets/models/gemma2.json +2 -2
  44. package/src/config/presets/runtime/experiments/bench/gemma3-bench-q4k.json +1 -1
  45. package/src/config/presets/runtime/experiments/debug/gemma3-debug-q4k.json +1 -1
  46. package/src/config/presets/runtime/experiments/verify/gemma3-verify.json +1 -1
  47. package/src/config/presets/runtime/kernels/dequant-f16-q4k.json +6 -13
  48. package/src/config/presets/runtime/kernels/dequant-f32-q4k.json +6 -13
  49. package/src/config/presets/runtime/kernels/embeddinggemma-q4k-dequant-f32a.json +37 -0
  50. package/src/config/presets/runtime/kernels/fused-q4k.json +6 -13
  51. package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f16a.json +33 -0
  52. package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f32a-nosubgroups.json +33 -0
  53. package/src/config/presets/runtime/kernels/gemma2-q4k-fused-f32a.json +33 -0
  54. package/src/config/presets/runtime/kernels/safe-q4k.json +6 -13
  55. package/src/config/presets/runtime/platform/metal-apple-q4k.json +1 -1
  56. package/src/config/runtime.js +6 -1
  57. package/src/config/schema/debug.schema.d.ts +5 -0
  58. package/src/config/schema/doppler.schema.js +16 -21
  59. package/src/config/schema/inference-defaults.schema.js +3 -3
  60. package/src/config/schema/kernel-path.schema.d.ts +5 -1
  61. package/src/config/schema/kernel-thresholds.schema.js +12 -4
  62. package/src/config/schema/manifest.schema.d.ts +2 -1
  63. package/src/config/schema/manifest.schema.js +16 -3
  64. package/src/config/training-defaults.js +30 -22
  65. package/src/converter/conversion-plan.js +94 -9
  66. package/src/converter/core.d.ts +7 -0
  67. package/src/converter/core.js +14 -9
  68. package/src/converter/execution-v0-manifest.js +4 -1
  69. package/src/converter/index.d.ts +1 -0
  70. package/src/converter/index.js +1 -0
  71. package/src/converter/manifest-inference.js +43 -12
  72. package/src/converter/parsers/diffusion.js +0 -3
  73. package/src/converter/quantization-info.js +35 -15
  74. package/src/converter/shard-packer.d.ts +1 -1
  75. package/src/converter/shard-packer.js +4 -1
  76. package/src/debug/config.js +123 -11
  77. package/src/debug/signals.js +7 -1
  78. package/src/debug/tensor.d.ts +2 -0
  79. package/src/debug/tensor.js +13 -2
  80. package/src/distribution/p2p-control-plane.js +52 -12
  81. package/src/distribution/p2p-observability.js +43 -7
  82. package/src/distribution/p2p-webrtc-browser.js +20 -0
  83. package/src/distribution/shard-delivery.js +77 -26
  84. package/src/formats/gguf/types.js +33 -16
  85. package/src/formats/rdrr/groups.d.ts +12 -4
  86. package/src/formats/rdrr/groups.js +3 -6
  87. package/src/formats/rdrr/parsing.js +39 -2
  88. package/src/formats/rdrr/types.d.ts +2 -1
  89. package/src/gpu/command-recorder.js +86 -61
  90. package/src/gpu/device.d.ts +1 -0
  91. package/src/gpu/device.js +73 -19
  92. package/src/gpu/kernel-tuner/benchmarks.js +326 -316
  93. package/src/gpu/kernel-tuner/cache.js +71 -4
  94. package/src/gpu/kernel-tuner/tuner.js +22 -4
  95. package/src/gpu/kernels/attention.js +15 -34
  96. package/src/gpu/kernels/backward/adam.js +62 -58
  97. package/src/gpu/kernels/backward/attention_backward.js +257 -169
  98. package/src/gpu/kernels/backward/conv2d_backward.js +14 -1
  99. package/src/gpu/kernels/cast.js +191 -149
  100. package/src/gpu/kernels/check-stop.js +33 -44
  101. package/src/gpu/kernels/conv2d.js +27 -17
  102. package/src/gpu/kernels/cross_entropy_loss.js +21 -15
  103. package/src/gpu/kernels/depthwise_conv2d.js +36 -26
  104. package/src/gpu/kernels/dequant.js +178 -126
  105. package/src/gpu/kernels/energy.d.ts +3 -21
  106. package/src/gpu/kernels/energy.js +111 -88
  107. package/src/gpu/kernels/feature-check.js +1 -1
  108. package/src/gpu/kernels/fused_ffn.js +84 -65
  109. package/src/gpu/kernels/fused_matmul_residual.js +56 -33
  110. package/src/gpu/kernels/fused_matmul_rmsnorm.js +62 -45
  111. package/src/gpu/kernels/gather.js +33 -15
  112. package/src/gpu/kernels/gelu.js +19 -11
  113. package/src/gpu/kernels/grouped_pointwise_conv2d.js +33 -23
  114. package/src/gpu/kernels/groupnorm.js +34 -23
  115. package/src/gpu/kernels/kv-quantize.js +5 -2
  116. package/src/gpu/kernels/layernorm.js +35 -19
  117. package/src/gpu/kernels/logit-merge.js +5 -3
  118. package/src/gpu/kernels/matmul.js +58 -39
  119. package/src/gpu/kernels/modulate.js +23 -15
  120. package/src/gpu/kernels/moe.js +221 -175
  121. package/src/gpu/kernels/pixel_shuffle.js +22 -14
  122. package/src/gpu/kernels/relu.js +18 -10
  123. package/src/gpu/kernels/repeat_channels.js +25 -17
  124. package/src/gpu/kernels/residual.js +37 -27
  125. package/src/gpu/kernels/rmsnorm.js +57 -41
  126. package/src/gpu/kernels/rope.js +3 -0
  127. package/src/gpu/kernels/sample.js +27 -38
  128. package/src/gpu/kernels/sana_linear_attention.js +18 -10
  129. package/src/gpu/kernels/scale.js +18 -11
  130. package/src/gpu/kernels/shader-cache.js +4 -2
  131. package/src/gpu/kernels/silu.js +120 -72
  132. package/src/gpu/kernels/softmax.js +44 -25
  133. package/src/gpu/kernels/split_qkv.js +23 -13
  134. package/src/gpu/kernels/transpose.js +18 -10
  135. package/src/gpu/kernels/transpose.wgsl +5 -3
  136. package/src/gpu/kernels/upsample2d.js +21 -13
  137. package/src/gpu/kernels/utils.js +20 -13
  138. package/src/gpu/partitioned-buffer-pool.js +10 -2
  139. package/src/gpu/perf-guards.js +2 -9
  140. package/src/gpu/profiler.js +27 -22
  141. package/src/gpu/readback-utils.d.ts +16 -0
  142. package/src/gpu/readback-utils.js +41 -0
  143. package/src/gpu/submit-tracker.js +13 -0
  144. package/src/gpu/uniform-cache.d.ts +1 -0
  145. package/src/gpu/uniform-cache.js +30 -9
  146. package/src/hotswap/intent-bundle.js +6 -0
  147. package/src/hotswap/manifest.d.ts +10 -1
  148. package/src/hotswap/manifest.js +12 -2
  149. package/src/hotswap/runtime.js +30 -8
  150. package/src/index-browser.d.ts +44 -0
  151. package/src/index-browser.js +14 -0
  152. package/src/inference/browser-harness-contract-helpers.d.ts +5 -0
  153. package/src/inference/browser-harness-contract-helpers.js +28 -0
  154. package/src/inference/browser-harness-diffusion-energy-suites.d.ts +2 -0
  155. package/src/inference/browser-harness-diffusion-energy-suites.js +269 -0
  156. package/src/inference/browser-harness-model-helpers.d.ts +16 -0
  157. package/src/inference/browser-harness-model-helpers.js +217 -0
  158. package/src/inference/browser-harness-report-helpers.d.ts +7 -0
  159. package/src/inference/browser-harness-report-helpers.js +42 -0
  160. package/src/inference/browser-harness-runtime-helpers.d.ts +61 -0
  161. package/src/inference/browser-harness-runtime-helpers.js +415 -0
  162. package/src/inference/browser-harness-suite-helpers.d.ts +28 -0
  163. package/src/inference/browser-harness-suite-helpers.js +268 -0
  164. package/src/inference/browser-harness-text-helpers.d.ts +27 -0
  165. package/src/inference/browser-harness-text-helpers.js +788 -0
  166. package/src/inference/browser-harness.d.ts +6 -0
  167. package/src/inference/browser-harness.js +130 -1996
  168. package/src/inference/kv-cache/base.js +140 -94
  169. package/src/inference/kv-cache/tiered.js +5 -3
  170. package/src/inference/moe-router.js +88 -56
  171. package/src/inference/multi-model-network.js +5 -3
  172. package/src/inference/network-evolution.d.ts +11 -2
  173. package/src/inference/network-evolution.js +20 -21
  174. package/src/inference/pipelines/context.d.ts +3 -0
  175. package/src/inference/pipelines/context.js +142 -2
  176. package/src/inference/pipelines/diffusion/helpers.js +7 -2
  177. package/src/inference/pipelines/diffusion/pipeline.js +2 -1
  178. package/src/inference/pipelines/diffusion/sd3-transformer.js +10 -10
  179. package/src/inference/pipelines/diffusion/vae.js +3 -7
  180. package/src/inference/pipelines/energy/pipeline.js +27 -21
  181. package/src/inference/pipelines/energy/quintel.d.ts +5 -0
  182. package/src/inference/pipelines/energy/quintel.js +11 -0
  183. package/src/inference/pipelines/energy-head/row-head-pipeline.js +17 -13
  184. package/src/inference/pipelines/structured/json-head-pipeline.js +26 -11
  185. package/src/inference/pipelines/text/attention/projections.js +151 -101
  186. package/src/inference/pipelines/text/attention/record.js +62 -8
  187. package/src/inference/pipelines/text/attention/run.js +62 -8
  188. package/src/inference/pipelines/text/config.js +3 -4
  189. package/src/inference/pipelines/text/embed.js +2 -8
  190. package/src/inference/pipelines/text/execution-plan.js +41 -19
  191. package/src/inference/pipelines/text/execution-v0-contract-helpers.d.ts +59 -0
  192. package/src/inference/pipelines/text/execution-v0-contract-helpers.js +937 -0
  193. package/src/inference/pipelines/text/execution-v0-runtime-builders.d.ts +15 -0
  194. package/src/inference/pipelines/text/execution-v0-runtime-builders.js +279 -0
  195. package/src/inference/pipelines/text/execution-v0.js +62 -1013
  196. package/src/inference/pipelines/text/generator-steps.d.ts +46 -0
  197. package/src/inference/pipelines/text/generator-steps.js +298 -207
  198. package/src/inference/pipelines/text/generator.js +6 -23
  199. package/src/inference/pipelines/text/init.js +78 -20
  200. package/src/inference/pipelines/text/kernel-path-auto-select.js +2 -0
  201. package/src/inference/pipelines/text/kernel-trace.d.ts +2 -0
  202. package/src/inference/pipelines/text/kernel-trace.js +6 -0
  203. package/src/inference/pipelines/text/layer.js +3 -9
  204. package/src/inference/pipelines/text/linear-attention.d.ts +10 -0
  205. package/src/inference/pipelines/text/linear-attention.js +80 -6
  206. package/src/inference/pipelines/text/logits/gpu.js +10 -5
  207. package/src/inference/pipelines/text/logits/index.js +10 -11
  208. package/src/inference/pipelines/text/logits/utils.d.ts +7 -0
  209. package/src/inference/pipelines/text/logits/utils.js +9 -0
  210. package/src/inference/pipelines/text/lora-apply.js +50 -32
  211. package/src/inference/pipelines/text/model-load.js +279 -104
  212. package/src/inference/pipelines/text/moe-cache.js +5 -4
  213. package/src/inference/pipelines/text/moe-cpu-gptoss.js +74 -69
  214. package/src/inference/pipelines/text/moe-cpu.js +42 -38
  215. package/src/inference/pipelines/text/moe-gpu.js +110 -86
  216. package/src/inference/pipelines/text/ops.js +90 -90
  217. package/src/inference/pipelines/text/probes.js +9 -9
  218. package/src/inference/pipelines/text/weights.js +17 -7
  219. package/src/inference/pipelines/text.js +13 -1
  220. package/src/inference/speculative.d.ts +2 -2
  221. package/src/inference/speculative.js +4 -18
  222. package/src/inference/test-harness.d.ts +1 -1
  223. package/src/inference/test-harness.js +15 -5
  224. package/src/inference/tokenizer.d.ts +0 -5
  225. package/src/inference/tokenizer.js +4 -23
  226. package/src/inference/tokenizers/bpe.js +9 -0
  227. package/src/inference/tokenizers/bundled.js +20 -0
  228. package/src/inference/tokenizers/sentencepiece.js +12 -0
  229. package/src/loader/doppler-loader.js +38 -22
  230. package/src/loader/dtype-utils.js +3 -44
  231. package/src/loader/embedding-loader.js +7 -3
  232. package/src/loader/experts/expert-cache.js +13 -6
  233. package/src/loader/experts/expert-loader.js +10 -6
  234. package/src/loader/final-weights-loader.js +8 -4
  235. package/src/loader/layer-loader.js +2 -1
  236. package/src/loader/loader-state.js +2 -2
  237. package/src/loader/memory-monitor.js +8 -0
  238. package/src/loader/multi-model-loader.d.ts +14 -0
  239. package/src/loader/multi-model-loader.js +70 -24
  240. package/src/loader/shard-cache.js +81 -12
  241. package/src/loader/shard-resolver.js +25 -3
  242. package/src/loader/tensors/tensor-loader.js +209 -144
  243. package/src/loader/tensors/tensor-reader.js +76 -19
  244. package/src/loader/weight-downcast.js +1 -1
  245. package/src/memory/buffer-pool.d.ts +9 -1
  246. package/src/memory/buffer-pool.js +109 -44
  247. package/src/memory/unified-detect.js +1 -1
  248. package/src/rules/inference/kernel-path.rules.json +24 -8
  249. package/src/rules/rule-registry.js +25 -1
  250. package/src/storage/backends/opfs-store.js +68 -24
  251. package/src/storage/downloader.js +364 -83
  252. package/src/storage/index.d.ts +3 -0
  253. package/src/storage/index.js +3 -0
  254. package/src/storage/preflight.d.ts +2 -2
  255. package/src/storage/preflight.js +24 -2
  256. package/src/storage/quickstart-downloader.js +11 -5
  257. package/src/storage/registry.js +10 -4
  258. package/src/storage/reports.js +1 -1
  259. package/src/storage/shard-manager.d.ts +15 -1
  260. package/src/storage/shard-manager.js +51 -3
  261. package/src/storage/source-artifact-store.d.ts +52 -0
  262. package/src/storage/source-artifact-store.js +234 -0
  263. package/src/tooling/command-api-constants.d.ts +9 -0
  264. package/src/tooling/command-api-constants.js +9 -0
  265. package/src/tooling/command-api-family-normalizers.d.ts +9 -0
  266. package/src/tooling/command-api-family-normalizers.js +343 -0
  267. package/src/tooling/command-api-helpers.d.ts +25 -0
  268. package/src/tooling/command-api-helpers.js +262 -0
  269. package/src/tooling/command-api.js +16 -602
  270. package/src/tooling/command-envelope.js +4 -1
  271. package/src/tooling/command-runner-shared.js +52 -18
  272. package/src/tooling/lean-execution-contract.js +150 -3
  273. package/src/tooling/node-browser-command-runner.js +161 -271
  274. package/src/tooling/node-command-runner.js +29 -3
  275. package/src/tooling/node-converter.js +27 -1
  276. package/src/tooling/node-source-runtime.d.ts +1 -1
  277. package/src/tooling/node-source-runtime.js +84 -3
  278. package/src/tooling/node-webgpu.js +24 -21
  279. package/src/tooling/opfs-cache.js +21 -4
  280. package/src/tooling/runtime-input-composition.d.ts +38 -0
  281. package/src/tooling/runtime-input-composition.js +86 -0
  282. package/src/tooling/source-runtime-bundle.d.ts +40 -5
  283. package/src/tooling/source-runtime-bundle.js +261 -34
  284. package/src/tooling/source-runtime-materializer.d.ts +6 -0
  285. package/src/tooling/source-runtime-materializer.js +93 -0
  286. package/src/training/attention-backward.js +32 -17
  287. package/src/training/autograd.js +80 -52
  288. package/src/training/checkpoint-watch.d.ts +2 -1
  289. package/src/training/checkpoint-watch.js +39 -6
  290. package/src/training/checkpoint.js +40 -11
  291. package/src/training/clip.js +2 -1
  292. package/src/training/datasets/token-batch.js +20 -8
  293. package/src/training/distillation/checkpoint-watch.js +1 -0
  294. package/src/training/distillation/student-fixture.d.ts +22 -0
  295. package/src/training/distillation/student-fixture.js +846 -0
  296. package/src/training/distillation/suite-data.d.ts +45 -0
  297. package/src/training/distillation/suite-data.js +189 -0
  298. package/src/training/lora-pipeline.js +4 -7
  299. package/src/training/lora.js +26 -12
  300. package/src/training/loss.js +5 -6
  301. package/src/training/objectives/cross_entropy.js +2 -5
  302. package/src/training/objectives/distill_kd.js +4 -8
  303. package/src/training/objectives/distill_triplet.js +4 -8
  304. package/src/training/objectives/ul_stage2_base.js +4 -8
  305. package/src/training/operator-command.js +2 -0
  306. package/src/training/optimizer.js +19 -7
  307. package/src/training/runner.js +2 -1
  308. package/src/training/suite.js +18 -978
  309. package/src/training/tensor-factory.d.ts +9 -0
  310. package/src/training/tensor-factory.js +13 -0
  311. package/src/training/trainer.js +3 -5
  312. package/src/training/ul_dataset.js +3 -5
  313. package/src/training/workloads.js +70 -79
  314. package/src/version.js +1 -1
  315. package/tools/convert-safetensors-node.js +22 -16
  316. package/tools/doppler-cli.js +44 -25
@@ -14,11 +14,13 @@ import {
14
14
  parseGroupLayerIndex,
15
15
  sortGroupIds,
16
16
  } from '../formats/rdrr/index.js';
17
+ import { computeHash } from '../storage/shard-manager.js';
17
18
 
18
- const PLACEHOLDER_HASH = '0'.repeat(64);
19
19
  export const DIRECT_SOURCE_RUNTIME_MODE = 'direct-source';
20
20
  export const DIRECT_SOURCE_RUNTIME_SCHEMA_VERSION = 1;
21
21
  export const DIRECT_SOURCE_RUNTIME_SCHEMA = `direct-source/v${DIRECT_SOURCE_RUNTIME_SCHEMA_VERSION}`;
22
+ export const DIRECT_SOURCE_PATH_RUNTIME_LOCAL = 'runtime-local';
23
+ export const DIRECT_SOURCE_PATH_ARTIFACT_RELATIVE = 'artifact-relative';
22
24
 
23
25
  function toPathKey(value) {
24
26
  return String(value || '').trim().replace(/\\/g, '/');
@@ -38,6 +40,31 @@ function toUint8Chunk(value, label) {
38
40
  return value instanceof Uint8Array ? value : new Uint8Array(toArrayBuffer(value, label));
39
41
  }
40
42
 
43
+ function encodeUtf8(value) {
44
+ return new TextEncoder().encode(String(value ?? ''));
45
+ }
46
+
47
+ function normalizeHashAlgorithm(value) {
48
+ const normalized = String(value || '').trim().toLowerCase();
49
+ return normalized === 'blake3' ? 'blake3' : 'sha256';
50
+ }
51
+
52
+ function normalizeHashString(value, label) {
53
+ if (value == null) return null;
54
+ const normalized = String(value).trim().toLowerCase();
55
+ if (!normalized) return null;
56
+ if (!/^[a-f0-9]{64}$/.test(normalized)) {
57
+ throw new Error(`${label} must be a 64-character lowercase hex digest.`);
58
+ }
59
+ return normalized;
60
+ }
61
+
62
+ function normalizeAssetKind(value) {
63
+ const normalized = String(value || '').trim().toLowerCase();
64
+ if (!normalized) return 'unknown';
65
+ return normalized;
66
+ }
67
+
41
68
  function normalizePositiveInteger(value, label) {
42
69
  const parsed = Number(value);
43
70
  if (!Number.isFinite(parsed) || parsed < 0) {
@@ -66,7 +93,12 @@ async function resolveSourceFiles(tensors, sourceFiles, resolveSourceSize) {
66
93
  const path = toPathKey(entry?.path);
67
94
  if (!path) continue;
68
95
  const size = normalizePositiveInteger(entry?.size, `source file size (${path})`);
69
- fileMap.set(path, { path, size });
96
+ fileMap.set(path, {
97
+ path,
98
+ size,
99
+ hash: normalizeHashString(entry?.hash, `source file hash (${path})`),
100
+ hashAlgorithm: normalizeHashAlgorithm(entry?.hashAlgorithm),
101
+ });
70
102
  }
71
103
 
72
104
  for (const tensor of tensors) {
@@ -106,7 +138,7 @@ function buildSourceShards(sourceFiles, hashAlgorithm) {
106
138
  index,
107
139
  filename,
108
140
  size: file.size,
109
- hash: PLACEHOLDER_HASH,
141
+ hash: file.hash ?? '',
110
142
  hashAlgorithm,
111
143
  offset,
112
144
  });
@@ -115,6 +147,8 @@ function buildSourceShards(sourceFiles, hashAlgorithm) {
115
147
  path: file.path,
116
148
  filename,
117
149
  size: file.size,
150
+ hash: file.hash ?? '',
151
+ hashAlgorithm,
118
152
  });
119
153
  offset += file.size;
120
154
  }
@@ -203,7 +237,7 @@ function buildSourceGroups(tensorLocations, modelType) {
203
237
  version: '1.0.0',
204
238
  shards: Array.from(entry.shards).sort((left, right) => left - right),
205
239
  tensors: [...entry.tensors].sort((left, right) => left.localeCompare(right)),
206
- hash: PLACEHOLDER_HASH,
240
+ hash: '',
207
241
  ...(Number.isInteger(layerIndex) ? { layerIndex } : {}),
208
242
  ...(Number.isInteger(expertIndex) ? { expertIndex } : {}),
209
243
  };
@@ -212,6 +246,170 @@ function buildSourceGroups(tensorLocations, modelType) {
212
246
  return groups;
213
247
  }
214
248
 
249
+ async function assignGroupHashes(groups, tensorLocations, hashAlgorithm) {
250
+ const groupIds = sortGroupIds(Object.keys(groups ?? {}));
251
+ for (const groupId of groupIds) {
252
+ const group = groups[groupId];
253
+ if (!group) continue;
254
+ const tensors = Array.isArray(group.tensors) ? group.tensors : [];
255
+ const payload = {
256
+ groupId,
257
+ type: group.type ?? null,
258
+ version: group.version ?? null,
259
+ layerIndex: Number.isInteger(group.layerIndex) ? group.layerIndex : null,
260
+ expertIndex: Number.isInteger(group.expertIndex) ? group.expertIndex : null,
261
+ tensors: tensors.map((tensorName) => {
262
+ const location = tensorLocations?.[tensorName] ?? null;
263
+ return {
264
+ name: tensorName,
265
+ shard: location?.shard ?? null,
266
+ offset: location?.offset ?? null,
267
+ size: location?.size ?? null,
268
+ dtype: location?.dtype ?? null,
269
+ shape: Array.isArray(location?.shape) ? location.shape : null,
270
+ layout: location?.layout ?? null,
271
+ };
272
+ }),
273
+ };
274
+ group.hash = await computeHash(encodeUtf8(JSON.stringify(payload)), hashAlgorithm);
275
+ }
276
+ }
277
+
278
+ function normalizeAuxiliaryFileEntry(entry, defaultHashAlgorithm) {
279
+ const path = toPathKey(entry?.path);
280
+ if (!path) return null;
281
+ return {
282
+ path,
283
+ size: normalizePositiveInteger(entry?.size, `source auxiliary file size (${path})`),
284
+ hash: normalizeHashString(entry?.hash, `source auxiliary file hash (${path})`),
285
+ hashAlgorithm: normalizeHashAlgorithm(entry?.hashAlgorithm ?? defaultHashAlgorithm),
286
+ kind: normalizeAssetKind(entry?.kind),
287
+ };
288
+ }
289
+
290
+ function normalizeAuxiliaryFiles(auxiliaryFiles, defaultHashAlgorithm) {
291
+ const normalized = [];
292
+ for (const entry of Array.isArray(auxiliaryFiles) ? auxiliaryFiles : []) {
293
+ const resolved = normalizeAuxiliaryFileEntry(entry, defaultHashAlgorithm);
294
+ if (resolved) normalized.push(resolved);
295
+ }
296
+ normalized.sort((left, right) => left.path.localeCompare(right.path));
297
+ return normalized;
298
+ }
299
+
300
+ function buildSourceRuntimeMetadata(options, manifest, shardSources, auxiliaryFiles, hashAlgorithm) {
301
+ const tokenizerJsonPath = typeof options.tokenizerJsonPath === 'string' && options.tokenizerJsonPath.trim()
302
+ ? toPathKey(options.tokenizerJsonPath)
303
+ : null;
304
+ const tokenizerConfigPath = typeof options.tokenizerConfigPath === 'string' && options.tokenizerConfigPath.trim()
305
+ ? toPathKey(options.tokenizerConfigPath)
306
+ : null;
307
+ const tokenizerModelPath = typeof options.tokenizerModelPath === 'string' && options.tokenizerModelPath.trim()
308
+ ? toPathKey(options.tokenizerModelPath)
309
+ : null;
310
+ const hasFullSourceDigests = shardSources.every((entry) => typeof entry.hash === 'string' && entry.hash.length > 0);
311
+ const hasFullAuxDigests = auxiliaryFiles.every((entry) => typeof entry.hash === 'string' && entry.hash.length > 0);
312
+
313
+ return {
314
+ mode: DIRECT_SOURCE_RUNTIME_MODE,
315
+ schema: DIRECT_SOURCE_RUNTIME_SCHEMA,
316
+ schemaVersion: DIRECT_SOURCE_RUNTIME_SCHEMA_VERSION,
317
+ sourceKind: typeof options.sourceKind === 'string' && options.sourceKind.trim()
318
+ ? String(options.sourceKind).trim().toLowerCase()
319
+ : null,
320
+ hashAlgorithm,
321
+ pathSemantics: DIRECT_SOURCE_PATH_RUNTIME_LOCAL,
322
+ sourceFileCount: shardSources.length,
323
+ auxiliaryFileCount: auxiliaryFiles.length,
324
+ sourceFiles: shardSources.map((entry) => ({
325
+ index: entry.index,
326
+ path: entry.path,
327
+ filename: entry.filename,
328
+ size: entry.size,
329
+ hash: entry.hash,
330
+ hashAlgorithm: entry.hashAlgorithm,
331
+ })),
332
+ auxiliaryFiles,
333
+ tokenizer: {
334
+ jsonPath: tokenizerJsonPath,
335
+ configPath: tokenizerConfigPath,
336
+ modelPath: tokenizerModelPath,
337
+ },
338
+ invariants: {
339
+ tensorIdentity: 'tensor.name',
340
+ shardIdentity: 'sourceFiles[index].path',
341
+ byteOffsets: 'shard-relative bytes',
342
+ hashSemantics: hasFullSourceDigests && hasFullAuxDigests
343
+ ? 'sourceFiles[*].hash digests raw source files; auxiliaryFiles[*].hash digests config/index/tokenizer assets'
344
+ : 'source digests are incomplete; persist a materialized direct-source manifest before release claims',
345
+ cacheKeying: hasFullSourceDigests ? 'path:size:hash' : 'path:size',
346
+ tokenizerAssetsCovered: tokenizerJsonPath != null || tokenizerModelPath != null,
347
+ manifestFamily: manifest?.modelType ?? null,
348
+ },
349
+ };
350
+ }
351
+
352
+ export function getSourceRuntimeMetadata(manifest) {
353
+ const metadata = manifest?.metadata?.sourceRuntime;
354
+ if (!metadata || typeof metadata !== 'object') {
355
+ return null;
356
+ }
357
+ if (metadata.mode !== DIRECT_SOURCE_RUNTIME_MODE) {
358
+ return null;
359
+ }
360
+
361
+ const hashAlgorithm = normalizeHashAlgorithm(metadata.hashAlgorithm);
362
+ const sourceFiles = Array.isArray(metadata.sourceFiles)
363
+ ? metadata.sourceFiles
364
+ .map((entry) => {
365
+ const path = toPathKey(entry?.path);
366
+ if (!path) return null;
367
+ return {
368
+ index: normalizePositiveInteger(entry?.index ?? 0, `source runtime sourceFiles index (${path})`),
369
+ path,
370
+ filename: typeof entry?.filename === 'string' && entry.filename.trim()
371
+ ? entry.filename.trim()
372
+ : null,
373
+ size: normalizePositiveInteger(entry?.size, `source runtime sourceFiles size (${path})`),
374
+ hash: normalizeHashString(entry?.hash, `source runtime sourceFiles hash (${path})`),
375
+ hashAlgorithm: normalizeHashAlgorithm(entry?.hashAlgorithm ?? hashAlgorithm),
376
+ };
377
+ })
378
+ .filter(Boolean)
379
+ .sort((left, right) => left.index - right.index)
380
+ : [];
381
+ const auxiliaryFiles = normalizeAuxiliaryFiles(metadata.auxiliaryFiles, hashAlgorithm);
382
+ const tokenizer = metadata.tokenizer && typeof metadata.tokenizer === 'object'
383
+ ? {
384
+ jsonPath: typeof metadata.tokenizer.jsonPath === 'string' && metadata.tokenizer.jsonPath.trim()
385
+ ? toPathKey(metadata.tokenizer.jsonPath)
386
+ : null,
387
+ configPath: typeof metadata.tokenizer.configPath === 'string' && metadata.tokenizer.configPath.trim()
388
+ ? toPathKey(metadata.tokenizer.configPath)
389
+ : null,
390
+ modelPath: typeof metadata.tokenizer.modelPath === 'string' && metadata.tokenizer.modelPath.trim()
391
+ ? toPathKey(metadata.tokenizer.modelPath)
392
+ : null,
393
+ }
394
+ : { jsonPath: null, configPath: null, modelPath: null };
395
+
396
+ return {
397
+ mode: DIRECT_SOURCE_RUNTIME_MODE,
398
+ schema: DIRECT_SOURCE_RUNTIME_SCHEMA,
399
+ schemaVersion: DIRECT_SOURCE_RUNTIME_SCHEMA_VERSION,
400
+ sourceKind: typeof metadata.sourceKind === 'string' && metadata.sourceKind.trim()
401
+ ? String(metadata.sourceKind).trim().toLowerCase()
402
+ : null,
403
+ hashAlgorithm,
404
+ pathSemantics: metadata.pathSemantics === DIRECT_SOURCE_PATH_ARTIFACT_RELATIVE
405
+ ? DIRECT_SOURCE_PATH_ARTIFACT_RELATIVE
406
+ : DIRECT_SOURCE_PATH_RUNTIME_LOCAL,
407
+ sourceFiles,
408
+ auxiliaryFiles,
409
+ tokenizer,
410
+ };
411
+ }
412
+
215
413
  function resolveModelQuantization(options, tensorLocations) {
216
414
  const sourceQuantization = options.sourceQuantization
217
415
  ? normalizeQuantTag(options.sourceQuantization)
@@ -270,14 +468,15 @@ export async function buildSourceRuntimeBundle(options = {}) {
270
468
  throw new Error('source runtime bundle: tensors[] is required.');
271
469
  }
272
470
 
273
- const hashAlgorithmRaw = String(options.hashAlgorithm || '').trim().toLowerCase();
274
- const hashAlgorithm = hashAlgorithmRaw === 'blake3' ? 'blake3' : 'sha256';
471
+ const hashAlgorithm = normalizeHashAlgorithm(options.hashAlgorithm);
275
472
  const sourceFiles = await resolveSourceFiles(tensors, options.sourceFiles, options.resolveSourceSize);
276
473
  const { shards, shardSources } = buildSourceShards(sourceFiles, hashAlgorithm);
277
474
  const shardIndexByPath = new Map(shardSources.map((entry) => [entry.path, entry.index]));
278
475
  const tensorLocations = buildSourceTensorLocations(tensors, shardIndexByPath, modelType);
279
476
  const groups = buildSourceGroups(tensorLocations, modelType);
477
+ await assignGroupHashes(groups, tensorLocations, hashAlgorithm);
280
478
  const { quantizationInfo, manifestQuantization } = resolveModelQuantization(options, tensorLocations);
479
+ const auxiliaryFiles = normalizeAuxiliaryFiles(options.auxiliaryFiles, hashAlgorithm);
281
480
 
282
481
  const model = {
283
482
  name: options.modelName || modelId,
@@ -316,19 +515,13 @@ export async function buildSourceRuntimeBundle(options = {}) {
316
515
  if (!manifest.metadata || typeof manifest.metadata !== 'object') {
317
516
  manifest.metadata = {};
318
517
  }
319
- manifest.metadata.sourceRuntime = {
320
- mode: DIRECT_SOURCE_RUNTIME_MODE,
321
- schema: DIRECT_SOURCE_RUNTIME_SCHEMA,
322
- schemaVersion: DIRECT_SOURCE_RUNTIME_SCHEMA_VERSION,
323
- sourceFileCount: shardSources.length,
324
- invariants: {
325
- tensorIdentity: 'tensor.name',
326
- shardIdentity: 'sourcePath -> shard index',
327
- byteOffsets: 'shard-relative bytes',
328
- hashSemantics: 'placeholder shard/group hashes; verifyHashes must be false',
329
- cacheKeying: 'sourcePath:size',
330
- },
331
- };
518
+ manifest.metadata.sourceRuntime = buildSourceRuntimeMetadata(
519
+ options,
520
+ manifest,
521
+ shardSources,
522
+ auxiliaryFiles,
523
+ hashAlgorithm
524
+ );
332
525
 
333
526
  return {
334
527
  manifest,
@@ -357,7 +550,10 @@ export function createSourceStorageContext(options = {}) {
357
550
  throw new Error('source storage context: manifest is required.');
358
551
  }
359
552
 
360
- const shardSources = Array.isArray(options.shardSources) ? options.shardSources : null;
553
+ const sourceRuntime = getSourceRuntimeMetadata(manifest);
554
+ const shardSources = Array.isArray(options.shardSources) && options.shardSources.length > 0
555
+ ? options.shardSources
556
+ : (sourceRuntime?.sourceFiles ?? null);
361
557
  if (!shardSources || shardSources.length === 0) {
362
558
  throw new Error('source storage context: shardSources[] is required.');
363
559
  }
@@ -376,17 +572,15 @@ export function createSourceStorageContext(options = {}) {
376
572
  const readBinary = typeof options.readBinary === 'function'
377
573
  ? options.readBinary
378
574
  : null;
379
- const tokenizerJsonPath = options.tokenizerJsonPath ?? null;
380
- const tokenizerModelPath = options.tokenizerModelPath ?? null;
575
+ const auxiliaryFileMap = new Map(
576
+ (sourceRuntime?.auxiliaryFiles ?? []).map((entry) => [entry.path, entry])
577
+ );
578
+ const tokenizerJsonPath = options.tokenizerJsonPath ?? sourceRuntime?.tokenizer?.jsonPath ?? null;
579
+ const tokenizerModelPath = options.tokenizerModelPath ?? sourceRuntime?.tokenizer?.modelPath ?? null;
381
580
  const verifyHashes = options.verifyHashes === true;
382
- if (verifyHashes) {
383
- throw new Error(
384
- 'source storage context: verifyHashes=true is not supported for direct-source manifests. ' +
385
- 'Convert to persisted RDRR shards first when hash verification is required.'
386
- );
387
- }
581
+ const allowRangeFastPath = verifyHashes !== true;
388
582
 
389
- const loadShardRange = async (index, offset = 0, length = null) => {
583
+ const loadShardRange = allowRangeFastPath ? async (index, offset = 0, length = null) => {
390
584
  const { sourcePath, shardSize } = resolveSourceEntry(index, manifest, shardSources);
391
585
  const start = normalizePositiveInteger(offset, `shard offset (${index})`);
392
586
  const maxLength = Math.max(0, shardSize - start);
@@ -398,14 +592,19 @@ export function createSourceStorageContext(options = {}) {
398
592
  }
399
593
  const payload = await readRange(sourcePath, start, requested);
400
594
  return toArrayBuffer(payload, `readRange(${sourcePath})`);
401
- };
595
+ } : null;
402
596
 
403
597
  const loadShard = async (index) => {
404
598
  const { shardSize } = resolveSourceEntry(index, manifest, shardSources);
405
- return loadShardRange(index, 0, shardSize);
599
+ if (loadShardRange) {
600
+ return loadShardRange(index, 0, shardSize);
601
+ }
602
+ const { sourcePath } = resolveSourceEntry(index, manifest, shardSources);
603
+ const payload = await readRange(sourcePath, 0, shardSize);
604
+ return toArrayBuffer(payload, `readRange(${sourcePath})`);
406
605
  };
407
606
 
408
- const streamShardRange = async function* (index, offset = 0, length = null, streamOptions = {}) {
607
+ const streamShardRange = allowRangeFastPath ? async function* (index, offset = 0, length = null, streamOptions = {}) {
409
608
  const { sourcePath, shardSize } = resolveSourceEntry(index, manifest, shardSources);
410
609
  const start = normalizePositiveInteger(offset, `shard stream offset (${index})`);
411
610
  const maxLength = Math.max(0, shardSize - start);
@@ -441,18 +640,35 @@ export function createSourceStorageContext(options = {}) {
441
640
  break;
442
641
  }
443
642
  }
444
- };
643
+ } : null;
445
644
 
446
645
  const loadTokenizerJson = readText && tokenizerJsonPath
447
646
  ? async () => {
448
647
  const raw = await readText(tokenizerJsonPath);
449
648
  if (typeof raw === 'string') {
649
+ if (verifyHashes) {
650
+ const descriptor = auxiliaryFileMap.get(tokenizerJsonPath);
651
+ if (descriptor?.hash) {
652
+ const computedHash = await computeHash(encodeUtf8(raw), descriptor.hashAlgorithm);
653
+ if (computedHash !== descriptor.hash) {
654
+ throw new Error(
655
+ `Tokenizer asset hash mismatch for ${tokenizerJsonPath}. ` +
656
+ `Expected ${descriptor.hash}, got ${computedHash}.`
657
+ );
658
+ }
659
+ }
660
+ }
450
661
  return JSON.parse(raw);
451
662
  }
663
+ if (verifyHashes && raw && typeof raw === 'object') {
664
+ throw new Error(
665
+ `readText(${tokenizerJsonPath}) must return the original JSON string when verifyHashes=true.`
666
+ );
667
+ }
452
668
  if (raw && typeof raw === 'object') {
453
669
  return raw;
454
670
  }
455
- return null;
671
+ throw new Error(`readText(${tokenizerJsonPath}) did not return tokenizer JSON data.`);
456
672
  }
457
673
  : null;
458
674
 
@@ -466,6 +682,17 @@ export function createSourceStorageContext(options = {}) {
466
682
  }
467
683
  const raw = await readBinary(targetPath);
468
684
  const buffer = toArrayBuffer(raw, `readBinary(${targetPath})`);
685
+ if (verifyHashes) {
686
+ const descriptor = auxiliaryFileMap.get(targetPath);
687
+ if (descriptor?.hash) {
688
+ const computedHash = await computeHash(new Uint8Array(buffer), descriptor.hashAlgorithm);
689
+ if (computedHash !== descriptor.hash) {
690
+ throw new Error(
691
+ `Binary asset hash mismatch for ${targetPath}. Expected ${descriptor.hash}, got ${computedHash}.`
692
+ );
693
+ }
694
+ }
695
+ }
469
696
  if (buffer.byteLength <= 0) {
470
697
  throw new Error(`readBinary(${targetPath}) returned an empty tokenizer model payload.`);
471
698
  }
@@ -0,0 +1,6 @@
1
+ import type { RDRRManifest } from '../formats/rdrr/index.js';
2
+
3
+ export declare function materializeSourceRuntimeManifest(
4
+ manifest: RDRRManifest,
5
+ artifactDir: string
6
+ ): RDRRManifest;
@@ -0,0 +1,93 @@
1
+ import path from 'node:path';
2
+
3
+ import {
4
+ DIRECT_SOURCE_PATH_ARTIFACT_RELATIVE,
5
+ DIRECT_SOURCE_RUNTIME_MODE,
6
+ DIRECT_SOURCE_RUNTIME_SCHEMA,
7
+ DIRECT_SOURCE_RUNTIME_SCHEMA_VERSION,
8
+ getSourceRuntimeMetadata,
9
+ } from './source-runtime-bundle.js';
10
+
11
+ function cloneJsonValue(value) {
12
+ if (typeof structuredClone === 'function') {
13
+ return structuredClone(value);
14
+ }
15
+ return JSON.parse(JSON.stringify(value));
16
+ }
17
+
18
+ function toRelativeArtifactPath(value, artifactDir, label) {
19
+ const raw = String(value || '').trim();
20
+ if (!raw) {
21
+ throw new Error(`${label} path is required.`);
22
+ }
23
+ const resolvedArtifactDir = path.resolve(artifactDir);
24
+ const resolvedTarget = path.resolve(raw);
25
+ const relativePath = path.relative(resolvedArtifactDir, resolvedTarget).replace(/\\/g, '/');
26
+ if (!relativePath || relativePath.startsWith('../') || relativePath === '..') {
27
+ throw new Error(
28
+ `${label} "${raw}" must live inside artifactDir "${resolvedArtifactDir}" for a persisted direct-source manifest.`
29
+ );
30
+ }
31
+ return relativePath;
32
+ }
33
+
34
+ export function materializeSourceRuntimeManifest(manifest, artifactDir) {
35
+ const sourceRuntime = getSourceRuntimeMetadata(manifest);
36
+ if (!sourceRuntime) {
37
+ throw new Error('materializeSourceRuntimeManifest requires manifest.metadata.sourceRuntime.');
38
+ }
39
+ const resolvedArtifactDir = String(artifactDir || '').trim();
40
+ if (!resolvedArtifactDir) {
41
+ throw new Error('materializeSourceRuntimeManifest requires artifactDir.');
42
+ }
43
+
44
+ const nextManifest = cloneJsonValue(manifest);
45
+ if (!nextManifest.metadata || typeof nextManifest.metadata !== 'object') {
46
+ nextManifest.metadata = {};
47
+ }
48
+ const sourceMetadata = nextManifest.metadata.sourceRuntime && typeof nextManifest.metadata.sourceRuntime === 'object'
49
+ ? cloneJsonValue(nextManifest.metadata.sourceRuntime)
50
+ : {};
51
+
52
+ sourceMetadata.mode = DIRECT_SOURCE_RUNTIME_MODE;
53
+ sourceMetadata.schema = DIRECT_SOURCE_RUNTIME_SCHEMA;
54
+ sourceMetadata.schemaVersion = DIRECT_SOURCE_RUNTIME_SCHEMA_VERSION;
55
+ sourceMetadata.hashAlgorithm = sourceRuntime.hashAlgorithm;
56
+ sourceMetadata.pathSemantics = DIRECT_SOURCE_PATH_ARTIFACT_RELATIVE;
57
+ sourceMetadata.sourceFiles = sourceRuntime.sourceFiles.map((entry) => ({
58
+ index: entry.index,
59
+ filename: entry.filename ?? null,
60
+ path: toRelativeArtifactPath(
61
+ entry.path,
62
+ resolvedArtifactDir,
63
+ `source runtime source file ${entry.index}`
64
+ ),
65
+ size: entry.size,
66
+ hash: entry.hash,
67
+ hashAlgorithm: entry.hashAlgorithm,
68
+ }));
69
+ sourceMetadata.auxiliaryFiles = sourceRuntime.auxiliaryFiles.map((entry) => ({
70
+ path: toRelativeArtifactPath(
71
+ entry.path,
72
+ resolvedArtifactDir,
73
+ `source runtime auxiliary file ${entry.kind}`
74
+ ),
75
+ size: entry.size,
76
+ hash: entry.hash,
77
+ hashAlgorithm: entry.hashAlgorithm,
78
+ kind: entry.kind,
79
+ }));
80
+ sourceMetadata.tokenizer = {
81
+ jsonPath: sourceRuntime.tokenizer.jsonPath
82
+ ? toRelativeArtifactPath(sourceRuntime.tokenizer.jsonPath, resolvedArtifactDir, 'source runtime tokenizer json')
83
+ : null,
84
+ configPath: sourceRuntime.tokenizer.configPath
85
+ ? toRelativeArtifactPath(sourceRuntime.tokenizer.configPath, resolvedArtifactDir, 'source runtime tokenizer config')
86
+ : null,
87
+ modelPath: sourceRuntime.tokenizer.modelPath
88
+ ? toRelativeArtifactPath(sourceRuntime.tokenizer.modelPath, resolvedArtifactDir, 'source runtime tokenizer model')
89
+ : null,
90
+ };
91
+ nextManifest.metadata.sourceRuntime = sourceMetadata;
92
+ return nextManifest;
93
+ }
@@ -1,6 +1,7 @@
1
- import { acquireBuffer, uploadData, readBuffer } from '../memory/buffer-pool.js';
1
+ import { acquireBuffer, uploadData, readBuffer, releaseBuffer } from '../memory/buffer-pool.js';
2
2
  import { createTensor, tensorBytes } from '../gpu/tensor.js';
3
3
  import { f16ToF32Array } from '../inference/kv-cache/types.js';
4
+ import { createUploadedTensor } from './tensor-factory.js';
4
5
 
5
6
  function toFloat32(buffer, dtype) {
6
7
  if (dtype === 'f16') {
@@ -67,9 +68,7 @@ export async function buildAttentionSoftmaxCache(q, k, options) {
67
68
  const kData = toFloat32(kBuf, k.dtype);
68
69
  const sData = computeSoftmax(qData, kData, options);
69
70
  const { seqLen, numHeads } = options;
70
- const outBuf = acquireBuffer(tensorBytes([numHeads, seqLen, seqLen], 'f32'), undefined, 'attn_softmax_cache');
71
- uploadData(outBuf, sData);
72
- return createTensor(outBuf, 'f32', [numHeads, seqLen, seqLen], 'attn_softmax_cache');
71
+ return createUploadedTensor(sData, 'f32', [numHeads, seqLen, seqLen], 'attn_softmax_cache');
73
72
  }
74
73
 
75
74
  export async function attentionBackwardCpu(
@@ -201,17 +200,33 @@ export async function attentionBackwardCpu(
201
200
  }
202
201
  }
203
202
 
204
- const qBufOut = acquireBuffer(tensorBytes([seqLen, numHeads, headDim], 'f32'), undefined, 'attn_backward_q');
205
- const kBufOut = acquireBuffer(tensorBytes([seqLen, numHeads, headDim], 'f32'), undefined, 'attn_backward_k');
206
- const vBufOut = acquireBuffer(tensorBytes([seqLen, numHeads, headDim], 'f32'), undefined, 'attn_backward_v');
207
-
208
- uploadData(qBufOut, dQ);
209
- uploadData(kBufOut, dK);
210
- uploadData(vBufOut, dV);
211
-
212
- return {
213
- gradQ: createTensor(qBufOut, 'f32', [seqLen, numHeads, headDim], 'attn_grad_q'),
214
- gradK: createTensor(kBufOut, 'f32', [seqLen, numHeads, headDim], 'attn_grad_k'),
215
- gradV: createTensor(vBufOut, 'f32', [seqLen, numHeads, headDim], 'attn_grad_v'),
216
- };
203
+ let qBufOut = null;
204
+ let kBufOut = null;
205
+ let vBufOut = null;
206
+ try {
207
+ qBufOut = acquireBuffer(tensorBytes([seqLen, numHeads, headDim], 'f32'), undefined, 'attn_backward_q');
208
+ kBufOut = acquireBuffer(tensorBytes([seqLen, numHeads, headDim], 'f32'), undefined, 'attn_backward_k');
209
+ vBufOut = acquireBuffer(tensorBytes([seqLen, numHeads, headDim], 'f32'), undefined, 'attn_backward_v');
210
+
211
+ uploadData(qBufOut, dQ);
212
+ uploadData(kBufOut, dK);
213
+ uploadData(vBufOut, dV);
214
+
215
+ return {
216
+ gradQ: createTensor(qBufOut, 'f32', [seqLen, numHeads, headDim], 'attn_grad_q'),
217
+ gradK: createTensor(kBufOut, 'f32', [seqLen, numHeads, headDim], 'attn_grad_k'),
218
+ gradV: createTensor(vBufOut, 'f32', [seqLen, numHeads, headDim], 'attn_grad_v'),
219
+ };
220
+ } catch (error) {
221
+ if (qBufOut) {
222
+ releaseBuffer(qBufOut);
223
+ }
224
+ if (kBufOut) {
225
+ releaseBuffer(kBufOut);
226
+ }
227
+ if (vBufOut) {
228
+ releaseBuffer(vBufOut);
229
+ }
230
+ throw error;
231
+ }
217
232
  }