@simulatte/doppler 0.1.6 → 0.1.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (316) hide show
  1. package/CHANGELOG.md +126 -0
  2. package/README.md +16 -23
  3. package/package.json +14 -1
  4. package/src/adapters/adapter-registry.js +12 -1
  5. package/src/adapters/lora-loader.js +23 -6
  6. package/src/bridge/extension-client.d.ts +5 -0
  7. package/src/bridge/extension-client.js +40 -0
  8. package/src/bridge/index.d.ts +2 -1
  9. package/src/bridge/index.js +6 -4
  10. package/src/browser/browser-converter.js +26 -1
  11. package/src/browser/file-picker.js +6 -0
  12. package/src/browser/safetensors-parser-browser.js +84 -1
  13. package/src/browser/shard-io-browser.js +2 -2
  14. package/src/browser/tensor-source-download.js +8 -2
  15. package/src/browser/tensor-source-http.d.ts +1 -0
  16. package/src/browser/tensor-source-http.js +5 -1
  17. package/src/client/doppler-api.browser.js +20 -4
  18. package/src/client/doppler-api.js +19 -3
  19. package/src/client/doppler-provider/generation.js +12 -0
  20. package/src/client/doppler-provider/model-manager.d.ts +10 -0
  21. package/src/client/doppler-provider/model-manager.js +91 -19
  22. package/src/client/doppler-provider/source-runtime.d.ts +2 -1
  23. package/src/client/doppler-provider/source-runtime.js +132 -13
  24. package/src/client/doppler-registry.json +8 -7
  25. package/src/config/backward-registry-loader.js +17 -2
  26. package/src/config/execution-v0-contract-check.js +113 -15
  27. package/src/config/kernel-path-contract-check.js +57 -29
  28. package/src/config/kernel-path-loader.js +5 -36
  29. package/src/config/kernels/kernel-ref-digests.js +1 -1
  30. package/src/config/kernels/registry.js +14 -1
  31. package/src/config/kernels/registry.json +7 -5
  32. package/src/config/loader.d.ts +1 -1
  33. package/src/config/loader.js +12 -2
  34. package/src/config/merge-contract-check.js +59 -4
  35. package/src/config/merge-helpers.js +128 -7
  36. package/src/config/merge.d.ts +1 -0
  37. package/src/config/merge.js +10 -0
  38. package/src/config/param-validator.js +47 -2
  39. package/src/config/presets/kernel-paths/{gemma2-q4k-dequant-f32a.json → gemma2-q4k-dequant-f32a-nosubgroups.json} +3 -3
  40. package/src/config/presets/kernel-paths/gemma3-f16-fused-f32a-online-streamingprefill.json +223 -0
  41. package/src/config/presets/kernel-paths/{gemma3-q4k-dequant-f32a.json → gemma3-q4k-dequant-f32a-nosubgroups.json} +3 -3
  42. package/src/config/presets/kernel-paths/registry.json +29 -8
  43. package/src/config/presets/models/gemma2.json +2 -2
  44. package/src/config/presets/runtime/experiments/bench/gemma3-bench-q4k.json +1 -1
  45. package/src/config/presets/runtime/experiments/debug/gemma3-debug-q4k.json +1 -1
  46. package/src/config/presets/runtime/experiments/verify/gemma3-verify.json +1 -1
  47. package/src/config/presets/runtime/kernels/dequant-f16-q4k.json +6 -13
  48. package/src/config/presets/runtime/kernels/dequant-f32-q4k.json +6 -13
  49. package/src/config/presets/runtime/kernels/embeddinggemma-q4k-dequant-f32a.json +37 -0
  50. package/src/config/presets/runtime/kernels/fused-q4k.json +6 -13
  51. package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f16a.json +33 -0
  52. package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f32a-nosubgroups.json +33 -0
  53. package/src/config/presets/runtime/kernels/gemma2-q4k-fused-f32a.json +33 -0
  54. package/src/config/presets/runtime/kernels/safe-q4k.json +6 -13
  55. package/src/config/presets/runtime/platform/metal-apple-q4k.json +1 -1
  56. package/src/config/runtime.js +6 -1
  57. package/src/config/schema/debug.schema.d.ts +5 -0
  58. package/src/config/schema/doppler.schema.js +16 -21
  59. package/src/config/schema/inference-defaults.schema.js +3 -3
  60. package/src/config/schema/kernel-path.schema.d.ts +5 -1
  61. package/src/config/schema/kernel-thresholds.schema.js +12 -4
  62. package/src/config/schema/manifest.schema.d.ts +2 -1
  63. package/src/config/schema/manifest.schema.js +16 -3
  64. package/src/config/training-defaults.js +30 -22
  65. package/src/converter/conversion-plan.js +94 -9
  66. package/src/converter/core.d.ts +7 -0
  67. package/src/converter/core.js +14 -9
  68. package/src/converter/execution-v0-manifest.js +4 -1
  69. package/src/converter/index.d.ts +1 -0
  70. package/src/converter/index.js +1 -0
  71. package/src/converter/manifest-inference.js +43 -12
  72. package/src/converter/parsers/diffusion.js +0 -3
  73. package/src/converter/quantization-info.js +35 -15
  74. package/src/converter/shard-packer.d.ts +1 -1
  75. package/src/converter/shard-packer.js +4 -1
  76. package/src/debug/config.js +123 -11
  77. package/src/debug/signals.js +7 -1
  78. package/src/debug/tensor.d.ts +2 -0
  79. package/src/debug/tensor.js +13 -2
  80. package/src/distribution/p2p-control-plane.js +52 -12
  81. package/src/distribution/p2p-observability.js +43 -7
  82. package/src/distribution/p2p-webrtc-browser.js +20 -0
  83. package/src/distribution/shard-delivery.js +77 -26
  84. package/src/formats/gguf/types.js +33 -16
  85. package/src/formats/rdrr/groups.d.ts +12 -4
  86. package/src/formats/rdrr/groups.js +3 -6
  87. package/src/formats/rdrr/parsing.js +39 -2
  88. package/src/formats/rdrr/types.d.ts +2 -1
  89. package/src/gpu/command-recorder.js +86 -61
  90. package/src/gpu/device.d.ts +1 -0
  91. package/src/gpu/device.js +73 -19
  92. package/src/gpu/kernel-tuner/benchmarks.js +326 -316
  93. package/src/gpu/kernel-tuner/cache.js +71 -4
  94. package/src/gpu/kernel-tuner/tuner.js +22 -4
  95. package/src/gpu/kernels/attention.js +15 -34
  96. package/src/gpu/kernels/backward/adam.js +62 -58
  97. package/src/gpu/kernels/backward/attention_backward.js +257 -169
  98. package/src/gpu/kernels/backward/conv2d_backward.js +14 -1
  99. package/src/gpu/kernels/cast.js +191 -149
  100. package/src/gpu/kernels/check-stop.js +33 -44
  101. package/src/gpu/kernels/conv2d.js +27 -17
  102. package/src/gpu/kernels/cross_entropy_loss.js +21 -15
  103. package/src/gpu/kernels/depthwise_conv2d.js +36 -26
  104. package/src/gpu/kernels/dequant.js +178 -126
  105. package/src/gpu/kernels/energy.d.ts +3 -21
  106. package/src/gpu/kernels/energy.js +111 -88
  107. package/src/gpu/kernels/feature-check.js +1 -1
  108. package/src/gpu/kernels/fused_ffn.js +84 -65
  109. package/src/gpu/kernels/fused_matmul_residual.js +56 -33
  110. package/src/gpu/kernels/fused_matmul_rmsnorm.js +62 -45
  111. package/src/gpu/kernels/gather.js +33 -15
  112. package/src/gpu/kernels/gelu.js +19 -11
  113. package/src/gpu/kernels/grouped_pointwise_conv2d.js +33 -23
  114. package/src/gpu/kernels/groupnorm.js +34 -23
  115. package/src/gpu/kernels/kv-quantize.js +5 -2
  116. package/src/gpu/kernels/layernorm.js +35 -19
  117. package/src/gpu/kernels/logit-merge.js +5 -3
  118. package/src/gpu/kernels/matmul.js +58 -39
  119. package/src/gpu/kernels/modulate.js +23 -15
  120. package/src/gpu/kernels/moe.js +221 -175
  121. package/src/gpu/kernels/pixel_shuffle.js +22 -14
  122. package/src/gpu/kernels/relu.js +18 -10
  123. package/src/gpu/kernels/repeat_channels.js +25 -17
  124. package/src/gpu/kernels/residual.js +37 -27
  125. package/src/gpu/kernels/rmsnorm.js +57 -41
  126. package/src/gpu/kernels/rope.js +3 -0
  127. package/src/gpu/kernels/sample.js +27 -38
  128. package/src/gpu/kernels/sana_linear_attention.js +18 -10
  129. package/src/gpu/kernels/scale.js +18 -11
  130. package/src/gpu/kernels/shader-cache.js +4 -2
  131. package/src/gpu/kernels/silu.js +120 -72
  132. package/src/gpu/kernels/softmax.js +44 -25
  133. package/src/gpu/kernels/split_qkv.js +23 -13
  134. package/src/gpu/kernels/transpose.js +18 -10
  135. package/src/gpu/kernels/transpose.wgsl +5 -3
  136. package/src/gpu/kernels/upsample2d.js +21 -13
  137. package/src/gpu/kernels/utils.js +20 -13
  138. package/src/gpu/partitioned-buffer-pool.js +10 -2
  139. package/src/gpu/perf-guards.js +2 -9
  140. package/src/gpu/profiler.js +27 -22
  141. package/src/gpu/readback-utils.d.ts +16 -0
  142. package/src/gpu/readback-utils.js +41 -0
  143. package/src/gpu/submit-tracker.js +13 -0
  144. package/src/gpu/uniform-cache.d.ts +1 -0
  145. package/src/gpu/uniform-cache.js +30 -9
  146. package/src/hotswap/intent-bundle.js +6 -0
  147. package/src/hotswap/manifest.d.ts +10 -1
  148. package/src/hotswap/manifest.js +12 -2
  149. package/src/hotswap/runtime.js +30 -8
  150. package/src/index-browser.d.ts +44 -0
  151. package/src/index-browser.js +14 -0
  152. package/src/inference/browser-harness-contract-helpers.d.ts +5 -0
  153. package/src/inference/browser-harness-contract-helpers.js +28 -0
  154. package/src/inference/browser-harness-diffusion-energy-suites.d.ts +2 -0
  155. package/src/inference/browser-harness-diffusion-energy-suites.js +269 -0
  156. package/src/inference/browser-harness-model-helpers.d.ts +16 -0
  157. package/src/inference/browser-harness-model-helpers.js +217 -0
  158. package/src/inference/browser-harness-report-helpers.d.ts +7 -0
  159. package/src/inference/browser-harness-report-helpers.js +42 -0
  160. package/src/inference/browser-harness-runtime-helpers.d.ts +61 -0
  161. package/src/inference/browser-harness-runtime-helpers.js +415 -0
  162. package/src/inference/browser-harness-suite-helpers.d.ts +28 -0
  163. package/src/inference/browser-harness-suite-helpers.js +268 -0
  164. package/src/inference/browser-harness-text-helpers.d.ts +27 -0
  165. package/src/inference/browser-harness-text-helpers.js +788 -0
  166. package/src/inference/browser-harness.d.ts +6 -0
  167. package/src/inference/browser-harness.js +130 -1996
  168. package/src/inference/kv-cache/base.js +140 -94
  169. package/src/inference/kv-cache/tiered.js +5 -3
  170. package/src/inference/moe-router.js +88 -56
  171. package/src/inference/multi-model-network.js +5 -3
  172. package/src/inference/network-evolution.d.ts +11 -2
  173. package/src/inference/network-evolution.js +20 -21
  174. package/src/inference/pipelines/context.d.ts +3 -0
  175. package/src/inference/pipelines/context.js +142 -2
  176. package/src/inference/pipelines/diffusion/helpers.js +7 -2
  177. package/src/inference/pipelines/diffusion/pipeline.js +2 -1
  178. package/src/inference/pipelines/diffusion/sd3-transformer.js +10 -10
  179. package/src/inference/pipelines/diffusion/vae.js +3 -7
  180. package/src/inference/pipelines/energy/pipeline.js +27 -21
  181. package/src/inference/pipelines/energy/quintel.d.ts +5 -0
  182. package/src/inference/pipelines/energy/quintel.js +11 -0
  183. package/src/inference/pipelines/energy-head/row-head-pipeline.js +17 -13
  184. package/src/inference/pipelines/structured/json-head-pipeline.js +26 -11
  185. package/src/inference/pipelines/text/attention/projections.js +151 -101
  186. package/src/inference/pipelines/text/attention/record.js +62 -8
  187. package/src/inference/pipelines/text/attention/run.js +62 -8
  188. package/src/inference/pipelines/text/config.js +3 -4
  189. package/src/inference/pipelines/text/embed.js +2 -8
  190. package/src/inference/pipelines/text/execution-plan.js +41 -19
  191. package/src/inference/pipelines/text/execution-v0-contract-helpers.d.ts +59 -0
  192. package/src/inference/pipelines/text/execution-v0-contract-helpers.js +937 -0
  193. package/src/inference/pipelines/text/execution-v0-runtime-builders.d.ts +15 -0
  194. package/src/inference/pipelines/text/execution-v0-runtime-builders.js +279 -0
  195. package/src/inference/pipelines/text/execution-v0.js +62 -1013
  196. package/src/inference/pipelines/text/generator-steps.d.ts +46 -0
  197. package/src/inference/pipelines/text/generator-steps.js +298 -207
  198. package/src/inference/pipelines/text/generator.js +6 -23
  199. package/src/inference/pipelines/text/init.js +78 -20
  200. package/src/inference/pipelines/text/kernel-path-auto-select.js +2 -0
  201. package/src/inference/pipelines/text/kernel-trace.d.ts +2 -0
  202. package/src/inference/pipelines/text/kernel-trace.js +6 -0
  203. package/src/inference/pipelines/text/layer.js +3 -9
  204. package/src/inference/pipelines/text/linear-attention.d.ts +10 -0
  205. package/src/inference/pipelines/text/linear-attention.js +80 -6
  206. package/src/inference/pipelines/text/logits/gpu.js +10 -5
  207. package/src/inference/pipelines/text/logits/index.js +10 -11
  208. package/src/inference/pipelines/text/logits/utils.d.ts +7 -0
  209. package/src/inference/pipelines/text/logits/utils.js +9 -0
  210. package/src/inference/pipelines/text/lora-apply.js +50 -32
  211. package/src/inference/pipelines/text/model-load.js +279 -104
  212. package/src/inference/pipelines/text/moe-cache.js +5 -4
  213. package/src/inference/pipelines/text/moe-cpu-gptoss.js +74 -69
  214. package/src/inference/pipelines/text/moe-cpu.js +42 -38
  215. package/src/inference/pipelines/text/moe-gpu.js +110 -86
  216. package/src/inference/pipelines/text/ops.js +90 -90
  217. package/src/inference/pipelines/text/probes.js +9 -9
  218. package/src/inference/pipelines/text/weights.js +17 -7
  219. package/src/inference/pipelines/text.js +13 -1
  220. package/src/inference/speculative.d.ts +2 -2
  221. package/src/inference/speculative.js +4 -18
  222. package/src/inference/test-harness.d.ts +1 -1
  223. package/src/inference/test-harness.js +15 -5
  224. package/src/inference/tokenizer.d.ts +0 -5
  225. package/src/inference/tokenizer.js +4 -23
  226. package/src/inference/tokenizers/bpe.js +9 -0
  227. package/src/inference/tokenizers/bundled.js +20 -0
  228. package/src/inference/tokenizers/sentencepiece.js +12 -0
  229. package/src/loader/doppler-loader.js +38 -22
  230. package/src/loader/dtype-utils.js +3 -44
  231. package/src/loader/embedding-loader.js +7 -3
  232. package/src/loader/experts/expert-cache.js +13 -6
  233. package/src/loader/experts/expert-loader.js +10 -6
  234. package/src/loader/final-weights-loader.js +8 -4
  235. package/src/loader/layer-loader.js +2 -1
  236. package/src/loader/loader-state.js +2 -2
  237. package/src/loader/memory-monitor.js +8 -0
  238. package/src/loader/multi-model-loader.d.ts +14 -0
  239. package/src/loader/multi-model-loader.js +70 -24
  240. package/src/loader/shard-cache.js +81 -12
  241. package/src/loader/shard-resolver.js +25 -3
  242. package/src/loader/tensors/tensor-loader.js +209 -144
  243. package/src/loader/tensors/tensor-reader.js +76 -19
  244. package/src/loader/weight-downcast.js +1 -1
  245. package/src/memory/buffer-pool.d.ts +9 -1
  246. package/src/memory/buffer-pool.js +109 -44
  247. package/src/memory/unified-detect.js +1 -1
  248. package/src/rules/inference/kernel-path.rules.json +24 -8
  249. package/src/rules/rule-registry.js +25 -1
  250. package/src/storage/backends/opfs-store.js +68 -24
  251. package/src/storage/downloader.js +364 -83
  252. package/src/storage/index.d.ts +3 -0
  253. package/src/storage/index.js +3 -0
  254. package/src/storage/preflight.d.ts +2 -2
  255. package/src/storage/preflight.js +24 -2
  256. package/src/storage/quickstart-downloader.js +11 -5
  257. package/src/storage/registry.js +10 -4
  258. package/src/storage/reports.js +1 -1
  259. package/src/storage/shard-manager.d.ts +15 -1
  260. package/src/storage/shard-manager.js +51 -3
  261. package/src/storage/source-artifact-store.d.ts +52 -0
  262. package/src/storage/source-artifact-store.js +234 -0
  263. package/src/tooling/command-api-constants.d.ts +9 -0
  264. package/src/tooling/command-api-constants.js +9 -0
  265. package/src/tooling/command-api-family-normalizers.d.ts +9 -0
  266. package/src/tooling/command-api-family-normalizers.js +343 -0
  267. package/src/tooling/command-api-helpers.d.ts +25 -0
  268. package/src/tooling/command-api-helpers.js +262 -0
  269. package/src/tooling/command-api.js +16 -602
  270. package/src/tooling/command-envelope.js +4 -1
  271. package/src/tooling/command-runner-shared.js +52 -18
  272. package/src/tooling/lean-execution-contract.js +150 -3
  273. package/src/tooling/node-browser-command-runner.js +161 -271
  274. package/src/tooling/node-command-runner.js +29 -3
  275. package/src/tooling/node-converter.js +27 -1
  276. package/src/tooling/node-source-runtime.d.ts +1 -1
  277. package/src/tooling/node-source-runtime.js +84 -3
  278. package/src/tooling/node-webgpu.js +24 -21
  279. package/src/tooling/opfs-cache.js +21 -4
  280. package/src/tooling/runtime-input-composition.d.ts +38 -0
  281. package/src/tooling/runtime-input-composition.js +86 -0
  282. package/src/tooling/source-runtime-bundle.d.ts +40 -5
  283. package/src/tooling/source-runtime-bundle.js +261 -34
  284. package/src/tooling/source-runtime-materializer.d.ts +6 -0
  285. package/src/tooling/source-runtime-materializer.js +93 -0
  286. package/src/training/attention-backward.js +32 -17
  287. package/src/training/autograd.js +80 -52
  288. package/src/training/checkpoint-watch.d.ts +2 -1
  289. package/src/training/checkpoint-watch.js +39 -6
  290. package/src/training/checkpoint.js +40 -11
  291. package/src/training/clip.js +2 -1
  292. package/src/training/datasets/token-batch.js +20 -8
  293. package/src/training/distillation/checkpoint-watch.js +1 -0
  294. package/src/training/distillation/student-fixture.d.ts +22 -0
  295. package/src/training/distillation/student-fixture.js +846 -0
  296. package/src/training/distillation/suite-data.d.ts +45 -0
  297. package/src/training/distillation/suite-data.js +189 -0
  298. package/src/training/lora-pipeline.js +4 -7
  299. package/src/training/lora.js +26 -12
  300. package/src/training/loss.js +5 -6
  301. package/src/training/objectives/cross_entropy.js +2 -5
  302. package/src/training/objectives/distill_kd.js +4 -8
  303. package/src/training/objectives/distill_triplet.js +4 -8
  304. package/src/training/objectives/ul_stage2_base.js +4 -8
  305. package/src/training/operator-command.js +2 -0
  306. package/src/training/optimizer.js +19 -7
  307. package/src/training/runner.js +2 -1
  308. package/src/training/suite.js +18 -978
  309. package/src/training/tensor-factory.d.ts +9 -0
  310. package/src/training/tensor-factory.js +13 -0
  311. package/src/training/trainer.js +3 -5
  312. package/src/training/ul_dataset.js +3 -5
  313. package/src/training/workloads.js +70 -79
  314. package/src/version.js +1 -1
  315. package/tools/convert-safetensors-node.js +22 -16
  316. package/tools/doppler-cli.js +44 -25
