@simulatte/doppler 0.1.6 → 0.1.8
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +145 -0
- package/README.md +16 -23
- package/package.json +30 -32
- package/src/adapters/adapter-registry.js +12 -1
- package/src/adapters/lora-loader.js +23 -6
- package/src/bridge/extension-client.d.ts +5 -0
- package/src/bridge/extension-client.js +40 -0
- package/src/bridge/index.d.ts +2 -1
- package/src/bridge/index.js +6 -4
- package/src/browser/browser-converter.js +31 -1
- package/src/browser/file-picker.js +6 -0
- package/src/browser/safetensors-parser-browser.js +84 -1
- package/src/browser/shard-io-browser.js +2 -2
- package/src/browser/tensor-source-download.js +8 -2
- package/src/browser/tensor-source-http.d.ts +1 -0
- package/src/browser/tensor-source-http.js +5 -1
- package/src/client/doppler-api.browser.js +20 -4
- package/src/client/doppler-api.js +19 -3
- package/src/client/doppler-provider/generation.js +12 -0
- package/src/client/doppler-provider/model-manager.d.ts +10 -0
- package/src/client/doppler-provider/model-manager.js +91 -19
- package/src/client/doppler-provider/source-runtime.d.ts +2 -1
- package/src/client/doppler-provider/source-runtime.js +132 -13
- package/src/client/doppler-registry.json +5 -20
- package/src/config/backward-registry-loader.js +17 -2
- package/src/config/execution-v0-contract-check.js +113 -15
- package/src/config/kernel-path-contract-check.js +57 -29
- package/src/config/kernel-path-loader.d.ts +5 -0
- package/src/config/kernel-path-loader.js +18 -36
- package/src/config/kernels/kernel-ref-digests.js +1 -1
- package/src/config/kernels/registry.js +14 -1
- package/src/config/kernels/registry.json +81 -5
- package/src/config/loader.d.ts +1 -1
- package/src/config/loader.js +15 -2
- package/src/config/merge-contract-check.js +66 -4
- package/src/config/merge-helpers.js +128 -7
- package/src/config/merge.d.ts +1 -0
- package/src/config/merge.js +10 -0
- package/src/config/param-validator.js +47 -2
- package/src/config/presets/kernel-paths/{gemma2-q4k-dequant-f32a.json → gemma2-q4k-dequant-f32a-nosubgroups.json} +3 -3
- package/src/config/presets/kernel-paths/gemma3-f16-fused-f32a-online-streamingprefill.json +223 -0
- package/src/config/presets/kernel-paths/{gemma3-q4k-dequant-f32a.json → gemma3-q4k-dequant-f32a-nosubgroups.json} +3 -3
- package/src/config/presets/kernel-paths/gemma3-q4k-dequant-f32w-f32a-online.json +56 -0
- package/src/config/presets/kernel-paths/lfm2-q4k-dequant-f32a-nosubgroups.json +61 -0
- package/src/config/presets/kernel-paths/registry.json +43 -8
- package/src/config/presets/models/gemma2.json +3 -2
- package/src/config/presets/models/gemma3.json +2 -0
- package/src/config/presets/models/qwen3.json +4 -3
- package/src/config/presets/models/qwen3_5.json +16 -0
- package/src/config/presets/runtime/experiments/bench/gemma3-bench-q4k.json +1 -1
- package/src/config/presets/runtime/experiments/debug/gemma3-debug-q4k.json +1 -1
- package/src/config/presets/runtime/experiments/verify/gemma3-verify.json +1 -1
- package/src/config/presets/runtime/kernels/dequant-f16-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/dequant-f32-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/embeddinggemma-q4k-dequant-f32a.json +37 -0
- package/src/config/presets/runtime/kernels/fused-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f16a.json +33 -0
- package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f32a-nosubgroups.json +33 -0
- package/src/config/presets/runtime/kernels/gemma2-q4k-fused-f32a.json +33 -0
- package/src/config/presets/runtime/kernels/safe-q4k.json +6 -13
- package/src/config/presets/runtime/model/qwen3-5-layer-probe.json +52 -0
- package/src/config/presets/runtime/model/qwen3-5-linear-attn-debug.json +90 -0
- package/src/config/presets/runtime/platform/metal-apple-q4k.json +1 -1
- package/src/config/runtime.js +6 -1
- package/src/config/schema/conversion.schema.d.ts +1 -0
- package/src/config/schema/debug.schema.d.ts +5 -0
- package/src/config/schema/doppler.schema.js +16 -21
- package/src/config/schema/inference-defaults.schema.js +3 -3
- package/src/config/schema/kernel-path.schema.d.ts +5 -1
- package/src/config/schema/kernel-thresholds.schema.js +12 -4
- package/src/config/schema/manifest.schema.d.ts +3 -2
- package/src/config/schema/manifest.schema.js +17 -4
- package/src/config/schema/storage.schema.js +1 -1
- package/src/config/training-defaults.js +30 -22
- package/src/converter/conversion-plan.js +104 -11
- package/src/converter/core.d.ts +7 -0
- package/src/converter/core.js +16 -9
- package/src/converter/execution-v0-manifest.js +4 -1
- package/src/converter/index.d.ts +1 -0
- package/src/converter/index.js +1 -0
- package/src/converter/manifest-inference.js +50 -29
- package/src/converter/parsers/diffusion.js +0 -3
- package/src/converter/parsers/transformer.js +4 -0
- package/src/converter/quantization-info.js +40 -16
- package/src/converter/quantizer.js +19 -12
- package/src/converter/rope-config.js +8 -6
- package/src/converter/shard-packer.d.ts +1 -1
- package/src/converter/shard-packer.js +4 -1
- package/src/converter/tokenizer-utils.d.ts +1 -0
- package/src/converter/tokenizer-utils.js +4 -1
- package/src/debug/config.js +123 -11
- package/src/debug/reference/hf_qwen35_linear_attn_debug.py +268 -0
- package/src/debug/signals.js +7 -1
- package/src/debug/tensor.d.ts +2 -0
- package/src/debug/tensor.js +13 -2
- package/src/distribution/p2p-control-plane.js +52 -12
- package/src/distribution/p2p-observability.js +43 -7
- package/src/distribution/p2p-webrtc-browser.js +20 -0
- package/src/distribution/shard-delivery.js +83 -27
- package/src/formats/gguf/types.js +33 -16
- package/src/formats/rdrr/groups.d.ts +12 -4
- package/src/formats/rdrr/groups.js +3 -6
- package/src/formats/rdrr/parsing.d.ts +4 -0
- package/src/formats/rdrr/parsing.js +53 -3
- package/src/formats/rdrr/types.d.ts +2 -1
- package/src/gpu/command-recorder.js +86 -61
- package/src/gpu/device.d.ts +1 -0
- package/src/gpu/device.js +73 -19
- package/src/gpu/kernel-tuner/benchmarks.js +326 -316
- package/src/gpu/kernel-tuner/cache.js +71 -4
- package/src/gpu/kernel-tuner/tuner.js +22 -4
- package/src/gpu/kernels/attention.js +15 -34
- package/src/gpu/kernels/backward/adam.js +62 -58
- package/src/gpu/kernels/backward/attention_backward.js +257 -169
- package/src/gpu/kernels/backward/conv2d_backward.js +14 -1
- package/src/gpu/kernels/cast.js +191 -149
- package/src/gpu/kernels/check-stop.js +33 -44
- package/src/gpu/kernels/conv2d.js +27 -17
- package/src/gpu/kernels/cross_entropy_loss.js +21 -15
- package/src/gpu/kernels/depthwise_conv2d.js +36 -26
- package/src/gpu/kernels/dequant.js +178 -126
- package/src/gpu/kernels/energy.d.ts +3 -21
- package/src/gpu/kernels/energy.js +111 -88
- package/src/gpu/kernels/feature-check.js +1 -1
- package/src/gpu/kernels/fused_ffn.js +84 -65
- package/src/gpu/kernels/fused_matmul_residual.js +56 -33
- package/src/gpu/kernels/fused_matmul_rmsnorm.js +62 -45
- package/src/gpu/kernels/gather.js +33 -15
- package/src/gpu/kernels/gelu.js +19 -11
- package/src/gpu/kernels/grouped_pointwise_conv2d.js +33 -23
- package/src/gpu/kernels/groupnorm.js +34 -23
- package/src/gpu/kernels/index.d.ts +8 -0
- package/src/gpu/kernels/index.js +6 -0
- package/src/gpu/kernels/kv-quantize.js +5 -2
- package/src/gpu/kernels/layernorm.js +35 -19
- package/src/gpu/kernels/logit-merge.js +5 -3
- package/src/gpu/kernels/matmul-selection.js +47 -4
- package/src/gpu/kernels/matmul.d.ts +2 -0
- package/src/gpu/kernels/matmul.js +59 -40
- package/src/gpu/kernels/modulate.js +23 -15
- package/src/gpu/kernels/moe.js +221 -175
- package/src/gpu/kernels/pixel_shuffle.js +22 -14
- package/src/gpu/kernels/relu.js +18 -10
- package/src/gpu/kernels/repeat_channels.js +25 -17
- package/src/gpu/kernels/residual.js +37 -27
- package/src/gpu/kernels/rmsnorm.js +66 -43
- package/src/gpu/kernels/rope.js +3 -0
- package/src/gpu/kernels/sample.js +27 -38
- package/src/gpu/kernels/sana_linear_attention.js +18 -10
- package/src/gpu/kernels/scale.js +18 -11
- package/src/gpu/kernels/shader-cache.js +4 -2
- package/src/gpu/kernels/silu.js +120 -72
- package/src/gpu/kernels/softmax.js +44 -25
- package/src/gpu/kernels/split_qg.d.ts +50 -0
- package/src/gpu/kernels/split_qg.js +46 -0
- package/src/gpu/kernels/split_qg.wgsl +58 -0
- package/src/gpu/kernels/split_qg_f16.wgsl +62 -0
- package/src/gpu/kernels/split_qkv.js +23 -13
- package/src/gpu/kernels/transpose.js +18 -10
- package/src/gpu/kernels/transpose.wgsl +5 -3
- package/src/gpu/kernels/upsample2d.js +21 -13
- package/src/gpu/kernels/utils.js +20 -13
- package/src/gpu/partitioned-buffer-pool.js +10 -2
- package/src/gpu/perf-guards.js +2 -9
- package/src/gpu/profiler.js +27 -22
- package/src/gpu/readback-utils.d.ts +16 -0
- package/src/gpu/readback-utils.js +41 -0
- package/src/gpu/submit-tracker.js +13 -0
- package/src/gpu/uniform-cache.d.ts +1 -0
- package/src/gpu/uniform-cache.js +30 -9
- package/src/gpu/weight-buffer.d.ts +1 -1
- package/src/gpu/weight-buffer.js +1 -1
- package/src/hotswap/intent-bundle.js +6 -0
- package/src/hotswap/manifest.d.ts +10 -1
- package/src/hotswap/manifest.js +12 -2
- package/src/hotswap/runtime.js +30 -8
- package/src/index-browser.d.ts +44 -0
- package/src/index-browser.js +14 -0
- package/src/inference/browser-harness-contract-helpers.d.ts +5 -0
- package/src/inference/browser-harness-contract-helpers.js +28 -0
- package/src/inference/browser-harness-diffusion-energy-suites.d.ts +2 -0
- package/src/inference/browser-harness-diffusion-energy-suites.js +269 -0
- package/src/inference/browser-harness-model-helpers.d.ts +16 -0
- package/src/inference/browser-harness-model-helpers.js +217 -0
- package/src/inference/browser-harness-report-helpers.d.ts +7 -0
- package/src/inference/browser-harness-report-helpers.js +42 -0
- package/src/inference/browser-harness-runtime-helpers.d.ts +61 -0
- package/src/inference/browser-harness-runtime-helpers.js +415 -0
- package/src/inference/browser-harness-suite-helpers.d.ts +28 -0
- package/src/inference/browser-harness-suite-helpers.js +268 -0
- package/src/inference/browser-harness-text-helpers.d.ts +27 -0
- package/src/inference/browser-harness-text-helpers.js +788 -0
- package/src/inference/browser-harness.d.ts +8 -0
- package/src/inference/browser-harness.js +149 -1996
- package/src/inference/kv-cache/base.js +140 -94
- package/src/inference/kv-cache/tiered.js +5 -3
- package/src/inference/moe-router.js +88 -56
- package/src/inference/multi-model-network.js +5 -3
- package/src/inference/network-evolution.d.ts +11 -2
- package/src/inference/network-evolution.js +20 -21
- package/src/inference/pipelines/context.d.ts +3 -0
- package/src/inference/pipelines/context.js +142 -2
- package/src/inference/pipelines/diffusion/helpers.js +10 -2
- package/src/inference/pipelines/diffusion/pipeline.js +2 -1
- package/src/inference/pipelines/diffusion/sd3-transformer.js +10 -10
- package/src/inference/pipelines/diffusion/text-encoder-gpu.js +8 -2
- package/src/inference/pipelines/diffusion/vae.js +3 -7
- package/src/inference/pipelines/energy/pipeline.js +27 -21
- package/src/inference/pipelines/energy/quintel.d.ts +5 -0
- package/src/inference/pipelines/energy/quintel.js +11 -0
- package/src/inference/pipelines/energy-head/row-head-pipeline.js +17 -13
- package/src/inference/pipelines/structured/json-head-pipeline.js +26 -11
- package/src/inference/pipelines/text/attention/output-projection.d.ts +12 -0
- package/src/inference/pipelines/text/attention/output-projection.js +8 -0
- package/src/inference/pipelines/text/attention/projections.d.ts +10 -1
- package/src/inference/pipelines/text/attention/projections.js +192 -112
- package/src/inference/pipelines/text/attention/record.js +77 -14
- package/src/inference/pipelines/text/attention/run.js +112 -14
- package/src/inference/pipelines/text/config.js +17 -4
- package/src/inference/pipelines/text/embed.js +2 -8
- package/src/inference/pipelines/text/execution-plan.js +46 -23
- package/src/inference/pipelines/text/execution-v0-contract-helpers.d.ts +59 -0
- package/src/inference/pipelines/text/execution-v0-contract-helpers.js +937 -0
- package/src/inference/pipelines/text/execution-v0-runtime-builders.d.ts +15 -0
- package/src/inference/pipelines/text/execution-v0-runtime-builders.js +279 -0
- package/src/inference/pipelines/text/execution-v0.js +62 -1013
- package/src/inference/pipelines/text/generator-runtime.js +5 -0
- package/src/inference/pipelines/text/generator-steps.d.ts +52 -0
- package/src/inference/pipelines/text/generator-steps.js +340 -221
- package/src/inference/pipelines/text/generator.js +56 -40
- package/src/inference/pipelines/text/init.d.ts +13 -0
- package/src/inference/pipelines/text/init.js +94 -25
- package/src/inference/pipelines/text/kernel-path-auto-select.js +2 -0
- package/src/inference/pipelines/text/kernel-trace.d.ts +2 -0
- package/src/inference/pipelines/text/kernel-trace.js +6 -0
- package/src/inference/pipelines/text/layer.js +4 -9
- package/src/inference/pipelines/text/linear-attention.d.ts +15 -0
- package/src/inference/pipelines/text/linear-attention.js +113 -9
- package/src/inference/pipelines/text/logits/gpu.js +12 -7
- package/src/inference/pipelines/text/logits/index.d.ts +6 -1
- package/src/inference/pipelines/text/logits/index.js +13 -12
- package/src/inference/pipelines/text/logits/utils.d.ts +7 -0
- package/src/inference/pipelines/text/logits/utils.js +9 -0
- package/src/inference/pipelines/text/lora-apply.js +50 -32
- package/src/inference/pipelines/text/model-load.js +282 -104
- package/src/inference/pipelines/text/moe-cache.js +5 -4
- package/src/inference/pipelines/text/moe-cpu-gptoss.js +74 -69
- package/src/inference/pipelines/text/moe-cpu.js +42 -38
- package/src/inference/pipelines/text/moe-gpu.js +110 -86
- package/src/inference/pipelines/text/ops.js +90 -90
- package/src/inference/pipelines/text/probes.js +9 -9
- package/src/inference/pipelines/text/sampling.js +52 -6
- package/src/inference/pipelines/text/weights.js +17 -7
- package/src/inference/pipelines/text.js +13 -1
- package/src/inference/speculative.d.ts +2 -2
- package/src/inference/speculative.js +4 -18
- package/src/inference/test-harness.d.ts +1 -1
- package/src/inference/test-harness.js +17 -7
- package/src/inference/tokenizer.d.ts +0 -5
- package/src/inference/tokenizer.js +4 -23
- package/src/inference/tokenizers/bpe.js +9 -0
- package/src/inference/tokenizers/bundled.js +20 -0
- package/src/inference/tokenizers/sentencepiece.js +12 -0
- package/src/loader/doppler-loader.js +38 -22
- package/src/loader/dtype-utils.js +3 -44
- package/src/loader/embedding-loader.js +7 -3
- package/src/loader/experts/expert-cache.js +13 -6
- package/src/loader/experts/expert-loader.js +10 -6
- package/src/loader/final-weights-loader.js +10 -4
- package/src/loader/layer-loader.js +2 -1
- package/src/loader/loader-state.js +2 -2
- package/src/loader/memory-monitor.js +8 -0
- package/src/loader/multi-model-loader.d.ts +14 -0
- package/src/loader/multi-model-loader.js +70 -24
- package/src/loader/shard-cache.js +84 -14
- package/src/loader/shard-resolver.js +25 -3
- package/src/loader/tensors/tensor-loader.js +214 -144
- package/src/loader/tensors/tensor-reader.js +76 -19
- package/src/loader/weight-downcast.js +1 -1
- package/src/memory/buffer-pool.d.ts +9 -1
- package/src/memory/buffer-pool.js +109 -44
- package/src/memory/unified-detect.js +1 -1
- package/src/rules/inference/dtype.rules.json +5 -0
- package/src/rules/inference/kernel-path.rules.json +24 -8
- package/src/rules/kernels/split-qg.rules.json +6 -0
- package/src/rules/rule-registry.js +27 -1
- package/src/storage/backends/opfs-store.js +68 -24
- package/src/storage/downloader.js +365 -83
- package/src/storage/index.d.ts +3 -0
- package/src/storage/index.js +3 -0
- package/src/storage/preflight.d.ts +2 -2
- package/src/storage/preflight.js +24 -2
- package/src/storage/quickstart-downloader.js +11 -5
- package/src/storage/registry.js +10 -4
- package/src/storage/reports.js +1 -1
- package/src/storage/shard-manager.d.ts +15 -1
- package/src/storage/shard-manager.js +55 -6
- package/src/storage/source-artifact-store.d.ts +52 -0
- package/src/storage/source-artifact-store.js +234 -0
- package/src/tooling/command-api-constants.d.ts +9 -0
- package/src/tooling/command-api-constants.js +9 -0
- package/src/tooling/command-api-family-normalizers.d.ts +9 -0
- package/src/tooling/command-api-family-normalizers.js +343 -0
- package/src/tooling/command-api-helpers.d.ts +25 -0
- package/src/tooling/command-api-helpers.js +262 -0
- package/src/tooling/command-api.js +16 -602
- package/src/tooling/command-envelope.js +4 -1
- package/src/tooling/command-runner-shared.js +52 -18
- package/src/tooling/conversion-config-materializer.js +3 -5
- package/src/tooling/lean-execution-contract.js +150 -3
- package/src/tooling/node-browser-command-runner.js +161 -271
- package/src/tooling/node-command-runner.js +29 -3
- package/src/tooling/node-converter.js +30 -1
- package/src/tooling/node-source-runtime.d.ts +1 -1
- package/src/tooling/node-source-runtime.js +120 -3
- package/src/tooling/node-webgpu.js +24 -21
- package/src/tooling/opfs-cache.js +21 -4
- package/src/tooling/runtime-input-composition.d.ts +38 -0
- package/src/tooling/runtime-input-composition.js +86 -0
- package/src/tooling/source-runtime-bundle.d.ts +40 -5
- package/src/tooling/source-runtime-bundle.js +261 -34
- package/src/tooling/source-runtime-materializer.d.ts +6 -0
- package/src/tooling/source-runtime-materializer.js +93 -0
- package/src/training/attention-backward.js +32 -17
- package/src/training/autograd.js +80 -52
- package/src/training/checkpoint-watch.d.ts +2 -1
- package/src/training/checkpoint-watch.js +39 -6
- package/src/training/checkpoint.js +40 -11
- package/src/training/clip.js +2 -1
- package/src/training/datasets/token-batch.js +20 -8
- package/src/training/distillation/checkpoint-watch.js +1 -0
- package/src/training/distillation/student-fixture.d.ts +22 -0
- package/src/training/distillation/student-fixture.js +846 -0
- package/src/training/distillation/suite-data.d.ts +45 -0
- package/src/training/distillation/suite-data.js +189 -0
- package/src/training/lora-pipeline.js +4 -7
- package/src/training/lora.js +26 -12
- package/src/training/loss.js +5 -6
- package/src/training/objectives/cross_entropy.js +2 -5
- package/src/training/objectives/distill_kd.js +4 -8
- package/src/training/objectives/distill_triplet.js +4 -8
- package/src/training/objectives/ul_stage2_base.js +4 -8
- package/src/training/operator-command.js +2 -0
- package/src/training/optimizer.js +19 -7
- package/src/training/runner.js +2 -1
- package/src/training/suite.js +18 -978
- package/src/training/tensor-factory.d.ts +9 -0
- package/src/training/tensor-factory.js +13 -0
- package/src/training/trainer.js +3 -5
- package/src/training/ul_dataset.js +3 -5
- package/src/training/workloads.js +70 -79
- package/src/types/model.d.ts +5 -0
- package/src/version.js +1 -1
- package/tools/convert-safetensors-node.js +22 -16
- package/tools/doppler-cli.js +50 -26
package/src/gpu/kernels/silu.js
CHANGED
|
@@ -1,13 +1,26 @@
|
|
|
1
1
|
|
|
2
2
|
|
|
3
3
|
import { getDevice } from '../device.js';
|
|
4
|
-
import { acquireBuffer } from '../../memory/buffer-pool.js';
|
|
4
|
+
import { acquireBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
|
|
5
5
|
import { createTensor, dtypeBytes } from '../tensor.js';
|
|
6
6
|
import { WORKGROUP_SIZES } from './constants.js';
|
|
7
7
|
import { dispatch, recordDispatch } from './dispatch.js';
|
|
8
8
|
import { getPipelineFast, createUniformBufferWithView } from './utils.js';
|
|
9
9
|
import { selectRuleValue } from './rule-registry.js';
|
|
10
10
|
|
|
11
|
+
function destroyAfterSubmit(device, buffer) {
|
|
12
|
+
if (!buffer) {
|
|
13
|
+
return;
|
|
14
|
+
}
|
|
15
|
+
device.queue.onSubmittedWorkDone()
|
|
16
|
+
.then(() => {
|
|
17
|
+
buffer.destroy();
|
|
18
|
+
})
|
|
19
|
+
.catch(() => {
|
|
20
|
+
buffer.destroy();
|
|
21
|
+
});
|
|
22
|
+
}
|
|
23
|
+
|
|
11
24
|
function canUseF16(input) {
|
|
12
25
|
return input.dtype === 'f16';
|
|
13
26
|
}
|
|
@@ -47,6 +60,12 @@ function createSiLUBindGroupEntries(uniformBuffer, input, output, gate) {
|
|
|
47
60
|
];
|
|
48
61
|
}
|
|
49
62
|
|
|
63
|
+
function cleanupRunResources(uniformBuffer, ownedOutput) {
|
|
64
|
+
if (ownedOutput) {
|
|
65
|
+
releaseBuffer(ownedOutput);
|
|
66
|
+
}
|
|
67
|
+
}
|
|
68
|
+
|
|
50
69
|
function planSiLUDispatch(device, size, useVec4) {
|
|
51
70
|
const maxPerDim = Number.isFinite(device?.limits?.maxComputeWorkgroupsPerDimension)
|
|
52
71
|
? device.limits.maxComputeWorkgroupsPerDimension
|
|
@@ -97,6 +116,7 @@ export async function runSiLU(
|
|
|
97
116
|
const inferredSize = size || (input.buffer.size / bytesPerElement);
|
|
98
117
|
const outputSize = inferredSize * bytesPerElement;
|
|
99
118
|
const output = outputBuffer || acquireBuffer(outputSize, undefined, 'silu_output');
|
|
119
|
+
const ownedOutput = outputBuffer ? null : output;
|
|
100
120
|
const dispatchPlan = planSiLUDispatch(device, inferredSize, useVec4);
|
|
101
121
|
|
|
102
122
|
// Create uniform buffer
|
|
@@ -116,17 +136,21 @@ export async function runSiLU(
|
|
|
116
136
|
// Create bind group using helper
|
|
117
137
|
const entries = createSiLUBindGroupEntries(uniformBuffer, input, output, gate);
|
|
118
138
|
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
139
|
+
try {
|
|
140
|
+
const bindGroup = device.createBindGroup({
|
|
141
|
+
label: 'silu_bind_group',
|
|
142
|
+
layout: pipeline.getBindGroupLayout(0),
|
|
143
|
+
entries,
|
|
144
|
+
});
|
|
145
|
+
|
|
146
|
+
dispatch(device, pipeline, bindGroup, dispatchPlan.workgroups, 'silu');
|
|
147
|
+
return createTensor(output, input.dtype, [inferredSize], 'silu_output');
|
|
148
|
+
} catch (error) {
|
|
149
|
+
cleanupRunResources(null, ownedOutput);
|
|
150
|
+
throw error;
|
|
151
|
+
} finally {
|
|
152
|
+
destroyAfterSubmit(device, uniformBuffer);
|
|
153
|
+
}
|
|
130
154
|
}
|
|
131
155
|
|
|
132
156
|
|
|
@@ -148,6 +172,7 @@ export async function runSwiGLURowsplitBias(
|
|
|
148
172
|
const bytesPerElement = dtypeBytes(input.dtype);
|
|
149
173
|
const outputSize = numTokens * dim * bytesPerElement;
|
|
150
174
|
const output = outputBuffer || acquireBuffer(outputSize, undefined, 'swiglu_output');
|
|
175
|
+
const ownedOutput = outputBuffer ? null : output;
|
|
151
176
|
|
|
152
177
|
// Create uniform buffer
|
|
153
178
|
const uniformBuffer = createUniformBufferWithView(
|
|
@@ -164,23 +189,27 @@ export async function runSwiGLURowsplitBias(
|
|
|
164
189
|
);
|
|
165
190
|
|
|
166
191
|
// Create bind group
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
192
|
+
try {
|
|
193
|
+
const bindGroup = device.createBindGroup({
|
|
194
|
+
label: 'swiglu_bind_group',
|
|
195
|
+
layout: pipeline.getBindGroupLayout(0),
|
|
196
|
+
entries: [
|
|
197
|
+
{ binding: 0, resource: { buffer: uniformBuffer } },
|
|
198
|
+
{ binding: 1, resource: { buffer: input.buffer } },
|
|
199
|
+
{ binding: 2, resource: { buffer: bias.buffer } },
|
|
200
|
+
{ binding: 3, resource: { buffer: output } },
|
|
201
|
+
],
|
|
202
|
+
});
|
|
203
|
+
|
|
204
|
+
const workgroups = Math.ceil((numTokens * dim) / WORKGROUP_SIZES.DEFAULT);
|
|
205
|
+
dispatch(device, pipeline, bindGroup, workgroups, 'swiglu');
|
|
206
|
+
return createTensor(output, input.dtype, [numTokens, dim], 'swiglu_output');
|
|
207
|
+
} catch (error) {
|
|
208
|
+
cleanupRunResources(null, ownedOutput);
|
|
209
|
+
throw error;
|
|
210
|
+
} finally {
|
|
211
|
+
destroyAfterSubmit(device, uniformBuffer);
|
|
212
|
+
}
|
|
184
213
|
}
|
|
185
214
|
|
|
186
215
|
|
|
@@ -202,6 +231,7 @@ export async function runSiLURowSplit(
|
|
|
202
231
|
|
|
203
232
|
const outputSize = numTokens * dim * bytesPerElement;
|
|
204
233
|
const output = outputBuffer || acquireBuffer(outputSize, undefined, 'silu_rowsplit_output');
|
|
234
|
+
const ownedOutput = outputBuffer ? null : output;
|
|
205
235
|
|
|
206
236
|
// Create uniform buffer
|
|
207
237
|
const uniformBuffer = createUniformBufferWithView(
|
|
@@ -218,24 +248,28 @@ export async function runSiLURowSplit(
|
|
|
218
248
|
);
|
|
219
249
|
|
|
220
250
|
// Bind group: provide a dummy gate buffer to satisfy the fixed layout
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
251
|
+
try {
|
|
252
|
+
const gateBuffer = input.buffer;
|
|
253
|
+
const bindGroup = device.createBindGroup({
|
|
254
|
+
label: 'silu_rowsplit_bind_group',
|
|
255
|
+
layout: pipeline.getBindGroupLayout(0),
|
|
256
|
+
entries: [
|
|
257
|
+
{ binding: 0, resource: { buffer: uniformBuffer } },
|
|
258
|
+
{ binding: 1, resource: { buffer: input.buffer } },
|
|
259
|
+
{ binding: 2, resource: { buffer: output } },
|
|
260
|
+
{ binding: 3, resource: { buffer: gateBuffer } },
|
|
261
|
+
],
|
|
262
|
+
});
|
|
263
|
+
|
|
264
|
+
const workgroups = [Math.ceil(dim / WORKGROUP_SIZES.DEFAULT), numTokens, 1];
|
|
265
|
+
dispatch(device, pipeline, bindGroup, workgroups, 'silu_rowsplit');
|
|
266
|
+
return createTensor(output, input.dtype, [numTokens, dim], 'silu_rowsplit_output');
|
|
267
|
+
} catch (error) {
|
|
268
|
+
cleanupRunResources(null, ownedOutput);
|
|
269
|
+
throw error;
|
|
270
|
+
} finally {
|
|
271
|
+
uniformBuffer.destroy();
|
|
272
|
+
}
|
|
239
273
|
}
|
|
240
274
|
|
|
241
275
|
|
|
@@ -258,6 +292,7 @@ export async function recordSiLURowSplit(
|
|
|
258
292
|
|
|
259
293
|
const outputSize = numTokens * dim * bytesPerElement;
|
|
260
294
|
const output = outputBuffer || acquireBuffer(outputSize, undefined, 'silu_rowsplit_output');
|
|
295
|
+
const ownedOutput = outputBuffer ? null : output;
|
|
261
296
|
|
|
262
297
|
// Uniform buffer
|
|
263
298
|
const uniformBuffer = createUniformBufferWithView(
|
|
@@ -272,22 +307,28 @@ export async function recordSiLURowSplit(
|
|
|
272
307
|
recorder
|
|
273
308
|
);
|
|
274
309
|
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
310
|
+
try {
|
|
311
|
+
const gateBuffer = input.buffer;
|
|
312
|
+
const bindGroup = device.createBindGroup({
|
|
313
|
+
label: 'silu_rowsplit_bind_group',
|
|
314
|
+
layout: pipeline.getBindGroupLayout(0),
|
|
315
|
+
entries: [
|
|
316
|
+
{ binding: 0, resource: { buffer: uniformBuffer } },
|
|
317
|
+
{ binding: 1, resource: { buffer: input.buffer } },
|
|
318
|
+
{ binding: 2, resource: { buffer: output } },
|
|
319
|
+
{ binding: 3, resource: { buffer: gateBuffer } },
|
|
320
|
+
],
|
|
321
|
+
});
|
|
322
|
+
|
|
323
|
+
const workgroups = [Math.ceil(dim / WORKGROUP_SIZES.DEFAULT), numTokens, 1];
|
|
324
|
+
recordDispatch(recorder, pipeline, bindGroup, workgroups, 'silu_rowsplit');
|
|
325
|
+
return createTensor(output, input.dtype, [numTokens, dim], 'silu_rowsplit_output');
|
|
326
|
+
} catch (error) {
|
|
327
|
+
if (ownedOutput) {
|
|
328
|
+
releaseBuffer(ownedOutput);
|
|
329
|
+
}
|
|
330
|
+
throw error;
|
|
331
|
+
}
|
|
291
332
|
}
|
|
292
333
|
|
|
293
334
|
|
|
@@ -328,6 +369,7 @@ export async function recordSiLU(
|
|
|
328
369
|
const inferredSize = size || (input.buffer.size / bytesPerElement);
|
|
329
370
|
const outputSize = inferredSize * bytesPerElement;
|
|
330
371
|
const output = outputBuffer || acquireBuffer(outputSize, undefined, 'silu_output');
|
|
372
|
+
const ownedOutput = outputBuffer ? null : output;
|
|
331
373
|
const dispatchPlan = planSiLUDispatch(device, inferredSize, false);
|
|
332
374
|
|
|
333
375
|
// Uniform buffer
|
|
@@ -346,13 +388,19 @@ export async function recordSiLU(
|
|
|
346
388
|
// Create bind group using helper
|
|
347
389
|
const entries = createSiLUBindGroupEntries(uniformBuffer, input, output, gate);
|
|
348
390
|
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
391
|
+
try {
|
|
392
|
+
const bindGroup = device.createBindGroup({
|
|
393
|
+
label: 'silu_bind_group',
|
|
394
|
+
layout: pipeline.getBindGroupLayout(0),
|
|
395
|
+
entries,
|
|
396
|
+
});
|
|
397
|
+
|
|
398
|
+
recordDispatch(recorder, pipeline, bindGroup, dispatchPlan.workgroups, 'silu');
|
|
399
|
+
return createTensor(output, input.dtype, [inferredSize], 'silu_output');
|
|
400
|
+
} catch (error) {
|
|
401
|
+
if (ownedOutput) {
|
|
402
|
+
releaseBuffer(ownedOutput);
|
|
403
|
+
}
|
|
404
|
+
throw error;
|
|
405
|
+
}
|
|
358
406
|
}
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
|
|
2
2
|
import { getKernelCapabilities } from '../device.js';
|
|
3
|
-
import { acquireBuffer } from '../../memory/buffer-pool.js';
|
|
3
|
+
import { acquireBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
|
|
4
4
|
import { createTensor } from '../tensor.js';
|
|
5
5
|
import { unifiedKernelWrapper } from './utils.js';
|
|
6
6
|
import { createPipeline, createUniformBufferWithView, createBindGroupWithValidation } from './utils.js';
|
|
@@ -20,23 +20,34 @@ function selectSoftmaxVariant(innerSize) {
|
|
|
20
20
|
|
|
21
21
|
async function _softmax(target, input, axis, options = {}) {
|
|
22
22
|
const { batchSize = 1, size, seqLen, temperature = 1.0, outputBuffer = null } = options;
|
|
23
|
+
if (input.dtype !== 'f32') {
|
|
24
|
+
throw new Error(`Softmax requires f32 input, got ${input.dtype}.`);
|
|
25
|
+
}
|
|
23
26
|
|
|
24
|
-
const bytesPerElement =
|
|
27
|
+
const bytesPerElement = 4;
|
|
25
28
|
const inferredSize = size || seqLen || (input.buffer.size / (batchSize * bytesPerElement));
|
|
26
29
|
const variant = selectSoftmaxVariant(inferredSize);
|
|
27
30
|
trace.kernels(`Softmax: size=${inferredSize}, variant=${variant}`);
|
|
28
31
|
|
|
29
32
|
const outputSize = batchSize * inferredSize * bytesPerElement;
|
|
30
33
|
const output = outputBuffer || acquireBuffer(outputSize, undefined, 'softmax_output');
|
|
34
|
+
const ownedOutput = outputBuffer ? null : output;
|
|
35
|
+
|
|
36
|
+
try {
|
|
37
|
+
await unifiedKernelWrapper(
|
|
38
|
+
'softmax', target, variant,
|
|
39
|
+
[input, output],
|
|
40
|
+
{ inner_size: inferredSize, outer_size: batchSize, temperature },
|
|
41
|
+
batchSize
|
|
42
|
+
);
|
|
43
|
+
} catch (error) {
|
|
44
|
+
if (ownedOutput) {
|
|
45
|
+
releaseBuffer(ownedOutput);
|
|
46
|
+
}
|
|
47
|
+
throw error;
|
|
48
|
+
}
|
|
31
49
|
|
|
32
|
-
|
|
33
|
-
'softmax', target, variant,
|
|
34
|
-
[input, output],
|
|
35
|
-
{ inner_size: inferredSize, outer_size: batchSize, temperature },
|
|
36
|
-
batchSize
|
|
37
|
-
);
|
|
38
|
-
|
|
39
|
-
return createTensor(output, input.dtype, [batchSize, inferredSize], 'softmax_output');
|
|
50
|
+
return createTensor(output, 'f32', [batchSize, inferredSize], 'softmax_output');
|
|
40
51
|
}
|
|
41
52
|
|
|
42
53
|
export async function runSoftmax(input, axis, options = {}) {
|
|
@@ -76,6 +87,7 @@ export async function runSoftmaxTopK(logits, numTokens, numExperts, topK, option
|
|
|
76
87
|
|
|
77
88
|
const indices = acquireBuffer(indicesSize, undefined, 'softmax_topk_indices');
|
|
78
89
|
const weights = acquireBuffer(weightsSize, undefined, 'softmax_topk_weights');
|
|
90
|
+
let completed = false;
|
|
79
91
|
|
|
80
92
|
const uniformBuffer = createUniformBufferWithView(
|
|
81
93
|
'softmax_topk_uniforms', 16,
|
|
@@ -88,19 +100,26 @@ export async function runSoftmaxTopK(logits, numTokens, numExperts, topK, option
|
|
|
88
100
|
null, device
|
|
89
101
|
);
|
|
90
102
|
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
103
|
+
try {
|
|
104
|
+
const bindGroup = await createBindGroupWithValidation(device, {
|
|
105
|
+
label: 'softmax_topk_bind_group',
|
|
106
|
+
layout: pipeline.getBindGroupLayout(0),
|
|
107
|
+
entries: [
|
|
108
|
+
{ binding: 0, resource: { buffer: uniformBuffer } },
|
|
109
|
+
{ binding: 1, resource: { buffer: logits } },
|
|
110
|
+
{ binding: 2, resource: { buffer: indices } },
|
|
111
|
+
{ binding: 3, resource: { buffer: weights } },
|
|
112
|
+
],
|
|
113
|
+
}, `topk:${variant}`);
|
|
114
|
+
|
|
115
|
+
dispatchKernel(null, pipeline, bindGroup, numTokens, 'softmax_topk');
|
|
116
|
+
completed = true;
|
|
117
|
+
return { indices, weights };
|
|
118
|
+
} finally {
|
|
119
|
+
uniformBuffer.destroy();
|
|
120
|
+
if (!completed) {
|
|
121
|
+
releaseBuffer(indices);
|
|
122
|
+
releaseBuffer(weights);
|
|
123
|
+
}
|
|
124
|
+
}
|
|
106
125
|
}
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Split Q and Gate Kernel
|
|
3
|
+
*
|
|
4
|
+
* De-interleaves Q and Gate projections from q_proj output for attentionOutputGate models.
|
|
5
|
+
* Models like Qwen 3.5 store q_proj weights in per-head interleaved layout:
|
|
6
|
+
* rows [h*headDim*2 : h*headDim*2+headDim] = Q for head h
|
|
7
|
+
* rows [h*headDim*2+headDim : (h+1)*headDim*2] = Gate for head h
|
|
8
|
+
* This kernel separates the full matmul output into contiguous Q and Gate tensors.
|
|
9
|
+
*/
|
|
10
|
+
|
|
11
|
+
import type { Tensor } from '../tensor.js';
|
|
12
|
+
import type { CommandRecorder } from '../command-recorder.js';
|
|
13
|
+
|
|
14
|
+
/** Split Q and Gate options */
|
|
15
|
+
export interface SplitQGOptions {
|
|
16
|
+
numTokens: number;
|
|
17
|
+
numHeads: number;
|
|
18
|
+
headDim: number;
|
|
19
|
+
/** Pre-allocated Q output tensor */
|
|
20
|
+
qTensor?: Tensor | null;
|
|
21
|
+
/** Pre-allocated Gate output tensor */
|
|
22
|
+
gTensor?: Tensor | null;
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
/** Split Q and Gate result */
|
|
26
|
+
export interface SplitQGResult {
|
|
27
|
+
Q: Tensor;
|
|
28
|
+
G: Tensor;
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
/**
|
|
32
|
+
* De-interleave Q and Gate from q_proj output.
|
|
33
|
+
*
|
|
34
|
+
* @param qgTensor - Full q_proj output [numTokens, numHeads * headDim * 2] (interleaved)
|
|
35
|
+
* @param options - Split configuration
|
|
36
|
+
* @returns Separate Q and Gate tensors, each [numTokens, numHeads * headDim]
|
|
37
|
+
*/
|
|
38
|
+
export declare function runSplitQG(
|
|
39
|
+
qgTensor: Tensor,
|
|
40
|
+
options: SplitQGOptions
|
|
41
|
+
): Promise<SplitQGResult>;
|
|
42
|
+
|
|
43
|
+
/**
|
|
44
|
+
* Record split Q and Gate (batched, no submit).
|
|
45
|
+
*/
|
|
46
|
+
export declare function recordSplitQG(
|
|
47
|
+
recorder: CommandRecorder,
|
|
48
|
+
qgTensor: Tensor,
|
|
49
|
+
options: SplitQGOptions
|
|
50
|
+
): Promise<SplitQGResult>;
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
|
|
2
|
+
import { acquireBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
|
|
3
|
+
import { createTensor, dtypeBytes } from '../tensor.js';
|
|
4
|
+
import { WORKGROUP_SIZES } from './constants.js';
|
|
5
|
+
import { unifiedKernelWrapper } from './utils.js';
|
|
6
|
+
import { selectRuleValue } from './rule-registry.js';
|
|
7
|
+
|
|
8
|
+
async function _splitQG(target, qgTensor, options) {
|
|
9
|
+
const { numTokens, numHeads, headDim, qTensor = null, gTensor = null } = options;
|
|
10
|
+
const ownsQ = qTensor == null;
|
|
11
|
+
const ownsG = gTensor == null;
|
|
12
|
+
|
|
13
|
+
const outputDtype = qgTensor.dtype;
|
|
14
|
+
const pipelineVariant = selectRuleValue('splitQg', 'variant', { outputDtype });
|
|
15
|
+
const bytesPerElement = dtypeBytes(outputDtype);
|
|
16
|
+
const qSize = numHeads * headDim;
|
|
17
|
+
|
|
18
|
+
const qBuffer = qTensor?.buffer || acquireBuffer(numTokens * qSize * bytesPerElement, undefined, 'Q');
|
|
19
|
+
const gBuffer = gTensor?.buffer || acquireBuffer(numTokens * qSize * bytesPerElement, undefined, 'Q_gate');
|
|
20
|
+
|
|
21
|
+
try {
|
|
22
|
+
await unifiedKernelWrapper(
|
|
23
|
+
'split_qg', target, pipelineVariant,
|
|
24
|
+
[qgTensor, qBuffer, gBuffer],
|
|
25
|
+
{ num_tokens: numTokens, num_heads: numHeads, head_dim: headDim, _pad: 0 },
|
|
26
|
+
Math.ceil((numTokens * qSize) / WORKGROUP_SIZES.DEFAULT)
|
|
27
|
+
);
|
|
28
|
+
|
|
29
|
+
const Q = qTensor || createTensor(qBuffer, outputDtype, [numTokens, qSize], 'Q');
|
|
30
|
+
const G = gTensor || createTensor(gBuffer, outputDtype, [numTokens, qSize], 'Q_gate');
|
|
31
|
+
|
|
32
|
+
return { Q, G };
|
|
33
|
+
} catch (error) {
|
|
34
|
+
if (ownsQ) releaseBuffer(qBuffer);
|
|
35
|
+
if (ownsG) releaseBuffer(gBuffer);
|
|
36
|
+
throw error;
|
|
37
|
+
}
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
export async function runSplitQG(qgTensor, options) {
|
|
41
|
+
return _splitQG(null, qgTensor, options);
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
export async function recordSplitQG(recorder, qgTensor, options) {
|
|
45
|
+
return _splitQG(recorder, qgTensor, options);
|
|
46
|
+
}
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
// split_qg.wgsl
|
|
2
|
+
|
|
3
|
+
/**
|
|
4
|
+
* De-interleave Q and Gate projections from q_proj output for attentionOutputGate models.
|
|
5
|
+
*
|
|
6
|
+
* Models like Qwen 3.5 store q_proj weights with interleaved head layout:
|
|
7
|
+
* rows [h*headDim*2 : h*headDim*2+headDim] = Q for head h
|
|
8
|
+
* rows [h*headDim*2+headDim : (h+1)*headDim*2] = Gate for head h
|
|
9
|
+
*
|
|
10
|
+
* A single full matmul over all 2*qSize rows produces interleaved output:
|
|
11
|
+
* input[token, h*headDim*2 : h*headDim*2+headDim] = Q head h
|
|
12
|
+
* input[token, h*headDim*2+headDim : (h+1)*headDim*2] = Gate head h
|
|
13
|
+
*
|
|
14
|
+
* This kernel separates them into contiguous Q and G outputs:
|
|
15
|
+
* Q[token, h*headDim + dim] = input[token, h*headDim*2 + dim]
|
|
16
|
+
* G[token, h*headDim + dim] = input[token, h*headDim*2 + headDim + dim]
|
|
17
|
+
*
|
|
18
|
+
* Input layout (row-major): [numTokens, numHeads * headDim * 2]
|
|
19
|
+
* Output Q layout (row-major): [numTokens, numHeads * headDim]
|
|
20
|
+
* Output G layout (row-major): [numTokens, numHeads * headDim]
|
|
21
|
+
*/
|
|
22
|
+
|
|
23
|
+
struct Params {
|
|
24
|
+
num_tokens: u32,
|
|
25
|
+
num_heads: u32,
|
|
26
|
+
head_dim: u32,
|
|
27
|
+
_pad: u32,
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
override WORKGROUP_SIZE: u32 = 256u;
|
|
31
|
+
|
|
32
|
+
@group(0) @binding(0) var<uniform> params: Params;
|
|
33
|
+
@group(0) @binding(1) var<storage, read> input: array<f32>;
|
|
34
|
+
@group(0) @binding(2) var<storage, read_write> Q: array<f32>;
|
|
35
|
+
@group(0) @binding(3) var<storage, read_write> G: array<f32>;
|
|
36
|
+
|
|
37
|
+
@compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
|
|
38
|
+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
39
|
+
let idx = gid.x;
|
|
40
|
+
let q_size = params.num_heads * params.head_dim;
|
|
41
|
+
let total_elements = params.num_tokens * q_size;
|
|
42
|
+
|
|
43
|
+
if (idx >= total_elements) {
|
|
44
|
+
return;
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
let token = idx / q_size;
|
|
48
|
+
let elem = idx % q_size;
|
|
49
|
+
let head = elem / params.head_dim;
|
|
50
|
+
let dim = elem % params.head_dim;
|
|
51
|
+
|
|
52
|
+
// Input is interleaved per head: [Q_h (headDim elems), G_h (headDim elems)]
|
|
53
|
+
let src_q = token * (q_size * 2u) + head * (params.head_dim * 2u) + dim;
|
|
54
|
+
let src_g = src_q + params.head_dim;
|
|
55
|
+
|
|
56
|
+
Q[idx] = input[src_q];
|
|
57
|
+
G[idx] = input[src_g];
|
|
58
|
+
}
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
// AUTO-GENERATED from src/gpu/kernels/split_qg.wgsl.
|
|
2
|
+
// Edit the source kernel and tools/configs/wgsl-variants.js, then run `npm run kernels:generate`.
|
|
3
|
+
// split_qg_f16.wgsl
|
|
4
|
+
|
|
5
|
+
/**
|
|
6
|
+
* De-interleave Q and Gate projections from q_proj output for attentionOutputGate models (f16).
|
|
7
|
+
*
|
|
8
|
+
* Models like Qwen 3.5 store q_proj weights with interleaved head layout:
|
|
9
|
+
* rows [h*headDim*2 : h*headDim*2+headDim] = Q for head h
|
|
10
|
+
* rows [h*headDim*2+headDim : (h+1)*headDim*2] = Gate for head h
|
|
11
|
+
*
|
|
12
|
+
* A single full matmul over all 2*qSize rows produces interleaved output:
|
|
13
|
+
* input[token, h*headDim*2 : h*headDim*2+headDim] = Q head h
|
|
14
|
+
* input[token, h*headDim*2+headDim : (h+1)*headDim*2] = Gate head h
|
|
15
|
+
*
|
|
16
|
+
* This kernel separates them into contiguous Q and G outputs:
|
|
17
|
+
* Q[token, h*headDim + dim] = input[token, h*headDim*2 + dim]
|
|
18
|
+
* G[token, h*headDim + dim] = input[token, h*headDim*2 + headDim + dim]
|
|
19
|
+
*
|
|
20
|
+
* Input layout (row-major): [numTokens, numHeads * headDim * 2]
|
|
21
|
+
* Output Q layout (row-major): [numTokens, numHeads * headDim]
|
|
22
|
+
* Output G layout (row-major): [numTokens, numHeads * headDim]
|
|
23
|
+
*/
|
|
24
|
+
|
|
25
|
+
enable f16;
|
|
26
|
+
|
|
27
|
+
struct Params {
|
|
28
|
+
num_tokens: u32,
|
|
29
|
+
num_heads: u32,
|
|
30
|
+
head_dim: u32,
|
|
31
|
+
_pad: u32,
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
override WORKGROUP_SIZE: u32 = 256u;
|
|
35
|
+
|
|
36
|
+
@group(0) @binding(0) var<uniform> params: Params;
|
|
37
|
+
@group(0) @binding(1) var<storage, read> input: array<f16>;
|
|
38
|
+
@group(0) @binding(2) var<storage, read_write> Q: array<f16>;
|
|
39
|
+
@group(0) @binding(3) var<storage, read_write> G: array<f16>;
|
|
40
|
+
|
|
41
|
+
@compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
|
|
42
|
+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
43
|
+
let idx = gid.x;
|
|
44
|
+
let q_size = params.num_heads * params.head_dim;
|
|
45
|
+
let total_elements = params.num_tokens * q_size;
|
|
46
|
+
|
|
47
|
+
if (idx >= total_elements) {
|
|
48
|
+
return;
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
let token = idx / q_size;
|
|
52
|
+
let elem = idx % q_size;
|
|
53
|
+
let head = elem / params.head_dim;
|
|
54
|
+
let dim = elem % params.head_dim;
|
|
55
|
+
|
|
56
|
+
// Input is interleaved per head: [Q_h (headDim elems), G_h (headDim elems)]
|
|
57
|
+
let src_q = token * (q_size * 2u) + head * (params.head_dim * 2u) + dim;
|
|
58
|
+
let src_g = src_q + params.head_dim;
|
|
59
|
+
|
|
60
|
+
Q[idx] = input[src_q];
|
|
61
|
+
G[idx] = input[src_g];
|
|
62
|
+
}
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
|
|
2
|
-
import { acquireBuffer } from '../../memory/buffer-pool.js';
|
|
2
|
+
import { acquireBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
|
|
3
3
|
import { createTensor, dtypeBytes } from '../tensor.js';
|
|
4
4
|
import { WORKGROUP_SIZES } from './constants.js';
|
|
5
5
|
import { unifiedKernelWrapper } from './utils.js';
|
|
@@ -7,6 +7,9 @@ import { selectRuleValue } from './rule-registry.js';
|
|
|
7
7
|
|
|
8
8
|
async function _splitQKV(target, qkvTensor, options) {
|
|
9
9
|
const { numTokens, qSize, kSize, vSize, qTensor = null, kTensor = null, vTensor = null } = options;
|
|
10
|
+
const ownsQ = qTensor == null;
|
|
11
|
+
const ownsK = kTensor == null;
|
|
12
|
+
const ownsV = vTensor == null;
|
|
10
13
|
|
|
11
14
|
const outputDtype = qkvTensor.dtype;
|
|
12
15
|
const pipelineVariant = selectRuleValue('splitQkv', 'variant', { outputDtype });
|
|
@@ -18,18 +21,25 @@ async function _splitQKV(target, qkvTensor, options) {
|
|
|
18
21
|
|
|
19
22
|
const totalElements = numTokens * (qSize + kSize + vSize);
|
|
20
23
|
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
24
|
+
try {
|
|
25
|
+
await unifiedKernelWrapper(
|
|
26
|
+
'split_qkv', target, pipelineVariant,
|
|
27
|
+
[qkvTensor, qBuffer, kBuffer, vBuffer],
|
|
28
|
+
{ num_tokens: numTokens, q_size: qSize, k_size: kSize, v_size: vSize },
|
|
29
|
+
Math.ceil(totalElements / WORKGROUP_SIZES.DEFAULT)
|
|
30
|
+
);
|
|
31
|
+
|
|
32
|
+
const Q = qTensor || createTensor(qBuffer, outputDtype, [numTokens, qSize], 'Q');
|
|
33
|
+
const K = kTensor || createTensor(kBuffer, outputDtype, [numTokens, kSize], 'K');
|
|
34
|
+
const V = vTensor || createTensor(vBuffer, outputDtype, [numTokens, vSize], 'V');
|
|
35
|
+
|
|
36
|
+
return { Q, K, V };
|
|
37
|
+
} catch (error) {
|
|
38
|
+
if (ownsQ) releaseBuffer(qBuffer);
|
|
39
|
+
if (ownsK) releaseBuffer(kBuffer);
|
|
40
|
+
if (ownsV) releaseBuffer(vBuffer);
|
|
41
|
+
throw error;
|
|
42
|
+
}
|
|
33
43
|
}
|
|
34
44
|
|
|
35
45
|
export async function runSplitQKV(qkvTensor, options) {
|