@simulatte/doppler 0.1.6 → 0.1.8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (355) hide show
  1. package/CHANGELOG.md +145 -0
  2. package/README.md +16 -23
  3. package/package.json +30 -32
  4. package/src/adapters/adapter-registry.js +12 -1
  5. package/src/adapters/lora-loader.js +23 -6
  6. package/src/bridge/extension-client.d.ts +5 -0
  7. package/src/bridge/extension-client.js +40 -0
  8. package/src/bridge/index.d.ts +2 -1
  9. package/src/bridge/index.js +6 -4
  10. package/src/browser/browser-converter.js +31 -1
  11. package/src/browser/file-picker.js +6 -0
  12. package/src/browser/safetensors-parser-browser.js +84 -1
  13. package/src/browser/shard-io-browser.js +2 -2
  14. package/src/browser/tensor-source-download.js +8 -2
  15. package/src/browser/tensor-source-http.d.ts +1 -0
  16. package/src/browser/tensor-source-http.js +5 -1
  17. package/src/client/doppler-api.browser.js +20 -4
  18. package/src/client/doppler-api.js +19 -3
  19. package/src/client/doppler-provider/generation.js +12 -0
  20. package/src/client/doppler-provider/model-manager.d.ts +10 -0
  21. package/src/client/doppler-provider/model-manager.js +91 -19
  22. package/src/client/doppler-provider/source-runtime.d.ts +2 -1
  23. package/src/client/doppler-provider/source-runtime.js +132 -13
  24. package/src/client/doppler-registry.json +5 -20
  25. package/src/config/backward-registry-loader.js +17 -2
  26. package/src/config/execution-v0-contract-check.js +113 -15
  27. package/src/config/kernel-path-contract-check.js +57 -29
  28. package/src/config/kernel-path-loader.d.ts +5 -0
  29. package/src/config/kernel-path-loader.js +18 -36
  30. package/src/config/kernels/kernel-ref-digests.js +1 -1
  31. package/src/config/kernels/registry.js +14 -1
  32. package/src/config/kernels/registry.json +81 -5
  33. package/src/config/loader.d.ts +1 -1
  34. package/src/config/loader.js +15 -2
  35. package/src/config/merge-contract-check.js +66 -4
  36. package/src/config/merge-helpers.js +128 -7
  37. package/src/config/merge.d.ts +1 -0
  38. package/src/config/merge.js +10 -0
  39. package/src/config/param-validator.js +47 -2
  40. package/src/config/presets/kernel-paths/{gemma2-q4k-dequant-f32a.json → gemma2-q4k-dequant-f32a-nosubgroups.json} +3 -3
  41. package/src/config/presets/kernel-paths/gemma3-f16-fused-f32a-online-streamingprefill.json +223 -0
  42. package/src/config/presets/kernel-paths/{gemma3-q4k-dequant-f32a.json → gemma3-q4k-dequant-f32a-nosubgroups.json} +3 -3
  43. package/src/config/presets/kernel-paths/gemma3-q4k-dequant-f32w-f32a-online.json +56 -0
  44. package/src/config/presets/kernel-paths/lfm2-q4k-dequant-f32a-nosubgroups.json +61 -0
  45. package/src/config/presets/kernel-paths/registry.json +43 -8
  46. package/src/config/presets/models/gemma2.json +3 -2
  47. package/src/config/presets/models/gemma3.json +2 -0
  48. package/src/config/presets/models/qwen3.json +4 -3
  49. package/src/config/presets/models/qwen3_5.json +16 -0
  50. package/src/config/presets/runtime/experiments/bench/gemma3-bench-q4k.json +1 -1
  51. package/src/config/presets/runtime/experiments/debug/gemma3-debug-q4k.json +1 -1
  52. package/src/config/presets/runtime/experiments/verify/gemma3-verify.json +1 -1
  53. package/src/config/presets/runtime/kernels/dequant-f16-q4k.json +6 -13
  54. package/src/config/presets/runtime/kernels/dequant-f32-q4k.json +6 -13
  55. package/src/config/presets/runtime/kernels/embeddinggemma-q4k-dequant-f32a.json +37 -0
  56. package/src/config/presets/runtime/kernels/fused-q4k.json +6 -13
  57. package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f16a.json +33 -0
  58. package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f32a-nosubgroups.json +33 -0
  59. package/src/config/presets/runtime/kernels/gemma2-q4k-fused-f32a.json +33 -0
  60. package/src/config/presets/runtime/kernels/safe-q4k.json +6 -13
  61. package/src/config/presets/runtime/model/qwen3-5-layer-probe.json +52 -0
  62. package/src/config/presets/runtime/model/qwen3-5-linear-attn-debug.json +90 -0
  63. package/src/config/presets/runtime/platform/metal-apple-q4k.json +1 -1
  64. package/src/config/runtime.js +6 -1
  65. package/src/config/schema/conversion.schema.d.ts +1 -0
  66. package/src/config/schema/debug.schema.d.ts +5 -0
  67. package/src/config/schema/doppler.schema.js +16 -21
  68. package/src/config/schema/inference-defaults.schema.js +3 -3
  69. package/src/config/schema/kernel-path.schema.d.ts +5 -1
  70. package/src/config/schema/kernel-thresholds.schema.js +12 -4
  71. package/src/config/schema/manifest.schema.d.ts +3 -2
  72. package/src/config/schema/manifest.schema.js +17 -4
  73. package/src/config/schema/storage.schema.js +1 -1
  74. package/src/config/training-defaults.js +30 -22
  75. package/src/converter/conversion-plan.js +104 -11
  76. package/src/converter/core.d.ts +7 -0
  77. package/src/converter/core.js +16 -9
  78. package/src/converter/execution-v0-manifest.js +4 -1
  79. package/src/converter/index.d.ts +1 -0
  80. package/src/converter/index.js +1 -0
  81. package/src/converter/manifest-inference.js +50 -29
  82. package/src/converter/parsers/diffusion.js +0 -3
  83. package/src/converter/parsers/transformer.js +4 -0
  84. package/src/converter/quantization-info.js +40 -16
  85. package/src/converter/quantizer.js +19 -12
  86. package/src/converter/rope-config.js +8 -6
  87. package/src/converter/shard-packer.d.ts +1 -1
  88. package/src/converter/shard-packer.js +4 -1
  89. package/src/converter/tokenizer-utils.d.ts +1 -0
  90. package/src/converter/tokenizer-utils.js +4 -1
  91. package/src/debug/config.js +123 -11
  92. package/src/debug/reference/hf_qwen35_linear_attn_debug.py +268 -0
  93. package/src/debug/signals.js +7 -1
  94. package/src/debug/tensor.d.ts +2 -0
  95. package/src/debug/tensor.js +13 -2
  96. package/src/distribution/p2p-control-plane.js +52 -12
  97. package/src/distribution/p2p-observability.js +43 -7
  98. package/src/distribution/p2p-webrtc-browser.js +20 -0
  99. package/src/distribution/shard-delivery.js +83 -27
  100. package/src/formats/gguf/types.js +33 -16
  101. package/src/formats/rdrr/groups.d.ts +12 -4
  102. package/src/formats/rdrr/groups.js +3 -6
  103. package/src/formats/rdrr/parsing.d.ts +4 -0
  104. package/src/formats/rdrr/parsing.js +53 -3
  105. package/src/formats/rdrr/types.d.ts +2 -1
  106. package/src/gpu/command-recorder.js +86 -61
  107. package/src/gpu/device.d.ts +1 -0
  108. package/src/gpu/device.js +73 -19
  109. package/src/gpu/kernel-tuner/benchmarks.js +326 -316
  110. package/src/gpu/kernel-tuner/cache.js +71 -4
  111. package/src/gpu/kernel-tuner/tuner.js +22 -4
  112. package/src/gpu/kernels/attention.js +15 -34
  113. package/src/gpu/kernels/backward/adam.js +62 -58
  114. package/src/gpu/kernels/backward/attention_backward.js +257 -169
  115. package/src/gpu/kernels/backward/conv2d_backward.js +14 -1
  116. package/src/gpu/kernels/cast.js +191 -149
  117. package/src/gpu/kernels/check-stop.js +33 -44
  118. package/src/gpu/kernels/conv2d.js +27 -17
  119. package/src/gpu/kernels/cross_entropy_loss.js +21 -15
  120. package/src/gpu/kernels/depthwise_conv2d.js +36 -26
  121. package/src/gpu/kernels/dequant.js +178 -126
  122. package/src/gpu/kernels/energy.d.ts +3 -21
  123. package/src/gpu/kernels/energy.js +111 -88
  124. package/src/gpu/kernels/feature-check.js +1 -1
  125. package/src/gpu/kernels/fused_ffn.js +84 -65
  126. package/src/gpu/kernels/fused_matmul_residual.js +56 -33
  127. package/src/gpu/kernels/fused_matmul_rmsnorm.js +62 -45
  128. package/src/gpu/kernels/gather.js +33 -15
  129. package/src/gpu/kernels/gelu.js +19 -11
  130. package/src/gpu/kernels/grouped_pointwise_conv2d.js +33 -23
  131. package/src/gpu/kernels/groupnorm.js +34 -23
  132. package/src/gpu/kernels/index.d.ts +8 -0
  133. package/src/gpu/kernels/index.js +6 -0
  134. package/src/gpu/kernels/kv-quantize.js +5 -2
  135. package/src/gpu/kernels/layernorm.js +35 -19
  136. package/src/gpu/kernels/logit-merge.js +5 -3
  137. package/src/gpu/kernels/matmul-selection.js +47 -4
  138. package/src/gpu/kernels/matmul.d.ts +2 -0
  139. package/src/gpu/kernels/matmul.js +59 -40
  140. package/src/gpu/kernels/modulate.js +23 -15
  141. package/src/gpu/kernels/moe.js +221 -175
  142. package/src/gpu/kernels/pixel_shuffle.js +22 -14
  143. package/src/gpu/kernels/relu.js +18 -10
  144. package/src/gpu/kernels/repeat_channels.js +25 -17
  145. package/src/gpu/kernels/residual.js +37 -27
  146. package/src/gpu/kernels/rmsnorm.js +66 -43
  147. package/src/gpu/kernels/rope.js +3 -0
  148. package/src/gpu/kernels/sample.js +27 -38
  149. package/src/gpu/kernels/sana_linear_attention.js +18 -10
  150. package/src/gpu/kernels/scale.js +18 -11
  151. package/src/gpu/kernels/shader-cache.js +4 -2
  152. package/src/gpu/kernels/silu.js +120 -72
  153. package/src/gpu/kernels/softmax.js +44 -25
  154. package/src/gpu/kernels/split_qg.d.ts +50 -0
  155. package/src/gpu/kernels/split_qg.js +46 -0
  156. package/src/gpu/kernels/split_qg.wgsl +58 -0
  157. package/src/gpu/kernels/split_qg_f16.wgsl +62 -0
  158. package/src/gpu/kernels/split_qkv.js +23 -13
  159. package/src/gpu/kernels/transpose.js +18 -10
  160. package/src/gpu/kernels/transpose.wgsl +5 -3
  161. package/src/gpu/kernels/upsample2d.js +21 -13
  162. package/src/gpu/kernels/utils.js +20 -13
  163. package/src/gpu/partitioned-buffer-pool.js +10 -2
  164. package/src/gpu/perf-guards.js +2 -9
  165. package/src/gpu/profiler.js +27 -22
  166. package/src/gpu/readback-utils.d.ts +16 -0
  167. package/src/gpu/readback-utils.js +41 -0
  168. package/src/gpu/submit-tracker.js +13 -0
  169. package/src/gpu/uniform-cache.d.ts +1 -0
  170. package/src/gpu/uniform-cache.js +30 -9
  171. package/src/gpu/weight-buffer.d.ts +1 -1
  172. package/src/gpu/weight-buffer.js +1 -1
  173. package/src/hotswap/intent-bundle.js +6 -0
  174. package/src/hotswap/manifest.d.ts +10 -1
  175. package/src/hotswap/manifest.js +12 -2
  176. package/src/hotswap/runtime.js +30 -8
  177. package/src/index-browser.d.ts +44 -0
  178. package/src/index-browser.js +14 -0
  179. package/src/inference/browser-harness-contract-helpers.d.ts +5 -0
  180. package/src/inference/browser-harness-contract-helpers.js +28 -0
  181. package/src/inference/browser-harness-diffusion-energy-suites.d.ts +2 -0
  182. package/src/inference/browser-harness-diffusion-energy-suites.js +269 -0
  183. package/src/inference/browser-harness-model-helpers.d.ts +16 -0
  184. package/src/inference/browser-harness-model-helpers.js +217 -0
  185. package/src/inference/browser-harness-report-helpers.d.ts +7 -0
  186. package/src/inference/browser-harness-report-helpers.js +42 -0
  187. package/src/inference/browser-harness-runtime-helpers.d.ts +61 -0
  188. package/src/inference/browser-harness-runtime-helpers.js +415 -0
  189. package/src/inference/browser-harness-suite-helpers.d.ts +28 -0
  190. package/src/inference/browser-harness-suite-helpers.js +268 -0
  191. package/src/inference/browser-harness-text-helpers.d.ts +27 -0
  192. package/src/inference/browser-harness-text-helpers.js +788 -0
  193. package/src/inference/browser-harness.d.ts +8 -0
  194. package/src/inference/browser-harness.js +149 -1996
  195. package/src/inference/kv-cache/base.js +140 -94
  196. package/src/inference/kv-cache/tiered.js +5 -3
  197. package/src/inference/moe-router.js +88 -56
  198. package/src/inference/multi-model-network.js +5 -3
  199. package/src/inference/network-evolution.d.ts +11 -2
  200. package/src/inference/network-evolution.js +20 -21
  201. package/src/inference/pipelines/context.d.ts +3 -0
  202. package/src/inference/pipelines/context.js +142 -2
  203. package/src/inference/pipelines/diffusion/helpers.js +10 -2
  204. package/src/inference/pipelines/diffusion/pipeline.js +2 -1
  205. package/src/inference/pipelines/diffusion/sd3-transformer.js +10 -10
  206. package/src/inference/pipelines/diffusion/text-encoder-gpu.js +8 -2
  207. package/src/inference/pipelines/diffusion/vae.js +3 -7
  208. package/src/inference/pipelines/energy/pipeline.js +27 -21
  209. package/src/inference/pipelines/energy/quintel.d.ts +5 -0
  210. package/src/inference/pipelines/energy/quintel.js +11 -0
  211. package/src/inference/pipelines/energy-head/row-head-pipeline.js +17 -13
  212. package/src/inference/pipelines/structured/json-head-pipeline.js +26 -11
  213. package/src/inference/pipelines/text/attention/output-projection.d.ts +12 -0
  214. package/src/inference/pipelines/text/attention/output-projection.js +8 -0
  215. package/src/inference/pipelines/text/attention/projections.d.ts +10 -1
  216. package/src/inference/pipelines/text/attention/projections.js +192 -112
  217. package/src/inference/pipelines/text/attention/record.js +77 -14
  218. package/src/inference/pipelines/text/attention/run.js +112 -14
  219. package/src/inference/pipelines/text/config.js +17 -4
  220. package/src/inference/pipelines/text/embed.js +2 -8
  221. package/src/inference/pipelines/text/execution-plan.js +46 -23
  222. package/src/inference/pipelines/text/execution-v0-contract-helpers.d.ts +59 -0
  223. package/src/inference/pipelines/text/execution-v0-contract-helpers.js +937 -0
  224. package/src/inference/pipelines/text/execution-v0-runtime-builders.d.ts +15 -0
  225. package/src/inference/pipelines/text/execution-v0-runtime-builders.js +279 -0
  226. package/src/inference/pipelines/text/execution-v0.js +62 -1013
  227. package/src/inference/pipelines/text/generator-runtime.js +5 -0
  228. package/src/inference/pipelines/text/generator-steps.d.ts +52 -0
  229. package/src/inference/pipelines/text/generator-steps.js +340 -221
  230. package/src/inference/pipelines/text/generator.js +56 -40
  231. package/src/inference/pipelines/text/init.d.ts +13 -0
  232. package/src/inference/pipelines/text/init.js +94 -25
  233. package/src/inference/pipelines/text/kernel-path-auto-select.js +2 -0
  234. package/src/inference/pipelines/text/kernel-trace.d.ts +2 -0
  235. package/src/inference/pipelines/text/kernel-trace.js +6 -0
  236. package/src/inference/pipelines/text/layer.js +4 -9
  237. package/src/inference/pipelines/text/linear-attention.d.ts +15 -0
  238. package/src/inference/pipelines/text/linear-attention.js +113 -9
  239. package/src/inference/pipelines/text/logits/gpu.js +12 -7
  240. package/src/inference/pipelines/text/logits/index.d.ts +6 -1
  241. package/src/inference/pipelines/text/logits/index.js +13 -12
  242. package/src/inference/pipelines/text/logits/utils.d.ts +7 -0
  243. package/src/inference/pipelines/text/logits/utils.js +9 -0
  244. package/src/inference/pipelines/text/lora-apply.js +50 -32
  245. package/src/inference/pipelines/text/model-load.js +282 -104
  246. package/src/inference/pipelines/text/moe-cache.js +5 -4
  247. package/src/inference/pipelines/text/moe-cpu-gptoss.js +74 -69
  248. package/src/inference/pipelines/text/moe-cpu.js +42 -38
  249. package/src/inference/pipelines/text/moe-gpu.js +110 -86
  250. package/src/inference/pipelines/text/ops.js +90 -90
  251. package/src/inference/pipelines/text/probes.js +9 -9
  252. package/src/inference/pipelines/text/sampling.js +52 -6
  253. package/src/inference/pipelines/text/weights.js +17 -7
  254. package/src/inference/pipelines/text.js +13 -1
  255. package/src/inference/speculative.d.ts +2 -2
  256. package/src/inference/speculative.js +4 -18
  257. package/src/inference/test-harness.d.ts +1 -1
  258. package/src/inference/test-harness.js +17 -7
  259. package/src/inference/tokenizer.d.ts +0 -5
  260. package/src/inference/tokenizer.js +4 -23
  261. package/src/inference/tokenizers/bpe.js +9 -0
  262. package/src/inference/tokenizers/bundled.js +20 -0
  263. package/src/inference/tokenizers/sentencepiece.js +12 -0
  264. package/src/loader/doppler-loader.js +38 -22
  265. package/src/loader/dtype-utils.js +3 -44
  266. package/src/loader/embedding-loader.js +7 -3
  267. package/src/loader/experts/expert-cache.js +13 -6
  268. package/src/loader/experts/expert-loader.js +10 -6
  269. package/src/loader/final-weights-loader.js +10 -4
  270. package/src/loader/layer-loader.js +2 -1
  271. package/src/loader/loader-state.js +2 -2
  272. package/src/loader/memory-monitor.js +8 -0
  273. package/src/loader/multi-model-loader.d.ts +14 -0
  274. package/src/loader/multi-model-loader.js +70 -24
  275. package/src/loader/shard-cache.js +84 -14
  276. package/src/loader/shard-resolver.js +25 -3
  277. package/src/loader/tensors/tensor-loader.js +214 -144
  278. package/src/loader/tensors/tensor-reader.js +76 -19
  279. package/src/loader/weight-downcast.js +1 -1
  280. package/src/memory/buffer-pool.d.ts +9 -1
  281. package/src/memory/buffer-pool.js +109 -44
  282. package/src/memory/unified-detect.js +1 -1
  283. package/src/rules/inference/dtype.rules.json +5 -0
  284. package/src/rules/inference/kernel-path.rules.json +24 -8
  285. package/src/rules/kernels/split-qg.rules.json +6 -0
  286. package/src/rules/rule-registry.js +27 -1
  287. package/src/storage/backends/opfs-store.js +68 -24
  288. package/src/storage/downloader.js +365 -83
  289. package/src/storage/index.d.ts +3 -0
  290. package/src/storage/index.js +3 -0
  291. package/src/storage/preflight.d.ts +2 -2
  292. package/src/storage/preflight.js +24 -2
  293. package/src/storage/quickstart-downloader.js +11 -5
  294. package/src/storage/registry.js +10 -4
  295. package/src/storage/reports.js +1 -1
  296. package/src/storage/shard-manager.d.ts +15 -1
  297. package/src/storage/shard-manager.js +55 -6
  298. package/src/storage/source-artifact-store.d.ts +52 -0
  299. package/src/storage/source-artifact-store.js +234 -0
  300. package/src/tooling/command-api-constants.d.ts +9 -0
  301. package/src/tooling/command-api-constants.js +9 -0
  302. package/src/tooling/command-api-family-normalizers.d.ts +9 -0
  303. package/src/tooling/command-api-family-normalizers.js +343 -0
  304. package/src/tooling/command-api-helpers.d.ts +25 -0
  305. package/src/tooling/command-api-helpers.js +262 -0
  306. package/src/tooling/command-api.js +16 -602
  307. package/src/tooling/command-envelope.js +4 -1
  308. package/src/tooling/command-runner-shared.js +52 -18
  309. package/src/tooling/conversion-config-materializer.js +3 -5
  310. package/src/tooling/lean-execution-contract.js +150 -3
  311. package/src/tooling/node-browser-command-runner.js +161 -271
  312. package/src/tooling/node-command-runner.js +29 -3
  313. package/src/tooling/node-converter.js +30 -1
  314. package/src/tooling/node-source-runtime.d.ts +1 -1
  315. package/src/tooling/node-source-runtime.js +120 -3
  316. package/src/tooling/node-webgpu.js +24 -21
  317. package/src/tooling/opfs-cache.js +21 -4
  318. package/src/tooling/runtime-input-composition.d.ts +38 -0
  319. package/src/tooling/runtime-input-composition.js +86 -0
  320. package/src/tooling/source-runtime-bundle.d.ts +40 -5
  321. package/src/tooling/source-runtime-bundle.js +261 -34
  322. package/src/tooling/source-runtime-materializer.d.ts +6 -0
  323. package/src/tooling/source-runtime-materializer.js +93 -0
  324. package/src/training/attention-backward.js +32 -17
  325. package/src/training/autograd.js +80 -52
  326. package/src/training/checkpoint-watch.d.ts +2 -1
  327. package/src/training/checkpoint-watch.js +39 -6
  328. package/src/training/checkpoint.js +40 -11
  329. package/src/training/clip.js +2 -1
  330. package/src/training/datasets/token-batch.js +20 -8
  331. package/src/training/distillation/checkpoint-watch.js +1 -0
  332. package/src/training/distillation/student-fixture.d.ts +22 -0
  333. package/src/training/distillation/student-fixture.js +846 -0
  334. package/src/training/distillation/suite-data.d.ts +45 -0
  335. package/src/training/distillation/suite-data.js +189 -0
  336. package/src/training/lora-pipeline.js +4 -7
  337. package/src/training/lora.js +26 -12
  338. package/src/training/loss.js +5 -6
  339. package/src/training/objectives/cross_entropy.js +2 -5
  340. package/src/training/objectives/distill_kd.js +4 -8
  341. package/src/training/objectives/distill_triplet.js +4 -8
  342. package/src/training/objectives/ul_stage2_base.js +4 -8
  343. package/src/training/operator-command.js +2 -0
  344. package/src/training/optimizer.js +19 -7
  345. package/src/training/runner.js +2 -1
  346. package/src/training/suite.js +18 -978
  347. package/src/training/tensor-factory.d.ts +9 -0
  348. package/src/training/tensor-factory.js +13 -0
  349. package/src/training/trainer.js +3 -5
  350. package/src/training/ul_dataset.js +3 -5
  351. package/src/training/workloads.js +70 -79
  352. package/src/types/model.d.ts +5 -0
  353. package/src/version.js +1 -1
  354. package/tools/convert-safetensors-node.js +22 -16
  355. package/tools/doppler-cli.js +50 -26
