@simulatte/doppler 0.1.5 → 0.1.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +126 -0
- package/README.md +25 -17
- package/package.json +20 -4
- package/src/adapters/adapter-registry.js +12 -1
- package/src/adapters/lora-loader.js +23 -6
- package/src/bridge/extension-client.d.ts +5 -0
- package/src/bridge/extension-client.js +40 -0
- package/src/bridge/index.d.ts +2 -1
- package/src/bridge/index.js +6 -4
- package/src/browser/browser-converter.js +26 -1
- package/src/browser/file-picker.js +6 -0
- package/src/browser/safetensors-parser-browser.js +84 -1
- package/src/browser/shard-io-browser.js +2 -2
- package/src/browser/tensor-source-download.js +8 -2
- package/src/browser/tensor-source-http.d.ts +1 -0
- package/src/browser/tensor-source-http.js +5 -1
- package/src/client/doppler-api.browser.js +20 -4
- package/src/client/doppler-api.js +19 -3
- package/src/client/doppler-provider/generation.js +12 -0
- package/src/client/doppler-provider/model-manager.d.ts +10 -0
- package/src/client/doppler-provider/model-manager.js +91 -19
- package/src/client/doppler-provider/source-runtime.d.ts +2 -1
- package/src/client/doppler-provider/source-runtime.js +132 -13
- package/src/client/doppler-registry.json +8 -7
- package/src/config/backward-registry-loader.js +17 -2
- package/src/config/execution-v0-contract-check.js +113 -15
- package/src/config/kernel-path-contract-check.js +57 -29
- package/src/config/kernel-path-loader.js +5 -36
- package/src/config/kernels/kernel-ref-digests.js +39 -39
- package/src/config/kernels/registry.js +14 -1
- package/src/config/kernels/registry.json +49 -7
- package/src/config/loader.d.ts +1 -1
- package/src/config/loader.js +43 -4
- package/src/config/merge-contract-check.js +59 -4
- package/src/config/merge-helpers.js +128 -7
- package/src/config/merge.d.ts +1 -0
- package/src/config/merge.js +28 -0
- package/src/config/param-validator.js +47 -2
- package/src/config/presets/kernel-paths/{gemma2-q4k-dequant-f32a.json → gemma2-q4k-dequant-f32a-nosubgroups.json} +3 -3
- package/src/config/presets/kernel-paths/gemma3-f16-fused-f32a-online-streamingprefill.json +223 -0
- package/src/config/presets/kernel-paths/{gemma3-q4k-dequant-f32a.json → gemma3-q4k-dequant-f32a-nosubgroups.json} +3 -3
- package/src/config/presets/kernel-paths/registry.json +29 -8
- package/src/config/presets/models/gemma2.json +2 -2
- package/src/config/presets/models/qwen3.json +9 -2
- package/src/config/presets/models/transformer.json +5 -0
- package/src/config/presets/runtime/experiments/bench/gemma3-bench-q4k.json +1 -1
- package/src/config/presets/runtime/experiments/debug/gemma3-debug-q4k.json +1 -1
- package/src/config/presets/runtime/experiments/verify/gemma3-verify.json +1 -1
- package/src/config/presets/runtime/kernels/dequant-f16-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/dequant-f32-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/embeddinggemma-q4k-dequant-f32a.json +37 -0
- package/src/config/presets/runtime/kernels/fused-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f16a.json +33 -0
- package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f32a-nosubgroups.json +33 -0
- package/src/config/presets/runtime/kernels/gemma2-q4k-fused-f32a.json +33 -0
- package/src/config/presets/runtime/kernels/safe-q4k.json +6 -13
- package/src/config/presets/runtime/platform/metal-apple-q4k.json +1 -1
- package/src/config/required-inference-fields-contract-check.js +6 -0
- package/src/config/runtime.js +6 -1
- package/src/config/schema/debug.schema.d.ts +5 -0
- package/src/config/schema/doppler.schema.js +16 -21
- package/src/config/schema/inference-defaults.schema.js +6 -3
- package/src/config/schema/inference.schema.d.ts +9 -0
- package/src/config/schema/kernel-path.schema.d.ts +11 -1
- package/src/config/schema/kernel-thresholds.schema.js +12 -4
- package/src/config/schema/manifest.schema.d.ts +8 -1
- package/src/config/schema/manifest.schema.js +19 -3
- package/src/config/training-defaults.js +30 -22
- package/src/converter/conversion-plan.js +94 -9
- package/src/converter/core.d.ts +7 -0
- package/src/converter/core.js +14 -9
- package/src/converter/execution-v0-manifest.js +4 -1
- package/src/converter/index.d.ts +1 -0
- package/src/converter/index.js +1 -0
- package/src/converter/manifest-inference.js +43 -12
- package/src/converter/parsers/diffusion.js +0 -3
- package/src/converter/quantization-info.js +35 -15
- package/src/converter/rope-config.js +42 -0
- package/src/converter/shard-packer.d.ts +1 -1
- package/src/converter/shard-packer.js +4 -1
- package/src/debug/config.js +123 -11
- package/src/debug/signals.js +7 -1
- package/src/debug/tensor.d.ts +2 -0
- package/src/debug/tensor.js +13 -2
- package/src/distribution/p2p-control-plane.js +52 -12
- package/src/distribution/p2p-observability.js +43 -7
- package/src/distribution/p2p-webrtc-browser.js +20 -0
- package/src/distribution/shard-delivery.js +77 -26
- package/src/formats/gguf/types.js +33 -16
- package/src/formats/rdrr/groups.d.ts +12 -4
- package/src/formats/rdrr/groups.js +3 -6
- package/src/formats/rdrr/parsing.js +39 -2
- package/src/formats/rdrr/types.d.ts +2 -1
- package/src/gpu/command-recorder.js +86 -61
- package/src/gpu/device.d.ts +1 -0
- package/src/gpu/device.js +131 -19
- package/src/gpu/kernel-tuner/benchmarks.js +326 -316
- package/src/gpu/kernel-tuner/cache.js +71 -4
- package/src/gpu/kernel-tuner/tuner.js +22 -4
- package/src/gpu/kernels/attention.js +113 -34
- package/src/gpu/kernels/backward/adam.js +62 -58
- package/src/gpu/kernels/backward/attention_backward.js +257 -169
- package/src/gpu/kernels/backward/conv2d_backward.js +14 -1
- package/src/gpu/kernels/bias_add.wgsl +8 -6
- package/src/gpu/kernels/bias_add_f16.wgsl +8 -5
- package/src/gpu/kernels/cast.js +191 -149
- package/src/gpu/kernels/check-stop.js +33 -44
- package/src/gpu/kernels/conv2d.js +27 -17
- package/src/gpu/kernels/conv2d.wgsl +7 -8
- package/src/gpu/kernels/conv2d_f16.wgsl +7 -8
- package/src/gpu/kernels/cross_entropy_loss.js +21 -15
- package/src/gpu/kernels/depthwise_conv2d.js +37 -26
- package/src/gpu/kernels/depthwise_conv2d.wgsl +6 -9
- package/src/gpu/kernels/depthwise_conv2d_f16.wgsl +6 -9
- package/src/gpu/kernels/dequant.js +178 -126
- package/src/gpu/kernels/energy.d.ts +3 -21
- package/src/gpu/kernels/energy.js +111 -88
- package/src/gpu/kernels/feature-check.js +1 -1
- package/src/gpu/kernels/fused_ffn.js +84 -65
- package/src/gpu/kernels/fused_matmul_residual.js +56 -33
- package/src/gpu/kernels/fused_matmul_rmsnorm.js +62 -45
- package/src/gpu/kernels/gather.js +33 -15
- package/src/gpu/kernels/gelu.js +19 -11
- package/src/gpu/kernels/grouped_pointwise_conv2d.js +34 -23
- package/src/gpu/kernels/grouped_pointwise_conv2d.wgsl +6 -9
- package/src/gpu/kernels/grouped_pointwise_conv2d_f16.wgsl +6 -9
- package/src/gpu/kernels/groupnorm.js +34 -23
- package/src/gpu/kernels/kv-quantize.js +5 -2
- package/src/gpu/kernels/layernorm.js +35 -19
- package/src/gpu/kernels/logit-merge.js +5 -3
- package/src/gpu/kernels/matmul.js +83 -39
- package/src/gpu/kernels/modulate.js +23 -15
- package/src/gpu/kernels/moe.js +221 -175
- package/src/gpu/kernels/pixel_shuffle.js +22 -14
- package/src/gpu/kernels/pixel_shuffle.wgsl +4 -5
- package/src/gpu/kernels/pixel_shuffle_f16.wgsl +4 -5
- package/src/gpu/kernels/relu.js +31 -10
- package/src/gpu/kernels/relu.wgsl +2 -1
- package/src/gpu/kernels/relu_f16.wgsl +2 -1
- package/src/gpu/kernels/repeat_channels.js +25 -17
- package/src/gpu/kernels/repeat_channels.wgsl +4 -5
- package/src/gpu/kernels/repeat_channels_f16.wgsl +4 -5
- package/src/gpu/kernels/residual.js +69 -23
- package/src/gpu/kernels/residual.wgsl +6 -3
- package/src/gpu/kernels/residual_f16.wgsl +2 -1
- package/src/gpu/kernels/residual_f16_vec4.wgsl +2 -1
- package/src/gpu/kernels/residual_vec4.wgsl +2 -1
- package/src/gpu/kernels/rmsnorm.js +96 -28
- package/src/gpu/kernels/rmsnorm.wgsl +14 -6
- package/src/gpu/kernels/rmsnorm_f16.wgsl +10 -2
- package/src/gpu/kernels/rope.d.ts +2 -0
- package/src/gpu/kernels/rope.js +14 -1
- package/src/gpu/kernels/rope.wgsl +56 -40
- package/src/gpu/kernels/sample.js +27 -38
- package/src/gpu/kernels/sana_linear_attention.js +19 -12
- package/src/gpu/kernels/sana_linear_attention_apply.wgsl +4 -5
- package/src/gpu/kernels/sana_linear_attention_apply_f16.wgsl +4 -5
- package/src/gpu/kernels/sana_linear_attention_summary.wgsl +4 -0
- package/src/gpu/kernels/sana_linear_attention_summary_f16.wgsl +4 -0
- package/src/gpu/kernels/scale.js +18 -11
- package/src/gpu/kernels/shader-cache.js +4 -2
- package/src/gpu/kernels/silu.d.ts +1 -0
- package/src/gpu/kernels/silu.js +148 -82
- package/src/gpu/kernels/silu.wgsl +19 -9
- package/src/gpu/kernels/silu_f16.wgsl +19 -9
- package/src/gpu/kernels/softmax.js +44 -25
- package/src/gpu/kernels/split_qkv.js +23 -13
- package/src/gpu/kernels/transpose.js +31 -10
- package/src/gpu/kernels/transpose.wgsl +6 -5
- package/src/gpu/kernels/upsample2d.js +22 -13
- package/src/gpu/kernels/upsample2d.wgsl +6 -9
- package/src/gpu/kernels/upsample2d_f16.wgsl +6 -9
- package/src/gpu/kernels/utils.js +35 -13
- package/src/gpu/partitioned-buffer-pool.js +10 -2
- package/src/gpu/perf-guards.js +2 -9
- package/src/gpu/profiler.js +27 -22
- package/src/gpu/readback-utils.d.ts +16 -0
- package/src/gpu/readback-utils.js +41 -0
- package/src/gpu/submit-tracker.js +13 -0
- package/src/gpu/uniform-cache.d.ts +1 -0
- package/src/gpu/uniform-cache.js +30 -9
- package/src/hotswap/intent-bundle.js +6 -0
- package/src/hotswap/manifest.d.ts +10 -1
- package/src/hotswap/manifest.js +12 -2
- package/src/hotswap/runtime.js +30 -8
- package/src/index-browser.d.ts +44 -0
- package/src/index-browser.js +14 -0
- package/src/inference/browser-harness-contract-helpers.d.ts +5 -0
- package/src/inference/browser-harness-contract-helpers.js +28 -0
- package/src/inference/browser-harness-diffusion-energy-suites.d.ts +2 -0
- package/src/inference/browser-harness-diffusion-energy-suites.js +269 -0
- package/src/inference/browser-harness-model-helpers.d.ts +16 -0
- package/src/inference/browser-harness-model-helpers.js +217 -0
- package/src/inference/browser-harness-report-helpers.d.ts +7 -0
- package/src/inference/browser-harness-report-helpers.js +42 -0
- package/src/inference/browser-harness-runtime-helpers.d.ts +61 -0
- package/src/inference/browser-harness-runtime-helpers.js +415 -0
- package/src/inference/browser-harness-suite-helpers.d.ts +28 -0
- package/src/inference/browser-harness-suite-helpers.js +268 -0
- package/src/inference/browser-harness-text-helpers.d.ts +27 -0
- package/src/inference/browser-harness-text-helpers.js +788 -0
- package/src/inference/browser-harness.d.ts +6 -0
- package/src/inference/browser-harness.js +130 -1950
- package/src/inference/kv-cache/base.js +140 -94
- package/src/inference/kv-cache/tiered.js +5 -3
- package/src/inference/moe-router.js +88 -56
- package/src/inference/multi-model-network.js +5 -3
- package/src/inference/network-evolution.d.ts +11 -2
- package/src/inference/network-evolution.js +20 -21
- package/src/inference/pipelines/context.d.ts +3 -0
- package/src/inference/pipelines/context.js +142 -2
- package/src/inference/pipelines/diffusion/helpers.js +7 -2
- package/src/inference/pipelines/diffusion/pipeline.js +17 -7
- package/src/inference/pipelines/diffusion/sd3-transformer.js +10 -10
- package/src/inference/pipelines/diffusion/text-encoder-gpu.d.ts +5 -0
- package/src/inference/pipelines/diffusion/text-encoder-gpu.js +27 -15
- package/src/inference/pipelines/diffusion/vae.js +3 -7
- package/src/inference/pipelines/energy/pipeline.js +27 -21
- package/src/inference/pipelines/energy/quintel.d.ts +5 -0
- package/src/inference/pipelines/energy/quintel.js +11 -0
- package/src/inference/pipelines/energy-head/row-head-pipeline.js +17 -13
- package/src/inference/pipelines/structured/json-head-pipeline.js +26 -11
- package/src/inference/pipelines/text/attention/projections.js +151 -101
- package/src/inference/pipelines/text/attention/record.js +73 -10
- package/src/inference/pipelines/text/attention/run.js +73 -10
- package/src/inference/pipelines/text/chat-format.js +25 -1
- package/src/inference/pipelines/text/config.d.ts +4 -0
- package/src/inference/pipelines/text/config.js +71 -5
- package/src/inference/pipelines/text/embed.js +2 -8
- package/src/inference/pipelines/text/execution-plan.js +64 -50
- package/src/inference/pipelines/text/execution-v0-contract-helpers.d.ts +59 -0
- package/src/inference/pipelines/text/execution-v0-contract-helpers.js +937 -0
- package/src/inference/pipelines/text/execution-v0-runtime-builders.d.ts +15 -0
- package/src/inference/pipelines/text/execution-v0-runtime-builders.js +279 -0
- package/src/inference/pipelines/text/execution-v0.js +78 -1002
- package/src/inference/pipelines/text/ffn/standard.js +3 -0
- package/src/inference/pipelines/text/generator-steps.d.ts +46 -0
- package/src/inference/pipelines/text/generator-steps.js +298 -207
- package/src/inference/pipelines/text/generator.js +6 -23
- package/src/inference/pipelines/text/init.d.ts +4 -0
- package/src/inference/pipelines/text/init.js +134 -29
- package/src/inference/pipelines/text/kernel-path-auto-select.js +2 -0
- package/src/inference/pipelines/text/kernel-trace.d.ts +2 -0
- package/src/inference/pipelines/text/kernel-trace.js +6 -0
- package/src/inference/pipelines/text/layer.js +14 -9
- package/src/inference/pipelines/text/linear-attention.d.ts +10 -0
- package/src/inference/pipelines/text/linear-attention.js +80 -6
- package/src/inference/pipelines/text/logits/gpu.js +10 -5
- package/src/inference/pipelines/text/logits/index.js +10 -11
- package/src/inference/pipelines/text/logits/utils.d.ts +7 -0
- package/src/inference/pipelines/text/logits/utils.js +9 -0
- package/src/inference/pipelines/text/lora-apply.js +50 -32
- package/src/inference/pipelines/text/model-load.js +279 -104
- package/src/inference/pipelines/text/moe-cache.js +5 -4
- package/src/inference/pipelines/text/moe-cpu-gptoss.js +74 -69
- package/src/inference/pipelines/text/moe-cpu.js +42 -38
- package/src/inference/pipelines/text/moe-gpu.js +110 -86
- package/src/inference/pipelines/text/ops.js +90 -90
- package/src/inference/pipelines/text/probes.js +9 -9
- package/src/inference/pipelines/text/weights.js +17 -7
- package/src/inference/pipelines/text.js +17 -1
- package/src/inference/speculative.d.ts +2 -2
- package/src/inference/speculative.js +4 -18
- package/src/inference/test-harness.d.ts +1 -1
- package/src/inference/test-harness.js +15 -5
- package/src/inference/tokenizer.d.ts +0 -5
- package/src/inference/tokenizer.js +4 -23
- package/src/inference/tokenizers/bpe.js +9 -0
- package/src/inference/tokenizers/bundled.js +176 -33
- package/src/inference/tokenizers/sentencepiece.js +12 -0
- package/src/loader/doppler-loader.js +38 -22
- package/src/loader/dtype-utils.js +3 -44
- package/src/loader/embedding-loader.js +7 -3
- package/src/loader/experts/expert-cache.js +13 -6
- package/src/loader/experts/expert-loader.js +10 -6
- package/src/loader/final-weights-loader.js +8 -4
- package/src/loader/layer-loader.js +2 -1
- package/src/loader/loader-state.js +2 -2
- package/src/loader/memory-monitor.js +8 -0
- package/src/loader/multi-model-loader.d.ts +14 -0
- package/src/loader/multi-model-loader.js +70 -24
- package/src/loader/shard-cache.js +81 -12
- package/src/loader/shard-resolver.js +25 -3
- package/src/loader/tensors/tensor-loader.js +209 -144
- package/src/loader/tensors/tensor-reader.js +76 -19
- package/src/loader/weight-downcast.js +1 -1
- package/src/memory/buffer-pool.d.ts +9 -1
- package/src/memory/buffer-pool.js +109 -44
- package/src/memory/unified-detect.js +1 -1
- package/src/rules/inference/kernel-path.rules.json +24 -8
- package/src/rules/rule-registry.js +25 -1
- package/src/rules/tooling/command-runtime.rules.json +18 -0
- package/src/storage/backends/opfs-store.js +68 -24
- package/src/storage/downloader.js +364 -83
- package/src/storage/index.d.ts +3 -0
- package/src/storage/index.js +3 -0
- package/src/storage/preflight.d.ts +2 -2
- package/src/storage/preflight.js +24 -2
- package/src/storage/quickstart-downloader.js +11 -5
- package/src/storage/registry.js +10 -4
- package/src/storage/reports.js +1 -1
- package/src/storage/shard-manager.d.ts +15 -1
- package/src/storage/shard-manager.js +51 -3
- package/src/storage/source-artifact-store.d.ts +52 -0
- package/src/storage/source-artifact-store.js +234 -0
- package/src/tooling/command-api-constants.d.ts +9 -0
- package/src/tooling/command-api-constants.js +9 -0
- package/src/tooling/command-api-family-normalizers.d.ts +9 -0
- package/src/tooling/command-api-family-normalizers.js +343 -0
- package/src/tooling/command-api-helpers.d.ts +25 -0
- package/src/tooling/command-api-helpers.js +262 -0
- package/src/tooling/command-api.d.ts +27 -1
- package/src/tooling/command-api.js +26 -473
- package/src/tooling/command-envelope.js +4 -1
- package/src/tooling/command-runner-shared.js +52 -18
- package/src/tooling/lean-execution-contract.js +150 -3
- package/src/tooling/node-browser-command-runner.d.ts +4 -0
- package/src/tooling/node-browser-command-runner.js +218 -273
- package/src/tooling/node-command-runner.js +44 -3
- package/src/tooling/node-converter.js +27 -1
- package/src/tooling/node-source-runtime.d.ts +1 -1
- package/src/tooling/node-source-runtime.js +84 -3
- package/src/tooling/node-webgpu.js +30 -105
- package/src/tooling/opfs-cache.js +21 -4
- package/src/tooling/runtime-input-composition.d.ts +38 -0
- package/src/tooling/runtime-input-composition.js +86 -0
- package/src/tooling/source-runtime-bundle.d.ts +40 -5
- package/src/tooling/source-runtime-bundle.js +261 -34
- package/src/tooling/source-runtime-materializer.d.ts +6 -0
- package/src/tooling/source-runtime-materializer.js +93 -0
- package/src/training/attention-backward.js +32 -17
- package/src/training/autograd.js +80 -52
- package/src/training/checkpoint-watch.d.ts +8 -0
- package/src/training/checkpoint-watch.js +139 -0
- package/src/training/checkpoint.d.ts +6 -1
- package/src/training/checkpoint.js +46 -7
- package/src/training/clip.js +2 -1
- package/src/training/datasets/token-batch.js +20 -8
- package/src/training/distillation/artifacts.d.ts +71 -0
- package/src/training/distillation/artifacts.js +132 -0
- package/src/training/distillation/checkpoint-watch.d.ts +10 -0
- package/src/training/distillation/checkpoint-watch.js +58 -0
- package/src/training/distillation/dataset.d.ts +59 -0
- package/src/training/distillation/dataset.js +337 -0
- package/src/training/distillation/eval.d.ts +34 -0
- package/src/training/distillation/eval.js +310 -0
- package/src/training/distillation/index.d.ts +29 -0
- package/src/training/distillation/index.js +29 -0
- package/src/training/distillation/runtime.d.ts +20 -0
- package/src/training/distillation/runtime.js +121 -0
- package/src/training/distillation/scoreboard.d.ts +6 -0
- package/src/training/distillation/scoreboard.js +8 -0
- package/src/training/distillation/stage-a.d.ts +45 -0
- package/src/training/distillation/stage-a.js +338 -0
- package/src/training/distillation/stage-b.d.ts +24 -0
- package/src/training/distillation/stage-b.js +20 -0
- package/src/training/distillation/student-fixture.d.ts +22 -0
- package/src/training/distillation/student-fixture.js +846 -0
- package/src/training/distillation/suite-data.d.ts +45 -0
- package/src/training/distillation/suite-data.js +189 -0
- package/src/training/index.d.ts +10 -0
- package/src/training/index.js +10 -0
- package/src/training/lora-pipeline.d.ts +40 -0
- package/src/training/lora-pipeline.js +793 -0
- package/src/training/lora.js +26 -12
- package/src/training/loss.js +5 -6
- package/src/training/objectives/cross_entropy.js +2 -5
- package/src/training/objectives/distill_kd.js +4 -8
- package/src/training/objectives/distill_triplet.js +4 -8
- package/src/training/objectives/ul_stage2_base.js +4 -8
- package/src/training/operator-artifacts.d.ts +62 -0
- package/src/training/operator-artifacts.js +140 -0
- package/src/training/operator-command.d.ts +5 -0
- package/src/training/operator-command.js +455 -0
- package/src/training/operator-eval.d.ts +48 -0
- package/src/training/operator-eval.js +230 -0
- package/src/training/operator-scoreboard.d.ts +5 -0
- package/src/training/operator-scoreboard.js +44 -0
- package/src/training/optimizer.js +19 -7
- package/src/training/runner.d.ts +52 -0
- package/src/training/runner.js +31 -5
- package/src/training/suite.d.ts +112 -0
- package/src/training/suite.js +24 -984
- package/src/training/tensor-factory.d.ts +9 -0
- package/src/training/tensor-factory.js +13 -0
- package/src/training/trainer.js +3 -5
- package/src/training/ul_dataset.js +3 -5
- package/src/training/workloads.d.ts +164 -0
- package/src/training/workloads.js +530 -0
- package/src/version.js +1 -1
- package/tools/convert-safetensors-node.js +22 -16
- package/tools/doppler-cli.js +179 -63
|
@@ -1,33 +1,29 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
if (typeof crypto !== 'undefined' && typeof crypto.getRandomValues === 'function') {
|
|
5
|
-
const values = new Uint32Array(1);
|
|
6
|
-
crypto.getRandomValues(values);
|
|
7
|
-
return values[0] / 4294967296;
|
|
1
|
+
function requireRandomSource(random) {
|
|
2
|
+
if (typeof random !== 'function') {
|
|
3
|
+
throw new Error('network evolution requires an explicit random() source.');
|
|
8
4
|
}
|
|
9
|
-
|
|
10
|
-
return fallbackRandomState / 4294967296;
|
|
5
|
+
return random;
|
|
11
6
|
}
|
|
12
7
|
|
|
13
|
-
export const mutateGenome = (genome, mutationRate = 0.1) => {
|
|
8
|
+
export const mutateGenome = (genome, mutationRate = 0.1, random = null) => {
|
|
9
|
+
const sample = requireRandomSource(random);
|
|
14
10
|
|
|
15
11
|
const mutated = JSON.parse(JSON.stringify(genome));
|
|
16
|
-
if (
|
|
12
|
+
if (sample() < mutationRate) {
|
|
17
13
|
|
|
18
14
|
const types = ['chain', 'tree', 'mesh', 'dag'];
|
|
19
|
-
mutated.topology.type = types[Math.floor(
|
|
15
|
+
mutated.topology.type = types[Math.floor(sample() * types.length)];
|
|
20
16
|
}
|
|
21
17
|
|
|
22
18
|
for (const node of mutated.nodes) {
|
|
23
|
-
if (
|
|
24
|
-
node.temperature = Math.min(1, Math.max(0, node.temperature + (
|
|
19
|
+
if (sample() < mutationRate && typeof node.temperature === 'number') {
|
|
20
|
+
node.temperature = Math.min(1, Math.max(0, node.temperature + (sample() - 0.5) * 0.2));
|
|
25
21
|
}
|
|
26
22
|
}
|
|
27
23
|
|
|
28
24
|
for (const edge of mutated.edges) {
|
|
29
|
-
if (
|
|
30
|
-
edge.weight = Math.min(1, Math.max(0, edge.weight + (
|
|
25
|
+
if (sample() < mutationRate) {
|
|
26
|
+
edge.weight = Math.min(1, Math.max(0, edge.weight + (sample() - 0.5) * 0.4));
|
|
31
27
|
}
|
|
32
28
|
}
|
|
33
29
|
|
|
@@ -35,8 +31,9 @@ export const mutateGenome = (genome, mutationRate = 0.1) => {
|
|
|
35
31
|
};
|
|
36
32
|
|
|
37
33
|
|
|
38
|
-
export const crossoverGenome = (a, b) => {
|
|
39
|
-
|
|
34
|
+
export const crossoverGenome = (a, b, random = null) => {
|
|
35
|
+
const sample = requireRandomSource(random);
|
|
36
|
+
return sample() < 0.5 ? JSON.parse(JSON.stringify(a)) : JSON.parse(JSON.stringify(b));
|
|
40
37
|
};
|
|
41
38
|
|
|
42
39
|
|
|
@@ -48,7 +45,9 @@ export async function evolveNetwork(config) {
|
|
|
48
45
|
mutationRate = 0.1,
|
|
49
46
|
evaluate,
|
|
50
47
|
randomGenome,
|
|
48
|
+
random,
|
|
51
49
|
} = config;
|
|
50
|
+
const sample = requireRandomSource(random);
|
|
52
51
|
|
|
53
52
|
let population = Array.from({ length: populationSize }, () => randomGenome());
|
|
54
53
|
|
|
@@ -63,9 +62,9 @@ export async function evolveNetwork(config) {
|
|
|
63
62
|
const offspring = [];
|
|
64
63
|
|
|
65
64
|
while (offspring.length < populationSize - eliteCount) {
|
|
66
|
-
const parentA = scored[Math.floor(
|
|
67
|
-
const parentB = scored[Math.floor(
|
|
68
|
-
const child = mutateGenome(crossoverGenome(parentA, parentB), mutationRate);
|
|
65
|
+
const parentA = scored[Math.floor(sample() * scored.length)].genome;
|
|
66
|
+
const parentB = scored[Math.floor(sample() * scored.length)].genome;
|
|
67
|
+
const child = mutateGenome(crossoverGenome(parentA, parentB, sample), mutationRate, sample);
|
|
69
68
|
offspring.push(child);
|
|
70
69
|
}
|
|
71
70
|
|
|
@@ -8,6 +8,8 @@ export type PipelineContextOptions = {
|
|
|
8
8
|
assignProgress?: boolean;
|
|
9
9
|
};
|
|
10
10
|
|
|
11
|
+
export declare function restorePipelineContexts(target: Record<string, unknown>): boolean;
|
|
12
|
+
|
|
11
13
|
export declare function applyPipelineContexts(
|
|
12
14
|
target: Record<string, unknown>,
|
|
13
15
|
contexts?: Record<string, unknown>,
|
|
@@ -15,4 +17,5 @@ export declare function applyPipelineContexts(
|
|
|
15
17
|
): {
|
|
16
18
|
runtimeConfig: Record<string, unknown>;
|
|
17
19
|
sharedDebug: Record<string, unknown> | null | undefined;
|
|
20
|
+
restore: () => void;
|
|
18
21
|
};
|
|
@@ -1,8 +1,115 @@
|
|
|
1
|
-
import {
|
|
1
|
+
import {
|
|
2
|
+
getDevice,
|
|
3
|
+
getKernelCapabilities,
|
|
4
|
+
getPlatformConfig,
|
|
5
|
+
setDevice,
|
|
6
|
+
} from '../../gpu/device.js';
|
|
2
7
|
import { applyDebugConfig, setGPUDevice } from '../../debug/index.js';
|
|
3
8
|
import { getRuntimeConfig, setRuntimeConfig } from '../../config/runtime.js';
|
|
9
|
+
import {
|
|
10
|
+
getLogLevel,
|
|
11
|
+
getTrace,
|
|
12
|
+
isSilentMode,
|
|
13
|
+
setLogLevel,
|
|
14
|
+
setSilentMode,
|
|
15
|
+
setTrace,
|
|
16
|
+
} from '../../debug/config.js';
|
|
17
|
+
import {
|
|
18
|
+
gpuDevice as debugGpuDevice,
|
|
19
|
+
traceBreakOnAnomaly,
|
|
20
|
+
traceLayerFilter,
|
|
21
|
+
traceMaxDecodeSteps,
|
|
22
|
+
} from '../../debug/config.js';
|
|
23
|
+
|
|
24
|
+
const RESTORE_PIPELINE_CONTEXTS = Symbol('restorePipelineContexts');
|
|
25
|
+
|
|
26
|
+
function captureTargetField(target, key) {
|
|
27
|
+
return {
|
|
28
|
+
present: Object.prototype.hasOwnProperty.call(target, key),
|
|
29
|
+
value: target[key],
|
|
30
|
+
};
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
function restoreTargetField(target, key, snapshot) {
|
|
34
|
+
if (snapshot.present) {
|
|
35
|
+
target[key] = snapshot.value;
|
|
36
|
+
return;
|
|
37
|
+
}
|
|
38
|
+
delete target[key];
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
function captureDebugState() {
|
|
42
|
+
return {
|
|
43
|
+
logLevel: getLogLevel(),
|
|
44
|
+
traceCategories: getTrace(),
|
|
45
|
+
traceLayers: [...traceLayerFilter],
|
|
46
|
+
traceMaxDecodeSteps,
|
|
47
|
+
traceBreakOnAnomaly,
|
|
48
|
+
silentMode: isSilentMode(),
|
|
49
|
+
gpuDevice: debugGpuDevice,
|
|
50
|
+
};
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
function restoreDebugState(snapshot) {
|
|
54
|
+
if (snapshot.silentMode !== isSilentMode()) {
|
|
55
|
+
setSilentMode(snapshot.silentMode);
|
|
56
|
+
}
|
|
57
|
+
if (getLogLevel() !== snapshot.logLevel) {
|
|
58
|
+
setLogLevel(snapshot.logLevel);
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
const traceCategories = getTrace();
|
|
62
|
+
const traceChanged = traceCategories.length !== snapshot.traceCategories.length
|
|
63
|
+
|| traceCategories.some((category, idx) => category !== snapshot.traceCategories[idx])
|
|
64
|
+
|| traceLayerFilter.length !== snapshot.traceLayers.length
|
|
65
|
+
|| traceLayerFilter.some((layer, idx) => layer !== snapshot.traceLayers[idx])
|
|
66
|
+
|| traceMaxDecodeSteps !== snapshot.traceMaxDecodeSteps
|
|
67
|
+
|| traceBreakOnAnomaly !== snapshot.traceBreakOnAnomaly;
|
|
68
|
+
|
|
69
|
+
if (traceChanged) {
|
|
70
|
+
if (snapshot.traceCategories.length > 0) {
|
|
71
|
+
setTrace(snapshot.traceCategories.join(','), {
|
|
72
|
+
layers: snapshot.traceLayers.length > 0 ? snapshot.traceLayers : undefined,
|
|
73
|
+
maxDecodeSteps: snapshot.traceMaxDecodeSteps > 0 ? snapshot.traceMaxDecodeSteps : undefined,
|
|
74
|
+
breakOnAnomaly: snapshot.traceBreakOnAnomaly,
|
|
75
|
+
});
|
|
76
|
+
} else {
|
|
77
|
+
setTrace(false);
|
|
78
|
+
}
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
setGPUDevice(snapshot.gpuDevice ?? null);
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
export function restorePipelineContexts(target) {
|
|
85
|
+
const restore = target?.[RESTORE_PIPELINE_CONTEXTS];
|
|
86
|
+
if (typeof restore !== 'function') {
|
|
87
|
+
return false;
|
|
88
|
+
}
|
|
89
|
+
delete target[RESTORE_PIPELINE_CONTEXTS];
|
|
90
|
+
restore();
|
|
91
|
+
return true;
|
|
92
|
+
}
|
|
4
93
|
|
|
5
94
|
export function applyPipelineContexts(target, contexts = {}, options = {}) {
|
|
95
|
+
restorePipelineContexts(target);
|
|
96
|
+
|
|
97
|
+
const previousRuntimeConfig = getRuntimeConfig();
|
|
98
|
+
const previousDevice = getDevice();
|
|
99
|
+
const previousPlatformConfig = getPlatformConfig();
|
|
100
|
+
const previousAdapterInfo = previousDevice
|
|
101
|
+
? (getKernelCapabilities().adapterInfo ?? null)
|
|
102
|
+
: null;
|
|
103
|
+
const previousDebugState = captureDebugState();
|
|
104
|
+
const targetSnapshot = {
|
|
105
|
+
gpuContext: captureTargetField(target, 'gpuContext'),
|
|
106
|
+
useGPU: captureTargetField(target, 'useGPU'),
|
|
107
|
+
memoryContext: captureTargetField(target, 'memoryContext'),
|
|
108
|
+
storageContext: captureTargetField(target, 'storageContext'),
|
|
109
|
+
baseUrl: captureTargetField(target, 'baseUrl'),
|
|
110
|
+
_onProgress: captureTargetField(target, '_onProgress'),
|
|
111
|
+
};
|
|
112
|
+
|
|
6
113
|
const runtimeConfig = contexts.runtimeConfig
|
|
7
114
|
? setRuntimeConfig(contexts.runtimeConfig)
|
|
8
115
|
: getRuntimeConfig();
|
|
@@ -40,5 +147,38 @@ export function applyPipelineContexts(target, contexts = {}, options = {}) {
|
|
|
40
147
|
target._onProgress = contexts.onProgress;
|
|
41
148
|
}
|
|
42
149
|
|
|
43
|
-
|
|
150
|
+
let restored = false;
|
|
151
|
+
const restore = () => {
|
|
152
|
+
if (restored) {
|
|
153
|
+
return;
|
|
154
|
+
}
|
|
155
|
+
restored = true;
|
|
156
|
+
delete target[RESTORE_PIPELINE_CONTEXTS];
|
|
157
|
+
|
|
158
|
+
setRuntimeConfig(previousRuntimeConfig);
|
|
159
|
+
if (previousDevice) {
|
|
160
|
+
setDevice(previousDevice, {
|
|
161
|
+
platformConfig: previousPlatformConfig,
|
|
162
|
+
adapterInfo: previousAdapterInfo,
|
|
163
|
+
});
|
|
164
|
+
} else {
|
|
165
|
+
setDevice(null);
|
|
166
|
+
}
|
|
167
|
+
restoreDebugState(previousDebugState);
|
|
168
|
+
restoreTargetField(target, 'gpuContext', targetSnapshot.gpuContext);
|
|
169
|
+
restoreTargetField(target, 'useGPU', targetSnapshot.useGPU);
|
|
170
|
+
restoreTargetField(target, 'memoryContext', targetSnapshot.memoryContext);
|
|
171
|
+
restoreTargetField(target, 'storageContext', targetSnapshot.storageContext);
|
|
172
|
+
restoreTargetField(target, 'baseUrl', targetSnapshot.baseUrl);
|
|
173
|
+
restoreTargetField(target, '_onProgress', targetSnapshot._onProgress);
|
|
174
|
+
};
|
|
175
|
+
|
|
176
|
+
Object.defineProperty(target, RESTORE_PIPELINE_CONTEXTS, {
|
|
177
|
+
value: restore,
|
|
178
|
+
configurable: true,
|
|
179
|
+
enumerable: false,
|
|
180
|
+
writable: false,
|
|
181
|
+
});
|
|
182
|
+
|
|
183
|
+
return { runtimeConfig, sharedDebug, restore };
|
|
44
184
|
}
|
|
@@ -54,8 +54,13 @@ export function createDiffusionIndexBuffer(device, indices, label) {
|
|
|
54
54
|
size: indices.byteLength,
|
|
55
55
|
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST,
|
|
56
56
|
});
|
|
57
|
-
|
|
58
|
-
|
|
57
|
+
try {
|
|
58
|
+
device.queue.writeBuffer(buffer, 0, indices);
|
|
59
|
+
return buffer;
|
|
60
|
+
} catch (error) {
|
|
61
|
+
buffer.destroy();
|
|
62
|
+
throw error;
|
|
63
|
+
}
|
|
59
64
|
}
|
|
60
65
|
|
|
61
66
|
export function expectDiffusionWeight(weight, label) {
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import { getDevice, getKernelCapabilities } from '../../../gpu/device.js';
|
|
2
2
|
import { log, trace } from '../../../debug/index.js';
|
|
3
3
|
import { registerPipeline } from '../registry.js';
|
|
4
|
-
import { applyPipelineContexts } from '../context.js';
|
|
4
|
+
import { applyPipelineContexts, restorePipelineContexts } from '../context.js';
|
|
5
5
|
import { createInitializedPipeline } from '../factory.js';
|
|
6
6
|
import { createRng, sampleNormal } from '../rng.js';
|
|
7
7
|
import { initializeDiffusion } from './init.js';
|
|
@@ -52,6 +52,18 @@ function generateLatents(width, height, channels, latentScale, seed) {
|
|
|
52
52
|
return { latents, latentWidth, latentHeight };
|
|
53
53
|
}
|
|
54
54
|
|
|
55
|
+
function generateNoiseVector(size, seed) {
|
|
56
|
+
if (!Number.isFinite(size) || size <= 0) {
|
|
57
|
+
throw new Error(`generateNoiseVector requires a positive size, got ${size}.`);
|
|
58
|
+
}
|
|
59
|
+
const out = new Float32Array(size);
|
|
60
|
+
const rand = createRng(seed ?? createRandomSeed());
|
|
61
|
+
for (let i = 0; i < size; i++) {
|
|
62
|
+
out[i] = sampleNormal(rand);
|
|
63
|
+
}
|
|
64
|
+
return out;
|
|
65
|
+
}
|
|
66
|
+
|
|
55
67
|
function extractTokenSet(tokensByEncoder, key) {
|
|
56
68
|
const output = {};
|
|
57
69
|
for (const [name, entry] of Object.entries(tokensByEncoder || {})) {
|
|
@@ -195,13 +207,10 @@ async function applySchedulerStep(latentsTensor, scheduler, stepIndex, timestep,
|
|
|
195
207
|
const isFinalStep = stepIndex + 1 >= scheduler.timesteps.length - 1;
|
|
196
208
|
const noise = isFinalStep
|
|
197
209
|
? null
|
|
198
|
-
:
|
|
199
|
-
|
|
200
|
-
runtime.latent.height,
|
|
201
|
-
runtime.latent.channels,
|
|
202
|
-
runtime.latent.scale,
|
|
210
|
+
: generateNoiseVector(
|
|
211
|
+
sample.length,
|
|
203
212
|
(options.seedBase ?? createRandomSeed()) + stepIndex + 1
|
|
204
|
-
)
|
|
213
|
+
);
|
|
205
214
|
const step = stepScmScheduler(scheduler, modelOutput, timestep, sample, stepIndex, noise);
|
|
206
215
|
return createLatentTensor(step.prevSample, [...latentsTensor.shape], runtime);
|
|
207
216
|
}
|
|
@@ -310,6 +319,7 @@ export class DiffusionPipeline {
|
|
|
310
319
|
this.vaeWeights = null;
|
|
311
320
|
this.textEncoderWeights = null;
|
|
312
321
|
this.transformerWeights = null;
|
|
322
|
+
restorePipelineContexts(this);
|
|
313
323
|
}
|
|
314
324
|
|
|
315
325
|
async ensureVaeWeights() {
|
|
@@ -299,26 +299,26 @@ function resolveModulationSegments(weight, hiddenSize, fallbackSegments, resolve
|
|
|
299
299
|
if (Number.isInteger(segments) && segments > 0) {
|
|
300
300
|
return segments;
|
|
301
301
|
}
|
|
302
|
-
|
|
303
|
-
'
|
|
304
|
-
`
|
|
302
|
+
throw new Error(
|
|
303
|
+
`Modulation segments mismatch for ${name || 'unknown'}: rows=${rows}, hidden=${hiddenSize}, ` +
|
|
304
|
+
`expected an integer multiple instead of falling back to ${fallbackSegments}.`
|
|
305
305
|
);
|
|
306
306
|
}
|
|
307
|
-
|
|
307
|
+
throw new Error(
|
|
308
|
+
`Modulation tensor "${name || 'unknown'}" is missing shape metadata. ` +
|
|
309
|
+
`Runtime cannot fall back to ${fallbackSegments} segments.`
|
|
310
|
+
);
|
|
308
311
|
}
|
|
309
312
|
|
|
310
313
|
function resolveModulationOffsets(segments, hiddenSize) {
|
|
311
|
-
if (segments
|
|
314
|
+
if (segments === 9) {
|
|
312
315
|
return {
|
|
313
316
|
attn: { scale: 0, shift: hiddenSize, gate: hiddenSize * 2 },
|
|
314
317
|
attn2: { scale: hiddenSize * 3, shift: hiddenSize * 4, gate: hiddenSize * 5 },
|
|
315
318
|
ff: { scale: hiddenSize * 6, shift: hiddenSize * 7, gate: hiddenSize * 8 },
|
|
316
319
|
};
|
|
317
320
|
}
|
|
318
|
-
if (segments
|
|
319
|
-
if (segments !== 6) {
|
|
320
|
-
log.warn('Diffusion', `Unexpected modulation segment count=${segments}; using 6-segment layout.`);
|
|
321
|
-
}
|
|
321
|
+
if (segments === 6) {
|
|
322
322
|
const attn = { scale: 0, shift: hiddenSize, gate: hiddenSize * 2 };
|
|
323
323
|
return {
|
|
324
324
|
attn,
|
|
@@ -326,7 +326,7 @@ function resolveModulationOffsets(segments, hiddenSize) {
|
|
|
326
326
|
ff: { scale: hiddenSize * 3, shift: hiddenSize * 4, gate: hiddenSize * 5 },
|
|
327
327
|
};
|
|
328
328
|
}
|
|
329
|
-
throw new Error(`Unsupported modulation segments=${segments} (expected
|
|
329
|
+
throw new Error(`Unsupported modulation segments=${segments} (expected 6 or 9).`);
|
|
330
330
|
}
|
|
331
331
|
|
|
332
332
|
async function buildModulation(timeText, weight, bias, hiddenSize, segments, runtime, matmul, weightName, ops) {
|
|
@@ -80,3 +80,8 @@ export declare function projectContext(
|
|
|
80
80
|
): Promise<Tensor>;
|
|
81
81
|
|
|
82
82
|
export declare function assertClipHiddenActivationSupported(config: { hidden_act?: string }): void;
|
|
83
|
+
|
|
84
|
+
export declare function resolveGemma2WeightRoot(
|
|
85
|
+
weights: Map<string, any>,
|
|
86
|
+
prefix?: string
|
|
87
|
+
): string;
|
|
@@ -723,8 +723,19 @@ function buildGemma2LayerTypes(layerCount, slidingWindow) {
|
|
|
723
723
|
));
|
|
724
724
|
}
|
|
725
725
|
|
|
726
|
-
function
|
|
727
|
-
const
|
|
726
|
+
export function resolveGemma2WeightRoot(weights, prefix = 'text_encoder') {
|
|
727
|
+
const nestedRoot = `${prefix}.model`;
|
|
728
|
+
if (weights?.has(`${nestedRoot}.embed_tokens.weight`)) {
|
|
729
|
+
return nestedRoot;
|
|
730
|
+
}
|
|
731
|
+
if (weights?.has(`${prefix}.embed_tokens.weight`)) {
|
|
732
|
+
return prefix;
|
|
733
|
+
}
|
|
734
|
+
return nestedRoot;
|
|
735
|
+
}
|
|
736
|
+
|
|
737
|
+
function getGemma2LayerWeight(weights, weightRoot, layerIdx, suffix, required = true) {
|
|
738
|
+
const key = `${weightRoot}.layers.${layerIdx}.${suffix}`;
|
|
728
739
|
const weight = weights.get(key) || null;
|
|
729
740
|
if (!weight && required) {
|
|
730
741
|
throw new Error(`Missing Gemma2 diffusion weight "${key}".`);
|
|
@@ -805,8 +816,9 @@ async function runGemma2TextEncoder(tokens, weightsEntry, config, runtime, optio
|
|
|
805
816
|
const tokenIds = normalizeTokens(tokens, options.maxLength ?? resolved.maxPositionEmbeddings, padTokenId);
|
|
806
817
|
const numTokens = tokenIds.length;
|
|
807
818
|
const tokenBuffer = createDiffusionIndexBuffer(device, tokenIds, `${prefix}_tokens`);
|
|
819
|
+
const weightRoot = resolveGemma2WeightRoot(weights, prefix);
|
|
808
820
|
|
|
809
|
-
const embedKey = `${
|
|
821
|
+
const embedKey = `${weightRoot}.embed_tokens.weight`;
|
|
810
822
|
const embedWeight = expectDiffusionWeight(
|
|
811
823
|
weights.get(embedKey),
|
|
812
824
|
embedKey
|
|
@@ -837,16 +849,16 @@ async function runGemma2TextEncoder(tokens, weightsEntry, config, runtime, optio
|
|
|
837
849
|
const layerWeights = new Map();
|
|
838
850
|
for (let layerIdx = 0; layerIdx < resolved.numLayers; layerIdx++) {
|
|
839
851
|
layerWeights.set(`layer_${layerIdx}`, {
|
|
840
|
-
inputNorm: getGemma2LayerWeight(weights,
|
|
841
|
-
qProj: getGemma2LayerWeight(weights,
|
|
842
|
-
kProj: getGemma2LayerWeight(weights,
|
|
843
|
-
vProj: getGemma2LayerWeight(weights,
|
|
844
|
-
oProj: getGemma2LayerWeight(weights,
|
|
845
|
-
postAttentionNorm: getGemma2LayerWeight(weights,
|
|
846
|
-
preFeedforwardNorm: getGemma2LayerWeight(weights,
|
|
847
|
-
gate: getGemma2LayerWeight(weights,
|
|
848
|
-
up: getGemma2LayerWeight(weights,
|
|
849
|
-
down: getGemma2LayerWeight(weights,
|
|
852
|
+
inputNorm: getGemma2LayerWeight(weights, weightRoot, layerIdx, 'input_layernorm.weight'),
|
|
853
|
+
qProj: getGemma2LayerWeight(weights, weightRoot, layerIdx, 'self_attn.q_proj.weight'),
|
|
854
|
+
kProj: getGemma2LayerWeight(weights, weightRoot, layerIdx, 'self_attn.k_proj.weight'),
|
|
855
|
+
vProj: getGemma2LayerWeight(weights, weightRoot, layerIdx, 'self_attn.v_proj.weight'),
|
|
856
|
+
oProj: getGemma2LayerWeight(weights, weightRoot, layerIdx, 'self_attn.o_proj.weight'),
|
|
857
|
+
postAttentionNorm: getGemma2LayerWeight(weights, weightRoot, layerIdx, 'post_attention_layernorm.weight'),
|
|
858
|
+
preFeedforwardNorm: getGemma2LayerWeight(weights, weightRoot, layerIdx, 'pre_feedforward_layernorm.weight'),
|
|
859
|
+
gate: getGemma2LayerWeight(weights, weightRoot, layerIdx, 'mlp.gate_proj.weight'),
|
|
860
|
+
up: getGemma2LayerWeight(weights, weightRoot, layerIdx, 'mlp.up_proj.weight'),
|
|
861
|
+
down: getGemma2LayerWeight(weights, weightRoot, layerIdx, 'mlp.down_proj.weight'),
|
|
850
862
|
});
|
|
851
863
|
}
|
|
852
864
|
|
|
@@ -910,10 +922,10 @@ async function runGemma2TextEncoder(tokens, weightsEntry, config, runtime, optio
|
|
|
910
922
|
numTokens * resolved.hiddenSize,
|
|
911
923
|
context
|
|
912
924
|
);
|
|
913
|
-
hidden = createTensor(output
|
|
925
|
+
hidden = createTensor(output, activationDtype, [numTokens, resolved.hiddenSize], `gemma2_layer_${layerIdx}`);
|
|
914
926
|
}
|
|
915
927
|
|
|
916
|
-
const finalNormKey = `${
|
|
928
|
+
const finalNormKey = `${weightRoot}.norm.weight`;
|
|
917
929
|
const finalNorm = expectDiffusionWeight(weights.get(finalNormKey), finalNormKey);
|
|
918
930
|
const final = await ops.rmsNorm(hidden, getBuffer(finalNorm), resolved.rmsNormEps, {
|
|
919
931
|
batchSize: numTokens,
|
|
@@ -118,13 +118,9 @@ function resolveAttentionHeadShape(channels, config) {
|
|
|
118
118
|
headDim: channels / configuredNumHeads,
|
|
119
119
|
};
|
|
120
120
|
}
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
return {
|
|
125
|
-
numHeads: Math.max(1, channels / headDim),
|
|
126
|
-
headDim,
|
|
127
|
-
};
|
|
121
|
+
throw new Error(
|
|
122
|
+
`VAE attention requires explicit compatible attention_head_dim or num_attention_heads for channels=${channels}.`
|
|
123
|
+
);
|
|
128
124
|
}
|
|
129
125
|
|
|
130
126
|
function createBiasTensor(weight, label, fallbackDtype = 'f16') {
|
|
@@ -16,10 +16,10 @@ import { log, trace } from '../../../debug/index.js';
|
|
|
16
16
|
import { DEFAULT_ENERGY_CONFIG } from '../../../config/schema/energy.schema.js';
|
|
17
17
|
import { f32ToF16Array, f16ToF32Array } from '../../kv-cache/types.js';
|
|
18
18
|
import { registerPipeline } from '../registry.js';
|
|
19
|
-
import { applyPipelineContexts } from '../context.js';
|
|
19
|
+
import { applyPipelineContexts, restorePipelineContexts } from '../context.js';
|
|
20
20
|
import { createInitializedPipeline } from '../factory.js';
|
|
21
21
|
import { createRng, sampleNormal } from '../rng.js';
|
|
22
|
-
import { mergeQuintelConfig, runQuintelEnergyLoop } from './quintel.js';
|
|
22
|
+
import { buildQuintelKernelFlags, mergeQuintelConfig, runQuintelEnergyLoop } from './quintel.js';
|
|
23
23
|
|
|
24
24
|
|
|
25
25
|
function generateRandomArray(count, mode, seed, scale) {
|
|
@@ -140,24 +140,28 @@ async function createEnergyTensor(device, data, dtype, shape, label) {
|
|
|
140
140
|
const byteLength = data.byteLength;
|
|
141
141
|
const alignedSize = Math.ceil(byteLength / 4) * 4;
|
|
142
142
|
const buffer = acquireBuffer(alignedSize, undefined, label);
|
|
143
|
+
try {
|
|
144
|
+
let payload = data;
|
|
145
|
+
if (alignedSize !== byteLength) {
|
|
146
|
+
const padded = new Uint8Array(alignedSize);
|
|
147
|
+
const view = data instanceof ArrayBuffer
|
|
148
|
+
? new Uint8Array(data)
|
|
149
|
+
: new Uint8Array(data.buffer, data.byteOffset, data.byteLength);
|
|
150
|
+
padded.set(view);
|
|
151
|
+
payload = padded;
|
|
152
|
+
}
|
|
143
153
|
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
const
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
device.queue.writeBuffer(buffer, 0, payload);
|
|
155
|
-
const tensor = createTensor(buffer, dtype, shape, label);
|
|
156
|
-
const expectedBytes = tensorBytes(shape, dtype);
|
|
157
|
-
if (expectedBytes !== byteLength) {
|
|
158
|
-
log.warn('Energy', `${label} byte length mismatch: expected ${expectedBytes}, got ${byteLength}`);
|
|
154
|
+
device.queue.writeBuffer(buffer, 0, payload);
|
|
155
|
+
const tensor = createTensor(buffer, dtype, shape, label);
|
|
156
|
+
const expectedBytes = tensorBytes(shape, dtype);
|
|
157
|
+
if (expectedBytes !== byteLength) {
|
|
158
|
+
log.warn('Energy', `${label} byte length mismatch: expected ${expectedBytes}, got ${byteLength}`);
|
|
159
|
+
}
|
|
160
|
+
return tensor;
|
|
161
|
+
} catch (error) {
|
|
162
|
+
releaseBuffer(buffer);
|
|
163
|
+
throw error;
|
|
159
164
|
}
|
|
160
|
-
return tensor;
|
|
161
165
|
}
|
|
162
166
|
|
|
163
167
|
async function readTensorToFloat32(tensor) {
|
|
@@ -202,6 +206,7 @@ export class EnergyPipeline {
|
|
|
202
206
|
|
|
203
207
|
async unload() {
|
|
204
208
|
this.manifest = null;
|
|
209
|
+
restorePipelineContexts(this);
|
|
205
210
|
}
|
|
206
211
|
|
|
207
212
|
async generate(request = {}) {
|
|
@@ -336,6 +341,7 @@ export class EnergyPipeline {
|
|
|
336
341
|
const centerWeight = Number.isFinite(weights.center) ? weights.center : 1.0;
|
|
337
342
|
const binarizeWeight = Number.isFinite(weights.binarize) ? weights.binarize : 0.0;
|
|
338
343
|
const centerTarget = Number.isFinite(quintelConfig.centerTarget) ? quintelConfig.centerTarget : 1.0;
|
|
344
|
+
const flags = buildQuintelKernelFlags(rules, binarizeWeight);
|
|
339
345
|
const energyHistory = [];
|
|
340
346
|
const stepTimesMs = [];
|
|
341
347
|
let lastEnergy = null;
|
|
@@ -387,11 +393,11 @@ export class EnergyPipeline {
|
|
|
387
393
|
await runEnergyQuintelReduce(stateTensor, {
|
|
388
394
|
count: elementCount,
|
|
389
395
|
size,
|
|
396
|
+
flags,
|
|
390
397
|
symmetryWeight,
|
|
391
398
|
centerWeight,
|
|
392
399
|
binarizeWeight,
|
|
393
400
|
centerTarget,
|
|
394
|
-
rules,
|
|
395
401
|
outputBuffer: reduceBuffer,
|
|
396
402
|
});
|
|
397
403
|
|
|
@@ -447,13 +453,13 @@ export class EnergyPipeline {
|
|
|
447
453
|
await runEnergyQuintelGrad(stateTensor, {
|
|
448
454
|
count: elementCount,
|
|
449
455
|
size,
|
|
456
|
+
flags,
|
|
450
457
|
countDiff: safeCountDiff,
|
|
451
458
|
symmetryWeight,
|
|
452
459
|
countWeight,
|
|
453
460
|
centerWeight,
|
|
454
461
|
binarizeWeight,
|
|
455
462
|
centerTarget,
|
|
456
|
-
rules,
|
|
457
463
|
outputBuffer: gradBuffer,
|
|
458
464
|
});
|
|
459
465
|
|
|
@@ -471,6 +477,7 @@ export class EnergyPipeline {
|
|
|
471
477
|
await runEnergyQuintelUpdate(stateTensor, {
|
|
472
478
|
count: elementCount,
|
|
473
479
|
size,
|
|
480
|
+
flags,
|
|
474
481
|
stepSize,
|
|
475
482
|
gradientScale,
|
|
476
483
|
countDiff: safeCountDiff,
|
|
@@ -481,7 +488,6 @@ export class EnergyPipeline {
|
|
|
481
488
|
centerTarget,
|
|
482
489
|
clampMin,
|
|
483
490
|
clampMax,
|
|
484
|
-
rules,
|
|
485
491
|
});
|
|
486
492
|
}
|
|
487
493
|
|
|
@@ -84,4 +84,9 @@ export function mergeQuintelConfig(
|
|
|
84
84
|
override?: Partial<QuintelEnergyConfig> | null
|
|
85
85
|
): QuintelEnergyConfig;
|
|
86
86
|
|
|
87
|
+
export function buildQuintelKernelFlags(
|
|
88
|
+
rules: Partial<QuintelRuleConfig> | null | undefined,
|
|
89
|
+
binarizeWeight?: number
|
|
90
|
+
): number;
|
|
91
|
+
|
|
87
92
|
export function runQuintelEnergyLoop(options: QuintelEnergyLoopOptions): QuintelEnergyLoopResult;
|
|
@@ -22,6 +22,17 @@ export function mergeQuintelConfig(base, override) {
|
|
|
22
22
|
};
|
|
23
23
|
}
|
|
24
24
|
|
|
25
|
+
export function buildQuintelKernelFlags(rules, binarizeWeight) {
|
|
26
|
+
let flags = 0;
|
|
27
|
+
if (rules?.mirrorX) flags |= 1;
|
|
28
|
+
if (rules?.mirrorY) flags |= 2;
|
|
29
|
+
if (rules?.diagonal) flags |= 4;
|
|
30
|
+
if (rules?.count) flags |= 8;
|
|
31
|
+
if (rules?.center) flags |= 16;
|
|
32
|
+
if (Number.isFinite(binarizeWeight) && binarizeWeight !== 0) flags |= 32;
|
|
33
|
+
return flags >>> 0;
|
|
34
|
+
}
|
|
35
|
+
|
|
25
36
|
function applyPairEnergy(state, gradients, indexA, indexB, weight) {
|
|
26
37
|
const diff = state[indexA] - state[indexB];
|
|
27
38
|
const energy = weight * diff * diff;
|
|
@@ -5,7 +5,7 @@ import { runEnergyEval, runEnergyUpdate } from '../../../gpu/kernels/index.js';
|
|
|
5
5
|
import { log } from '../../../debug/index.js';
|
|
6
6
|
import { f16ToF32Array, f32ToF16Array } from '../../kv-cache/types.js';
|
|
7
7
|
import { registerPipeline } from '../registry.js';
|
|
8
|
-
import { applyPipelineContexts } from '../context.js';
|
|
8
|
+
import { applyPipelineContexts, restorePipelineContexts } from '../context.js';
|
|
9
9
|
import { createInitializedPipeline } from '../factory.js';
|
|
10
10
|
import { selectRuleValue } from '../../../rules/rule-registry.js';
|
|
11
11
|
|
|
@@ -165,19 +165,22 @@ async function createFeatureTensor(device, values, dtype, label) {
|
|
|
165
165
|
const byteLength = payload.byteLength;
|
|
166
166
|
const alignedSize = Math.ceil(byteLength / 4) * 4;
|
|
167
167
|
const buffer = acquireBuffer(alignedSize, undefined, label);
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
168
|
+
try {
|
|
169
|
+
if (alignedSize === byteLength) {
|
|
170
|
+
device.queue.writeBuffer(buffer, 0, payload);
|
|
171
|
+
} else {
|
|
172
|
+
const bytes = payload instanceof Uint16Array
|
|
173
|
+
? new Uint8Array(payload.buffer, payload.byteOffset, payload.byteLength)
|
|
174
|
+
: new Uint8Array(payload.buffer, payload.byteOffset, payload.byteLength);
|
|
175
|
+
const padded = new Uint8Array(alignedSize);
|
|
176
|
+
padded.set(bytes);
|
|
177
|
+
device.queue.writeBuffer(buffer, 0, padded);
|
|
178
|
+
}
|
|
179
|
+
return createTensor(buffer, dtype, [values.length], label);
|
|
180
|
+
} catch (error) {
|
|
181
|
+
releaseBuffer(buffer);
|
|
182
|
+
throw error;
|
|
178
183
|
}
|
|
179
|
-
|
|
180
|
-
return createTensor(buffer, dtype, [values.length], label);
|
|
181
184
|
}
|
|
182
185
|
|
|
183
186
|
async function readTensorF32(tensor) {
|
|
@@ -307,6 +310,7 @@ export class EnergyRowHeadPipeline {
|
|
|
307
310
|
this.manifest = null;
|
|
308
311
|
this.model = null;
|
|
309
312
|
this.stats = {};
|
|
313
|
+
restorePipelineContexts(this);
|
|
310
314
|
}
|
|
311
315
|
|
|
312
316
|
async scoreRows(request = {}) {
|