@simulatte/doppler 0.1.6 → 0.1.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +126 -0
- package/README.md +16 -23
- package/package.json +14 -1
- package/src/adapters/adapter-registry.js +12 -1
- package/src/adapters/lora-loader.js +23 -6
- package/src/bridge/extension-client.d.ts +5 -0
- package/src/bridge/extension-client.js +40 -0
- package/src/bridge/index.d.ts +2 -1
- package/src/bridge/index.js +6 -4
- package/src/browser/browser-converter.js +26 -1
- package/src/browser/file-picker.js +6 -0
- package/src/browser/safetensors-parser-browser.js +84 -1
- package/src/browser/shard-io-browser.js +2 -2
- package/src/browser/tensor-source-download.js +8 -2
- package/src/browser/tensor-source-http.d.ts +1 -0
- package/src/browser/tensor-source-http.js +5 -1
- package/src/client/doppler-api.browser.js +20 -4
- package/src/client/doppler-api.js +19 -3
- package/src/client/doppler-provider/generation.js +12 -0
- package/src/client/doppler-provider/model-manager.d.ts +10 -0
- package/src/client/doppler-provider/model-manager.js +91 -19
- package/src/client/doppler-provider/source-runtime.d.ts +2 -1
- package/src/client/doppler-provider/source-runtime.js +132 -13
- package/src/client/doppler-registry.json +8 -7
- package/src/config/backward-registry-loader.js +17 -2
- package/src/config/execution-v0-contract-check.js +113 -15
- package/src/config/kernel-path-contract-check.js +57 -29
- package/src/config/kernel-path-loader.js +5 -36
- package/src/config/kernels/kernel-ref-digests.js +1 -1
- package/src/config/kernels/registry.js +14 -1
- package/src/config/kernels/registry.json +7 -5
- package/src/config/loader.d.ts +1 -1
- package/src/config/loader.js +12 -2
- package/src/config/merge-contract-check.js +59 -4
- package/src/config/merge-helpers.js +128 -7
- package/src/config/merge.d.ts +1 -0
- package/src/config/merge.js +10 -0
- package/src/config/param-validator.js +47 -2
- package/src/config/presets/kernel-paths/{gemma2-q4k-dequant-f32a.json → gemma2-q4k-dequant-f32a-nosubgroups.json} +3 -3
- package/src/config/presets/kernel-paths/gemma3-f16-fused-f32a-online-streamingprefill.json +223 -0
- package/src/config/presets/kernel-paths/{gemma3-q4k-dequant-f32a.json → gemma3-q4k-dequant-f32a-nosubgroups.json} +3 -3
- package/src/config/presets/kernel-paths/registry.json +29 -8
- package/src/config/presets/models/gemma2.json +2 -2
- package/src/config/presets/runtime/experiments/bench/gemma3-bench-q4k.json +1 -1
- package/src/config/presets/runtime/experiments/debug/gemma3-debug-q4k.json +1 -1
- package/src/config/presets/runtime/experiments/verify/gemma3-verify.json +1 -1
- package/src/config/presets/runtime/kernels/dequant-f16-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/dequant-f32-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/embeddinggemma-q4k-dequant-f32a.json +37 -0
- package/src/config/presets/runtime/kernels/fused-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f16a.json +33 -0
- package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f32a-nosubgroups.json +33 -0
- package/src/config/presets/runtime/kernels/gemma2-q4k-fused-f32a.json +33 -0
- package/src/config/presets/runtime/kernels/safe-q4k.json +6 -13
- package/src/config/presets/runtime/platform/metal-apple-q4k.json +1 -1
- package/src/config/runtime.js +6 -1
- package/src/config/schema/debug.schema.d.ts +5 -0
- package/src/config/schema/doppler.schema.js +16 -21
- package/src/config/schema/inference-defaults.schema.js +3 -3
- package/src/config/schema/kernel-path.schema.d.ts +5 -1
- package/src/config/schema/kernel-thresholds.schema.js +12 -4
- package/src/config/schema/manifest.schema.d.ts +2 -1
- package/src/config/schema/manifest.schema.js +16 -3
- package/src/config/training-defaults.js +30 -22
- package/src/converter/conversion-plan.js +94 -9
- package/src/converter/core.d.ts +7 -0
- package/src/converter/core.js +14 -9
- package/src/converter/execution-v0-manifest.js +4 -1
- package/src/converter/index.d.ts +1 -0
- package/src/converter/index.js +1 -0
- package/src/converter/manifest-inference.js +43 -12
- package/src/converter/parsers/diffusion.js +0 -3
- package/src/converter/quantization-info.js +35 -15
- package/src/converter/shard-packer.d.ts +1 -1
- package/src/converter/shard-packer.js +4 -1
- package/src/debug/config.js +123 -11
- package/src/debug/signals.js +7 -1
- package/src/debug/tensor.d.ts +2 -0
- package/src/debug/tensor.js +13 -2
- package/src/distribution/p2p-control-plane.js +52 -12
- package/src/distribution/p2p-observability.js +43 -7
- package/src/distribution/p2p-webrtc-browser.js +20 -0
- package/src/distribution/shard-delivery.js +77 -26
- package/src/formats/gguf/types.js +33 -16
- package/src/formats/rdrr/groups.d.ts +12 -4
- package/src/formats/rdrr/groups.js +3 -6
- package/src/formats/rdrr/parsing.js +39 -2
- package/src/formats/rdrr/types.d.ts +2 -1
- package/src/gpu/command-recorder.js +86 -61
- package/src/gpu/device.d.ts +1 -0
- package/src/gpu/device.js +73 -19
- package/src/gpu/kernel-tuner/benchmarks.js +326 -316
- package/src/gpu/kernel-tuner/cache.js +71 -4
- package/src/gpu/kernel-tuner/tuner.js +22 -4
- package/src/gpu/kernels/attention.js +15 -34
- package/src/gpu/kernels/backward/adam.js +62 -58
- package/src/gpu/kernels/backward/attention_backward.js +257 -169
- package/src/gpu/kernels/backward/conv2d_backward.js +14 -1
- package/src/gpu/kernels/cast.js +191 -149
- package/src/gpu/kernels/check-stop.js +33 -44
- package/src/gpu/kernels/conv2d.js +27 -17
- package/src/gpu/kernels/cross_entropy_loss.js +21 -15
- package/src/gpu/kernels/depthwise_conv2d.js +36 -26
- package/src/gpu/kernels/dequant.js +178 -126
- package/src/gpu/kernels/energy.d.ts +3 -21
- package/src/gpu/kernels/energy.js +111 -88
- package/src/gpu/kernels/feature-check.js +1 -1
- package/src/gpu/kernels/fused_ffn.js +84 -65
- package/src/gpu/kernels/fused_matmul_residual.js +56 -33
- package/src/gpu/kernels/fused_matmul_rmsnorm.js +62 -45
- package/src/gpu/kernels/gather.js +33 -15
- package/src/gpu/kernels/gelu.js +19 -11
- package/src/gpu/kernels/grouped_pointwise_conv2d.js +33 -23
- package/src/gpu/kernels/groupnorm.js +34 -23
- package/src/gpu/kernels/kv-quantize.js +5 -2
- package/src/gpu/kernels/layernorm.js +35 -19
- package/src/gpu/kernels/logit-merge.js +5 -3
- package/src/gpu/kernels/matmul.js +58 -39
- package/src/gpu/kernels/modulate.js +23 -15
- package/src/gpu/kernels/moe.js +221 -175
- package/src/gpu/kernels/pixel_shuffle.js +22 -14
- package/src/gpu/kernels/relu.js +18 -10
- package/src/gpu/kernels/repeat_channels.js +25 -17
- package/src/gpu/kernels/residual.js +37 -27
- package/src/gpu/kernels/rmsnorm.js +57 -41
- package/src/gpu/kernels/rope.js +3 -0
- package/src/gpu/kernels/sample.js +27 -38
- package/src/gpu/kernels/sana_linear_attention.js +18 -10
- package/src/gpu/kernels/scale.js +18 -11
- package/src/gpu/kernels/shader-cache.js +4 -2
- package/src/gpu/kernels/silu.js +120 -72
- package/src/gpu/kernels/softmax.js +44 -25
- package/src/gpu/kernels/split_qkv.js +23 -13
- package/src/gpu/kernels/transpose.js +18 -10
- package/src/gpu/kernels/transpose.wgsl +5 -3
- package/src/gpu/kernels/upsample2d.js +21 -13
- package/src/gpu/kernels/utils.js +20 -13
- package/src/gpu/partitioned-buffer-pool.js +10 -2
- package/src/gpu/perf-guards.js +2 -9
- package/src/gpu/profiler.js +27 -22
- package/src/gpu/readback-utils.d.ts +16 -0
- package/src/gpu/readback-utils.js +41 -0
- package/src/gpu/submit-tracker.js +13 -0
- package/src/gpu/uniform-cache.d.ts +1 -0
- package/src/gpu/uniform-cache.js +30 -9
- package/src/hotswap/intent-bundle.js +6 -0
- package/src/hotswap/manifest.d.ts +10 -1
- package/src/hotswap/manifest.js +12 -2
- package/src/hotswap/runtime.js +30 -8
- package/src/index-browser.d.ts +44 -0
- package/src/index-browser.js +14 -0
- package/src/inference/browser-harness-contract-helpers.d.ts +5 -0
- package/src/inference/browser-harness-contract-helpers.js +28 -0
- package/src/inference/browser-harness-diffusion-energy-suites.d.ts +2 -0
- package/src/inference/browser-harness-diffusion-energy-suites.js +269 -0
- package/src/inference/browser-harness-model-helpers.d.ts +16 -0
- package/src/inference/browser-harness-model-helpers.js +217 -0
- package/src/inference/browser-harness-report-helpers.d.ts +7 -0
- package/src/inference/browser-harness-report-helpers.js +42 -0
- package/src/inference/browser-harness-runtime-helpers.d.ts +61 -0
- package/src/inference/browser-harness-runtime-helpers.js +415 -0
- package/src/inference/browser-harness-suite-helpers.d.ts +28 -0
- package/src/inference/browser-harness-suite-helpers.js +268 -0
- package/src/inference/browser-harness-text-helpers.d.ts +27 -0
- package/src/inference/browser-harness-text-helpers.js +788 -0
- package/src/inference/browser-harness.d.ts +6 -0
- package/src/inference/browser-harness.js +130 -1996
- package/src/inference/kv-cache/base.js +140 -94
- package/src/inference/kv-cache/tiered.js +5 -3
- package/src/inference/moe-router.js +88 -56
- package/src/inference/multi-model-network.js +5 -3
- package/src/inference/network-evolution.d.ts +11 -2
- package/src/inference/network-evolution.js +20 -21
- package/src/inference/pipelines/context.d.ts +3 -0
- package/src/inference/pipelines/context.js +142 -2
- package/src/inference/pipelines/diffusion/helpers.js +7 -2
- package/src/inference/pipelines/diffusion/pipeline.js +2 -1
- package/src/inference/pipelines/diffusion/sd3-transformer.js +10 -10
- package/src/inference/pipelines/diffusion/vae.js +3 -7
- package/src/inference/pipelines/energy/pipeline.js +27 -21
- package/src/inference/pipelines/energy/quintel.d.ts +5 -0
- package/src/inference/pipelines/energy/quintel.js +11 -0
- package/src/inference/pipelines/energy-head/row-head-pipeline.js +17 -13
- package/src/inference/pipelines/structured/json-head-pipeline.js +26 -11
- package/src/inference/pipelines/text/attention/projections.js +151 -101
- package/src/inference/pipelines/text/attention/record.js +62 -8
- package/src/inference/pipelines/text/attention/run.js +62 -8
- package/src/inference/pipelines/text/config.js +3 -4
- package/src/inference/pipelines/text/embed.js +2 -8
- package/src/inference/pipelines/text/execution-plan.js +41 -19
- package/src/inference/pipelines/text/execution-v0-contract-helpers.d.ts +59 -0
- package/src/inference/pipelines/text/execution-v0-contract-helpers.js +937 -0
- package/src/inference/pipelines/text/execution-v0-runtime-builders.d.ts +15 -0
- package/src/inference/pipelines/text/execution-v0-runtime-builders.js +279 -0
- package/src/inference/pipelines/text/execution-v0.js +62 -1013
- package/src/inference/pipelines/text/generator-steps.d.ts +46 -0
- package/src/inference/pipelines/text/generator-steps.js +298 -207
- package/src/inference/pipelines/text/generator.js +6 -23
- package/src/inference/pipelines/text/init.js +78 -20
- package/src/inference/pipelines/text/kernel-path-auto-select.js +2 -0
- package/src/inference/pipelines/text/kernel-trace.d.ts +2 -0
- package/src/inference/pipelines/text/kernel-trace.js +6 -0
- package/src/inference/pipelines/text/layer.js +3 -9
- package/src/inference/pipelines/text/linear-attention.d.ts +10 -0
- package/src/inference/pipelines/text/linear-attention.js +80 -6
- package/src/inference/pipelines/text/logits/gpu.js +10 -5
- package/src/inference/pipelines/text/logits/index.js +10 -11
- package/src/inference/pipelines/text/logits/utils.d.ts +7 -0
- package/src/inference/pipelines/text/logits/utils.js +9 -0
- package/src/inference/pipelines/text/lora-apply.js +50 -32
- package/src/inference/pipelines/text/model-load.js +279 -104
- package/src/inference/pipelines/text/moe-cache.js +5 -4
- package/src/inference/pipelines/text/moe-cpu-gptoss.js +74 -69
- package/src/inference/pipelines/text/moe-cpu.js +42 -38
- package/src/inference/pipelines/text/moe-gpu.js +110 -86
- package/src/inference/pipelines/text/ops.js +90 -90
- package/src/inference/pipelines/text/probes.js +9 -9
- package/src/inference/pipelines/text/weights.js +17 -7
- package/src/inference/pipelines/text.js +13 -1
- package/src/inference/speculative.d.ts +2 -2
- package/src/inference/speculative.js +4 -18
- package/src/inference/test-harness.d.ts +1 -1
- package/src/inference/test-harness.js +15 -5
- package/src/inference/tokenizer.d.ts +0 -5
- package/src/inference/tokenizer.js +4 -23
- package/src/inference/tokenizers/bpe.js +9 -0
- package/src/inference/tokenizers/bundled.js +20 -0
- package/src/inference/tokenizers/sentencepiece.js +12 -0
- package/src/loader/doppler-loader.js +38 -22
- package/src/loader/dtype-utils.js +3 -44
- package/src/loader/embedding-loader.js +7 -3
- package/src/loader/experts/expert-cache.js +13 -6
- package/src/loader/experts/expert-loader.js +10 -6
- package/src/loader/final-weights-loader.js +8 -4
- package/src/loader/layer-loader.js +2 -1
- package/src/loader/loader-state.js +2 -2
- package/src/loader/memory-monitor.js +8 -0
- package/src/loader/multi-model-loader.d.ts +14 -0
- package/src/loader/multi-model-loader.js +70 -24
- package/src/loader/shard-cache.js +81 -12
- package/src/loader/shard-resolver.js +25 -3
- package/src/loader/tensors/tensor-loader.js +209 -144
- package/src/loader/tensors/tensor-reader.js +76 -19
- package/src/loader/weight-downcast.js +1 -1
- package/src/memory/buffer-pool.d.ts +9 -1
- package/src/memory/buffer-pool.js +109 -44
- package/src/memory/unified-detect.js +1 -1
- package/src/rules/inference/kernel-path.rules.json +24 -8
- package/src/rules/rule-registry.js +25 -1
- package/src/storage/backends/opfs-store.js +68 -24
- package/src/storage/downloader.js +364 -83
- package/src/storage/index.d.ts +3 -0
- package/src/storage/index.js +3 -0
- package/src/storage/preflight.d.ts +2 -2
- package/src/storage/preflight.js +24 -2
- package/src/storage/quickstart-downloader.js +11 -5
- package/src/storage/registry.js +10 -4
- package/src/storage/reports.js +1 -1
- package/src/storage/shard-manager.d.ts +15 -1
- package/src/storage/shard-manager.js +51 -3
- package/src/storage/source-artifact-store.d.ts +52 -0
- package/src/storage/source-artifact-store.js +234 -0
- package/src/tooling/command-api-constants.d.ts +9 -0
- package/src/tooling/command-api-constants.js +9 -0
- package/src/tooling/command-api-family-normalizers.d.ts +9 -0
- package/src/tooling/command-api-family-normalizers.js +343 -0
- package/src/tooling/command-api-helpers.d.ts +25 -0
- package/src/tooling/command-api-helpers.js +262 -0
- package/src/tooling/command-api.js +16 -602
- package/src/tooling/command-envelope.js +4 -1
- package/src/tooling/command-runner-shared.js +52 -18
- package/src/tooling/lean-execution-contract.js +150 -3
- package/src/tooling/node-browser-command-runner.js +161 -271
- package/src/tooling/node-command-runner.js +29 -3
- package/src/tooling/node-converter.js +27 -1
- package/src/tooling/node-source-runtime.d.ts +1 -1
- package/src/tooling/node-source-runtime.js +84 -3
- package/src/tooling/node-webgpu.js +24 -21
- package/src/tooling/opfs-cache.js +21 -4
- package/src/tooling/runtime-input-composition.d.ts +38 -0
- package/src/tooling/runtime-input-composition.js +86 -0
- package/src/tooling/source-runtime-bundle.d.ts +40 -5
- package/src/tooling/source-runtime-bundle.js +261 -34
- package/src/tooling/source-runtime-materializer.d.ts +6 -0
- package/src/tooling/source-runtime-materializer.js +93 -0
- package/src/training/attention-backward.js +32 -17
- package/src/training/autograd.js +80 -52
- package/src/training/checkpoint-watch.d.ts +2 -1
- package/src/training/checkpoint-watch.js +39 -6
- package/src/training/checkpoint.js +40 -11
- package/src/training/clip.js +2 -1
- package/src/training/datasets/token-batch.js +20 -8
- package/src/training/distillation/checkpoint-watch.js +1 -0
- package/src/training/distillation/student-fixture.d.ts +22 -0
- package/src/training/distillation/student-fixture.js +846 -0
- package/src/training/distillation/suite-data.d.ts +45 -0
- package/src/training/distillation/suite-data.js +189 -0
- package/src/training/lora-pipeline.js +4 -7
- package/src/training/lora.js +26 -12
- package/src/training/loss.js +5 -6
- package/src/training/objectives/cross_entropy.js +2 -5
- package/src/training/objectives/distill_kd.js +4 -8
- package/src/training/objectives/distill_triplet.js +4 -8
- package/src/training/objectives/ul_stage2_base.js +4 -8
- package/src/training/operator-command.js +2 -0
- package/src/training/optimizer.js +19 -7
- package/src/training/runner.js +2 -1
- package/src/training/suite.js +18 -978
- package/src/training/tensor-factory.d.ts +9 -0
- package/src/training/tensor-factory.js +13 -0
- package/src/training/trainer.js +3 -5
- package/src/training/ul_dataset.js +3 -5
- package/src/training/workloads.js +70 -79
- package/src/version.js +1 -1
- package/tools/convert-safetensors-node.js +22 -16
- package/tools/doppler-cli.js +44 -25
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
|
|
2
2
|
|
|
3
3
|
import { getDevice } from '../device.js';
|
|
4
|
-
import { acquireBuffer, getBufferRequestedSize } from '../../memory/buffer-pool.js';
|
|
4
|
+
import { acquireBuffer, getBufferRequestedSize, releaseBuffer } from '../../memory/buffer-pool.js';
|
|
5
5
|
import { createTensor } from '../tensor.js';
|
|
6
6
|
import { getBuffer } from '../weight-buffer.js';
|
|
7
7
|
import { dispatch, recordDispatch } from './dispatch.js';
|
|
@@ -91,7 +91,8 @@ export async function runMatmulRMSNormFused(
|
|
|
91
91
|
// Output buffer: [1, N] - size depends on dtype
|
|
92
92
|
const bytesPerElement = dtype === 'f16' ? 2 : 4;
|
|
93
93
|
const outputSize = N * bytesPerElement;
|
|
94
|
-
const
|
|
94
|
+
const ownedOutput = outputBuffer ? null : acquireBuffer(outputSize, undefined, 'matmul_rmsnorm_fused_output');
|
|
95
|
+
const output = outputBuffer || ownedOutput;
|
|
95
96
|
|
|
96
97
|
// Create uniform buffer (8 u32/f32 = 32 bytes, padded for alignment)
|
|
97
98
|
const uniformBuffer = createUniformBufferWithView(
|
|
@@ -110,36 +111,44 @@ export async function runMatmulRMSNormFused(
|
|
|
110
111
|
);
|
|
111
112
|
|
|
112
113
|
// Create placeholder for residual if not provided
|
|
114
|
+
const ownsResidualBuffer = !residual;
|
|
113
115
|
const residualBuffer = residual || device.createBuffer({
|
|
114
116
|
label: 'matmul_rmsnorm_residual_placeholder',
|
|
115
117
|
size: 4,
|
|
116
118
|
usage: GPUBufferUsage.STORAGE,
|
|
117
119
|
});
|
|
118
120
|
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
121
|
+
try {
|
|
122
|
+
const bindGroup = device.createBindGroup({
|
|
123
|
+
label: 'matmul_rmsnorm_fused_bind_group',
|
|
124
|
+
layout: pipeline.getBindGroupLayout(0),
|
|
125
|
+
entries: [
|
|
126
|
+
{ binding: 0, resource: { buffer: uniformBuffer } },
|
|
127
|
+
{ binding: 1, resource: { buffer: input.buffer } },
|
|
128
|
+
{ binding: 2, resource: { buffer: weightBuffer } },
|
|
129
|
+
{ binding: 3, resource: { buffer: normWeightBuffer } },
|
|
130
|
+
{ binding: 4, resource: { buffer: output } },
|
|
131
|
+
{ binding: 5, resource: { buffer: residualBuffer } },
|
|
132
|
+
],
|
|
133
|
+
});
|
|
134
|
+
|
|
135
|
+
const workgroups = 1;
|
|
136
|
+
const dispatchLabel = label ? `matmul_rmsnorm_fused:${label}` : 'matmul_rmsnorm_fused';
|
|
137
|
+
dispatch(device, pipeline, bindGroup, workgroups, dispatchLabel);
|
|
138
|
+
} catch (error) {
|
|
139
|
+
uniformBuffer.destroy();
|
|
140
|
+
if (ownsResidualBuffer) {
|
|
141
|
+
residualBuffer.destroy();
|
|
142
|
+
}
|
|
143
|
+
if (ownedOutput) {
|
|
144
|
+
releaseBuffer(ownedOutput);
|
|
145
|
+
}
|
|
146
|
+
throw error;
|
|
147
|
+
}
|
|
139
148
|
|
|
140
149
|
// Cleanup
|
|
141
150
|
uniformBuffer.destroy();
|
|
142
|
-
if (
|
|
151
|
+
if (ownsResidualBuffer) residualBuffer.destroy();
|
|
143
152
|
|
|
144
153
|
// Output dtype matches input dtype
|
|
145
154
|
return createTensor(output, input.dtype, [1, N], 'matmul_rmsnorm_fused_output');
|
|
@@ -199,7 +208,8 @@ export async function recordMatmulRMSNormFused(
|
|
|
199
208
|
// Output buffer - size depends on dtype
|
|
200
209
|
const bytesPerElement = dtype === 'f16' ? 2 : 4;
|
|
201
210
|
const outputSize = N * bytesPerElement;
|
|
202
|
-
const
|
|
211
|
+
const ownedOutput = outputBuffer ? null : acquireBuffer(outputSize, undefined, 'matmul_rmsnorm_fused_output');
|
|
212
|
+
const output = outputBuffer || ownedOutput;
|
|
203
213
|
|
|
204
214
|
// Uniform buffer via recorder (8 u32/f32 = 32 bytes, padded for alignment)
|
|
205
215
|
const uniformBuffer = createUniformBufferWithView(
|
|
@@ -217,35 +227,42 @@ export async function recordMatmulRMSNormFused(
|
|
|
217
227
|
);
|
|
218
228
|
|
|
219
229
|
// Placeholder for residual
|
|
230
|
+
const ownsResidualBuffer = !residual;
|
|
220
231
|
const residualBuffer = residual || device.createBuffer({
|
|
221
232
|
label: 'matmul_rmsnorm_residual_placeholder',
|
|
222
233
|
size: 4,
|
|
223
234
|
usage: GPUBufferUsage.STORAGE,
|
|
224
235
|
});
|
|
225
236
|
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
237
|
+
try {
|
|
238
|
+
const bindGroup = device.createBindGroup({
|
|
239
|
+
label: 'matmul_rmsnorm_fused_bind_group',
|
|
240
|
+
layout: pipeline.getBindGroupLayout(0),
|
|
241
|
+
entries: [
|
|
242
|
+
{ binding: 0, resource: { buffer: uniformBuffer } },
|
|
243
|
+
{ binding: 1, resource: { buffer: input.buffer } },
|
|
244
|
+
{ binding: 2, resource: { buffer: weightBuffer } },
|
|
245
|
+
{ binding: 3, resource: { buffer: normWeightBuffer } },
|
|
246
|
+
{ binding: 4, resource: { buffer: output } },
|
|
247
|
+
{ binding: 5, resource: { buffer: residualBuffer } },
|
|
248
|
+
],
|
|
249
|
+
});
|
|
250
|
+
|
|
251
|
+
const workgroups = 1;
|
|
252
|
+
const dispatchLabel = label ? `matmul_rmsnorm_fused:${label}` : 'matmul_rmsnorm_fused';
|
|
253
|
+
recordDispatch(recorder, pipeline, bindGroup, workgroups, dispatchLabel);
|
|
254
|
+
} catch (error) {
|
|
255
|
+
if (ownsResidualBuffer) {
|
|
256
|
+
residualBuffer.destroy();
|
|
257
|
+
}
|
|
258
|
+
if (ownedOutput) {
|
|
259
|
+
releaseBuffer(ownedOutput);
|
|
260
|
+
}
|
|
261
|
+
throw error;
|
|
262
|
+
}
|
|
246
263
|
|
|
247
264
|
// Track placeholder for cleanup
|
|
248
|
-
if (
|
|
265
|
+
if (ownsResidualBuffer) {
|
|
249
266
|
recorder.trackTemporaryBuffer(residualBuffer);
|
|
250
267
|
}
|
|
251
268
|
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import { getKernelCapabilities } from '../device.js';
|
|
2
|
-
import { acquireBuffer } from '../../memory/buffer-pool.js';
|
|
2
|
+
import { acquireBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
|
|
3
3
|
import { WORKGROUP_SIZES, VEC4_ELEMENTS_PER_WG } from './constants.js';
|
|
4
4
|
import { unifiedKernelWrapper } from './utils.js';
|
|
5
5
|
import { trace } from '../../debug/index.js';
|
|
@@ -26,7 +26,6 @@ async function _gather(
|
|
|
26
26
|
options = {}
|
|
27
27
|
) {
|
|
28
28
|
const {
|
|
29
|
-
useVec4 = true,
|
|
30
29
|
outputBuffer = null,
|
|
31
30
|
embeddingDtype,
|
|
32
31
|
outputDtype,
|
|
@@ -43,9 +42,22 @@ async function _gather(
|
|
|
43
42
|
if (outputDtype == null) {
|
|
44
43
|
throw new Error('[Gather] outputDtype is required.');
|
|
45
44
|
}
|
|
45
|
+
if (embeddingDtype === 'f16' && !caps.hasF16) {
|
|
46
|
+
throw new Error('[Gather] embeddingDtype=f16 requires shader-f16 support.');
|
|
47
|
+
}
|
|
48
|
+
if (outputDtype === 'f16' && !caps.hasF16) {
|
|
49
|
+
throw new Error('[Gather] outputDtype=f16 requires shader-f16 support.');
|
|
50
|
+
}
|
|
46
51
|
|
|
47
|
-
const
|
|
48
|
-
const
|
|
52
|
+
const requestedVec4 = options.useVec4;
|
|
53
|
+
const wantsVec4 = requestedVec4 ?? true;
|
|
54
|
+
if (requestedVec4 === true && hiddenSize % 4 !== 0) {
|
|
55
|
+
throw new Error('[Gather] useVec4=true requires hiddenSize to be divisible by 4.');
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
const useF16Input = embeddingDtype === 'f16';
|
|
59
|
+
const useF16Output = outputDtype === 'f16';
|
|
60
|
+
const useVec4 = wantsVec4 && hiddenSize % 4 === 0;
|
|
49
61
|
|
|
50
62
|
trace.embed(
|
|
51
63
|
`Gather: numTokens=${numTokens}, hiddenSize=${hiddenSize}, vocabSize=${vocabSize}, ` +
|
|
@@ -64,6 +76,7 @@ async function _gather(
|
|
|
64
76
|
const paddedHiddenSize = padToQ4KBlock(hiddenSize);
|
|
65
77
|
const outputSize = numTokens * paddedHiddenSize * bytesPerElement;
|
|
66
78
|
const output = outputBuffer || acquireBuffer(outputSize, undefined, 'gather_output');
|
|
79
|
+
const ownedOutput = outputBuffer ? null : output;
|
|
67
80
|
|
|
68
81
|
const uniforms = {
|
|
69
82
|
num_tokens: numTokens,
|
|
@@ -82,16 +95,22 @@ async function _gather(
|
|
|
82
95
|
? Math.ceil((numTokens * hiddenSize) / VEC4_ELEMENTS_PER_WG)
|
|
83
96
|
: Math.ceil((numTokens * hiddenSize) / WORKGROUP_SIZES.DEFAULT));
|
|
84
97
|
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
98
|
+
try {
|
|
99
|
+
await unifiedKernelWrapper(
|
|
100
|
+
'gather',
|
|
101
|
+
target,
|
|
102
|
+
variant,
|
|
103
|
+
[indices, embeddings, output],
|
|
104
|
+
uniforms,
|
|
105
|
+
workgroups
|
|
106
|
+
);
|
|
107
|
+
return createTensor(output, actualDtype, [numTokens, hiddenSize], 'gather_output');
|
|
108
|
+
} catch (error) {
|
|
109
|
+
if (ownedOutput) {
|
|
110
|
+
releaseBuffer(ownedOutput);
|
|
111
|
+
}
|
|
112
|
+
throw error;
|
|
113
|
+
}
|
|
95
114
|
}
|
|
96
115
|
|
|
97
116
|
export async function runGather(
|
|
@@ -116,4 +135,3 @@ export async function recordGather(
|
|
|
116
135
|
) {
|
|
117
136
|
return _gather(recorder, indices, embeddings, numTokens, hiddenSize, vocabSize, options);
|
|
118
137
|
}
|
|
119
|
-
|
package/src/gpu/kernels/gelu.js
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
|
|
2
|
-
import { acquireBuffer } from '../../memory/buffer-pool.js';
|
|
2
|
+
import { acquireBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
|
|
3
3
|
import { createTensor, dtypeBytes } from '../tensor.js';
|
|
4
4
|
import { WORKGROUP_SIZES } from './constants.js';
|
|
5
5
|
import { unifiedKernelWrapper } from './utils.js';
|
|
@@ -26,16 +26,24 @@ async function _gelu(target, input, options = {}) {
|
|
|
26
26
|
const outputSize = inferredSize * bytesPerElement;
|
|
27
27
|
const output = outputBuffer || acquireBuffer(outputSize, undefined, 'gelu_output');
|
|
28
28
|
const gateBuffer = gate ?? input;
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
29
|
+
const ownedOutput = outputBuffer ? null : output;
|
|
30
|
+
|
|
31
|
+
try {
|
|
32
|
+
await unifiedKernelWrapper(
|
|
33
|
+
'gelu', target, variant,
|
|
34
|
+
[input, output, gateBuffer],
|
|
35
|
+
{ size: inferredSize, rowsplit_dim: 0 },
|
|
36
|
+
Math.ceil(inferredSize / WORKGROUP_SIZES.DEFAULT),
|
|
37
|
+
overrides
|
|
38
|
+
);
|
|
39
|
+
|
|
40
|
+
return createTensor(output, input.dtype, [inferredSize], 'gelu_output');
|
|
41
|
+
} catch (error) {
|
|
42
|
+
if (ownedOutput) {
|
|
43
|
+
releaseBuffer(ownedOutput);
|
|
44
|
+
}
|
|
45
|
+
throw error;
|
|
46
|
+
}
|
|
39
47
|
}
|
|
40
48
|
|
|
41
49
|
export async function runGeLU(input, options = {}) {
|
|
@@ -55,33 +55,43 @@ async function _groupedPointwiseConv2D(target, input, weight, bias, options = {}
|
|
|
55
55
|
device.queue.writeBuffer(biasBuffer, 0, new Uint8Array(paddedSize));
|
|
56
56
|
}
|
|
57
57
|
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
58
|
+
try {
|
|
59
|
+
await unifiedKernelWrapper(
|
|
60
|
+
'grouped_pointwise_conv2d',
|
|
61
|
+
target,
|
|
62
|
+
variant,
|
|
63
|
+
[input, weightBuffer, biasBuffer, output],
|
|
64
|
+
{
|
|
65
|
+
in_channels: inChannels,
|
|
66
|
+
out_channels: outChannels,
|
|
67
|
+
height,
|
|
68
|
+
width,
|
|
69
|
+
groups,
|
|
70
|
+
_pad0: 0,
|
|
71
|
+
_pad1: 0,
|
|
72
|
+
_pad2: 0,
|
|
73
|
+
},
|
|
74
|
+
[Math.ceil(spatial / WORKGROUP_SIZES.DEFAULT), outChannels, 1]
|
|
75
|
+
);
|
|
76
|
+
|
|
77
|
+
if (tempBias) {
|
|
78
|
+
if (recorder) {
|
|
79
|
+
recorder.trackTemporaryBuffer(tempBias);
|
|
80
|
+
} else {
|
|
81
|
+
releaseBuffer(tempBias);
|
|
82
|
+
}
|
|
83
|
+
}
|
|
75
84
|
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
} else {
|
|
85
|
+
return createTensor(output, input.dtype, [outChannels, height, width], 'grouped_pointwise_conv2d_output');
|
|
86
|
+
} catch (error) {
|
|
87
|
+
if (tempBias) {
|
|
80
88
|
releaseBuffer(tempBias);
|
|
81
89
|
}
|
|
90
|
+
if (!outputBuffer) {
|
|
91
|
+
releaseBuffer(output);
|
|
92
|
+
}
|
|
93
|
+
throw error;
|
|
82
94
|
}
|
|
83
|
-
|
|
84
|
-
return createTensor(output, input.dtype, [outChannels, height, width], 'grouped_pointwise_conv2d_output');
|
|
85
95
|
}
|
|
86
96
|
|
|
87
97
|
export async function runGroupedPointwiseConv2D(input, weight, bias, options = {}) {
|
|
@@ -17,6 +17,9 @@ function validateOptions(options) {
|
|
|
17
17
|
if (!Number.isFinite(numGroups) || numGroups <= 0) {
|
|
18
18
|
throw new Error('GroupNorm requires numGroups > 0.');
|
|
19
19
|
}
|
|
20
|
+
if (channels % numGroups !== 0) {
|
|
21
|
+
throw new Error('GroupNorm requires channels to be divisible by numGroups.');
|
|
22
|
+
}
|
|
20
23
|
if (!Number.isFinite(eps)) {
|
|
21
24
|
throw new Error('GroupNorm requires eps.');
|
|
22
25
|
}
|
|
@@ -44,34 +47,42 @@ async function _groupNorm(target, input, weight, bias, options = {}) {
|
|
|
44
47
|
|
|
45
48
|
const statsSize = numGroups * 2 * 4;
|
|
46
49
|
const statsBuffer = acquireBuffer(statsSize, undefined, 'groupnorm_stats');
|
|
47
|
-
|
|
48
|
-
await unifiedKernelWrapper(
|
|
49
|
-
'groupnorm_stats',
|
|
50
|
-
target,
|
|
51
|
-
statsVariant,
|
|
52
|
-
[input, statsBuffer],
|
|
53
|
-
uniforms,
|
|
54
|
-
numGroups
|
|
55
|
-
);
|
|
56
|
-
|
|
57
50
|
const bytesPerElement = dtypeBytes(input.dtype);
|
|
58
51
|
const outputSize = channels * height * width * bytesPerElement;
|
|
59
|
-
const
|
|
52
|
+
const ownedOutput = outputBuffer ? null : acquireBuffer(outputSize, undefined, 'groupnorm_output');
|
|
53
|
+
const output = outputBuffer || ownedOutput;
|
|
60
54
|
|
|
61
|
-
|
|
62
|
-
|
|
55
|
+
try {
|
|
56
|
+
await unifiedKernelWrapper(
|
|
57
|
+
'groupnorm_stats',
|
|
58
|
+
target,
|
|
59
|
+
statsVariant,
|
|
60
|
+
[input, statsBuffer],
|
|
61
|
+
uniforms,
|
|
62
|
+
numGroups
|
|
63
|
+
);
|
|
63
64
|
|
|
64
|
-
|
|
65
|
-
|
|
65
|
+
const weightBuffer = getBuffer(weight);
|
|
66
|
+
const biasBuffer = getBuffer(bias);
|
|
66
67
|
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
68
|
+
const total = channels * height * width;
|
|
69
|
+
const workgroups = Math.ceil(total / WORKGROUP_SIZES.DEFAULT);
|
|
70
|
+
|
|
71
|
+
await unifiedKernelWrapper(
|
|
72
|
+
'groupnorm_apply',
|
|
73
|
+
target,
|
|
74
|
+
applyVariant,
|
|
75
|
+
[input, statsBuffer, weightBuffer, biasBuffer, output],
|
|
76
|
+
uniforms,
|
|
77
|
+
workgroups
|
|
78
|
+
);
|
|
79
|
+
} catch (error) {
|
|
80
|
+
releaseBuffer(statsBuffer);
|
|
81
|
+
if (ownedOutput) {
|
|
82
|
+
releaseBuffer(ownedOutput);
|
|
83
|
+
}
|
|
84
|
+
throw error;
|
|
85
|
+
}
|
|
75
86
|
|
|
76
87
|
if (recorder) {
|
|
77
88
|
recorder.trackTemporaryBuffer(statsBuffer);
|
|
@@ -78,8 +78,11 @@ export async function runKVQuantize(
|
|
|
78
78
|
});
|
|
79
79
|
|
|
80
80
|
const workgroups = [numKVHeads, numTokens, 1];
|
|
81
|
-
|
|
82
|
-
|
|
81
|
+
try {
|
|
82
|
+
dispatch(device, pipeline, bindGroup, workgroups, 'kv_quantize');
|
|
83
|
+
} finally {
|
|
84
|
+
uniformBuffer.destroy();
|
|
85
|
+
}
|
|
83
86
|
}
|
|
84
87
|
|
|
85
88
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
|
|
2
2
|
import { getKernelCapabilities } from '../device.js';
|
|
3
|
-
import { acquireBuffer } from '../../memory/buffer-pool.js';
|
|
3
|
+
import { acquireBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
|
|
4
4
|
import { createTensor } from '../tensor.js';
|
|
5
5
|
import { padToQ4KBlock } from '../../config/schema/index.js';
|
|
6
6
|
import { selectRuleValue } from './rule-registry.js';
|
|
@@ -36,17 +36,25 @@ export async function runLayerNorm(
|
|
|
36
36
|
const paddedHiddenSize = padToQ4KBlock(inferredHiddenSize);
|
|
37
37
|
const outputSize = batchSize * paddedHiddenSize * bytesPerElement;
|
|
38
38
|
const outputBuf = outputBuffer || acquireBuffer(outputSize, undefined, 'layernorm_output');
|
|
39
|
+
const ownedOutput = outputBuffer ? null : outputBuf;
|
|
39
40
|
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
41
|
+
try {
|
|
42
|
+
await unifiedKernelWrapper(
|
|
43
|
+
'layernorm',
|
|
44
|
+
null,
|
|
45
|
+
variant,
|
|
46
|
+
[input, weight, bias, outputBuf],
|
|
47
|
+
{ hidden_size: inferredHiddenSize, num_tokens: batchSize, eps },
|
|
48
|
+
batchSize
|
|
49
|
+
);
|
|
48
50
|
|
|
49
|
-
|
|
51
|
+
return createTensor(outputBuf, input.dtype, [batchSize, inferredHiddenSize], 'layernorm_output');
|
|
52
|
+
} catch (error) {
|
|
53
|
+
if (ownedOutput) {
|
|
54
|
+
releaseBuffer(ownedOutput);
|
|
55
|
+
}
|
|
56
|
+
throw error;
|
|
57
|
+
}
|
|
50
58
|
}
|
|
51
59
|
|
|
52
60
|
export async function recordLayerNorm(
|
|
@@ -66,15 +74,23 @@ export async function recordLayerNorm(
|
|
|
66
74
|
const paddedHiddenSize = padToQ4KBlock(inferredHiddenSize);
|
|
67
75
|
const outputSize = batchSize * paddedHiddenSize * bytesPerElement;
|
|
68
76
|
const outputBuf = outputBuffer || acquireBuffer(outputSize, undefined, 'layernorm_output');
|
|
77
|
+
const ownedOutput = outputBuffer ? null : outputBuf;
|
|
69
78
|
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
79
|
+
try {
|
|
80
|
+
await unifiedKernelWrapper(
|
|
81
|
+
'layernorm',
|
|
82
|
+
recorder,
|
|
83
|
+
variant,
|
|
84
|
+
[input, weight, bias, outputBuf],
|
|
85
|
+
{ hidden_size: inferredHiddenSize, num_tokens: batchSize, eps },
|
|
86
|
+
batchSize
|
|
87
|
+
);
|
|
78
88
|
|
|
79
|
-
|
|
89
|
+
return createTensor(outputBuf, input.dtype, [batchSize, inferredHiddenSize], 'layernorm_output');
|
|
90
|
+
} catch (error) {
|
|
91
|
+
if (ownedOutput) {
|
|
92
|
+
releaseBuffer(ownedOutput);
|
|
93
|
+
}
|
|
94
|
+
throw error;
|
|
95
|
+
}
|
|
80
96
|
}
|
|
@@ -266,9 +266,11 @@ export class LogitMergeKernel {
|
|
|
266
266
|
pass.end();
|
|
267
267
|
|
|
268
268
|
this.#device.queue.submit([encoder.finish()]);
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
269
|
+
this.#device.queue.onSubmittedWorkDone()
|
|
270
|
+
.catch(() => {})
|
|
271
|
+
.finally(() => {
|
|
272
|
+
paramsBuffer.destroy();
|
|
273
|
+
});
|
|
272
274
|
|
|
273
275
|
return mergedBuffer;
|
|
274
276
|
}
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { getDevice } from '../device.js';
|
|
1
|
+
import { getDevice, getKernelCapabilities } from '../device.js';
|
|
2
2
|
import { createTensor } from '../tensor.js';
|
|
3
3
|
import { getBuffer, getLayout, getWeightDtype } from '../weight-buffer.js';
|
|
4
4
|
import { log, trace, isTraceEnabled } from '../../debug/index.js';
|
|
@@ -110,6 +110,7 @@ async function executeMatmul(recorder, A, B, M, N, K, options = {}) {
|
|
|
110
110
|
const mode = isRecord ? 'record' : 'run';
|
|
111
111
|
const opLabel = isRecord ? 'recordMatmul' : 'runMatmul';
|
|
112
112
|
const device = recorder?.device || getDevice();
|
|
113
|
+
const capabilities = getKernelCapabilities();
|
|
113
114
|
|
|
114
115
|
const {
|
|
115
116
|
alpha = 1.0,
|
|
@@ -139,6 +140,13 @@ async function executeMatmul(recorder, A, B, M, N, K, options = {}) {
|
|
|
139
140
|
const bDtype = toMatmulDtype(weightDtype ?? options.bDtype);
|
|
140
141
|
const requestedOutputDtype = options.outputDtype || A.dtype;
|
|
141
142
|
|
|
143
|
+
if (bDtype === 'f16' && capabilities?.hasF16 !== true) {
|
|
144
|
+
throw new Error(`[${opLabel}] f16 weights require shader-f16 support.`);
|
|
145
|
+
}
|
|
146
|
+
if (requestedOutputDtype === 'f16' && capabilities?.hasF16 !== true) {
|
|
147
|
+
throw new Error(`[${opLabel}] f16 output requires shader-f16 support.`);
|
|
148
|
+
}
|
|
149
|
+
|
|
142
150
|
if (!isRecord && isTraceEnabled('kernels') && !weightDtype && !options.bDtype && M <= 2) {
|
|
143
151
|
log.warn('Matmul', `runMatmul: B buffer dtype unknown! size=${bBuffer.size}, M=${M}, N=${N}, K=${K}. Assuming f32.`);
|
|
144
152
|
}
|
|
@@ -228,6 +236,7 @@ async function executeMatmul(recorder, A, B, M, N, K, options = {}) {
|
|
|
228
236
|
N,
|
|
229
237
|
outputBuffer
|
|
230
238
|
);
|
|
239
|
+
const ownsOutput = outputBuffer == null;
|
|
231
240
|
|
|
232
241
|
if (!Number.isFinite(outputSize) || outputSize <= 0) {
|
|
233
242
|
throw new Error(`[${opLabel}] Invalid output size: ${outputSize} (M=${M}, N=${N})`);
|
|
@@ -239,50 +248,60 @@ async function executeMatmul(recorder, A, B, M, N, K, options = {}) {
|
|
|
239
248
|
}
|
|
240
249
|
|
|
241
250
|
const dispatchPlan = calculateMatmulDispatch(variant, useQ4KFused, useGemv, M, N, config);
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
251
|
+
let uniformBuffer = null;
|
|
252
|
+
let completed = false;
|
|
253
|
+
try {
|
|
254
|
+
uniformBuffer = createMatmulUniformBuffer(
|
|
255
|
+
'matmul_uniforms',
|
|
256
|
+
M,
|
|
257
|
+
N,
|
|
258
|
+
K,
|
|
259
|
+
alpha,
|
|
260
|
+
useQ4KFused,
|
|
261
|
+
transposeB,
|
|
262
|
+
dispatchPlan.uniformWorkgroupsX,
|
|
263
|
+
recorder || null,
|
|
264
|
+
device
|
|
265
|
+
);
|
|
254
266
|
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
267
|
+
const entries = createMatmulBindGroupEntries(
|
|
268
|
+
variant,
|
|
269
|
+
uniformBuffer,
|
|
270
|
+
matmulInput,
|
|
271
|
+
bBuffer,
|
|
272
|
+
C,
|
|
273
|
+
{ aOffset, bOffset, cOffset },
|
|
274
|
+
{
|
|
275
|
+
aBindingSize: bindingSizes.aBindingSize,
|
|
276
|
+
bBindingSize: bindingSizes.bBindingSize,
|
|
277
|
+
cBindingSize,
|
|
278
|
+
}
|
|
279
|
+
);
|
|
268
280
|
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
281
|
+
const bindGroup = device.createBindGroup({
|
|
282
|
+
label: 'matmul_bind_group',
|
|
283
|
+
layout: pipeline.getBindGroupLayout(0),
|
|
284
|
+
entries,
|
|
285
|
+
});
|
|
274
286
|
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
287
|
+
if (isRecord) {
|
|
288
|
+
kernel.record(recorder, pipeline, bindGroup, dispatchPlan.workgroups, buildProfileLabel(options));
|
|
289
|
+
} else {
|
|
290
|
+
kernel.dispatch(pipeline, bindGroup, dispatchPlan.workgroups);
|
|
291
|
+
}
|
|
292
|
+
completed = true;
|
|
293
|
+
return createTensor(C, actualOutputDtype, [M, N], 'matmul_output');
|
|
294
|
+
} finally {
|
|
295
|
+
if (!isRecord && uniformBuffer) {
|
|
296
|
+
releaseUniformBuffer(uniformBuffer);
|
|
297
|
+
}
|
|
298
|
+
if (!isRecord && castedInput) {
|
|
281
299
|
releaseBuffer(castedInput.buffer);
|
|
282
300
|
}
|
|
301
|
+
if (!completed && ownsOutput) {
|
|
302
|
+
releaseBuffer(C);
|
|
303
|
+
}
|
|
283
304
|
}
|
|
284
|
-
|
|
285
|
-
return createTensor(C, actualOutputDtype, [M, N], 'matmul_output');
|
|
286
305
|
}
|
|
287
306
|
|
|
288
307
|
|