@@ -1,33 +1,29 @@
1
- let fallbackRandomState = (Date.now() >>> 0) || 0x6d2b79f5;
2
-
3
- function unseededRandom() {
4
- if (typeof crypto !== 'undefined' && typeof crypto.getRandomValues === 'function') {
5
- const values = new Uint32Array(1);
6
- crypto.getRandomValues(values);
7
- return values[0] / 4294967296;
1
+ function requireRandomSource(random) {
2
+ if (typeof random !== 'function') {
3
+ throw new Error('network evolution requires an explicit random() source.');
8
4
  }
9
- fallbackRandomState = (fallbackRandomState + 0x6d2b79f5) >>> 0;
10
- return fallbackRandomState / 4294967296;
5
+ return random;
11
6
  }
12
7
 
13
- export const mutateGenome = (genome, mutationRate = 0.1) => {
8
+ export const mutateGenome = (genome, mutationRate = 0.1, random = null) => {
9
+ const sample = requireRandomSource(random);
14
10
 
15
11
  const mutated = JSON.parse(JSON.stringify(genome));
16
- if (unseededRandom() < mutationRate) {
12
+ if (sample() < mutationRate) {
17
13
 
18
14
  const types = ['chain', 'tree', 'mesh', 'dag'];
19
- mutated.topology.type = types[Math.floor(unseededRandom() * types.length)];
15
+ mutated.topology.type = types[Math.floor(sample() * types.length)];
20
16
  }
21
17
 
22
18
  for (const node of mutated.nodes) {
23
- if (unseededRandom() < mutationRate && typeof node.temperature === 'number') {
24
- node.temperature = Math.min(1, Math.max(0, node.temperature + (unseededRandom() - 0.5) * 0.2));
19
+ if (sample() < mutationRate && typeof node.temperature === 'number') {
20
+ node.temperature = Math.min(1, Math.max(0, node.temperature + (sample() - 0.5) * 0.2));
25
21
  }
26
22
  }
27
23
 
28
24
  for (const edge of mutated.edges) {
29
- if (unseededRandom() < mutationRate) {
30
- edge.weight = Math.min(1, Math.max(0, edge.weight + (unseededRandom() - 0.5) * 0.4));
25
+ if (sample() < mutationRate) {
26
+ edge.weight = Math.min(1, Math.max(0, edge.weight + (sample() - 0.5) * 0.4));
31
27
  }
32
28
  }
33
29
 
@@ -35,8 +31,9 @@ export const mutateGenome = (genome, mutationRate = 0.1) => {
35
31
  };
36
32
 
37
33
 
38
- export const crossoverGenome = (a, b) => {
39
- return unseededRandom() < 0.5 ? JSON.parse(JSON.stringify(a)) : JSON.parse(JSON.stringify(b));
34
+ export const crossoverGenome = (a, b, random = null) => {
35
+ const sample = requireRandomSource(random);
36
+ return sample() < 0.5 ? JSON.parse(JSON.stringify(a)) : JSON.parse(JSON.stringify(b));
40
37
  };
41
38
 
42
39
 
@@ -48,7 +45,9 @@ export async function evolveNetwork(config) {
48
45
  mutationRate = 0.1,
49
46
  evaluate,
50
47
  randomGenome,
48
+ random,
51
49
  } = config;
50
+ const sample = requireRandomSource(random);
52
51
 
53
52
  let population = Array.from({ length: populationSize }, () => randomGenome());
54
53
 
@@ -63,9 +62,9 @@ export async function evolveNetwork(config) {
63
62
  const offspring = [];
64
63
 
65
64
  while (offspring.length < populationSize - eliteCount) {
66
- const parentA = scored[Math.floor(unseededRandom() * scored.length)].genome;
67
- const parentB = scored[Math.floor(unseededRandom() * scored.length)].genome;
68
- const child = mutateGenome(crossoverGenome(parentA, parentB), mutationRate);
65
+ const parentA = scored[Math.floor(sample() * scored.length)].genome;
66
+ const parentB = scored[Math.floor(sample() * scored.length)].genome;
67
+ const child = mutateGenome(crossoverGenome(parentA, parentB, sample), mutationRate, sample);
69
68
  offspring.push(child);
70
69
  }
71
70
 
@@ -8,6 +8,8 @@ export type PipelineContextOptions = {
8
8
  assignProgress?: boolean;
9
9
  };
10
10
 
11
+ export declare function restorePipelineContexts(target: Record<string, unknown>): boolean;
12
+
11
13
  export declare function applyPipelineContexts(
12
14
  target: Record<string, unknown>,
13
15
  contexts?: Record<string, unknown>,
@@ -15,4 +17,5 @@ export declare function applyPipelineContexts(
15
17
  ): {
16
18
  runtimeConfig: Record<string, unknown>;
17
19
  sharedDebug: Record<string, unknown> | null | undefined;
20
+ restore: () => void;
18
21
  };
@@ -1,8 +1,115 @@
1
- import { getDevice, setDevice } from '../../gpu/device.js';
1
+ import {
2
+ getDevice,
3
+ getKernelCapabilities,
4
+ getPlatformConfig,
5
+ setDevice,
6
+ } from '../../gpu/device.js';
2
7
  import { applyDebugConfig, setGPUDevice } from '../../debug/index.js';
3
8
  import { getRuntimeConfig, setRuntimeConfig } from '../../config/runtime.js';
9
+ import {
10
+ getLogLevel,
11
+ getTrace,
12
+ isSilentMode,
13
+ setLogLevel,
14
+ setSilentMode,
15
+ setTrace,
16
+ } from '../../debug/config.js';
17
+ import {
18
+ gpuDevice as debugGpuDevice,
19
+ traceBreakOnAnomaly,
20
+ traceLayerFilter,
21
+ traceMaxDecodeSteps,
22
+ } from '../../debug/config.js';
23
+
24
+ const RESTORE_PIPELINE_CONTEXTS = Symbol('restorePipelineContexts');
25
+
26
+ function captureTargetField(target, key) {
27
+ return {
28
+ present: Object.prototype.hasOwnProperty.call(target, key),
29
+ value: target[key],
30
+ };
31
+ }
32
+
33
+ function restoreTargetField(target, key, snapshot) {
34
+ if (snapshot.present) {
35
+ target[key] = snapshot.value;
36
+ return;
37
+ }
38
+ delete target[key];
39
+ }
40
+
41
+ function captureDebugState() {
42
+ return {
43
+ logLevel: getLogLevel(),
44
+ traceCategories: getTrace(),
45
+ traceLayers: [...traceLayerFilter],
46
+ traceMaxDecodeSteps,
47
+ traceBreakOnAnomaly,
48
+ silentMode: isSilentMode(),
49
+ gpuDevice: debugGpuDevice,
50
+ };
51
+ }
52
+
53
+ function restoreDebugState(snapshot) {
54
+ if (snapshot.silentMode !== isSilentMode()) {
55
+ setSilentMode(snapshot.silentMode);
56
+ }
57
+ if (getLogLevel() !== snapshot.logLevel) {
58
+ setLogLevel(snapshot.logLevel);
59
+ }
60
+
61
+ const traceCategories = getTrace();
62
+ const traceChanged = traceCategories.length !== snapshot.traceCategories.length
63
+ || traceCategories.some((category, idx) => category !== snapshot.traceCategories[idx])
64
+ || traceLayerFilter.length !== snapshot.traceLayers.length
65
+ || traceLayerFilter.some((layer, idx) => layer !== snapshot.traceLayers[idx])
66
+ || traceMaxDecodeSteps !== snapshot.traceMaxDecodeSteps
67
+ || traceBreakOnAnomaly !== snapshot.traceBreakOnAnomaly;
68
+
69
+ if (traceChanged) {
70
+ if (snapshot.traceCategories.length > 0) {
71
+ setTrace(snapshot.traceCategories.join(','), {
72
+ layers: snapshot.traceLayers.length > 0 ? snapshot.traceLayers : undefined,
73
+ maxDecodeSteps: snapshot.traceMaxDecodeSteps > 0 ? snapshot.traceMaxDecodeSteps : undefined,
74
+ breakOnAnomaly: snapshot.traceBreakOnAnomaly,
75
+ });
76
+ } else {
77
+ setTrace(false);
78
+ }
79
+ }
80
+
81
+ setGPUDevice(snapshot.gpuDevice ?? null);
82
+ }
83
+
84
+ export function restorePipelineContexts(target) {
85
+ const restore = target?.[RESTORE_PIPELINE_CONTEXTS];
86
+ if (typeof restore !== 'function') {
87
+ return false;
88
+ }
89
+ delete target[RESTORE_PIPELINE_CONTEXTS];
90
+ restore();
91
+ return true;
92
+ }
4
93
 
5
94
  export function applyPipelineContexts(target, contexts = {}, options = {}) {
95
+ restorePipelineContexts(target);
96
+
97
+ const previousRuntimeConfig = getRuntimeConfig();
98
+ const previousDevice = getDevice();
99
+ const previousPlatformConfig = getPlatformConfig();
100
+ const previousAdapterInfo = previousDevice
101
+ ? (getKernelCapabilities().adapterInfo ?? null)
102
+ : null;
103
+ const previousDebugState = captureDebugState();
104
+ const targetSnapshot = {
105
+ gpuContext: captureTargetField(target, 'gpuContext'),
106
+ useGPU: captureTargetField(target, 'useGPU'),
107
+ memoryContext: captureTargetField(target, 'memoryContext'),
108
+ storageContext: captureTargetField(target, 'storageContext'),
109
+ baseUrl: captureTargetField(target, 'baseUrl'),
110
+ _onProgress: captureTargetField(target, '_onProgress'),
111
+ };
112
+
6
113
  const runtimeConfig = contexts.runtimeConfig
7
114
  ? setRuntimeConfig(contexts.runtimeConfig)
8
115
  : getRuntimeConfig();
@@ -40,5 +147,38 @@ export function applyPipelineContexts(target, contexts = {}, options = {}) {
40
147
  target._onProgress = contexts.onProgress;
41
148
  }
42
149
 
43
- return { runtimeConfig, sharedDebug };
150
+ let restored = false;
151
+ const restore = () => {
152
+ if (restored) {
153
+ return;
154
+ }
155
+ restored = true;
156
+ delete target[RESTORE_PIPELINE_CONTEXTS];
157
+
158
+ setRuntimeConfig(previousRuntimeConfig);
159
+ if (previousDevice) {
160
+ setDevice(previousDevice, {
161
+ platformConfig: previousPlatformConfig,
162
+ adapterInfo: previousAdapterInfo,
163
+ });
164
+ } else {
165
+ setDevice(null);
166
+ }
167
+ restoreDebugState(previousDebugState);
168
+ restoreTargetField(target, 'gpuContext', targetSnapshot.gpuContext);
169
+ restoreTargetField(target, 'useGPU', targetSnapshot.useGPU);
170
+ restoreTargetField(target, 'memoryContext', targetSnapshot.memoryContext);
171
+ restoreTargetField(target, 'storageContext', targetSnapshot.storageContext);
172
+ restoreTargetField(target, 'baseUrl', targetSnapshot.baseUrl);
173
+ restoreTargetField(target, '_onProgress', targetSnapshot._onProgress);
174
+ };
175
+
176
+ Object.defineProperty(target, RESTORE_PIPELINE_CONTEXTS, {
177
+ value: restore,
178
+ configurable: true,
179
+ enumerable: false,
180
+ writable: false,
181
+ });
182
+
183
+ return { runtimeConfig, sharedDebug, restore };
44
184
  }
@@ -54,8 +54,13 @@ export function createDiffusionIndexBuffer(device, indices, label) {
54
54
  size: indices.byteLength,
55
55
  usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST,
56
56
  });
57
- device.queue.writeBuffer(buffer, 0, indices);
58
- return buffer;
57
+ try {
58
+ device.queue.writeBuffer(buffer, 0, indices);
59
+ return buffer;
60
+ } catch (error) {
61
+ buffer.destroy();
62
+ throw error;
63
+ }
59
64
  }
60
65
 
61
66
  export function expectDiffusionWeight(weight, label) {
@@ -84,6 +89,9 @@ export function normalizeDiffusionMatmulLocationDtype(dtype) {
84
89
  return normalized;
85
90
  }
86
91
 
92
+ // Artifact-derived dtype inference: determines actual storage dtype from buffer byte size.
93
+ // This is NOT a config-bypass — it reads physical buffer dimensions (artifact-derived config),
94
+ // which is a valid merge layer per the config merge contract.
87
95
  export function inferDiffusionMatmulDtypeFromBuffer(weight, N, K, preferred) {
88
96
  const buffer = getBuffer(weight);
89
97
  if (!buffer || !Number.isFinite(N) || !Number.isFinite(K)) return preferred;
@@ -1,7 +1,7 @@
1
1
  import { getDevice, getKernelCapabilities } from '../../../gpu/device.js';
2
2
  import { log, trace } from '../../../debug/index.js';
3
3
  import { registerPipeline } from '../registry.js';
4
- import { applyPipelineContexts } from '../context.js';
4
+ import { applyPipelineContexts, restorePipelineContexts } from '../context.js';
5
5
  import { createInitializedPipeline } from '../factory.js';
6
6
  import { createRng, sampleNormal } from '../rng.js';
7
7
  import { initializeDiffusion } from './init.js';
@@ -319,6 +319,7 @@ export class DiffusionPipeline {
319
319
  this.vaeWeights = null;
320
320
  this.textEncoderWeights = null;
321
321
  this.transformerWeights = null;
322
+ restorePipelineContexts(this);
322
323
  }
323
324
 
324
325
  async ensureVaeWeights() {
@@ -299,26 +299,26 @@ function resolveModulationSegments(weight, hiddenSize, fallbackSegments, resolve
299
299
  if (Number.isInteger(segments) && segments > 0) {
300
300
  return segments;
301
301
  }
302
- log.warn(
303
- 'Diffusion',
304
- `Modulation segments mismatch for ${name || 'unknown'}: rows=${rows}, hidden=${hiddenSize}, fallback=${fallbackSegments}`
302
+ throw new Error(
303
+ `Modulation segments mismatch for ${name || 'unknown'}: rows=${rows}, hidden=${hiddenSize}, ` +
304
+ `expected an integer multiple instead of falling back to ${fallbackSegments}.`
305
305
  );
306
306
  }
307
- return fallbackSegments;
307
+ throw new Error(
308
+ `Modulation tensor "${name || 'unknown'}" is missing shape metadata. ` +
309
+ `Runtime cannot fall back to ${fallbackSegments} segments.`
310
+ );
308
311
  }
309
312
 
310
313
  function resolveModulationOffsets(segments, hiddenSize) {
311
- if (segments >= 9) {
314
+ if (segments === 9) {
312
315
  return {
313
316
  attn: { scale: 0, shift: hiddenSize, gate: hiddenSize * 2 },
314
317
  attn2: { scale: hiddenSize * 3, shift: hiddenSize * 4, gate: hiddenSize * 5 },
315
318
  ff: { scale: hiddenSize * 6, shift: hiddenSize * 7, gate: hiddenSize * 8 },
316
319
  };
317
320
  }
318
- if (segments >= 6) {
319
- if (segments !== 6) {
320
- log.warn('Diffusion', `Unexpected modulation segment count=${segments}; using 6-segment layout.`);
321
- }
321
+ if (segments === 6) {
322
322
  const attn = { scale: 0, shift: hiddenSize, gate: hiddenSize * 2 };
323
323
  return {
324
324
  attn,
@@ -326,7 +326,7 @@ function resolveModulationOffsets(segments, hiddenSize) {
326
326
  ff: { scale: hiddenSize * 3, shift: hiddenSize * 4, gate: hiddenSize * 5 },
327
327
  };
328
328
  }
329
- throw new Error(`Unsupported modulation segments=${segments} (expected >= 6).`);
329
+ throw new Error(`Unsupported modulation segments=${segments} (expected 6 or 9).`);
330
330
  }
331
331
 
332
332
  async function buildModulation(timeText, weight, bias, hiddenSize, segments, runtime, matmul, weightName, ops) {
@@ -45,6 +45,8 @@ import { processLayerGPU } from '../text/layer.js';
45
45
 
46
46
  const QUICK_GELU_ALPHA = 1.702;
47
47
  const SUPPORTED_CLIP_HIDDEN_ACTIVATIONS = new Set(['gelu', 'quick_gelu']);
48
+ // Standard CLIP hidden activation per OpenAI CLIP specification.
49
+ const DEFAULT_CLIP_HIDDEN_ACT = 'gelu';
48
50
 
49
51
  function padTokens(tokens, maxLength, padTokenId) {
50
52
  if (!Number.isFinite(maxLength) || maxLength <= 0) {
@@ -100,11 +102,15 @@ function createVectorTensor(device, data, dtype, label) {
100
102
  return createTensor(buffer, dtype, [1, length], label);
101
103
  }
102
104
 
105
+ // Conservative fallback dtype for diffusion bias tensors when no dtype
106
+ // metadata is available. F32 avoids precision loss in bias additions.
107
+ const DEFAULT_BIAS_DTYPE = 'f32';
108
+
103
109
  function resolveBiasDtype(weight, weightsEntry, key) {
104
110
  if (weight && weight.dtype) return weight.dtype;
105
111
  const locationDtype = weightsEntry?.dtypes?.get(key);
106
112
  const mapped = normalizeDiffusionLocationDtype(locationDtype);
107
- return mapped || 'f32';
113
+ return mapped || DEFAULT_BIAS_DTYPE;
108
114
  }
109
115
 
110
116
  function createBiasTensorWithDtype(weight, weightsEntry, key, size, label) {
@@ -145,7 +151,7 @@ function createKernelOps(recorder) {
145
151
  }
146
152
 
147
153
  function resolveClipHiddenActivation(config) {
148
- const hiddenAct = config?.hidden_act ?? 'gelu';
154
+ const hiddenAct = config?.hidden_act ?? DEFAULT_CLIP_HIDDEN_ACT;
149
155
  if (!SUPPORTED_CLIP_HIDDEN_ACTIVATIONS.has(hiddenAct)) {
150
156
  throw new Error(
151
157
  `Unsupported CLIP hidden_act "${hiddenAct}". ` +
@@ -118,13 +118,9 @@ function resolveAttentionHeadShape(channels, config) {
118
118
  headDim: channels / configuredNumHeads,
119
119
  };
120
120
  }
121
-
122
- const fallbackHeadDims = [64, 40, 32, 24, 20, 16, 12, 10, 8, 6, 5, 4, 3, 2, 1];
123
- const headDim = fallbackHeadDims.find((candidate) => candidate <= channels && channels % candidate === 0) || 1;
124
- return {
125
- numHeads: Math.max(1, channels / headDim),
126
- headDim,
127
- };
121
+ throw new Error(
122
+ `VAE attention requires explicit compatible attention_head_dim or num_attention_heads for channels=${channels}.`
123
+ );
128
124
  }
129
125
 
130
126
  function createBiasTensor(weight, label, fallbackDtype = 'f16') {
@@ -16,10 +16,10 @@ import { log, trace } from '../../../debug/index.js';
16
16
  import { DEFAULT_ENERGY_CONFIG } from '../../../config/schema/energy.schema.js';
17
17
  import { f32ToF16Array, f16ToF32Array } from '../../kv-cache/types.js';
18
18
  import { registerPipeline } from '../registry.js';
19
- import { applyPipelineContexts } from '../context.js';
19
+ import { applyPipelineContexts, restorePipelineContexts } from '../context.js';
20
20
  import { createInitializedPipeline } from '../factory.js';
21
21
  import { createRng, sampleNormal } from '../rng.js';
22
- import { mergeQuintelConfig, runQuintelEnergyLoop } from './quintel.js';
22
+ import { buildQuintelKernelFlags, mergeQuintelConfig, runQuintelEnergyLoop } from './quintel.js';
23
23
 
24
24
 
25
25
  function generateRandomArray(count, mode, seed, scale) {
@@ -140,24 +140,28 @@ async function createEnergyTensor(device, data, dtype, shape, label) {
140
140
  const byteLength = data.byteLength;
141
141
  const alignedSize = Math.ceil(byteLength / 4) * 4;
142
142
  const buffer = acquireBuffer(alignedSize, undefined, label);
143
+ try {
144
+ let payload = data;
145
+ if (alignedSize !== byteLength) {
146
+ const padded = new Uint8Array(alignedSize);
147
+ const view = data instanceof ArrayBuffer
148
+ ? new Uint8Array(data)
149
+ : new Uint8Array(data.buffer, data.byteOffset, data.byteLength);
150
+ padded.set(view);
151
+ payload = padded;
152
+ }
143
153
 
144
- let payload = data;
145
- if (alignedSize !== byteLength) {
146
- const padded = new Uint8Array(alignedSize);
147
- const view = data instanceof ArrayBuffer
148
- ? new Uint8Array(data)
149
- : new Uint8Array(data.buffer, data.byteOffset, data.byteLength);
150
- padded.set(view);
151
- payload = padded;
152
- }
153
-
154
- device.queue.writeBuffer(buffer, 0, payload);
155
- const tensor = createTensor(buffer, dtype, shape, label);
156
- const expectedBytes = tensorBytes(shape, dtype);
157
- if (expectedBytes !== byteLength) {
158
- log.warn('Energy', `${label} byte length mismatch: expected ${expectedBytes}, got ${byteLength}`);
154
+ device.queue.writeBuffer(buffer, 0, payload);
155
+ const tensor = createTensor(buffer, dtype, shape, label);
156
+ const expectedBytes = tensorBytes(shape, dtype);
157
+ if (expectedBytes !== byteLength) {
158
+ log.warn('Energy', `${label} byte length mismatch: expected ${expectedBytes}, got ${byteLength}`);
159
+ }
160
+ return tensor;
161
+ } catch (error) {
162
+ releaseBuffer(buffer);
163
+ throw error;
159
164
  }
160
- return tensor;
161
165
  }
162
166
 
163
167
  async function readTensorToFloat32(tensor) {
@@ -202,6 +206,7 @@ export class EnergyPipeline {
202
206
 
203
207
  async unload() {
204
208
  this.manifest = null;
209
+ restorePipelineContexts(this);
205
210
  }
206
211
 
207
212
  async generate(request = {}) {
@@ -336,6 +341,7 @@ export class EnergyPipeline {
336
341
  const centerWeight = Number.isFinite(weights.center) ? weights.center : 1.0;
337
342
  const binarizeWeight = Number.isFinite(weights.binarize) ? weights.binarize : 0.0;
338
343
  const centerTarget = Number.isFinite(quintelConfig.centerTarget) ? quintelConfig.centerTarget : 1.0;
344
+ const flags = buildQuintelKernelFlags(rules, binarizeWeight);
339
345
  const energyHistory = [];
340
346
  const stepTimesMs = [];
341
347
  let lastEnergy = null;
@@ -387,11 +393,11 @@ export class EnergyPipeline {
387
393
  await runEnergyQuintelReduce(stateTensor, {
388
394
  count: elementCount,
389
395
  size,
396
+ flags,
390
397
  symmetryWeight,
391
398
  centerWeight,
392
399
  binarizeWeight,
393
400
  centerTarget,
394
- rules,
395
401
  outputBuffer: reduceBuffer,
396
402
  });
397
403
 
@@ -447,13 +453,13 @@ export class EnergyPipeline {
447
453
  await runEnergyQuintelGrad(stateTensor, {
448
454
  count: elementCount,
449
455
  size,
456
+ flags,
450
457
  countDiff: safeCountDiff,
451
458
  symmetryWeight,
452
459
  countWeight,
453
460
  centerWeight,
454
461
  binarizeWeight,
455
462
  centerTarget,
456
- rules,
457
463
  outputBuffer: gradBuffer,
458
464
  });
459
465
 
@@ -471,6 +477,7 @@ export class EnergyPipeline {
471
477
  await runEnergyQuintelUpdate(stateTensor, {
472
478
  count: elementCount,
473
479
  size,
480
+ flags,
474
481
  stepSize,
475
482
  gradientScale,
476
483
  countDiff: safeCountDiff,
@@ -481,7 +488,6 @@ export class EnergyPipeline {
481
488
  centerTarget,
482
489
  clampMin,
483
490
  clampMax,
484
- rules,
485
491
  });
486
492
  }
487
493
 
@@ -84,4 +84,9 @@ export function mergeQuintelConfig(
84
84
  override?: Partial<QuintelEnergyConfig> | null
85
85
  ): QuintelEnergyConfig;
86
86
 
87
+ export function buildQuintelKernelFlags(
88
+ rules: Partial<QuintelRuleConfig> | null | undefined,
89
+ binarizeWeight?: number
90
+ ): number;
91
+
87
92
  export function runQuintelEnergyLoop(options: QuintelEnergyLoopOptions): QuintelEnergyLoopResult;
@@ -22,6 +22,17 @@ export function mergeQuintelConfig(base, override) {
22
22
  };
23
23
  }
24
24
 
25
+ export function buildQuintelKernelFlags(rules, binarizeWeight) {
26
+ let flags = 0;
27
+ if (rules?.mirrorX) flags |= 1;
28
+ if (rules?.mirrorY) flags |= 2;
29
+ if (rules?.diagonal) flags |= 4;
30
+ if (rules?.count) flags |= 8;
31
+ if (rules?.center) flags |= 16;
32
+ if (Number.isFinite(binarizeWeight) && binarizeWeight !== 0) flags |= 32;
33
+ return flags >>> 0;
34
+ }
35
+
25
36
  function applyPairEnergy(state, gradients, indexA, indexB, weight) {
26
37
  const diff = state[indexA] - state[indexB];
27
38
  const energy = weight * diff * diff;
@@ -5,7 +5,7 @@ import { runEnergyEval, runEnergyUpdate } from '../../../gpu/kernels/index.js';
5
5
  import { log } from '../../../debug/index.js';
6
6
  import { f16ToF32Array, f32ToF16Array } from '../../kv-cache/types.js';
7
7
  import { registerPipeline } from '../registry.js';
8
- import { applyPipelineContexts } from '../context.js';
8
+ import { applyPipelineContexts, restorePipelineContexts } from '../context.js';
9
9
  import { createInitializedPipeline } from '../factory.js';
10
10
  import { selectRuleValue } from '../../../rules/rule-registry.js';
11
11
 
@@ -165,19 +165,22 @@ async function createFeatureTensor(device, values, dtype, label) {
165
165
  const byteLength = payload.byteLength;
166
166
  const alignedSize = Math.ceil(byteLength / 4) * 4;
167
167
  const buffer = acquireBuffer(alignedSize, undefined, label);
168
-
169
- if (alignedSize === byteLength) {
170
- device.queue.writeBuffer(buffer, 0, payload);
171
- } else {
172
- const bytes = payload instanceof Uint16Array
173
- ? new Uint8Array(payload.buffer, payload.byteOffset, payload.byteLength)
174
- : new Uint8Array(payload.buffer, payload.byteOffset, payload.byteLength);
175
- const padded = new Uint8Array(alignedSize);
176
- padded.set(bytes);
177
- device.queue.writeBuffer(buffer, 0, padded);
168
+ try {
169
+ if (alignedSize === byteLength) {
170
+ device.queue.writeBuffer(buffer, 0, payload);
171
+ } else {
172
+ const bytes = payload instanceof Uint16Array
173
+ ? new Uint8Array(payload.buffer, payload.byteOffset, payload.byteLength)
174
+ : new Uint8Array(payload.buffer, payload.byteOffset, payload.byteLength);
175
+ const padded = new Uint8Array(alignedSize);
176
+ padded.set(bytes);
177
+ device.queue.writeBuffer(buffer, 0, padded);
178
+ }
179
+ return createTensor(buffer, dtype, [values.length], label);
180
+ } catch (error) {
181
+ releaseBuffer(buffer);
182
+ throw error;
178
183
  }
179
-
180
- return createTensor(buffer, dtype, [values.length], label);
181
184
  }
182
185
 
183
186
  async function readTensorF32(tensor) {
@@ -307,6 +310,7 @@ export class EnergyRowHeadPipeline {
307
310
  this.manifest = null;
308
311
  this.model = null;
309
312
  this.stats = {};
313
+ restorePipelineContexts(this);
310
314
  }
311
315
 
312
316
  async scoreRows(request = {}) {
@@ -84,20 +84,35 @@ function parseStructuredJSONObject(rawText) {
84
84
  function resolveStructuredRuntime(manifest, runtimeConfig) {
85
85
  const modelCfg = isObj(manifest?.inference?.structuredJsonHead)
86
86
  ? manifest.inference.structuredJsonHead
87
- : (isObj(manifest?.inference?.dream) ? manifest.inference.dream : {});
87
+ : null;
88
+ if (!modelCfg) {
89
+ throw new Error('StructuredJsonHeadPipeline: manifest.inference.structuredJsonHead is required.');
90
+ }
88
91
  const runtimeCfg = isObj(runtimeConfig?.inference?.structuredJsonHead)
89
92
  ? runtimeConfig.inference.structuredJsonHead
90
- : (isObj(runtimeConfig?.inference?.dream) ? runtimeConfig.inference.dream : {});
93
+ : {};
94
+ const resolvedMaxTokens = Number.isFinite(runtimeCfg.maxTokens)
95
+ ? Math.max(1, Math.floor(runtimeCfg.maxTokens))
96
+ : (Number.isFinite(modelCfg.maxTokens) ? Math.max(1, Math.floor(modelCfg.maxTokens)) : null);
97
+ const resolvedTemperature = Number.isFinite(runtimeCfg.temperature)
98
+ ? Number(runtimeCfg.temperature)
99
+ : (Number.isFinite(modelCfg.temperature) ? Number(modelCfg.temperature) : null);
100
+ const resolvedMaxOutputChars = Number.isFinite(runtimeCfg.maxOutputChars)
101
+ ? Math.max(4096, Math.floor(runtimeCfg.maxOutputChars))
102
+ : (Number.isFinite(modelCfg.maxOutputChars) ? Math.max(4096, Math.floor(modelCfg.maxOutputChars)) : null);
103
+ if (!Number.isFinite(resolvedMaxTokens)) {
104
+ throw new Error('StructuredJsonHeadPipeline: structuredJsonHead.maxTokens is required.');
105
+ }
106
+ if (!Number.isFinite(resolvedTemperature)) {
107
+ throw new Error('StructuredJsonHeadPipeline: structuredJsonHead.temperature is required.');
108
+ }
109
+ if (!Number.isFinite(resolvedMaxOutputChars)) {
110
+ throw new Error('StructuredJsonHeadPipeline: structuredJsonHead.maxOutputChars is required.');
111
+ }
91
112
  return {
92
- maxTokens: Number.isFinite(runtimeCfg.maxTokens)
93
- ? Math.max(1, Math.floor(runtimeCfg.maxTokens))
94
- : (Number.isFinite(modelCfg.maxTokens) ? Math.max(1, Math.floor(modelCfg.maxTokens)) : 768),
95
- temperature: Number.isFinite(runtimeCfg.temperature)
96
- ? Number(runtimeCfg.temperature)
97
- : (Number.isFinite(modelCfg.temperature) ? Number(modelCfg.temperature) : 0),
98
- maxOutputChars: Number.isFinite(runtimeCfg.maxOutputChars)
99
- ? Math.max(4096, Math.floor(runtimeCfg.maxOutputChars))
100
- : (Number.isFinite(modelCfg.maxOutputChars) ? Math.max(4096, Math.floor(modelCfg.maxOutputChars)) : 262144),
113
+ maxTokens: resolvedMaxTokens,
114
+ temperature: resolvedTemperature,
115
+ maxOutputChars: resolvedMaxOutputChars,
101
116
  };
102
117
  }
103
118
 
@@ -0,0 +1,12 @@
1
+ import type { Tensor } from '../../../../gpu/tensor.js';
2
+
3
+ export interface AttentionProjectionInputResult {
4
+ oProjInput: Tensor;
5
+ oProjInputTemp: Tensor | null;
6
+ }
7
+
8
+ export function prepareAttentionProjectionInput(
9
+ attnForProjection: Tensor,
10
+ matmulOutputDtype: string,
11
+ castToF16: (tensor: Tensor) => Promise<Tensor>
12
+ ): Promise<AttentionProjectionInputResult>;
@@ -0,0 +1,8 @@
1
+ export async function prepareAttentionProjectionInput(attnForProjection, matmulOutputDtype, castToF16) {
2
+ if (matmulOutputDtype === 'f16' && attnForProjection.dtype !== 'f16') {
3
+ const casted = await castToF16(attnForProjection);
4
+ return { oProjInput: casted, oProjInputTemp: casted };
5
+ }
6
+
7
+ return { oProjInput: attnForProjection, oProjInputTemp: null };
8
+ }