@simulatte/doppler 0.1.6 → 0.1.8
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +145 -0
- package/README.md +16 -23
- package/package.json +30 -32
- package/src/adapters/adapter-registry.js +12 -1
- package/src/adapters/lora-loader.js +23 -6
- package/src/bridge/extension-client.d.ts +5 -0
- package/src/bridge/extension-client.js +40 -0
- package/src/bridge/index.d.ts +2 -1
- package/src/bridge/index.js +6 -4
- package/src/browser/browser-converter.js +31 -1
- package/src/browser/file-picker.js +6 -0
- package/src/browser/safetensors-parser-browser.js +84 -1
- package/src/browser/shard-io-browser.js +2 -2
- package/src/browser/tensor-source-download.js +8 -2
- package/src/browser/tensor-source-http.d.ts +1 -0
- package/src/browser/tensor-source-http.js +5 -1
- package/src/client/doppler-api.browser.js +20 -4
- package/src/client/doppler-api.js +19 -3
- package/src/client/doppler-provider/generation.js +12 -0
- package/src/client/doppler-provider/model-manager.d.ts +10 -0
- package/src/client/doppler-provider/model-manager.js +91 -19
- package/src/client/doppler-provider/source-runtime.d.ts +2 -1
- package/src/client/doppler-provider/source-runtime.js +132 -13
- package/src/client/doppler-registry.json +5 -20
- package/src/config/backward-registry-loader.js +17 -2
- package/src/config/execution-v0-contract-check.js +113 -15
- package/src/config/kernel-path-contract-check.js +57 -29
- package/src/config/kernel-path-loader.d.ts +5 -0
- package/src/config/kernel-path-loader.js +18 -36
- package/src/config/kernels/kernel-ref-digests.js +1 -1
- package/src/config/kernels/registry.js +14 -1
- package/src/config/kernels/registry.json +81 -5
- package/src/config/loader.d.ts +1 -1
- package/src/config/loader.js +15 -2
- package/src/config/merge-contract-check.js +66 -4
- package/src/config/merge-helpers.js +128 -7
- package/src/config/merge.d.ts +1 -0
- package/src/config/merge.js +10 -0
- package/src/config/param-validator.js +47 -2
- package/src/config/presets/kernel-paths/{gemma2-q4k-dequant-f32a.json → gemma2-q4k-dequant-f32a-nosubgroups.json} +3 -3
- package/src/config/presets/kernel-paths/gemma3-f16-fused-f32a-online-streamingprefill.json +223 -0
- package/src/config/presets/kernel-paths/{gemma3-q4k-dequant-f32a.json → gemma3-q4k-dequant-f32a-nosubgroups.json} +3 -3
- package/src/config/presets/kernel-paths/gemma3-q4k-dequant-f32w-f32a-online.json +56 -0
- package/src/config/presets/kernel-paths/lfm2-q4k-dequant-f32a-nosubgroups.json +61 -0
- package/src/config/presets/kernel-paths/registry.json +43 -8
- package/src/config/presets/models/gemma2.json +3 -2
- package/src/config/presets/models/gemma3.json +2 -0
- package/src/config/presets/models/qwen3.json +4 -3
- package/src/config/presets/models/qwen3_5.json +16 -0
- package/src/config/presets/runtime/experiments/bench/gemma3-bench-q4k.json +1 -1
- package/src/config/presets/runtime/experiments/debug/gemma3-debug-q4k.json +1 -1
- package/src/config/presets/runtime/experiments/verify/gemma3-verify.json +1 -1
- package/src/config/presets/runtime/kernels/dequant-f16-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/dequant-f32-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/embeddinggemma-q4k-dequant-f32a.json +37 -0
- package/src/config/presets/runtime/kernels/fused-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f16a.json +33 -0
- package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f32a-nosubgroups.json +33 -0
- package/src/config/presets/runtime/kernels/gemma2-q4k-fused-f32a.json +33 -0
- package/src/config/presets/runtime/kernels/safe-q4k.json +6 -13
- package/src/config/presets/runtime/model/qwen3-5-layer-probe.json +52 -0
- package/src/config/presets/runtime/model/qwen3-5-linear-attn-debug.json +90 -0
- package/src/config/presets/runtime/platform/metal-apple-q4k.json +1 -1
- package/src/config/runtime.js +6 -1
- package/src/config/schema/conversion.schema.d.ts +1 -0
- package/src/config/schema/debug.schema.d.ts +5 -0
- package/src/config/schema/doppler.schema.js +16 -21
- package/src/config/schema/inference-defaults.schema.js +3 -3
- package/src/config/schema/kernel-path.schema.d.ts +5 -1
- package/src/config/schema/kernel-thresholds.schema.js +12 -4
- package/src/config/schema/manifest.schema.d.ts +3 -2
- package/src/config/schema/manifest.schema.js +17 -4
- package/src/config/schema/storage.schema.js +1 -1
- package/src/config/training-defaults.js +30 -22
- package/src/converter/conversion-plan.js +104 -11
- package/src/converter/core.d.ts +7 -0
- package/src/converter/core.js +16 -9
- package/src/converter/execution-v0-manifest.js +4 -1
- package/src/converter/index.d.ts +1 -0
- package/src/converter/index.js +1 -0
- package/src/converter/manifest-inference.js +50 -29
- package/src/converter/parsers/diffusion.js +0 -3
- package/src/converter/parsers/transformer.js +4 -0
- package/src/converter/quantization-info.js +40 -16
- package/src/converter/quantizer.js +19 -12
- package/src/converter/rope-config.js +8 -6
- package/src/converter/shard-packer.d.ts +1 -1
- package/src/converter/shard-packer.js +4 -1
- package/src/converter/tokenizer-utils.d.ts +1 -0
- package/src/converter/tokenizer-utils.js +4 -1
- package/src/debug/config.js +123 -11
- package/src/debug/reference/hf_qwen35_linear_attn_debug.py +268 -0
- package/src/debug/signals.js +7 -1
- package/src/debug/tensor.d.ts +2 -0
- package/src/debug/tensor.js +13 -2
- package/src/distribution/p2p-control-plane.js +52 -12
- package/src/distribution/p2p-observability.js +43 -7
- package/src/distribution/p2p-webrtc-browser.js +20 -0
- package/src/distribution/shard-delivery.js +83 -27
- package/src/formats/gguf/types.js +33 -16
- package/src/formats/rdrr/groups.d.ts +12 -4
- package/src/formats/rdrr/groups.js +3 -6
- package/src/formats/rdrr/parsing.d.ts +4 -0
- package/src/formats/rdrr/parsing.js +53 -3
- package/src/formats/rdrr/types.d.ts +2 -1
- package/src/gpu/command-recorder.js +86 -61
- package/src/gpu/device.d.ts +1 -0
- package/src/gpu/device.js +73 -19
- package/src/gpu/kernel-tuner/benchmarks.js +326 -316
- package/src/gpu/kernel-tuner/cache.js +71 -4
- package/src/gpu/kernel-tuner/tuner.js +22 -4
- package/src/gpu/kernels/attention.js +15 -34
- package/src/gpu/kernels/backward/adam.js +62 -58
- package/src/gpu/kernels/backward/attention_backward.js +257 -169
- package/src/gpu/kernels/backward/conv2d_backward.js +14 -1
- package/src/gpu/kernels/cast.js +191 -149
- package/src/gpu/kernels/check-stop.js +33 -44
- package/src/gpu/kernels/conv2d.js +27 -17
- package/src/gpu/kernels/cross_entropy_loss.js +21 -15
- package/src/gpu/kernels/depthwise_conv2d.js +36 -26
- package/src/gpu/kernels/dequant.js +178 -126
- package/src/gpu/kernels/energy.d.ts +3 -21
- package/src/gpu/kernels/energy.js +111 -88
- package/src/gpu/kernels/feature-check.js +1 -1
- package/src/gpu/kernels/fused_ffn.js +84 -65
- package/src/gpu/kernels/fused_matmul_residual.js +56 -33
- package/src/gpu/kernels/fused_matmul_rmsnorm.js +62 -45
- package/src/gpu/kernels/gather.js +33 -15
- package/src/gpu/kernels/gelu.js +19 -11
- package/src/gpu/kernels/grouped_pointwise_conv2d.js +33 -23
- package/src/gpu/kernels/groupnorm.js +34 -23
- package/src/gpu/kernels/index.d.ts +8 -0
- package/src/gpu/kernels/index.js +6 -0
- package/src/gpu/kernels/kv-quantize.js +5 -2
- package/src/gpu/kernels/layernorm.js +35 -19
- package/src/gpu/kernels/logit-merge.js +5 -3
- package/src/gpu/kernels/matmul-selection.js +47 -4
- package/src/gpu/kernels/matmul.d.ts +2 -0
- package/src/gpu/kernels/matmul.js +59 -40
- package/src/gpu/kernels/modulate.js +23 -15
- package/src/gpu/kernels/moe.js +221 -175
- package/src/gpu/kernels/pixel_shuffle.js +22 -14
- package/src/gpu/kernels/relu.js +18 -10
- package/src/gpu/kernels/repeat_channels.js +25 -17
- package/src/gpu/kernels/residual.js +37 -27
- package/src/gpu/kernels/rmsnorm.js +66 -43
- package/src/gpu/kernels/rope.js +3 -0
- package/src/gpu/kernels/sample.js +27 -38
- package/src/gpu/kernels/sana_linear_attention.js +18 -10
- package/src/gpu/kernels/scale.js +18 -11
- package/src/gpu/kernels/shader-cache.js +4 -2
- package/src/gpu/kernels/silu.js +120 -72
- package/src/gpu/kernels/softmax.js +44 -25
- package/src/gpu/kernels/split_qg.d.ts +50 -0
- package/src/gpu/kernels/split_qg.js +46 -0
- package/src/gpu/kernels/split_qg.wgsl +58 -0
- package/src/gpu/kernels/split_qg_f16.wgsl +62 -0
- package/src/gpu/kernels/split_qkv.js +23 -13
- package/src/gpu/kernels/transpose.js +18 -10
- package/src/gpu/kernels/transpose.wgsl +5 -3
- package/src/gpu/kernels/upsample2d.js +21 -13
- package/src/gpu/kernels/utils.js +20 -13
- package/src/gpu/partitioned-buffer-pool.js +10 -2
- package/src/gpu/perf-guards.js +2 -9
- package/src/gpu/profiler.js +27 -22
- package/src/gpu/readback-utils.d.ts +16 -0
- package/src/gpu/readback-utils.js +41 -0
- package/src/gpu/submit-tracker.js +13 -0
- package/src/gpu/uniform-cache.d.ts +1 -0
- package/src/gpu/uniform-cache.js +30 -9
- package/src/gpu/weight-buffer.d.ts +1 -1
- package/src/gpu/weight-buffer.js +1 -1
- package/src/hotswap/intent-bundle.js +6 -0
- package/src/hotswap/manifest.d.ts +10 -1
- package/src/hotswap/manifest.js +12 -2
- package/src/hotswap/runtime.js +30 -8
- package/src/index-browser.d.ts +44 -0
- package/src/index-browser.js +14 -0
- package/src/inference/browser-harness-contract-helpers.d.ts +5 -0
- package/src/inference/browser-harness-contract-helpers.js +28 -0
- package/src/inference/browser-harness-diffusion-energy-suites.d.ts +2 -0
- package/src/inference/browser-harness-diffusion-energy-suites.js +269 -0
- package/src/inference/browser-harness-model-helpers.d.ts +16 -0
- package/src/inference/browser-harness-model-helpers.js +217 -0
- package/src/inference/browser-harness-report-helpers.d.ts +7 -0
- package/src/inference/browser-harness-report-helpers.js +42 -0
- package/src/inference/browser-harness-runtime-helpers.d.ts +61 -0
- package/src/inference/browser-harness-runtime-helpers.js +415 -0
- package/src/inference/browser-harness-suite-helpers.d.ts +28 -0
- package/src/inference/browser-harness-suite-helpers.js +268 -0
- package/src/inference/browser-harness-text-helpers.d.ts +27 -0
- package/src/inference/browser-harness-text-helpers.js +788 -0
- package/src/inference/browser-harness.d.ts +8 -0
- package/src/inference/browser-harness.js +149 -1996
- package/src/inference/kv-cache/base.js +140 -94
- package/src/inference/kv-cache/tiered.js +5 -3
- package/src/inference/moe-router.js +88 -56
- package/src/inference/multi-model-network.js +5 -3
- package/src/inference/network-evolution.d.ts +11 -2
- package/src/inference/network-evolution.js +20 -21
- package/src/inference/pipelines/context.d.ts +3 -0
- package/src/inference/pipelines/context.js +142 -2
- package/src/inference/pipelines/diffusion/helpers.js +10 -2
- package/src/inference/pipelines/diffusion/pipeline.js +2 -1
- package/src/inference/pipelines/diffusion/sd3-transformer.js +10 -10
- package/src/inference/pipelines/diffusion/text-encoder-gpu.js +8 -2
- package/src/inference/pipelines/diffusion/vae.js +3 -7
- package/src/inference/pipelines/energy/pipeline.js +27 -21
- package/src/inference/pipelines/energy/quintel.d.ts +5 -0
- package/src/inference/pipelines/energy/quintel.js +11 -0
- package/src/inference/pipelines/energy-head/row-head-pipeline.js +17 -13
- package/src/inference/pipelines/structured/json-head-pipeline.js +26 -11
- package/src/inference/pipelines/text/attention/output-projection.d.ts +12 -0
- package/src/inference/pipelines/text/attention/output-projection.js +8 -0
- package/src/inference/pipelines/text/attention/projections.d.ts +10 -1
- package/src/inference/pipelines/text/attention/projections.js +192 -112
- package/src/inference/pipelines/text/attention/record.js +77 -14
- package/src/inference/pipelines/text/attention/run.js +112 -14
- package/src/inference/pipelines/text/config.js +17 -4
- package/src/inference/pipelines/text/embed.js +2 -8
- package/src/inference/pipelines/text/execution-plan.js +46 -23
- package/src/inference/pipelines/text/execution-v0-contract-helpers.d.ts +59 -0
- package/src/inference/pipelines/text/execution-v0-contract-helpers.js +937 -0
- package/src/inference/pipelines/text/execution-v0-runtime-builders.d.ts +15 -0
- package/src/inference/pipelines/text/execution-v0-runtime-builders.js +279 -0
- package/src/inference/pipelines/text/execution-v0.js +62 -1013
- package/src/inference/pipelines/text/generator-runtime.js +5 -0
- package/src/inference/pipelines/text/generator-steps.d.ts +52 -0
- package/src/inference/pipelines/text/generator-steps.js +340 -221
- package/src/inference/pipelines/text/generator.js +56 -40
- package/src/inference/pipelines/text/init.d.ts +13 -0
- package/src/inference/pipelines/text/init.js +94 -25
- package/src/inference/pipelines/text/kernel-path-auto-select.js +2 -0
- package/src/inference/pipelines/text/kernel-trace.d.ts +2 -0
- package/src/inference/pipelines/text/kernel-trace.js +6 -0
- package/src/inference/pipelines/text/layer.js +4 -9
- package/src/inference/pipelines/text/linear-attention.d.ts +15 -0
- package/src/inference/pipelines/text/linear-attention.js +113 -9
- package/src/inference/pipelines/text/logits/gpu.js +12 -7
- package/src/inference/pipelines/text/logits/index.d.ts +6 -1
- package/src/inference/pipelines/text/logits/index.js +13 -12
- package/src/inference/pipelines/text/logits/utils.d.ts +7 -0
- package/src/inference/pipelines/text/logits/utils.js +9 -0
- package/src/inference/pipelines/text/lora-apply.js +50 -32
- package/src/inference/pipelines/text/model-load.js +282 -104
- package/src/inference/pipelines/text/moe-cache.js +5 -4
- package/src/inference/pipelines/text/moe-cpu-gptoss.js +74 -69
- package/src/inference/pipelines/text/moe-cpu.js +42 -38
- package/src/inference/pipelines/text/moe-gpu.js +110 -86
- package/src/inference/pipelines/text/ops.js +90 -90
- package/src/inference/pipelines/text/probes.js +9 -9
- package/src/inference/pipelines/text/sampling.js +52 -6
- package/src/inference/pipelines/text/weights.js +17 -7
- package/src/inference/pipelines/text.js +13 -1
- package/src/inference/speculative.d.ts +2 -2
- package/src/inference/speculative.js +4 -18
- package/src/inference/test-harness.d.ts +1 -1
- package/src/inference/test-harness.js +17 -7
- package/src/inference/tokenizer.d.ts +0 -5
- package/src/inference/tokenizer.js +4 -23
- package/src/inference/tokenizers/bpe.js +9 -0
- package/src/inference/tokenizers/bundled.js +20 -0
- package/src/inference/tokenizers/sentencepiece.js +12 -0
- package/src/loader/doppler-loader.js +38 -22
- package/src/loader/dtype-utils.js +3 -44
- package/src/loader/embedding-loader.js +7 -3
- package/src/loader/experts/expert-cache.js +13 -6
- package/src/loader/experts/expert-loader.js +10 -6
- package/src/loader/final-weights-loader.js +10 -4
- package/src/loader/layer-loader.js +2 -1
- package/src/loader/loader-state.js +2 -2
- package/src/loader/memory-monitor.js +8 -0
- package/src/loader/multi-model-loader.d.ts +14 -0
- package/src/loader/multi-model-loader.js +70 -24
- package/src/loader/shard-cache.js +84 -14
- package/src/loader/shard-resolver.js +25 -3
- package/src/loader/tensors/tensor-loader.js +214 -144
- package/src/loader/tensors/tensor-reader.js +76 -19
- package/src/loader/weight-downcast.js +1 -1
- package/src/memory/buffer-pool.d.ts +9 -1
- package/src/memory/buffer-pool.js +109 -44
- package/src/memory/unified-detect.js +1 -1
- package/src/rules/inference/dtype.rules.json +5 -0
- package/src/rules/inference/kernel-path.rules.json +24 -8
- package/src/rules/kernels/split-qg.rules.json +6 -0
- package/src/rules/rule-registry.js +27 -1
- package/src/storage/backends/opfs-store.js +68 -24
- package/src/storage/downloader.js +365 -83
- package/src/storage/index.d.ts +3 -0
- package/src/storage/index.js +3 -0
- package/src/storage/preflight.d.ts +2 -2
- package/src/storage/preflight.js +24 -2
- package/src/storage/quickstart-downloader.js +11 -5
- package/src/storage/registry.js +10 -4
- package/src/storage/reports.js +1 -1
- package/src/storage/shard-manager.d.ts +15 -1
- package/src/storage/shard-manager.js +55 -6
- package/src/storage/source-artifact-store.d.ts +52 -0
- package/src/storage/source-artifact-store.js +234 -0
- package/src/tooling/command-api-constants.d.ts +9 -0
- package/src/tooling/command-api-constants.js +9 -0
- package/src/tooling/command-api-family-normalizers.d.ts +9 -0
- package/src/tooling/command-api-family-normalizers.js +343 -0
- package/src/tooling/command-api-helpers.d.ts +25 -0
- package/src/tooling/command-api-helpers.js +262 -0
- package/src/tooling/command-api.js +16 -602
- package/src/tooling/command-envelope.js +4 -1
- package/src/tooling/command-runner-shared.js +52 -18
- package/src/tooling/conversion-config-materializer.js +3 -5
- package/src/tooling/lean-execution-contract.js +150 -3
- package/src/tooling/node-browser-command-runner.js +161 -271
- package/src/tooling/node-command-runner.js +29 -3
- package/src/tooling/node-converter.js +30 -1
- package/src/tooling/node-source-runtime.d.ts +1 -1
- package/src/tooling/node-source-runtime.js +120 -3
- package/src/tooling/node-webgpu.js +24 -21
- package/src/tooling/opfs-cache.js +21 -4
- package/src/tooling/runtime-input-composition.d.ts +38 -0
- package/src/tooling/runtime-input-composition.js +86 -0
- package/src/tooling/source-runtime-bundle.d.ts +40 -5
- package/src/tooling/source-runtime-bundle.js +261 -34
- package/src/tooling/source-runtime-materializer.d.ts +6 -0
- package/src/tooling/source-runtime-materializer.js +93 -0
- package/src/training/attention-backward.js +32 -17
- package/src/training/autograd.js +80 -52
- package/src/training/checkpoint-watch.d.ts +2 -1
- package/src/training/checkpoint-watch.js +39 -6
- package/src/training/checkpoint.js +40 -11
- package/src/training/clip.js +2 -1
- package/src/training/datasets/token-batch.js +20 -8
- package/src/training/distillation/checkpoint-watch.js +1 -0
- package/src/training/distillation/student-fixture.d.ts +22 -0
- package/src/training/distillation/student-fixture.js +846 -0
- package/src/training/distillation/suite-data.d.ts +45 -0
- package/src/training/distillation/suite-data.js +189 -0
- package/src/training/lora-pipeline.js +4 -7
- package/src/training/lora.js +26 -12
- package/src/training/loss.js +5 -6
- package/src/training/objectives/cross_entropy.js +2 -5
- package/src/training/objectives/distill_kd.js +4 -8
- package/src/training/objectives/distill_triplet.js +4 -8
- package/src/training/objectives/ul_stage2_base.js +4 -8
- package/src/training/operator-command.js +2 -0
- package/src/training/optimizer.js +19 -7
- package/src/training/runner.js +2 -1
- package/src/training/suite.js +18 -978
- package/src/training/tensor-factory.d.ts +9 -0
- package/src/training/tensor-factory.js +13 -0
- package/src/training/trainer.js +3 -5
- package/src/training/ul_dataset.js +3 -5
- package/src/training/workloads.js +70 -79
- package/src/types/model.d.ts +5 -0
- package/src/version.js +1 -1
- package/tools/convert-safetensors-node.js +22 -16
- package/tools/doppler-cli.js +50 -26
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { acquireBuffer } from '../../memory/buffer-pool.js';
|
|
1
|
+
import { acquireBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
|
|
2
2
|
import { createTensor, dtypeBytes } from '../tensor.js';
|
|
3
3
|
import { unifiedKernelWrapper } from './utils.js';
|
|
4
4
|
import { selectRuleValue } from './rule-registry.js';
|
|
@@ -25,19 +25,27 @@ async function _pixelShuffle(target, input, options = {}) {
|
|
|
25
25
|
const bytesPerElement = dtypeBytes(input.dtype);
|
|
26
26
|
const outputSize = outChannels * outHeight * outWidth * bytesPerElement;
|
|
27
27
|
const output = outputBuffer || acquireBuffer(outputSize, undefined, 'pixel_shuffle_output');
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
28
|
+
const ownedOutput = outputBuffer ? null : output;
|
|
29
|
+
|
|
30
|
+
try {
|
|
31
|
+
await unifiedKernelWrapper(
|
|
32
|
+
'pixel_shuffle', target, variant,
|
|
33
|
+
[input, output],
|
|
34
|
+
{
|
|
35
|
+
out_channels: outChannels, out_height: outHeight, out_width: outWidth,
|
|
36
|
+
grid_width: gridWidth, grid_height: gridHeight, patch_size: patchSize,
|
|
37
|
+
patch_channels: inferredPatchChannels, _pad0: 0,
|
|
38
|
+
},
|
|
39
|
+
[Math.ceil((outHeight * outWidth) / WORKGROUP_SIZES.DEFAULT), outChannels, 1]
|
|
40
|
+
);
|
|
41
|
+
|
|
42
|
+
return createTensor(output, input.dtype, [outChannels, outHeight, outWidth], 'pixel_shuffle_output');
|
|
43
|
+
} catch (error) {
|
|
44
|
+
if (ownedOutput) {
|
|
45
|
+
releaseBuffer(ownedOutput);
|
|
46
|
+
}
|
|
47
|
+
throw error;
|
|
48
|
+
}
|
|
41
49
|
}
|
|
42
50
|
|
|
43
51
|
export async function runPixelShuffle(input, options = {}) {
|
package/src/gpu/kernels/relu.js
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { acquireBuffer } from '../../memory/buffer-pool.js';
|
|
1
|
+
import { acquireBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
|
|
2
2
|
import { createTensor, dtypeBytes } from '../tensor.js';
|
|
3
3
|
import { unifiedKernelWrapper } from './utils.js';
|
|
4
4
|
import { selectRuleValue } from './rule-registry.js';
|
|
@@ -35,18 +35,26 @@ async function _relu(target, input, options = {}) {
|
|
|
35
35
|
const size = resolveCount(input, count);
|
|
36
36
|
const variant = selectReluVariant(input.dtype);
|
|
37
37
|
const output = outputBuffer || acquireBuffer(size * dtypeBytes(input.dtype), undefined, 'relu_output');
|
|
38
|
+
const ownedOutput = outputBuffer ? null : output;
|
|
38
39
|
const dispatchPlan = planReluDispatch(target, size);
|
|
39
40
|
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
41
|
+
try {
|
|
42
|
+
await unifiedKernelWrapper(
|
|
43
|
+
'relu',
|
|
44
|
+
target,
|
|
45
|
+
variant,
|
|
46
|
+
[input, output],
|
|
47
|
+
{ size, _pad0: dispatchPlan.dispatchStride, _pad1: 0, _pad2: 0 },
|
|
48
|
+
dispatchPlan.workgroups
|
|
49
|
+
);
|
|
48
50
|
|
|
49
|
-
|
|
51
|
+
return createTensor(output, input.dtype, [...input.shape], 'relu_output');
|
|
52
|
+
} catch (error) {
|
|
53
|
+
if (ownedOutput) {
|
|
54
|
+
releaseBuffer(ownedOutput);
|
|
55
|
+
}
|
|
56
|
+
throw error;
|
|
57
|
+
}
|
|
50
58
|
}
|
|
51
59
|
|
|
52
60
|
export async function runReLU(input, options = {}) {
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { acquireBuffer } from '../../memory/buffer-pool.js';
|
|
1
|
+
import { acquireBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
|
|
2
2
|
import { createTensor, dtypeBytes } from '../tensor.js';
|
|
3
3
|
import { unifiedKernelWrapper } from './utils.js';
|
|
4
4
|
import { selectRuleValue } from './rule-registry.js';
|
|
@@ -32,23 +32,31 @@ async function _repeatChannels(target, input, options = {}) {
|
|
|
32
32
|
const bytesPerElement = dtypeBytes(input.dtype);
|
|
33
33
|
const outputSize = outChannels * height * width * bytesPerElement;
|
|
34
34
|
const output = outputBuffer || acquireBuffer(outputSize, undefined, 'repeat_channels_output');
|
|
35
|
+
const ownedOutput = outputBuffer ? null : output;
|
|
35
36
|
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
37
|
+
try {
|
|
38
|
+
await unifiedKernelWrapper(
|
|
39
|
+
'repeat_channels',
|
|
40
|
+
target,
|
|
41
|
+
variant,
|
|
42
|
+
[input, output],
|
|
43
|
+
{
|
|
44
|
+
in_channels: inChannels,
|
|
45
|
+
height,
|
|
46
|
+
width,
|
|
47
|
+
repeats,
|
|
48
|
+
_pad0: 0,
|
|
49
|
+
},
|
|
50
|
+
[Math.ceil((height * width) / WORKGROUP_SIZES.DEFAULT), outChannels, 1]
|
|
51
|
+
);
|
|
52
|
+
|
|
53
|
+
return createTensor(output, input.dtype, [outChannels, height, width], 'repeat_channels_output');
|
|
54
|
+
} catch (error) {
|
|
55
|
+
if (ownedOutput) {
|
|
56
|
+
releaseBuffer(ownedOutput);
|
|
57
|
+
}
|
|
58
|
+
throw error;
|
|
59
|
+
}
|
|
52
60
|
}
|
|
53
61
|
|
|
54
62
|
export async function runRepeatChannels(input, options = {}) {
|
|
@@ -82,6 +82,7 @@ function planResidualDispatch(target, size, elementsPerWorkgroup) {
|
|
|
82
82
|
async function _residualAdd(target, a, b, size, options = {}) {
|
|
83
83
|
const recorder = target && typeof target.beginComputePass === 'function' ? target : null;
|
|
84
84
|
const { useVec4 = true, outputBuffer = null } = options;
|
|
85
|
+
const ownsOutput = outputBuffer == null;
|
|
85
86
|
|
|
86
87
|
const { a: aAligned, b: bAligned, temps } = await alignResidualInputs(a, b, recorder);
|
|
87
88
|
const outputDtype = inferOutputDtype(aAligned, bAligned);
|
|
@@ -97,15 +98,22 @@ async function _residualAdd(target, a, b, size, options = {}) {
|
|
|
97
98
|
useVec4 ? VEC4_ELEMENTS_PER_WG : WORKGROUP_SIZES.DEFAULT
|
|
98
99
|
);
|
|
99
100
|
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
101
|
+
try {
|
|
102
|
+
await unifiedKernelWrapper(
|
|
103
|
+
'residual', target, variant,
|
|
104
|
+
[aAligned, bAligned, output],
|
|
105
|
+
{ size, scale: 1, _pad1: dispatchPlan.dispatchStride, _pad2: 0 },
|
|
106
|
+
dispatchPlan.workgroups
|
|
107
|
+
);
|
|
108
|
+
return createTensor(output, outputDtype, [size], 'residual_output');
|
|
109
|
+
} catch (error) {
|
|
110
|
+
if (ownsOutput) {
|
|
111
|
+
releaseBuffer(output);
|
|
112
|
+
}
|
|
113
|
+
throw error;
|
|
114
|
+
} finally {
|
|
115
|
+
cleanupTemps(temps, recorder);
|
|
116
|
+
}
|
|
109
117
|
}
|
|
110
118
|
|
|
111
119
|
async function _biasAdd(target, data, bias, numTokens, dim, options = {}) {
|
|
@@ -126,24 +134,26 @@ async function _biasAdd(target, data, bias, numTokens, dim, options = {}) {
|
|
|
126
134
|
Math.ceil(numTokens / tokenStride),
|
|
127
135
|
];
|
|
128
136
|
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
137
|
+
try {
|
|
138
|
+
await unifiedKernelWrapper(
|
|
139
|
+
'bias_add', target, variant,
|
|
140
|
+
[data, biasAligned],
|
|
141
|
+
{
|
|
142
|
+
num_tokens: numTokens,
|
|
143
|
+
dim,
|
|
144
|
+
data_offset: dataOffset,
|
|
145
|
+
bias_offset: biasOffset,
|
|
146
|
+
token_stride: tokenStride,
|
|
147
|
+
_pad0: 0,
|
|
148
|
+
_pad1: 0,
|
|
149
|
+
_pad2: 0,
|
|
150
|
+
},
|
|
151
|
+
workgroups
|
|
152
|
+
);
|
|
153
|
+
return createTensor(data.buffer, data.dtype, [numTokens, dim], 'bias_add_output');
|
|
154
|
+
} finally {
|
|
155
|
+
cleanupTemps(temps, recorder);
|
|
156
|
+
}
|
|
147
157
|
}
|
|
148
158
|
|
|
149
159
|
export async function runResidualAdd(a, b, size, options = {}) {
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
|
|
2
2
|
|
|
3
3
|
import { getKernelCapabilities } from '../device.js';
|
|
4
|
-
import { acquireBuffer, getBufferRequestedSize } from '../../memory/buffer-pool.js';
|
|
4
|
+
import { acquireBuffer, getBufferRequestedSize, releaseBuffer } from '../../memory/buffer-pool.js';
|
|
5
5
|
import { createTensor } from '../tensor.js';
|
|
6
6
|
import { getKernelThresholds, padToQ4KBlock } from '../../config/schema/index.js';
|
|
7
7
|
import { selectRuleValue } from './rule-registry.js';
|
|
@@ -9,6 +9,9 @@ import { selectRuleValue as selectLoaderRule } from '../../rules/rule-registry.j
|
|
|
9
9
|
import { getBuffer, getWeightDtype, getBufferDtype } from '../weight-buffer.js';
|
|
10
10
|
import { unifiedKernelWrapper } from './utils.js';
|
|
11
11
|
|
|
12
|
+
// Conservative fallback dtype for norm weight inference when metadata is unavailable.
|
|
13
|
+
const DEFAULT_DTYPE = 'f32';
|
|
14
|
+
|
|
12
15
|
function inferHiddenSize(input, hiddenSize) {
|
|
13
16
|
if (hiddenSize != null) return hiddenSize;
|
|
14
17
|
const shape = input?.shape;
|
|
@@ -39,9 +42,12 @@ function resolveNormWeightDtype(weight, hiddenSize) {
|
|
|
39
42
|
return taggedDtype;
|
|
40
43
|
}
|
|
41
44
|
|
|
45
|
+
// Conservative fallback: f32 avoids precision loss when dtype cannot be determined.
|
|
46
|
+
// This path fires for non-GPU buffers or missing hiddenSize, both of which prevent
|
|
47
|
+
// size-based dtype inference below.
|
|
42
48
|
const hasGPUBufferType = typeof GPUBuffer !== 'undefined';
|
|
43
49
|
if (!hasGPUBufferType || !(weightBuffer instanceof GPUBuffer) || hiddenSize == null || hiddenSize <= 0) {
|
|
44
|
-
return
|
|
50
|
+
return DEFAULT_DTYPE;
|
|
45
51
|
}
|
|
46
52
|
|
|
47
53
|
const byteSize = getBufferRequestedSize(weightBuffer);
|
|
@@ -55,7 +61,8 @@ function resolveNormWeightDtype(weight, hiddenSize) {
|
|
|
55
61
|
sizeMatchesF32,
|
|
56
62
|
});
|
|
57
63
|
}
|
|
58
|
-
|
|
64
|
+
// Buffer size matches neither f16 nor f32 for given hiddenSize; fall back to f32.
|
|
65
|
+
return DEFAULT_DTYPE;
|
|
59
66
|
}
|
|
60
67
|
|
|
61
68
|
function assertRMSNormWeightBuffer(weight, weightBuffer, hiddenSize) {
|
|
@@ -119,31 +126,39 @@ export async function runRMSNorm(
|
|
|
119
126
|
const paddedHiddenSize = padToQ4KBlock(inferredHiddenSize);
|
|
120
127
|
const outputSize = batchSize * paddedHiddenSize * bytesPerElement;
|
|
121
128
|
const outputBuf = outputBuffer || acquireBuffer(outputSize, undefined, 'rmsnorm_output');
|
|
129
|
+
const ownedOutput = outputBuffer ? null : outputBuf;
|
|
122
130
|
const dispatchPlan = planRMSNormDispatch(null, batchSize);
|
|
123
131
|
|
|
124
132
|
// Shader layout always includes the residual binding; when unused, bind a harmless placeholder.
|
|
125
133
|
const residualBuf = residual?.buffer || residual || input?.buffer || input || outputBuf;
|
|
126
134
|
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
135
|
+
try {
|
|
136
|
+
await unifiedKernelWrapper(
|
|
137
|
+
'rmsnorm',
|
|
138
|
+
null,
|
|
139
|
+
variant,
|
|
140
|
+
[input, normWeightBuffer, outputBuf, residualBuf],
|
|
141
|
+
{
|
|
142
|
+
hidden_size: inferredHiddenSize,
|
|
143
|
+
num_tokens: batchSize,
|
|
144
|
+
eps,
|
|
145
|
+
has_residual: residual ? 1 : 0,
|
|
146
|
+
token_stride: dispatchPlan.tokenStride,
|
|
147
|
+
_pad0: 0,
|
|
148
|
+
_pad1: 0,
|
|
149
|
+
_pad2: 0,
|
|
150
|
+
},
|
|
151
|
+
dispatchPlan.workgroups,
|
|
152
|
+
{ RMS_NORM_OFFSET: rmsNormWeightOffset, WEIGHT_IS_F16: normWeightDtype === 'f16' }
|
|
153
|
+
);
|
|
154
|
+
|
|
155
|
+
return createTensor(outputBuf, input.dtype, [batchSize, inferredHiddenSize], 'rmsnorm_output');
|
|
156
|
+
} catch (error) {
|
|
157
|
+
if (ownedOutput) {
|
|
158
|
+
releaseBuffer(ownedOutput);
|
|
159
|
+
}
|
|
160
|
+
throw error;
|
|
161
|
+
}
|
|
147
162
|
}
|
|
148
163
|
|
|
149
164
|
export async function recordRMSNorm(
|
|
@@ -165,28 +180,36 @@ export async function recordRMSNorm(
|
|
|
165
180
|
const paddedHiddenSize = padToQ4KBlock(inferredHiddenSize);
|
|
166
181
|
const outputSize = batchSize * paddedHiddenSize * bytesPerElement;
|
|
167
182
|
const outputBuf = outputBuffer || acquireBuffer(outputSize, undefined, 'rmsnorm_output');
|
|
183
|
+
const ownedOutput = outputBuffer ? null : outputBuf;
|
|
168
184
|
const dispatchPlan = planRMSNormDispatch(recorder, batchSize);
|
|
169
185
|
|
|
170
186
|
const residualBuf = residual?.buffer || residual || input?.buffer || input || outputBuf;
|
|
171
187
|
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
188
|
+
try {
|
|
189
|
+
await unifiedKernelWrapper(
|
|
190
|
+
'rmsnorm',
|
|
191
|
+
recorder,
|
|
192
|
+
variant,
|
|
193
|
+
[input, normWeightBuffer, outputBuf, residualBuf],
|
|
194
|
+
{
|
|
195
|
+
hidden_size: inferredHiddenSize,
|
|
196
|
+
num_tokens: batchSize,
|
|
197
|
+
eps,
|
|
198
|
+
has_residual: residual ? 1 : 0,
|
|
199
|
+
token_stride: dispatchPlan.tokenStride,
|
|
200
|
+
_pad0: 0,
|
|
201
|
+
_pad1: 0,
|
|
202
|
+
_pad2: 0,
|
|
203
|
+
},
|
|
204
|
+
dispatchPlan.workgroups,
|
|
205
|
+
{ RMS_NORM_OFFSET: rmsNormWeightOffset, WEIGHT_IS_F16: normWeightDtype === 'f16' }
|
|
206
|
+
);
|
|
207
|
+
|
|
208
|
+
return createTensor(outputBuf, input.dtype, [batchSize, inferredHiddenSize], 'rmsnorm_output');
|
|
209
|
+
} catch (error) {
|
|
210
|
+
if (ownedOutput) {
|
|
211
|
+
releaseBuffer(ownedOutput);
|
|
212
|
+
}
|
|
213
|
+
throw error;
|
|
214
|
+
}
|
|
192
215
|
}
|
package/src/gpu/kernels/rope.js
CHANGED
|
@@ -27,6 +27,9 @@ async function _rope(target, input, freqsCos, freqsSin, seqLen, options = {}) {
|
|
|
27
27
|
if (rotaryDim <= 0 || rotaryDim > headDim) {
|
|
28
28
|
throw new Error(`RoPE rotaryDim must be in (0, headDim]; got ${rotaryDim} for headDim ${headDim}`);
|
|
29
29
|
}
|
|
30
|
+
if (input.dtype === 'f16' && (rotaryDim !== headDim || interleaved)) {
|
|
31
|
+
throw new Error('RoPE f16 kernel requires rotaryDim === headDim and interleaved === false.');
|
|
32
|
+
}
|
|
30
33
|
|
|
31
34
|
const caps = getKernelCapabilities();
|
|
32
35
|
const useF16 = input.dtype === 'f16' && caps.hasF16;
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
|
|
2
2
|
|
|
3
3
|
import { getDevice, getKernelCapabilities } from '../device.js';
|
|
4
|
-
import { acquireBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
|
|
4
|
+
import { acquireBuffer, readBufferSlice, releaseBuffer } from '../../memory/buffer-pool.js';
|
|
5
5
|
import { WORKGROUP_SIZES } from './constants.js';
|
|
6
6
|
import { createPipeline, createUniformBufferWithView, getOrCreateBindGroupLayout } from './utils.js';
|
|
7
7
|
import { allowReadback } from '../perf-guards.js';
|
|
@@ -156,18 +156,19 @@ function ensureOutputBufferSize(outputBuffer, minBytes, label) {
|
|
|
156
156
|
}
|
|
157
157
|
}
|
|
158
158
|
|
|
159
|
-
function readTokenFromOutput(
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
size: 4,
|
|
163
|
-
usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST,
|
|
164
|
-
});
|
|
165
|
-
|
|
166
|
-
const copyEncoder = device.createCommandEncoder({ label: `${label}_copy` });
|
|
167
|
-
copyEncoder.copyBufferToBuffer(outputBuffer, outputIndex * 4, stagingBuffer, 0, 4);
|
|
168
|
-
device.queue.submit([copyEncoder.finish()]);
|
|
159
|
+
async function readTokenFromOutput(outputBuffer, outputIndex) {
|
|
160
|
+
return new Uint32Array(await readBufferSlice(outputBuffer, outputIndex * 4, 4))[0];
|
|
161
|
+
}
|
|
169
162
|
|
|
170
|
-
|
|
163
|
+
function cleanupRunResources(uniformBuffer, ownedBuffers) {
|
|
164
|
+
if (uniformBuffer) {
|
|
165
|
+
uniformBuffer.destroy();
|
|
166
|
+
}
|
|
167
|
+
for (const buffer of ownedBuffers) {
|
|
168
|
+
if (buffer) {
|
|
169
|
+
releaseBuffer(buffer);
|
|
170
|
+
}
|
|
171
|
+
}
|
|
171
172
|
}
|
|
172
173
|
|
|
173
174
|
async function executeArgmaxRun(logits, vocabSize, options) {
|
|
@@ -238,20 +239,14 @@ async function executeArgmaxRun(logits, vocabSize, options) {
|
|
|
238
239
|
|
|
239
240
|
device.queue.submit([encoder.finish()]);
|
|
240
241
|
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
releaseBuffer(tempLogits);
|
|
249
|
-
releaseBuffer(tempIndices);
|
|
250
|
-
if (ownsOutputBuffer) {
|
|
251
|
-
releaseBuffer(outputBuffer);
|
|
242
|
+
try {
|
|
243
|
+
return await readTokenFromOutput(outputBuffer, outputIndex);
|
|
244
|
+
} finally {
|
|
245
|
+
cleanupRunResources(
|
|
246
|
+
uniformBuffer,
|
|
247
|
+
[tempLogits, tempIndices, ownsOutputBuffer ? outputBuffer : null]
|
|
248
|
+
);
|
|
252
249
|
}
|
|
253
|
-
|
|
254
|
-
return tokenId;
|
|
255
250
|
}
|
|
256
251
|
|
|
257
252
|
async function executeArgmaxRecord(recorder, logits, vocabSize, options) {
|
|
@@ -428,20 +423,14 @@ export async function runGPUSample(
|
|
|
428
423
|
|
|
429
424
|
device.queue.submit([encoder.finish()]);
|
|
430
425
|
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
releaseBuffer(topkLogits);
|
|
439
|
-
releaseBuffer(topkIndices);
|
|
440
|
-
if (ownsOutputBuffer) {
|
|
441
|
-
releaseBuffer(outputBuffer);
|
|
426
|
+
try {
|
|
427
|
+
return await readTokenFromOutput(outputBuffer, outputIndex);
|
|
428
|
+
} finally {
|
|
429
|
+
cleanupRunResources(
|
|
430
|
+
uniformBuffer,
|
|
431
|
+
[topkLogits, topkIndices, ownsOutputBuffer ? outputBuffer : null]
|
|
432
|
+
);
|
|
442
433
|
}
|
|
443
|
-
|
|
444
|
-
return tokenId;
|
|
445
434
|
}
|
|
446
435
|
|
|
447
436
|
|
|
@@ -64,6 +64,8 @@ async function _sanaLinearAttention(target, query, key, value, options = {}) {
|
|
|
64
64
|
outputBuffer = null,
|
|
65
65
|
summaryBuffer = null,
|
|
66
66
|
} = options;
|
|
67
|
+
const ownsSummary = summaryBuffer == null;
|
|
68
|
+
const ownsOutput = outputBuffer == null;
|
|
67
69
|
|
|
68
70
|
if (
|
|
69
71
|
!Number.isFinite(numHeads) ||
|
|
@@ -98,18 +100,24 @@ async function _sanaLinearAttention(target, query, key, value, options = {}) {
|
|
|
98
100
|
eps,
|
|
99
101
|
};
|
|
100
102
|
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
103
|
+
try {
|
|
104
|
+
await runSummary(target, query, key, value, temporarySummary, uniforms, variant);
|
|
105
|
+
await runApply(target, query, temporarySummary, output, uniforms, variant);
|
|
106
|
+
return createTensor(output, query.dtype, [numTokens, hiddenSize], 'sana_linear_attention_output');
|
|
107
|
+
} catch (error) {
|
|
108
|
+
if (ownsOutput) {
|
|
109
|
+
releaseBuffer(output);
|
|
110
|
+
}
|
|
111
|
+
throw error;
|
|
112
|
+
} finally {
|
|
113
|
+
if (ownsSummary) {
|
|
114
|
+
if (recorder) {
|
|
115
|
+
recorder.trackTemporaryBuffer(temporarySummary);
|
|
116
|
+
} else {
|
|
117
|
+
releaseBuffer(temporarySummary);
|
|
118
|
+
}
|
|
109
119
|
}
|
|
110
120
|
}
|
|
111
|
-
|
|
112
|
-
return createTensor(output, query.dtype, [numTokens, hiddenSize], 'sana_linear_attention_output');
|
|
113
121
|
}
|
|
114
122
|
|
|
115
123
|
export async function runSanaLinearAttention(query, key, value, options = {}) {
|
package/src/gpu/kernels/scale.js
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { acquireBuffer } from '../../memory/buffer-pool.js';
|
|
1
|
+
import { acquireBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
|
|
2
2
|
import { createTensor, dtypeBytes } from '../tensor.js';
|
|
3
3
|
import { WORKGROUP_SIZES } from './constants.js';
|
|
4
4
|
import { unifiedKernelWrapper } from './utils.js';
|
|
@@ -6,6 +6,7 @@ import { selectRuleValue } from './rule-registry.js';
|
|
|
6
6
|
|
|
7
7
|
async function _scale(target, input, scale, options = {}) {
|
|
8
8
|
const { count, outputBuffer = null, inplace = false } = options;
|
|
9
|
+
const ownsOutput = !inplace && outputBuffer == null;
|
|
9
10
|
|
|
10
11
|
const bytesPerElement = dtypeBytes(input.dtype);
|
|
11
12
|
const inferredCount = count ?? Math.floor(input.buffer.size / bytesPerElement);
|
|
@@ -16,16 +17,22 @@ async function _scale(target, input, scale, options = {}) {
|
|
|
16
17
|
|
|
17
18
|
const bindings = inplace ? [outputBuf, outputBuf] : [input, outputBuf];
|
|
18
19
|
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
20
|
+
try {
|
|
21
|
+
await unifiedKernelWrapper(
|
|
22
|
+
'scale',
|
|
23
|
+
target,
|
|
24
|
+
variant,
|
|
25
|
+
bindings,
|
|
26
|
+
{ size: inferredCount, scale },
|
|
27
|
+
Math.ceil(inferredCount / WORKGROUP_SIZES.DEFAULT)
|
|
28
|
+
);
|
|
29
|
+
return createTensor(outputBuf, input.dtype, [...input.shape], 'scale_output');
|
|
30
|
+
} catch (error) {
|
|
31
|
+
if (ownsOutput) {
|
|
32
|
+
releaseBuffer(outputBuf);
|
|
33
|
+
}
|
|
34
|
+
throw error;
|
|
35
|
+
}
|
|
29
36
|
}
|
|
30
37
|
|
|
31
38
|
export async function runScale(input, scale, options = {}) {
|
|
@@ -138,8 +138,10 @@ export async function compileShader(
|
|
|
138
138
|
code: source,
|
|
139
139
|
});
|
|
140
140
|
|
|
141
|
-
// Check for compilation errors
|
|
142
|
-
const compilationInfo =
|
|
141
|
+
// Check for compilation errors (getCompilationInfo not available in all WebGPU providers)
|
|
142
|
+
const compilationInfo = typeof module.getCompilationInfo === 'function'
|
|
143
|
+
? await module.getCompilationInfo()
|
|
144
|
+
: { messages: [] };
|
|
143
145
|
if (compilationInfo.messages.length > 0) {
|
|
144
146
|
for (const msg of compilationInfo.messages) {
|
|
145
147
|
if (msg.type === 'error') {
|