@simulatte/doppler 0.1.5 → 0.1.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +126 -0
- package/README.md +25 -17
- package/package.json +20 -4
- package/src/adapters/adapter-registry.js +12 -1
- package/src/adapters/lora-loader.js +23 -6
- package/src/bridge/extension-client.d.ts +5 -0
- package/src/bridge/extension-client.js +40 -0
- package/src/bridge/index.d.ts +2 -1
- package/src/bridge/index.js +6 -4
- package/src/browser/browser-converter.js +26 -1
- package/src/browser/file-picker.js +6 -0
- package/src/browser/safetensors-parser-browser.js +84 -1
- package/src/browser/shard-io-browser.js +2 -2
- package/src/browser/tensor-source-download.js +8 -2
- package/src/browser/tensor-source-http.d.ts +1 -0
- package/src/browser/tensor-source-http.js +5 -1
- package/src/client/doppler-api.browser.js +20 -4
- package/src/client/doppler-api.js +19 -3
- package/src/client/doppler-provider/generation.js +12 -0
- package/src/client/doppler-provider/model-manager.d.ts +10 -0
- package/src/client/doppler-provider/model-manager.js +91 -19
- package/src/client/doppler-provider/source-runtime.d.ts +2 -1
- package/src/client/doppler-provider/source-runtime.js +132 -13
- package/src/client/doppler-registry.json +8 -7
- package/src/config/backward-registry-loader.js +17 -2
- package/src/config/execution-v0-contract-check.js +113 -15
- package/src/config/kernel-path-contract-check.js +57 -29
- package/src/config/kernel-path-loader.js +5 -36
- package/src/config/kernels/kernel-ref-digests.js +39 -39
- package/src/config/kernels/registry.js +14 -1
- package/src/config/kernels/registry.json +49 -7
- package/src/config/loader.d.ts +1 -1
- package/src/config/loader.js +43 -4
- package/src/config/merge-contract-check.js +59 -4
- package/src/config/merge-helpers.js +128 -7
- package/src/config/merge.d.ts +1 -0
- package/src/config/merge.js +28 -0
- package/src/config/param-validator.js +47 -2
- package/src/config/presets/kernel-paths/{gemma2-q4k-dequant-f32a.json → gemma2-q4k-dequant-f32a-nosubgroups.json} +3 -3
- package/src/config/presets/kernel-paths/gemma3-f16-fused-f32a-online-streamingprefill.json +223 -0
- package/src/config/presets/kernel-paths/{gemma3-q4k-dequant-f32a.json → gemma3-q4k-dequant-f32a-nosubgroups.json} +3 -3
- package/src/config/presets/kernel-paths/registry.json +29 -8
- package/src/config/presets/models/gemma2.json +2 -2
- package/src/config/presets/models/qwen3.json +9 -2
- package/src/config/presets/models/transformer.json +5 -0
- package/src/config/presets/runtime/experiments/bench/gemma3-bench-q4k.json +1 -1
- package/src/config/presets/runtime/experiments/debug/gemma3-debug-q4k.json +1 -1
- package/src/config/presets/runtime/experiments/verify/gemma3-verify.json +1 -1
- package/src/config/presets/runtime/kernels/dequant-f16-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/dequant-f32-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/embeddinggemma-q4k-dequant-f32a.json +37 -0
- package/src/config/presets/runtime/kernels/fused-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f16a.json +33 -0
- package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f32a-nosubgroups.json +33 -0
- package/src/config/presets/runtime/kernels/gemma2-q4k-fused-f32a.json +33 -0
- package/src/config/presets/runtime/kernels/safe-q4k.json +6 -13
- package/src/config/presets/runtime/platform/metal-apple-q4k.json +1 -1
- package/src/config/required-inference-fields-contract-check.js +6 -0
- package/src/config/runtime.js +6 -1
- package/src/config/schema/debug.schema.d.ts +5 -0
- package/src/config/schema/doppler.schema.js +16 -21
- package/src/config/schema/inference-defaults.schema.js +6 -3
- package/src/config/schema/inference.schema.d.ts +9 -0
- package/src/config/schema/kernel-path.schema.d.ts +11 -1
- package/src/config/schema/kernel-thresholds.schema.js +12 -4
- package/src/config/schema/manifest.schema.d.ts +8 -1
- package/src/config/schema/manifest.schema.js +19 -3
- package/src/config/training-defaults.js +30 -22
- package/src/converter/conversion-plan.js +94 -9
- package/src/converter/core.d.ts +7 -0
- package/src/converter/core.js +14 -9
- package/src/converter/execution-v0-manifest.js +4 -1
- package/src/converter/index.d.ts +1 -0
- package/src/converter/index.js +1 -0
- package/src/converter/manifest-inference.js +43 -12
- package/src/converter/parsers/diffusion.js +0 -3
- package/src/converter/quantization-info.js +35 -15
- package/src/converter/rope-config.js +42 -0
- package/src/converter/shard-packer.d.ts +1 -1
- package/src/converter/shard-packer.js +4 -1
- package/src/debug/config.js +123 -11
- package/src/debug/signals.js +7 -1
- package/src/debug/tensor.d.ts +2 -0
- package/src/debug/tensor.js +13 -2
- package/src/distribution/p2p-control-plane.js +52 -12
- package/src/distribution/p2p-observability.js +43 -7
- package/src/distribution/p2p-webrtc-browser.js +20 -0
- package/src/distribution/shard-delivery.js +77 -26
- package/src/formats/gguf/types.js +33 -16
- package/src/formats/rdrr/groups.d.ts +12 -4
- package/src/formats/rdrr/groups.js +3 -6
- package/src/formats/rdrr/parsing.js +39 -2
- package/src/formats/rdrr/types.d.ts +2 -1
- package/src/gpu/command-recorder.js +86 -61
- package/src/gpu/device.d.ts +1 -0
- package/src/gpu/device.js +131 -19
- package/src/gpu/kernel-tuner/benchmarks.js +326 -316
- package/src/gpu/kernel-tuner/cache.js +71 -4
- package/src/gpu/kernel-tuner/tuner.js +22 -4
- package/src/gpu/kernels/attention.js +113 -34
- package/src/gpu/kernels/backward/adam.js +62 -58
- package/src/gpu/kernels/backward/attention_backward.js +257 -169
- package/src/gpu/kernels/backward/conv2d_backward.js +14 -1
- package/src/gpu/kernels/bias_add.wgsl +8 -6
- package/src/gpu/kernels/bias_add_f16.wgsl +8 -5
- package/src/gpu/kernels/cast.js +191 -149
- package/src/gpu/kernels/check-stop.js +33 -44
- package/src/gpu/kernels/conv2d.js +27 -17
- package/src/gpu/kernels/conv2d.wgsl +7 -8
- package/src/gpu/kernels/conv2d_f16.wgsl +7 -8
- package/src/gpu/kernels/cross_entropy_loss.js +21 -15
- package/src/gpu/kernels/depthwise_conv2d.js +37 -26
- package/src/gpu/kernels/depthwise_conv2d.wgsl +6 -9
- package/src/gpu/kernels/depthwise_conv2d_f16.wgsl +6 -9
- package/src/gpu/kernels/dequant.js +178 -126
- package/src/gpu/kernels/energy.d.ts +3 -21
- package/src/gpu/kernels/energy.js +111 -88
- package/src/gpu/kernels/feature-check.js +1 -1
- package/src/gpu/kernels/fused_ffn.js +84 -65
- package/src/gpu/kernels/fused_matmul_residual.js +56 -33
- package/src/gpu/kernels/fused_matmul_rmsnorm.js +62 -45
- package/src/gpu/kernels/gather.js +33 -15
- package/src/gpu/kernels/gelu.js +19 -11
- package/src/gpu/kernels/grouped_pointwise_conv2d.js +34 -23
- package/src/gpu/kernels/grouped_pointwise_conv2d.wgsl +6 -9
- package/src/gpu/kernels/grouped_pointwise_conv2d_f16.wgsl +6 -9
- package/src/gpu/kernels/groupnorm.js +34 -23
- package/src/gpu/kernels/kv-quantize.js +5 -2
- package/src/gpu/kernels/layernorm.js +35 -19
- package/src/gpu/kernels/logit-merge.js +5 -3
- package/src/gpu/kernels/matmul.js +83 -39
- package/src/gpu/kernels/modulate.js +23 -15
- package/src/gpu/kernels/moe.js +221 -175
- package/src/gpu/kernels/pixel_shuffle.js +22 -14
- package/src/gpu/kernels/pixel_shuffle.wgsl +4 -5
- package/src/gpu/kernels/pixel_shuffle_f16.wgsl +4 -5
- package/src/gpu/kernels/relu.js +31 -10
- package/src/gpu/kernels/relu.wgsl +2 -1
- package/src/gpu/kernels/relu_f16.wgsl +2 -1
- package/src/gpu/kernels/repeat_channels.js +25 -17
- package/src/gpu/kernels/repeat_channels.wgsl +4 -5
- package/src/gpu/kernels/repeat_channels_f16.wgsl +4 -5
- package/src/gpu/kernels/residual.js +69 -23
- package/src/gpu/kernels/residual.wgsl +6 -3
- package/src/gpu/kernels/residual_f16.wgsl +2 -1
- package/src/gpu/kernels/residual_f16_vec4.wgsl +2 -1
- package/src/gpu/kernels/residual_vec4.wgsl +2 -1
- package/src/gpu/kernels/rmsnorm.js +96 -28
- package/src/gpu/kernels/rmsnorm.wgsl +14 -6
- package/src/gpu/kernels/rmsnorm_f16.wgsl +10 -2
- package/src/gpu/kernels/rope.d.ts +2 -0
- package/src/gpu/kernels/rope.js +14 -1
- package/src/gpu/kernels/rope.wgsl +56 -40
- package/src/gpu/kernels/sample.js +27 -38
- package/src/gpu/kernels/sana_linear_attention.js +19 -12
- package/src/gpu/kernels/sana_linear_attention_apply.wgsl +4 -5
- package/src/gpu/kernels/sana_linear_attention_apply_f16.wgsl +4 -5
- package/src/gpu/kernels/sana_linear_attention_summary.wgsl +4 -0
- package/src/gpu/kernels/sana_linear_attention_summary_f16.wgsl +4 -0
- package/src/gpu/kernels/scale.js +18 -11
- package/src/gpu/kernels/shader-cache.js +4 -2
- package/src/gpu/kernels/silu.d.ts +1 -0
- package/src/gpu/kernels/silu.js +148 -82
- package/src/gpu/kernels/silu.wgsl +19 -9
- package/src/gpu/kernels/silu_f16.wgsl +19 -9
- package/src/gpu/kernels/softmax.js +44 -25
- package/src/gpu/kernels/split_qkv.js +23 -13
- package/src/gpu/kernels/transpose.js +31 -10
- package/src/gpu/kernels/transpose.wgsl +6 -5
- package/src/gpu/kernels/upsample2d.js +22 -13
- package/src/gpu/kernels/upsample2d.wgsl +6 -9
- package/src/gpu/kernels/upsample2d_f16.wgsl +6 -9
- package/src/gpu/kernels/utils.js +35 -13
- package/src/gpu/partitioned-buffer-pool.js +10 -2
- package/src/gpu/perf-guards.js +2 -9
- package/src/gpu/profiler.js +27 -22
- package/src/gpu/readback-utils.d.ts +16 -0
- package/src/gpu/readback-utils.js +41 -0
- package/src/gpu/submit-tracker.js +13 -0
- package/src/gpu/uniform-cache.d.ts +1 -0
- package/src/gpu/uniform-cache.js +30 -9
- package/src/hotswap/intent-bundle.js +6 -0
- package/src/hotswap/manifest.d.ts +10 -1
- package/src/hotswap/manifest.js +12 -2
- package/src/hotswap/runtime.js +30 -8
- package/src/index-browser.d.ts +44 -0
- package/src/index-browser.js +14 -0
- package/src/inference/browser-harness-contract-helpers.d.ts +5 -0
- package/src/inference/browser-harness-contract-helpers.js +28 -0
- package/src/inference/browser-harness-diffusion-energy-suites.d.ts +2 -0
- package/src/inference/browser-harness-diffusion-energy-suites.js +269 -0
- package/src/inference/browser-harness-model-helpers.d.ts +16 -0
- package/src/inference/browser-harness-model-helpers.js +217 -0
- package/src/inference/browser-harness-report-helpers.d.ts +7 -0
- package/src/inference/browser-harness-report-helpers.js +42 -0
- package/src/inference/browser-harness-runtime-helpers.d.ts +61 -0
- package/src/inference/browser-harness-runtime-helpers.js +415 -0
- package/src/inference/browser-harness-suite-helpers.d.ts +28 -0
- package/src/inference/browser-harness-suite-helpers.js +268 -0
- package/src/inference/browser-harness-text-helpers.d.ts +27 -0
- package/src/inference/browser-harness-text-helpers.js +788 -0
- package/src/inference/browser-harness.d.ts +6 -0
- package/src/inference/browser-harness.js +130 -1950
- package/src/inference/kv-cache/base.js +140 -94
- package/src/inference/kv-cache/tiered.js +5 -3
- package/src/inference/moe-router.js +88 -56
- package/src/inference/multi-model-network.js +5 -3
- package/src/inference/network-evolution.d.ts +11 -2
- package/src/inference/network-evolution.js +20 -21
- package/src/inference/pipelines/context.d.ts +3 -0
- package/src/inference/pipelines/context.js +142 -2
- package/src/inference/pipelines/diffusion/helpers.js +7 -2
- package/src/inference/pipelines/diffusion/pipeline.js +17 -7
- package/src/inference/pipelines/diffusion/sd3-transformer.js +10 -10
- package/src/inference/pipelines/diffusion/text-encoder-gpu.d.ts +5 -0
- package/src/inference/pipelines/diffusion/text-encoder-gpu.js +27 -15
- package/src/inference/pipelines/diffusion/vae.js +3 -7
- package/src/inference/pipelines/energy/pipeline.js +27 -21
- package/src/inference/pipelines/energy/quintel.d.ts +5 -0
- package/src/inference/pipelines/energy/quintel.js +11 -0
- package/src/inference/pipelines/energy-head/row-head-pipeline.js +17 -13
- package/src/inference/pipelines/structured/json-head-pipeline.js +26 -11
- package/src/inference/pipelines/text/attention/projections.js +151 -101
- package/src/inference/pipelines/text/attention/record.js +73 -10
- package/src/inference/pipelines/text/attention/run.js +73 -10
- package/src/inference/pipelines/text/chat-format.js +25 -1
- package/src/inference/pipelines/text/config.d.ts +4 -0
- package/src/inference/pipelines/text/config.js +71 -5
- package/src/inference/pipelines/text/embed.js +2 -8
- package/src/inference/pipelines/text/execution-plan.js +64 -50
- package/src/inference/pipelines/text/execution-v0-contract-helpers.d.ts +59 -0
- package/src/inference/pipelines/text/execution-v0-contract-helpers.js +937 -0
- package/src/inference/pipelines/text/execution-v0-runtime-builders.d.ts +15 -0
- package/src/inference/pipelines/text/execution-v0-runtime-builders.js +279 -0
- package/src/inference/pipelines/text/execution-v0.js +78 -1002
- package/src/inference/pipelines/text/ffn/standard.js +3 -0
- package/src/inference/pipelines/text/generator-steps.d.ts +46 -0
- package/src/inference/pipelines/text/generator-steps.js +298 -207
- package/src/inference/pipelines/text/generator.js +6 -23
- package/src/inference/pipelines/text/init.d.ts +4 -0
- package/src/inference/pipelines/text/init.js +134 -29
- package/src/inference/pipelines/text/kernel-path-auto-select.js +2 -0
- package/src/inference/pipelines/text/kernel-trace.d.ts +2 -0
- package/src/inference/pipelines/text/kernel-trace.js +6 -0
- package/src/inference/pipelines/text/layer.js +14 -9
- package/src/inference/pipelines/text/linear-attention.d.ts +10 -0
- package/src/inference/pipelines/text/linear-attention.js +80 -6
- package/src/inference/pipelines/text/logits/gpu.js +10 -5
- package/src/inference/pipelines/text/logits/index.js +10 -11
- package/src/inference/pipelines/text/logits/utils.d.ts +7 -0
- package/src/inference/pipelines/text/logits/utils.js +9 -0
- package/src/inference/pipelines/text/lora-apply.js +50 -32
- package/src/inference/pipelines/text/model-load.js +279 -104
- package/src/inference/pipelines/text/moe-cache.js +5 -4
- package/src/inference/pipelines/text/moe-cpu-gptoss.js +74 -69
- package/src/inference/pipelines/text/moe-cpu.js +42 -38
- package/src/inference/pipelines/text/moe-gpu.js +110 -86
- package/src/inference/pipelines/text/ops.js +90 -90
- package/src/inference/pipelines/text/probes.js +9 -9
- package/src/inference/pipelines/text/weights.js +17 -7
- package/src/inference/pipelines/text.js +17 -1
- package/src/inference/speculative.d.ts +2 -2
- package/src/inference/speculative.js +4 -18
- package/src/inference/test-harness.d.ts +1 -1
- package/src/inference/test-harness.js +15 -5
- package/src/inference/tokenizer.d.ts +0 -5
- package/src/inference/tokenizer.js +4 -23
- package/src/inference/tokenizers/bpe.js +9 -0
- package/src/inference/tokenizers/bundled.js +176 -33
- package/src/inference/tokenizers/sentencepiece.js +12 -0
- package/src/loader/doppler-loader.js +38 -22
- package/src/loader/dtype-utils.js +3 -44
- package/src/loader/embedding-loader.js +7 -3
- package/src/loader/experts/expert-cache.js +13 -6
- package/src/loader/experts/expert-loader.js +10 -6
- package/src/loader/final-weights-loader.js +8 -4
- package/src/loader/layer-loader.js +2 -1
- package/src/loader/loader-state.js +2 -2
- package/src/loader/memory-monitor.js +8 -0
- package/src/loader/multi-model-loader.d.ts +14 -0
- package/src/loader/multi-model-loader.js +70 -24
- package/src/loader/shard-cache.js +81 -12
- package/src/loader/shard-resolver.js +25 -3
- package/src/loader/tensors/tensor-loader.js +209 -144
- package/src/loader/tensors/tensor-reader.js +76 -19
- package/src/loader/weight-downcast.js +1 -1
- package/src/memory/buffer-pool.d.ts +9 -1
- package/src/memory/buffer-pool.js +109 -44
- package/src/memory/unified-detect.js +1 -1
- package/src/rules/inference/kernel-path.rules.json +24 -8
- package/src/rules/rule-registry.js +25 -1
- package/src/rules/tooling/command-runtime.rules.json +18 -0
- package/src/storage/backends/opfs-store.js +68 -24
- package/src/storage/downloader.js +364 -83
- package/src/storage/index.d.ts +3 -0
- package/src/storage/index.js +3 -0
- package/src/storage/preflight.d.ts +2 -2
- package/src/storage/preflight.js +24 -2
- package/src/storage/quickstart-downloader.js +11 -5
- package/src/storage/registry.js +10 -4
- package/src/storage/reports.js +1 -1
- package/src/storage/shard-manager.d.ts +15 -1
- package/src/storage/shard-manager.js +51 -3
- package/src/storage/source-artifact-store.d.ts +52 -0
- package/src/storage/source-artifact-store.js +234 -0
- package/src/tooling/command-api-constants.d.ts +9 -0
- package/src/tooling/command-api-constants.js +9 -0
- package/src/tooling/command-api-family-normalizers.d.ts +9 -0
- package/src/tooling/command-api-family-normalizers.js +343 -0
- package/src/tooling/command-api-helpers.d.ts +25 -0
- package/src/tooling/command-api-helpers.js +262 -0
- package/src/tooling/command-api.d.ts +27 -1
- package/src/tooling/command-api.js +26 -473
- package/src/tooling/command-envelope.js +4 -1
- package/src/tooling/command-runner-shared.js +52 -18
- package/src/tooling/lean-execution-contract.js +150 -3
- package/src/tooling/node-browser-command-runner.d.ts +4 -0
- package/src/tooling/node-browser-command-runner.js +218 -273
- package/src/tooling/node-command-runner.js +44 -3
- package/src/tooling/node-converter.js +27 -1
- package/src/tooling/node-source-runtime.d.ts +1 -1
- package/src/tooling/node-source-runtime.js +84 -3
- package/src/tooling/node-webgpu.js +30 -105
- package/src/tooling/opfs-cache.js +21 -4
- package/src/tooling/runtime-input-composition.d.ts +38 -0
- package/src/tooling/runtime-input-composition.js +86 -0
- package/src/tooling/source-runtime-bundle.d.ts +40 -5
- package/src/tooling/source-runtime-bundle.js +261 -34
- package/src/tooling/source-runtime-materializer.d.ts +6 -0
- package/src/tooling/source-runtime-materializer.js +93 -0
- package/src/training/attention-backward.js +32 -17
- package/src/training/autograd.js +80 -52
- package/src/training/checkpoint-watch.d.ts +8 -0
- package/src/training/checkpoint-watch.js +139 -0
- package/src/training/checkpoint.d.ts +6 -1
- package/src/training/checkpoint.js +46 -7
- package/src/training/clip.js +2 -1
- package/src/training/datasets/token-batch.js +20 -8
- package/src/training/distillation/artifacts.d.ts +71 -0
- package/src/training/distillation/artifacts.js +132 -0
- package/src/training/distillation/checkpoint-watch.d.ts +10 -0
- package/src/training/distillation/checkpoint-watch.js +58 -0
- package/src/training/distillation/dataset.d.ts +59 -0
- package/src/training/distillation/dataset.js +337 -0
- package/src/training/distillation/eval.d.ts +34 -0
- package/src/training/distillation/eval.js +310 -0
- package/src/training/distillation/index.d.ts +29 -0
- package/src/training/distillation/index.js +29 -0
- package/src/training/distillation/runtime.d.ts +20 -0
- package/src/training/distillation/runtime.js +121 -0
- package/src/training/distillation/scoreboard.d.ts +6 -0
- package/src/training/distillation/scoreboard.js +8 -0
- package/src/training/distillation/stage-a.d.ts +45 -0
- package/src/training/distillation/stage-a.js +338 -0
- package/src/training/distillation/stage-b.d.ts +24 -0
- package/src/training/distillation/stage-b.js +20 -0
- package/src/training/distillation/student-fixture.d.ts +22 -0
- package/src/training/distillation/student-fixture.js +846 -0
- package/src/training/distillation/suite-data.d.ts +45 -0
- package/src/training/distillation/suite-data.js +189 -0
- package/src/training/index.d.ts +10 -0
- package/src/training/index.js +10 -0
- package/src/training/lora-pipeline.d.ts +40 -0
- package/src/training/lora-pipeline.js +793 -0
- package/src/training/lora.js +26 -12
- package/src/training/loss.js +5 -6
- package/src/training/objectives/cross_entropy.js +2 -5
- package/src/training/objectives/distill_kd.js +4 -8
- package/src/training/objectives/distill_triplet.js +4 -8
- package/src/training/objectives/ul_stage2_base.js +4 -8
- package/src/training/operator-artifacts.d.ts +62 -0
- package/src/training/operator-artifacts.js +140 -0
- package/src/training/operator-command.d.ts +5 -0
- package/src/training/operator-command.js +455 -0
- package/src/training/operator-eval.d.ts +48 -0
- package/src/training/operator-eval.js +230 -0
- package/src/training/operator-scoreboard.d.ts +5 -0
- package/src/training/operator-scoreboard.js +44 -0
- package/src/training/optimizer.js +19 -7
- package/src/training/runner.d.ts +52 -0
- package/src/training/runner.js +31 -5
- package/src/training/suite.d.ts +112 -0
- package/src/training/suite.js +24 -984
- package/src/training/tensor-factory.d.ts +9 -0
- package/src/training/tensor-factory.js +13 -0
- package/src/training/trainer.js +3 -5
- package/src/training/ul_dataset.js +3 -5
- package/src/training/workloads.d.ts +164 -0
- package/src/training/workloads.js +530 -0
- package/src/version.js +1 -1
- package/tools/convert-safetensors-node.js +22 -16
- package/tools/doppler-cli.js +179 -63
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { acquireBuffer } from '../../memory/buffer-pool.js';
|
|
1
|
+
import { acquireBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
|
|
2
2
|
import { createTensor, dtypeBytes } from '../tensor.js';
|
|
3
3
|
import { unifiedKernelWrapper } from './utils.js';
|
|
4
4
|
import { selectRuleValue } from './rule-registry.js';
|
|
@@ -32,23 +32,31 @@ async function _repeatChannels(target, input, options = {}) {
|
|
|
32
32
|
const bytesPerElement = dtypeBytes(input.dtype);
|
|
33
33
|
const outputSize = outChannels * height * width * bytesPerElement;
|
|
34
34
|
const output = outputBuffer || acquireBuffer(outputSize, undefined, 'repeat_channels_output');
|
|
35
|
+
const ownedOutput = outputBuffer ? null : output;
|
|
35
36
|
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
37
|
+
try {
|
|
38
|
+
await unifiedKernelWrapper(
|
|
39
|
+
'repeat_channels',
|
|
40
|
+
target,
|
|
41
|
+
variant,
|
|
42
|
+
[input, output],
|
|
43
|
+
{
|
|
44
|
+
in_channels: inChannels,
|
|
45
|
+
height,
|
|
46
|
+
width,
|
|
47
|
+
repeats,
|
|
48
|
+
_pad0: 0,
|
|
49
|
+
},
|
|
50
|
+
[Math.ceil((height * width) / WORKGROUP_SIZES.DEFAULT), outChannels, 1]
|
|
51
|
+
);
|
|
52
|
+
|
|
53
|
+
return createTensor(output, input.dtype, [outChannels, height, width], 'repeat_channels_output');
|
|
54
|
+
} catch (error) {
|
|
55
|
+
if (ownedOutput) {
|
|
56
|
+
releaseBuffer(ownedOutput);
|
|
57
|
+
}
|
|
58
|
+
throw error;
|
|
59
|
+
}
|
|
52
60
|
}
|
|
53
61
|
|
|
54
62
|
export async function runRepeatChannels(input, options = {}) {
|
|
@@ -14,16 +14,15 @@ struct Uniforms {
|
|
|
14
14
|
|
|
15
15
|
@compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
|
|
16
16
|
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
17
|
-
let idx = gid.x;
|
|
18
17
|
let spatial = u.height * u.width;
|
|
19
18
|
let out_channels = u.in_channels * u.repeats;
|
|
20
|
-
let
|
|
21
|
-
|
|
19
|
+
let spatial_idx = gid.x;
|
|
20
|
+
let out_channel = gid.y;
|
|
21
|
+
if (out_channel >= out_channels || spatial_idx >= spatial) {
|
|
22
22
|
return;
|
|
23
23
|
}
|
|
24
24
|
|
|
25
|
-
let out_channel = idx / spatial;
|
|
26
25
|
let channel = out_channel / u.repeats;
|
|
27
|
-
let
|
|
26
|
+
let idx = out_channel * spatial + spatial_idx;
|
|
28
27
|
output[idx] = input[channel * spatial + spatial_idx];
|
|
29
28
|
}
|
|
@@ -16,16 +16,15 @@ struct Uniforms {
|
|
|
16
16
|
|
|
17
17
|
@compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
|
|
18
18
|
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
19
|
-
let idx = gid.x;
|
|
20
19
|
let spatial = u.height * u.width;
|
|
21
20
|
let out_channels = u.in_channels * u.repeats;
|
|
22
|
-
let
|
|
23
|
-
|
|
21
|
+
let spatial_idx = gid.x;
|
|
22
|
+
let out_channel = gid.y;
|
|
23
|
+
if (out_channel >= out_channels || spatial_idx >= spatial) {
|
|
24
24
|
return;
|
|
25
25
|
}
|
|
26
26
|
|
|
27
|
-
let out_channel = idx / spatial;
|
|
28
27
|
let channel = out_channel / u.repeats;
|
|
29
|
-
let
|
|
28
|
+
let idx = out_channel * spatial + spatial_idx;
|
|
30
29
|
output[idx] = input[channel * spatial + spatial_idx];
|
|
31
30
|
}
|
|
@@ -63,9 +63,26 @@ function cleanupTemps(temps, recorder) {
|
|
|
63
63
|
}
|
|
64
64
|
}
|
|
65
65
|
|
|
66
|
+
function planResidualDispatch(target, size, elementsPerWorkgroup) {
|
|
67
|
+
const device = target?.device;
|
|
68
|
+
const maxPerDim = Number.isFinite(device?.limits?.maxComputeWorkgroupsPerDimension)
|
|
69
|
+
? device.limits.maxComputeWorkgroupsPerDimension
|
|
70
|
+
: 65535;
|
|
71
|
+
const dispatchStride = Math.min(size, maxPerDim * elementsPerWorkgroup);
|
|
72
|
+
return {
|
|
73
|
+
dispatchStride,
|
|
74
|
+
workgroups: [
|
|
75
|
+
Math.ceil(dispatchStride / elementsPerWorkgroup),
|
|
76
|
+
Math.ceil(size / dispatchStride),
|
|
77
|
+
1,
|
|
78
|
+
],
|
|
79
|
+
};
|
|
80
|
+
}
|
|
81
|
+
|
|
66
82
|
async function _residualAdd(target, a, b, size, options = {}) {
|
|
67
83
|
const recorder = target && typeof target.beginComputePass === 'function' ? target : null;
|
|
68
84
|
const { useVec4 = true, outputBuffer = null } = options;
|
|
85
|
+
const ownsOutput = outputBuffer == null;
|
|
69
86
|
|
|
70
87
|
const { a: aAligned, b: bAligned, temps } = await alignResidualInputs(a, b, recorder);
|
|
71
88
|
const outputDtype = inferOutputDtype(aAligned, bAligned);
|
|
@@ -75,19 +92,28 @@ async function _residualAdd(target, a, b, size, options = {}) {
|
|
|
75
92
|
const outputSize = size * bytesPerElement;
|
|
76
93
|
const output = outputBuffer || acquireBuffer(outputSize, undefined, 'residual_output');
|
|
77
94
|
|
|
78
|
-
const
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
await unifiedKernelWrapper(
|
|
83
|
-
'residual', target, variant,
|
|
84
|
-
[aAligned, bAligned, output],
|
|
85
|
-
{ size },
|
|
86
|
-
workgroups
|
|
95
|
+
const dispatchPlan = planResidualDispatch(
|
|
96
|
+
target,
|
|
97
|
+
size,
|
|
98
|
+
useVec4 ? VEC4_ELEMENTS_PER_WG : WORKGROUP_SIZES.DEFAULT
|
|
87
99
|
);
|
|
88
100
|
|
|
89
|
-
|
|
90
|
-
|
|
101
|
+
try {
|
|
102
|
+
await unifiedKernelWrapper(
|
|
103
|
+
'residual', target, variant,
|
|
104
|
+
[aAligned, bAligned, output],
|
|
105
|
+
{ size, scale: 1, _pad1: dispatchPlan.dispatchStride, _pad2: 0 },
|
|
106
|
+
dispatchPlan.workgroups
|
|
107
|
+
);
|
|
108
|
+
return createTensor(output, outputDtype, [size], 'residual_output');
|
|
109
|
+
} catch (error) {
|
|
110
|
+
if (ownsOutput) {
|
|
111
|
+
releaseBuffer(output);
|
|
112
|
+
}
|
|
113
|
+
throw error;
|
|
114
|
+
} finally {
|
|
115
|
+
cleanupTemps(temps, recorder);
|
|
116
|
+
}
|
|
91
117
|
}
|
|
92
118
|
|
|
93
119
|
async function _biasAdd(target, data, bias, numTokens, dim, options = {}) {
|
|
@@ -96,18 +122,38 @@ async function _biasAdd(target, data, bias, numTokens, dim, options = {}) {
|
|
|
96
122
|
|
|
97
123
|
const { bias: biasAligned, temps } = await alignBiasTensor(data, bias, recorder);
|
|
98
124
|
const variant = selectBiasAddVariant(data.dtype, biasAligned.dtype);
|
|
99
|
-
|
|
100
|
-
const
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
125
|
+
const device = target?.device;
|
|
126
|
+
const maxPerDim = Number.isFinite(device?.limits?.maxComputeWorkgroupsPerDimension)
|
|
127
|
+
? device.limits.maxComputeWorkgroupsPerDimension
|
|
128
|
+
: 65535;
|
|
129
|
+
const tokenStride = Math.min(numTokens, maxPerDim);
|
|
130
|
+
|
|
131
|
+
const workgroups = [
|
|
132
|
+
Math.ceil(dim / WORKGROUP_SIZES.DEFAULT),
|
|
133
|
+
tokenStride,
|
|
134
|
+
Math.ceil(numTokens / tokenStride),
|
|
135
|
+
];
|
|
136
|
+
|
|
137
|
+
try {
|
|
138
|
+
await unifiedKernelWrapper(
|
|
139
|
+
'bias_add', target, variant,
|
|
140
|
+
[data, biasAligned],
|
|
141
|
+
{
|
|
142
|
+
num_tokens: numTokens,
|
|
143
|
+
dim,
|
|
144
|
+
data_offset: dataOffset,
|
|
145
|
+
bias_offset: biasOffset,
|
|
146
|
+
token_stride: tokenStride,
|
|
147
|
+
_pad0: 0,
|
|
148
|
+
_pad1: 0,
|
|
149
|
+
_pad2: 0,
|
|
150
|
+
},
|
|
151
|
+
workgroups
|
|
152
|
+
);
|
|
153
|
+
return createTensor(data.buffer, data.dtype, [numTokens, dim], 'bias_add_output');
|
|
154
|
+
} finally {
|
|
155
|
+
cleanupTemps(temps, recorder);
|
|
156
|
+
}
|
|
111
157
|
}
|
|
112
158
|
|
|
113
159
|
export async function runResidualAdd(a, b, size, options = {}) {
|
|
@@ -23,7 +23,8 @@ override WORKGROUP_SIZE: u32 = 256u;
|
|
|
23
23
|
|
|
24
24
|
@compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
|
|
25
25
|
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
26
|
-
let
|
|
26
|
+
let dispatch_stride = max(u._pad1, 1u);
|
|
27
|
+
let idx = gid.y * dispatch_stride + gid.x;
|
|
27
28
|
if (idx >= u.size) {
|
|
28
29
|
return;
|
|
29
30
|
}
|
|
@@ -35,7 +36,8 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
|
35
36
|
// This avoids requiring a different bind group layout with read_write on 'a'
|
|
36
37
|
@compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
|
|
37
38
|
fn add_inplace(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
38
|
-
let
|
|
39
|
+
let dispatch_stride = max(u._pad1, 1u);
|
|
40
|
+
let idx = gid.y * dispatch_stride + gid.x;
|
|
39
41
|
if (idx >= u.size) {
|
|
40
42
|
return;
|
|
41
43
|
}
|
|
@@ -45,7 +47,8 @@ fn add_inplace(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
|
45
47
|
// Fused residual + scale: output = a + scale * b
|
|
46
48
|
@compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
|
|
47
49
|
fn add_scaled(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
48
|
-
let
|
|
50
|
+
let dispatch_stride = max(u._pad1, 1u);
|
|
51
|
+
let idx = gid.y * dispatch_stride + gid.x;
|
|
49
52
|
if (idx >= u.size) {
|
|
50
53
|
return;
|
|
51
54
|
}
|
|
@@ -27,7 +27,8 @@ override WORKGROUP_SIZE: u32 = 256u;
|
|
|
27
27
|
|
|
28
28
|
@compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
|
|
29
29
|
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
30
|
-
let
|
|
30
|
+
let dispatch_stride = max(u._pad1, 1u);
|
|
31
|
+
let idx = gid.y * dispatch_stride + gid.x;
|
|
31
32
|
if (idx >= u.size) {
|
|
32
33
|
return;
|
|
33
34
|
}
|
|
@@ -25,7 +25,8 @@ override WORKGROUP_SIZE_VEC4: u32 = 64u;
|
|
|
25
25
|
// Vectorized version for better throughput
|
|
26
26
|
@compute @workgroup_size(WORKGROUP_SIZE_VEC4, 1, 1)
|
|
27
27
|
fn add_vec4(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
28
|
-
let
|
|
28
|
+
let dispatch_stride = max(u._pad1, 4u);
|
|
29
|
+
let idx = gid.y * dispatch_stride + gid.x * 4u;
|
|
29
30
|
let size = u.size;
|
|
30
31
|
|
|
31
32
|
if (idx >= size) {
|
|
@@ -23,7 +23,8 @@ override WORKGROUP_SIZE_VEC4: u32 = 64u;
|
|
|
23
23
|
// Vectorized version for better throughput
|
|
24
24
|
@compute @workgroup_size(WORKGROUP_SIZE_VEC4, 1, 1)
|
|
25
25
|
fn add_vec4(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
26
|
-
let
|
|
26
|
+
let dispatch_stride = max(u._pad1, 4u);
|
|
27
|
+
let idx = gid.y * dispatch_stride + gid.x * 4u;
|
|
27
28
|
let size = u.size;
|
|
28
29
|
|
|
29
30
|
if (idx >= size) {
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
|
|
2
2
|
|
|
3
3
|
import { getKernelCapabilities } from '../device.js';
|
|
4
|
-
import { acquireBuffer, getBufferRequestedSize } from '../../memory/buffer-pool.js';
|
|
4
|
+
import { acquireBuffer, getBufferRequestedSize, releaseBuffer } from '../../memory/buffer-pool.js';
|
|
5
5
|
import { createTensor } from '../tensor.js';
|
|
6
6
|
import { getKernelThresholds, padToQ4KBlock } from '../../config/schema/index.js';
|
|
7
7
|
import { selectRuleValue } from './rule-registry.js';
|
|
@@ -58,6 +58,36 @@ function resolveNormWeightDtype(weight, hiddenSize) {
|
|
|
58
58
|
return 'f32';
|
|
59
59
|
}
|
|
60
60
|
|
|
61
|
+
function assertRMSNormWeightBuffer(weight, weightBuffer, hiddenSize) {
|
|
62
|
+
const isGpuBuffer = weightBuffer && (
|
|
63
|
+
typeof GPUBuffer === 'undefined'
|
|
64
|
+
? true
|
|
65
|
+
: weightBuffer instanceof GPUBuffer
|
|
66
|
+
);
|
|
67
|
+
if (isGpuBuffer) {
|
|
68
|
+
return;
|
|
69
|
+
}
|
|
70
|
+
const weightLabel = weight?.label ?? 'unknown';
|
|
71
|
+
const weightType = weight === null ? 'null' : weight === undefined ? 'undefined' : weight.constructor?.name || typeof weight;
|
|
72
|
+
const bufferType = weightBuffer === null ? 'null' : weightBuffer === undefined ? 'undefined' : weightBuffer.constructor?.name || typeof weightBuffer;
|
|
73
|
+
throw new Error(
|
|
74
|
+
`[rmsnorm] weight "${weightLabel}" requires a GPUBuffer ` +
|
|
75
|
+
`(weightType=${weightType}, bufferType=${bufferType}, hiddenSize=${hiddenSize ?? 'unknown'}).`
|
|
76
|
+
);
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
function planRMSNormDispatch(target, numTokens) {
|
|
80
|
+
const device = target?.device;
|
|
81
|
+
const maxPerDim = Number.isFinite(device?.limits?.maxComputeWorkgroupsPerDimension)
|
|
82
|
+
? device.limits.maxComputeWorkgroupsPerDimension
|
|
83
|
+
: 65535;
|
|
84
|
+
const tokenStride = Math.min(numTokens, maxPerDim);
|
|
85
|
+
return {
|
|
86
|
+
tokenStride,
|
|
87
|
+
workgroups: [tokenStride, Math.ceil(numTokens / tokenStride), 1],
|
|
88
|
+
};
|
|
89
|
+
}
|
|
90
|
+
|
|
61
91
|
export function selectRMSNormKernel(options = {}, isF16 = false) {
|
|
62
92
|
const { residual = null, hiddenSize = null } = options;
|
|
63
93
|
const { smallThreshold } = getKernelThresholds().rmsnorm;
|
|
@@ -82,27 +112,46 @@ export async function runRMSNorm(
|
|
|
82
112
|
const variant = selectRMSNormKernel(options, isF16);
|
|
83
113
|
const inferredHiddenSize = inferHiddenSize(input, hiddenSize);
|
|
84
114
|
const normWeightBuffer = getBuffer(weight);
|
|
115
|
+
assertRMSNormWeightBuffer(weight, normWeightBuffer, inferredHiddenSize);
|
|
85
116
|
const normWeightDtype = resolveNormWeightDtype(weight, inferredHiddenSize);
|
|
86
117
|
|
|
87
118
|
const bytesPerElement = isF16 ? 2 : 4;
|
|
88
119
|
const paddedHiddenSize = padToQ4KBlock(inferredHiddenSize);
|
|
89
120
|
const outputSize = batchSize * paddedHiddenSize * bytesPerElement;
|
|
90
121
|
const outputBuf = outputBuffer || acquireBuffer(outputSize, undefined, 'rmsnorm_output');
|
|
122
|
+
const ownedOutput = outputBuffer ? null : outputBuf;
|
|
123
|
+
const dispatchPlan = planRMSNormDispatch(null, batchSize);
|
|
91
124
|
|
|
92
125
|
// Shader layout always includes the residual binding; when unused, bind a harmless placeholder.
|
|
93
|
-
const residualBuf = residual?.buffer || input
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
126
|
+
const residualBuf = residual?.buffer || residual || input?.buffer || input || outputBuf;
|
|
127
|
+
|
|
128
|
+
try {
|
|
129
|
+
await unifiedKernelWrapper(
|
|
130
|
+
'rmsnorm',
|
|
131
|
+
null,
|
|
132
|
+
variant,
|
|
133
|
+
[input, normWeightBuffer, outputBuf, residualBuf],
|
|
134
|
+
{
|
|
135
|
+
hidden_size: inferredHiddenSize,
|
|
136
|
+
num_tokens: batchSize,
|
|
137
|
+
eps,
|
|
138
|
+
has_residual: residual ? 1 : 0,
|
|
139
|
+
token_stride: dispatchPlan.tokenStride,
|
|
140
|
+
_pad0: 0,
|
|
141
|
+
_pad1: 0,
|
|
142
|
+
_pad2: 0,
|
|
143
|
+
},
|
|
144
|
+
dispatchPlan.workgroups,
|
|
145
|
+
{ RMS_NORM_OFFSET: rmsNormWeightOffset, WEIGHT_IS_F16: normWeightDtype === 'f16' }
|
|
146
|
+
);
|
|
147
|
+
|
|
148
|
+
return createTensor(outputBuf, input.dtype, [batchSize, inferredHiddenSize], 'rmsnorm_output');
|
|
149
|
+
} catch (error) {
|
|
150
|
+
if (ownedOutput) {
|
|
151
|
+
releaseBuffer(ownedOutput);
|
|
152
|
+
}
|
|
153
|
+
throw error;
|
|
154
|
+
}
|
|
106
155
|
}
|
|
107
156
|
|
|
108
157
|
export async function recordRMSNorm(
|
|
@@ -117,24 +166,43 @@ export async function recordRMSNorm(
|
|
|
117
166
|
const variant = selectRMSNormKernel(options, isF16);
|
|
118
167
|
const inferredHiddenSize = inferHiddenSize(input, hiddenSize);
|
|
119
168
|
const normWeightBuffer = getBuffer(weight);
|
|
169
|
+
assertRMSNormWeightBuffer(weight, normWeightBuffer, inferredHiddenSize);
|
|
120
170
|
const normWeightDtype = resolveNormWeightDtype(weight, inferredHiddenSize);
|
|
121
171
|
|
|
122
172
|
const bytesPerElement = isF16 ? 2 : 4;
|
|
123
173
|
const paddedHiddenSize = padToQ4KBlock(inferredHiddenSize);
|
|
124
174
|
const outputSize = batchSize * paddedHiddenSize * bytesPerElement;
|
|
125
175
|
const outputBuf = outputBuffer || acquireBuffer(outputSize, undefined, 'rmsnorm_output');
|
|
126
|
-
|
|
127
|
-
const
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
176
|
+
const ownedOutput = outputBuffer ? null : outputBuf;
|
|
177
|
+
const dispatchPlan = planRMSNormDispatch(recorder, batchSize);
|
|
178
|
+
|
|
179
|
+
const residualBuf = residual?.buffer || residual || input?.buffer || input || outputBuf;
|
|
180
|
+
|
|
181
|
+
try {
|
|
182
|
+
await unifiedKernelWrapper(
|
|
183
|
+
'rmsnorm',
|
|
184
|
+
recorder,
|
|
185
|
+
variant,
|
|
186
|
+
[input, normWeightBuffer, outputBuf, residualBuf],
|
|
187
|
+
{
|
|
188
|
+
hidden_size: inferredHiddenSize,
|
|
189
|
+
num_tokens: batchSize,
|
|
190
|
+
eps,
|
|
191
|
+
has_residual: residual ? 1 : 0,
|
|
192
|
+
token_stride: dispatchPlan.tokenStride,
|
|
193
|
+
_pad0: 0,
|
|
194
|
+
_pad1: 0,
|
|
195
|
+
_pad2: 0,
|
|
196
|
+
},
|
|
197
|
+
dispatchPlan.workgroups,
|
|
198
|
+
{ RMS_NORM_OFFSET: rmsNormWeightOffset, WEIGHT_IS_F16: normWeightDtype === 'f16' }
|
|
199
|
+
);
|
|
200
|
+
|
|
201
|
+
return createTensor(outputBuf, input.dtype, [batchSize, inferredHiddenSize], 'rmsnorm_output');
|
|
202
|
+
} catch (error) {
|
|
203
|
+
if (ownedOutput) {
|
|
204
|
+
releaseBuffer(ownedOutput);
|
|
205
|
+
}
|
|
206
|
+
throw error;
|
|
207
|
+
}
|
|
140
208
|
}
|
|
@@ -39,6 +39,10 @@ struct Uniforms {
|
|
|
39
39
|
num_tokens: u32, // Number of tokens to process
|
|
40
40
|
eps: f32, // Epsilon for numerical stability (typically 1e-5 or 1e-6)
|
|
41
41
|
has_residual: u32, // Runtime flag: 1 = add residual after norm
|
|
42
|
+
token_stride: u32, // Workgroup rows per dispatch row
|
|
43
|
+
_pad0: u32,
|
|
44
|
+
_pad1: u32,
|
|
45
|
+
_pad2: u32,
|
|
42
46
|
}
|
|
43
47
|
|
|
44
48
|
@group(0) @binding(0) var<uniform> u: Uniforms;
|
|
@@ -82,6 +86,10 @@ fn should_add_residual() -> bool {
|
|
|
82
86
|
return HAS_RESIDUAL || (u.has_residual != 0u);
|
|
83
87
|
}
|
|
84
88
|
|
|
89
|
+
fn token_index(wg_id: vec3<u32>) -> u32 {
|
|
90
|
+
return wg_id.y * max(u.token_stride, 1u) + wg_id.x;
|
|
91
|
+
}
|
|
92
|
+
|
|
85
93
|
// =============================================================================
|
|
86
94
|
// Main Entry Point
|
|
87
95
|
// =============================================================================
|
|
@@ -93,7 +101,7 @@ fn main(
|
|
|
93
101
|
@builtin(local_invocation_id) local_id: vec3<u32>,
|
|
94
102
|
@builtin(workgroup_id) wg_id: vec3<u32>
|
|
95
103
|
) {
|
|
96
|
-
let token_idx = wg_id
|
|
104
|
+
let token_idx = token_index(wg_id);
|
|
97
105
|
let thread_idx = local_id.x;
|
|
98
106
|
let size = u.size;
|
|
99
107
|
|
|
@@ -163,7 +171,7 @@ fn main_small(
|
|
|
163
171
|
@builtin(local_invocation_id) local_id: vec3<u32>,
|
|
164
172
|
@builtin(workgroup_id) wg_id: vec3<u32>
|
|
165
173
|
) {
|
|
166
|
-
let token_idx = wg_id
|
|
174
|
+
let token_idx = token_index(wg_id);
|
|
167
175
|
let thread_idx = local_id.x;
|
|
168
176
|
let size = u.size;
|
|
169
177
|
|
|
@@ -219,7 +227,7 @@ fn main_cached(
|
|
|
219
227
|
@builtin(local_invocation_id) local_id: vec3<u32>,
|
|
220
228
|
@builtin(workgroup_id) wg_id: vec3<u32>
|
|
221
229
|
) {
|
|
222
|
-
let token_idx = wg_id
|
|
230
|
+
let token_idx = token_index(wg_id);
|
|
223
231
|
let thread_idx = local_id.x;
|
|
224
232
|
let size = u.size;
|
|
225
233
|
|
|
@@ -288,7 +296,7 @@ fn main_subgroup(
|
|
|
288
296
|
@builtin(subgroup_invocation_id) sg_lane: u32,
|
|
289
297
|
@builtin(subgroup_size) sg_size: u32,
|
|
290
298
|
) {
|
|
291
|
-
let token_idx = wg_id
|
|
299
|
+
let token_idx = token_index(wg_id);
|
|
292
300
|
let thread_idx = local_id.x;
|
|
293
301
|
let size = u.size;
|
|
294
302
|
|
|
@@ -362,7 +370,7 @@ fn main_small_subgroup(
|
|
|
362
370
|
@builtin(subgroup_invocation_id) sg_lane: u32,
|
|
363
371
|
@builtin(subgroup_size) sg_size: u32,
|
|
364
372
|
) {
|
|
365
|
-
let token_idx = wg_id
|
|
373
|
+
let token_idx = token_index(wg_id);
|
|
366
374
|
let thread_idx = local_id.x;
|
|
367
375
|
let size = u.size;
|
|
368
376
|
|
|
@@ -414,4 +422,4 @@ fn main_small_subgroup(
|
|
|
414
422
|
}
|
|
415
423
|
output[base_offset + thread_idx] = result;
|
|
416
424
|
}
|
|
417
|
-
}
|
|
425
|
+
}
|
|
@@ -20,6 +20,10 @@ struct Uniforms {
|
|
|
20
20
|
num_tokens: u32, // Number of tokens to process
|
|
21
21
|
eps: f32, // Epsilon for numerical stability
|
|
22
22
|
has_residual: u32, // 1 if residual input provided, 0 otherwise
|
|
23
|
+
token_stride: u32, // Workgroup rows per dispatch row
|
|
24
|
+
_pad0: u32,
|
|
25
|
+
_pad1: u32,
|
|
26
|
+
_pad2: u32,
|
|
23
27
|
}
|
|
24
28
|
|
|
25
29
|
@group(0) @binding(0) var<uniform> u: Uniforms;
|
|
@@ -47,6 +51,10 @@ fn load_weight(idx: u32) -> f32 {
|
|
|
47
51
|
return bitcast<f32>(weight[idx]);
|
|
48
52
|
}
|
|
49
53
|
|
|
54
|
+
fn token_index(wg_id: vec3<u32>) -> u32 {
|
|
55
|
+
return wg_id.y * max(u.token_stride, 1u) + wg_id.x;
|
|
56
|
+
}
|
|
57
|
+
|
|
50
58
|
// Main RMSNorm kernel - one workgroup per token
|
|
51
59
|
@compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
|
|
52
60
|
fn main(
|
|
@@ -54,7 +62,7 @@ fn main(
|
|
|
54
62
|
@builtin(local_invocation_id) local_id: vec3<u32>,
|
|
55
63
|
@builtin(workgroup_id) wg_id: vec3<u32>
|
|
56
64
|
) {
|
|
57
|
-
let token_idx = wg_id
|
|
65
|
+
let token_idx = token_index(wg_id);
|
|
58
66
|
let thread_idx = local_id.x;
|
|
59
67
|
let size = u.size;
|
|
60
68
|
|
|
@@ -121,7 +129,7 @@ fn rmsnorm_small_f16(
|
|
|
121
129
|
@builtin(local_invocation_id) local_id: vec3<u32>,
|
|
122
130
|
@builtin(workgroup_id) wg_id: vec3<u32>
|
|
123
131
|
) {
|
|
124
|
-
let token_idx = wg_id
|
|
132
|
+
let token_idx = token_index(wg_id);
|
|
125
133
|
let thread_idx = local_id.x;
|
|
126
134
|
let size = u.size;
|
|
127
135
|
|
package/src/gpu/kernels/rope.js
CHANGED
|
@@ -13,18 +13,29 @@ async function _rope(target, input, freqsCos, freqsSin, seqLen, options = {}) {
|
|
|
13
13
|
const {
|
|
14
14
|
numHeads = 1,
|
|
15
15
|
headDim = 64,
|
|
16
|
+
rotaryDim = headDim,
|
|
17
|
+
interleaved = false,
|
|
16
18
|
ropeTheta = ropeDefaults.defaultTheta,
|
|
17
19
|
} = options;
|
|
18
20
|
|
|
19
21
|
if (headDim % 2 !== 0) {
|
|
20
22
|
throw new Error(`RoPE headDim must be even, got ${headDim}`);
|
|
21
23
|
}
|
|
24
|
+
if (rotaryDim % 2 !== 0) {
|
|
25
|
+
throw new Error(`RoPE rotaryDim must be even, got ${rotaryDim}`);
|
|
26
|
+
}
|
|
27
|
+
if (rotaryDim <= 0 || rotaryDim > headDim) {
|
|
28
|
+
throw new Error(`RoPE rotaryDim must be in (0, headDim]; got ${rotaryDim} for headDim ${headDim}`);
|
|
29
|
+
}
|
|
30
|
+
if (input.dtype === 'f16' && (rotaryDim !== headDim || interleaved)) {
|
|
31
|
+
throw new Error('RoPE f16 kernel requires rotaryDim === headDim and interleaved === false.');
|
|
32
|
+
}
|
|
22
33
|
|
|
23
34
|
const caps = getKernelCapabilities();
|
|
24
35
|
const useF16 = input.dtype === 'f16' && caps.hasF16;
|
|
25
36
|
const variant = selectRuleValue('rope', 'variant', { useF16 });
|
|
26
37
|
|
|
27
|
-
const halfDim =
|
|
38
|
+
const halfDim = rotaryDim / 2;
|
|
28
39
|
const workgroups = Math.ceil((seqLen * numHeads * halfDim) / WORKGROUP_SIZES.DEFAULT);
|
|
29
40
|
|
|
30
41
|
await unifiedKernelWrapper(
|
|
@@ -34,9 +45,11 @@ async function _rope(target, input, freqsCos, freqsSin, seqLen, options = {}) {
|
|
|
34
45
|
seq_len: seqLen,
|
|
35
46
|
num_heads: numHeads,
|
|
36
47
|
head_dim: headDim,
|
|
48
|
+
rotary_dim: rotaryDim,
|
|
37
49
|
start_pos: options.startPos ?? ropeDefaults.defaultStartPos,
|
|
38
50
|
rope_base: ropeTheta,
|
|
39
51
|
rope_scale: 1.0,
|
|
52
|
+
interleaved: interleaved ? 1 : 0,
|
|
40
53
|
},
|
|
41
54
|
workgroups
|
|
42
55
|
);
|