@simulatte/doppler 0.1.5 → 0.1.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +126 -0
- package/README.md +25 -17
- package/package.json +20 -4
- package/src/adapters/adapter-registry.js +12 -1
- package/src/adapters/lora-loader.js +23 -6
- package/src/bridge/extension-client.d.ts +5 -0
- package/src/bridge/extension-client.js +40 -0
- package/src/bridge/index.d.ts +2 -1
- package/src/bridge/index.js +6 -4
- package/src/browser/browser-converter.js +26 -1
- package/src/browser/file-picker.js +6 -0
- package/src/browser/safetensors-parser-browser.js +84 -1
- package/src/browser/shard-io-browser.js +2 -2
- package/src/browser/tensor-source-download.js +8 -2
- package/src/browser/tensor-source-http.d.ts +1 -0
- package/src/browser/tensor-source-http.js +5 -1
- package/src/client/doppler-api.browser.js +20 -4
- package/src/client/doppler-api.js +19 -3
- package/src/client/doppler-provider/generation.js +12 -0
- package/src/client/doppler-provider/model-manager.d.ts +10 -0
- package/src/client/doppler-provider/model-manager.js +91 -19
- package/src/client/doppler-provider/source-runtime.d.ts +2 -1
- package/src/client/doppler-provider/source-runtime.js +132 -13
- package/src/client/doppler-registry.json +8 -7
- package/src/config/backward-registry-loader.js +17 -2
- package/src/config/execution-v0-contract-check.js +113 -15
- package/src/config/kernel-path-contract-check.js +57 -29
- package/src/config/kernel-path-loader.js +5 -36
- package/src/config/kernels/kernel-ref-digests.js +39 -39
- package/src/config/kernels/registry.js +14 -1
- package/src/config/kernels/registry.json +49 -7
- package/src/config/loader.d.ts +1 -1
- package/src/config/loader.js +43 -4
- package/src/config/merge-contract-check.js +59 -4
- package/src/config/merge-helpers.js +128 -7
- package/src/config/merge.d.ts +1 -0
- package/src/config/merge.js +28 -0
- package/src/config/param-validator.js +47 -2
- package/src/config/presets/kernel-paths/{gemma2-q4k-dequant-f32a.json → gemma2-q4k-dequant-f32a-nosubgroups.json} +3 -3
- package/src/config/presets/kernel-paths/gemma3-f16-fused-f32a-online-streamingprefill.json +223 -0
- package/src/config/presets/kernel-paths/{gemma3-q4k-dequant-f32a.json → gemma3-q4k-dequant-f32a-nosubgroups.json} +3 -3
- package/src/config/presets/kernel-paths/registry.json +29 -8
- package/src/config/presets/models/gemma2.json +2 -2
- package/src/config/presets/models/qwen3.json +9 -2
- package/src/config/presets/models/transformer.json +5 -0
- package/src/config/presets/runtime/experiments/bench/gemma3-bench-q4k.json +1 -1
- package/src/config/presets/runtime/experiments/debug/gemma3-debug-q4k.json +1 -1
- package/src/config/presets/runtime/experiments/verify/gemma3-verify.json +1 -1
- package/src/config/presets/runtime/kernels/dequant-f16-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/dequant-f32-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/embeddinggemma-q4k-dequant-f32a.json +37 -0
- package/src/config/presets/runtime/kernels/fused-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f16a.json +33 -0
- package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f32a-nosubgroups.json +33 -0
- package/src/config/presets/runtime/kernels/gemma2-q4k-fused-f32a.json +33 -0
- package/src/config/presets/runtime/kernels/safe-q4k.json +6 -13
- package/src/config/presets/runtime/platform/metal-apple-q4k.json +1 -1
- package/src/config/required-inference-fields-contract-check.js +6 -0
- package/src/config/runtime.js +6 -1
- package/src/config/schema/debug.schema.d.ts +5 -0
- package/src/config/schema/doppler.schema.js +16 -21
- package/src/config/schema/inference-defaults.schema.js +6 -3
- package/src/config/schema/inference.schema.d.ts +9 -0
- package/src/config/schema/kernel-path.schema.d.ts +11 -1
- package/src/config/schema/kernel-thresholds.schema.js +12 -4
- package/src/config/schema/manifest.schema.d.ts +8 -1
- package/src/config/schema/manifest.schema.js +19 -3
- package/src/config/training-defaults.js +30 -22
- package/src/converter/conversion-plan.js +94 -9
- package/src/converter/core.d.ts +7 -0
- package/src/converter/core.js +14 -9
- package/src/converter/execution-v0-manifest.js +4 -1
- package/src/converter/index.d.ts +1 -0
- package/src/converter/index.js +1 -0
- package/src/converter/manifest-inference.js +43 -12
- package/src/converter/parsers/diffusion.js +0 -3
- package/src/converter/quantization-info.js +35 -15
- package/src/converter/rope-config.js +42 -0
- package/src/converter/shard-packer.d.ts +1 -1
- package/src/converter/shard-packer.js +4 -1
- package/src/debug/config.js +123 -11
- package/src/debug/signals.js +7 -1
- package/src/debug/tensor.d.ts +2 -0
- package/src/debug/tensor.js +13 -2
- package/src/distribution/p2p-control-plane.js +52 -12
- package/src/distribution/p2p-observability.js +43 -7
- package/src/distribution/p2p-webrtc-browser.js +20 -0
- package/src/distribution/shard-delivery.js +77 -26
- package/src/formats/gguf/types.js +33 -16
- package/src/formats/rdrr/groups.d.ts +12 -4
- package/src/formats/rdrr/groups.js +3 -6
- package/src/formats/rdrr/parsing.js +39 -2
- package/src/formats/rdrr/types.d.ts +2 -1
- package/src/gpu/command-recorder.js +86 -61
- package/src/gpu/device.d.ts +1 -0
- package/src/gpu/device.js +131 -19
- package/src/gpu/kernel-tuner/benchmarks.js +326 -316
- package/src/gpu/kernel-tuner/cache.js +71 -4
- package/src/gpu/kernel-tuner/tuner.js +22 -4
- package/src/gpu/kernels/attention.js +113 -34
- package/src/gpu/kernels/backward/adam.js +62 -58
- package/src/gpu/kernels/backward/attention_backward.js +257 -169
- package/src/gpu/kernels/backward/conv2d_backward.js +14 -1
- package/src/gpu/kernels/bias_add.wgsl +8 -6
- package/src/gpu/kernels/bias_add_f16.wgsl +8 -5
- package/src/gpu/kernels/cast.js +191 -149
- package/src/gpu/kernels/check-stop.js +33 -44
- package/src/gpu/kernels/conv2d.js +27 -17
- package/src/gpu/kernels/conv2d.wgsl +7 -8
- package/src/gpu/kernels/conv2d_f16.wgsl +7 -8
- package/src/gpu/kernels/cross_entropy_loss.js +21 -15
- package/src/gpu/kernels/depthwise_conv2d.js +37 -26
- package/src/gpu/kernels/depthwise_conv2d.wgsl +6 -9
- package/src/gpu/kernels/depthwise_conv2d_f16.wgsl +6 -9
- package/src/gpu/kernels/dequant.js +178 -126
- package/src/gpu/kernels/energy.d.ts +3 -21
- package/src/gpu/kernels/energy.js +111 -88
- package/src/gpu/kernels/feature-check.js +1 -1
- package/src/gpu/kernels/fused_ffn.js +84 -65
- package/src/gpu/kernels/fused_matmul_residual.js +56 -33
- package/src/gpu/kernels/fused_matmul_rmsnorm.js +62 -45
- package/src/gpu/kernels/gather.js +33 -15
- package/src/gpu/kernels/gelu.js +19 -11
- package/src/gpu/kernels/grouped_pointwise_conv2d.js +34 -23
- package/src/gpu/kernels/grouped_pointwise_conv2d.wgsl +6 -9
- package/src/gpu/kernels/grouped_pointwise_conv2d_f16.wgsl +6 -9
- package/src/gpu/kernels/groupnorm.js +34 -23
- package/src/gpu/kernels/kv-quantize.js +5 -2
- package/src/gpu/kernels/layernorm.js +35 -19
- package/src/gpu/kernels/logit-merge.js +5 -3
- package/src/gpu/kernels/matmul.js +83 -39
- package/src/gpu/kernels/modulate.js +23 -15
- package/src/gpu/kernels/moe.js +221 -175
- package/src/gpu/kernels/pixel_shuffle.js +22 -14
- package/src/gpu/kernels/pixel_shuffle.wgsl +4 -5
- package/src/gpu/kernels/pixel_shuffle_f16.wgsl +4 -5
- package/src/gpu/kernels/relu.js +31 -10
- package/src/gpu/kernels/relu.wgsl +2 -1
- package/src/gpu/kernels/relu_f16.wgsl +2 -1
- package/src/gpu/kernels/repeat_channels.js +25 -17
- package/src/gpu/kernels/repeat_channels.wgsl +4 -5
- package/src/gpu/kernels/repeat_channels_f16.wgsl +4 -5
- package/src/gpu/kernels/residual.js +69 -23
- package/src/gpu/kernels/residual.wgsl +6 -3
- package/src/gpu/kernels/residual_f16.wgsl +2 -1
- package/src/gpu/kernels/residual_f16_vec4.wgsl +2 -1
- package/src/gpu/kernels/residual_vec4.wgsl +2 -1
- package/src/gpu/kernels/rmsnorm.js +96 -28
- package/src/gpu/kernels/rmsnorm.wgsl +14 -6
- package/src/gpu/kernels/rmsnorm_f16.wgsl +10 -2
- package/src/gpu/kernels/rope.d.ts +2 -0
- package/src/gpu/kernels/rope.js +14 -1
- package/src/gpu/kernels/rope.wgsl +56 -40
- package/src/gpu/kernels/sample.js +27 -38
- package/src/gpu/kernels/sana_linear_attention.js +19 -12
- package/src/gpu/kernels/sana_linear_attention_apply.wgsl +4 -5
- package/src/gpu/kernels/sana_linear_attention_apply_f16.wgsl +4 -5
- package/src/gpu/kernels/sana_linear_attention_summary.wgsl +4 -0
- package/src/gpu/kernels/sana_linear_attention_summary_f16.wgsl +4 -0
- package/src/gpu/kernels/scale.js +18 -11
- package/src/gpu/kernels/shader-cache.js +4 -2
- package/src/gpu/kernels/silu.d.ts +1 -0
- package/src/gpu/kernels/silu.js +148 -82
- package/src/gpu/kernels/silu.wgsl +19 -9
- package/src/gpu/kernels/silu_f16.wgsl +19 -9
- package/src/gpu/kernels/softmax.js +44 -25
- package/src/gpu/kernels/split_qkv.js +23 -13
- package/src/gpu/kernels/transpose.js +31 -10
- package/src/gpu/kernels/transpose.wgsl +6 -5
- package/src/gpu/kernels/upsample2d.js +22 -13
- package/src/gpu/kernels/upsample2d.wgsl +6 -9
- package/src/gpu/kernels/upsample2d_f16.wgsl +6 -9
- package/src/gpu/kernels/utils.js +35 -13
- package/src/gpu/partitioned-buffer-pool.js +10 -2
- package/src/gpu/perf-guards.js +2 -9
- package/src/gpu/profiler.js +27 -22
- package/src/gpu/readback-utils.d.ts +16 -0
- package/src/gpu/readback-utils.js +41 -0
- package/src/gpu/submit-tracker.js +13 -0
- package/src/gpu/uniform-cache.d.ts +1 -0
- package/src/gpu/uniform-cache.js +30 -9
- package/src/hotswap/intent-bundle.js +6 -0
- package/src/hotswap/manifest.d.ts +10 -1
- package/src/hotswap/manifest.js +12 -2
- package/src/hotswap/runtime.js +30 -8
- package/src/index-browser.d.ts +44 -0
- package/src/index-browser.js +14 -0
- package/src/inference/browser-harness-contract-helpers.d.ts +5 -0
- package/src/inference/browser-harness-contract-helpers.js +28 -0
- package/src/inference/browser-harness-diffusion-energy-suites.d.ts +2 -0
- package/src/inference/browser-harness-diffusion-energy-suites.js +269 -0
- package/src/inference/browser-harness-model-helpers.d.ts +16 -0
- package/src/inference/browser-harness-model-helpers.js +217 -0
- package/src/inference/browser-harness-report-helpers.d.ts +7 -0
- package/src/inference/browser-harness-report-helpers.js +42 -0
- package/src/inference/browser-harness-runtime-helpers.d.ts +61 -0
- package/src/inference/browser-harness-runtime-helpers.js +415 -0
- package/src/inference/browser-harness-suite-helpers.d.ts +28 -0
- package/src/inference/browser-harness-suite-helpers.js +268 -0
- package/src/inference/browser-harness-text-helpers.d.ts +27 -0
- package/src/inference/browser-harness-text-helpers.js +788 -0
- package/src/inference/browser-harness.d.ts +6 -0
- package/src/inference/browser-harness.js +130 -1950
- package/src/inference/kv-cache/base.js +140 -94
- package/src/inference/kv-cache/tiered.js +5 -3
- package/src/inference/moe-router.js +88 -56
- package/src/inference/multi-model-network.js +5 -3
- package/src/inference/network-evolution.d.ts +11 -2
- package/src/inference/network-evolution.js +20 -21
- package/src/inference/pipelines/context.d.ts +3 -0
- package/src/inference/pipelines/context.js +142 -2
- package/src/inference/pipelines/diffusion/helpers.js +7 -2
- package/src/inference/pipelines/diffusion/pipeline.js +17 -7
- package/src/inference/pipelines/diffusion/sd3-transformer.js +10 -10
- package/src/inference/pipelines/diffusion/text-encoder-gpu.d.ts +5 -0
- package/src/inference/pipelines/diffusion/text-encoder-gpu.js +27 -15
- package/src/inference/pipelines/diffusion/vae.js +3 -7
- package/src/inference/pipelines/energy/pipeline.js +27 -21
- package/src/inference/pipelines/energy/quintel.d.ts +5 -0
- package/src/inference/pipelines/energy/quintel.js +11 -0
- package/src/inference/pipelines/energy-head/row-head-pipeline.js +17 -13
- package/src/inference/pipelines/structured/json-head-pipeline.js +26 -11
- package/src/inference/pipelines/text/attention/projections.js +151 -101
- package/src/inference/pipelines/text/attention/record.js +73 -10
- package/src/inference/pipelines/text/attention/run.js +73 -10
- package/src/inference/pipelines/text/chat-format.js +25 -1
- package/src/inference/pipelines/text/config.d.ts +4 -0
- package/src/inference/pipelines/text/config.js +71 -5
- package/src/inference/pipelines/text/embed.js +2 -8
- package/src/inference/pipelines/text/execution-plan.js +64 -50
- package/src/inference/pipelines/text/execution-v0-contract-helpers.d.ts +59 -0
- package/src/inference/pipelines/text/execution-v0-contract-helpers.js +937 -0
- package/src/inference/pipelines/text/execution-v0-runtime-builders.d.ts +15 -0
- package/src/inference/pipelines/text/execution-v0-runtime-builders.js +279 -0
- package/src/inference/pipelines/text/execution-v0.js +78 -1002
- package/src/inference/pipelines/text/ffn/standard.js +3 -0
- package/src/inference/pipelines/text/generator-steps.d.ts +46 -0
- package/src/inference/pipelines/text/generator-steps.js +298 -207
- package/src/inference/pipelines/text/generator.js +6 -23
- package/src/inference/pipelines/text/init.d.ts +4 -0
- package/src/inference/pipelines/text/init.js +134 -29
- package/src/inference/pipelines/text/kernel-path-auto-select.js +2 -0
- package/src/inference/pipelines/text/kernel-trace.d.ts +2 -0
- package/src/inference/pipelines/text/kernel-trace.js +6 -0
- package/src/inference/pipelines/text/layer.js +14 -9
- package/src/inference/pipelines/text/linear-attention.d.ts +10 -0
- package/src/inference/pipelines/text/linear-attention.js +80 -6
- package/src/inference/pipelines/text/logits/gpu.js +10 -5
- package/src/inference/pipelines/text/logits/index.js +10 -11
- package/src/inference/pipelines/text/logits/utils.d.ts +7 -0
- package/src/inference/pipelines/text/logits/utils.js +9 -0
- package/src/inference/pipelines/text/lora-apply.js +50 -32
- package/src/inference/pipelines/text/model-load.js +279 -104
- package/src/inference/pipelines/text/moe-cache.js +5 -4
- package/src/inference/pipelines/text/moe-cpu-gptoss.js +74 -69
- package/src/inference/pipelines/text/moe-cpu.js +42 -38
- package/src/inference/pipelines/text/moe-gpu.js +110 -86
- package/src/inference/pipelines/text/ops.js +90 -90
- package/src/inference/pipelines/text/probes.js +9 -9
- package/src/inference/pipelines/text/weights.js +17 -7
- package/src/inference/pipelines/text.js +17 -1
- package/src/inference/speculative.d.ts +2 -2
- package/src/inference/speculative.js +4 -18
- package/src/inference/test-harness.d.ts +1 -1
- package/src/inference/test-harness.js +15 -5
- package/src/inference/tokenizer.d.ts +0 -5
- package/src/inference/tokenizer.js +4 -23
- package/src/inference/tokenizers/bpe.js +9 -0
- package/src/inference/tokenizers/bundled.js +176 -33
- package/src/inference/tokenizers/sentencepiece.js +12 -0
- package/src/loader/doppler-loader.js +38 -22
- package/src/loader/dtype-utils.js +3 -44
- package/src/loader/embedding-loader.js +7 -3
- package/src/loader/experts/expert-cache.js +13 -6
- package/src/loader/experts/expert-loader.js +10 -6
- package/src/loader/final-weights-loader.js +8 -4
- package/src/loader/layer-loader.js +2 -1
- package/src/loader/loader-state.js +2 -2
- package/src/loader/memory-monitor.js +8 -0
- package/src/loader/multi-model-loader.d.ts +14 -0
- package/src/loader/multi-model-loader.js +70 -24
- package/src/loader/shard-cache.js +81 -12
- package/src/loader/shard-resolver.js +25 -3
- package/src/loader/tensors/tensor-loader.js +209 -144
- package/src/loader/tensors/tensor-reader.js +76 -19
- package/src/loader/weight-downcast.js +1 -1
- package/src/memory/buffer-pool.d.ts +9 -1
- package/src/memory/buffer-pool.js +109 -44
- package/src/memory/unified-detect.js +1 -1
- package/src/rules/inference/kernel-path.rules.json +24 -8
- package/src/rules/rule-registry.js +25 -1
- package/src/rules/tooling/command-runtime.rules.json +18 -0
- package/src/storage/backends/opfs-store.js +68 -24
- package/src/storage/downloader.js +364 -83
- package/src/storage/index.d.ts +3 -0
- package/src/storage/index.js +3 -0
- package/src/storage/preflight.d.ts +2 -2
- package/src/storage/preflight.js +24 -2
- package/src/storage/quickstart-downloader.js +11 -5
- package/src/storage/registry.js +10 -4
- package/src/storage/reports.js +1 -1
- package/src/storage/shard-manager.d.ts +15 -1
- package/src/storage/shard-manager.js +51 -3
- package/src/storage/source-artifact-store.d.ts +52 -0
- package/src/storage/source-artifact-store.js +234 -0
- package/src/tooling/command-api-constants.d.ts +9 -0
- package/src/tooling/command-api-constants.js +9 -0
- package/src/tooling/command-api-family-normalizers.d.ts +9 -0
- package/src/tooling/command-api-family-normalizers.js +343 -0
- package/src/tooling/command-api-helpers.d.ts +25 -0
- package/src/tooling/command-api-helpers.js +262 -0
- package/src/tooling/command-api.d.ts +27 -1
- package/src/tooling/command-api.js +26 -473
- package/src/tooling/command-envelope.js +4 -1
- package/src/tooling/command-runner-shared.js +52 -18
- package/src/tooling/lean-execution-contract.js +150 -3
- package/src/tooling/node-browser-command-runner.d.ts +4 -0
- package/src/tooling/node-browser-command-runner.js +218 -273
- package/src/tooling/node-command-runner.js +44 -3
- package/src/tooling/node-converter.js +27 -1
- package/src/tooling/node-source-runtime.d.ts +1 -1
- package/src/tooling/node-source-runtime.js +84 -3
- package/src/tooling/node-webgpu.js +30 -105
- package/src/tooling/opfs-cache.js +21 -4
- package/src/tooling/runtime-input-composition.d.ts +38 -0
- package/src/tooling/runtime-input-composition.js +86 -0
- package/src/tooling/source-runtime-bundle.d.ts +40 -5
- package/src/tooling/source-runtime-bundle.js +261 -34
- package/src/tooling/source-runtime-materializer.d.ts +6 -0
- package/src/tooling/source-runtime-materializer.js +93 -0
- package/src/training/attention-backward.js +32 -17
- package/src/training/autograd.js +80 -52
- package/src/training/checkpoint-watch.d.ts +8 -0
- package/src/training/checkpoint-watch.js +139 -0
- package/src/training/checkpoint.d.ts +6 -1
- package/src/training/checkpoint.js +46 -7
- package/src/training/clip.js +2 -1
- package/src/training/datasets/token-batch.js +20 -8
- package/src/training/distillation/artifacts.d.ts +71 -0
- package/src/training/distillation/artifacts.js +132 -0
- package/src/training/distillation/checkpoint-watch.d.ts +10 -0
- package/src/training/distillation/checkpoint-watch.js +58 -0
- package/src/training/distillation/dataset.d.ts +59 -0
- package/src/training/distillation/dataset.js +337 -0
- package/src/training/distillation/eval.d.ts +34 -0
- package/src/training/distillation/eval.js +310 -0
- package/src/training/distillation/index.d.ts +29 -0
- package/src/training/distillation/index.js +29 -0
- package/src/training/distillation/runtime.d.ts +20 -0
- package/src/training/distillation/runtime.js +121 -0
- package/src/training/distillation/scoreboard.d.ts +6 -0
- package/src/training/distillation/scoreboard.js +8 -0
- package/src/training/distillation/stage-a.d.ts +45 -0
- package/src/training/distillation/stage-a.js +338 -0
- package/src/training/distillation/stage-b.d.ts +24 -0
- package/src/training/distillation/stage-b.js +20 -0
- package/src/training/distillation/student-fixture.d.ts +22 -0
- package/src/training/distillation/student-fixture.js +846 -0
- package/src/training/distillation/suite-data.d.ts +45 -0
- package/src/training/distillation/suite-data.js +189 -0
- package/src/training/index.d.ts +10 -0
- package/src/training/index.js +10 -0
- package/src/training/lora-pipeline.d.ts +40 -0
- package/src/training/lora-pipeline.js +793 -0
- package/src/training/lora.js +26 -12
- package/src/training/loss.js +5 -6
- package/src/training/objectives/cross_entropy.js +2 -5
- package/src/training/objectives/distill_kd.js +4 -8
- package/src/training/objectives/distill_triplet.js +4 -8
- package/src/training/objectives/ul_stage2_base.js +4 -8
- package/src/training/operator-artifacts.d.ts +62 -0
- package/src/training/operator-artifacts.js +140 -0
- package/src/training/operator-command.d.ts +5 -0
- package/src/training/operator-command.js +455 -0
- package/src/training/operator-eval.d.ts +48 -0
- package/src/training/operator-eval.js +230 -0
- package/src/training/operator-scoreboard.d.ts +5 -0
- package/src/training/operator-scoreboard.js +44 -0
- package/src/training/optimizer.js +19 -7
- package/src/training/runner.d.ts +52 -0
- package/src/training/runner.js +31 -5
- package/src/training/suite.d.ts +112 -0
- package/src/training/suite.js +24 -984
- package/src/training/tensor-factory.d.ts +9 -0
- package/src/training/tensor-factory.js +13 -0
- package/src/training/trainer.js +3 -5
- package/src/training/ul_dataset.js +3 -5
- package/src/training/workloads.d.ts +164 -0
- package/src/training/workloads.js +530 -0
- package/src/version.js +1 -1
- package/tools/convert-safetensors-node.js +22 -16
- package/tools/doppler-cli.js +179 -63
|
@@ -16,6 +16,7 @@ export interface EnergyUpdateOptions {
|
|
|
16
16
|
export interface EnergyQuintelUpdateOptions {
|
|
17
17
|
count?: number;
|
|
18
18
|
size?: number;
|
|
19
|
+
flags?: number;
|
|
19
20
|
stepSize?: number;
|
|
20
21
|
gradientScale?: number;
|
|
21
22
|
countDiff?: number;
|
|
@@ -26,48 +27,29 @@ export interface EnergyQuintelUpdateOptions {
|
|
|
26
27
|
centerTarget?: number;
|
|
27
28
|
clampMin?: number;
|
|
28
29
|
clampMax?: number;
|
|
29
|
-
rules?: {
|
|
30
|
-
mirrorX?: boolean;
|
|
31
|
-
mirrorY?: boolean;
|
|
32
|
-
diagonal?: boolean;
|
|
33
|
-
count?: boolean;
|
|
34
|
-
center?: boolean;
|
|
35
|
-
};
|
|
36
30
|
}
|
|
37
31
|
|
|
38
32
|
export interface EnergyQuintelReduceOptions {
|
|
39
33
|
count?: number;
|
|
40
34
|
size?: number;
|
|
35
|
+
flags?: number;
|
|
41
36
|
symmetryWeight?: number;
|
|
42
37
|
centerWeight?: number;
|
|
43
38
|
binarizeWeight?: number;
|
|
44
39
|
centerTarget?: number;
|
|
45
|
-
rules?: {
|
|
46
|
-
mirrorX?: boolean;
|
|
47
|
-
mirrorY?: boolean;
|
|
48
|
-
diagonal?: boolean;
|
|
49
|
-
count?: boolean;
|
|
50
|
-
center?: boolean;
|
|
51
|
-
};
|
|
52
40
|
outputBuffer?: GPUBuffer | null;
|
|
53
41
|
}
|
|
54
42
|
|
|
55
43
|
export interface EnergyQuintelGradOptions {
|
|
56
44
|
count?: number;
|
|
57
45
|
size?: number;
|
|
46
|
+
flags?: number;
|
|
58
47
|
countDiff?: number;
|
|
59
48
|
symmetryWeight?: number;
|
|
60
49
|
countWeight?: number;
|
|
61
50
|
centerWeight?: number;
|
|
62
51
|
binarizeWeight?: number;
|
|
63
52
|
centerTarget?: number;
|
|
64
|
-
rules?: {
|
|
65
|
-
mirrorX?: boolean;
|
|
66
|
-
mirrorY?: boolean;
|
|
67
|
-
diagonal?: boolean;
|
|
68
|
-
count?: boolean;
|
|
69
|
-
center?: boolean;
|
|
70
|
-
};
|
|
71
53
|
outputBuffer?: GPUBuffer | null;
|
|
72
54
|
}
|
|
73
55
|
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import { getDevice } from '../device.js';
|
|
2
|
-
import { acquireBuffer } from '../../memory/buffer-pool.js';
|
|
2
|
+
import { acquireBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
|
|
3
3
|
import { createTensor, dtypeBytes } from '../tensor.js';
|
|
4
4
|
import { WORKGROUP_SIZES } from './constants.js';
|
|
5
5
|
import { dispatch, recordDispatch } from './dispatch.js';
|
|
@@ -61,15 +61,14 @@ function resolveQuintelSize(state, sizeOverride) {
|
|
|
61
61
|
return null;
|
|
62
62
|
}
|
|
63
63
|
|
|
64
|
-
function
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
if (
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
return flags >>> 0;
|
|
64
|
+
function resolveQuintelFlags(options, op) {
|
|
65
|
+
if (options.rules !== undefined) {
|
|
66
|
+
throw new Error(`${op}: quintel kernel flags must be resolved before dispatch.`);
|
|
67
|
+
}
|
|
68
|
+
if (!Number.isFinite(options.flags)) {
|
|
69
|
+
throw new Error(`${op}: flags is required for quintel kernels.`);
|
|
70
|
+
}
|
|
71
|
+
return options.flags >>> 0;
|
|
73
72
|
}
|
|
74
73
|
|
|
75
74
|
function resolveExecution(recorder) {
|
|
@@ -103,6 +102,12 @@ function releaseUniformBuffer(execution, uniformBuffer) {
|
|
|
103
102
|
}
|
|
104
103
|
}
|
|
105
104
|
|
|
105
|
+
function releaseOwnedBuffer(ownedBuffer) {
|
|
106
|
+
if (ownedBuffer) {
|
|
107
|
+
releaseBuffer(ownedBuffer);
|
|
108
|
+
}
|
|
109
|
+
}
|
|
110
|
+
|
|
106
111
|
function writeQuintelUpdateUniform(view, params) {
|
|
107
112
|
view.setUint32(0, params.elementCount, true);
|
|
108
113
|
view.setUint32(4, params.boardSize, true);
|
|
@@ -149,6 +154,7 @@ async function executeEnergyEval(recorder, state, target, options = {}, op) {
|
|
|
149
154
|
|
|
150
155
|
const outputSize = elementCount * 4;
|
|
151
156
|
const output = outputBuffer || acquireBuffer(outputSize, undefined, 'energy_eval_output');
|
|
157
|
+
const ownedOutput = outputBuffer ? null : output;
|
|
152
158
|
|
|
153
159
|
const variant = selectEnergyEvalVariant(state.dtype);
|
|
154
160
|
const pipeline = await getPipelineFast('energy_eval', variant);
|
|
@@ -157,23 +163,27 @@ async function executeEnergyEval(recorder, state, target, options = {}, op) {
|
|
|
157
163
|
view.setUint32(0, elementCount, true);
|
|
158
164
|
view.setFloat32(4, scale, true);
|
|
159
165
|
});
|
|
166
|
+
try {
|
|
167
|
+
const bindGroup = execution.device.createBindGroup({
|
|
168
|
+
label: 'energy_eval_bind_group',
|
|
169
|
+
layout: pipeline.getBindGroupLayout(0),
|
|
170
|
+
entries: [
|
|
171
|
+
{ binding: 0, resource: { buffer: uniformBuffer } },
|
|
172
|
+
{ binding: 1, resource: { buffer: state.buffer } },
|
|
173
|
+
{ binding: 2, resource: { buffer: target.buffer } },
|
|
174
|
+
{ binding: 3, resource: { buffer: output } },
|
|
175
|
+
],
|
|
176
|
+
});
|
|
160
177
|
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
});
|
|
171
|
-
|
|
172
|
-
const workgroups = Math.ceil(elementCount / WORKGROUP_SIZES.DEFAULT);
|
|
173
|
-
dispatchEnergy(execution, pipeline, bindGroup, workgroups, 'energy_eval');
|
|
174
|
-
releaseUniformBuffer(execution, uniformBuffer);
|
|
175
|
-
|
|
176
|
-
return createTensor(output, 'f32', [elementCount], 'energy_eval_output');
|
|
178
|
+
const workgroups = Math.ceil(elementCount / WORKGROUP_SIZES.DEFAULT);
|
|
179
|
+
dispatchEnergy(execution, pipeline, bindGroup, workgroups, 'energy_eval');
|
|
180
|
+
return createTensor(output, 'f32', [elementCount], 'energy_eval_output');
|
|
181
|
+
} catch (error) {
|
|
182
|
+
releaseOwnedBuffer(ownedOutput);
|
|
183
|
+
throw error;
|
|
184
|
+
} finally {
|
|
185
|
+
releaseUniformBuffer(execution, uniformBuffer);
|
|
186
|
+
}
|
|
177
187
|
}
|
|
178
188
|
|
|
179
189
|
async function executeEnergyUpdate(recorder, state, target, options = {}, op) {
|
|
@@ -191,21 +201,23 @@ async function executeEnergyUpdate(recorder, state, target, options = {}, op) {
|
|
|
191
201
|
view.setFloat32(8, gradientScale, true);
|
|
192
202
|
});
|
|
193
203
|
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
const workgroups = Math.ceil(elementCount / WORKGROUP_SIZES.DEFAULT);
|
|
205
|
-
dispatchEnergy(execution, pipeline, bindGroup, workgroups, 'energy_update');
|
|
206
|
-
releaseUniformBuffer(execution, uniformBuffer);
|
|
204
|
+
try {
|
|
205
|
+
const bindGroup = execution.device.createBindGroup({
|
|
206
|
+
label: 'energy_update_bind_group',
|
|
207
|
+
layout: pipeline.getBindGroupLayout(0),
|
|
208
|
+
entries: [
|
|
209
|
+
{ binding: 0, resource: { buffer: uniformBuffer } },
|
|
210
|
+
{ binding: 1, resource: { buffer: state.buffer } },
|
|
211
|
+
{ binding: 2, resource: { buffer: target.buffer } },
|
|
212
|
+
],
|
|
213
|
+
});
|
|
207
214
|
|
|
208
|
-
|
|
215
|
+
const workgroups = Math.ceil(elementCount / WORKGROUP_SIZES.DEFAULT);
|
|
216
|
+
dispatchEnergy(execution, pipeline, bindGroup, workgroups, 'energy_update');
|
|
217
|
+
return state;
|
|
218
|
+
} finally {
|
|
219
|
+
releaseUniformBuffer(execution, uniformBuffer);
|
|
220
|
+
}
|
|
209
221
|
}
|
|
210
222
|
|
|
211
223
|
async function executeEnergyQuintelUpdate(recorder, state, options = {}, op) {
|
|
@@ -224,7 +236,6 @@ async function executeEnergyQuintelUpdate(recorder, state, options = {}, op) {
|
|
|
224
236
|
centerTarget = 1.0,
|
|
225
237
|
clampMin = 0.0,
|
|
226
238
|
clampMax = 1.0,
|
|
227
|
-
rules = {},
|
|
228
239
|
} = options;
|
|
229
240
|
const elementCount = inferCount(state, count);
|
|
230
241
|
const boardSize = resolveQuintelSize(state, size);
|
|
@@ -234,7 +245,7 @@ async function executeEnergyQuintelUpdate(recorder, state, options = {}, op) {
|
|
|
234
245
|
|
|
235
246
|
const variant = selectEnergyQuintelUpdateVariant(state.dtype);
|
|
236
247
|
const pipeline = await getPipelineFast('energy_quintel_update', variant);
|
|
237
|
-
const flags =
|
|
248
|
+
const flags = resolveQuintelFlags(options, op);
|
|
238
249
|
|
|
239
250
|
const uniformBuffer = createUniformBuffer(execution, 'energy_quintel_uniforms', 64, (view) => {
|
|
240
251
|
writeQuintelUpdateUniform(view, {
|
|
@@ -254,20 +265,22 @@ async function executeEnergyQuintelUpdate(recorder, state, options = {}, op) {
|
|
|
254
265
|
});
|
|
255
266
|
});
|
|
256
267
|
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
const workgroups = Math.ceil(elementCount / WORKGROUP_SIZES.DEFAULT);
|
|
267
|
-
dispatchEnergy(execution, pipeline, bindGroup, workgroups, 'energy_quintel_update');
|
|
268
|
-
releaseUniformBuffer(execution, uniformBuffer);
|
|
268
|
+
try {
|
|
269
|
+
const bindGroup = execution.device.createBindGroup({
|
|
270
|
+
label: 'energy_quintel_update_bind_group',
|
|
271
|
+
layout: pipeline.getBindGroupLayout(0),
|
|
272
|
+
entries: [
|
|
273
|
+
{ binding: 0, resource: { buffer: uniformBuffer } },
|
|
274
|
+
{ binding: 1, resource: { buffer: state.buffer } },
|
|
275
|
+
],
|
|
276
|
+
});
|
|
269
277
|
|
|
270
|
-
|
|
278
|
+
const workgroups = Math.ceil(elementCount / WORKGROUP_SIZES.DEFAULT);
|
|
279
|
+
dispatchEnergy(execution, pipeline, bindGroup, workgroups, 'energy_quintel_update');
|
|
280
|
+
return state;
|
|
281
|
+
} finally {
|
|
282
|
+
releaseUniformBuffer(execution, uniformBuffer);
|
|
283
|
+
}
|
|
271
284
|
}
|
|
272
285
|
|
|
273
286
|
async function executeEnergyQuintelReduce(recorder, state, options = {}, op) {
|
|
@@ -280,7 +293,6 @@ async function executeEnergyQuintelReduce(recorder, state, options = {}, op) {
|
|
|
280
293
|
centerWeight = 1.0,
|
|
281
294
|
binarizeWeight = 0.0,
|
|
282
295
|
centerTarget = 1.0,
|
|
283
|
-
rules = {},
|
|
284
296
|
outputBuffer = null,
|
|
285
297
|
} = options;
|
|
286
298
|
const elementCount = inferCount(state, count);
|
|
@@ -291,7 +303,7 @@ async function executeEnergyQuintelReduce(recorder, state, options = {}, op) {
|
|
|
291
303
|
|
|
292
304
|
const variant = selectEnergyQuintelReduceVariant(state.dtype);
|
|
293
305
|
const pipeline = await getPipelineFast('energy_quintel_reduce', variant);
|
|
294
|
-
const flags =
|
|
306
|
+
const flags = resolveQuintelFlags(options, op);
|
|
295
307
|
|
|
296
308
|
const uniformBuffer = createUniformBuffer(execution, 'energy_quintel_reduce_uniforms', 48, (view) => {
|
|
297
309
|
writeQuintelReduceUniform(view, {
|
|
@@ -308,21 +320,27 @@ async function executeEnergyQuintelReduce(recorder, state, options = {}, op) {
|
|
|
308
320
|
const workgroups = Math.ceil(elementCount / WORKGROUP_SIZES.DEFAULT);
|
|
309
321
|
const outputSize = workgroups * 16;
|
|
310
322
|
const output = outputBuffer || acquireBuffer(outputSize, undefined, 'energy_quintel_reduce_output');
|
|
323
|
+
const ownedOutput = outputBuffer ? null : output;
|
|
324
|
+
|
|
325
|
+
try {
|
|
326
|
+
const bindGroup = execution.device.createBindGroup({
|
|
327
|
+
label: 'energy_quintel_reduce_bind_group',
|
|
328
|
+
layout: pipeline.getBindGroupLayout(0),
|
|
329
|
+
entries: [
|
|
330
|
+
{ binding: 0, resource: { buffer: uniformBuffer } },
|
|
331
|
+
{ binding: 1, resource: { buffer: state.buffer } },
|
|
332
|
+
{ binding: 2, resource: { buffer: output } },
|
|
333
|
+
],
|
|
334
|
+
});
|
|
311
335
|
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
});
|
|
321
|
-
|
|
322
|
-
dispatchEnergy(execution, pipeline, bindGroup, workgroups, 'energy_quintel_reduce');
|
|
323
|
-
releaseUniformBuffer(execution, uniformBuffer);
|
|
324
|
-
|
|
325
|
-
return createTensor(output, 'f32', [workgroups, 4], 'energy_quintel_reduce_output');
|
|
336
|
+
dispatchEnergy(execution, pipeline, bindGroup, workgroups, 'energy_quintel_reduce');
|
|
337
|
+
return createTensor(output, 'f32', [workgroups, 4], 'energy_quintel_reduce_output');
|
|
338
|
+
} catch (error) {
|
|
339
|
+
releaseOwnedBuffer(ownedOutput);
|
|
340
|
+
throw error;
|
|
341
|
+
} finally {
|
|
342
|
+
releaseUniformBuffer(execution, uniformBuffer);
|
|
343
|
+
}
|
|
326
344
|
}
|
|
327
345
|
|
|
328
346
|
async function executeEnergyQuintelGrad(recorder, state, options = {}, op) {
|
|
@@ -337,7 +355,6 @@ async function executeEnergyQuintelGrad(recorder, state, options = {}, op) {
|
|
|
337
355
|
centerWeight = 1.0,
|
|
338
356
|
binarizeWeight = 0.0,
|
|
339
357
|
centerTarget = 1.0,
|
|
340
|
-
rules = {},
|
|
341
358
|
outputBuffer = null,
|
|
342
359
|
} = options;
|
|
343
360
|
const elementCount = inferCount(state, count);
|
|
@@ -348,7 +365,7 @@ async function executeEnergyQuintelGrad(recorder, state, options = {}, op) {
|
|
|
348
365
|
|
|
349
366
|
const variant = selectEnergyQuintelGradVariant(state.dtype);
|
|
350
367
|
const pipeline = await getPipelineFast('energy_quintel_grad', variant);
|
|
351
|
-
const flags =
|
|
368
|
+
const flags = resolveQuintelFlags(options, op);
|
|
352
369
|
|
|
353
370
|
const uniformBuffer = createUniformBuffer(execution, 'energy_quintel_grad_uniforms', 64, (view) => {
|
|
354
371
|
writeQuintelGradUniform(view, {
|
|
@@ -366,22 +383,28 @@ async function executeEnergyQuintelGrad(recorder, state, options = {}, op) {
|
|
|
366
383
|
|
|
367
384
|
const outputSize = elementCount * 4;
|
|
368
385
|
const output = outputBuffer || acquireBuffer(outputSize, undefined, 'energy_quintel_grad_output');
|
|
386
|
+
const ownedOutput = outputBuffer ? null : output;
|
|
387
|
+
|
|
388
|
+
try {
|
|
389
|
+
const bindGroup = execution.device.createBindGroup({
|
|
390
|
+
label: 'energy_quintel_grad_bind_group',
|
|
391
|
+
layout: pipeline.getBindGroupLayout(0),
|
|
392
|
+
entries: [
|
|
393
|
+
{ binding: 0, resource: { buffer: uniformBuffer } },
|
|
394
|
+
{ binding: 1, resource: { buffer: state.buffer } },
|
|
395
|
+
{ binding: 2, resource: { buffer: output } },
|
|
396
|
+
],
|
|
397
|
+
});
|
|
369
398
|
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
}
|
|
379
|
-
|
|
380
|
-
const workgroups = Math.ceil(elementCount / WORKGROUP_SIZES.DEFAULT);
|
|
381
|
-
dispatchEnergy(execution, pipeline, bindGroup, workgroups, 'energy_quintel_grad');
|
|
382
|
-
releaseUniformBuffer(execution, uniformBuffer);
|
|
383
|
-
|
|
384
|
-
return createTensor(output, 'f32', [elementCount], 'energy_quintel_grad_output');
|
|
399
|
+
const workgroups = Math.ceil(elementCount / WORKGROUP_SIZES.DEFAULT);
|
|
400
|
+
dispatchEnergy(execution, pipeline, bindGroup, workgroups, 'energy_quintel_grad');
|
|
401
|
+
return createTensor(output, 'f32', [elementCount], 'energy_quintel_grad_output');
|
|
402
|
+
} catch (error) {
|
|
403
|
+
releaseOwnedBuffer(ownedOutput);
|
|
404
|
+
throw error;
|
|
405
|
+
} finally {
|
|
406
|
+
releaseUniformBuffer(execution, uniformBuffer);
|
|
407
|
+
}
|
|
385
408
|
}
|
|
386
409
|
|
|
387
410
|
export async function runEnergyEval(state, target, options = {}) {
|
|
@@ -16,7 +16,7 @@ export function hasRequiredFeatures(
|
|
|
16
16
|
for (const feature of required) {
|
|
17
17
|
if (feature === 'shader-f16' && !capabilities.hasF16) return false;
|
|
18
18
|
if (feature === 'subgroups' && !capabilities.hasSubgroups) return false;
|
|
19
|
-
if (feature === 'subgroups-f16' && !capabilities.
|
|
19
|
+
if (feature === 'subgroups-f16' && !capabilities.hasSubgroupsF16) return false;
|
|
20
20
|
}
|
|
21
21
|
return true;
|
|
22
22
|
}
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
|
|
2
2
|
|
|
3
3
|
import { getDevice, getKernelCapabilities } from '../device.js';
|
|
4
|
-
import { acquireBuffer } from '../../memory/buffer-pool.js';
|
|
4
|
+
import { acquireBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
|
|
5
5
|
import { createTensor } from '../tensor.js';
|
|
6
6
|
import { KernelBase } from './kernel-base.js';
|
|
7
7
|
import { createUniformBufferWithView } from './utils.js';
|
|
@@ -77,6 +77,17 @@ function resolveSwigluLimit(value, context) {
|
|
|
77
77
|
return value;
|
|
78
78
|
}
|
|
79
79
|
|
|
80
|
+
function releaseRunResources(uniformBuffer, ownedBuffers) {
|
|
81
|
+
if (uniformBuffer) {
|
|
82
|
+
uniformBuffer.destroy();
|
|
83
|
+
}
|
|
84
|
+
for (const buffer of ownedBuffers) {
|
|
85
|
+
if (buffer) {
|
|
86
|
+
releaseBuffer(buffer);
|
|
87
|
+
}
|
|
88
|
+
}
|
|
89
|
+
}
|
|
90
|
+
|
|
80
91
|
|
|
81
92
|
export async function runFusedFFN(
|
|
82
93
|
input,
|
|
@@ -132,7 +143,8 @@ export async function runFusedFFN(
|
|
|
132
143
|
const outputBytesPerElement = isF16Native ? 2 : 4;
|
|
133
144
|
const outputDtype = isF16Native ? 'f16' : 'f32';
|
|
134
145
|
const outputSize = batchSize * intermediateSize * outputBytesPerElement;
|
|
135
|
-
const
|
|
146
|
+
const ownedOutput = outputBuffer ? null : acquireBuffer(outputSize, undefined, 'fused_ffn_output');
|
|
147
|
+
const output = outputBuffer || ownedOutput;
|
|
136
148
|
|
|
137
149
|
// Create uniform buffer
|
|
138
150
|
const uniformBuffer = createFFNUniformBuffer(device, null, {
|
|
@@ -145,41 +157,42 @@ export async function runFusedFFN(
|
|
|
145
157
|
swigluLimit: activation === 'silu' ? swigluLimit : null,
|
|
146
158
|
});
|
|
147
159
|
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
+
try {
|
|
161
|
+
const bindGroup = device.createBindGroup({
|
|
162
|
+
label: 'fused_ffn_bind_group',
|
|
163
|
+
layout: pipeline.getBindGroupLayout(0),
|
|
164
|
+
entries: [
|
|
165
|
+
{ binding: 0, resource: { buffer: uniformBuffer } },
|
|
166
|
+
{ binding: 1, resource: { buffer: input.buffer } },
|
|
167
|
+
{ binding: 2, resource: { buffer: getBuffer(W_gate) } },
|
|
168
|
+
{ binding: 3, resource: { buffer: getBuffer(W_up) } },
|
|
169
|
+
{ binding: 4, resource: { buffer: output } },
|
|
170
|
+
],
|
|
171
|
+
});
|
|
172
|
+
|
|
173
|
+
let workgroupsX;
|
|
174
|
+
let workgroupsY = 1;
|
|
175
|
+
|
|
176
|
+
if (variant === 'multi') {
|
|
177
|
+
const outputsPerWg = 4;
|
|
178
|
+
workgroupsX = Math.ceil(intermediateSize / outputsPerWg);
|
|
179
|
+
} else if (variant === 'q4k' || variant === 'q4k_batched') {
|
|
180
|
+
const colsPerWg = 32;
|
|
181
|
+
workgroupsX = Math.ceil(intermediateSize / colsPerWg);
|
|
182
|
+
workgroupsY = variant === 'q4k_batched' ? batchSize : 1;
|
|
183
|
+
} else if (variant === 'batched' || variant === 'f16_native_batched') {
|
|
184
|
+
workgroupsX = intermediateSize;
|
|
185
|
+
workgroupsY = batchSize;
|
|
186
|
+
} else {
|
|
187
|
+
workgroupsX = intermediateSize;
|
|
188
|
+
}
|
|
160
189
|
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
if (variant === 'multi') {
|
|
167
|
-
const outputsPerWg = 4;
|
|
168
|
-
workgroupsX = Math.ceil(intermediateSize / outputsPerWg);
|
|
169
|
-
} else if (variant === 'q4k' || variant === 'q4k_batched') {
|
|
170
|
-
// Q4K uses multi-column: 32 columns per workgroup
|
|
171
|
-
const colsPerWg = 32;
|
|
172
|
-
workgroupsX = Math.ceil(intermediateSize / colsPerWg);
|
|
173
|
-
workgroupsY = variant === 'q4k_batched' ? batchSize : 1;
|
|
174
|
-
} else if (variant === 'batched' || variant === 'f16_native_batched') {
|
|
175
|
-
workgroupsX = intermediateSize;
|
|
176
|
-
workgroupsY = batchSize;
|
|
177
|
-
} else {
|
|
178
|
-
workgroupsX = intermediateSize;
|
|
190
|
+
kernel.dispatch(pipeline, bindGroup, workgroupsX, workgroupsY);
|
|
191
|
+
} catch (error) {
|
|
192
|
+
releaseRunResources(uniformBuffer, [ownedOutput]);
|
|
193
|
+
throw error;
|
|
179
194
|
}
|
|
180
195
|
|
|
181
|
-
kernel.dispatch(pipeline, bindGroup, workgroupsX, workgroupsY);
|
|
182
|
-
|
|
183
196
|
uniformBuffer.destroy();
|
|
184
197
|
|
|
185
198
|
return createTensor(output, outputDtype, [batchSize, intermediateSize], 'fused_ffn_output');
|
|
@@ -240,7 +253,8 @@ export async function recordFusedFFN(
|
|
|
240
253
|
const outputBytesPerElement = isF16Native ? 2 : 4;
|
|
241
254
|
const outputDtype = isF16Native ? 'f16' : 'f32';
|
|
242
255
|
const outputSize = batchSize * intermediateSize * outputBytesPerElement;
|
|
243
|
-
const
|
|
256
|
+
const ownedOutput = outputBuffer ? null : acquireBuffer(outputSize, undefined, 'fused_ffn_output');
|
|
257
|
+
const output = outputBuffer || ownedOutput;
|
|
244
258
|
|
|
245
259
|
const uniformBuffer = createFFNUniformBuffer(device, recorder, {
|
|
246
260
|
M: batchSize,
|
|
@@ -252,39 +266,44 @@ export async function recordFusedFFN(
|
|
|
252
266
|
swigluLimit: activation === 'silu' ? swigluLimit : null,
|
|
253
267
|
});
|
|
254
268
|
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
269
|
+
try {
|
|
270
|
+
const bindGroup = device.createBindGroup({
|
|
271
|
+
label: 'fused_ffn_bind_group',
|
|
272
|
+
layout: pipeline.getBindGroupLayout(0),
|
|
273
|
+
entries: [
|
|
274
|
+
{ binding: 0, resource: { buffer: uniformBuffer } },
|
|
275
|
+
{ binding: 1, resource: { buffer: input.buffer } },
|
|
276
|
+
{ binding: 2, resource: { buffer: getBuffer(W_gate) } },
|
|
277
|
+
{ binding: 3, resource: { buffer: getBuffer(W_up) } },
|
|
278
|
+
{ binding: 4, resource: { buffer: output } },
|
|
279
|
+
],
|
|
280
|
+
});
|
|
281
|
+
|
|
282
|
+
let workgroupsX;
|
|
283
|
+
let workgroupsY = 1;
|
|
284
|
+
|
|
285
|
+
if (variant === 'multi') {
|
|
286
|
+
const outputsPerWg = 4;
|
|
287
|
+
workgroupsX = Math.ceil(intermediateSize / outputsPerWg);
|
|
288
|
+
} else if (variant === 'q4k' || variant === 'q4k_batched') {
|
|
289
|
+
const colsPerWg = 32;
|
|
290
|
+
workgroupsX = Math.ceil(intermediateSize / colsPerWg);
|
|
291
|
+
workgroupsY = variant === 'q4k_batched' ? batchSize : 1;
|
|
292
|
+
} else if (variant === 'batched' || variant === 'f16_native_batched') {
|
|
293
|
+
workgroupsX = intermediateSize;
|
|
294
|
+
workgroupsY = batchSize;
|
|
295
|
+
} else {
|
|
296
|
+
workgroupsX = intermediateSize;
|
|
297
|
+
}
|
|
267
298
|
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
} else if (variant === 'q4k' || variant === 'q4k_batched') {
|
|
275
|
-
// Q4K uses multi-column: 32 columns per workgroup
|
|
276
|
-
const colsPerWg = 32;
|
|
277
|
-
workgroupsX = Math.ceil(intermediateSize / colsPerWg);
|
|
278
|
-
workgroupsY = variant === 'q4k_batched' ? batchSize : 1;
|
|
279
|
-
} else if (variant === 'batched' || variant === 'f16_native_batched') {
|
|
280
|
-
workgroupsX = intermediateSize;
|
|
281
|
-
workgroupsY = batchSize;
|
|
282
|
-
} else {
|
|
283
|
-
workgroupsX = intermediateSize;
|
|
299
|
+
kernel.record(recorder, pipeline, bindGroup, workgroupsX, workgroupsY);
|
|
300
|
+
} catch (error) {
|
|
301
|
+
if (ownedOutput) {
|
|
302
|
+
releaseBuffer(ownedOutput);
|
|
303
|
+
}
|
|
304
|
+
throw error;
|
|
284
305
|
}
|
|
285
306
|
|
|
286
|
-
kernel.record(recorder, pipeline, bindGroup, workgroupsX, workgroupsY);
|
|
287
|
-
|
|
288
307
|
return createTensor(output, outputDtype, [batchSize, intermediateSize], 'fused_ffn_output');
|
|
289
308
|
}
|
|
290
309
|
|