@simulatte/doppler 0.1.4 → 0.1.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (199) hide show
  1. package/README.md +26 -10
  2. package/package.json +30 -6
  3. package/src/client/doppler-api.browser.d.ts +1 -0
  4. package/src/client/doppler-api.browser.js +288 -0
  5. package/src/client/doppler-api.js +1 -1
  6. package/src/client/doppler-provider/types.js +1 -1
  7. package/src/config/execution-contract-check.d.ts +33 -0
  8. package/src/config/execution-contract-check.js +72 -0
  9. package/src/config/execution-v0-contract-check.d.ts +94 -0
  10. package/src/config/execution-v0-contract-check.js +251 -0
  11. package/src/config/execution-v0-graph-contract-check.d.ts +20 -0
  12. package/src/config/execution-v0-graph-contract-check.js +64 -0
  13. package/src/config/kernel-path-contract-check.d.ts +76 -0
  14. package/src/config/kernel-path-contract-check.js +479 -0
  15. package/src/config/kernel-path-loader.d.ts +16 -0
  16. package/src/config/kernel-path-loader.js +54 -0
  17. package/src/config/kernels/kernel-ref-digests.js +39 -27
  18. package/src/config/kernels/registry.json +598 -2
  19. package/src/config/loader.js +81 -48
  20. package/src/config/merge-contract-check.d.ts +16 -0
  21. package/src/config/merge-contract-check.js +321 -0
  22. package/src/config/merge-helpers.d.ts +58 -0
  23. package/src/config/merge-helpers.js +54 -0
  24. package/src/config/merge.js +21 -6
  25. package/src/config/presets/models/janus-text.json +2 -0
  26. package/src/config/presets/models/qwen3.json +9 -2
  27. package/src/config/presets/models/transformer.json +5 -0
  28. package/src/config/quantization-contract-check.d.ts +12 -0
  29. package/src/config/quantization-contract-check.js +91 -0
  30. package/src/config/required-inference-fields-contract-check.d.ts +24 -0
  31. package/src/config/required-inference-fields-contract-check.js +237 -0
  32. package/src/config/schema/browser-suite-metrics.schema.d.ts +17 -0
  33. package/src/config/schema/browser-suite-metrics.schema.js +46 -0
  34. package/src/config/schema/conversion-report.schema.d.ts +40 -0
  35. package/src/config/schema/conversion-report.schema.js +108 -0
  36. package/src/config/schema/doppler.schema.js +12 -18
  37. package/src/config/schema/index.d.ts +22 -0
  38. package/src/config/schema/index.js +18 -0
  39. package/src/config/schema/inference-defaults.schema.js +3 -0
  40. package/src/config/schema/inference.schema.d.ts +9 -0
  41. package/src/config/schema/kernel-path.schema.d.ts +6 -0
  42. package/src/config/schema/manifest.schema.d.ts +6 -0
  43. package/src/config/schema/manifest.schema.js +3 -0
  44. package/src/converter/core.d.ts +10 -0
  45. package/src/converter/core.js +27 -2
  46. package/src/converter/parsers/diffusion.js +63 -3
  47. package/src/converter/rope-config.js +42 -0
  48. package/src/gpu/device.js +58 -0
  49. package/src/gpu/kernels/attention.js +98 -0
  50. package/src/gpu/kernels/bias_add.wgsl +8 -6
  51. package/src/gpu/kernels/bias_add_f16.wgsl +8 -5
  52. package/src/gpu/kernels/conv2d.js +1 -1
  53. package/src/gpu/kernels/conv2d.wgsl +7 -8
  54. package/src/gpu/kernels/conv2d_f16.wgsl +7 -8
  55. package/src/gpu/kernels/depthwise_conv2d.d.ts +29 -0
  56. package/src/gpu/kernels/depthwise_conv2d.js +99 -0
  57. package/src/gpu/kernels/depthwise_conv2d.wgsl +55 -0
  58. package/src/gpu/kernels/depthwise_conv2d_f16.wgsl +59 -0
  59. package/src/gpu/kernels/grouped_pointwise_conv2d.d.ts +27 -0
  60. package/src/gpu/kernels/grouped_pointwise_conv2d.js +93 -0
  61. package/src/gpu/kernels/grouped_pointwise_conv2d.wgsl +44 -0
  62. package/src/gpu/kernels/grouped_pointwise_conv2d_f16.wgsl +48 -0
  63. package/src/gpu/kernels/index.d.ts +30 -0
  64. package/src/gpu/kernels/index.js +25 -0
  65. package/src/gpu/kernels/matmul.js +25 -0
  66. package/src/gpu/kernels/pixel_shuffle.js +1 -1
  67. package/src/gpu/kernels/pixel_shuffle.wgsl +4 -5
  68. package/src/gpu/kernels/pixel_shuffle_f16.wgsl +4 -5
  69. package/src/gpu/kernels/relu.d.ts +18 -0
  70. package/src/gpu/kernels/relu.js +58 -0
  71. package/src/gpu/kernels/relu.wgsl +22 -0
  72. package/src/gpu/kernels/relu_f16.wgsl +24 -0
  73. package/src/gpu/kernels/repeat_channels.d.ts +21 -0
  74. package/src/gpu/kernels/repeat_channels.js +60 -0
  75. package/src/gpu/kernels/repeat_channels.wgsl +28 -0
  76. package/src/gpu/kernels/repeat_channels_f16.wgsl +30 -0
  77. package/src/gpu/kernels/residual.js +44 -8
  78. package/src/gpu/kernels/residual.wgsl +6 -3
  79. package/src/gpu/kernels/residual_f16.wgsl +2 -1
  80. package/src/gpu/kernels/residual_f16_vec4.wgsl +2 -1
  81. package/src/gpu/kernels/residual_vec4.wgsl +2 -1
  82. package/src/gpu/kernels/rmsnorm.js +58 -6
  83. package/src/gpu/kernels/rmsnorm.wgsl +14 -6
  84. package/src/gpu/kernels/rmsnorm_f16.wgsl +10 -2
  85. package/src/gpu/kernels/rope.d.ts +2 -0
  86. package/src/gpu/kernels/rope.js +11 -1
  87. package/src/gpu/kernels/rope.wgsl +56 -40
  88. package/src/gpu/kernels/sana_linear_attention.d.ts +27 -0
  89. package/src/gpu/kernels/sana_linear_attention.js +121 -0
  90. package/src/gpu/kernels/sana_linear_attention_apply.wgsl +43 -0
  91. package/src/gpu/kernels/sana_linear_attention_apply_f16.wgsl +46 -0
  92. package/src/gpu/kernels/sana_linear_attention_summary.wgsl +51 -0
  93. package/src/gpu/kernels/sana_linear_attention_summary_f16.wgsl +53 -0
  94. package/src/gpu/kernels/silu.d.ts +1 -0
  95. package/src/gpu/kernels/silu.js +32 -14
  96. package/src/gpu/kernels/silu.wgsl +19 -9
  97. package/src/gpu/kernels/silu_f16.wgsl +19 -9
  98. package/src/gpu/kernels/transpose.js +15 -2
  99. package/src/gpu/kernels/transpose.wgsl +5 -6
  100. package/src/gpu/kernels/upsample2d.js +2 -1
  101. package/src/gpu/kernels/upsample2d.wgsl +6 -9
  102. package/src/gpu/kernels/upsample2d_f16.wgsl +6 -9
  103. package/src/gpu/kernels/utils.js +16 -1
  104. package/src/index-browser.d.ts +1 -1
  105. package/src/index-browser.js +2 -2
  106. package/src/index.js +1 -1
  107. package/src/inference/browser-harness.js +109 -23
  108. package/src/inference/pipelines/diffusion/init.js +14 -0
  109. package/src/inference/pipelines/diffusion/pipeline.js +215 -77
  110. package/src/inference/pipelines/diffusion/sana-transformer.d.ts +53 -0
  111. package/src/inference/pipelines/diffusion/sana-transformer.js +738 -0
  112. package/src/inference/pipelines/diffusion/scheduler.d.ts +17 -1
  113. package/src/inference/pipelines/diffusion/scheduler.js +91 -3
  114. package/src/inference/pipelines/diffusion/text-encoder-gpu.d.ts +11 -4
  115. package/src/inference/pipelines/diffusion/text-encoder-gpu.js +282 -0
  116. package/src/inference/pipelines/diffusion/text-encoder.js +18 -1
  117. package/src/inference/pipelines/diffusion/types.d.ts +4 -0
  118. package/src/inference/pipelines/diffusion/vae.js +782 -78
  119. package/src/inference/pipelines/text/attention/record.js +11 -2
  120. package/src/inference/pipelines/text/attention/run.js +11 -2
  121. package/src/inference/pipelines/text/chat-format.js +25 -1
  122. package/src/inference/pipelines/text/config.d.ts +9 -0
  123. package/src/inference/pipelines/text/config.js +69 -2
  124. package/src/inference/pipelines/text/execution-plan.js +23 -31
  125. package/src/inference/pipelines/text/execution-v0.js +43 -95
  126. package/src/inference/pipelines/text/ffn/standard.js +3 -0
  127. package/src/inference/pipelines/text/init.d.ts +4 -0
  128. package/src/inference/pipelines/text/init.js +56 -9
  129. package/src/inference/pipelines/text/layer.js +11 -0
  130. package/src/inference/pipelines/text.js +4 -0
  131. package/src/inference/tokenizers/bundled.js +156 -33
  132. package/src/rules/execution-rules-contract-check.d.ts +17 -0
  133. package/src/rules/execution-rules-contract-check.js +245 -0
  134. package/src/rules/kernels/depthwise-conv2d.rules.json +6 -0
  135. package/src/rules/kernels/grouped-pointwise-conv2d.rules.json +6 -0
  136. package/src/rules/kernels/relu.rules.json +6 -0
  137. package/src/rules/kernels/repeat-channels.rules.json +6 -0
  138. package/src/rules/kernels/sana-linear-attention.rules.json +6 -0
  139. package/src/rules/layer-pattern-contract-check.d.ts +17 -0
  140. package/src/rules/layer-pattern-contract-check.js +231 -0
  141. package/src/rules/rule-registry.d.ts +28 -0
  142. package/src/rules/rule-registry.js +38 -0
  143. package/src/rules/tooling/command-runtime.rules.json +18 -0
  144. package/src/tooling/command-api.d.ts +27 -1
  145. package/src/tooling/command-api.js +142 -3
  146. package/src/tooling/conversion-config-materializer.d.ts +24 -0
  147. package/src/tooling/conversion-config-materializer.js +99 -0
  148. package/src/tooling/lean-execution-contract-runner.d.ts +43 -0
  149. package/src/tooling/lean-execution-contract-runner.js +158 -0
  150. package/src/tooling/node-browser-command-runner.d.ts +4 -0
  151. package/src/tooling/node-browser-command-runner.js +58 -3
  152. package/src/tooling/node-command-runner.js +15 -0
  153. package/src/tooling/node-convert.d.ts +10 -0
  154. package/src/tooling/node-converter.js +59 -0
  155. package/src/tooling/node-webgpu.js +11 -89
  156. package/src/training/checkpoint-watch.d.ts +7 -0
  157. package/src/training/checkpoint-watch.js +106 -0
  158. package/src/training/checkpoint.d.ts +6 -1
  159. package/src/training/checkpoint.js +12 -2
  160. package/src/training/distillation/artifacts.d.ts +71 -0
  161. package/src/training/distillation/artifacts.js +132 -0
  162. package/src/training/distillation/checkpoint-watch.d.ts +10 -0
  163. package/src/training/distillation/checkpoint-watch.js +57 -0
  164. package/src/training/distillation/dataset.d.ts +59 -0
  165. package/src/training/distillation/dataset.js +337 -0
  166. package/src/training/distillation/eval.d.ts +34 -0
  167. package/src/training/distillation/eval.js +310 -0
  168. package/src/training/distillation/index.d.ts +29 -0
  169. package/src/training/distillation/index.js +29 -0
  170. package/src/training/distillation/runtime.d.ts +20 -0
  171. package/src/training/distillation/runtime.js +121 -0
  172. package/src/training/distillation/scoreboard.d.ts +6 -0
  173. package/src/training/distillation/scoreboard.js +8 -0
  174. package/src/training/distillation/stage-a.d.ts +45 -0
  175. package/src/training/distillation/stage-a.js +338 -0
  176. package/src/training/distillation/stage-b.d.ts +24 -0
  177. package/src/training/distillation/stage-b.js +20 -0
  178. package/src/training/index.d.ts +10 -0
  179. package/src/training/index.js +10 -0
  180. package/src/training/lora-pipeline.d.ts +40 -0
  181. package/src/training/lora-pipeline.js +796 -0
  182. package/src/training/operator-artifacts.d.ts +62 -0
  183. package/src/training/operator-artifacts.js +140 -0
  184. package/src/training/operator-command.d.ts +5 -0
  185. package/src/training/operator-command.js +453 -0
  186. package/src/training/operator-eval.d.ts +48 -0
  187. package/src/training/operator-eval.js +230 -0
  188. package/src/training/operator-scoreboard.d.ts +5 -0
  189. package/src/training/operator-scoreboard.js +44 -0
  190. package/src/training/runner.d.ts +52 -0
  191. package/src/training/runner.js +29 -4
  192. package/src/training/suite.d.ts +112 -0
  193. package/src/training/suite.js +9 -9
  194. package/src/training/workloads.d.ts +164 -0
  195. package/src/training/workloads.js +539 -0
  196. package/src/version.d.ts +2 -0
  197. package/src/version.js +2 -0
  198. package/tools/convert-safetensors-node.js +47 -0
  199. package/tools/doppler-cli.js +252 -41
