@simulatte/doppler 0.1.4 → 0.1.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +26 -10
- package/package.json +30 -6
- package/src/client/doppler-api.browser.d.ts +1 -0
- package/src/client/doppler-api.browser.js +288 -0
- package/src/client/doppler-api.js +1 -1
- package/src/client/doppler-provider/types.js +1 -1
- package/src/config/execution-contract-check.d.ts +33 -0
- package/src/config/execution-contract-check.js +72 -0
- package/src/config/execution-v0-contract-check.d.ts +94 -0
- package/src/config/execution-v0-contract-check.js +251 -0
- package/src/config/execution-v0-graph-contract-check.d.ts +20 -0
- package/src/config/execution-v0-graph-contract-check.js +64 -0
- package/src/config/kernel-path-contract-check.d.ts +76 -0
- package/src/config/kernel-path-contract-check.js +479 -0
- package/src/config/kernel-path-loader.d.ts +16 -0
- package/src/config/kernel-path-loader.js +54 -0
- package/src/config/kernels/kernel-ref-digests.js +39 -27
- package/src/config/kernels/registry.json +598 -2
- package/src/config/loader.js +81 -48
- package/src/config/merge-contract-check.d.ts +16 -0
- package/src/config/merge-contract-check.js +321 -0
- package/src/config/merge-helpers.d.ts +58 -0
- package/src/config/merge-helpers.js +54 -0
- package/src/config/merge.js +21 -6
- package/src/config/presets/models/janus-text.json +2 -0
- package/src/config/presets/models/qwen3.json +9 -2
- package/src/config/presets/models/transformer.json +5 -0
- package/src/config/quantization-contract-check.d.ts +12 -0
- package/src/config/quantization-contract-check.js +91 -0
- package/src/config/required-inference-fields-contract-check.d.ts +24 -0
- package/src/config/required-inference-fields-contract-check.js +237 -0
- package/src/config/schema/browser-suite-metrics.schema.d.ts +17 -0
- package/src/config/schema/browser-suite-metrics.schema.js +46 -0
- package/src/config/schema/conversion-report.schema.d.ts +40 -0
- package/src/config/schema/conversion-report.schema.js +108 -0
- package/src/config/schema/doppler.schema.js +12 -18
- package/src/config/schema/index.d.ts +22 -0
- package/src/config/schema/index.js +18 -0
- package/src/config/schema/inference-defaults.schema.js +3 -0
- package/src/config/schema/inference.schema.d.ts +9 -0
- package/src/config/schema/kernel-path.schema.d.ts +6 -0
- package/src/config/schema/manifest.schema.d.ts +6 -0
- package/src/config/schema/manifest.schema.js +3 -0
- package/src/converter/core.d.ts +10 -0
- package/src/converter/core.js +27 -2
- package/src/converter/parsers/diffusion.js +63 -3
- package/src/converter/rope-config.js +42 -0
- package/src/gpu/device.js +58 -0
- package/src/gpu/kernels/attention.js +98 -0
- package/src/gpu/kernels/bias_add.wgsl +8 -6
- package/src/gpu/kernels/bias_add_f16.wgsl +8 -5
- package/src/gpu/kernels/conv2d.js +1 -1
- package/src/gpu/kernels/conv2d.wgsl +7 -8
- package/src/gpu/kernels/conv2d_f16.wgsl +7 -8
- package/src/gpu/kernels/depthwise_conv2d.d.ts +29 -0
- package/src/gpu/kernels/depthwise_conv2d.js +99 -0
- package/src/gpu/kernels/depthwise_conv2d.wgsl +55 -0
- package/src/gpu/kernels/depthwise_conv2d_f16.wgsl +59 -0
- package/src/gpu/kernels/grouped_pointwise_conv2d.d.ts +27 -0
- package/src/gpu/kernels/grouped_pointwise_conv2d.js +93 -0
- package/src/gpu/kernels/grouped_pointwise_conv2d.wgsl +44 -0
- package/src/gpu/kernels/grouped_pointwise_conv2d_f16.wgsl +48 -0
- package/src/gpu/kernels/index.d.ts +30 -0
- package/src/gpu/kernels/index.js +25 -0
- package/src/gpu/kernels/matmul.js +25 -0
- package/src/gpu/kernels/pixel_shuffle.js +1 -1
- package/src/gpu/kernels/pixel_shuffle.wgsl +4 -5
- package/src/gpu/kernels/pixel_shuffle_f16.wgsl +4 -5
- package/src/gpu/kernels/relu.d.ts +18 -0
- package/src/gpu/kernels/relu.js +58 -0
- package/src/gpu/kernels/relu.wgsl +22 -0
- package/src/gpu/kernels/relu_f16.wgsl +24 -0
- package/src/gpu/kernels/repeat_channels.d.ts +21 -0
- package/src/gpu/kernels/repeat_channels.js +60 -0
- package/src/gpu/kernels/repeat_channels.wgsl +28 -0
- package/src/gpu/kernels/repeat_channels_f16.wgsl +30 -0
- package/src/gpu/kernels/residual.js +44 -8
- package/src/gpu/kernels/residual.wgsl +6 -3
- package/src/gpu/kernels/residual_f16.wgsl +2 -1
- package/src/gpu/kernels/residual_f16_vec4.wgsl +2 -1
- package/src/gpu/kernels/residual_vec4.wgsl +2 -1
- package/src/gpu/kernels/rmsnorm.js +58 -6
- package/src/gpu/kernels/rmsnorm.wgsl +14 -6
- package/src/gpu/kernels/rmsnorm_f16.wgsl +10 -2
- package/src/gpu/kernels/rope.d.ts +2 -0
- package/src/gpu/kernels/rope.js +11 -1
- package/src/gpu/kernels/rope.wgsl +56 -40
- package/src/gpu/kernels/sana_linear_attention.d.ts +27 -0
- package/src/gpu/kernels/sana_linear_attention.js +121 -0
- package/src/gpu/kernels/sana_linear_attention_apply.wgsl +43 -0
- package/src/gpu/kernels/sana_linear_attention_apply_f16.wgsl +46 -0
- package/src/gpu/kernels/sana_linear_attention_summary.wgsl +51 -0
- package/src/gpu/kernels/sana_linear_attention_summary_f16.wgsl +53 -0
- package/src/gpu/kernels/silu.d.ts +1 -0
- package/src/gpu/kernels/silu.js +32 -14
- package/src/gpu/kernels/silu.wgsl +19 -9
- package/src/gpu/kernels/silu_f16.wgsl +19 -9
- package/src/gpu/kernels/transpose.js +15 -2
- package/src/gpu/kernels/transpose.wgsl +5 -6
- package/src/gpu/kernels/upsample2d.js +2 -1
- package/src/gpu/kernels/upsample2d.wgsl +6 -9
- package/src/gpu/kernels/upsample2d_f16.wgsl +6 -9
- package/src/gpu/kernels/utils.js +16 -1
- package/src/index-browser.d.ts +1 -1
- package/src/index-browser.js +2 -2
- package/src/index.js +1 -1
- package/src/inference/browser-harness.js +109 -23
- package/src/inference/pipelines/diffusion/init.js +14 -0
- package/src/inference/pipelines/diffusion/pipeline.js +215 -77
- package/src/inference/pipelines/diffusion/sana-transformer.d.ts +53 -0
- package/src/inference/pipelines/diffusion/sana-transformer.js +738 -0
- package/src/inference/pipelines/diffusion/scheduler.d.ts +17 -1
- package/src/inference/pipelines/diffusion/scheduler.js +91 -3
- package/src/inference/pipelines/diffusion/text-encoder-gpu.d.ts +11 -4
- package/src/inference/pipelines/diffusion/text-encoder-gpu.js +282 -0
- package/src/inference/pipelines/diffusion/text-encoder.js +18 -1
- package/src/inference/pipelines/diffusion/types.d.ts +4 -0
- package/src/inference/pipelines/diffusion/vae.js +782 -78
- package/src/inference/pipelines/text/attention/record.js +11 -2
- package/src/inference/pipelines/text/attention/run.js +11 -2
- package/src/inference/pipelines/text/chat-format.js +25 -1
- package/src/inference/pipelines/text/config.d.ts +9 -0
- package/src/inference/pipelines/text/config.js +69 -2
- package/src/inference/pipelines/text/execution-plan.js +23 -31
- package/src/inference/pipelines/text/execution-v0.js +43 -95
- package/src/inference/pipelines/text/ffn/standard.js +3 -0
- package/src/inference/pipelines/text/init.d.ts +4 -0
- package/src/inference/pipelines/text/init.js +56 -9
- package/src/inference/pipelines/text/layer.js +11 -0
- package/src/inference/pipelines/text.js +4 -0
- package/src/inference/tokenizers/bundled.js +156 -33
- package/src/rules/execution-rules-contract-check.d.ts +17 -0
- package/src/rules/execution-rules-contract-check.js +245 -0
- package/src/rules/kernels/depthwise-conv2d.rules.json +6 -0
- package/src/rules/kernels/grouped-pointwise-conv2d.rules.json +6 -0
- package/src/rules/kernels/relu.rules.json +6 -0
- package/src/rules/kernels/repeat-channels.rules.json +6 -0
- package/src/rules/kernels/sana-linear-attention.rules.json +6 -0
- package/src/rules/layer-pattern-contract-check.d.ts +17 -0
- package/src/rules/layer-pattern-contract-check.js +231 -0
- package/src/rules/rule-registry.d.ts +28 -0
- package/src/rules/rule-registry.js +38 -0
- package/src/rules/tooling/command-runtime.rules.json +18 -0
- package/src/tooling/command-api.d.ts +27 -1
- package/src/tooling/command-api.js +142 -3
- package/src/tooling/conversion-config-materializer.d.ts +24 -0
- package/src/tooling/conversion-config-materializer.js +99 -0
- package/src/tooling/lean-execution-contract-runner.d.ts +43 -0
- package/src/tooling/lean-execution-contract-runner.js +158 -0
- package/src/tooling/node-browser-command-runner.d.ts +4 -0
- package/src/tooling/node-browser-command-runner.js +58 -3
- package/src/tooling/node-command-runner.js +15 -0
- package/src/tooling/node-convert.d.ts +10 -0
- package/src/tooling/node-converter.js +59 -0
- package/src/tooling/node-webgpu.js +11 -89
- package/src/training/checkpoint-watch.d.ts +7 -0
- package/src/training/checkpoint-watch.js +106 -0
- package/src/training/checkpoint.d.ts +6 -1
- package/src/training/checkpoint.js +12 -2
- package/src/training/distillation/artifacts.d.ts +71 -0
- package/src/training/distillation/artifacts.js +132 -0
- package/src/training/distillation/checkpoint-watch.d.ts +10 -0
- package/src/training/distillation/checkpoint-watch.js +57 -0
- package/src/training/distillation/dataset.d.ts +59 -0
- package/src/training/distillation/dataset.js +337 -0
- package/src/training/distillation/eval.d.ts +34 -0
- package/src/training/distillation/eval.js +310 -0
- package/src/training/distillation/index.d.ts +29 -0
- package/src/training/distillation/index.js +29 -0
- package/src/training/distillation/runtime.d.ts +20 -0
- package/src/training/distillation/runtime.js +121 -0
- package/src/training/distillation/scoreboard.d.ts +6 -0
- package/src/training/distillation/scoreboard.js +8 -0
- package/src/training/distillation/stage-a.d.ts +45 -0
- package/src/training/distillation/stage-a.js +338 -0
- package/src/training/distillation/stage-b.d.ts +24 -0
- package/src/training/distillation/stage-b.js +20 -0
- package/src/training/index.d.ts +10 -0
- package/src/training/index.js +10 -0
- package/src/training/lora-pipeline.d.ts +40 -0
- package/src/training/lora-pipeline.js +796 -0
- package/src/training/operator-artifacts.d.ts +62 -0
- package/src/training/operator-artifacts.js +140 -0
- package/src/training/operator-command.d.ts +5 -0
- package/src/training/operator-command.js +453 -0
- package/src/training/operator-eval.d.ts +48 -0
- package/src/training/operator-eval.js +230 -0
- package/src/training/operator-scoreboard.d.ts +5 -0
- package/src/training/operator-scoreboard.js +44 -0
- package/src/training/runner.d.ts +52 -0
- package/src/training/runner.js +29 -4
- package/src/training/suite.d.ts +112 -0
- package/src/training/suite.js +9 -9
- package/src/training/workloads.d.ts +164 -0
- package/src/training/workloads.js +539 -0
- package/src/version.d.ts +2 -0
- package/src/version.js +2 -0
- package/tools/convert-safetensors-node.js +47 -0
- package/tools/doppler-cli.js +252 -41
|
@@ -780,6 +780,23 @@ function resolveAttentionExecution(recorder) {
|
|
|
780
780
|
};
|
|
781
781
|
}
|
|
782
782
|
|
|
783
|
+
function assertAttentionBindGroupBuffer(kernelName, variant, bindingIndex, bindingLabel, buffer, details = []) {
|
|
784
|
+
const isGpuBuffer = buffer && (
|
|
785
|
+
typeof GPUBuffer === 'undefined'
|
|
786
|
+
? true
|
|
787
|
+
: buffer instanceof GPUBuffer
|
|
788
|
+
);
|
|
789
|
+
if (isGpuBuffer) {
|
|
790
|
+
return;
|
|
791
|
+
}
|
|
792
|
+
const detailText = details.filter(Boolean).join(', ');
|
|
793
|
+
throw new Error(
|
|
794
|
+
`[${kernelName}] variant="${variant}" binding ${bindingIndex} "${bindingLabel}" requires a GPUBuffer` +
|
|
795
|
+
(detailText ? ` (${detailText})` : '') +
|
|
796
|
+
'.'
|
|
797
|
+
);
|
|
798
|
+
}
|
|
799
|
+
|
|
783
800
|
function releaseAttentionUniform(execution, uniformBuffer) {
|
|
784
801
|
if (!execution.recorder) {
|
|
785
802
|
releaseUniformBuffer(uniformBuffer);
|
|
@@ -867,6 +884,26 @@ async function executeAttentionBDPA(
|
|
|
867
884
|
slidingWindow,
|
|
868
885
|
});
|
|
869
886
|
|
|
887
|
+
assertAttentionBindGroupBuffer('attention_bdpa', variant, 0, 'uniforms', uniformBuffer);
|
|
888
|
+
assertAttentionBindGroupBuffer('attention_bdpa', variant, 1, 'Q', Q?.buffer, [
|
|
889
|
+
`QLabel=${Q?.label ?? 'unknown'}`,
|
|
890
|
+
`QDtype=${Q?.dtype ?? 'unknown'}`,
|
|
891
|
+
]);
|
|
892
|
+
assertAttentionBindGroupBuffer('attention_bdpa', variant, 2, 'basisK', basisK?.buffer, [
|
|
893
|
+
`basisKLabel=${basisK?.label ?? 'unknown'}`,
|
|
894
|
+
`basisKDtype=${basisK?.dtype ?? 'unknown'}`,
|
|
895
|
+
]);
|
|
896
|
+
assertAttentionBindGroupBuffer('attention_bdpa', variant, 3, 'basisV', basisV?.buffer, [
|
|
897
|
+
`basisVLabel=${basisV?.label ?? 'unknown'}`,
|
|
898
|
+
`basisVDtype=${basisV?.dtype ?? 'unknown'}`,
|
|
899
|
+
]);
|
|
900
|
+
assertAttentionBindGroupBuffer('attention_bdpa', variant, 4, 'pagedK', pagedK);
|
|
901
|
+
assertAttentionBindGroupBuffer('attention_bdpa', variant, 5, 'pagedV', pagedV);
|
|
902
|
+
assertAttentionBindGroupBuffer('attention_bdpa', variant, 6, 'index', index);
|
|
903
|
+
assertAttentionBindGroupBuffer('attention_bdpa', variant, 7, 'ropeCos', ropeCos);
|
|
904
|
+
assertAttentionBindGroupBuffer('attention_bdpa', variant, 8, 'ropeSin', ropeSin);
|
|
905
|
+
assertAttentionBindGroupBuffer('attention_bdpa', variant, 9, 'output', outputBuf);
|
|
906
|
+
|
|
870
907
|
const bindGroup = execution.device.createBindGroup({
|
|
871
908
|
label: 'attention_bdpa_bind_group',
|
|
872
909
|
layout: pipeline.getBindGroupLayout(0),
|
|
@@ -982,6 +1019,24 @@ async function executeAttention(
|
|
|
982
1019
|
|
|
983
1020
|
const kvLenBinding = kvLenBuffer || getKvLenFallbackBuffer(execution.device);
|
|
984
1021
|
const pageTableBinding = kvPageTable || getPageTableFallbackBuffer(execution.device);
|
|
1022
|
+
assertAttentionBindGroupBuffer('attention', plan.variant, 0, 'uniforms', uniformBuffer);
|
|
1023
|
+
assertAttentionBindGroupBuffer('attention', plan.variant, 1, 'Q', Q?.buffer, [
|
|
1024
|
+
`QLabel=${Q?.label ?? 'unknown'}`,
|
|
1025
|
+
`QDtype=${Q?.dtype ?? 'unknown'}`,
|
|
1026
|
+
]);
|
|
1027
|
+
assertAttentionBindGroupBuffer('attention', plan.variant, 2, 'K', K?.buffer, [
|
|
1028
|
+
`KLabel=${K?.label ?? 'unknown'}`,
|
|
1029
|
+
`KDtype=${K?.dtype ?? 'unknown'}`,
|
|
1030
|
+
]);
|
|
1031
|
+
assertAttentionBindGroupBuffer('attention', plan.variant, 3, 'V', V?.buffer, [
|
|
1032
|
+
`VLabel=${V?.label ?? 'unknown'}`,
|
|
1033
|
+
`VDtype=${V?.dtype ?? 'unknown'}`,
|
|
1034
|
+
]);
|
|
1035
|
+
assertAttentionBindGroupBuffer('attention', plan.variant, 4, 'output', outputBuf);
|
|
1036
|
+
assertAttentionBindGroupBuffer('attention', plan.variant, 5, 'kvLen', kvLenBinding);
|
|
1037
|
+
assertAttentionBindGroupBuffer('attention', plan.variant, 6, 'pageTable', pageTableBinding, [
|
|
1038
|
+
`kvLayout=${kvLayout}`,
|
|
1039
|
+
]);
|
|
985
1040
|
const bindGroup = execution.device.createBindGroup({
|
|
986
1041
|
label: 'attention_bind_group',
|
|
987
1042
|
layout: pipeline.getBindGroupLayout(0),
|
|
@@ -1099,6 +1154,31 @@ async function executeAttentionTiered(
|
|
|
1099
1154
|
});
|
|
1100
1155
|
|
|
1101
1156
|
const pageTableBinding = coldPageTable || getPageTableFallbackBuffer(execution.device);
|
|
1157
|
+
assertAttentionBindGroupBuffer('attention_tiered', variant, 0, 'uniforms', uniformBuffer);
|
|
1158
|
+
assertAttentionBindGroupBuffer('attention_tiered', variant, 1, 'Q', Q?.buffer, [
|
|
1159
|
+
`QLabel=${Q?.label ?? 'unknown'}`,
|
|
1160
|
+
`QDtype=${Q?.dtype ?? 'unknown'}`,
|
|
1161
|
+
]);
|
|
1162
|
+
assertAttentionBindGroupBuffer('attention_tiered', variant, 2, 'hotK', hotK?.buffer, [
|
|
1163
|
+
`hotKLabel=${hotK?.label ?? 'unknown'}`,
|
|
1164
|
+
`hotKDtype=${hotK?.dtype ?? 'unknown'}`,
|
|
1165
|
+
]);
|
|
1166
|
+
assertAttentionBindGroupBuffer('attention_tiered', variant, 3, 'hotV', hotV?.buffer, [
|
|
1167
|
+
`hotVLabel=${hotV?.label ?? 'unknown'}`,
|
|
1168
|
+
`hotVDtype=${hotV?.dtype ?? 'unknown'}`,
|
|
1169
|
+
]);
|
|
1170
|
+
assertAttentionBindGroupBuffer('attention_tiered', variant, 4, 'coldK', coldK?.buffer, [
|
|
1171
|
+
`coldKLabel=${coldK?.label ?? 'unknown'}`,
|
|
1172
|
+
`coldKDtype=${coldK?.dtype ?? 'unknown'}`,
|
|
1173
|
+
]);
|
|
1174
|
+
assertAttentionBindGroupBuffer('attention_tiered', variant, 5, 'coldV', coldV?.buffer, [
|
|
1175
|
+
`coldVLabel=${coldV?.label ?? 'unknown'}`,
|
|
1176
|
+
`coldVDtype=${coldV?.dtype ?? 'unknown'}`,
|
|
1177
|
+
]);
|
|
1178
|
+
assertAttentionBindGroupBuffer('attention_tiered', variant, 6, 'output', outputBuf);
|
|
1179
|
+
assertAttentionBindGroupBuffer('attention_tiered', variant, 7, 'pageTable', pageTableBinding, [
|
|
1180
|
+
`coldLayout=${coldLayout}`,
|
|
1181
|
+
]);
|
|
1102
1182
|
const bindGroup = execution.device.createBindGroup({
|
|
1103
1183
|
label: 'attention_tiered_bind_group',
|
|
1104
1184
|
layout: pipeline.getBindGroupLayout(0),
|
|
@@ -1200,6 +1280,24 @@ async function executeAttentionTieredQuant(
|
|
|
1200
1280
|
packedStride,
|
|
1201
1281
|
});
|
|
1202
1282
|
|
|
1283
|
+
assertAttentionBindGroupBuffer('attention_tiered_quant', variant, 0, 'uniforms', uniformBuffer);
|
|
1284
|
+
assertAttentionBindGroupBuffer('attention_tiered_quant', variant, 1, 'Q', Q?.buffer, [
|
|
1285
|
+
`QLabel=${Q?.label ?? 'unknown'}`,
|
|
1286
|
+
`QDtype=${Q?.dtype ?? 'unknown'}`,
|
|
1287
|
+
]);
|
|
1288
|
+
assertAttentionBindGroupBuffer('attention_tiered_quant', variant, 2, 'hotK', hotK?.buffer, [
|
|
1289
|
+
`hotKLabel=${hotK?.label ?? 'unknown'}`,
|
|
1290
|
+
`hotKDtype=${hotK?.dtype ?? 'unknown'}`,
|
|
1291
|
+
]);
|
|
1292
|
+
assertAttentionBindGroupBuffer('attention_tiered_quant', variant, 3, 'hotV', hotV?.buffer, [
|
|
1293
|
+
`hotVLabel=${hotV?.label ?? 'unknown'}`,
|
|
1294
|
+
`hotVDtype=${hotV?.dtype ?? 'unknown'}`,
|
|
1295
|
+
]);
|
|
1296
|
+
assertAttentionBindGroupBuffer('attention_tiered_quant', variant, 4, 'coldPackedK', coldPackedK);
|
|
1297
|
+
assertAttentionBindGroupBuffer('attention_tiered_quant', variant, 5, 'coldPackedV', coldPackedV);
|
|
1298
|
+
assertAttentionBindGroupBuffer('attention_tiered_quant', variant, 6, 'coldScalesK', coldScalesK);
|
|
1299
|
+
assertAttentionBindGroupBuffer('attention_tiered_quant', variant, 7, 'coldScalesV', coldScalesV);
|
|
1300
|
+
assertAttentionBindGroupBuffer('attention_tiered_quant', variant, 8, 'output', outputBuf);
|
|
1203
1301
|
const bindGroup = execution.device.createBindGroup({
|
|
1204
1302
|
label: 'attention_tiered_quant_bind_group',
|
|
1205
1303
|
layout: pipeline.getBindGroupLayout(0),
|
|
@@ -14,6 +14,10 @@ struct Uniforms {
|
|
|
14
14
|
dim: u32,
|
|
15
15
|
data_offset: u32, // byte offset into data buffer (divide by 4 for F32)
|
|
16
16
|
bias_offset: u32, // byte offset into bias buffer (divide by 4 for F32)
|
|
17
|
+
token_stride: u32,
|
|
18
|
+
_pad0: u32,
|
|
19
|
+
_pad1: u32,
|
|
20
|
+
_pad2: u32,
|
|
17
21
|
}
|
|
18
22
|
|
|
19
23
|
override WORKGROUP_SIZE: u32 = 256u;
|
|
@@ -24,17 +28,15 @@ override WORKGROUP_SIZE: u32 = 256u;
|
|
|
24
28
|
|
|
25
29
|
@compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
|
|
26
30
|
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
27
|
-
let
|
|
28
|
-
let
|
|
29
|
-
if (
|
|
31
|
+
let d = gid.x;
|
|
32
|
+
let token = gid.z * max(u.token_stride, 1u) + gid.y;
|
|
33
|
+
if (token >= u.num_tokens || d >= u.dim) {
|
|
30
34
|
return;
|
|
31
35
|
}
|
|
32
36
|
|
|
33
37
|
// Convert byte offsets to F32 indices
|
|
34
38
|
let data_base = u.data_offset / 4u;
|
|
35
39
|
let bias_base = u.bias_offset / 4u;
|
|
36
|
-
|
|
37
|
-
let d = idx % u.dim;
|
|
40
|
+
let idx = token * u.dim + d;
|
|
38
41
|
data[data_base + idx] = data[data_base + idx] + bias[bias_base + d];
|
|
39
42
|
}
|
|
40
|
-
|
|
@@ -18,6 +18,10 @@ struct Uniforms {
|
|
|
18
18
|
dim: u32,
|
|
19
19
|
data_offset: u32, // byte offset into data buffer (divide by 2 for F16)
|
|
20
20
|
bias_offset: u32, // byte offset into bias buffer (divide by 2 for F16)
|
|
21
|
+
token_stride: u32,
|
|
22
|
+
_pad0: u32,
|
|
23
|
+
_pad1: u32,
|
|
24
|
+
_pad2: u32,
|
|
21
25
|
}
|
|
22
26
|
|
|
23
27
|
override WORKGROUP_SIZE: u32 = 256u;
|
|
@@ -28,17 +32,16 @@ override WORKGROUP_SIZE: u32 = 256u;
|
|
|
28
32
|
|
|
29
33
|
@compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
|
|
30
34
|
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
31
|
-
let
|
|
32
|
-
let
|
|
33
|
-
if (
|
|
35
|
+
let d = gid.x;
|
|
36
|
+
let token = gid.z * max(u.token_stride, 1u) + gid.y;
|
|
37
|
+
if (token >= u.num_tokens || d >= u.dim) {
|
|
34
38
|
return;
|
|
35
39
|
}
|
|
36
40
|
|
|
37
41
|
// Convert byte offsets to F16 indices
|
|
38
42
|
let data_base = u.data_offset / 2u;
|
|
39
43
|
let bias_base = u.bias_offset / 2u;
|
|
40
|
-
|
|
41
|
-
let d = idx % u.dim;
|
|
44
|
+
let idx = token * u.dim + d;
|
|
42
45
|
let out = f32(data[data_base + idx]) + f32(bias[bias_base + d]);
|
|
43
46
|
data[data_base + idx] = f16(out);
|
|
44
47
|
}
|
|
@@ -58,7 +58,7 @@ async function _conv2d(target, input, weight, bias, options = {}) {
|
|
|
58
58
|
kernel_h: kernelH, kernel_w: kernelW,
|
|
59
59
|
stride, pad, _pad0: 0, _pad1: 0,
|
|
60
60
|
},
|
|
61
|
-
Math.ceil((
|
|
61
|
+
[Math.ceil((outHeight * outWidth) / WORKGROUP_SIZES.DEFAULT), outChannels, 1]
|
|
62
62
|
);
|
|
63
63
|
|
|
64
64
|
if (tempBias) {
|
|
@@ -27,19 +27,18 @@ struct Uniforms {
|
|
|
27
27
|
|
|
28
28
|
@compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
|
|
29
29
|
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
30
|
-
let idx = gid.x;
|
|
31
30
|
let out_height = u.out_height;
|
|
32
31
|
let out_width = u.out_width;
|
|
33
|
-
let
|
|
34
|
-
|
|
32
|
+
let out_spatial = out_height * out_width;
|
|
33
|
+
let out_spatial_idx = gid.x;
|
|
34
|
+
let out_c = gid.y;
|
|
35
|
+
if (out_c >= u.out_channels || out_spatial_idx >= out_spatial) {
|
|
35
36
|
return;
|
|
36
37
|
}
|
|
37
38
|
|
|
38
|
-
let
|
|
39
|
-
let
|
|
40
|
-
let
|
|
41
|
-
let out_y = rem / out_width;
|
|
42
|
-
let out_x = rem - out_y * out_width;
|
|
39
|
+
let out_y = out_spatial_idx / out_width;
|
|
40
|
+
let out_x = out_spatial_idx - out_y * out_width;
|
|
41
|
+
let idx = out_c * out_spatial + out_spatial_idx;
|
|
43
42
|
|
|
44
43
|
var sum: f32 = bias[out_c];
|
|
45
44
|
|
|
@@ -29,19 +29,18 @@ struct Uniforms {
|
|
|
29
29
|
|
|
30
30
|
@compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
|
|
31
31
|
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
32
|
-
let idx = gid.x;
|
|
33
32
|
let out_height = u.out_height;
|
|
34
33
|
let out_width = u.out_width;
|
|
35
|
-
let
|
|
36
|
-
|
|
34
|
+
let out_spatial = out_height * out_width;
|
|
35
|
+
let out_spatial_idx = gid.x;
|
|
36
|
+
let out_c = gid.y;
|
|
37
|
+
if (out_c >= u.out_channels || out_spatial_idx >= out_spatial) {
|
|
37
38
|
return;
|
|
38
39
|
}
|
|
39
40
|
|
|
40
|
-
let
|
|
41
|
-
let
|
|
42
|
-
let
|
|
43
|
-
let out_y = rem / out_width;
|
|
44
|
-
let out_x = rem - out_y * out_width;
|
|
41
|
+
let out_y = out_spatial_idx / out_width;
|
|
42
|
+
let out_x = out_spatial_idx - out_y * out_width;
|
|
43
|
+
let idx = out_c * out_spatial + out_spatial_idx;
|
|
45
44
|
|
|
46
45
|
var sum: f32 = f32(bias[out_c]);
|
|
47
46
|
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
import type { Tensor } from '../tensor.js';
|
|
2
|
+
import type { CommandRecorder } from '../command-recorder.js';
|
|
3
|
+
import type { OutputBufferOptions } from './types.js';
|
|
4
|
+
import type { WeightBuffer } from '../weight-buffer.js';
|
|
5
|
+
|
|
6
|
+
export interface DepthwiseConv2DOptions extends OutputBufferOptions {
|
|
7
|
+
channels: number;
|
|
8
|
+
height: number;
|
|
9
|
+
width: number;
|
|
10
|
+
kernelH: number;
|
|
11
|
+
kernelW: number;
|
|
12
|
+
stride?: number;
|
|
13
|
+
pad?: number;
|
|
14
|
+
}
|
|
15
|
+
|
|
16
|
+
export declare function runDepthwiseConv2D(
|
|
17
|
+
input: Tensor,
|
|
18
|
+
weight: GPUBuffer | WeightBuffer,
|
|
19
|
+
bias: GPUBuffer | WeightBuffer | null,
|
|
20
|
+
options: DepthwiseConv2DOptions
|
|
21
|
+
): Promise<Tensor>;
|
|
22
|
+
|
|
23
|
+
export declare function recordDepthwiseConv2D(
|
|
24
|
+
recorder: CommandRecorder,
|
|
25
|
+
input: Tensor,
|
|
26
|
+
weight: GPUBuffer | WeightBuffer,
|
|
27
|
+
bias: GPUBuffer | WeightBuffer | null,
|
|
28
|
+
options: DepthwiseConv2DOptions
|
|
29
|
+
): Promise<Tensor>;
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
import { getDevice } from '../device.js';
|
|
2
|
+
import { acquireBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
|
|
3
|
+
import { createTensor, dtypeBytes } from '../tensor.js';
|
|
4
|
+
import { getBuffer } from '../weight-buffer.js';
|
|
5
|
+
import { unifiedKernelWrapper } from './utils.js';
|
|
6
|
+
import { selectRuleValue } from './rule-registry.js';
|
|
7
|
+
import { WORKGROUP_SIZES } from './constants.js';
|
|
8
|
+
|
|
9
|
+
function selectDepthwiseConv2DVariant(isF16) {
|
|
10
|
+
return selectRuleValue('depthwiseConv2d', 'variant', { isF16 });
|
|
11
|
+
}
|
|
12
|
+
|
|
13
|
+
async function _depthwiseConv2D(target, input, weight, bias, options = {}) {
|
|
14
|
+
const recorder = target && typeof target.beginComputePass === 'function' ? target : null;
|
|
15
|
+
const device = target?.device || getDevice();
|
|
16
|
+
const {
|
|
17
|
+
channels,
|
|
18
|
+
height,
|
|
19
|
+
width,
|
|
20
|
+
kernelH,
|
|
21
|
+
kernelW,
|
|
22
|
+
stride = 1,
|
|
23
|
+
pad = 0,
|
|
24
|
+
outputBuffer = null,
|
|
25
|
+
} = options;
|
|
26
|
+
|
|
27
|
+
if (
|
|
28
|
+
!Number.isFinite(channels) ||
|
|
29
|
+
!Number.isFinite(height) ||
|
|
30
|
+
!Number.isFinite(width) ||
|
|
31
|
+
!Number.isFinite(kernelH) ||
|
|
32
|
+
!Number.isFinite(kernelW)
|
|
33
|
+
) {
|
|
34
|
+
throw new Error('DepthwiseConv2D requires explicit dimensions.');
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
const outHeight = Math.floor((height + pad * 2 - kernelH) / stride) + 1;
|
|
38
|
+
const outWidth = Math.floor((width + pad * 2 - kernelW) / stride) + 1;
|
|
39
|
+
if (outHeight <= 0 || outWidth <= 0) {
|
|
40
|
+
throw new Error(`DepthwiseConv2D invalid output size: ${outHeight}x${outWidth}`);
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
const isF16 = input.dtype === 'f16';
|
|
44
|
+
const variant = selectDepthwiseConv2DVariant(isF16);
|
|
45
|
+
const bytesPerElement = dtypeBytes(input.dtype);
|
|
46
|
+
const outputSize = channels * outHeight * outWidth * bytesPerElement;
|
|
47
|
+
const output = outputBuffer || acquireBuffer(outputSize, undefined, 'depthwise_conv2d_output');
|
|
48
|
+
const outSpatial = outHeight * outWidth;
|
|
49
|
+
|
|
50
|
+
const weightBuffer = getBuffer(weight);
|
|
51
|
+
let biasBuffer = getBuffer(bias);
|
|
52
|
+
let tempBias = null;
|
|
53
|
+
if (!biasBuffer) {
|
|
54
|
+
const biasSize = channels * bytesPerElement;
|
|
55
|
+
tempBias = acquireBuffer(biasSize, undefined, 'depthwise_conv2d_bias_zero');
|
|
56
|
+
biasBuffer = tempBias;
|
|
57
|
+
const paddedSize = Math.ceil(biasSize / 4) * 4;
|
|
58
|
+
device.queue.writeBuffer(biasBuffer, 0, new Uint8Array(paddedSize));
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
await unifiedKernelWrapper(
|
|
62
|
+
'depthwise_conv2d',
|
|
63
|
+
target,
|
|
64
|
+
variant,
|
|
65
|
+
[input, weightBuffer, biasBuffer, output],
|
|
66
|
+
{
|
|
67
|
+
channels,
|
|
68
|
+
height,
|
|
69
|
+
width,
|
|
70
|
+
out_height: outHeight,
|
|
71
|
+
out_width: outWidth,
|
|
72
|
+
kernel_h: kernelH,
|
|
73
|
+
kernel_w: kernelW,
|
|
74
|
+
stride,
|
|
75
|
+
pad,
|
|
76
|
+
_pad0: 0,
|
|
77
|
+
_pad1: 0,
|
|
78
|
+
},
|
|
79
|
+
[Math.ceil(outSpatial / WORKGROUP_SIZES.DEFAULT), channels, 1]
|
|
80
|
+
);
|
|
81
|
+
|
|
82
|
+
if (tempBias) {
|
|
83
|
+
if (recorder) {
|
|
84
|
+
recorder.trackTemporaryBuffer(tempBias);
|
|
85
|
+
} else {
|
|
86
|
+
releaseBuffer(tempBias);
|
|
87
|
+
}
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
return createTensor(output, input.dtype, [channels, outHeight, outWidth], 'depthwise_conv2d_output');
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
export async function runDepthwiseConv2D(input, weight, bias, options = {}) {
|
|
94
|
+
return _depthwiseConv2D(null, input, weight, bias, options);
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
export async function recordDepthwiseConv2D(recorder, input, weight, bias, options = {}) {
|
|
98
|
+
return _depthwiseConv2D(recorder, input, weight, bias, options);
|
|
99
|
+
}
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
override WORKGROUP_SIZE: u32 = 256u;
|
|
2
|
+
|
|
3
|
+
struct Uniforms {
|
|
4
|
+
channels: u32,
|
|
5
|
+
height: u32,
|
|
6
|
+
width: u32,
|
|
7
|
+
out_height: u32,
|
|
8
|
+
out_width: u32,
|
|
9
|
+
kernel_h: u32,
|
|
10
|
+
kernel_w: u32,
|
|
11
|
+
stride: u32,
|
|
12
|
+
pad: u32,
|
|
13
|
+
_pad0: u32,
|
|
14
|
+
_pad1: u32,
|
|
15
|
+
_pad2: u32,
|
|
16
|
+
}
|
|
17
|
+
|
|
18
|
+
@group(0) @binding(0) var<uniform> u: Uniforms;
|
|
19
|
+
@group(0) @binding(1) var<storage, read> input: array<f32>;
|
|
20
|
+
@group(0) @binding(2) var<storage, read> weight: array<f32>;
|
|
21
|
+
@group(0) @binding(3) var<storage, read> bias: array<f32>;
|
|
22
|
+
@group(0) @binding(4) var<storage, read_write> output: array<f32>;
|
|
23
|
+
|
|
24
|
+
@compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
|
|
25
|
+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
26
|
+
let out_spatial = u.out_height * u.out_width;
|
|
27
|
+
let spatial_idx = gid.x;
|
|
28
|
+
let channel = gid.y;
|
|
29
|
+
if (spatial_idx >= out_spatial || channel >= u.channels) {
|
|
30
|
+
return;
|
|
31
|
+
}
|
|
32
|
+
let out_y = spatial_idx / u.out_width;
|
|
33
|
+
let out_x = spatial_idx - out_y * u.out_width;
|
|
34
|
+
|
|
35
|
+
var sum: f32 = bias[channel];
|
|
36
|
+
let pad = i32(u.pad);
|
|
37
|
+
|
|
38
|
+
for (var ky: u32 = 0u; ky < u.kernel_h; ky = ky + 1u) {
|
|
39
|
+
let in_y = i32(out_y * u.stride + ky) - pad;
|
|
40
|
+
if (in_y < 0 || in_y >= i32(u.height)) {
|
|
41
|
+
continue;
|
|
42
|
+
}
|
|
43
|
+
for (var kx: u32 = 0u; kx < u.kernel_w; kx = kx + 1u) {
|
|
44
|
+
let in_x = i32(out_x * u.stride + kx) - pad;
|
|
45
|
+
if (in_x < 0 || in_x >= i32(u.width)) {
|
|
46
|
+
continue;
|
|
47
|
+
}
|
|
48
|
+
let input_idx = (channel * u.height + u32(in_y)) * u.width + u32(in_x);
|
|
49
|
+
let weight_idx = ((channel * u.kernel_h + ky) * u.kernel_w + kx);
|
|
50
|
+
sum = sum + input[input_idx] * weight[weight_idx];
|
|
51
|
+
}
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
output[channel * out_spatial + spatial_idx] = sum;
|
|
55
|
+
}
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
// Depthwise Conv2D Kernel (NCHW, f16)
|
|
2
|
+
|
|
3
|
+
enable f16;
|
|
4
|
+
|
|
5
|
+
override WORKGROUP_SIZE: u32 = 256u;
|
|
6
|
+
|
|
7
|
+
struct Uniforms {
|
|
8
|
+
channels: u32,
|
|
9
|
+
height: u32,
|
|
10
|
+
width: u32,
|
|
11
|
+
out_height: u32,
|
|
12
|
+
out_width: u32,
|
|
13
|
+
kernel_h: u32,
|
|
14
|
+
kernel_w: u32,
|
|
15
|
+
stride: u32,
|
|
16
|
+
pad: u32,
|
|
17
|
+
_pad0: u32,
|
|
18
|
+
_pad1: u32,
|
|
19
|
+
_pad2: u32,
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
@group(0) @binding(0) var<uniform> u: Uniforms;
|
|
23
|
+
@group(0) @binding(1) var<storage, read> input: array<f16>;
|
|
24
|
+
@group(0) @binding(2) var<storage, read> weight: array<f16>;
|
|
25
|
+
@group(0) @binding(3) var<storage, read> bias: array<f16>;
|
|
26
|
+
@group(0) @binding(4) var<storage, read_write> output: array<f16>;
|
|
27
|
+
|
|
28
|
+
@compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
|
|
29
|
+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
30
|
+
let out_spatial = u.out_height * u.out_width;
|
|
31
|
+
let spatial_idx = gid.x;
|
|
32
|
+
let channel = gid.y;
|
|
33
|
+
if (spatial_idx >= out_spatial || channel >= u.channels) {
|
|
34
|
+
return;
|
|
35
|
+
}
|
|
36
|
+
let out_y = spatial_idx / u.out_width;
|
|
37
|
+
let out_x = spatial_idx - out_y * u.out_width;
|
|
38
|
+
|
|
39
|
+
var sum: f32 = f32(bias[channel]);
|
|
40
|
+
let pad = i32(u.pad);
|
|
41
|
+
|
|
42
|
+
for (var ky: u32 = 0u; ky < u.kernel_h; ky = ky + 1u) {
|
|
43
|
+
let in_y = i32(out_y * u.stride + ky) - pad;
|
|
44
|
+
if (in_y < 0 || in_y >= i32(u.height)) {
|
|
45
|
+
continue;
|
|
46
|
+
}
|
|
47
|
+
for (var kx: u32 = 0u; kx < u.kernel_w; kx = kx + 1u) {
|
|
48
|
+
let in_x = i32(out_x * u.stride + kx) - pad;
|
|
49
|
+
if (in_x < 0 || in_x >= i32(u.width)) {
|
|
50
|
+
continue;
|
|
51
|
+
}
|
|
52
|
+
let input_idx = (channel * u.height + u32(in_y)) * u.width + u32(in_x);
|
|
53
|
+
let weight_idx = ((channel * u.kernel_h + ky) * u.kernel_w + kx);
|
|
54
|
+
sum = sum + f32(input[input_idx]) * f32(weight[weight_idx]);
|
|
55
|
+
}
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
output[channel * out_spatial + spatial_idx] = f16(sum);
|
|
59
|
+
}
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
import type { Tensor } from '../tensor.js';
|
|
2
|
+
import type { CommandRecorder } from '../command-recorder.js';
|
|
3
|
+
import type { OutputBufferOptions } from './types.js';
|
|
4
|
+
import type { WeightBuffer } from '../weight-buffer.js';
|
|
5
|
+
|
|
6
|
+
export interface GroupedPointwiseConv2DOptions extends OutputBufferOptions {
|
|
7
|
+
inChannels: number;
|
|
8
|
+
outChannels: number;
|
|
9
|
+
height: number;
|
|
10
|
+
width: number;
|
|
11
|
+
groups: number;
|
|
12
|
+
}
|
|
13
|
+
|
|
14
|
+
export declare function runGroupedPointwiseConv2D(
|
|
15
|
+
input: Tensor,
|
|
16
|
+
weight: GPUBuffer | WeightBuffer,
|
|
17
|
+
bias: GPUBuffer | WeightBuffer | null,
|
|
18
|
+
options: GroupedPointwiseConv2DOptions
|
|
19
|
+
): Promise<Tensor>;
|
|
20
|
+
|
|
21
|
+
export declare function recordGroupedPointwiseConv2D(
|
|
22
|
+
recorder: CommandRecorder,
|
|
23
|
+
input: Tensor,
|
|
24
|
+
weight: GPUBuffer | WeightBuffer,
|
|
25
|
+
bias: GPUBuffer | WeightBuffer | null,
|
|
26
|
+
options: GroupedPointwiseConv2DOptions
|
|
27
|
+
): Promise<Tensor>;
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
import { getDevice } from '../device.js';
|
|
2
|
+
import { acquireBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
|
|
3
|
+
import { createTensor, dtypeBytes } from '../tensor.js';
|
|
4
|
+
import { getBuffer } from '../weight-buffer.js';
|
|
5
|
+
import { unifiedKernelWrapper } from './utils.js';
|
|
6
|
+
import { selectRuleValue } from './rule-registry.js';
|
|
7
|
+
import { WORKGROUP_SIZES } from './constants.js';
|
|
8
|
+
|
|
9
|
+
function selectGroupedPointwiseConv2DVariant(isF16) {
|
|
10
|
+
return selectRuleValue('groupedPointwiseConv2d', 'variant', { isF16 });
|
|
11
|
+
}
|
|
12
|
+
|
|
13
|
+
async function _groupedPointwiseConv2D(target, input, weight, bias, options = {}) {
|
|
14
|
+
const recorder = target && typeof target.beginComputePass === 'function' ? target : null;
|
|
15
|
+
const device = target?.device || getDevice();
|
|
16
|
+
const {
|
|
17
|
+
inChannels,
|
|
18
|
+
outChannels,
|
|
19
|
+
height,
|
|
20
|
+
width,
|
|
21
|
+
groups,
|
|
22
|
+
outputBuffer = null,
|
|
23
|
+
} = options;
|
|
24
|
+
|
|
25
|
+
if (
|
|
26
|
+
!Number.isFinite(inChannels) ||
|
|
27
|
+
!Number.isFinite(outChannels) ||
|
|
28
|
+
!Number.isFinite(height) ||
|
|
29
|
+
!Number.isFinite(width) ||
|
|
30
|
+
!Number.isFinite(groups)
|
|
31
|
+
) {
|
|
32
|
+
throw new Error('GroupedPointwiseConv2D requires explicit dimensions.');
|
|
33
|
+
}
|
|
34
|
+
if (groups <= 0 || inChannels % groups !== 0 || outChannels % groups !== 0) {
|
|
35
|
+
throw new Error(
|
|
36
|
+
`GroupedPointwiseConv2D requires inChannels/outChannels divisible by groups. Got ${inChannels}/${outChannels}/${groups}.`
|
|
37
|
+
);
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
const isF16 = input.dtype === 'f16';
|
|
41
|
+
const variant = selectGroupedPointwiseConv2DVariant(isF16);
|
|
42
|
+
const bytesPerElement = dtypeBytes(input.dtype);
|
|
43
|
+
const outputSize = outChannels * height * width * bytesPerElement;
|
|
44
|
+
const output = outputBuffer || acquireBuffer(outputSize, undefined, 'grouped_pointwise_conv2d_output');
|
|
45
|
+
const spatial = height * width;
|
|
46
|
+
|
|
47
|
+
const weightBuffer = getBuffer(weight);
|
|
48
|
+
let biasBuffer = getBuffer(bias);
|
|
49
|
+
let tempBias = null;
|
|
50
|
+
if (!biasBuffer) {
|
|
51
|
+
const biasSize = outChannels * bytesPerElement;
|
|
52
|
+
tempBias = acquireBuffer(biasSize, undefined, 'grouped_pointwise_conv2d_bias_zero');
|
|
53
|
+
biasBuffer = tempBias;
|
|
54
|
+
const paddedSize = Math.ceil(biasSize / 4) * 4;
|
|
55
|
+
device.queue.writeBuffer(biasBuffer, 0, new Uint8Array(paddedSize));
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
await unifiedKernelWrapper(
|
|
59
|
+
'grouped_pointwise_conv2d',
|
|
60
|
+
target,
|
|
61
|
+
variant,
|
|
62
|
+
[input, weightBuffer, biasBuffer, output],
|
|
63
|
+
{
|
|
64
|
+
in_channels: inChannels,
|
|
65
|
+
out_channels: outChannels,
|
|
66
|
+
height,
|
|
67
|
+
width,
|
|
68
|
+
groups,
|
|
69
|
+
_pad0: 0,
|
|
70
|
+
_pad1: 0,
|
|
71
|
+
_pad2: 0,
|
|
72
|
+
},
|
|
73
|
+
[Math.ceil(spatial / WORKGROUP_SIZES.DEFAULT), outChannels, 1]
|
|
74
|
+
);
|
|
75
|
+
|
|
76
|
+
if (tempBias) {
|
|
77
|
+
if (recorder) {
|
|
78
|
+
recorder.trackTemporaryBuffer(tempBias);
|
|
79
|
+
} else {
|
|
80
|
+
releaseBuffer(tempBias);
|
|
81
|
+
}
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
return createTensor(output, input.dtype, [outChannels, height, width], 'grouped_pointwise_conv2d_output');
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
export async function runGroupedPointwiseConv2D(input, weight, bias, options = {}) {
|
|
88
|
+
return _groupedPointwiseConv2D(null, input, weight, bias, options);
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
export async function recordGroupedPointwiseConv2D(recorder, input, weight, bias, options = {}) {
|
|
92
|
+
return _groupedPointwiseConv2D(recorder, input, weight, bias, options);
|
|
93
|
+
}
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
override WORKGROUP_SIZE: u32 = 256u;
|
|
2
|
+
|
|
3
|
+
struct Uniforms {
|
|
4
|
+
in_channels: u32,
|
|
5
|
+
out_channels: u32,
|
|
6
|
+
height: u32,
|
|
7
|
+
width: u32,
|
|
8
|
+
groups: u32,
|
|
9
|
+
_pad0: u32,
|
|
10
|
+
_pad1: u32,
|
|
11
|
+
_pad2: u32,
|
|
12
|
+
}
|
|
13
|
+
|
|
14
|
+
@group(0) @binding(0) var<uniform> u: Uniforms;
|
|
15
|
+
@group(0) @binding(1) var<storage, read> input: array<f32>;
|
|
16
|
+
@group(0) @binding(2) var<storage, read> weight: array<f32>;
|
|
17
|
+
@group(0) @binding(3) var<storage, read> bias: array<f32>;
|
|
18
|
+
@group(0) @binding(4) var<storage, read_write> output: array<f32>;
|
|
19
|
+
|
|
20
|
+
@compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
|
|
21
|
+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
22
|
+
let spatial = u.height * u.width;
|
|
23
|
+
let spatial_idx = gid.x;
|
|
24
|
+
let out_channel = gid.y;
|
|
25
|
+
if (spatial_idx >= spatial || out_channel >= u.out_channels) {
|
|
26
|
+
return;
|
|
27
|
+
}
|
|
28
|
+
let y = spatial_idx / u.width;
|
|
29
|
+
let x = spatial_idx - y * u.width;
|
|
30
|
+
|
|
31
|
+
let in_per_group = u.in_channels / u.groups;
|
|
32
|
+
let out_per_group = u.out_channels / u.groups;
|
|
33
|
+
let group_idx = out_channel / out_per_group;
|
|
34
|
+
let in_offset = group_idx * in_per_group;
|
|
35
|
+
|
|
36
|
+
var sum: f32 = bias[out_channel];
|
|
37
|
+
for (var i: u32 = 0u; i < in_per_group; i = i + 1u) {
|
|
38
|
+
let input_idx = ((in_offset + i) * u.height + y) * u.width + x;
|
|
39
|
+
let weight_idx = out_channel * in_per_group + i;
|
|
40
|
+
sum = sum + input[input_idx] * weight[weight_idx];
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
output[out_channel * spatial + spatial_idx] = sum;
|
|
44
|
+
}
|