@@ -2,7 +2,7 @@
2
2
 
3
3
  import { parseModelConfig } from './config.js';
4
4
  import { getDevice, getDeviceLimits, getKernelCapabilities } from '../../../gpu/device.js';
5
- import { acquireBuffer } from '../../../memory/buffer-pool.js';
5
+ import { acquireBuffer, releaseBuffer } from '../../../memory/buffer-pool.js';
6
6
  import { KVCache, SlidingWindowKVCache, TieredKVCache, BasisDecomposedPagedCache } from '../../kv-cache.js';
7
7
  import { Tokenizer } from '../../tokenizer.js';
8
8
  import { MoERouter } from '../../moe-router.js';
@@ -14,6 +14,10 @@ import { PAGED_LAYOUT_SEQ_LEN_THRESHOLD } from '../../../config/schema/index.js'
14
14
  import { isKernelPathFusedQ4K } from '../../../config/kernel-path-loader.js';
15
15
  import { createWeightBuffer, getWeightDtype, isWeightBuffer } from '../../../gpu/weight-buffer.js';
16
16
  import { selectRuleValue } from '../../../rules/rule-registry.js';
17
+ import {
18
+ createSourceStorageContext,
19
+ getSourceRuntimeMetadata,
20
+ } from '../../../tooling/source-runtime-bundle.js';
17
21
 
