@simulatte/doppler 0.1.5 → 0.1.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +126 -0
- package/README.md +25 -17
- package/package.json +20 -4
- package/src/adapters/adapter-registry.js +12 -1
- package/src/adapters/lora-loader.js +23 -6
- package/src/bridge/extension-client.d.ts +5 -0
- package/src/bridge/extension-client.js +40 -0
- package/src/bridge/index.d.ts +2 -1
- package/src/bridge/index.js +6 -4
- package/src/browser/browser-converter.js +26 -1
- package/src/browser/file-picker.js +6 -0
- package/src/browser/safetensors-parser-browser.js +84 -1
- package/src/browser/shard-io-browser.js +2 -2
- package/src/browser/tensor-source-download.js +8 -2
- package/src/browser/tensor-source-http.d.ts +1 -0
- package/src/browser/tensor-source-http.js +5 -1
- package/src/client/doppler-api.browser.js +20 -4
- package/src/client/doppler-api.js +19 -3
- package/src/client/doppler-provider/generation.js +12 -0
- package/src/client/doppler-provider/model-manager.d.ts +10 -0
- package/src/client/doppler-provider/model-manager.js +91 -19
- package/src/client/doppler-provider/source-runtime.d.ts +2 -1
- package/src/client/doppler-provider/source-runtime.js +132 -13
- package/src/client/doppler-registry.json +8 -7
- package/src/config/backward-registry-loader.js +17 -2
- package/src/config/execution-v0-contract-check.js +113 -15
- package/src/config/kernel-path-contract-check.js +57 -29
- package/src/config/kernel-path-loader.js +5 -36
- package/src/config/kernels/kernel-ref-digests.js +39 -39
- package/src/config/kernels/registry.js +14 -1
- package/src/config/kernels/registry.json +49 -7
- package/src/config/loader.d.ts +1 -1
- package/src/config/loader.js +43 -4
- package/src/config/merge-contract-check.js +59 -4
- package/src/config/merge-helpers.js +128 -7
- package/src/config/merge.d.ts +1 -0
- package/src/config/merge.js +28 -0
- package/src/config/param-validator.js +47 -2
- package/src/config/presets/kernel-paths/{gemma2-q4k-dequant-f32a.json → gemma2-q4k-dequant-f32a-nosubgroups.json} +3 -3
- package/src/config/presets/kernel-paths/gemma3-f16-fused-f32a-online-streamingprefill.json +223 -0
- package/src/config/presets/kernel-paths/{gemma3-q4k-dequant-f32a.json → gemma3-q4k-dequant-f32a-nosubgroups.json} +3 -3
- package/src/config/presets/kernel-paths/registry.json +29 -8
- package/src/config/presets/models/gemma2.json +2 -2
- package/src/config/presets/models/qwen3.json +9 -2
- package/src/config/presets/models/transformer.json +5 -0
- package/src/config/presets/runtime/experiments/bench/gemma3-bench-q4k.json +1 -1
- package/src/config/presets/runtime/experiments/debug/gemma3-debug-q4k.json +1 -1
- package/src/config/presets/runtime/experiments/verify/gemma3-verify.json +1 -1
- package/src/config/presets/runtime/kernels/dequant-f16-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/dequant-f32-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/embeddinggemma-q4k-dequant-f32a.json +37 -0
- package/src/config/presets/runtime/kernels/fused-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f16a.json +33 -0
- package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f32a-nosubgroups.json +33 -0
- package/src/config/presets/runtime/kernels/gemma2-q4k-fused-f32a.json +33 -0
- package/src/config/presets/runtime/kernels/safe-q4k.json +6 -13
- package/src/config/presets/runtime/platform/metal-apple-q4k.json +1 -1
- package/src/config/required-inference-fields-contract-check.js +6 -0
- package/src/config/runtime.js +6 -1
- package/src/config/schema/debug.schema.d.ts +5 -0
- package/src/config/schema/doppler.schema.js +16 -21
- package/src/config/schema/inference-defaults.schema.js +6 -3
- package/src/config/schema/inference.schema.d.ts +9 -0
- package/src/config/schema/kernel-path.schema.d.ts +11 -1
- package/src/config/schema/kernel-thresholds.schema.js +12 -4
- package/src/config/schema/manifest.schema.d.ts +8 -1
- package/src/config/schema/manifest.schema.js +19 -3
- package/src/config/training-defaults.js +30 -22
- package/src/converter/conversion-plan.js +94 -9
- package/src/converter/core.d.ts +7 -0
- package/src/converter/core.js +14 -9
- package/src/converter/execution-v0-manifest.js +4 -1
- package/src/converter/index.d.ts +1 -0
- package/src/converter/index.js +1 -0
- package/src/converter/manifest-inference.js +43 -12
- package/src/converter/parsers/diffusion.js +0 -3
- package/src/converter/quantization-info.js +35 -15
- package/src/converter/rope-config.js +42 -0
- package/src/converter/shard-packer.d.ts +1 -1
- package/src/converter/shard-packer.js +4 -1
- package/src/debug/config.js +123 -11
- package/src/debug/signals.js +7 -1
- package/src/debug/tensor.d.ts +2 -0
- package/src/debug/tensor.js +13 -2
- package/src/distribution/p2p-control-plane.js +52 -12
- package/src/distribution/p2p-observability.js +43 -7
- package/src/distribution/p2p-webrtc-browser.js +20 -0
- package/src/distribution/shard-delivery.js +77 -26
- package/src/formats/gguf/types.js +33 -16
- package/src/formats/rdrr/groups.d.ts +12 -4
- package/src/formats/rdrr/groups.js +3 -6
- package/src/formats/rdrr/parsing.js +39 -2
- package/src/formats/rdrr/types.d.ts +2 -1
- package/src/gpu/command-recorder.js +86 -61
- package/src/gpu/device.d.ts +1 -0
- package/src/gpu/device.js +131 -19
- package/src/gpu/kernel-tuner/benchmarks.js +326 -316
- package/src/gpu/kernel-tuner/cache.js +71 -4
- package/src/gpu/kernel-tuner/tuner.js +22 -4
- package/src/gpu/kernels/attention.js +113 -34
- package/src/gpu/kernels/backward/adam.js +62 -58
- package/src/gpu/kernels/backward/attention_backward.js +257 -169
- package/src/gpu/kernels/backward/conv2d_backward.js +14 -1
- package/src/gpu/kernels/bias_add.wgsl +8 -6
- package/src/gpu/kernels/bias_add_f16.wgsl +8 -5
- package/src/gpu/kernels/cast.js +191 -149
- package/src/gpu/kernels/check-stop.js +33 -44
- package/src/gpu/kernels/conv2d.js +27 -17
- package/src/gpu/kernels/conv2d.wgsl +7 -8
- package/src/gpu/kernels/conv2d_f16.wgsl +7 -8
- package/src/gpu/kernels/cross_entropy_loss.js +21 -15
- package/src/gpu/kernels/depthwise_conv2d.js +37 -26
- package/src/gpu/kernels/depthwise_conv2d.wgsl +6 -9
- package/src/gpu/kernels/depthwise_conv2d_f16.wgsl +6 -9
- package/src/gpu/kernels/dequant.js +178 -126
- package/src/gpu/kernels/energy.d.ts +3 -21
- package/src/gpu/kernels/energy.js +111 -88
- package/src/gpu/kernels/feature-check.js +1 -1
- package/src/gpu/kernels/fused_ffn.js +84 -65
- package/src/gpu/kernels/fused_matmul_residual.js +56 -33
- package/src/gpu/kernels/fused_matmul_rmsnorm.js +62 -45
- package/src/gpu/kernels/gather.js +33 -15
- package/src/gpu/kernels/gelu.js +19 -11
- package/src/gpu/kernels/grouped_pointwise_conv2d.js +34 -23
- package/src/gpu/kernels/grouped_pointwise_conv2d.wgsl +6 -9
- package/src/gpu/kernels/grouped_pointwise_conv2d_f16.wgsl +6 -9
- package/src/gpu/kernels/groupnorm.js +34 -23
- package/src/gpu/kernels/kv-quantize.js +5 -2
- package/src/gpu/kernels/layernorm.js +35 -19
- package/src/gpu/kernels/logit-merge.js +5 -3
- package/src/gpu/kernels/matmul.js +83 -39
- package/src/gpu/kernels/modulate.js +23 -15
- package/src/gpu/kernels/moe.js +221 -175
- package/src/gpu/kernels/pixel_shuffle.js +22 -14
- package/src/gpu/kernels/pixel_shuffle.wgsl +4 -5
- package/src/gpu/kernels/pixel_shuffle_f16.wgsl +4 -5
- package/src/gpu/kernels/relu.js +31 -10
- package/src/gpu/kernels/relu.wgsl +2 -1
- package/src/gpu/kernels/relu_f16.wgsl +2 -1
- package/src/gpu/kernels/repeat_channels.js +25 -17
- package/src/gpu/kernels/repeat_channels.wgsl +4 -5
- package/src/gpu/kernels/repeat_channels_f16.wgsl +4 -5
- package/src/gpu/kernels/residual.js +69 -23
- package/src/gpu/kernels/residual.wgsl +6 -3
- package/src/gpu/kernels/residual_f16.wgsl +2 -1
- package/src/gpu/kernels/residual_f16_vec4.wgsl +2 -1
- package/src/gpu/kernels/residual_vec4.wgsl +2 -1
- package/src/gpu/kernels/rmsnorm.js +96 -28
- package/src/gpu/kernels/rmsnorm.wgsl +14 -6
- package/src/gpu/kernels/rmsnorm_f16.wgsl +10 -2
- package/src/gpu/kernels/rope.d.ts +2 -0
- package/src/gpu/kernels/rope.js +14 -1
- package/src/gpu/kernels/rope.wgsl +56 -40
- package/src/gpu/kernels/sample.js +27 -38
- package/src/gpu/kernels/sana_linear_attention.js +19 -12
- package/src/gpu/kernels/sana_linear_attention_apply.wgsl +4 -5
- package/src/gpu/kernels/sana_linear_attention_apply_f16.wgsl +4 -5
- package/src/gpu/kernels/sana_linear_attention_summary.wgsl +4 -0
- package/src/gpu/kernels/sana_linear_attention_summary_f16.wgsl +4 -0
- package/src/gpu/kernels/scale.js +18 -11
- package/src/gpu/kernels/shader-cache.js +4 -2
- package/src/gpu/kernels/silu.d.ts +1 -0
- package/src/gpu/kernels/silu.js +148 -82
- package/src/gpu/kernels/silu.wgsl +19 -9
- package/src/gpu/kernels/silu_f16.wgsl +19 -9
- package/src/gpu/kernels/softmax.js +44 -25
- package/src/gpu/kernels/split_qkv.js +23 -13
- package/src/gpu/kernels/transpose.js +31 -10
- package/src/gpu/kernels/transpose.wgsl +6 -5
- package/src/gpu/kernels/upsample2d.js +22 -13
- package/src/gpu/kernels/upsample2d.wgsl +6 -9
- package/src/gpu/kernels/upsample2d_f16.wgsl +6 -9
- package/src/gpu/kernels/utils.js +35 -13
- package/src/gpu/partitioned-buffer-pool.js +10 -2
- package/src/gpu/perf-guards.js +2 -9
- package/src/gpu/profiler.js +27 -22
- package/src/gpu/readback-utils.d.ts +16 -0
- package/src/gpu/readback-utils.js +41 -0
- package/src/gpu/submit-tracker.js +13 -0
- package/src/gpu/uniform-cache.d.ts +1 -0
- package/src/gpu/uniform-cache.js +30 -9
- package/src/hotswap/intent-bundle.js +6 -0
- package/src/hotswap/manifest.d.ts +10 -1
- package/src/hotswap/manifest.js +12 -2
- package/src/hotswap/runtime.js +30 -8
- package/src/index-browser.d.ts +44 -0
- package/src/index-browser.js +14 -0
- package/src/inference/browser-harness-contract-helpers.d.ts +5 -0
- package/src/inference/browser-harness-contract-helpers.js +28 -0
- package/src/inference/browser-harness-diffusion-energy-suites.d.ts +2 -0
- package/src/inference/browser-harness-diffusion-energy-suites.js +269 -0
- package/src/inference/browser-harness-model-helpers.d.ts +16 -0
- package/src/inference/browser-harness-model-helpers.js +217 -0
- package/src/inference/browser-harness-report-helpers.d.ts +7 -0
- package/src/inference/browser-harness-report-helpers.js +42 -0
- package/src/inference/browser-harness-runtime-helpers.d.ts +61 -0
- package/src/inference/browser-harness-runtime-helpers.js +415 -0
- package/src/inference/browser-harness-suite-helpers.d.ts +28 -0
- package/src/inference/browser-harness-suite-helpers.js +268 -0
- package/src/inference/browser-harness-text-helpers.d.ts +27 -0
- package/src/inference/browser-harness-text-helpers.js +788 -0
- package/src/inference/browser-harness.d.ts +6 -0
- package/src/inference/browser-harness.js +130 -1950
- package/src/inference/kv-cache/base.js +140 -94
- package/src/inference/kv-cache/tiered.js +5 -3
- package/src/inference/moe-router.js +88 -56
- package/src/inference/multi-model-network.js +5 -3
- package/src/inference/network-evolution.d.ts +11 -2
- package/src/inference/network-evolution.js +20 -21
- package/src/inference/pipelines/context.d.ts +3 -0
- package/src/inference/pipelines/context.js +142 -2
- package/src/inference/pipelines/diffusion/helpers.js +7 -2
- package/src/inference/pipelines/diffusion/pipeline.js +17 -7
- package/src/inference/pipelines/diffusion/sd3-transformer.js +10 -10
- package/src/inference/pipelines/diffusion/text-encoder-gpu.d.ts +5 -0
- package/src/inference/pipelines/diffusion/text-encoder-gpu.js +27 -15
- package/src/inference/pipelines/diffusion/vae.js +3 -7
- package/src/inference/pipelines/energy/pipeline.js +27 -21
- package/src/inference/pipelines/energy/quintel.d.ts +5 -0
- package/src/inference/pipelines/energy/quintel.js +11 -0
- package/src/inference/pipelines/energy-head/row-head-pipeline.js +17 -13
- package/src/inference/pipelines/structured/json-head-pipeline.js +26 -11
- package/src/inference/pipelines/text/attention/projections.js +151 -101
- package/src/inference/pipelines/text/attention/record.js +73 -10
- package/src/inference/pipelines/text/attention/run.js +73 -10
- package/src/inference/pipelines/text/chat-format.js +25 -1
- package/src/inference/pipelines/text/config.d.ts +4 -0
- package/src/inference/pipelines/text/config.js +71 -5
- package/src/inference/pipelines/text/embed.js +2 -8
- package/src/inference/pipelines/text/execution-plan.js +64 -50
- package/src/inference/pipelines/text/execution-v0-contract-helpers.d.ts +59 -0
- package/src/inference/pipelines/text/execution-v0-contract-helpers.js +937 -0
- package/src/inference/pipelines/text/execution-v0-runtime-builders.d.ts +15 -0
- package/src/inference/pipelines/text/execution-v0-runtime-builders.js +279 -0
- package/src/inference/pipelines/text/execution-v0.js +78 -1002
- package/src/inference/pipelines/text/ffn/standard.js +3 -0
- package/src/inference/pipelines/text/generator-steps.d.ts +46 -0
- package/src/inference/pipelines/text/generator-steps.js +298 -207
- package/src/inference/pipelines/text/generator.js +6 -23
- package/src/inference/pipelines/text/init.d.ts +4 -0
- package/src/inference/pipelines/text/init.js +134 -29
- package/src/inference/pipelines/text/kernel-path-auto-select.js +2 -0
- package/src/inference/pipelines/text/kernel-trace.d.ts +2 -0
- package/src/inference/pipelines/text/kernel-trace.js +6 -0
- package/src/inference/pipelines/text/layer.js +14 -9
- package/src/inference/pipelines/text/linear-attention.d.ts +10 -0
- package/src/inference/pipelines/text/linear-attention.js +80 -6
- package/src/inference/pipelines/text/logits/gpu.js +10 -5
- package/src/inference/pipelines/text/logits/index.js +10 -11
- package/src/inference/pipelines/text/logits/utils.d.ts +7 -0
- package/src/inference/pipelines/text/logits/utils.js +9 -0
- package/src/inference/pipelines/text/lora-apply.js +50 -32
- package/src/inference/pipelines/text/model-load.js +279 -104
- package/src/inference/pipelines/text/moe-cache.js +5 -4
- package/src/inference/pipelines/text/moe-cpu-gptoss.js +74 -69
- package/src/inference/pipelines/text/moe-cpu.js +42 -38
- package/src/inference/pipelines/text/moe-gpu.js +110 -86
- package/src/inference/pipelines/text/ops.js +90 -90
- package/src/inference/pipelines/text/probes.js +9 -9
- package/src/inference/pipelines/text/weights.js +17 -7
- package/src/inference/pipelines/text.js +17 -1
- package/src/inference/speculative.d.ts +2 -2
- package/src/inference/speculative.js +4 -18
- package/src/inference/test-harness.d.ts +1 -1
- package/src/inference/test-harness.js +15 -5
- package/src/inference/tokenizer.d.ts +0 -5
- package/src/inference/tokenizer.js +4 -23
- package/src/inference/tokenizers/bpe.js +9 -0
- package/src/inference/tokenizers/bundled.js +176 -33
- package/src/inference/tokenizers/sentencepiece.js +12 -0
- package/src/loader/doppler-loader.js +38 -22
- package/src/loader/dtype-utils.js +3 -44
- package/src/loader/embedding-loader.js +7 -3
- package/src/loader/experts/expert-cache.js +13 -6
- package/src/loader/experts/expert-loader.js +10 -6
- package/src/loader/final-weights-loader.js +8 -4
- package/src/loader/layer-loader.js +2 -1
- package/src/loader/loader-state.js +2 -2
- package/src/loader/memory-monitor.js +8 -0
- package/src/loader/multi-model-loader.d.ts +14 -0
- package/src/loader/multi-model-loader.js +70 -24
- package/src/loader/shard-cache.js +81 -12
- package/src/loader/shard-resolver.js +25 -3
- package/src/loader/tensors/tensor-loader.js +209 -144
- package/src/loader/tensors/tensor-reader.js +76 -19
- package/src/loader/weight-downcast.js +1 -1
- package/src/memory/buffer-pool.d.ts +9 -1
- package/src/memory/buffer-pool.js +109 -44
- package/src/memory/unified-detect.js +1 -1
- package/src/rules/inference/kernel-path.rules.json +24 -8
- package/src/rules/rule-registry.js +25 -1
- package/src/rules/tooling/command-runtime.rules.json +18 -0
- package/src/storage/backends/opfs-store.js +68 -24
- package/src/storage/downloader.js +364 -83
- package/src/storage/index.d.ts +3 -0
- package/src/storage/index.js +3 -0
- package/src/storage/preflight.d.ts +2 -2
- package/src/storage/preflight.js +24 -2
- package/src/storage/quickstart-downloader.js +11 -5
- package/src/storage/registry.js +10 -4
- package/src/storage/reports.js +1 -1
- package/src/storage/shard-manager.d.ts +15 -1
- package/src/storage/shard-manager.js +51 -3
- package/src/storage/source-artifact-store.d.ts +52 -0
- package/src/storage/source-artifact-store.js +234 -0
- package/src/tooling/command-api-constants.d.ts +9 -0
- package/src/tooling/command-api-constants.js +9 -0
- package/src/tooling/command-api-family-normalizers.d.ts +9 -0
- package/src/tooling/command-api-family-normalizers.js +343 -0
- package/src/tooling/command-api-helpers.d.ts +25 -0
- package/src/tooling/command-api-helpers.js +262 -0
- package/src/tooling/command-api.d.ts +27 -1
- package/src/tooling/command-api.js +26 -473
- package/src/tooling/command-envelope.js +4 -1
- package/src/tooling/command-runner-shared.js +52 -18
- package/src/tooling/lean-execution-contract.js +150 -3
- package/src/tooling/node-browser-command-runner.d.ts +4 -0
- package/src/tooling/node-browser-command-runner.js +218 -273
- package/src/tooling/node-command-runner.js +44 -3
- package/src/tooling/node-converter.js +27 -1
- package/src/tooling/node-source-runtime.d.ts +1 -1
- package/src/tooling/node-source-runtime.js +84 -3
- package/src/tooling/node-webgpu.js +30 -105
- package/src/tooling/opfs-cache.js +21 -4
- package/src/tooling/runtime-input-composition.d.ts +38 -0
- package/src/tooling/runtime-input-composition.js +86 -0
- package/src/tooling/source-runtime-bundle.d.ts +40 -5
- package/src/tooling/source-runtime-bundle.js +261 -34
- package/src/tooling/source-runtime-materializer.d.ts +6 -0
- package/src/tooling/source-runtime-materializer.js +93 -0
- package/src/training/attention-backward.js +32 -17
- package/src/training/autograd.js +80 -52
- package/src/training/checkpoint-watch.d.ts +8 -0
- package/src/training/checkpoint-watch.js +139 -0
- package/src/training/checkpoint.d.ts +6 -1
- package/src/training/checkpoint.js +46 -7
- package/src/training/clip.js +2 -1
- package/src/training/datasets/token-batch.js +20 -8
- package/src/training/distillation/artifacts.d.ts +71 -0
- package/src/training/distillation/artifacts.js +132 -0
- package/src/training/distillation/checkpoint-watch.d.ts +10 -0
- package/src/training/distillation/checkpoint-watch.js +58 -0
- package/src/training/distillation/dataset.d.ts +59 -0
- package/src/training/distillation/dataset.js +337 -0
- package/src/training/distillation/eval.d.ts +34 -0
- package/src/training/distillation/eval.js +310 -0
- package/src/training/distillation/index.d.ts +29 -0
- package/src/training/distillation/index.js +29 -0
- package/src/training/distillation/runtime.d.ts +20 -0
- package/src/training/distillation/runtime.js +121 -0
- package/src/training/distillation/scoreboard.d.ts +6 -0
- package/src/training/distillation/scoreboard.js +8 -0
- package/src/training/distillation/stage-a.d.ts +45 -0
- package/src/training/distillation/stage-a.js +338 -0
- package/src/training/distillation/stage-b.d.ts +24 -0
- package/src/training/distillation/stage-b.js +20 -0
- package/src/training/distillation/student-fixture.d.ts +22 -0
- package/src/training/distillation/student-fixture.js +846 -0
- package/src/training/distillation/suite-data.d.ts +45 -0
- package/src/training/distillation/suite-data.js +189 -0
- package/src/training/index.d.ts +10 -0
- package/src/training/index.js +10 -0
- package/src/training/lora-pipeline.d.ts +40 -0
- package/src/training/lora-pipeline.js +793 -0
- package/src/training/lora.js +26 -12
- package/src/training/loss.js +5 -6
- package/src/training/objectives/cross_entropy.js +2 -5
- package/src/training/objectives/distill_kd.js +4 -8
- package/src/training/objectives/distill_triplet.js +4 -8
- package/src/training/objectives/ul_stage2_base.js +4 -8
- package/src/training/operator-artifacts.d.ts +62 -0
- package/src/training/operator-artifacts.js +140 -0
- package/src/training/operator-command.d.ts +5 -0
- package/src/training/operator-command.js +455 -0
- package/src/training/operator-eval.d.ts +48 -0
- package/src/training/operator-eval.js +230 -0
- package/src/training/operator-scoreboard.d.ts +5 -0
- package/src/training/operator-scoreboard.js +44 -0
- package/src/training/optimizer.js +19 -7
- package/src/training/runner.d.ts +52 -0
- package/src/training/runner.js +31 -5
- package/src/training/suite.d.ts +112 -0
- package/src/training/suite.js +24 -984
- package/src/training/tensor-factory.d.ts +9 -0
- package/src/training/tensor-factory.js +13 -0
- package/src/training/trainer.js +3 -5
- package/src/training/ul_dataset.js +3 -5
- package/src/training/workloads.d.ts +164 -0
- package/src/training/workloads.js +530 -0
- package/src/version.js +1 -1
- package/tools/convert-safetensors-node.js +22 -16
- package/tools/doppler-cli.js +179 -63
|
@@ -26,8 +26,8 @@ struct Uniforms {
|
|
|
26
26
|
start_pos: u32, // Starting position (for decode)
|
|
27
27
|
rope_base: f32, // Base frequency (default 10000)
|
|
28
28
|
rope_scale: f32, // Scaling factor for extended context
|
|
29
|
-
|
|
30
|
-
|
|
29
|
+
rotary_dim: u32, // Rotary slice within head_dim
|
|
30
|
+
interleaved: u32, // 1 = adjacent pairs, 0 = rotate-half
|
|
31
31
|
}
|
|
32
32
|
|
|
33
33
|
@group(0) @binding(0) var<uniform> u: Uniforms;
|
|
@@ -46,7 +46,8 @@ fn main(
|
|
|
46
46
|
let start_pos = u.start_pos;
|
|
47
47
|
|
|
48
48
|
// Global thread index (one thread per complex pair)
|
|
49
|
-
let
|
|
49
|
+
let rotary_dim = u.rotary_dim;
|
|
50
|
+
let half_dim = rotary_dim / 2u;
|
|
50
51
|
let total_pairs = seq_len * num_heads * half_dim;
|
|
51
52
|
let idx = global_id.x;
|
|
52
53
|
|
|
@@ -68,16 +69,18 @@ fn main(
|
|
|
68
69
|
|
|
69
70
|
// Apply "rotate-half" layout: pair (x[i], x[i + half_dim])
|
|
70
71
|
let base_idx = pos * num_heads * head_dim + head_idx * head_dim;
|
|
71
|
-
let
|
|
72
|
-
let
|
|
72
|
+
let first_idx = select(pair_idx, pair_idx * 2u, u.interleaved == 1u);
|
|
73
|
+
let second_idx = select(pair_idx + half_dim, pair_idx * 2u + 1u, u.interleaved == 1u);
|
|
74
|
+
let x0 = input[base_idx + first_idx];
|
|
75
|
+
let x1 = input[base_idx + second_idx];
|
|
73
76
|
|
|
74
77
|
// Apply rotation
|
|
75
78
|
let y0 = x0 * cos_val - x1 * sin_val;
|
|
76
79
|
let y1 = x0 * sin_val + x1 * cos_val;
|
|
77
80
|
|
|
78
81
|
// Write back
|
|
79
|
-
input[base_idx +
|
|
80
|
-
input[base_idx +
|
|
82
|
+
input[base_idx + first_idx] = y0;
|
|
83
|
+
input[base_idx + second_idx] = y1;
|
|
81
84
|
}
|
|
82
85
|
|
|
83
86
|
// Compute frequencies on-the-fly (no precomputation needed)
|
|
@@ -91,9 +94,10 @@ fn rope_compute_freqs(
|
|
|
91
94
|
let start_pos = u.start_pos;
|
|
92
95
|
let rope_base = u.rope_base;
|
|
93
96
|
let rope_scale = u.rope_scale;
|
|
97
|
+
let rotary_dim = u.rotary_dim;
|
|
94
98
|
|
|
95
99
|
let idx = global_id.x;
|
|
96
|
-
let half_dim =
|
|
100
|
+
let half_dim = rotary_dim / 2u;
|
|
97
101
|
let total_pairs = seq_len * num_heads * half_dim;
|
|
98
102
|
|
|
99
103
|
if (idx >= total_pairs) {
|
|
@@ -109,7 +113,7 @@ fn rope_compute_freqs(
|
|
|
109
113
|
let actual_pos = f32(start_pos + pos) / rope_scale;
|
|
110
114
|
|
|
111
115
|
// Compute frequency: 1 / (base^(2*pair_idx/head_dim))
|
|
112
|
-
let exponent = f32(pair_idx * 2u) / f32(
|
|
116
|
+
let exponent = f32(pair_idx * 2u) / f32(rotary_dim);
|
|
113
117
|
let freq = 1.0 / pow(rope_base, exponent);
|
|
114
118
|
let theta = actual_pos * freq;
|
|
115
119
|
|
|
@@ -118,12 +122,14 @@ fn rope_compute_freqs(
|
|
|
118
122
|
|
|
119
123
|
// Apply "rotate-half" layout: pair (x[i], x[i + half_dim])
|
|
120
124
|
let base_idx = pos * num_heads * head_dim + head_idx * head_dim;
|
|
121
|
-
let
|
|
122
|
-
let
|
|
125
|
+
let first_idx = select(pair_idx, pair_idx * 2u, u.interleaved == 1u);
|
|
126
|
+
let second_idx = select(pair_idx + half_dim, pair_idx * 2u + 1u, u.interleaved == 1u);
|
|
127
|
+
let x0 = input[base_idx + first_idx];
|
|
128
|
+
let x1 = input[base_idx + second_idx];
|
|
123
129
|
|
|
124
130
|
// Apply rotation
|
|
125
|
-
input[base_idx +
|
|
126
|
-
input[base_idx +
|
|
131
|
+
input[base_idx + first_idx] = x0 * cos_val - x1 * sin_val;
|
|
132
|
+
input[base_idx + second_idx] = x0 * sin_val + x1 * cos_val;
|
|
127
133
|
}
|
|
128
134
|
|
|
129
135
|
// Apply RoPE to both Q and K in one pass
|
|
@@ -138,10 +144,11 @@ fn rope_qk(
|
|
|
138
144
|
let start_pos = u.start_pos;
|
|
139
145
|
let rope_base = u.rope_base;
|
|
140
146
|
let rope_scale = u.rope_scale;
|
|
147
|
+
let rotary_dim = u.rotary_dim;
|
|
141
148
|
|
|
142
149
|
let idx = global_id.x;
|
|
143
150
|
// Each thread handles one Q-K pair at one dimension pair
|
|
144
|
-
let half_dim =
|
|
151
|
+
let half_dim = rotary_dim / 2u;
|
|
145
152
|
let total_pairs = seq_len * num_heads * half_dim;
|
|
146
153
|
|
|
147
154
|
if (idx >= total_pairs) {
|
|
@@ -156,7 +163,7 @@ fn rope_qk(
|
|
|
156
163
|
let actual_pos = f32(start_pos + pos) / rope_scale;
|
|
157
164
|
|
|
158
165
|
// Compute frequency
|
|
159
|
-
let exponent = f32(pair_idx * 2u) / f32(
|
|
166
|
+
let exponent = f32(pair_idx * 2u) / f32(rotary_dim);
|
|
160
167
|
let freq = 1.0 / pow(rope_base, exponent);
|
|
161
168
|
let theta = actual_pos * freq;
|
|
162
169
|
|
|
@@ -168,16 +175,18 @@ fn rope_qk(
|
|
|
168
175
|
let k_base_idx = q_base_idx + head_dim; // K starts after Q
|
|
169
176
|
|
|
170
177
|
// Process Q
|
|
171
|
-
let
|
|
172
|
-
let
|
|
173
|
-
input[q_base_idx +
|
|
174
|
-
input[q_base_idx +
|
|
178
|
+
let first_idx = select(pair_idx, pair_idx * 2u, u.interleaved == 1u);
|
|
179
|
+
let second_idx = select(pair_idx + half_dim, pair_idx * 2u + 1u, u.interleaved == 1u);
|
|
180
|
+
let q0 = input[q_base_idx + first_idx];
|
|
181
|
+
let q1 = input[q_base_idx + second_idx];
|
|
182
|
+
input[q_base_idx + first_idx] = q0 * cos_val - q1 * sin_val;
|
|
183
|
+
input[q_base_idx + second_idx] = q0 * sin_val + q1 * cos_val;
|
|
175
184
|
|
|
176
185
|
// Process K
|
|
177
|
-
let k0 = input[k_base_idx +
|
|
178
|
-
let k1 = input[k_base_idx +
|
|
179
|
-
input[k_base_idx +
|
|
180
|
-
input[k_base_idx +
|
|
186
|
+
let k0 = input[k_base_idx + first_idx];
|
|
187
|
+
let k1 = input[k_base_idx + second_idx];
|
|
188
|
+
input[k_base_idx + first_idx] = k0 * cos_val - k1 * sin_val;
|
|
189
|
+
input[k_base_idx + second_idx] = k0 * sin_val + k1 * cos_val;
|
|
181
190
|
}
|
|
182
191
|
|
|
183
192
|
// Precompute frequency table (run once at init)
|
|
@@ -190,9 +199,10 @@ fn precompute_freqs(
|
|
|
190
199
|
let seq_len = u.seq_len; // maxSeqLen for precomputation
|
|
191
200
|
let rope_base = u.rope_base;
|
|
192
201
|
let rope_scale = u.rope_scale;
|
|
202
|
+
let rotary_dim = u.rotary_dim;
|
|
193
203
|
|
|
194
204
|
let idx = global_id.x;
|
|
195
|
-
let half_dim =
|
|
205
|
+
let half_dim = rotary_dim / 2u;
|
|
196
206
|
let total_elements = seq_len * half_dim;
|
|
197
207
|
|
|
198
208
|
if (idx >= total_elements) {
|
|
@@ -203,7 +213,7 @@ fn precompute_freqs(
|
|
|
203
213
|
let dim_idx = idx % half_dim;
|
|
204
214
|
|
|
205
215
|
let actual_pos = f32(pos) / rope_scale;
|
|
206
|
-
let exponent = f32(dim_idx * 2u) / f32(
|
|
216
|
+
let exponent = f32(dim_idx * 2u) / f32(rotary_dim);
|
|
207
217
|
let freq = 1.0 / pow(rope_base, exponent);
|
|
208
218
|
let theta = actual_pos * freq;
|
|
209
219
|
|
|
@@ -218,6 +228,7 @@ fn rope_ntk_scaled(
|
|
|
218
228
|
@builtin(global_invocation_id) global_id: vec3<u32>
|
|
219
229
|
) {
|
|
220
230
|
let head_dim = u.head_dim;
|
|
231
|
+
let rotary_dim = u.rotary_dim;
|
|
221
232
|
let num_heads = u.num_heads;
|
|
222
233
|
let seq_len = u.seq_len;
|
|
223
234
|
let start_pos = u.start_pos;
|
|
@@ -225,7 +236,7 @@ fn rope_ntk_scaled(
|
|
|
225
236
|
let rope_scale = u.rope_scale;
|
|
226
237
|
|
|
227
238
|
let idx = global_id.x;
|
|
228
|
-
let half_dim =
|
|
239
|
+
let half_dim = rotary_dim / 2u;
|
|
229
240
|
let total_pairs = seq_len * num_heads * half_dim;
|
|
230
241
|
|
|
231
242
|
if (idx >= total_pairs) {
|
|
@@ -234,7 +245,7 @@ fn rope_ntk_scaled(
|
|
|
234
245
|
|
|
235
246
|
// NTK scaling: increase base proportionally to scale factor
|
|
236
247
|
// This preserves high-frequency components better than linear interpolation
|
|
237
|
-
rope_base = rope_base * pow(rope_scale, f32(
|
|
248
|
+
rope_base = rope_base * pow(rope_scale, f32(rotary_dim) / (f32(rotary_dim) - 2.0));
|
|
238
249
|
|
|
239
250
|
let pos = idx / (num_heads * half_dim);
|
|
240
251
|
let remainder = idx % (num_heads * half_dim);
|
|
@@ -243,7 +254,7 @@ fn rope_ntk_scaled(
|
|
|
243
254
|
|
|
244
255
|
let actual_pos = f32(start_pos + pos);
|
|
245
256
|
|
|
246
|
-
let exponent = f32(pair_idx * 2u) / f32(
|
|
257
|
+
let exponent = f32(pair_idx * 2u) / f32(rotary_dim);
|
|
247
258
|
let freq = 1.0 / pow(rope_base, exponent);
|
|
248
259
|
let theta = actual_pos * freq;
|
|
249
260
|
|
|
@@ -251,11 +262,13 @@ fn rope_ntk_scaled(
|
|
|
251
262
|
let sin_val = sin(theta);
|
|
252
263
|
|
|
253
264
|
let base_idx = pos * num_heads * head_dim + head_idx * head_dim;
|
|
254
|
-
let
|
|
255
|
-
let
|
|
265
|
+
let first_idx = select(pair_idx, pair_idx * 2u, u.interleaved == 1u);
|
|
266
|
+
let second_idx = select(pair_idx + half_dim, pair_idx * 2u + 1u, u.interleaved == 1u);
|
|
267
|
+
let x0 = input[base_idx + first_idx];
|
|
268
|
+
let x1 = input[base_idx + second_idx];
|
|
256
269
|
|
|
257
|
-
input[base_idx +
|
|
258
|
-
input[base_idx +
|
|
270
|
+
input[base_idx + first_idx] = x0 * cos_val - x1 * sin_val;
|
|
271
|
+
input[base_idx + second_idx] = x0 * sin_val + x1 * cos_val;
|
|
259
272
|
}
|
|
260
273
|
|
|
261
274
|
// YaRN-style RoPE with attention scaling
|
|
@@ -265,6 +278,7 @@ fn rope_yarn(
|
|
|
265
278
|
@builtin(global_invocation_id) global_id: vec3<u32>
|
|
266
279
|
) {
|
|
267
280
|
let head_dim = u.head_dim;
|
|
281
|
+
let rotary_dim = u.rotary_dim;
|
|
268
282
|
let num_heads = u.num_heads;
|
|
269
283
|
let seq_len = u.seq_len;
|
|
270
284
|
let start_pos = u.start_pos;
|
|
@@ -272,7 +286,7 @@ fn rope_yarn(
|
|
|
272
286
|
let rope_scale = u.rope_scale;
|
|
273
287
|
|
|
274
288
|
let idx = global_id.x;
|
|
275
|
-
let half_dim =
|
|
289
|
+
let half_dim = rotary_dim / 2u;
|
|
276
290
|
let total_pairs = seq_len * num_heads * half_dim;
|
|
277
291
|
|
|
278
292
|
if (idx >= total_pairs) {
|
|
@@ -292,7 +306,7 @@ fn rope_yarn(
|
|
|
292
306
|
let alpha: f32 = 1.0;
|
|
293
307
|
|
|
294
308
|
// Compute original frequency
|
|
295
|
-
let exponent = f32(pair_idx * 2u) / f32(
|
|
309
|
+
let exponent = f32(pair_idx * 2u) / f32(rotary_dim);
|
|
296
310
|
let orig_freq = 1.0 / pow(rope_base, exponent);
|
|
297
311
|
|
|
298
312
|
// Compute wavelength
|
|
@@ -300,8 +314,8 @@ fn rope_yarn(
|
|
|
300
314
|
|
|
301
315
|
// Interpolation factor based on wavelength
|
|
302
316
|
var ramp: f32;
|
|
303
|
-
let low_wavelength = f32(
|
|
304
|
-
let high_wavelength = f32(
|
|
317
|
+
let low_wavelength = f32(rotary_dim) / beta_fast;
|
|
318
|
+
let high_wavelength = f32(rotary_dim) / beta_slow;
|
|
305
319
|
|
|
306
320
|
if (wavelength < low_wavelength) {
|
|
307
321
|
ramp = 0.0; // No interpolation for high frequencies
|
|
@@ -320,9 +334,11 @@ fn rope_yarn(
|
|
|
320
334
|
let sin_val = sin(theta);
|
|
321
335
|
|
|
322
336
|
let base_idx = pos * num_heads * head_dim + head_idx * head_dim;
|
|
323
|
-
let
|
|
324
|
-
let
|
|
337
|
+
let first_idx = select(pair_idx, pair_idx * 2u, u.interleaved == 1u);
|
|
338
|
+
let second_idx = select(pair_idx + half_dim, pair_idx * 2u + 1u, u.interleaved == 1u);
|
|
339
|
+
let x0 = input[base_idx + first_idx];
|
|
340
|
+
let x1 = input[base_idx + second_idx];
|
|
325
341
|
|
|
326
|
-
input[base_idx +
|
|
327
|
-
input[base_idx +
|
|
342
|
+
input[base_idx + first_idx] = x0 * cos_val - x1 * sin_val;
|
|
343
|
+
input[base_idx + second_idx] = x0 * sin_val + x1 * cos_val;
|
|
328
344
|
}
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
|
|
2
2
|
|
|
3
3
|
import { getDevice, getKernelCapabilities } from '../device.js';
|
|
4
|
-
import { acquireBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
|
|
4
|
+
import { acquireBuffer, readBufferSlice, releaseBuffer } from '../../memory/buffer-pool.js';
|
|
5
5
|
import { WORKGROUP_SIZES } from './constants.js';
|
|
6
6
|
import { createPipeline, createUniformBufferWithView, getOrCreateBindGroupLayout } from './utils.js';
|
|
7
7
|
import { allowReadback } from '../perf-guards.js';
|
|
@@ -156,18 +156,19 @@ function ensureOutputBufferSize(outputBuffer, minBytes, label) {
|
|
|
156
156
|
}
|
|
157
157
|
}
|
|
158
158
|
|
|
159
|
-
function readTokenFromOutput(
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
size: 4,
|
|
163
|
-
usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST,
|
|
164
|
-
});
|
|
165
|
-
|
|
166
|
-
const copyEncoder = device.createCommandEncoder({ label: `${label}_copy` });
|
|
167
|
-
copyEncoder.copyBufferToBuffer(outputBuffer, outputIndex * 4, stagingBuffer, 0, 4);
|
|
168
|
-
device.queue.submit([copyEncoder.finish()]);
|
|
159
|
+
async function readTokenFromOutput(outputBuffer, outputIndex) {
|
|
160
|
+
return new Uint32Array(await readBufferSlice(outputBuffer, outputIndex * 4, 4))[0];
|
|
161
|
+
}
|
|
169
162
|
|
|
170
|
-
|
|
163
|
+
function cleanupRunResources(uniformBuffer, ownedBuffers) {
|
|
164
|
+
if (uniformBuffer) {
|
|
165
|
+
uniformBuffer.destroy();
|
|
166
|
+
}
|
|
167
|
+
for (const buffer of ownedBuffers) {
|
|
168
|
+
if (buffer) {
|
|
169
|
+
releaseBuffer(buffer);
|
|
170
|
+
}
|
|
171
|
+
}
|
|
171
172
|
}
|
|
172
173
|
|
|
173
174
|
async function executeArgmaxRun(logits, vocabSize, options) {
|
|
@@ -238,20 +239,14 @@ async function executeArgmaxRun(logits, vocabSize, options) {
|
|
|
238
239
|
|
|
239
240
|
device.queue.submit([encoder.finish()]);
|
|
240
241
|
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
releaseBuffer(tempLogits);
|
|
249
|
-
releaseBuffer(tempIndices);
|
|
250
|
-
if (ownsOutputBuffer) {
|
|
251
|
-
releaseBuffer(outputBuffer);
|
|
242
|
+
try {
|
|
243
|
+
return await readTokenFromOutput(outputBuffer, outputIndex);
|
|
244
|
+
} finally {
|
|
245
|
+
cleanupRunResources(
|
|
246
|
+
uniformBuffer,
|
|
247
|
+
[tempLogits, tempIndices, ownsOutputBuffer ? outputBuffer : null]
|
|
248
|
+
);
|
|
252
249
|
}
|
|
253
|
-
|
|
254
|
-
return tokenId;
|
|
255
250
|
}
|
|
256
251
|
|
|
257
252
|
async function executeArgmaxRecord(recorder, logits, vocabSize, options) {
|
|
@@ -428,20 +423,14 @@ export async function runGPUSample(
|
|
|
428
423
|
|
|
429
424
|
device.queue.submit([encoder.finish()]);
|
|
430
425
|
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
releaseBuffer(topkLogits);
|
|
439
|
-
releaseBuffer(topkIndices);
|
|
440
|
-
if (ownsOutputBuffer) {
|
|
441
|
-
releaseBuffer(outputBuffer);
|
|
426
|
+
try {
|
|
427
|
+
return await readTokenFromOutput(outputBuffer, outputIndex);
|
|
428
|
+
} finally {
|
|
429
|
+
cleanupRunResources(
|
|
430
|
+
uniformBuffer,
|
|
431
|
+
[topkLogits, topkIndices, ownsOutputBuffer ? outputBuffer : null]
|
|
432
|
+
);
|
|
442
433
|
}
|
|
443
|
-
|
|
444
|
-
return tokenId;
|
|
445
434
|
}
|
|
446
435
|
|
|
447
436
|
|
|
@@ -29,7 +29,6 @@ async function runSummary(target, query, key, value, summaryBuffer, uniforms, va
|
|
|
29
29
|
}
|
|
30
30
|
|
|
31
31
|
async function runApply(target, query, summaryBuffer, outputBuffer, uniforms, variant) {
|
|
32
|
-
const outputSize = uniforms.num_tokens * uniforms.hidden_size;
|
|
33
32
|
await unifiedKernelWrapper(
|
|
34
33
|
'sana_linear_attention_apply',
|
|
35
34
|
target,
|
|
@@ -45,7 +44,7 @@ async function runApply(target, query, summaryBuffer, outputBuffer, uniforms, va
|
|
|
45
44
|
_pad1: 0,
|
|
46
45
|
_pad2: 0,
|
|
47
46
|
},
|
|
48
|
-
Math.ceil(
|
|
47
|
+
[Math.ceil(uniforms.hidden_size / WORKGROUP_SIZES.DEFAULT), uniforms.num_tokens, 1]
|
|
49
48
|
);
|
|
50
49
|
}
|
|
51
50
|
|
|
@@ -65,6 +64,8 @@ async function _sanaLinearAttention(target, query, key, value, options = {}) {
|
|
|
65
64
|
outputBuffer = null,
|
|
66
65
|
summaryBuffer = null,
|
|
67
66
|
} = options;
|
|
67
|
+
const ownsSummary = summaryBuffer == null;
|
|
68
|
+
const ownsOutput = outputBuffer == null;
|
|
68
69
|
|
|
69
70
|
if (
|
|
70
71
|
!Number.isFinite(numHeads) ||
|
|
@@ -99,18 +100,24 @@ async function _sanaLinearAttention(target, query, key, value, options = {}) {
|
|
|
99
100
|
eps,
|
|
100
101
|
};
|
|
101
102
|
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
103
|
+
try {
|
|
104
|
+
await runSummary(target, query, key, value, temporarySummary, uniforms, variant);
|
|
105
|
+
await runApply(target, query, temporarySummary, output, uniforms, variant);
|
|
106
|
+
return createTensor(output, query.dtype, [numTokens, hiddenSize], 'sana_linear_attention_output');
|
|
107
|
+
} catch (error) {
|
|
108
|
+
if (ownsOutput) {
|
|
109
|
+
releaseBuffer(output);
|
|
110
|
+
}
|
|
111
|
+
throw error;
|
|
112
|
+
} finally {
|
|
113
|
+
if (ownsSummary) {
|
|
114
|
+
if (recorder) {
|
|
115
|
+
recorder.trackTemporaryBuffer(temporarySummary);
|
|
116
|
+
} else {
|
|
117
|
+
releaseBuffer(temporarySummary);
|
|
118
|
+
}
|
|
110
119
|
}
|
|
111
120
|
}
|
|
112
|
-
|
|
113
|
-
return createTensor(output, query.dtype, [numTokens, hiddenSize], 'sana_linear_attention_output');
|
|
114
121
|
}
|
|
115
122
|
|
|
116
123
|
export async function runSanaLinearAttention(query, key, value, options = {}) {
|
|
@@ -18,14 +18,13 @@ struct Uniforms {
|
|
|
18
18
|
|
|
19
19
|
@compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
|
|
20
20
|
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
21
|
-
let
|
|
22
|
-
let
|
|
23
|
-
if (
|
|
21
|
+
let hidden = gid.x;
|
|
22
|
+
let token = gid.y;
|
|
23
|
+
if (token >= u.num_tokens || hidden >= u.hidden_size) {
|
|
24
24
|
return;
|
|
25
25
|
}
|
|
26
26
|
|
|
27
|
-
let
|
|
28
|
-
let hidden = idx - token * u.hidden_size;
|
|
27
|
+
let idx = token * u.hidden_size + hidden;
|
|
29
28
|
let head = hidden / u.head_dim;
|
|
30
29
|
let dim = hidden - head * u.head_dim;
|
|
31
30
|
let rows_per_head = u.head_dim + 1u;
|
|
@@ -20,14 +20,13 @@ struct Uniforms {
|
|
|
20
20
|
|
|
21
21
|
@compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
|
|
22
22
|
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
23
|
-
let
|
|
24
|
-
let
|
|
25
|
-
if (
|
|
23
|
+
let hidden = gid.x;
|
|
24
|
+
let token = gid.y;
|
|
25
|
+
if (token >= u.num_tokens || hidden >= u.hidden_size) {
|
|
26
26
|
return;
|
|
27
27
|
}
|
|
28
28
|
|
|
29
|
-
let
|
|
30
|
-
let hidden = idx - token * u.hidden_size;
|
|
29
|
+
let idx = token * u.hidden_size + hidden;
|
|
31
30
|
let head = hidden / u.head_dim;
|
|
32
31
|
let dim = hidden - head * u.head_dim;
|
|
33
32
|
let rows_per_head = u.head_dim + 1u;
|
|
@@ -33,6 +33,7 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
|
33
33
|
|
|
34
34
|
var acc: f32 = 0.0;
|
|
35
35
|
for (var token: u32 = 0u; token < u.num_tokens; token = token + 1u) {
|
|
36
|
+
let query_value = query[token * u.hidden_size + hidden_base + col];
|
|
36
37
|
let key_idx = token * u.hidden_size + hidden_base + col;
|
|
37
38
|
let key_value = max(key[key_idx], 0.0);
|
|
38
39
|
let value_value = select(
|
|
@@ -40,6 +41,9 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
|
40
41
|
1.0,
|
|
41
42
|
row == u.head_dim
|
|
42
43
|
);
|
|
44
|
+
if (u.hidden_size == 0u) {
|
|
45
|
+
acc = acc + query_value;
|
|
46
|
+
}
|
|
43
47
|
acc = acc + value_value * key_value;
|
|
44
48
|
}
|
|
45
49
|
|
|
@@ -35,6 +35,7 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
|
35
35
|
|
|
36
36
|
var acc: f32 = 0.0;
|
|
37
37
|
for (var token: u32 = 0u; token < u.num_tokens; token = token + 1u) {
|
|
38
|
+
let query_value = f32(query[token * u.hidden_size + hidden_base + col]);
|
|
38
39
|
let key_idx = token * u.hidden_size + hidden_base + col;
|
|
39
40
|
let key_value = max(f32(key[key_idx]), 0.0);
|
|
40
41
|
let value_value = select(
|
|
@@ -42,6 +43,9 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
|
42
43
|
1.0,
|
|
43
44
|
row == u.head_dim
|
|
44
45
|
);
|
|
46
|
+
if (u.hidden_size == 0u) {
|
|
47
|
+
acc = acc + query_value;
|
|
48
|
+
}
|
|
45
49
|
acc = acc + value_value * key_value;
|
|
46
50
|
}
|
|
47
51
|
|
package/src/gpu/kernels/scale.js
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { acquireBuffer } from '../../memory/buffer-pool.js';
|
|
1
|
+
import { acquireBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
|
|
2
2
|
import { createTensor, dtypeBytes } from '../tensor.js';
|
|
3
3
|
import { WORKGROUP_SIZES } from './constants.js';
|
|
4
4
|
import { unifiedKernelWrapper } from './utils.js';
|
|
@@ -6,6 +6,7 @@ import { selectRuleValue } from './rule-registry.js';
|
|
|
6
6
|
|
|
7
7
|
async function _scale(target, input, scale, options = {}) {
|
|
8
8
|
const { count, outputBuffer = null, inplace = false } = options;
|
|
9
|
+
const ownsOutput = !inplace && outputBuffer == null;
|
|
9
10
|
|
|
10
11
|
const bytesPerElement = dtypeBytes(input.dtype);
|
|
11
12
|
const inferredCount = count ?? Math.floor(input.buffer.size / bytesPerElement);
|
|
@@ -16,16 +17,22 @@ async function _scale(target, input, scale, options = {}) {
|
|
|
16
17
|
|
|
17
18
|
const bindings = inplace ? [outputBuf, outputBuf] : [input, outputBuf];
|
|
18
19
|
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
20
|
+
try {
|
|
21
|
+
await unifiedKernelWrapper(
|
|
22
|
+
'scale',
|
|
23
|
+
target,
|
|
24
|
+
variant,
|
|
25
|
+
bindings,
|
|
26
|
+
{ size: inferredCount, scale },
|
|
27
|
+
Math.ceil(inferredCount / WORKGROUP_SIZES.DEFAULT)
|
|
28
|
+
);
|
|
29
|
+
return createTensor(outputBuf, input.dtype, [...input.shape], 'scale_output');
|
|
30
|
+
} catch (error) {
|
|
31
|
+
if (ownsOutput) {
|
|
32
|
+
releaseBuffer(outputBuf);
|
|
33
|
+
}
|
|
34
|
+
throw error;
|
|
35
|
+
}
|
|
29
36
|
}
|
|
30
37
|
|
|
31
38
|
export async function runScale(input, scale, options = {}) {
|
|
@@ -138,8 +138,10 @@ export async function compileShader(
|
|
|
138
138
|
code: source,
|
|
139
139
|
});
|
|
140
140
|
|
|
141
|
-
// Check for compilation errors
|
|
142
|
-
const compilationInfo =
|
|
141
|
+
// Check for compilation errors (getCompilationInfo not available in all WebGPU providers)
|
|
142
|
+
const compilationInfo = typeof module.getCompilationInfo === 'function'
|
|
143
|
+
? await module.getCompilationInfo()
|
|
144
|
+
: { messages: [] };
|
|
143
145
|
if (compilationInfo.messages.length > 0) {
|
|
144
146
|
for (const msg of compilationInfo.messages) {
|
|
145
147
|
if (msg.type === 'error') {
|
|
@@ -16,6 +16,7 @@ export interface SiLUOptions extends OutputBufferOptions {
|
|
|
16
16
|
size?: number | null;
|
|
17
17
|
gate?: Tensor | null;
|
|
18
18
|
gateActivation?: 'silu' | 'sigmoid';
|
|
19
|
+
inputActivation?: 'silu' | 'identity';
|
|
19
20
|
useVec4?: boolean;
|
|
20
21
|
biasOffset?: number;
|
|
21
22
|
swigluLimit: number | null;
|