@@ -780,6 +780,23 @@ function resolveAttentionExecution(recorder) {
780
780
  };
781
781
  }
782
782
 
783
+ function assertAttentionBindGroupBuffer(kernelName, variant, bindingIndex, bindingLabel, buffer, details = []) {
784
+ const isGpuBuffer = buffer && (
785
+ typeof GPUBuffer === 'undefined'
786
+ ? true
787
+ : buffer instanceof GPUBuffer
788
+ );
789
+ if (isGpuBuffer) {
790
+ return;
791
+ }
792
+ const detailText = details.filter(Boolean).join(', ');
793
+ throw new Error(
794
+ `[${kernelName}] variant="${variant}" binding ${bindingIndex} "${bindingLabel}" requires a GPUBuffer` +
795
+ (detailText ? ` (${detailText})` : '') +
796
+ '.'
797
+ );
798
+ }
799
+
783
800
  function releaseAttentionUniform(execution, uniformBuffer) {
784
801
  if (!execution.recorder) {
785
802
  releaseUniformBuffer(uniformBuffer);
@@ -867,6 +884,26 @@ async function executeAttentionBDPA(
867
884
  slidingWindow,
868
885
  });
869
886
 
887
+ assertAttentionBindGroupBuffer('attention_bdpa', variant, 0, 'uniforms', uniformBuffer);
888
+ assertAttentionBindGroupBuffer('attention_bdpa', variant, 1, 'Q', Q?.buffer, [
889
+ `QLabel=${Q?.label ?? 'unknown'}`,
890
+ `QDtype=${Q?.dtype ?? 'unknown'}`,
891
+ ]);
892
+ assertAttentionBindGroupBuffer('attention_bdpa', variant, 2, 'basisK', basisK?.buffer, [
893
+ `basisKLabel=${basisK?.label ?? 'unknown'}`,
894
+ `basisKDtype=${basisK?.dtype ?? 'unknown'}`,
895
+ ]);
896
+ assertAttentionBindGroupBuffer('attention_bdpa', variant, 3, 'basisV', basisV?.buffer, [
897
+ `basisVLabel=${basisV?.label ?? 'unknown'}`,
898
+ `basisVDtype=${basisV?.dtype ?? 'unknown'}`,
899
+ ]);
900
+ assertAttentionBindGroupBuffer('attention_bdpa', variant, 4, 'pagedK', pagedK);
901
+ assertAttentionBindGroupBuffer('attention_bdpa', variant, 5, 'pagedV', pagedV);
902
+ assertAttentionBindGroupBuffer('attention_bdpa', variant, 6, 'index', index);
903
+ assertAttentionBindGroupBuffer('attention_bdpa', variant, 7, 'ropeCos', ropeCos);
904
+ assertAttentionBindGroupBuffer('attention_bdpa', variant, 8, 'ropeSin', ropeSin);
905
+ assertAttentionBindGroupBuffer('attention_bdpa', variant, 9, 'output', outputBuf);
906
+
870
907
  const bindGroup = execution.device.createBindGroup({
871
908
  label: 'attention_bdpa_bind_group',
872
909
  layout: pipeline.getBindGroupLayout(0),
@@ -982,6 +1019,24 @@ async function executeAttention(
982
1019
 
983
1020
  const kvLenBinding = kvLenBuffer || getKvLenFallbackBuffer(execution.device);
984
1021
  const pageTableBinding = kvPageTable || getPageTableFallbackBuffer(execution.device);
1022
+ assertAttentionBindGroupBuffer('attention', plan.variant, 0, 'uniforms', uniformBuffer);
1023
+ assertAttentionBindGroupBuffer('attention', plan.variant, 1, 'Q', Q?.buffer, [
1024
+ `QLabel=${Q?.label ?? 'unknown'}`,
1025
+ `QDtype=${Q?.dtype ?? 'unknown'}`,
1026
+ ]);
1027
+ assertAttentionBindGroupBuffer('attention', plan.variant, 2, 'K', K?.buffer, [
1028
+ `KLabel=${K?.label ?? 'unknown'}`,
1029
+ `KDtype=${K?.dtype ?? 'unknown'}`,
1030
+ ]);
1031
+ assertAttentionBindGroupBuffer('attention', plan.variant, 3, 'V', V?.buffer, [
1032
+ `VLabel=${V?.label ?? 'unknown'}`,
1033
+ `VDtype=${V?.dtype ?? 'unknown'}`,
1034
+ ]);
1035
+ assertAttentionBindGroupBuffer('attention', plan.variant, 4, 'output', outputBuf);
1036
+ assertAttentionBindGroupBuffer('attention', plan.variant, 5, 'kvLen', kvLenBinding);
1037
+ assertAttentionBindGroupBuffer('attention', plan.variant, 6, 'pageTable', pageTableBinding, [
1038
+ `kvLayout=${kvLayout}`,
1039
+ ]);
985
1040
  const bindGroup = execution.device.createBindGroup({
986
1041
  label: 'attention_bind_group',
987
1042
  layout: pipeline.getBindGroupLayout(0),
@@ -1099,6 +1154,31 @@ async function executeAttentionTiered(
1099
1154
  });
1100
1155
 
1101
1156
  const pageTableBinding = coldPageTable || getPageTableFallbackBuffer(execution.device);
1157
+ assertAttentionBindGroupBuffer('attention_tiered', variant, 0, 'uniforms', uniformBuffer);
1158
+ assertAttentionBindGroupBuffer('attention_tiered', variant, 1, 'Q', Q?.buffer, [
1159
+ `QLabel=${Q?.label ?? 'unknown'}`,
1160
+ `QDtype=${Q?.dtype ?? 'unknown'}`,
1161
+ ]);
1162
+ assertAttentionBindGroupBuffer('attention_tiered', variant, 2, 'hotK', hotK?.buffer, [
1163
+ `hotKLabel=${hotK?.label ?? 'unknown'}`,
1164
+ `hotKDtype=${hotK?.dtype ?? 'unknown'}`,
1165
+ ]);
1166
+ assertAttentionBindGroupBuffer('attention_tiered', variant, 3, 'hotV', hotV?.buffer, [
1167
+ `hotVLabel=${hotV?.label ?? 'unknown'}`,
1168
+ `hotVDtype=${hotV?.dtype ?? 'unknown'}`,
1169
+ ]);
1170
+ assertAttentionBindGroupBuffer('attention_tiered', variant, 4, 'coldK', coldK?.buffer, [
1171
+ `coldKLabel=${coldK?.label ?? 'unknown'}`,
1172
+ `coldKDtype=${coldK?.dtype ?? 'unknown'}`,
1173
+ ]);
1174
+ assertAttentionBindGroupBuffer('attention_tiered', variant, 5, 'coldV', coldV?.buffer, [
1175
+ `coldVLabel=${coldV?.label ?? 'unknown'}`,
1176
+ `coldVDtype=${coldV?.dtype ?? 'unknown'}`,
1177
+ ]);
1178
+ assertAttentionBindGroupBuffer('attention_tiered', variant, 6, 'output', outputBuf);
1179
+ assertAttentionBindGroupBuffer('attention_tiered', variant, 7, 'pageTable', pageTableBinding, [
1180
+ `coldLayout=${coldLayout}`,
1181
+ ]);
1102
1182
  const bindGroup = execution.device.createBindGroup({
1103
1183
  label: 'attention_tiered_bind_group',
1104
1184
  layout: pipeline.getBindGroupLayout(0),
@@ -1200,6 +1280,24 @@ async function executeAttentionTieredQuant(
1200
1280
  packedStride,
1201
1281
  });
1202
1282
 
1283
+ assertAttentionBindGroupBuffer('attention_tiered_quant', variant, 0, 'uniforms', uniformBuffer);
1284
+ assertAttentionBindGroupBuffer('attention_tiered_quant', variant, 1, 'Q', Q?.buffer, [
1285
+ `QLabel=${Q?.label ?? 'unknown'}`,
1286
+ `QDtype=${Q?.dtype ?? 'unknown'}`,
1287
+ ]);
1288
+ assertAttentionBindGroupBuffer('attention_tiered_quant', variant, 2, 'hotK', hotK?.buffer, [
1289
+ `hotKLabel=${hotK?.label ?? 'unknown'}`,
1290
+ `hotKDtype=${hotK?.dtype ?? 'unknown'}`,
1291
+ ]);
1292
+ assertAttentionBindGroupBuffer('attention_tiered_quant', variant, 3, 'hotV', hotV?.buffer, [
1293
+ `hotVLabel=${hotV?.label ?? 'unknown'}`,
1294
+ `hotVDtype=${hotV?.dtype ?? 'unknown'}`,
1295
+ ]);
1296
+ assertAttentionBindGroupBuffer('attention_tiered_quant', variant, 4, 'coldPackedK', coldPackedK);
1297
+ assertAttentionBindGroupBuffer('attention_tiered_quant', variant, 5, 'coldPackedV', coldPackedV);
1298
+ assertAttentionBindGroupBuffer('attention_tiered_quant', variant, 6, 'coldScalesK', coldScalesK);
1299
+ assertAttentionBindGroupBuffer('attention_tiered_quant', variant, 7, 'coldScalesV', coldScalesV);
1300
+ assertAttentionBindGroupBuffer('attention_tiered_quant', variant, 8, 'output', outputBuf);
1203
1301
  const bindGroup = execution.device.createBindGroup({
1204
1302
  label: 'attention_tiered_quant_bind_group',
1205
1303
  layout: pipeline.getBindGroupLayout(0),
@@ -14,6 +14,10 @@ struct Uniforms {
14
14
  dim: u32,
15
15
  data_offset: u32, // byte offset into data buffer (divide by 4 for F32)
16
16
  bias_offset: u32, // byte offset into bias buffer (divide by 4 for F32)
17
+ token_stride: u32,
18
+ _pad0: u32,
19
+ _pad1: u32,
20
+ _pad2: u32,
17
21
  }
18
22
 
19
23
  override WORKGROUP_SIZE: u32 = 256u;
@@ -24,17 +28,15 @@ override WORKGROUP_SIZE: u32 = 256u;
24
28
 
25
29
  @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
26
30
  fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
27
- let idx = gid.x;
28
- let total = u.num_tokens * u.dim;
29
- if (idx >= total) {
31
+ let d = gid.x;
32
+ let token = gid.z * max(u.token_stride, 1u) + gid.y;
33
+ if (token >= u.num_tokens || d >= u.dim) {
30
34
  return;
31
35
  }
32
36
 
33
37
  // Convert byte offsets to F32 indices
34
38
  let data_base = u.data_offset / 4u;
35
39
  let bias_base = u.bias_offset / 4u;
36
-
37
- let d = idx % u.dim;
40
+ let idx = token * u.dim + d;
38
41
  data[data_base + idx] = data[data_base + idx] + bias[bias_base + d];
39
42
  }
40
-
@@ -18,6 +18,10 @@ struct Uniforms {
18
18
  dim: u32,
19
19
  data_offset: u32, // byte offset into data buffer (divide by 2 for F16)
20
20
  bias_offset: u32, // byte offset into bias buffer (divide by 2 for F16)
21
+ token_stride: u32,
22
+ _pad0: u32,
23
+ _pad1: u32,
24
+ _pad2: u32,
21
25
  }
22
26
 
23
27
  override WORKGROUP_SIZE: u32 = 256u;
@@ -28,17 +32,16 @@ override WORKGROUP_SIZE: u32 = 256u;
28
32
 
29
33
  @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
30
34
  fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
31
- let idx = gid.x;
32
- let total = u.num_tokens * u.dim;
33
- if (idx >= total) {
35
+ let d = gid.x;
36
+ let token = gid.z * max(u.token_stride, 1u) + gid.y;
37
+ if (token >= u.num_tokens || d >= u.dim) {
34
38
  return;
35
39
  }
36
40
 
37
41
  // Convert byte offsets to F16 indices
38
42
  let data_base = u.data_offset / 2u;
39
43
  let bias_base = u.bias_offset / 2u;
40
-
41
- let d = idx % u.dim;
44
+ let idx = token * u.dim + d;
42
45
  let out = f32(data[data_base + idx]) + f32(bias[bias_base + d]);
43
46
  data[data_base + idx] = f16(out);
44
47
  }
@@ -58,7 +58,7 @@ async function _conv2d(target, input, weight, bias, options = {}) {
58
58
  kernel_h: kernelH, kernel_w: kernelW,
59
59
  stride, pad, _pad0: 0, _pad1: 0,
60
60
  },
61
- Math.ceil((outChannels * outHeight * outWidth) / WORKGROUP_SIZES.DEFAULT)
61
+ [Math.ceil((outHeight * outWidth) / WORKGROUP_SIZES.DEFAULT), outChannels, 1]
62
62
  );
63
63
 
64
64
  if (tempBias) {
@@ -27,19 +27,18 @@ struct Uniforms {
27
27
 
28
28
  @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
29
29
  fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
30
- let idx = gid.x;
31
30
  let out_height = u.out_height;
32
31
  let out_width = u.out_width;
33
- let out_size = u.out_channels * out_height * out_width;
34
- if (idx >= out_size) {
32
+ let out_spatial = out_height * out_width;
33
+ let out_spatial_idx = gid.x;
34
+ let out_c = gid.y;
35
+ if (out_c >= u.out_channels || out_spatial_idx >= out_spatial) {
35
36
  return;
36
37
  }
37
38
 
38
- let out_spatial = out_height * out_width;
39
- let out_c = idx / out_spatial;
40
- let rem = idx - out_c * out_spatial;
41
- let out_y = rem / out_width;
42
- let out_x = rem - out_y * out_width;
39
+ let out_y = out_spatial_idx / out_width;
40
+ let out_x = out_spatial_idx - out_y * out_width;
41
+ let idx = out_c * out_spatial + out_spatial_idx;
43
42
 
44
43
  var sum: f32 = bias[out_c];
45
44
 
@@ -29,19 +29,18 @@ struct Uniforms {
29
29
 
30
30
  @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
31
31
  fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
32
- let idx = gid.x;
33
32
  let out_height = u.out_height;
34
33
  let out_width = u.out_width;
35
- let out_size = u.out_channels * out_height * out_width;
36
- if (idx >= out_size) {
34
+ let out_spatial = out_height * out_width;
35
+ let out_spatial_idx = gid.x;
36
+ let out_c = gid.y;
37
+ if (out_c >= u.out_channels || out_spatial_idx >= out_spatial) {
37
38
  return;
38
39
  }
39
40
 
40
- let out_spatial = out_height * out_width;
41
- let out_c = idx / out_spatial;
42
- let rem = idx - out_c * out_spatial;
43
- let out_y = rem / out_width;
44
- let out_x = rem - out_y * out_width;
41
+ let out_y = out_spatial_idx / out_width;
42
+ let out_x = out_spatial_idx - out_y * out_width;
43
+ let idx = out_c * out_spatial + out_spatial_idx;
45
44
 
46
45
  var sum: f32 = f32(bias[out_c]);
47
46
 
@@ -0,0 +1,29 @@
1
+ import type { Tensor } from '../tensor.js';
2
+ import type { CommandRecorder } from '../command-recorder.js';
3
+ import type { OutputBufferOptions } from './types.js';
4
+ import type { WeightBuffer } from '../weight-buffer.js';
5
+
6
+ export interface DepthwiseConv2DOptions extends OutputBufferOptions {
7
+ channels: number;
8
+ height: number;
9
+ width: number;
10
+ kernelH: number;
11
+ kernelW: number;
12
+ stride?: number;
13
+ pad?: number;
14
+ }
15
+
16
+ export declare function runDepthwiseConv2D(
17
+ input: Tensor,
18
+ weight: GPUBuffer | WeightBuffer,
19
+ bias: GPUBuffer | WeightBuffer | null,
20
+ options: DepthwiseConv2DOptions
21
+ ): Promise<Tensor>;
22
+
23
+ export declare function recordDepthwiseConv2D(
24
+ recorder: CommandRecorder,
25
+ input: Tensor,
26
+ weight: GPUBuffer | WeightBuffer,
27
+ bias: GPUBuffer | WeightBuffer | null,
28
+ options: DepthwiseConv2DOptions
29
+ ): Promise<Tensor>;
@@ -0,0 +1,99 @@
1
+ import { getDevice } from '../device.js';
2
+ import { acquireBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
3
+ import { createTensor, dtypeBytes } from '../tensor.js';
4
+ import { getBuffer } from '../weight-buffer.js';
5
+ import { unifiedKernelWrapper } from './utils.js';
6
+ import { selectRuleValue } from './rule-registry.js';
7
+ import { WORKGROUP_SIZES } from './constants.js';
8
+
9
+ function selectDepthwiseConv2DVariant(isF16) {
10
+ return selectRuleValue('depthwiseConv2d', 'variant', { isF16 });
11
+ }
12
+
13
+ async function _depthwiseConv2D(target, input, weight, bias, options = {}) {
14
+ const recorder = target && typeof target.beginComputePass === 'function' ? target : null;
15
+ const device = target?.device || getDevice();
16
+ const {
17
+ channels,
18
+ height,
19
+ width,
20
+ kernelH,
21
+ kernelW,
22
+ stride = 1,
23
+ pad = 0,
24
+ outputBuffer = null,
25
+ } = options;
26
+
27
+ if (
28
+ !Number.isFinite(channels) ||
29
+ !Number.isFinite(height) ||
30
+ !Number.isFinite(width) ||
31
+ !Number.isFinite(kernelH) ||
32
+ !Number.isFinite(kernelW)
33
+ ) {
34
+ throw new Error('DepthwiseConv2D requires explicit dimensions.');
35
+ }
36
+
37
+ const outHeight = Math.floor((height + pad * 2 - kernelH) / stride) + 1;
38
+ const outWidth = Math.floor((width + pad * 2 - kernelW) / stride) + 1;
39
+ if (outHeight <= 0 || outWidth <= 0) {
40
+ throw new Error(`DepthwiseConv2D invalid output size: ${outHeight}x${outWidth}`);
41
+ }
42
+
43
+ const isF16 = input.dtype === 'f16';
44
+ const variant = selectDepthwiseConv2DVariant(isF16);
45
+ const bytesPerElement = dtypeBytes(input.dtype);
46
+ const outputSize = channels * outHeight * outWidth * bytesPerElement;
47
+ const output = outputBuffer || acquireBuffer(outputSize, undefined, 'depthwise_conv2d_output');
48
+ const outSpatial = outHeight * outWidth;
49
+
50
+ const weightBuffer = getBuffer(weight);
51
+ let biasBuffer = getBuffer(bias);
52
+ let tempBias = null;
53
+ if (!biasBuffer) {
54
+ const biasSize = channels * bytesPerElement;
55
+ tempBias = acquireBuffer(biasSize, undefined, 'depthwise_conv2d_bias_zero');
56
+ biasBuffer = tempBias;
57
+ const paddedSize = Math.ceil(biasSize / 4) * 4;
58
+ device.queue.writeBuffer(biasBuffer, 0, new Uint8Array(paddedSize));
59
+ }
60
+
61
+ await unifiedKernelWrapper(
62
+ 'depthwise_conv2d',
63
+ target,
64
+ variant,
65
+ [input, weightBuffer, biasBuffer, output],
66
+ {
67
+ channels,
68
+ height,
69
+ width,
70
+ out_height: outHeight,
71
+ out_width: outWidth,
72
+ kernel_h: kernelH,
73
+ kernel_w: kernelW,
74
+ stride,
75
+ pad,
76
+ _pad0: 0,
77
+ _pad1: 0,
78
+ },
79
+ [Math.ceil(outSpatial / WORKGROUP_SIZES.DEFAULT), channels, 1]
80
+ );
81
+
82
+ if (tempBias) {
83
+ if (recorder) {
84
+ recorder.trackTemporaryBuffer(tempBias);
85
+ } else {
86
+ releaseBuffer(tempBias);
87
+ }
88
+ }
89
+
90
+ return createTensor(output, input.dtype, [channels, outHeight, outWidth], 'depthwise_conv2d_output');
91
+ }
92
+
93
+ export async function runDepthwiseConv2D(input, weight, bias, options = {}) {
94
+ return _depthwiseConv2D(null, input, weight, bias, options);
95
+ }
96
+
97
+ export async function recordDepthwiseConv2D(recorder, input, weight, bias, options = {}) {
98
+ return _depthwiseConv2D(recorder, input, weight, bias, options);
99
+ }
@@ -0,0 +1,55 @@
1
+ override WORKGROUP_SIZE: u32 = 256u;
2
+
3
+ struct Uniforms {
4
+ channels: u32,
5
+ height: u32,
6
+ width: u32,
7
+ out_height: u32,
8
+ out_width: u32,
9
+ kernel_h: u32,
10
+ kernel_w: u32,
11
+ stride: u32,
12
+ pad: u32,
13
+ _pad0: u32,
14
+ _pad1: u32,
15
+ _pad2: u32,
16
+ }
17
+
18
+ @group(0) @binding(0) var<uniform> u: Uniforms;
19
+ @group(0) @binding(1) var<storage, read> input: array<f32>;
20
+ @group(0) @binding(2) var<storage, read> weight: array<f32>;
21
+ @group(0) @binding(3) var<storage, read> bias: array<f32>;
22
+ @group(0) @binding(4) var<storage, read_write> output: array<f32>;
23
+
24
+ @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
25
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
26
+ let out_spatial = u.out_height * u.out_width;
27
+ let spatial_idx = gid.x;
28
+ let channel = gid.y;
29
+ if (spatial_idx >= out_spatial || channel >= u.channels) {
30
+ return;
31
+ }
32
+ let out_y = spatial_idx / u.out_width;
33
+ let out_x = spatial_idx - out_y * u.out_width;
34
+
35
+ var sum: f32 = bias[channel];
36
+ let pad = i32(u.pad);
37
+
38
+ for (var ky: u32 = 0u; ky < u.kernel_h; ky = ky + 1u) {
39
+ let in_y = i32(out_y * u.stride + ky) - pad;
40
+ if (in_y < 0 || in_y >= i32(u.height)) {
41
+ continue;
42
+ }
43
+ for (var kx: u32 = 0u; kx < u.kernel_w; kx = kx + 1u) {
44
+ let in_x = i32(out_x * u.stride + kx) - pad;
45
+ if (in_x < 0 || in_x >= i32(u.width)) {
46
+ continue;
47
+ }
48
+ let input_idx = (channel * u.height + u32(in_y)) * u.width + u32(in_x);
49
+ let weight_idx = ((channel * u.kernel_h + ky) * u.kernel_w + kx);
50
+ sum = sum + input[input_idx] * weight[weight_idx];
51
+ }
52
+ }
53
+
54
+ output[channel * out_spatial + spatial_idx] = sum;
55
+ }
@@ -0,0 +1,59 @@
1
+ // Depthwise Conv2D Kernel (NCHW, f16)
2
+
3
+ enable f16;
4
+
5
+ override WORKGROUP_SIZE: u32 = 256u;
6
+
7
+ struct Uniforms {
8
+ channels: u32,
9
+ height: u32,
10
+ width: u32,
11
+ out_height: u32,
12
+ out_width: u32,
13
+ kernel_h: u32,
14
+ kernel_w: u32,
15
+ stride: u32,
16
+ pad: u32,
17
+ _pad0: u32,
18
+ _pad1: u32,
19
+ _pad2: u32,
20
+ }
21
+
22
+ @group(0) @binding(0) var<uniform> u: Uniforms;
23
+ @group(0) @binding(1) var<storage, read> input: array<f16>;
24
+ @group(0) @binding(2) var<storage, read> weight: array<f16>;
25
+ @group(0) @binding(3) var<storage, read> bias: array<f16>;
26
+ @group(0) @binding(4) var<storage, read_write> output: array<f16>;
27
+
28
+ @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
29
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
30
+ let out_spatial = u.out_height * u.out_width;
31
+ let spatial_idx = gid.x;
32
+ let channel = gid.y;
33
+ if (spatial_idx >= out_spatial || channel >= u.channels) {
34
+ return;
35
+ }
36
+ let out_y = spatial_idx / u.out_width;
37
+ let out_x = spatial_idx - out_y * u.out_width;
38
+
39
+ var sum: f32 = f32(bias[channel]);
40
+ let pad = i32(u.pad);
41
+
42
+ for (var ky: u32 = 0u; ky < u.kernel_h; ky = ky + 1u) {
43
+ let in_y = i32(out_y * u.stride + ky) - pad;
44
+ if (in_y < 0 || in_y >= i32(u.height)) {
45
+ continue;
46
+ }
47
+ for (var kx: u32 = 0u; kx < u.kernel_w; kx = kx + 1u) {
48
+ let in_x = i32(out_x * u.stride + kx) - pad;
49
+ if (in_x < 0 || in_x >= i32(u.width)) {
50
+ continue;
51
+ }
52
+ let input_idx = (channel * u.height + u32(in_y)) * u.width + u32(in_x);
53
+ let weight_idx = ((channel * u.kernel_h + ky) * u.kernel_w + kx);
54
+ sum = sum + f32(input[input_idx]) * f32(weight[weight_idx]);
55
+ }
56
+ }
57
+
58
+ output[channel * out_spatial + spatial_idx] = f16(sum);
59
+ }
@@ -0,0 +1,27 @@
1
+ import type { Tensor } from '../tensor.js';
2
+ import type { CommandRecorder } from '../command-recorder.js';
3
+ import type { OutputBufferOptions } from './types.js';
4
+ import type { WeightBuffer } from '../weight-buffer.js';
5
+
6
+ export interface GroupedPointwiseConv2DOptions extends OutputBufferOptions {
7
+ inChannels: number;
8
+ outChannels: number;
9
+ height: number;
10
+ width: number;
11
+ groups: number;
12
+ }
13
+
14
+ export declare function runGroupedPointwiseConv2D(
15
+ input: Tensor,
16
+ weight: GPUBuffer | WeightBuffer,
17
+ bias: GPUBuffer | WeightBuffer | null,
18
+ options: GroupedPointwiseConv2DOptions
19
+ ): Promise<Tensor>;
20
+
21
+ export declare function recordGroupedPointwiseConv2D(
22
+ recorder: CommandRecorder,
23
+ input: Tensor,
24
+ weight: GPUBuffer | WeightBuffer,
25
+ bias: GPUBuffer | WeightBuffer | null,
26
+ options: GroupedPointwiseConv2DOptions
27
+ ): Promise<Tensor>;
@@ -0,0 +1,93 @@
1
+ import { getDevice } from '../device.js';
2
+ import { acquireBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
3
+ import { createTensor, dtypeBytes } from '../tensor.js';
4
+ import { getBuffer } from '../weight-buffer.js';
5
+ import { unifiedKernelWrapper } from './utils.js';
6
+ import { selectRuleValue } from './rule-registry.js';
7
+ import { WORKGROUP_SIZES } from './constants.js';
8
+
9
+ function selectGroupedPointwiseConv2DVariant(isF16) {
10
+ return selectRuleValue('groupedPointwiseConv2d', 'variant', { isF16 });
11
+ }
12
+
13
+ async function _groupedPointwiseConv2D(target, input, weight, bias, options = {}) {
14
+ const recorder = target && typeof target.beginComputePass === 'function' ? target : null;
15
+ const device = target?.device || getDevice();
16
+ const {
17
+ inChannels,
18
+ outChannels,
19
+ height,
20
+ width,
21
+ groups,
22
+ outputBuffer = null,
23
+ } = options;
24
+
25
+ if (
26
+ !Number.isFinite(inChannels) ||
27
+ !Number.isFinite(outChannels) ||
28
+ !Number.isFinite(height) ||
29
+ !Number.isFinite(width) ||
30
+ !Number.isFinite(groups)
31
+ ) {
32
+ throw new Error('GroupedPointwiseConv2D requires explicit dimensions.');
33
+ }
34
+ if (groups <= 0 || inChannels % groups !== 0 || outChannels % groups !== 0) {
35
+ throw new Error(
36
+ `GroupedPointwiseConv2D requires inChannels/outChannels divisible by groups. Got ${inChannels}/${outChannels}/${groups}.`
37
+ );
38
+ }
39
+
40
+ const isF16 = input.dtype === 'f16';
41
+ const variant = selectGroupedPointwiseConv2DVariant(isF16);
42
+ const bytesPerElement = dtypeBytes(input.dtype);
43
+ const outputSize = outChannels * height * width * bytesPerElement;
44
+ const output = outputBuffer || acquireBuffer(outputSize, undefined, 'grouped_pointwise_conv2d_output');
45
+ const spatial = height * width;
46
+
47
+ const weightBuffer = getBuffer(weight);
48
+ let biasBuffer = getBuffer(bias);
49
+ let tempBias = null;
50
+ if (!biasBuffer) {
51
+ const biasSize = outChannels * bytesPerElement;
52
+ tempBias = acquireBuffer(biasSize, undefined, 'grouped_pointwise_conv2d_bias_zero');
53
+ biasBuffer = tempBias;
54
+ const paddedSize = Math.ceil(biasSize / 4) * 4;
55
+ device.queue.writeBuffer(biasBuffer, 0, new Uint8Array(paddedSize));
56
+ }
57
+
58
+ await unifiedKernelWrapper(
59
+ 'grouped_pointwise_conv2d',
60
+ target,
61
+ variant,
62
+ [input, weightBuffer, biasBuffer, output],
63
+ {
64
+ in_channels: inChannels,
65
+ out_channels: outChannels,
66
+ height,
67
+ width,
68
+ groups,
69
+ _pad0: 0,
70
+ _pad1: 0,
71
+ _pad2: 0,
72
+ },
73
+ [Math.ceil(spatial / WORKGROUP_SIZES.DEFAULT), outChannels, 1]
74
+ );
75
+
76
+ if (tempBias) {
77
+ if (recorder) {
78
+ recorder.trackTemporaryBuffer(tempBias);
79
+ } else {
80
+ releaseBuffer(tempBias);
81
+ }
82
+ }
83
+
84
+ return createTensor(output, input.dtype, [outChannels, height, width], 'grouped_pointwise_conv2d_output');
85
+ }
86
+
87
+ export async function runGroupedPointwiseConv2D(input, weight, bias, options = {}) {
88
+ return _groupedPointwiseConv2D(null, input, weight, bias, options);
89
+ }
90
+
91
+ export async function recordGroupedPointwiseConv2D(recorder, input, weight, bias, options = {}) {
92
+ return _groupedPointwiseConv2D(recorder, input, weight, bias, options);
93
+ }
@@ -0,0 +1,44 @@
1
+ override WORKGROUP_SIZE: u32 = 256u;
2
+
3
+ struct Uniforms {
4
+ in_channels: u32,
5
+ out_channels: u32,
6
+ height: u32,
7
+ width: u32,
8
+ groups: u32,
9
+ _pad0: u32,
10
+ _pad1: u32,
11
+ _pad2: u32,
12
+ }
13
+
14
+ @group(0) @binding(0) var<uniform> u: Uniforms;
15
+ @group(0) @binding(1) var<storage, read> input: array<f32>;
16
+ @group(0) @binding(2) var<storage, read> weight: array<f32>;
17
+ @group(0) @binding(3) var<storage, read> bias: array<f32>;
18
+ @group(0) @binding(4) var<storage, read_write> output: array<f32>;
19
+
20
+ @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
21
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
22
+ let spatial = u.height * u.width;
23
+ let spatial_idx = gid.x;
24
+ let out_channel = gid.y;
25
+ if (spatial_idx >= spatial || out_channel >= u.out_channels) {
26
+ return;
27
+ }
28
+ let y = spatial_idx / u.width;
29
+ let x = spatial_idx - y * u.width;
30
+
31
+ let in_per_group = u.in_channels / u.groups;
32
+ let out_per_group = u.out_channels / u.groups;
33
+ let group_idx = out_channel / out_per_group;
34
+ let in_offset = group_idx * in_per_group;
35
+
36
+ var sum: f32 = bias[out_channel];
37
+ for (var i: u32 = 0u; i < in_per_group; i = i + 1u) {
38
+ let input_idx = ((in_offset + i) * u.height + y) * u.width + x;
39
+ let weight_idx = out_channel * in_per_group + i;
40
+ sum = sum + input[input_idx] * weight[weight_idx];
41
+ }
42
+
43
+ output[out_channel * spatial + spatial_idx] = sum;
44
+ }