@simulatte/doppler 0.1.6 → 0.1.8
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +145 -0
- package/README.md +16 -23
- package/package.json +30 -32
- package/src/adapters/adapter-registry.js +12 -1
- package/src/adapters/lora-loader.js +23 -6
- package/src/bridge/extension-client.d.ts +5 -0
- package/src/bridge/extension-client.js +40 -0
- package/src/bridge/index.d.ts +2 -1
- package/src/bridge/index.js +6 -4
- package/src/browser/browser-converter.js +31 -1
- package/src/browser/file-picker.js +6 -0
- package/src/browser/safetensors-parser-browser.js +84 -1
- package/src/browser/shard-io-browser.js +2 -2
- package/src/browser/tensor-source-download.js +8 -2
- package/src/browser/tensor-source-http.d.ts +1 -0
- package/src/browser/tensor-source-http.js +5 -1
- package/src/client/doppler-api.browser.js +20 -4
- package/src/client/doppler-api.js +19 -3
- package/src/client/doppler-provider/generation.js +12 -0
- package/src/client/doppler-provider/model-manager.d.ts +10 -0
- package/src/client/doppler-provider/model-manager.js +91 -19
- package/src/client/doppler-provider/source-runtime.d.ts +2 -1
- package/src/client/doppler-provider/source-runtime.js +132 -13
- package/src/client/doppler-registry.json +5 -20
- package/src/config/backward-registry-loader.js +17 -2
- package/src/config/execution-v0-contract-check.js +113 -15
- package/src/config/kernel-path-contract-check.js +57 -29
- package/src/config/kernel-path-loader.d.ts +5 -0
- package/src/config/kernel-path-loader.js +18 -36
- package/src/config/kernels/kernel-ref-digests.js +1 -1
- package/src/config/kernels/registry.js +14 -1
- package/src/config/kernels/registry.json +81 -5
- package/src/config/loader.d.ts +1 -1
- package/src/config/loader.js +15 -2
- package/src/config/merge-contract-check.js +66 -4
- package/src/config/merge-helpers.js +128 -7
- package/src/config/merge.d.ts +1 -0
- package/src/config/merge.js +10 -0
- package/src/config/param-validator.js +47 -2
- package/src/config/presets/kernel-paths/{gemma2-q4k-dequant-f32a.json → gemma2-q4k-dequant-f32a-nosubgroups.json} +3 -3
- package/src/config/presets/kernel-paths/gemma3-f16-fused-f32a-online-streamingprefill.json +223 -0
- package/src/config/presets/kernel-paths/{gemma3-q4k-dequant-f32a.json → gemma3-q4k-dequant-f32a-nosubgroups.json} +3 -3
- package/src/config/presets/kernel-paths/gemma3-q4k-dequant-f32w-f32a-online.json +56 -0
- package/src/config/presets/kernel-paths/lfm2-q4k-dequant-f32a-nosubgroups.json +61 -0
- package/src/config/presets/kernel-paths/registry.json +43 -8
- package/src/config/presets/models/gemma2.json +3 -2
- package/src/config/presets/models/gemma3.json +2 -0
- package/src/config/presets/models/qwen3.json +4 -3
- package/src/config/presets/models/qwen3_5.json +16 -0
- package/src/config/presets/runtime/experiments/bench/gemma3-bench-q4k.json +1 -1
- package/src/config/presets/runtime/experiments/debug/gemma3-debug-q4k.json +1 -1
- package/src/config/presets/runtime/experiments/verify/gemma3-verify.json +1 -1
- package/src/config/presets/runtime/kernels/dequant-f16-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/dequant-f32-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/embeddinggemma-q4k-dequant-f32a.json +37 -0
- package/src/config/presets/runtime/kernels/fused-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f16a.json +33 -0
- package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f32a-nosubgroups.json +33 -0
- package/src/config/presets/runtime/kernels/gemma2-q4k-fused-f32a.json +33 -0
- package/src/config/presets/runtime/kernels/safe-q4k.json +6 -13
- package/src/config/presets/runtime/model/qwen3-5-layer-probe.json +52 -0
- package/src/config/presets/runtime/model/qwen3-5-linear-attn-debug.json +90 -0
- package/src/config/presets/runtime/platform/metal-apple-q4k.json +1 -1
- package/src/config/runtime.js +6 -1
- package/src/config/schema/conversion.schema.d.ts +1 -0
- package/src/config/schema/debug.schema.d.ts +5 -0
- package/src/config/schema/doppler.schema.js +16 -21
- package/src/config/schema/inference-defaults.schema.js +3 -3
- package/src/config/schema/kernel-path.schema.d.ts +5 -1
- package/src/config/schema/kernel-thresholds.schema.js +12 -4
- package/src/config/schema/manifest.schema.d.ts +3 -2
- package/src/config/schema/manifest.schema.js +17 -4
- package/src/config/schema/storage.schema.js +1 -1
- package/src/config/training-defaults.js +30 -22
- package/src/converter/conversion-plan.js +104 -11
- package/src/converter/core.d.ts +7 -0
- package/src/converter/core.js +16 -9
- package/src/converter/execution-v0-manifest.js +4 -1
- package/src/converter/index.d.ts +1 -0
- package/src/converter/index.js +1 -0
- package/src/converter/manifest-inference.js +50 -29
- package/src/converter/parsers/diffusion.js +0 -3
- package/src/converter/parsers/transformer.js +4 -0
- package/src/converter/quantization-info.js +40 -16
- package/src/converter/quantizer.js +19 -12
- package/src/converter/rope-config.js +8 -6
- package/src/converter/shard-packer.d.ts +1 -1
- package/src/converter/shard-packer.js +4 -1
- package/src/converter/tokenizer-utils.d.ts +1 -0
- package/src/converter/tokenizer-utils.js +4 -1
- package/src/debug/config.js +123 -11
- package/src/debug/reference/hf_qwen35_linear_attn_debug.py +268 -0
- package/src/debug/signals.js +7 -1
- package/src/debug/tensor.d.ts +2 -0
- package/src/debug/tensor.js +13 -2
- package/src/distribution/p2p-control-plane.js +52 -12
- package/src/distribution/p2p-observability.js +43 -7
- package/src/distribution/p2p-webrtc-browser.js +20 -0
- package/src/distribution/shard-delivery.js +83 -27
- package/src/formats/gguf/types.js +33 -16
- package/src/formats/rdrr/groups.d.ts +12 -4
- package/src/formats/rdrr/groups.js +3 -6
- package/src/formats/rdrr/parsing.d.ts +4 -0
- package/src/formats/rdrr/parsing.js +53 -3
- package/src/formats/rdrr/types.d.ts +2 -1
- package/src/gpu/command-recorder.js +86 -61
- package/src/gpu/device.d.ts +1 -0
- package/src/gpu/device.js +73 -19
- package/src/gpu/kernel-tuner/benchmarks.js +326 -316
- package/src/gpu/kernel-tuner/cache.js +71 -4
- package/src/gpu/kernel-tuner/tuner.js +22 -4
- package/src/gpu/kernels/attention.js +15 -34
- package/src/gpu/kernels/backward/adam.js +62 -58
- package/src/gpu/kernels/backward/attention_backward.js +257 -169
- package/src/gpu/kernels/backward/conv2d_backward.js +14 -1
- package/src/gpu/kernels/cast.js +191 -149
- package/src/gpu/kernels/check-stop.js +33 -44
- package/src/gpu/kernels/conv2d.js +27 -17
- package/src/gpu/kernels/cross_entropy_loss.js +21 -15
- package/src/gpu/kernels/depthwise_conv2d.js +36 -26
- package/src/gpu/kernels/dequant.js +178 -126
- package/src/gpu/kernels/energy.d.ts +3 -21
- package/src/gpu/kernels/energy.js +111 -88
- package/src/gpu/kernels/feature-check.js +1 -1
- package/src/gpu/kernels/fused_ffn.js +84 -65
- package/src/gpu/kernels/fused_matmul_residual.js +56 -33
- package/src/gpu/kernels/fused_matmul_rmsnorm.js +62 -45
- package/src/gpu/kernels/gather.js +33 -15
- package/src/gpu/kernels/gelu.js +19 -11
- package/src/gpu/kernels/grouped_pointwise_conv2d.js +33 -23
- package/src/gpu/kernels/groupnorm.js +34 -23
- package/src/gpu/kernels/index.d.ts +8 -0
- package/src/gpu/kernels/index.js +6 -0
- package/src/gpu/kernels/kv-quantize.js +5 -2
- package/src/gpu/kernels/layernorm.js +35 -19
- package/src/gpu/kernels/logit-merge.js +5 -3
- package/src/gpu/kernels/matmul-selection.js +47 -4
- package/src/gpu/kernels/matmul.d.ts +2 -0
- package/src/gpu/kernels/matmul.js +59 -40
- package/src/gpu/kernels/modulate.js +23 -15
- package/src/gpu/kernels/moe.js +221 -175
- package/src/gpu/kernels/pixel_shuffle.js +22 -14
- package/src/gpu/kernels/relu.js +18 -10
- package/src/gpu/kernels/repeat_channels.js +25 -17
- package/src/gpu/kernels/residual.js +37 -27
- package/src/gpu/kernels/rmsnorm.js +66 -43
- package/src/gpu/kernels/rope.js +3 -0
- package/src/gpu/kernels/sample.js +27 -38
- package/src/gpu/kernels/sana_linear_attention.js +18 -10
- package/src/gpu/kernels/scale.js +18 -11
- package/src/gpu/kernels/shader-cache.js +4 -2
- package/src/gpu/kernels/silu.js +120 -72
- package/src/gpu/kernels/softmax.js +44 -25
- package/src/gpu/kernels/split_qg.d.ts +50 -0
- package/src/gpu/kernels/split_qg.js +46 -0
- package/src/gpu/kernels/split_qg.wgsl +58 -0
- package/src/gpu/kernels/split_qg_f16.wgsl +62 -0
- package/src/gpu/kernels/split_qkv.js +23 -13
- package/src/gpu/kernels/transpose.js +18 -10
- package/src/gpu/kernels/transpose.wgsl +5 -3
- package/src/gpu/kernels/upsample2d.js +21 -13
- package/src/gpu/kernels/utils.js +20 -13
- package/src/gpu/partitioned-buffer-pool.js +10 -2
- package/src/gpu/perf-guards.js +2 -9
- package/src/gpu/profiler.js +27 -22
- package/src/gpu/readback-utils.d.ts +16 -0
- package/src/gpu/readback-utils.js +41 -0
- package/src/gpu/submit-tracker.js +13 -0
- package/src/gpu/uniform-cache.d.ts +1 -0
- package/src/gpu/uniform-cache.js +30 -9
- package/src/gpu/weight-buffer.d.ts +1 -1
- package/src/gpu/weight-buffer.js +1 -1
- package/src/hotswap/intent-bundle.js +6 -0
- package/src/hotswap/manifest.d.ts +10 -1
- package/src/hotswap/manifest.js +12 -2
- package/src/hotswap/runtime.js +30 -8
- package/src/index-browser.d.ts +44 -0
- package/src/index-browser.js +14 -0
- package/src/inference/browser-harness-contract-helpers.d.ts +5 -0
- package/src/inference/browser-harness-contract-helpers.js +28 -0
- package/src/inference/browser-harness-diffusion-energy-suites.d.ts +2 -0
- package/src/inference/browser-harness-diffusion-energy-suites.js +269 -0
- package/src/inference/browser-harness-model-helpers.d.ts +16 -0
- package/src/inference/browser-harness-model-helpers.js +217 -0
- package/src/inference/browser-harness-report-helpers.d.ts +7 -0
- package/src/inference/browser-harness-report-helpers.js +42 -0
- package/src/inference/browser-harness-runtime-helpers.d.ts +61 -0
- package/src/inference/browser-harness-runtime-helpers.js +415 -0
- package/src/inference/browser-harness-suite-helpers.d.ts +28 -0
- package/src/inference/browser-harness-suite-helpers.js +268 -0
- package/src/inference/browser-harness-text-helpers.d.ts +27 -0
- package/src/inference/browser-harness-text-helpers.js +788 -0
- package/src/inference/browser-harness.d.ts +8 -0
- package/src/inference/browser-harness.js +149 -1996
- package/src/inference/kv-cache/base.js +140 -94
- package/src/inference/kv-cache/tiered.js +5 -3
- package/src/inference/moe-router.js +88 -56
- package/src/inference/multi-model-network.js +5 -3
- package/src/inference/network-evolution.d.ts +11 -2
- package/src/inference/network-evolution.js +20 -21
- package/src/inference/pipelines/context.d.ts +3 -0
- package/src/inference/pipelines/context.js +142 -2
- package/src/inference/pipelines/diffusion/helpers.js +10 -2
- package/src/inference/pipelines/diffusion/pipeline.js +2 -1
- package/src/inference/pipelines/diffusion/sd3-transformer.js +10 -10
- package/src/inference/pipelines/diffusion/text-encoder-gpu.js +8 -2
- package/src/inference/pipelines/diffusion/vae.js +3 -7
- package/src/inference/pipelines/energy/pipeline.js +27 -21
- package/src/inference/pipelines/energy/quintel.d.ts +5 -0
- package/src/inference/pipelines/energy/quintel.js +11 -0
- package/src/inference/pipelines/energy-head/row-head-pipeline.js +17 -13
- package/src/inference/pipelines/structured/json-head-pipeline.js +26 -11
- package/src/inference/pipelines/text/attention/output-projection.d.ts +12 -0
- package/src/inference/pipelines/text/attention/output-projection.js +8 -0
- package/src/inference/pipelines/text/attention/projections.d.ts +10 -1
- package/src/inference/pipelines/text/attention/projections.js +192 -112
- package/src/inference/pipelines/text/attention/record.js +77 -14
- package/src/inference/pipelines/text/attention/run.js +112 -14
- package/src/inference/pipelines/text/config.js +17 -4
- package/src/inference/pipelines/text/embed.js +2 -8
- package/src/inference/pipelines/text/execution-plan.js +46 -23
- package/src/inference/pipelines/text/execution-v0-contract-helpers.d.ts +59 -0
- package/src/inference/pipelines/text/execution-v0-contract-helpers.js +937 -0
- package/src/inference/pipelines/text/execution-v0-runtime-builders.d.ts +15 -0
- package/src/inference/pipelines/text/execution-v0-runtime-builders.js +279 -0
- package/src/inference/pipelines/text/execution-v0.js +62 -1013
- package/src/inference/pipelines/text/generator-runtime.js +5 -0
- package/src/inference/pipelines/text/generator-steps.d.ts +52 -0
- package/src/inference/pipelines/text/generator-steps.js +340 -221
- package/src/inference/pipelines/text/generator.js +56 -40
- package/src/inference/pipelines/text/init.d.ts +13 -0
- package/src/inference/pipelines/text/init.js +94 -25
- package/src/inference/pipelines/text/kernel-path-auto-select.js +2 -0
- package/src/inference/pipelines/text/kernel-trace.d.ts +2 -0
- package/src/inference/pipelines/text/kernel-trace.js +6 -0
- package/src/inference/pipelines/text/layer.js +4 -9
- package/src/inference/pipelines/text/linear-attention.d.ts +15 -0
- package/src/inference/pipelines/text/linear-attention.js +113 -9
- package/src/inference/pipelines/text/logits/gpu.js +12 -7
- package/src/inference/pipelines/text/logits/index.d.ts +6 -1
- package/src/inference/pipelines/text/logits/index.js +13 -12
- package/src/inference/pipelines/text/logits/utils.d.ts +7 -0
- package/src/inference/pipelines/text/logits/utils.js +9 -0
- package/src/inference/pipelines/text/lora-apply.js +50 -32
- package/src/inference/pipelines/text/model-load.js +282 -104
- package/src/inference/pipelines/text/moe-cache.js +5 -4
- package/src/inference/pipelines/text/moe-cpu-gptoss.js +74 -69
- package/src/inference/pipelines/text/moe-cpu.js +42 -38
- package/src/inference/pipelines/text/moe-gpu.js +110 -86
- package/src/inference/pipelines/text/ops.js +90 -90
- package/src/inference/pipelines/text/probes.js +9 -9
- package/src/inference/pipelines/text/sampling.js +52 -6
- package/src/inference/pipelines/text/weights.js +17 -7
- package/src/inference/pipelines/text.js +13 -1
- package/src/inference/speculative.d.ts +2 -2
- package/src/inference/speculative.js +4 -18
- package/src/inference/test-harness.d.ts +1 -1
- package/src/inference/test-harness.js +17 -7
- package/src/inference/tokenizer.d.ts +0 -5
- package/src/inference/tokenizer.js +4 -23
- package/src/inference/tokenizers/bpe.js +9 -0
- package/src/inference/tokenizers/bundled.js +20 -0
- package/src/inference/tokenizers/sentencepiece.js +12 -0
- package/src/loader/doppler-loader.js +38 -22
- package/src/loader/dtype-utils.js +3 -44
- package/src/loader/embedding-loader.js +7 -3
- package/src/loader/experts/expert-cache.js +13 -6
- package/src/loader/experts/expert-loader.js +10 -6
- package/src/loader/final-weights-loader.js +10 -4
- package/src/loader/layer-loader.js +2 -1
- package/src/loader/loader-state.js +2 -2
- package/src/loader/memory-monitor.js +8 -0
- package/src/loader/multi-model-loader.d.ts +14 -0
- package/src/loader/multi-model-loader.js +70 -24
- package/src/loader/shard-cache.js +84 -14
- package/src/loader/shard-resolver.js +25 -3
- package/src/loader/tensors/tensor-loader.js +214 -144
- package/src/loader/tensors/tensor-reader.js +76 -19
- package/src/loader/weight-downcast.js +1 -1
- package/src/memory/buffer-pool.d.ts +9 -1
- package/src/memory/buffer-pool.js +109 -44
- package/src/memory/unified-detect.js +1 -1
- package/src/rules/inference/dtype.rules.json +5 -0
- package/src/rules/inference/kernel-path.rules.json +24 -8
- package/src/rules/kernels/split-qg.rules.json +6 -0
- package/src/rules/rule-registry.js +27 -1
- package/src/storage/backends/opfs-store.js +68 -24
- package/src/storage/downloader.js +365 -83
- package/src/storage/index.d.ts +3 -0
- package/src/storage/index.js +3 -0
- package/src/storage/preflight.d.ts +2 -2
- package/src/storage/preflight.js +24 -2
- package/src/storage/quickstart-downloader.js +11 -5
- package/src/storage/registry.js +10 -4
- package/src/storage/reports.js +1 -1
- package/src/storage/shard-manager.d.ts +15 -1
- package/src/storage/shard-manager.js +55 -6
- package/src/storage/source-artifact-store.d.ts +52 -0
- package/src/storage/source-artifact-store.js +234 -0
- package/src/tooling/command-api-constants.d.ts +9 -0
- package/src/tooling/command-api-constants.js +9 -0
- package/src/tooling/command-api-family-normalizers.d.ts +9 -0
- package/src/tooling/command-api-family-normalizers.js +343 -0
- package/src/tooling/command-api-helpers.d.ts +25 -0
- package/src/tooling/command-api-helpers.js +262 -0
- package/src/tooling/command-api.js +16 -602
- package/src/tooling/command-envelope.js +4 -1
- package/src/tooling/command-runner-shared.js +52 -18
- package/src/tooling/conversion-config-materializer.js +3 -5
- package/src/tooling/lean-execution-contract.js +150 -3
- package/src/tooling/node-browser-command-runner.js +161 -271
- package/src/tooling/node-command-runner.js +29 -3
- package/src/tooling/node-converter.js +30 -1
- package/src/tooling/node-source-runtime.d.ts +1 -1
- package/src/tooling/node-source-runtime.js +120 -3
- package/src/tooling/node-webgpu.js +24 -21
- package/src/tooling/opfs-cache.js +21 -4
- package/src/tooling/runtime-input-composition.d.ts +38 -0
- package/src/tooling/runtime-input-composition.js +86 -0
- package/src/tooling/source-runtime-bundle.d.ts +40 -5
- package/src/tooling/source-runtime-bundle.js +261 -34
- package/src/tooling/source-runtime-materializer.d.ts +6 -0
- package/src/tooling/source-runtime-materializer.js +93 -0
- package/src/training/attention-backward.js +32 -17
- package/src/training/autograd.js +80 -52
- package/src/training/checkpoint-watch.d.ts +2 -1
- package/src/training/checkpoint-watch.js +39 -6
- package/src/training/checkpoint.js +40 -11
- package/src/training/clip.js +2 -1
- package/src/training/datasets/token-batch.js +20 -8
- package/src/training/distillation/checkpoint-watch.js +1 -0
- package/src/training/distillation/student-fixture.d.ts +22 -0
- package/src/training/distillation/student-fixture.js +846 -0
- package/src/training/distillation/suite-data.d.ts +45 -0
- package/src/training/distillation/suite-data.js +189 -0
- package/src/training/lora-pipeline.js +4 -7
- package/src/training/lora.js +26 -12
- package/src/training/loss.js +5 -6
- package/src/training/objectives/cross_entropy.js +2 -5
- package/src/training/objectives/distill_kd.js +4 -8
- package/src/training/objectives/distill_triplet.js +4 -8
- package/src/training/objectives/ul_stage2_base.js +4 -8
- package/src/training/operator-command.js +2 -0
- package/src/training/optimizer.js +19 -7
- package/src/training/runner.js +2 -1
- package/src/training/suite.js +18 -978
- package/src/training/tensor-factory.d.ts +9 -0
- package/src/training/tensor-factory.js +13 -0
- package/src/training/trainer.js +3 -5
- package/src/training/ul_dataset.js +3 -5
- package/src/training/workloads.js +70 -79
- package/src/types/model.d.ts +5 -0
- package/src/version.js +1 -1
- package/tools/convert-safetensors-node.js +22 -16
- package/tools/doppler-cli.js +50 -26
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
|
|
2
2
|
|
|
3
3
|
import { getDevice } from '../device.js';
|
|
4
|
-
import { acquireBuffer, getBufferRequestedSize } from '../../memory/buffer-pool.js';
|
|
4
|
+
import { acquireBuffer, getBufferRequestedSize, releaseBuffer } from '../../memory/buffer-pool.js';
|
|
5
5
|
import { createTensor } from '../tensor.js';
|
|
6
6
|
import { getBuffer } from '../weight-buffer.js';
|
|
7
7
|
import { dispatch, recordDispatch } from './dispatch.js';
|
|
@@ -91,7 +91,8 @@ export async function runMatmulRMSNormFused(
|
|
|
91
91
|
// Output buffer: [1, N] - size depends on dtype
|
|
92
92
|
const bytesPerElement = dtype === 'f16' ? 2 : 4;
|
|
93
93
|
const outputSize = N * bytesPerElement;
|
|
94
|
-
const
|
|
94
|
+
const ownedOutput = outputBuffer ? null : acquireBuffer(outputSize, undefined, 'matmul_rmsnorm_fused_output');
|
|
95
|
+
const output = outputBuffer || ownedOutput;
|
|
95
96
|
|
|
96
97
|
// Create uniform buffer (8 u32/f32 = 32 bytes, padded for alignment)
|
|
97
98
|
const uniformBuffer = createUniformBufferWithView(
|
|
@@ -110,36 +111,44 @@ export async function runMatmulRMSNormFused(
|
|
|
110
111
|
);
|
|
111
112
|
|
|
112
113
|
// Create placeholder for residual if not provided
|
|
114
|
+
const ownsResidualBuffer = !residual;
|
|
113
115
|
const residualBuffer = residual || device.createBuffer({
|
|
114
116
|
label: 'matmul_rmsnorm_residual_placeholder',
|
|
115
117
|
size: 4,
|
|
116
118
|
usage: GPUBufferUsage.STORAGE,
|
|
117
119
|
});
|
|
118
120
|
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
121
|
+
try {
|
|
122
|
+
const bindGroup = device.createBindGroup({
|
|
123
|
+
label: 'matmul_rmsnorm_fused_bind_group',
|
|
124
|
+
layout: pipeline.getBindGroupLayout(0),
|
|
125
|
+
entries: [
|
|
126
|
+
{ binding: 0, resource: { buffer: uniformBuffer } },
|
|
127
|
+
{ binding: 1, resource: { buffer: input.buffer } },
|
|
128
|
+
{ binding: 2, resource: { buffer: weightBuffer } },
|
|
129
|
+
{ binding: 3, resource: { buffer: normWeightBuffer } },
|
|
130
|
+
{ binding: 4, resource: { buffer: output } },
|
|
131
|
+
{ binding: 5, resource: { buffer: residualBuffer } },
|
|
132
|
+
],
|
|
133
|
+
});
|
|
134
|
+
|
|
135
|
+
const workgroups = 1;
|
|
136
|
+
const dispatchLabel = label ? `matmul_rmsnorm_fused:${label}` : 'matmul_rmsnorm_fused';
|
|
137
|
+
dispatch(device, pipeline, bindGroup, workgroups, dispatchLabel);
|
|
138
|
+
} catch (error) {
|
|
139
|
+
uniformBuffer.destroy();
|
|
140
|
+
if (ownsResidualBuffer) {
|
|
141
|
+
residualBuffer.destroy();
|
|
142
|
+
}
|
|
143
|
+
if (ownedOutput) {
|
|
144
|
+
releaseBuffer(ownedOutput);
|
|
145
|
+
}
|
|
146
|
+
throw error;
|
|
147
|
+
}
|
|
139
148
|
|
|
140
149
|
// Cleanup
|
|
141
150
|
uniformBuffer.destroy();
|
|
142
|
-
if (
|
|
151
|
+
if (ownsResidualBuffer) residualBuffer.destroy();
|
|
143
152
|
|
|
144
153
|
// Output dtype matches input dtype
|
|
145
154
|
return createTensor(output, input.dtype, [1, N], 'matmul_rmsnorm_fused_output');
|
|
@@ -199,7 +208,8 @@ export async function recordMatmulRMSNormFused(
|
|
|
199
208
|
// Output buffer - size depends on dtype
|
|
200
209
|
const bytesPerElement = dtype === 'f16' ? 2 : 4;
|
|
201
210
|
const outputSize = N * bytesPerElement;
|
|
202
|
-
const
|
|
211
|
+
const ownedOutput = outputBuffer ? null : acquireBuffer(outputSize, undefined, 'matmul_rmsnorm_fused_output');
|
|
212
|
+
const output = outputBuffer || ownedOutput;
|
|
203
213
|
|
|
204
214
|
// Uniform buffer via recorder (8 u32/f32 = 32 bytes, padded for alignment)
|
|
205
215
|
const uniformBuffer = createUniformBufferWithView(
|
|
@@ -217,35 +227,42 @@ export async function recordMatmulRMSNormFused(
|
|
|
217
227
|
);
|
|
218
228
|
|
|
219
229
|
// Placeholder for residual
|
|
230
|
+
const ownsResidualBuffer = !residual;
|
|
220
231
|
const residualBuffer = residual || device.createBuffer({
|
|
221
232
|
label: 'matmul_rmsnorm_residual_placeholder',
|
|
222
233
|
size: 4,
|
|
223
234
|
usage: GPUBufferUsage.STORAGE,
|
|
224
235
|
});
|
|
225
236
|
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
237
|
+
try {
|
|
238
|
+
const bindGroup = device.createBindGroup({
|
|
239
|
+
label: 'matmul_rmsnorm_fused_bind_group',
|
|
240
|
+
layout: pipeline.getBindGroupLayout(0),
|
|
241
|
+
entries: [
|
|
242
|
+
{ binding: 0, resource: { buffer: uniformBuffer } },
|
|
243
|
+
{ binding: 1, resource: { buffer: input.buffer } },
|
|
244
|
+
{ binding: 2, resource: { buffer: weightBuffer } },
|
|
245
|
+
{ binding: 3, resource: { buffer: normWeightBuffer } },
|
|
246
|
+
{ binding: 4, resource: { buffer: output } },
|
|
247
|
+
{ binding: 5, resource: { buffer: residualBuffer } },
|
|
248
|
+
],
|
|
249
|
+
});
|
|
250
|
+
|
|
251
|
+
const workgroups = 1;
|
|
252
|
+
const dispatchLabel = label ? `matmul_rmsnorm_fused:${label}` : 'matmul_rmsnorm_fused';
|
|
253
|
+
recordDispatch(recorder, pipeline, bindGroup, workgroups, dispatchLabel);
|
|
254
|
+
} catch (error) {
|
|
255
|
+
if (ownsResidualBuffer) {
|
|
256
|
+
residualBuffer.destroy();
|
|
257
|
+
}
|
|
258
|
+
if (ownedOutput) {
|
|
259
|
+
releaseBuffer(ownedOutput);
|
|
260
|
+
}
|
|
261
|
+
throw error;
|
|
262
|
+
}
|
|
246
263
|
|
|
247
264
|
// Track placeholder for cleanup
|
|
248
|
-
if (
|
|
265
|
+
if (ownsResidualBuffer) {
|
|
249
266
|
recorder.trackTemporaryBuffer(residualBuffer);
|
|
250
267
|
}
|
|
251
268
|
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import { getKernelCapabilities } from '../device.js';
|
|
2
|
-
import { acquireBuffer } from '../../memory/buffer-pool.js';
|
|
2
|
+
import { acquireBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
|
|
3
3
|
import { WORKGROUP_SIZES, VEC4_ELEMENTS_PER_WG } from './constants.js';
|
|
4
4
|
import { unifiedKernelWrapper } from './utils.js';
|
|
5
5
|
import { trace } from '../../debug/index.js';
|
|
@@ -26,7 +26,6 @@ async function _gather(
|
|
|
26
26
|
options = {}
|
|
27
27
|
) {
|
|
28
28
|
const {
|
|
29
|
-
useVec4 = true,
|
|
30
29
|
outputBuffer = null,
|
|
31
30
|
embeddingDtype,
|
|
32
31
|
outputDtype,
|
|
@@ -43,9 +42,22 @@ async function _gather(
|
|
|
43
42
|
if (outputDtype == null) {
|
|
44
43
|
throw new Error('[Gather] outputDtype is required.');
|
|
45
44
|
}
|
|
45
|
+
if (embeddingDtype === 'f16' && !caps.hasF16) {
|
|
46
|
+
throw new Error('[Gather] embeddingDtype=f16 requires shader-f16 support.');
|
|
47
|
+
}
|
|
48
|
+
if (outputDtype === 'f16' && !caps.hasF16) {
|
|
49
|
+
throw new Error('[Gather] outputDtype=f16 requires shader-f16 support.');
|
|
50
|
+
}
|
|
46
51
|
|
|
47
|
-
const
|
|
48
|
-
const
|
|
52
|
+
const requestedVec4 = options.useVec4;
|
|
53
|
+
const wantsVec4 = requestedVec4 ?? true;
|
|
54
|
+
if (requestedVec4 === true && hiddenSize % 4 !== 0) {
|
|
55
|
+
throw new Error('[Gather] useVec4=true requires hiddenSize to be divisible by 4.');
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
const useF16Input = embeddingDtype === 'f16';
|
|
59
|
+
const useF16Output = outputDtype === 'f16';
|
|
60
|
+
const useVec4 = wantsVec4 && hiddenSize % 4 === 0;
|
|
49
61
|
|
|
50
62
|
trace.embed(
|
|
51
63
|
`Gather: numTokens=${numTokens}, hiddenSize=${hiddenSize}, vocabSize=${vocabSize}, ` +
|
|
@@ -64,6 +76,7 @@ async function _gather(
|
|
|
64
76
|
const paddedHiddenSize = padToQ4KBlock(hiddenSize);
|
|
65
77
|
const outputSize = numTokens * paddedHiddenSize * bytesPerElement;
|
|
66
78
|
const output = outputBuffer || acquireBuffer(outputSize, undefined, 'gather_output');
|
|
79
|
+
const ownedOutput = outputBuffer ? null : output;
|
|
67
80
|
|
|
68
81
|
const uniforms = {
|
|
69
82
|
num_tokens: numTokens,
|
|
@@ -82,16 +95,22 @@ async function _gather(
|
|
|
82
95
|
? Math.ceil((numTokens * hiddenSize) / VEC4_ELEMENTS_PER_WG)
|
|
83
96
|
: Math.ceil((numTokens * hiddenSize) / WORKGROUP_SIZES.DEFAULT));
|
|
84
97
|
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
98
|
+
try {
|
|
99
|
+
await unifiedKernelWrapper(
|
|
100
|
+
'gather',
|
|
101
|
+
target,
|
|
102
|
+
variant,
|
|
103
|
+
[indices, embeddings, output],
|
|
104
|
+
uniforms,
|
|
105
|
+
workgroups
|
|
106
|
+
);
|
|
107
|
+
return createTensor(output, actualDtype, [numTokens, hiddenSize], 'gather_output');
|
|
108
|
+
} catch (error) {
|
|
109
|
+
if (ownedOutput) {
|
|
110
|
+
releaseBuffer(ownedOutput);
|
|
111
|
+
}
|
|
112
|
+
throw error;
|
|
113
|
+
}
|
|
95
114
|
}
|
|
96
115
|
|
|
97
116
|
export async function runGather(
|
|
@@ -116,4 +135,3 @@ export async function recordGather(
|
|
|
116
135
|
) {
|
|
117
136
|
return _gather(recorder, indices, embeddings, numTokens, hiddenSize, vocabSize, options);
|
|
118
137
|
}
|
|
119
|
-
|
package/src/gpu/kernels/gelu.js
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
|
|
2
|
-
import { acquireBuffer } from '../../memory/buffer-pool.js';
|
|
2
|
+
import { acquireBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
|
|
3
3
|
import { createTensor, dtypeBytes } from '../tensor.js';
|
|
4
4
|
import { WORKGROUP_SIZES } from './constants.js';
|
|
5
5
|
import { unifiedKernelWrapper } from './utils.js';
|
|
@@ -26,16 +26,24 @@ async function _gelu(target, input, options = {}) {
|
|
|
26
26
|
const outputSize = inferredSize * bytesPerElement;
|
|
27
27
|
const output = outputBuffer || acquireBuffer(outputSize, undefined, 'gelu_output');
|
|
28
28
|
const gateBuffer = gate ?? input;
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
29
|
+
const ownedOutput = outputBuffer ? null : output;
|
|
30
|
+
|
|
31
|
+
try {
|
|
32
|
+
await unifiedKernelWrapper(
|
|
33
|
+
'gelu', target, variant,
|
|
34
|
+
[input, output, gateBuffer],
|
|
35
|
+
{ size: inferredSize, rowsplit_dim: 0 },
|
|
36
|
+
Math.ceil(inferredSize / WORKGROUP_SIZES.DEFAULT),
|
|
37
|
+
overrides
|
|
38
|
+
);
|
|
39
|
+
|
|
40
|
+
return createTensor(output, input.dtype, [inferredSize], 'gelu_output');
|
|
41
|
+
} catch (error) {
|
|
42
|
+
if (ownedOutput) {
|
|
43
|
+
releaseBuffer(ownedOutput);
|
|
44
|
+
}
|
|
45
|
+
throw error;
|
|
46
|
+
}
|
|
39
47
|
}
|
|
40
48
|
|
|
41
49
|
export async function runGeLU(input, options = {}) {
|
|
@@ -55,33 +55,43 @@ async function _groupedPointwiseConv2D(target, input, weight, bias, options = {}
|
|
|
55
55
|
device.queue.writeBuffer(biasBuffer, 0, new Uint8Array(paddedSize));
|
|
56
56
|
}
|
|
57
57
|
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
58
|
+
try {
|
|
59
|
+
await unifiedKernelWrapper(
|
|
60
|
+
'grouped_pointwise_conv2d',
|
|
61
|
+
target,
|
|
62
|
+
variant,
|
|
63
|
+
[input, weightBuffer, biasBuffer, output],
|
|
64
|
+
{
|
|
65
|
+
in_channels: inChannels,
|
|
66
|
+
out_channels: outChannels,
|
|
67
|
+
height,
|
|
68
|
+
width,
|
|
69
|
+
groups,
|
|
70
|
+
_pad0: 0,
|
|
71
|
+
_pad1: 0,
|
|
72
|
+
_pad2: 0,
|
|
73
|
+
},
|
|
74
|
+
[Math.ceil(spatial / WORKGROUP_SIZES.DEFAULT), outChannels, 1]
|
|
75
|
+
);
|
|
76
|
+
|
|
77
|
+
if (tempBias) {
|
|
78
|
+
if (recorder) {
|
|
79
|
+
recorder.trackTemporaryBuffer(tempBias);
|
|
80
|
+
} else {
|
|
81
|
+
releaseBuffer(tempBias);
|
|
82
|
+
}
|
|
83
|
+
}
|
|
75
84
|
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
} else {
|
|
85
|
+
return createTensor(output, input.dtype, [outChannels, height, width], 'grouped_pointwise_conv2d_output');
|
|
86
|
+
} catch (error) {
|
|
87
|
+
if (tempBias) {
|
|
80
88
|
releaseBuffer(tempBias);
|
|
81
89
|
}
|
|
90
|
+
if (!outputBuffer) {
|
|
91
|
+
releaseBuffer(output);
|
|
92
|
+
}
|
|
93
|
+
throw error;
|
|
82
94
|
}
|
|
83
|
-
|
|
84
|
-
return createTensor(output, input.dtype, [outChannels, height, width], 'grouped_pointwise_conv2d_output');
|
|
85
95
|
}
|
|
86
96
|
|
|
87
97
|
export async function runGroupedPointwiseConv2D(input, weight, bias, options = {}) {
|
|
@@ -17,6 +17,9 @@ function validateOptions(options) {
|
|
|
17
17
|
if (!Number.isFinite(numGroups) || numGroups <= 0) {
|
|
18
18
|
throw new Error('GroupNorm requires numGroups > 0.');
|
|
19
19
|
}
|
|
20
|
+
if (channels % numGroups !== 0) {
|
|
21
|
+
throw new Error('GroupNorm requires channels to be divisible by numGroups.');
|
|
22
|
+
}
|
|
20
23
|
if (!Number.isFinite(eps)) {
|
|
21
24
|
throw new Error('GroupNorm requires eps.');
|
|
22
25
|
}
|
|
@@ -44,34 +47,42 @@ async function _groupNorm(target, input, weight, bias, options = {}) {
|
|
|
44
47
|
|
|
45
48
|
const statsSize = numGroups * 2 * 4;
|
|
46
49
|
const statsBuffer = acquireBuffer(statsSize, undefined, 'groupnorm_stats');
|
|
47
|
-
|
|
48
|
-
await unifiedKernelWrapper(
|
|
49
|
-
'groupnorm_stats',
|
|
50
|
-
target,
|
|
51
|
-
statsVariant,
|
|
52
|
-
[input, statsBuffer],
|
|
53
|
-
uniforms,
|
|
54
|
-
numGroups
|
|
55
|
-
);
|
|
56
|
-
|
|
57
50
|
const bytesPerElement = dtypeBytes(input.dtype);
|
|
58
51
|
const outputSize = channels * height * width * bytesPerElement;
|
|
59
|
-
const
|
|
52
|
+
const ownedOutput = outputBuffer ? null : acquireBuffer(outputSize, undefined, 'groupnorm_output');
|
|
53
|
+
const output = outputBuffer || ownedOutput;
|
|
60
54
|
|
|
61
|
-
|
|
62
|
-
|
|
55
|
+
try {
|
|
56
|
+
await unifiedKernelWrapper(
|
|
57
|
+
'groupnorm_stats',
|
|
58
|
+
target,
|
|
59
|
+
statsVariant,
|
|
60
|
+
[input, statsBuffer],
|
|
61
|
+
uniforms,
|
|
62
|
+
numGroups
|
|
63
|
+
);
|
|
63
64
|
|
|
64
|
-
|
|
65
|
-
|
|
65
|
+
const weightBuffer = getBuffer(weight);
|
|
66
|
+
const biasBuffer = getBuffer(bias);
|
|
66
67
|
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
68
|
+
const total = channels * height * width;
|
|
69
|
+
const workgroups = Math.ceil(total / WORKGROUP_SIZES.DEFAULT);
|
|
70
|
+
|
|
71
|
+
await unifiedKernelWrapper(
|
|
72
|
+
'groupnorm_apply',
|
|
73
|
+
target,
|
|
74
|
+
applyVariant,
|
|
75
|
+
[input, statsBuffer, weightBuffer, biasBuffer, output],
|
|
76
|
+
uniforms,
|
|
77
|
+
workgroups
|
|
78
|
+
);
|
|
79
|
+
} catch (error) {
|
|
80
|
+
releaseBuffer(statsBuffer);
|
|
81
|
+
if (ownedOutput) {
|
|
82
|
+
releaseBuffer(ownedOutput);
|
|
83
|
+
}
|
|
84
|
+
throw error;
|
|
85
|
+
}
|
|
75
86
|
|
|
76
87
|
if (recorder) {
|
|
77
88
|
recorder.trackTemporaryBuffer(statsBuffer);
|
|
@@ -326,6 +326,14 @@ export {
|
|
|
326
326
|
type SplitQKVResult,
|
|
327
327
|
} from './split_qkv.js';
|
|
328
328
|
|
|
329
|
+
// Split Q and Gate (de-interleave attentionOutputGate q_proj output)
|
|
330
|
+
export {
|
|
331
|
+
runSplitQG,
|
|
332
|
+
recordSplitQG,
|
|
333
|
+
type SplitQGOptions,
|
|
334
|
+
type SplitQGResult,
|
|
335
|
+
} from './split_qg.js';
|
|
336
|
+
|
|
329
337
|
// Transpose
|
|
330
338
|
export {
|
|
331
339
|
runTranspose,
|
package/src/gpu/kernels/index.js
CHANGED
|
@@ -268,6 +268,12 @@ export {
|
|
|
268
268
|
recordSplitQKV,
|
|
269
269
|
} from './split_qkv.js';
|
|
270
270
|
|
|
271
|
+
// Split Q and Gate (de-interleave attentionOutputGate q_proj output)
|
|
272
|
+
export {
|
|
273
|
+
runSplitQG,
|
|
274
|
+
recordSplitQG,
|
|
275
|
+
} from './split_qg.js';
|
|
276
|
+
|
|
271
277
|
// Transpose
|
|
272
278
|
export {
|
|
273
279
|
runTranspose,
|
|
@@ -78,8 +78,11 @@ export async function runKVQuantize(
|
|
|
78
78
|
});
|
|
79
79
|
|
|
80
80
|
const workgroups = [numKVHeads, numTokens, 1];
|
|
81
|
-
|
|
82
|
-
|
|
81
|
+
try {
|
|
82
|
+
dispatch(device, pipeline, bindGroup, workgroups, 'kv_quantize');
|
|
83
|
+
} finally {
|
|
84
|
+
uniformBuffer.destroy();
|
|
85
|
+
}
|
|
83
86
|
}
|
|
84
87
|
|
|
85
88
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
|
|
2
2
|
import { getKernelCapabilities } from '../device.js';
|
|
3
|
-
import { acquireBuffer } from '../../memory/buffer-pool.js';
|
|
3
|
+
import { acquireBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
|
|
4
4
|
import { createTensor } from '../tensor.js';
|
|
5
5
|
import { padToQ4KBlock } from '../../config/schema/index.js';
|
|
6
6
|
import { selectRuleValue } from './rule-registry.js';
|
|
@@ -36,17 +36,25 @@ export async function runLayerNorm(
|
|
|
36
36
|
const paddedHiddenSize = padToQ4KBlock(inferredHiddenSize);
|
|
37
37
|
const outputSize = batchSize * paddedHiddenSize * bytesPerElement;
|
|
38
38
|
const outputBuf = outputBuffer || acquireBuffer(outputSize, undefined, 'layernorm_output');
|
|
39
|
+
const ownedOutput = outputBuffer ? null : outputBuf;
|
|
39
40
|
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
41
|
+
try {
|
|
42
|
+
await unifiedKernelWrapper(
|
|
43
|
+
'layernorm',
|
|
44
|
+
null,
|
|
45
|
+
variant,
|
|
46
|
+
[input, weight, bias, outputBuf],
|
|
47
|
+
{ hidden_size: inferredHiddenSize, num_tokens: batchSize, eps },
|
|
48
|
+
batchSize
|
|
49
|
+
);
|
|
48
50
|
|
|
49
|
-
|
|
51
|
+
return createTensor(outputBuf, input.dtype, [batchSize, inferredHiddenSize], 'layernorm_output');
|
|
52
|
+
} catch (error) {
|
|
53
|
+
if (ownedOutput) {
|
|
54
|
+
releaseBuffer(ownedOutput);
|
|
55
|
+
}
|
|
56
|
+
throw error;
|
|
57
|
+
}
|
|
50
58
|
}
|
|
51
59
|
|
|
52
60
|
export async function recordLayerNorm(
|
|
@@ -66,15 +74,23 @@ export async function recordLayerNorm(
|
|
|
66
74
|
const paddedHiddenSize = padToQ4KBlock(inferredHiddenSize);
|
|
67
75
|
const outputSize = batchSize * paddedHiddenSize * bytesPerElement;
|
|
68
76
|
const outputBuf = outputBuffer || acquireBuffer(outputSize, undefined, 'layernorm_output');
|
|
77
|
+
const ownedOutput = outputBuffer ? null : outputBuf;
|
|
69
78
|
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
79
|
+
try {
|
|
80
|
+
await unifiedKernelWrapper(
|
|
81
|
+
'layernorm',
|
|
82
|
+
recorder,
|
|
83
|
+
variant,
|
|
84
|
+
[input, weight, bias, outputBuf],
|
|
85
|
+
{ hidden_size: inferredHiddenSize, num_tokens: batchSize, eps },
|
|
86
|
+
batchSize
|
|
87
|
+
);
|
|
78
88
|
|
|
79
|
-
|
|
89
|
+
return createTensor(outputBuf, input.dtype, [batchSize, inferredHiddenSize], 'layernorm_output');
|
|
90
|
+
} catch (error) {
|
|
91
|
+
if (ownedOutput) {
|
|
92
|
+
releaseBuffer(ownedOutput);
|
|
93
|
+
}
|
|
94
|
+
throw error;
|
|
95
|
+
}
|
|
80
96
|
}
|
|
@@ -266,9 +266,11 @@ export class LogitMergeKernel {
|
|
|
266
266
|
pass.end();
|
|
267
267
|
|
|
268
268
|
this.#device.queue.submit([encoder.finish()]);
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
269
|
+
this.#device.queue.onSubmittedWorkDone()
|
|
270
|
+
.catch(() => {})
|
|
271
|
+
.finally(() => {
|
|
272
|
+
paramsBuffer.destroy();
|
|
273
|
+
});
|
|
272
274
|
|
|
273
275
|
return mergedBuffer;
|
|
274
276
|
}
|
|
@@ -29,7 +29,13 @@ function selectQ4KFusedVariant(isM1, wantF16Output, aDtype) {
|
|
|
29
29
|
}
|
|
30
30
|
|
|
31
31
|
|
|
32
|
-
export function resolveMatmulPhase(M) {
|
|
32
|
+
export function resolveMatmulPhase(M, phaseOverride = null) {
|
|
33
|
+
if (phaseOverride != null) {
|
|
34
|
+
if (phaseOverride !== 'decode' && phaseOverride !== 'prefill') {
|
|
35
|
+
throw new Error(`[Matmul] Invalid phase override "${phaseOverride}". Expected "decode" or "prefill".`);
|
|
36
|
+
}
|
|
37
|
+
return phaseOverride;
|
|
38
|
+
}
|
|
33
39
|
return selectKernelRuleValue('matmul', 'phase', { isDecode: M === 1 });
|
|
34
40
|
}
|
|
35
41
|
|
|
@@ -125,7 +131,9 @@ export function selectMatmulKernel(options = {}) {
|
|
|
125
131
|
const { tiledPrefillMinRows } = getKernelThresholds().matmul;
|
|
126
132
|
|
|
127
133
|
const inputsAreF16 = aDtype === 'f16' && bDtype === 'f16';
|
|
128
|
-
|
|
134
|
+
// F16 weights needing F32a path: weights are F16 and either activation is already F32,
|
|
135
|
+
// or both inputs are F16 but output is F32 (activation will be cast to F32 by executeMatmul)
|
|
136
|
+
const weightsAreF16 = bDtype === 'f16' && (aDtype !== 'f16' || outputDtype !== 'f16');
|
|
129
137
|
const useF16Matmul = outputDtype === 'f16' && preferF16 && inputsAreF16 && capabilities.hasF16;
|
|
130
138
|
const useF16wF32a = preferF16 && weightsAreF16 && capabilities.hasF16;
|
|
131
139
|
const useTiled = isPrefill
|
|
@@ -244,6 +252,30 @@ export function requiresF32Input(variant) {
|
|
|
244
252
|
return !supportsF16Input(variant);
|
|
245
253
|
}
|
|
246
254
|
|
|
255
|
+
function resolveRequiredWeightDtype(config) {
|
|
256
|
+
const shaderFile = String(config?.shaderFile ?? config?.wgsl ?? '');
|
|
257
|
+
if (!shaderFile) {
|
|
258
|
+
return null;
|
|
259
|
+
}
|
|
260
|
+
if (shaderFile.startsWith('fused_matmul_q4')) {
|
|
261
|
+
return 'q4k';
|
|
262
|
+
}
|
|
263
|
+
if (
|
|
264
|
+
shaderFile === 'matmul_f16.wgsl'
|
|
265
|
+
|| shaderFile === 'matmul_f16_tiled.wgsl'
|
|
266
|
+
|| shaderFile === 'matmul_f16w_f32a.wgsl'
|
|
267
|
+
|| shaderFile === 'matmul_f16w_f32a_tiled.wgsl'
|
|
268
|
+
|| shaderFile === 'matmul_gemv_subgroup.wgsl'
|
|
269
|
+
|| shaderFile === 'matmul_gemv_subgroup_f16a.wgsl'
|
|
270
|
+
) {
|
|
271
|
+
return 'f16';
|
|
272
|
+
}
|
|
273
|
+
if (shaderFile === 'matmul_f32.wgsl') {
|
|
274
|
+
return 'f32';
|
|
275
|
+
}
|
|
276
|
+
return null;
|
|
277
|
+
}
|
|
278
|
+
|
|
247
279
|
|
|
248
280
|
function resolveMatmulOverride(
|
|
249
281
|
variantOverride,
|
|
@@ -287,6 +319,16 @@ function resolveMatmulOverride(
|
|
|
287
319
|
);
|
|
288
320
|
}
|
|
289
321
|
|
|
322
|
+
const requiredWeightDtype = resolveRequiredWeightDtype(config);
|
|
323
|
+
const weightDtypeOk = !requiredWeightDtype
|
|
324
|
+
|| bDtype === requiredWeightDtype
|
|
325
|
+
|| (requiredWeightDtype === 'f16' && bDtype === 'q4k');
|
|
326
|
+
if (!weightDtypeOk) {
|
|
327
|
+
return failOrWarn(
|
|
328
|
+
`Matmul kernel "${variantOverride}" requires ${requiredWeightDtype} weights but B dtype is ${bDtype}.`
|
|
329
|
+
);
|
|
330
|
+
}
|
|
331
|
+
|
|
290
332
|
if (supportsF16Input(override) && aDtype !== 'f16') {
|
|
291
333
|
return failOrWarn(`Matmul kernel "${variantOverride}" requires f16 activations but A dtype is ${aDtype}.`);
|
|
292
334
|
}
|
|
@@ -341,7 +383,7 @@ function selectGemvVariant(useF16Gemv, useF32Gemv, hasSubgroups, useVec4, N, mul
|
|
|
341
383
|
export function selectMatmulVariantAndFlags(mode, M, N, K, aDtype, bDtype, transposeB, requestedOutputDtype, options) {
|
|
342
384
|
const capabilities = getKernelCapabilities();
|
|
343
385
|
const strict = getKernelPathStrict();
|
|
344
|
-
const phase = resolveMatmulPhase(M);
|
|
386
|
+
const phase = resolveMatmulPhase(M, options.phaseOverride ?? null);
|
|
345
387
|
let pathVariant = getKernelPathMatmulVariant(options.role, phase, options.layerIdx, options.kernelPath);
|
|
346
388
|
const hadPathVariant = Boolean(pathVariant);
|
|
347
389
|
|
|
@@ -426,7 +468,8 @@ export function selectMatmulVariantAndFlags(mode, M, N, K, aDtype, bDtype, trans
|
|
|
426
468
|
|
|
427
469
|
const canGemv = M === 1 && effectiveBDtype === 'f16' && capabilities.hasF16;
|
|
428
470
|
const useF16Gemv = canGemv && aDtype === 'f16' && wantF16Output;
|
|
429
|
-
|
|
471
|
+
// F32 GEMV: activation is F32, or activation is F16 with F32 output (will be cast to F32)
|
|
472
|
+
const useF32Gemv = canGemv && (aDtype === 'f32' || (aDtype === 'f16' && !wantF16Output));
|
|
430
473
|
const useGemv = useF16Gemv || useF32Gemv;
|
|
431
474
|
const useVec4 = (K % 4 === 0);
|
|
432
475
|
const { multicolThreshold } = getKernelThresholds().matmul;
|
|
@@ -23,6 +23,8 @@ export interface MatmulOptions extends OutputBufferOptions, OutputDtypeOptions,
|
|
|
23
23
|
layerIdx?: number;
|
|
24
24
|
/** Explicit kernel path context for variant selection (avoids global path state). */
|
|
25
25
|
kernelPath?: KernelPathSchema | null;
|
|
26
|
+
/** Optional explicit phase for kernel-path lookup when the runtime rewrites rows (for example prefill last-position logits). */
|
|
27
|
+
phaseOverride?: 'decode' | 'prefill' | null;
|
|
26
28
|
/**
|
|
27
29
|
* Whether B matrix is stored transposed.
|
|
28
30
|
* - true: B is [N,K] (SafeTensors/row-major), needs transpose
|