18
22
  function resolveErrorMessage(error) {
19
23
  if (error && typeof error === 'object' && typeof error.message === 'string') {
@@ -56,12 +60,61 @@ function normalizeBaseUrl(baseUrl) {
56
60
  return baseUrl.replace(/\/$/, '');
57
61
  }
58
62
 
63
+ async function fetchBytes(url, offset = null, length = null) {
64
+ const headers = {};
65
+ if (Number.isFinite(offset) && Number.isFinite(length) && length > 0) {
66
+ const start = Math.max(0, Math.floor(offset));
67
+ const end = start + Math.max(0, Math.floor(length)) - 1;
68
+ headers.Range = `bytes=${start}-${end}`;
69
+ }
70
+ const response = await fetch(url, { headers });
71
+ if (!response.ok) {
72
+ throw new Error(`Failed to fetch ${url}: ${response.status}`);
73
+ }
74
+ return new Uint8Array(await response.arrayBuffer());
75
+ }
76
+
59
77
  function createRemoteStorageContext(baseUrl, manifest) {
60
78
  const root = normalizeBaseUrl(baseUrl);
61
79
  if (!root || !isRDRRManifest(manifest)) {
62
80
  return null;
63
81
  }
64
82
 
83
+ const sourceRuntime = getSourceRuntimeMetadata(manifest);
84
+ if (sourceRuntime) {
85
+ const readRange = async (relativePath, offset, length) => {
86
+ const filename = String(relativePath || '').replace(/^\/+/, '');
87
+ if (!filename) {
88
+ throw new Error('Direct-source artifact path is required.');
89
+ }
90
+ const url = `${root}/${filename}`;
91
+ return fetchBytes(url, offset, length);
92
+ };
93
+ const readText = async (relativePath) => {
94
+ const filename = String(relativePath || '').replace(/^\/+/, '');
95
+ if (!filename) return null;
96
+ const response = await fetch(`${root}/${filename}`);
97
+ if (!response.ok) {
98
+ throw new Error(`Failed to fetch ${filename} from ${root}: ${response.status}`);
99
+ }
100
+ return response.text();
101
+ };
102
+ const readBinary = async (relativePath) => {
103
+ const filename = String(relativePath || '').replace(/^\/+/, '');
104
+ if (!filename) {
105
+ throw new Error('Direct-source binary asset path is required.');
106
+ }
107
+ return fetchBytes(`${root}/${filename}`);
108
+ };
109
+ return createSourceStorageContext({
110
+ manifest,
111
+ readRange,
112
+ readText,
113
+ readBinary,
114
+ verifyHashes: true,
115
+ });
116
+ }
117
+
65
118
  return {
66
119
  async loadShard(index) {
67
120
  const shard = manifest.shards[index];
@@ -69,11 +122,7 @@ function createRemoteStorageContext(baseUrl, manifest) {
69
122
  if (!filename) {
70
123
  throw new Error(`Manifest shard ${index} is missing filename.`);
71
124
  }
72
- const response = await fetch(`${root}/${filename.replace(/^\/+/, '')}`);
73
- if (!response.ok) {
74
- throw new Error(`Failed to fetch shard ${index} from ${root}: ${response.status}`);
75
- }
76
- return new Uint8Array(await response.arrayBuffer());
125
+ return fetchBytes(`${root}/${filename.replace(/^\/+/, '')}`);
77
126
  },
78
127
  };
79
128
  }
@@ -326,20 +375,29 @@ export async function initRoPEFrequencies(config, useGPU) {
326
375
  // Upload to GPU if available
327
376
  const device = getDevice();
328
377
  if (device && useGPU) {
329
- const cosBuffer = acquireBuffer(globalFreqs.cos.byteLength, undefined, 'rope_cos');
330
- const sinBuffer = acquireBuffer(globalFreqs.sin.byteLength, undefined, 'rope_sin');
331
- device.queue.writeBuffer(cosBuffer, 0, globalFreqs.cos.buffer, globalFreqs.cos.byteOffset, globalFreqs.cos.byteLength);
332
- device.queue.writeBuffer(sinBuffer, 0, globalFreqs.sin.buffer, globalFreqs.sin.byteOffset, globalFreqs.sin.byteLength);
333
-
334
-
335
- let localCosBuffer;
336
-
337
- let localSinBuffer;
338
- if (localFreqs) {
339
- localCosBuffer = acquireBuffer(localFreqs.cos.byteLength, undefined, 'rope_local_cos');
340
- localSinBuffer = acquireBuffer(localFreqs.sin.byteLength, undefined, 'rope_local_sin');
341
- device.queue.writeBuffer(localCosBuffer, 0, localFreqs.cos.buffer, localFreqs.cos.byteOffset, localFreqs.cos.byteLength);
342
- device.queue.writeBuffer(localSinBuffer, 0, localFreqs.sin.buffer, localFreqs.sin.byteOffset, localFreqs.sin.byteLength);
378
+ let cosBuffer = null;
379
+ let sinBuffer = null;
380
+ let localCosBuffer = null;
381
+ let localSinBuffer = null;
382
+ try {
383
+ cosBuffer = acquireBuffer(globalFreqs.cos.byteLength, undefined, 'rope_cos');
384
+ sinBuffer = acquireBuffer(globalFreqs.sin.byteLength, undefined, 'rope_sin');
385
+ device.queue.writeBuffer(cosBuffer, 0, globalFreqs.cos.buffer, globalFreqs.cos.byteOffset, globalFreqs.cos.byteLength);
386
+ device.queue.writeBuffer(sinBuffer, 0, globalFreqs.sin.buffer, globalFreqs.sin.byteOffset, globalFreqs.sin.byteLength);
387
+
388
+ if (localFreqs) {
389
+ localCosBuffer = acquireBuffer(localFreqs.cos.byteLength, undefined, 'rope_local_cos');
390
+ localSinBuffer = acquireBuffer(localFreqs.sin.byteLength, undefined, 'rope_local_sin');
391
+ device.queue.writeBuffer(localCosBuffer, 0, localFreqs.cos.buffer, localFreqs.cos.byteOffset, localFreqs.cos.byteLength);
392
+ device.queue.writeBuffer(localSinBuffer, 0, localFreqs.sin.buffer, localFreqs.sin.byteOffset, localFreqs.sin.byteLength);
393
+ }
394
+ } catch (error) {
395
+ for (const buffer of [cosBuffer, sinBuffer, localCosBuffer, localSinBuffer]) {
396
+ if (buffer) {
397
+ releaseBuffer(buffer);
398
+ }
399
+ }
400
+ throw error;
343
401
  }
344
402
 
345
403
  log.debug(
@@ -78,6 +78,7 @@ export function resolveCapabilityKernelPathRef(configuredKernelPathRef, kernelPa
78
78
 
79
79
  const normalizedPolicy = resolveKernelPathPolicy(kernelPathPolicy);
80
80
  const hasSubgroups = capabilities?.hasSubgroups === true;
81
+ const hasF16 = capabilities?.hasF16 === true;
81
82
  const normalizedSource = normalizeKernelPathSource(kernelPathSource);
82
83
  const allowCapabilityAutoSelection = normalizedPolicy.mode === 'capability-aware'
83
84
  && normalizedPolicy.sourceScope.includes(normalizedSource);
@@ -85,6 +86,7 @@ export function resolveCapabilityKernelPathRef(configuredKernelPathRef, kernelPa
85
86
  return selectRuleValue('inference', 'kernelPath', 'autoSelect', {
86
87
  kernelPathRef: configuredKernelPathRef,
87
88
  hasSubgroups,
89
+ hasF16,
88
90
  allowCapabilityAutoSelection,
89
91
  });
90
92
  }
@@ -12,6 +12,8 @@
12
12
  * Snapshot of a tensor's statistics (no full data, just stats).
13
13
  */
14
14
  export interface TensorSnapshot {
15
+ ok: boolean;
16
+ error: string | null;
15
17
  shape: number[];
16
18
  dtype: string;
17
19
  stats: {
@@ -283,6 +283,9 @@ export async function traceStep(name, label, layer, outputBuffer, outputShape, o
283
283
  if (layer >= 0 && !kernelTrace.shouldTraceLayer(layer)) return;
284
284
 
285
285
  const output = await snapshotTensor(outputBuffer, outputShape);
286
+ if (!output.ok) {
287
+ throw new Error(`[TRACE] Failed to snapshot output for ${label}: ${output.error}`);
288
+ }
286
289
 
287
290
  // Snapshot inputs if provided (expensive - only do if tracing)
288
291
 
@@ -290,6 +293,9 @@ export async function traceStep(name, label, layer, outputBuffer, outputShape, o
290
293
  if (options?.inputs && options?.inputShapes) {
291
294
  for (let i = 0; i < options.inputs.length; i++) {
292
295
  const snap = await snapshotTensor(options.inputs[i], options.inputShapes[i]);
296
+ if (!snap.ok) {
297
+ throw new Error(`[TRACE] Failed to snapshot input ${i} for ${label}: ${snap.error}`);
298
+ }
293
299
  inputs.push(snap);
294
300
  }
295
301
  }
@@ -2,7 +2,7 @@
2
2
 
3
3
  import { log, trace } from '../../../debug/index.js';
4
4
  import { getDevice } from '../../../gpu/device.js';
5
- import { releaseBuffer } from '../../../memory/buffer-pool.js';
5
+ import { releaseBuffer, readBuffer } from '../../../memory/buffer-pool.js';
6
6
  import { allowReadback } from '../../../gpu/perf-guards.js';
7
7
  import { createTensor } from '../../../gpu/tensor.js';
8
8
  import {
@@ -228,6 +228,7 @@ export async function processLayerGPU(layerIdx, inputBuffer, numTokens, isPrefil
228
228
  linearRuntime: context.linearAttentionRuntime ?? null,
229
229
  getWeightBuffer: (weight, label) => getWeightBuffer(weight, label),
230
230
  getNormWeightBuffer: (weight, label) => getNormWeightBuffer(weight, label, weightConfig, debugFlags),
231
+ debugProbes: context.debugProbes,
231
232
  recorder: recorder ?? null,
232
233
  });
233
234
  } else {
@@ -314,14 +315,7 @@ export async function processLayerGPU(layerIdx, inputBuffer, numTokens, isPrefil
314
315
  if (allowReadback(`layer.attn-out.${layerIdx}`)) {
315
316
  try {
316
317
  const sampleSize = Math.min(128, attnOutput.buffer.size);
317
- const staging = device.createBuffer({ size: sampleSize, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ });
318
- const enc = device.createCommandEncoder();
319
- enc.copyBufferToBuffer(attnOutput.buffer, 0, staging, 0, sampleSize);
320
- device.queue.submit([enc.finish()]);
321
- await staging.mapAsync(GPUMapMode.READ);
322
- const data = new Float32Array(staging.getMappedRange().slice(0));
323
- staging.unmap();
324
- staging.destroy();
318
+ const data = new Float32Array(await readBuffer(attnOutput.buffer, sampleSize));
325
319
  let maxAbs = 0;
326
320
  for (let i = 0; i < data.length; i++) {
327
321
  const abs = Math.abs(data[i]);
@@ -3,6 +3,7 @@ import type { Tensor } from '../../../gpu/tensor.js';
3
3
  import type { WeightBuffer } from '../../../gpu/weight-buffer.js';
4
4
  import type { CommandRecorder } from '../../../gpu/command-recorder.js';
5
5
  import type { LinearNormMode } from '../../../config/schema/index.js';
6
+ import type { ProbeConfigSchema } from '../../../config/schema/index.js';
6
7
 
7
8
  export interface LinearLayerRuntimeState {
8
9
  layerIdx: number;
@@ -67,6 +68,7 @@ export interface RunLinearAttentionLayerOptions {
67
68
  weight: GPUBuffer | Float32Array | ArrayBuffer,
68
69
  label: string
69
70
  ) => GPUBuffer;
71
+ debugProbes?: ProbeConfigSchema[] | null;
70
72
  recorder?: CommandRecorder | null;
71
73
  }
72
74
 
@@ -74,6 +76,14 @@ export declare function hasLinearAttentionLayers(layerTypes: unknown): boolean;
74
76
 
75
77
  export declare function createLinearAttentionRuntime(): LinearAttentionRuntime;
76
78
 
79
+ export declare function inferLinearNormMode(
80
+ weight: { size?: number; dtype?: string } | GPUBuffer | WeightBuffer | ArrayBufferView | ArrayBuffer | null | undefined,
81
+ projectionLayout: {
82
+ headVDim: number;
83
+ valueDim: number;
84
+ }
85
+ ): LinearNormMode | null;
86
+
77
87
  export declare function resetLinearAttentionRuntime(
78
88
  runtime: LinearAttentionRuntime | null | undefined
79
89
  ): LinearAttentionRuntime;
@@ -4,6 +4,7 @@ import { readBuffer, releaseBuffer, uploadData, acquireBuffer } from '../../../m
4
4
  import { log } from '../../../debug/index.js';
5
5
  import { decodeReadback } from './debug-utils/index.js';
6
6
  import { runLinearAttentionCoreGPU } from '../../../gpu/kernels/linear-attention-core.js';
7
+ import { runProbes } from './probes.js';
7
8
 
8
9
  const LINEAR_RUNTIME_SCHEMA_VERSION = 1;
9
10
  const QK_L2NORM_EPS = 1e-6;
@@ -173,9 +174,22 @@ function inferLinearNormModeFromWeight(weight, projectionLayout) {
173
174
  if (weight instanceof ArrayBuffer) {
174
175
  return classify(Math.trunc(weight.byteLength / Float32Array.BYTES_PER_ELEMENT));
175
176
  }
177
+ const explicitDtype = typeof weight?.dtype === 'string' ? weight.dtype.toLowerCase() : null;
178
+ const trackedDtype = isGpuBuffer(weight) ? String(getBufferDtype(weight) ?? '').toLowerCase() : '';
179
+ const bytesPerElement = bytesFromDtype(explicitDtype || trackedDtype || null);
180
+ const sizedElements = Number.isFinite(weight?.size)
181
+ ? Math.trunc(Number(weight.size) / bytesPerElement)
182
+ : null;
183
+ if (sizedElements && Number(weight.size) % bytesPerElement === 0) {
184
+ return classify(sizedElements);
185
+ }
176
186
  return null;
177
187
  }
178
188
 
189
+ export function inferLinearNormMode(weight, projectionLayout) {
190
+ return inferLinearNormModeFromWeight(weight, projectionLayout);
191
+ }
192
+
179
193
  function resolveLinearNormMode(configNormMode, normWeight, projectionLayout, layerIdx) {
180
194
  const configuredMode = normalizeLinearNormMode(configNormMode);
181
195
  const inferredMode = inferLinearNormModeFromWeight(normWeight, projectionLayout);
@@ -185,7 +199,15 @@ function resolveLinearNormMode(configNormMode, normWeight, projectionLayout, lay
185
199
  `but norm.weight shape implies "${inferredMode}".`
186
200
  );
187
201
  }
188
- return configuredMode ?? inferredMode ?? 'shared';
202
+ if (configuredMode) {
203
+ return configuredMode;
204
+ }
205
+ if (inferredMode) {
206
+ return inferredMode;
207
+ }
208
+ throw new Error(
209
+ `linear_attention layer ${layerIdx} requires explicit linearNormMode or a norm.weight shape that resolves it.`
210
+ );
189
211
  }
190
212
 
191
213
  async function readWeightAsF32(weight, expectedElements, label) {
@@ -395,10 +417,17 @@ async function createLayerRuntimeState(
395
417
 
396
418
  let convKernelSize = toPositiveInt(config.linearConvKernelDim) ?? null;
397
419
  if (isWeightBuffer(convKernel) && Array.isArray(convKernel.shape) && convKernel.shape.length >= 3) {
398
- convKernelSize = toPositiveInt(convKernel.shape[2]) ?? convKernelSize;
420
+ const shapeKernelSize = toPositiveInt(convKernel.shape[2]) ?? null;
421
+ if (convKernelSize != null && shapeKernelSize != null && convKernelSize !== shapeKernelSize) {
422
+ throw new Error(
423
+ `linear_attention layer ${layerIdx} declares linearConvKernelDim=${convKernelSize}, ` +
424
+ `but conv1d weight shape implies ${shapeKernelSize}.`
425
+ );
426
+ }
427
+ convKernelSize = shapeKernelSize ?? convKernelSize;
399
428
  }
400
429
  if (!convKernelSize) {
401
- convKernelSize = 4;
430
+ throw new Error(`linear_attention layer ${layerIdx} requires linearConvKernelDim.`);
402
431
  }
403
432
 
404
433
  const convWeight = await readWeightAsF32(
@@ -435,6 +464,11 @@ async function createLayerRuntimeState(
435
464
  const recurrentState = new Float32Array(
436
465
  projectionLayout.numVHeads * projectionLayout.headKDim * projectionLayout.headVDim
437
466
  );
467
+ const rmsNormEps = Number(config.rmsNormEps);
468
+ if (!Number.isFinite(rmsNormEps) || rmsNormEps <= 0) {
469
+ throw new Error(`linear_attention layer ${layerIdx} requires a positive rmsNormEps.`);
470
+ }
471
+
438
472
  const layerState = {
439
473
  layerIdx,
440
474
  seqLen: currentSeqLen,
@@ -452,7 +486,7 @@ async function createLayerRuntimeState(
452
486
  vSize: projectionLayout.vSize,
453
487
  qRep: projectionLayout.qRep,
454
488
  normMode,
455
- rmsNormEps: Number(config.rmsNormEps) || 1e-6,
489
+ rmsNormEps,
456
490
  convWeight,
457
491
  dtBias,
458
492
  aNegExp,
@@ -681,13 +715,13 @@ export async function runLinearAttentionLayer(inputTensor, layerWeights, options
681
715
  const normWeightBuffer = getNormWeightBuffer(layerWeights.inputNorm, `L${layerIdx}.linear_input_norm`);
682
716
  try {
683
717
  if (recorder) {
684
- normedTensor = await recordRMSNorm(recorder, inputTensor, normWeightBuffer, Number(config.rmsNormEps) || 1e-6, {
718
+ normedTensor = await recordRMSNorm(recorder, inputTensor, normWeightBuffer, layerState.rmsNormEps, {
685
719
  batchSize: numTokens,
686
720
  hiddenSize,
687
721
  rmsNormWeightOffset: config.rmsNormWeightOffset,
688
722
  });
689
723
  } else {
690
- normedTensor = await runRMSNorm(inputTensor, normWeightBuffer, Number(config.rmsNormEps) || 1e-6, {
724
+ normedTensor = await runRMSNorm(inputTensor, normWeightBuffer, layerState.rmsNormEps, {
691
725
  batchSize: numTokens,
692
726
  hiddenSize,
693
727
  rmsNormWeightOffset: config.rmsNormWeightOffset,
@@ -755,6 +789,38 @@ export async function runLinearAttentionLayer(inputTensor, layerWeights, options
755
789
  });
756
790
 
757
791
  try {
792
+ await runProbes('linear_qkv_proj', qkvTensor.buffer, {
793
+ layerIdx,
794
+ numTokens,
795
+ hiddenSize: projectionLayout.convDim,
796
+ probes: options.debugProbes,
797
+ recorder,
798
+ dtype: qkvTensor.dtype,
799
+ });
800
+ await runProbes('linear_z_proj', zTensor.buffer, {
801
+ layerIdx,
802
+ numTokens,
803
+ hiddenSize: projectionLayout.valueDim,
804
+ probes: options.debugProbes,
805
+ recorder,
806
+ dtype: zTensor.dtype,
807
+ });
808
+ await runProbes('linear_a_proj', aTensor.buffer, {
809
+ layerIdx,
810
+ numTokens,
811
+ hiddenSize: projectionLayout.numVHeads,
812
+ probes: options.debugProbes,
813
+ recorder,
814
+ dtype: aTensor.dtype,
815
+ });
816
+ await runProbes('linear_b_proj', bTensor.buffer, {
817
+ layerIdx,
818
+ numTokens,
819
+ hiddenSize: projectionLayout.numVHeads,
820
+ probes: options.debugProbes,
821
+ recorder,
822
+ dtype: bTensor.dtype,
823
+ });
758
824
  const coreTensor = await runLinearAttentionCoreGPU(
759
825
  qkvTensor,
760
826
  zTensor,
@@ -768,6 +834,14 @@ export async function runLinearAttentionLayer(inputTensor, layerWeights, options
768
834
  recorder,
769
835
  }
770
836
  );
837
+ await runProbes('linear_core_out', coreTensor.buffer, {
838
+ layerIdx,
839
+ numTokens,
840
+ hiddenSize: projectionLayout.valueDim,
841
+ probes: options.debugProbes,
842
+ recorder,
843
+ dtype: coreTensor.dtype,
844
+ });
771
845
  layerState.seqLen = currentSeqLen + numTokens;
772
846
  const outProjWeight = getWeightBuffer(layerWeights.oProj, `L${layerIdx}.linear_out_proj`);
773
847
  try {
@@ -1,7 +1,7 @@
1
1
 
2
2
 
3
3
  import { getDevice, getKernelCapabilities } from '../../../../gpu/device.js';
4
- import { acquireBuffer, releaseBuffer, readBuffer } from '../../../../memory/buffer-pool.js';
4
+ import { acquireBuffer, releaseBuffer } from '../../../../memory/buffer-pool.js';
5
5
  import { runMatmul, runRMSNorm } from '../../../../gpu/kernel-selector.js';
6
6
  import { recordMatmul } from '../../../../gpu/kernels/matmul.js';
7
7
  import { recordRMSNorm } from '../../../../gpu/kernels/rmsnorm.js';
@@ -13,6 +13,7 @@ import { getRuntimeConfig } from '../../../../config/runtime.js';
13
13
  import { selectRuleValue } from '../../../../rules/rule-registry.js';
14
14
  import { runProbes } from '../probes.js';
15
15
  import { f16BufferToF32 } from './cpu.js';
16
+ import { readBufferWithCleanup } from './utils.js';
16
17
 
17
18
  function shouldForceStableF32Logits(config, inputDtype) {
18
19
  // Small Gemma-family checkpoints can overflow in pure F16 logits path after RMSNorm offset.
@@ -187,14 +188,18 @@ export async function computeChunkedLogitsGPU(
187
188
  }
188
189
 
189
190
  const logitsBytes = selectRuleValue('shared', 'dtype', 'bytesFromDtype', { dtype: logitsTensor.dtype });
190
- const chunkLogitsData = await readBuffer(logitsTensor.buffer, numTokens * rowCount * logitsBytes);
191
+ const chunkLogitsData = await readBufferWithCleanup(
192
+ logitsTensor.buffer,
193
+ numTokens * rowCount * logitsBytes,
194
+ () => {
195
+ releaseBuffer(logitsTensor.buffer);
196
+ releaseBuffer(weightBuffer.buffer);
197
+ }
198
+ );
191
199
  const chunkLogits = logitsTensor.dtype === 'f16'
192
200
  ? f16BufferToF32(chunkLogitsData)
193
201
  : new Float32Array(chunkLogitsData);
194
202
  writeChunkLogits(logits, chunkLogits, numTokens, vocabSize, rowOffset, rowCount);
195
-
196
- releaseBuffer(logitsTensor.buffer);
197
- releaseBuffer(weightBuffer.buffer);
198
203
  }
199
204
 
200
205
  return logits;
@@ -7,7 +7,7 @@ export { rmsNormCPU, matmulCPU, applySoftcapping, f16ToF32, f16BufferToF32 } fro
7
7
  export { computeLogitsGPU, recordLogitsGPU, computeChunkedLogitsGPU, resolveCpuWeightDims, resolveLmHeadChunkRows, extractLmHeadChunk, writeChunkLogits } from './gpu.js';
8
8
 
9
9
  // Re-export utilities
10
- export { extractLastPositionLogits, finalizeLogits } from './utils.js';
10
+ export { extractLastPositionLogits, finalizeLogits, readBufferWithCleanup } from './utils.js';
11
11
 
12
12
  // Imports for computeLogits orchestrator
13
13
  import { getDevice } from '../../../../gpu/device.js';
@@ -20,7 +20,7 @@ import { log, trace, isTraceEnabled } from '../../../../debug/index.js';
20
20
  import { runProbes } from '../probes.js';
21
21
  import { rmsNormCPU, matmulCPU, f16BufferToF32 } from './cpu.js';
22
22
  import { resolveCpuWeightDims, computeChunkedLogitsGPU } from './gpu.js';
23
- import { finalizeLogits } from './utils.js';
23
+ import { finalizeLogits, readBufferWithCleanup } from './utils.js';
24
24
  import { getRuntimeConfig } from '../../../../config/runtime.js';
25
25
  import { selectRuleValue } from '../../../../rules/rule-registry.js';
26
26
 
@@ -288,15 +288,14 @@ export async function computeLogits(
288
288
  // 4. Read back logits
289
289
  const logitsBytes = selectRuleValue('shared', 'dtype', 'bytesFromDtype', { dtype: logitsTensor.dtype });
290
290
  const logitsReadSize = matmulRows * matmulVocabSize * logitsBytes;
291
- const logitsData = await readBuffer(logitsTensor.buffer, logitsReadSize);
292
-
293
- // Cleanup
294
- if (inputBufferOwned) releaseBuffer(inputBuffer);
295
- releaseBuffer(normedTensor.buffer);
296
- if (matmulInputOwned) releaseBuffer(matmulInputTensor.buffer);
297
- releaseBuffer(logitsTensor.buffer);
298
- if (!getNormWeightBuffer && !(finalNorm instanceof GPUBuffer)) releaseBuffer(normWeightBuffer);
299
- if (lmHeadBufferOwned) releaseBuffer(lmHeadGPU);
291
+ const logitsData = await readBufferWithCleanup(logitsTensor.buffer, logitsReadSize, () => {
292
+ if (inputBufferOwned) releaseBuffer(inputBuffer);
293
+ releaseBuffer(normedTensor.buffer);
294
+ if (matmulInputOwned) releaseBuffer(matmulInputTensor.buffer);
295
+ releaseBuffer(logitsTensor.buffer);
296
+ if (!getNormWeightBuffer && !(finalNorm instanceof GPUBuffer)) releaseBuffer(normWeightBuffer);
297
+ if (lmHeadBufferOwned) releaseBuffer(lmHeadGPU);
298
+ });
300
299
 
301
300
  const rawLogits = logitsTensor.dtype === 'f16'
302
301
  ? f16BufferToF32(logitsData)
@@ -25,6 +25,13 @@ export function extractLastPositionLogits(
25
25
  vocabSize: number
26
26
  ): Float32Array;
27
27
 
28
+ export function readBufferWithCleanup(
29
+ buffer: GPUBuffer,
30
+ byteLength: number,
31
+ cleanup?: (() => void) | null,
32
+ reader?: ((buffer: GPUBuffer, byteLength: number) => Promise<ArrayBuffer>) | null
33
+ ): Promise<ArrayBuffer>;
34
+
28
35
  /**
29
36
  * Finalize logits by applying padding and softcapping.
30
37
  *
@@ -1,5 +1,6 @@
1
1
 
2
2
 
3
+ import { readBuffer } from '../../../../memory/buffer-pool.js';
3
4
  import { runProbes } from '../probes.js';
4
5
  import { applySoftcapping } from './cpu.js';
5
6
 
@@ -19,6 +20,14 @@ export function extractLastPositionLogits(
19
20
  return lastPosLogits;
20
21
  }
21
22
 
23
+ export async function readBufferWithCleanup(buffer, byteLength, cleanup, reader = readBuffer) {
24
+ try {
25
+ return await reader(buffer, byteLength);
26
+ } finally {
27
+ cleanup?.();
28
+ }
29
+ }
30
+
22
31
 
23
32
  export async function finalizeLogits(
24
33
  rawLogits,
@@ -17,42 +17,60 @@ export async function applyLoRA(input, baseOutput, lora, dims, getWeightBuffer,
17
17
 
18
18
  const aBuf = getWeightBuffer(lora.a, 'lora_a');
19
19
  const bBuf = getWeightBuffer(lora.b, 'lora_b');
20
- const ownsA = !(lora.a instanceof GPUBuffer) && !isWeightBuffer(lora.a);
21
- const ownsB = !(lora.b instanceof GPUBuffer) && !isWeightBuffer(lora.b);
22
-
23
- const loraIntermediate = recorder
24
- ? await recordMatmul(recorder, input, aBuf, M, rank, K, { transposeB: 'auto', role: 'lora_a', kernelPath })
25
- : await runMatmul(input, aBuf, M, rank, K, { transposeB: 'auto', role: 'lora_a', kernelPath });
20
+ const ownsA = !(typeof GPUBuffer !== 'undefined' && lora.a instanceof GPUBuffer) && !isWeightBuffer(lora.a);
21
+ const ownsB = !(typeof GPUBuffer !== 'undefined' && lora.b instanceof GPUBuffer) && !isWeightBuffer(lora.b);
22
+ // Extract underlying GPUBuffer for WeightBuffers
23
+ const aBufGPU = isWeightBuffer(aBuf) ? aBuf.buffer : aBuf;
24
+ const bBufGPU = isWeightBuffer(bBuf) ? bBuf.buffer : bBuf;
25
+ let loraIntermediate = null;
26
+ let loraOutput = null;
27
+ let scaled = null;
28
+ try {
29
+ loraIntermediate = recorder
30
+ ? await recordMatmul(recorder, input, aBuf, M, rank, K, { transposeB: 'auto', role: 'lora_a', kernelPath })
31
+ : await runMatmul(input, aBuf, M, rank, K, { transposeB: 'auto', role: 'lora_a', kernelPath });
26
32
 
27
- const loraOutput = recorder
28
- ? await recordMatmul(recorder, loraIntermediate, bBuf, M, N, rank, { transposeB: 'auto', role: 'lora_b', kernelPath })
29
- : await runMatmul(loraIntermediate, bBuf, M, N, rank, { transposeB: 'auto', role: 'lora_b', kernelPath });
33
+ loraOutput = recorder
34
+ ? await recordMatmul(recorder, loraIntermediate, bBuf, M, N, rank, { transposeB: 'auto', role: 'lora_b', kernelPath })
35
+ : await runMatmul(loraIntermediate, bBuf, M, N, rank, { transposeB: 'auto', role: 'lora_b', kernelPath });
30
36
 
31
- const scaled = recorder
32
- ? await recordScale(recorder, loraOutput, lora.scale, { outputBuffer: null })
33
- : await runScale(loraOutput, lora.scale, { outputBuffer: null });
37
+ scaled = recorder
38
+ ? await recordScale(recorder, loraOutput, lora.scale, { outputBuffer: null })
39
+ : await runScale(loraOutput, lora.scale, { outputBuffer: null });
34
40
 
35
- const combined = recorder
36
- ? await recordResidualAdd(recorder, baseOutput, scaled, M * N)
37
- : await runResidualAdd(baseOutput, scaled, M * N);
41
+ const combined = recorder
42
+ ? await recordResidualAdd(recorder, baseOutput, scaled, M * N)
43
+ : await runResidualAdd(baseOutput, scaled, M * N);
38
44
 
39
- // Extract underlying GPUBuffer for WeightBuffers
40
- const aBufGPU = isWeightBuffer(aBuf) ? aBuf.buffer : aBuf;
41
- const bBufGPU = isWeightBuffer(bBuf) ? bBuf.buffer : bBuf;
45
+ if (recorder) {
46
+ recorder.trackTemporaryBuffer(loraIntermediate.buffer);
47
+ recorder.trackTemporaryBuffer(loraOutput.buffer);
48
+ recorder.trackTemporaryBuffer(scaled.buffer);
49
+ if (ownsA) recorder.trackTemporaryBuffer(aBufGPU);
50
+ if (ownsB) recorder.trackTemporaryBuffer(bBufGPU);
51
+ } else {
52
+ releaseBuffer(loraIntermediate.buffer);
53
+ releaseBuffer(loraOutput.buffer);
54
+ releaseBuffer(scaled.buffer);
55
+ if (ownsA) releaseBuffer(aBufGPU);
56
+ if (ownsB) releaseBuffer(bBufGPU);
57
+ }
42
58
 
43
- if (recorder) {
44
- recorder.trackTemporaryBuffer(loraIntermediate.buffer);
45
- recorder.trackTemporaryBuffer(loraOutput.buffer);
46
- recorder.trackTemporaryBuffer(scaled.buffer);
47
- if (ownsA) recorder.trackTemporaryBuffer(aBufGPU);
48
- if (ownsB) recorder.trackTemporaryBuffer(bBufGPU);
49
- } else {
50
- releaseBuffer(loraIntermediate.buffer);
51
- releaseBuffer(loraOutput.buffer);
52
- releaseBuffer(scaled.buffer);
53
- if (ownsA) releaseBuffer(aBufGPU);
54
- if (ownsB) releaseBuffer(bBufGPU);
59
+ return combined;
60
+ } catch (error) {
61
+ if (recorder) {
62
+ if (loraIntermediate) recorder.trackTemporaryBuffer(loraIntermediate.buffer);
63
+ if (loraOutput) recorder.trackTemporaryBuffer(loraOutput.buffer);
64
+ if (scaled) recorder.trackTemporaryBuffer(scaled.buffer);
65
+ if (ownsA) recorder.trackTemporaryBuffer(aBufGPU);
66
+ if (ownsB) recorder.trackTemporaryBuffer(bBufGPU);
67
+ } else {
68
+ if (loraIntermediate) releaseBuffer(loraIntermediate.buffer);
69
+ if (loraOutput) releaseBuffer(loraOutput.buffer);
70
+ if (scaled) releaseBuffer(scaled.buffer);
71
+ if (ownsA) releaseBuffer(aBufGPU);
72
+ if (ownsB) releaseBuffer(bBufGPU);
73
+ }
74
+ throw error;
55
75
  }
56
-
57
- return combined;
58
76
  }