@simulatte/doppler 0.1.6 → 0.1.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +126 -0
- package/README.md +16 -23
- package/package.json +14 -1
- package/src/adapters/adapter-registry.js +12 -1
- package/src/adapters/lora-loader.js +23 -6
- package/src/bridge/extension-client.d.ts +5 -0
- package/src/bridge/extension-client.js +40 -0
- package/src/bridge/index.d.ts +2 -1
- package/src/bridge/index.js +6 -4
- package/src/browser/browser-converter.js +26 -1
- package/src/browser/file-picker.js +6 -0
- package/src/browser/safetensors-parser-browser.js +84 -1
- package/src/browser/shard-io-browser.js +2 -2
- package/src/browser/tensor-source-download.js +8 -2
- package/src/browser/tensor-source-http.d.ts +1 -0
- package/src/browser/tensor-source-http.js +5 -1
- package/src/client/doppler-api.browser.js +20 -4
- package/src/client/doppler-api.js +19 -3
- package/src/client/doppler-provider/generation.js +12 -0
- package/src/client/doppler-provider/model-manager.d.ts +10 -0
- package/src/client/doppler-provider/model-manager.js +91 -19
- package/src/client/doppler-provider/source-runtime.d.ts +2 -1
- package/src/client/doppler-provider/source-runtime.js +132 -13
- package/src/client/doppler-registry.json +8 -7
- package/src/config/backward-registry-loader.js +17 -2
- package/src/config/execution-v0-contract-check.js +113 -15
- package/src/config/kernel-path-contract-check.js +57 -29
- package/src/config/kernel-path-loader.js +5 -36
- package/src/config/kernels/kernel-ref-digests.js +1 -1
- package/src/config/kernels/registry.js +14 -1
- package/src/config/kernels/registry.json +7 -5
- package/src/config/loader.d.ts +1 -1
- package/src/config/loader.js +12 -2
- package/src/config/merge-contract-check.js +59 -4
- package/src/config/merge-helpers.js +128 -7
- package/src/config/merge.d.ts +1 -0
- package/src/config/merge.js +10 -0
- package/src/config/param-validator.js +47 -2
- package/src/config/presets/kernel-paths/{gemma2-q4k-dequant-f32a.json → gemma2-q4k-dequant-f32a-nosubgroups.json} +3 -3
- package/src/config/presets/kernel-paths/gemma3-f16-fused-f32a-online-streamingprefill.json +223 -0
- package/src/config/presets/kernel-paths/{gemma3-q4k-dequant-f32a.json → gemma3-q4k-dequant-f32a-nosubgroups.json} +3 -3
- package/src/config/presets/kernel-paths/registry.json +29 -8
- package/src/config/presets/models/gemma2.json +2 -2
- package/src/config/presets/runtime/experiments/bench/gemma3-bench-q4k.json +1 -1
- package/src/config/presets/runtime/experiments/debug/gemma3-debug-q4k.json +1 -1
- package/src/config/presets/runtime/experiments/verify/gemma3-verify.json +1 -1
- package/src/config/presets/runtime/kernels/dequant-f16-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/dequant-f32-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/embeddinggemma-q4k-dequant-f32a.json +37 -0
- package/src/config/presets/runtime/kernels/fused-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f16a.json +33 -0
- package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f32a-nosubgroups.json +33 -0
- package/src/config/presets/runtime/kernels/gemma2-q4k-fused-f32a.json +33 -0
- package/src/config/presets/runtime/kernels/safe-q4k.json +6 -13
- package/src/config/presets/runtime/platform/metal-apple-q4k.json +1 -1
- package/src/config/runtime.js +6 -1
- package/src/config/schema/debug.schema.d.ts +5 -0
- package/src/config/schema/doppler.schema.js +16 -21
- package/src/config/schema/inference-defaults.schema.js +3 -3
- package/src/config/schema/kernel-path.schema.d.ts +5 -1
- package/src/config/schema/kernel-thresholds.schema.js +12 -4
- package/src/config/schema/manifest.schema.d.ts +2 -1
- package/src/config/schema/manifest.schema.js +16 -3
- package/src/config/training-defaults.js +30 -22
- package/src/converter/conversion-plan.js +94 -9
- package/src/converter/core.d.ts +7 -0
- package/src/converter/core.js +14 -9
- package/src/converter/execution-v0-manifest.js +4 -1
- package/src/converter/index.d.ts +1 -0
- package/src/converter/index.js +1 -0
- package/src/converter/manifest-inference.js +43 -12
- package/src/converter/parsers/diffusion.js +0 -3
- package/src/converter/quantization-info.js +35 -15
- package/src/converter/shard-packer.d.ts +1 -1
- package/src/converter/shard-packer.js +4 -1
- package/src/debug/config.js +123 -11
- package/src/debug/signals.js +7 -1
- package/src/debug/tensor.d.ts +2 -0
- package/src/debug/tensor.js +13 -2
- package/src/distribution/p2p-control-plane.js +52 -12
- package/src/distribution/p2p-observability.js +43 -7
- package/src/distribution/p2p-webrtc-browser.js +20 -0
- package/src/distribution/shard-delivery.js +77 -26
- package/src/formats/gguf/types.js +33 -16
- package/src/formats/rdrr/groups.d.ts +12 -4
- package/src/formats/rdrr/groups.js +3 -6
- package/src/formats/rdrr/parsing.js +39 -2
- package/src/formats/rdrr/types.d.ts +2 -1
- package/src/gpu/command-recorder.js +86 -61
- package/src/gpu/device.d.ts +1 -0
- package/src/gpu/device.js +73 -19
- package/src/gpu/kernel-tuner/benchmarks.js +326 -316
- package/src/gpu/kernel-tuner/cache.js +71 -4
- package/src/gpu/kernel-tuner/tuner.js +22 -4
- package/src/gpu/kernels/attention.js +15 -34
- package/src/gpu/kernels/backward/adam.js +62 -58
- package/src/gpu/kernels/backward/attention_backward.js +257 -169
- package/src/gpu/kernels/backward/conv2d_backward.js +14 -1
- package/src/gpu/kernels/cast.js +191 -149
- package/src/gpu/kernels/check-stop.js +33 -44
- package/src/gpu/kernels/conv2d.js +27 -17
- package/src/gpu/kernels/cross_entropy_loss.js +21 -15
- package/src/gpu/kernels/depthwise_conv2d.js +36 -26
- package/src/gpu/kernels/dequant.js +178 -126
- package/src/gpu/kernels/energy.d.ts +3 -21
- package/src/gpu/kernels/energy.js +111 -88
- package/src/gpu/kernels/feature-check.js +1 -1
- package/src/gpu/kernels/fused_ffn.js +84 -65
- package/src/gpu/kernels/fused_matmul_residual.js +56 -33
- package/src/gpu/kernels/fused_matmul_rmsnorm.js +62 -45
- package/src/gpu/kernels/gather.js +33 -15
- package/src/gpu/kernels/gelu.js +19 -11
- package/src/gpu/kernels/grouped_pointwise_conv2d.js +33 -23
- package/src/gpu/kernels/groupnorm.js +34 -23
- package/src/gpu/kernels/kv-quantize.js +5 -2
- package/src/gpu/kernels/layernorm.js +35 -19
- package/src/gpu/kernels/logit-merge.js +5 -3
- package/src/gpu/kernels/matmul.js +58 -39
- package/src/gpu/kernels/modulate.js +23 -15
- package/src/gpu/kernels/moe.js +221 -175
- package/src/gpu/kernels/pixel_shuffle.js +22 -14
- package/src/gpu/kernels/relu.js +18 -10
- package/src/gpu/kernels/repeat_channels.js +25 -17
- package/src/gpu/kernels/residual.js +37 -27
- package/src/gpu/kernels/rmsnorm.js +57 -41
- package/src/gpu/kernels/rope.js +3 -0
- package/src/gpu/kernels/sample.js +27 -38
- package/src/gpu/kernels/sana_linear_attention.js +18 -10
- package/src/gpu/kernels/scale.js +18 -11
- package/src/gpu/kernels/shader-cache.js +4 -2
- package/src/gpu/kernels/silu.js +120 -72
- package/src/gpu/kernels/softmax.js +44 -25
- package/src/gpu/kernels/split_qkv.js +23 -13
- package/src/gpu/kernels/transpose.js +18 -10
- package/src/gpu/kernels/transpose.wgsl +5 -3
- package/src/gpu/kernels/upsample2d.js +21 -13
- package/src/gpu/kernels/utils.js +20 -13
- package/src/gpu/partitioned-buffer-pool.js +10 -2
- package/src/gpu/perf-guards.js +2 -9
- package/src/gpu/profiler.js +27 -22
- package/src/gpu/readback-utils.d.ts +16 -0
- package/src/gpu/readback-utils.js +41 -0
- package/src/gpu/submit-tracker.js +13 -0
- package/src/gpu/uniform-cache.d.ts +1 -0
- package/src/gpu/uniform-cache.js +30 -9
- package/src/hotswap/intent-bundle.js +6 -0
- package/src/hotswap/manifest.d.ts +10 -1
- package/src/hotswap/manifest.js +12 -2
- package/src/hotswap/runtime.js +30 -8
- package/src/index-browser.d.ts +44 -0
- package/src/index-browser.js +14 -0
- package/src/inference/browser-harness-contract-helpers.d.ts +5 -0
- package/src/inference/browser-harness-contract-helpers.js +28 -0
- package/src/inference/browser-harness-diffusion-energy-suites.d.ts +2 -0
- package/src/inference/browser-harness-diffusion-energy-suites.js +269 -0
- package/src/inference/browser-harness-model-helpers.d.ts +16 -0
- package/src/inference/browser-harness-model-helpers.js +217 -0
- package/src/inference/browser-harness-report-helpers.d.ts +7 -0
- package/src/inference/browser-harness-report-helpers.js +42 -0
- package/src/inference/browser-harness-runtime-helpers.d.ts +61 -0
- package/src/inference/browser-harness-runtime-helpers.js +415 -0
- package/src/inference/browser-harness-suite-helpers.d.ts +28 -0
- package/src/inference/browser-harness-suite-helpers.js +268 -0
- package/src/inference/browser-harness-text-helpers.d.ts +27 -0
- package/src/inference/browser-harness-text-helpers.js +788 -0
- package/src/inference/browser-harness.d.ts +6 -0
- package/src/inference/browser-harness.js +130 -1996
- package/src/inference/kv-cache/base.js +140 -94
- package/src/inference/kv-cache/tiered.js +5 -3
- package/src/inference/moe-router.js +88 -56
- package/src/inference/multi-model-network.js +5 -3
- package/src/inference/network-evolution.d.ts +11 -2
- package/src/inference/network-evolution.js +20 -21
- package/src/inference/pipelines/context.d.ts +3 -0
- package/src/inference/pipelines/context.js +142 -2
- package/src/inference/pipelines/diffusion/helpers.js +7 -2
- package/src/inference/pipelines/diffusion/pipeline.js +2 -1
- package/src/inference/pipelines/diffusion/sd3-transformer.js +10 -10
- package/src/inference/pipelines/diffusion/vae.js +3 -7
- package/src/inference/pipelines/energy/pipeline.js +27 -21
- package/src/inference/pipelines/energy/quintel.d.ts +5 -0
- package/src/inference/pipelines/energy/quintel.js +11 -0
- package/src/inference/pipelines/energy-head/row-head-pipeline.js +17 -13
- package/src/inference/pipelines/structured/json-head-pipeline.js +26 -11
- package/src/inference/pipelines/text/attention/projections.js +151 -101
- package/src/inference/pipelines/text/attention/record.js +62 -8
- package/src/inference/pipelines/text/attention/run.js +62 -8
- package/src/inference/pipelines/text/config.js +3 -4
- package/src/inference/pipelines/text/embed.js +2 -8
- package/src/inference/pipelines/text/execution-plan.js +41 -19
- package/src/inference/pipelines/text/execution-v0-contract-helpers.d.ts +59 -0
- package/src/inference/pipelines/text/execution-v0-contract-helpers.js +937 -0
- package/src/inference/pipelines/text/execution-v0-runtime-builders.d.ts +15 -0
- package/src/inference/pipelines/text/execution-v0-runtime-builders.js +279 -0
- package/src/inference/pipelines/text/execution-v0.js +62 -1013
- package/src/inference/pipelines/text/generator-steps.d.ts +46 -0
- package/src/inference/pipelines/text/generator-steps.js +298 -207
- package/src/inference/pipelines/text/generator.js +6 -23
- package/src/inference/pipelines/text/init.js +78 -20
- package/src/inference/pipelines/text/kernel-path-auto-select.js +2 -0
- package/src/inference/pipelines/text/kernel-trace.d.ts +2 -0
- package/src/inference/pipelines/text/kernel-trace.js +6 -0
- package/src/inference/pipelines/text/layer.js +3 -9
- package/src/inference/pipelines/text/linear-attention.d.ts +10 -0
- package/src/inference/pipelines/text/linear-attention.js +80 -6
- package/src/inference/pipelines/text/logits/gpu.js +10 -5
- package/src/inference/pipelines/text/logits/index.js +10 -11
- package/src/inference/pipelines/text/logits/utils.d.ts +7 -0
- package/src/inference/pipelines/text/logits/utils.js +9 -0
- package/src/inference/pipelines/text/lora-apply.js +50 -32
- package/src/inference/pipelines/text/model-load.js +279 -104
- package/src/inference/pipelines/text/moe-cache.js +5 -4
- package/src/inference/pipelines/text/moe-cpu-gptoss.js +74 -69
- package/src/inference/pipelines/text/moe-cpu.js +42 -38
- package/src/inference/pipelines/text/moe-gpu.js +110 -86
- package/src/inference/pipelines/text/ops.js +90 -90
- package/src/inference/pipelines/text/probes.js +9 -9
- package/src/inference/pipelines/text/weights.js +17 -7
- package/src/inference/pipelines/text.js +13 -1
- package/src/inference/speculative.d.ts +2 -2
- package/src/inference/speculative.js +4 -18
- package/src/inference/test-harness.d.ts +1 -1
- package/src/inference/test-harness.js +15 -5
- package/src/inference/tokenizer.d.ts +0 -5
- package/src/inference/tokenizer.js +4 -23
- package/src/inference/tokenizers/bpe.js +9 -0
- package/src/inference/tokenizers/bundled.js +20 -0
- package/src/inference/tokenizers/sentencepiece.js +12 -0
- package/src/loader/doppler-loader.js +38 -22
- package/src/loader/dtype-utils.js +3 -44
- package/src/loader/embedding-loader.js +7 -3
- package/src/loader/experts/expert-cache.js +13 -6
- package/src/loader/experts/expert-loader.js +10 -6
- package/src/loader/final-weights-loader.js +8 -4
- package/src/loader/layer-loader.js +2 -1
- package/src/loader/loader-state.js +2 -2
- package/src/loader/memory-monitor.js +8 -0
- package/src/loader/multi-model-loader.d.ts +14 -0
- package/src/loader/multi-model-loader.js +70 -24
- package/src/loader/shard-cache.js +81 -12
- package/src/loader/shard-resolver.js +25 -3
- package/src/loader/tensors/tensor-loader.js +209 -144
- package/src/loader/tensors/tensor-reader.js +76 -19
- package/src/loader/weight-downcast.js +1 -1
- package/src/memory/buffer-pool.d.ts +9 -1
- package/src/memory/buffer-pool.js +109 -44
- package/src/memory/unified-detect.js +1 -1
- package/src/rules/inference/kernel-path.rules.json +24 -8
- package/src/rules/rule-registry.js +25 -1
- package/src/storage/backends/opfs-store.js +68 -24
- package/src/storage/downloader.js +364 -83
- package/src/storage/index.d.ts +3 -0
- package/src/storage/index.js +3 -0
- package/src/storage/preflight.d.ts +2 -2
- package/src/storage/preflight.js +24 -2
- package/src/storage/quickstart-downloader.js +11 -5
- package/src/storage/registry.js +10 -4
- package/src/storage/reports.js +1 -1
- package/src/storage/shard-manager.d.ts +15 -1
- package/src/storage/shard-manager.js +51 -3
- package/src/storage/source-artifact-store.d.ts +52 -0
- package/src/storage/source-artifact-store.js +234 -0
- package/src/tooling/command-api-constants.d.ts +9 -0
- package/src/tooling/command-api-constants.js +9 -0
- package/src/tooling/command-api-family-normalizers.d.ts +9 -0
- package/src/tooling/command-api-family-normalizers.js +343 -0
- package/src/tooling/command-api-helpers.d.ts +25 -0
- package/src/tooling/command-api-helpers.js +262 -0
- package/src/tooling/command-api.js +16 -602
- package/src/tooling/command-envelope.js +4 -1
- package/src/tooling/command-runner-shared.js +52 -18
- package/src/tooling/lean-execution-contract.js +150 -3
- package/src/tooling/node-browser-command-runner.js +161 -271
- package/src/tooling/node-command-runner.js +29 -3
- package/src/tooling/node-converter.js +27 -1
- package/src/tooling/node-source-runtime.d.ts +1 -1
- package/src/tooling/node-source-runtime.js +84 -3
- package/src/tooling/node-webgpu.js +24 -21
- package/src/tooling/opfs-cache.js +21 -4
- package/src/tooling/runtime-input-composition.d.ts +38 -0
- package/src/tooling/runtime-input-composition.js +86 -0
- package/src/tooling/source-runtime-bundle.d.ts +40 -5
- package/src/tooling/source-runtime-bundle.js +261 -34
- package/src/tooling/source-runtime-materializer.d.ts +6 -0
- package/src/tooling/source-runtime-materializer.js +93 -0
- package/src/training/attention-backward.js +32 -17
- package/src/training/autograd.js +80 -52
- package/src/training/checkpoint-watch.d.ts +2 -1
- package/src/training/checkpoint-watch.js +39 -6
- package/src/training/checkpoint.js +40 -11
- package/src/training/clip.js +2 -1
- package/src/training/datasets/token-batch.js +20 -8
- package/src/training/distillation/checkpoint-watch.js +1 -0
- package/src/training/distillation/student-fixture.d.ts +22 -0
- package/src/training/distillation/student-fixture.js +846 -0
- package/src/training/distillation/suite-data.d.ts +45 -0
- package/src/training/distillation/suite-data.js +189 -0
- package/src/training/lora-pipeline.js +4 -7
- package/src/training/lora.js +26 -12
- package/src/training/loss.js +5 -6
- package/src/training/objectives/cross_entropy.js +2 -5
- package/src/training/objectives/distill_kd.js +4 -8
- package/src/training/objectives/distill_triplet.js +4 -8
- package/src/training/objectives/ul_stage2_base.js +4 -8
- package/src/training/operator-command.js +2 -0
- package/src/training/optimizer.js +19 -7
- package/src/training/runner.js +2 -1
- package/src/training/suite.js +18 -978
- package/src/training/tensor-factory.d.ts +9 -0
- package/src/training/tensor-factory.js +13 -0
- package/src/training/trainer.js +3 -5
- package/src/training/ul_dataset.js +3 -5
- package/src/training/workloads.js +70 -79
- package/src/version.js +1 -1
- package/tools/convert-safetensors-node.js +22 -16
- package/tools/doppler-cli.js +44 -25
|
@@ -58,36 +58,46 @@ async function _depthwiseConv2D(target, input, weight, bias, options = {}) {
|
|
|
58
58
|
device.queue.writeBuffer(biasBuffer, 0, new Uint8Array(paddedSize));
|
|
59
59
|
}
|
|
60
60
|
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
61
|
+
try {
|
|
62
|
+
await unifiedKernelWrapper(
|
|
63
|
+
'depthwise_conv2d',
|
|
64
|
+
target,
|
|
65
|
+
variant,
|
|
66
|
+
[input, weightBuffer, biasBuffer, output],
|
|
67
|
+
{
|
|
68
|
+
channels,
|
|
69
|
+
height,
|
|
70
|
+
width,
|
|
71
|
+
out_height: outHeight,
|
|
72
|
+
out_width: outWidth,
|
|
73
|
+
kernel_h: kernelH,
|
|
74
|
+
kernel_w: kernelW,
|
|
75
|
+
stride,
|
|
76
|
+
pad,
|
|
77
|
+
_pad0: 0,
|
|
78
|
+
_pad1: 0,
|
|
79
|
+
},
|
|
80
|
+
[Math.ceil(outSpatial / WORKGROUP_SIZES.DEFAULT), channels, 1]
|
|
81
|
+
);
|
|
81
82
|
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
83
|
+
if (tempBias) {
|
|
84
|
+
if (recorder) {
|
|
85
|
+
recorder.trackTemporaryBuffer(tempBias);
|
|
86
|
+
} else {
|
|
87
|
+
releaseBuffer(tempBias);
|
|
88
|
+
}
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
return createTensor(output, input.dtype, [channels, outHeight, outWidth], 'depthwise_conv2d_output');
|
|
92
|
+
} catch (error) {
|
|
93
|
+
if (tempBias) {
|
|
86
94
|
releaseBuffer(tempBias);
|
|
87
95
|
}
|
|
96
|
+
if (!outputBuffer) {
|
|
97
|
+
releaseBuffer(output);
|
|
98
|
+
}
|
|
99
|
+
throw error;
|
|
88
100
|
}
|
|
89
|
-
|
|
90
|
-
return createTensor(output, input.dtype, [channels, outHeight, outWidth], 'depthwise_conv2d_output');
|
|
91
101
|
}
|
|
92
102
|
|
|
93
103
|
export async function runDepthwiseConv2D(input, weight, bias, options = {}) {
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
|
|
2
2
|
|
|
3
3
|
import { getDevice, getKernelCapabilities } from '../device.js';
|
|
4
|
-
import { acquireBuffer } from '../../memory/buffer-pool.js';
|
|
4
|
+
import { acquireBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
|
|
5
5
|
import { createTensor } from '../tensor.js';
|
|
6
6
|
import { GPU_LIMITS, TILE_SIZES, WORKGROUP_SIZES } from './constants.js';
|
|
7
7
|
import { Q6K_BLOCK_BYTES, Q8_0_BLOCK_BYTES, Q8_0_BLOCK_SIZE } from '../../loader/quantization-constants.js';
|
|
@@ -69,6 +69,17 @@ export function createDequantBindGroupLayout() {
|
|
|
69
69
|
]);
|
|
70
70
|
}
|
|
71
71
|
|
|
72
|
+
function cleanupDequantResources(uniformBuffer, ownedBuffers) {
|
|
73
|
+
if (uniformBuffer) {
|
|
74
|
+
releaseUniformBuffer(uniformBuffer);
|
|
75
|
+
}
|
|
76
|
+
for (const buffer of ownedBuffers) {
|
|
77
|
+
if (buffer) {
|
|
78
|
+
releaseBuffer(buffer);
|
|
79
|
+
}
|
|
80
|
+
}
|
|
81
|
+
}
|
|
82
|
+
|
|
72
83
|
|
|
73
84
|
export async function dequantize(
|
|
74
85
|
quantized,
|
|
@@ -76,12 +87,17 @@ export async function dequantize(
|
|
|
76
87
|
options = {}
|
|
77
88
|
) {
|
|
78
89
|
const device = getDevice();
|
|
90
|
+
const capabilities = getKernelCapabilities();
|
|
79
91
|
const {
|
|
80
92
|
outputOffset = 0,
|
|
81
93
|
outputBuffer = null,
|
|
82
94
|
outputDtype = 'f32',
|
|
83
95
|
} = options;
|
|
84
96
|
|
|
97
|
+
if (outputDtype === 'f16' && capabilities?.hasF16 !== true) {
|
|
98
|
+
throw new Error('[dequantize] f16 output requires shader-f16 support.');
|
|
99
|
+
}
|
|
100
|
+
|
|
85
101
|
// Select kernel
|
|
86
102
|
const variant = selectDequantKernel({ ...options, outputDtype });
|
|
87
103
|
const pipeline = await getPipelineFast('dequant', variant);
|
|
@@ -92,7 +108,8 @@ export async function dequantize(
|
|
|
92
108
|
const outputSize = numBlocks * QK_K * bytesPerElem;
|
|
93
109
|
|
|
94
110
|
// Create output buffer if not provided
|
|
95
|
-
const
|
|
111
|
+
const ownedOutput = outputBuffer ? null : acquireBuffer(outputSize, undefined, 'dequant_output');
|
|
112
|
+
const output = outputBuffer || ownedOutput;
|
|
96
113
|
|
|
97
114
|
// Create uniform buffer
|
|
98
115
|
const uniformBuffer = createUniformBufferWithView(
|
|
@@ -108,21 +125,24 @@ export async function dequantize(
|
|
|
108
125
|
device
|
|
109
126
|
);
|
|
110
127
|
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
128
|
+
try {
|
|
129
|
+
const bindGroup = device.createBindGroup({
|
|
130
|
+
label: 'dequant_bind_group',
|
|
131
|
+
layout: pipeline.getBindGroupLayout(0),
|
|
132
|
+
entries: [
|
|
133
|
+
{ binding: 0, resource: { buffer: uniformBuffer } },
|
|
134
|
+
{ binding: 1, resource: { buffer: quantized } },
|
|
135
|
+
{ binding: 2, resource: { buffer: output } },
|
|
136
|
+
],
|
|
137
|
+
});
|
|
138
|
+
|
|
139
|
+
const workgroups = calculateDequantWorkgroups(variant, numBlocks);
|
|
140
|
+
dispatch(device, pipeline, bindGroup, workgroups, 'dequant');
|
|
141
|
+
} catch (error) {
|
|
142
|
+
cleanupDequantResources(uniformBuffer, [ownedOutput]);
|
|
143
|
+
throw error;
|
|
144
|
+
}
|
|
124
145
|
|
|
125
|
-
// Release uniform buffer back to cache (or destroy if not cached)
|
|
126
146
|
releaseUniformBuffer(uniformBuffer);
|
|
127
147
|
|
|
128
148
|
|
|
@@ -140,7 +160,11 @@ export async function dequantizeRowwise(
|
|
|
140
160
|
options = {}
|
|
141
161
|
) {
|
|
142
162
|
const device = getDevice();
|
|
163
|
+
const capabilities = getKernelCapabilities();
|
|
143
164
|
const { outputBuffer = null, outputDtype = 'f16' } = options;
|
|
165
|
+
if (outputDtype === 'f16' && capabilities?.hasF16 !== true) {
|
|
166
|
+
throw new Error('[dequantizeRowwise] f16 output requires shader-f16 support.');
|
|
167
|
+
}
|
|
144
168
|
const finalOutputDtype = selectSharedRuleValue('shared', 'dtype', 'f16OrF32FromDtype', { dtype: outputDtype });
|
|
145
169
|
const pipelineVariant = selectKernelRuleValue(
|
|
146
170
|
'dequant',
|
|
@@ -157,7 +181,8 @@ export async function dequantizeRowwise(
|
|
|
157
181
|
const bytesPerElem = finalOutputDtype === 'f16' ? 2 : 4;
|
|
158
182
|
const outputSize = rows * K * bytesPerElem;
|
|
159
183
|
|
|
160
|
-
const
|
|
184
|
+
const ownedOutput = outputBuffer ? null : acquireBuffer(outputSize, undefined, 'dequant_rowwise_output');
|
|
185
|
+
const output = outputBuffer || ownedOutput;
|
|
161
186
|
|
|
162
187
|
const uniformBuffer = createUniformBufferWithView(
|
|
163
188
|
'dequant_rowwise_uniforms',
|
|
@@ -172,18 +197,23 @@ export async function dequantizeRowwise(
|
|
|
172
197
|
device
|
|
173
198
|
);
|
|
174
199
|
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
200
|
+
try {
|
|
201
|
+
const bindGroup = device.createBindGroup({
|
|
202
|
+
label: 'dequant_rowwise_bind_group',
|
|
203
|
+
layout: pipeline.getBindGroupLayout(0),
|
|
204
|
+
entries: [
|
|
205
|
+
{ binding: 0, resource: { buffer: uniformBuffer } },
|
|
206
|
+
{ binding: 1, resource: { buffer: quantized } },
|
|
207
|
+
{ binding: 2, resource: { buffer: output } },
|
|
208
|
+
],
|
|
209
|
+
});
|
|
210
|
+
|
|
211
|
+
const workgroups = [numBlocks, 1, 1];
|
|
212
|
+
dispatch(device, pipeline, bindGroup, workgroups, 'dequant_rowwise');
|
|
213
|
+
} catch (error) {
|
|
214
|
+
cleanupDequantResources(uniformBuffer, [ownedOutput]);
|
|
215
|
+
throw error;
|
|
216
|
+
}
|
|
187
217
|
|
|
188
218
|
releaseUniformBuffer(uniformBuffer);
|
|
189
219
|
|
|
@@ -208,7 +238,8 @@ export async function dequantizeMXFP4(
|
|
|
208
238
|
const pipeline = await getPipelineFast('dequant', 'mxfp4');
|
|
209
239
|
|
|
210
240
|
const outputSize = totalElements * 4; // F32 output
|
|
211
|
-
const
|
|
241
|
+
const ownedOutput = outputBuffer ? null : acquireBuffer(outputSize, undefined, 'mxfp4_dequant_output');
|
|
242
|
+
const output = outputBuffer || ownedOutput;
|
|
212
243
|
|
|
213
244
|
// Create uniform buffer
|
|
214
245
|
const uniformBuffer = createUniformBufferWithView(
|
|
@@ -224,26 +255,29 @@ export async function dequantizeMXFP4(
|
|
|
224
255
|
device
|
|
225
256
|
);
|
|
226
257
|
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
258
|
+
try {
|
|
259
|
+
const bindGroup = device.createBindGroup({
|
|
260
|
+
label: 'mxfp4_dequant_bind_group',
|
|
261
|
+
layout: pipeline.getBindGroupLayout(0),
|
|
262
|
+
entries: [
|
|
263
|
+
{ binding: 0, resource: { buffer: uniformBuffer } },
|
|
264
|
+
{ binding: 1, resource: { buffer: blocks } },
|
|
265
|
+
{ binding: 2, resource: { buffer: scales } },
|
|
266
|
+
{ binding: 3, resource: { buffer: output } },
|
|
267
|
+
],
|
|
268
|
+
});
|
|
269
|
+
|
|
270
|
+
const workgroups = Math.ceil(totalElements / WORKGROUP_SIZES.DEFAULT);
|
|
271
|
+
const dispatchSize = [
|
|
272
|
+
Math.min(workgroups, GPU_LIMITS.MAX_WORKGROUPS),
|
|
273
|
+
Math.max(1, Math.ceil(workgroups / GPU_LIMITS.MAX_WORKGROUPS)),
|
|
274
|
+
1,
|
|
275
|
+
];
|
|
276
|
+
dispatch(device, pipeline, bindGroup, dispatchSize, 'mxfp4_dequant');
|
|
277
|
+
} catch (error) {
|
|
278
|
+
cleanupDequantResources(uniformBuffer, [ownedOutput]);
|
|
279
|
+
throw error;
|
|
280
|
+
}
|
|
247
281
|
|
|
248
282
|
releaseUniformBuffer(uniformBuffer);
|
|
249
283
|
|
|
@@ -284,7 +318,8 @@ export async function dequantizeMXFP4Expert(
|
|
|
284
318
|
const totalOutput = outDim * numGroups * 32;
|
|
285
319
|
const bytesPerElement = outputDtype === 'f16' ? 2 : 4;
|
|
286
320
|
const outputSize = totalOutput * bytesPerElement;
|
|
287
|
-
const
|
|
321
|
+
const ownedOutput = outputBuffer ? null : acquireBuffer(outputSize, undefined, 'mxfp4_expert_output');
|
|
322
|
+
const output = outputBuffer || ownedOutput;
|
|
288
323
|
|
|
289
324
|
// Create uniform buffer
|
|
290
325
|
const uniformBuffer = createUniformBufferWithView(
|
|
@@ -301,26 +336,29 @@ export async function dequantizeMXFP4Expert(
|
|
|
301
336
|
device
|
|
302
337
|
);
|
|
303
338
|
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
339
|
+
try {
|
|
340
|
+
const bindGroup = device.createBindGroup({
|
|
341
|
+
label: 'mxfp4_expert_bind_group',
|
|
342
|
+
layout: pipeline.getBindGroupLayout(0),
|
|
343
|
+
entries: [
|
|
344
|
+
{ binding: 0, resource: { buffer: uniformBuffer } },
|
|
345
|
+
{ binding: 1, resource: { buffer: blocks } },
|
|
346
|
+
{ binding: 2, resource: { buffer: scales } },
|
|
347
|
+
{ binding: 3, resource: { buffer: output } },
|
|
348
|
+
],
|
|
349
|
+
});
|
|
350
|
+
|
|
351
|
+
const workgroups = Math.ceil(totalOutput / WORKGROUP_SIZES.DEFAULT);
|
|
352
|
+
const dispatchSize = [
|
|
353
|
+
Math.min(workgroups, GPU_LIMITS.MAX_WORKGROUPS),
|
|
354
|
+
Math.max(1, Math.ceil(workgroups / GPU_LIMITS.MAX_WORKGROUPS)),
|
|
355
|
+
1,
|
|
356
|
+
];
|
|
357
|
+
dispatch(device, pipeline, bindGroup, dispatchSize, 'mxfp4_expert');
|
|
358
|
+
} catch (error) {
|
|
359
|
+
cleanupDequantResources(uniformBuffer, [ownedOutput]);
|
|
360
|
+
throw error;
|
|
361
|
+
}
|
|
324
362
|
|
|
325
363
|
releaseUniformBuffer(uniformBuffer);
|
|
326
364
|
|
|
@@ -350,7 +388,8 @@ export async function dequantizeQ6K(
|
|
|
350
388
|
const outputSize = numBlocks * QK_K * bytesPerElem;
|
|
351
389
|
|
|
352
390
|
// Create output buffer if not provided
|
|
353
|
-
const
|
|
391
|
+
const ownedOutput = outputBuffer ? null : acquireBuffer(outputSize, undefined, 'q6k_dequant_output');
|
|
392
|
+
const output = outputBuffer || ownedOutput;
|
|
354
393
|
|
|
355
394
|
// Calculate workgroups for 2D dispatch
|
|
356
395
|
const maxWorkgroups = GPU_LIMITS.MAX_WORKGROUPS;
|
|
@@ -370,26 +409,28 @@ export async function dequantizeQ6K(
|
|
|
370
409
|
device
|
|
371
410
|
);
|
|
372
411
|
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
412
|
+
try {
|
|
413
|
+
const bindGroup = device.createBindGroup({
|
|
414
|
+
label: 'q6k_dequant_bind_group',
|
|
415
|
+
layout: pipeline.getBindGroupLayout(0),
|
|
416
|
+
entries: [
|
|
417
|
+
{ binding: 0, resource: { buffer: uniformBuffer } },
|
|
418
|
+
{ binding: 1, resource: { buffer: quantized } },
|
|
419
|
+
{ binding: 2, resource: { buffer: output } },
|
|
420
|
+
],
|
|
421
|
+
});
|
|
422
|
+
|
|
423
|
+
const workgroups = [
|
|
424
|
+
workgroupsX,
|
|
425
|
+
numBlocks > maxWorkgroups ? Math.ceil(numBlocks / maxWorkgroups) : 1,
|
|
426
|
+
1
|
|
427
|
+
];
|
|
428
|
+
|
|
429
|
+
dispatch(device, pipeline, bindGroup, workgroups, 'q6k_dequant');
|
|
430
|
+
} catch (error) {
|
|
431
|
+
cleanupDequantResources(uniformBuffer, [ownedOutput]);
|
|
432
|
+
throw error;
|
|
433
|
+
}
|
|
393
434
|
|
|
394
435
|
releaseUniformBuffer(uniformBuffer);
|
|
395
436
|
|
|
@@ -419,7 +460,8 @@ export async function dequantizeQ8_0(
|
|
|
419
460
|
const outputSize = numBlocks * Q8_0_BLOCK_SIZE * bytesPerElem;
|
|
420
461
|
|
|
421
462
|
// Create output buffer if not provided
|
|
422
|
-
const
|
|
463
|
+
const ownedOutput = outputBuffer ? null : acquireBuffer(outputSize, undefined, 'q8_0_dequant_output');
|
|
464
|
+
const output = outputBuffer || ownedOutput;
|
|
423
465
|
|
|
424
466
|
// Calculate workgroups for 2D dispatch
|
|
425
467
|
const maxWorkgroups = GPU_LIMITS.MAX_WORKGROUPS;
|
|
@@ -439,26 +481,28 @@ export async function dequantizeQ8_0(
|
|
|
439
481
|
device
|
|
440
482
|
);
|
|
441
483
|
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
484
|
+
try {
|
|
485
|
+
const bindGroup = device.createBindGroup({
|
|
486
|
+
label: 'q8_0_dequant_bind_group',
|
|
487
|
+
layout: pipeline.getBindGroupLayout(0),
|
|
488
|
+
entries: [
|
|
489
|
+
{ binding: 0, resource: { buffer: uniformBuffer } },
|
|
490
|
+
{ binding: 1, resource: { buffer: quantized } },
|
|
491
|
+
{ binding: 2, resource: { buffer: output } },
|
|
492
|
+
],
|
|
493
|
+
});
|
|
494
|
+
|
|
495
|
+
const workgroups = [
|
|
496
|
+
workgroupsX,
|
|
497
|
+
numBlocks > maxWorkgroups ? Math.ceil(numBlocks / maxWorkgroups) : 1,
|
|
498
|
+
1
|
|
499
|
+
];
|
|
500
|
+
|
|
501
|
+
dispatch(device, pipeline, bindGroup, workgroups, 'q8_0_dequant');
|
|
502
|
+
} catch (error) {
|
|
503
|
+
cleanupDequantResources(uniformBuffer, [ownedOutput]);
|
|
504
|
+
throw error;
|
|
505
|
+
}
|
|
462
506
|
|
|
463
507
|
releaseUniformBuffer(uniformBuffer);
|
|
464
508
|
|
|
@@ -491,7 +535,8 @@ export async function recordDequantize(
|
|
|
491
535
|
const outputSize = numBlocks * QK_K * bytesPerElem;
|
|
492
536
|
|
|
493
537
|
// Output buffer
|
|
494
|
-
const
|
|
538
|
+
const ownedOutput = outputBuffer ? null : acquireBuffer(outputSize, undefined, 'dequant_output');
|
|
539
|
+
const output = outputBuffer || ownedOutput;
|
|
495
540
|
|
|
496
541
|
// Uniform buffer
|
|
497
542
|
const uniformBuffer = createUniformBufferWithView(
|
|
@@ -505,18 +550,25 @@ export async function recordDequantize(
|
|
|
505
550
|
);
|
|
506
551
|
|
|
507
552
|
// Bind group
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
553
|
+
try {
|
|
554
|
+
const bindGroup = device.createBindGroup({
|
|
555
|
+
label: 'dequant_bind_group',
|
|
556
|
+
layout: pipeline.getBindGroupLayout(0),
|
|
557
|
+
entries: [
|
|
558
|
+
{ binding: 0, resource: { buffer: uniformBuffer } },
|
|
559
|
+
{ binding: 1, resource: { buffer: quantized } },
|
|
560
|
+
{ binding: 2, resource: { buffer: output } },
|
|
561
|
+
],
|
|
562
|
+
});
|
|
563
|
+
|
|
564
|
+
const workgroups = calculateDequantWorkgroups(variant, numBlocks);
|
|
565
|
+
recordDispatch(recorder, pipeline, bindGroup, workgroups, 'dequant');
|
|
566
|
+
} catch (error) {
|
|
567
|
+
if (ownedOutput) {
|
|
568
|
+
releaseBuffer(ownedOutput);
|
|
569
|
+
}
|
|
570
|
+
throw error;
|
|
571
|
+
}
|
|
520
572
|
|
|
521
573
|
|
|
522
574
|
const dtype = selectSharedRuleValue('shared', 'dtype', 'f16OrF32FromDtype', { dtype: outputDtype });
|
|
@@ -16,6 +16,7 @@ export interface EnergyUpdateOptions {
|
|
|
16
16
|
export interface EnergyQuintelUpdateOptions {
|
|
17
17
|
count?: number;
|
|
18
18
|
size?: number;
|
|
19
|
+
flags?: number;
|
|
19
20
|
stepSize?: number;
|
|
20
21
|
gradientScale?: number;
|
|
21
22
|
countDiff?: number;
|
|
@@ -26,48 +27,29 @@ export interface EnergyQuintelUpdateOptions {
|
|
|
26
27
|
centerTarget?: number;
|
|
27
28
|
clampMin?: number;
|
|
28
29
|
clampMax?: number;
|
|
29
|
-
rules?: {
|
|
30
|
-
mirrorX?: boolean;
|
|
31
|
-
mirrorY?: boolean;
|
|
32
|
-
diagonal?: boolean;
|
|
33
|
-
count?: boolean;
|
|
34
|
-
center?: boolean;
|
|
35
|
-
};
|
|
36
30
|
}
|
|
37
31
|
|
|
38
32
|
export interface EnergyQuintelReduceOptions {
|
|
39
33
|
count?: number;
|
|
40
34
|
size?: number;
|
|
35
|
+
flags?: number;
|
|
41
36
|
symmetryWeight?: number;
|
|
42
37
|
centerWeight?: number;
|
|
43
38
|
binarizeWeight?: number;
|
|
44
39
|
centerTarget?: number;
|
|
45
|
-
rules?: {
|
|
46
|
-
mirrorX?: boolean;
|
|
47
|
-
mirrorY?: boolean;
|
|
48
|
-
diagonal?: boolean;
|
|
49
|
-
count?: boolean;
|
|
50
|
-
center?: boolean;
|
|
51
|
-
};
|
|
52
40
|
outputBuffer?: GPUBuffer | null;
|
|
53
41
|
}
|
|
54
42
|
|
|
55
43
|
export interface EnergyQuintelGradOptions {
|
|
56
44
|
count?: number;
|
|
57
45
|
size?: number;
|
|
46
|
+
flags?: number;
|
|
58
47
|
countDiff?: number;
|
|
59
48
|
symmetryWeight?: number;
|
|
60
49
|
countWeight?: number;
|
|
61
50
|
centerWeight?: number;
|
|
62
51
|
binarizeWeight?: number;
|
|
63
52
|
centerTarget?: number;
|
|
64
|
-
rules?: {
|
|
65
|
-
mirrorX?: boolean;
|
|
66
|
-
mirrorY?: boolean;
|
|
67
|
-
diagonal?: boolean;
|
|
68
|
-
count?: boolean;
|
|
69
|
-
center?: boolean;
|
|
70
|
-
};
|
|
71
53
|
outputBuffer?: GPUBuffer | null;
|
|
72
54
|
}
|
|
73
55
|
|