@simulatte/doppler 0.1.5 → 0.1.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. package/README.md +23 -8
  2. package/package.json +7 -4
  3. package/src/config/kernels/kernel-ref-digests.js +39 -39
  4. package/src/config/kernels/registry.json +42 -2
  5. package/src/config/loader.js +31 -2
  6. package/src/config/merge.js +18 -0
  7. package/src/config/presets/models/qwen3.json +9 -2
  8. package/src/config/presets/models/transformer.json +5 -0
  9. package/src/config/required-inference-fields-contract-check.js +6 -0
  10. package/src/config/schema/inference-defaults.schema.js +3 -0
  11. package/src/config/schema/inference.schema.d.ts +9 -0
  12. package/src/config/schema/kernel-path.schema.d.ts +6 -0
  13. package/src/config/schema/manifest.schema.d.ts +6 -0
  14. package/src/config/schema/manifest.schema.js +3 -0
  15. package/src/converter/rope-config.js +42 -0
  16. package/src/gpu/device.js +58 -0
  17. package/src/gpu/kernels/attention.js +98 -0
  18. package/src/gpu/kernels/bias_add.wgsl +8 -6
  19. package/src/gpu/kernels/bias_add_f16.wgsl +8 -5
  20. package/src/gpu/kernels/conv2d.js +1 -1
  21. package/src/gpu/kernels/conv2d.wgsl +7 -8
  22. package/src/gpu/kernels/conv2d_f16.wgsl +7 -8
  23. package/src/gpu/kernels/depthwise_conv2d.js +2 -1
  24. package/src/gpu/kernels/depthwise_conv2d.wgsl +6 -9
  25. package/src/gpu/kernels/depthwise_conv2d_f16.wgsl +6 -9
  26. package/src/gpu/kernels/grouped_pointwise_conv2d.js +2 -1
  27. package/src/gpu/kernels/grouped_pointwise_conv2d.wgsl +6 -9
  28. package/src/gpu/kernels/grouped_pointwise_conv2d_f16.wgsl +6 -9
  29. package/src/gpu/kernels/matmul.js +25 -0
  30. package/src/gpu/kernels/pixel_shuffle.js +1 -1
  31. package/src/gpu/kernels/pixel_shuffle.wgsl +4 -5
  32. package/src/gpu/kernels/pixel_shuffle_f16.wgsl +4 -5
  33. package/src/gpu/kernels/relu.js +15 -2
  34. package/src/gpu/kernels/relu.wgsl +2 -1
  35. package/src/gpu/kernels/relu_f16.wgsl +2 -1
  36. package/src/gpu/kernels/repeat_channels.js +1 -1
  37. package/src/gpu/kernels/repeat_channels.wgsl +4 -5
  38. package/src/gpu/kernels/repeat_channels_f16.wgsl +4 -5
  39. package/src/gpu/kernels/residual.js +44 -8
  40. package/src/gpu/kernels/residual.wgsl +6 -3
  41. package/src/gpu/kernels/residual_f16.wgsl +2 -1
  42. package/src/gpu/kernels/residual_f16_vec4.wgsl +2 -1
  43. package/src/gpu/kernels/residual_vec4.wgsl +2 -1
  44. package/src/gpu/kernels/rmsnorm.js +58 -6
  45. package/src/gpu/kernels/rmsnorm.wgsl +14 -6
  46. package/src/gpu/kernels/rmsnorm_f16.wgsl +10 -2
  47. package/src/gpu/kernels/rope.d.ts +2 -0
  48. package/src/gpu/kernels/rope.js +11 -1
  49. package/src/gpu/kernels/rope.wgsl +56 -40
  50. package/src/gpu/kernels/sana_linear_attention.js +1 -2
  51. package/src/gpu/kernels/sana_linear_attention_apply.wgsl +4 -5
  52. package/src/gpu/kernels/sana_linear_attention_apply_f16.wgsl +4 -5
  53. package/src/gpu/kernels/sana_linear_attention_summary.wgsl +4 -0
  54. package/src/gpu/kernels/sana_linear_attention_summary_f16.wgsl +4 -0
  55. package/src/gpu/kernels/silu.d.ts +1 -0
  56. package/src/gpu/kernels/silu.js +32 -14
  57. package/src/gpu/kernels/silu.wgsl +19 -9
  58. package/src/gpu/kernels/silu_f16.wgsl +19 -9
  59. package/src/gpu/kernels/transpose.js +15 -2
  60. package/src/gpu/kernels/transpose.wgsl +5 -6
  61. package/src/gpu/kernels/upsample2d.js +2 -1
  62. package/src/gpu/kernels/upsample2d.wgsl +6 -9
  63. package/src/gpu/kernels/upsample2d_f16.wgsl +6 -9
  64. package/src/gpu/kernels/utils.js +16 -1
  65. package/src/inference/browser-harness.js +47 -1
  66. package/src/inference/pipelines/diffusion/pipeline.js +15 -6
  67. package/src/inference/pipelines/diffusion/text-encoder-gpu.d.ts +5 -0
  68. package/src/inference/pipelines/diffusion/text-encoder-gpu.js +27 -15
  69. package/src/inference/pipelines/text/attention/record.js +11 -2
  70. package/src/inference/pipelines/text/attention/run.js +11 -2
  71. package/src/inference/pipelines/text/chat-format.js +25 -1
  72. package/src/inference/pipelines/text/config.d.ts +4 -0
  73. package/src/inference/pipelines/text/config.js +68 -1
  74. package/src/inference/pipelines/text/execution-plan.js +23 -31
  75. package/src/inference/pipelines/text/execution-v0.js +29 -2
  76. package/src/inference/pipelines/text/ffn/standard.js +3 -0
  77. package/src/inference/pipelines/text/init.d.ts +4 -0
  78. package/src/inference/pipelines/text/init.js +56 -9
  79. package/src/inference/pipelines/text/layer.js +11 -0
  80. package/src/inference/pipelines/text.js +4 -0
  81. package/src/inference/tokenizers/bundled.js +156 -33
  82. package/src/rules/tooling/command-runtime.rules.json +18 -0
  83. package/src/tooling/command-api.d.ts +27 -1
  84. package/src/tooling/command-api.js +142 -3
  85. package/src/tooling/node-browser-command-runner.d.ts +4 -0
  86. package/src/tooling/node-browser-command-runner.js +58 -3
  87. package/src/tooling/node-command-runner.js +15 -0
  88. package/src/tooling/node-webgpu.js +9 -87
  89. package/src/training/checkpoint-watch.d.ts +7 -0
  90. package/src/training/checkpoint-watch.js +106 -0
  91. package/src/training/checkpoint.d.ts +6 -1
  92. package/src/training/checkpoint.js +12 -2
  93. package/src/training/distillation/artifacts.d.ts +71 -0
  94. package/src/training/distillation/artifacts.js +132 -0
  95. package/src/training/distillation/checkpoint-watch.d.ts +10 -0
  96. package/src/training/distillation/checkpoint-watch.js +57 -0
  97. package/src/training/distillation/dataset.d.ts +59 -0
  98. package/src/training/distillation/dataset.js +337 -0
  99. package/src/training/distillation/eval.d.ts +34 -0
  100. package/src/training/distillation/eval.js +310 -0
  101. package/src/training/distillation/index.d.ts +29 -0
  102. package/src/training/distillation/index.js +29 -0
  103. package/src/training/distillation/runtime.d.ts +20 -0
  104. package/src/training/distillation/runtime.js +121 -0
  105. package/src/training/distillation/scoreboard.d.ts +6 -0
  106. package/src/training/distillation/scoreboard.js +8 -0
  107. package/src/training/distillation/stage-a.d.ts +45 -0
  108. package/src/training/distillation/stage-a.js +338 -0
  109. package/src/training/distillation/stage-b.d.ts +24 -0
  110. package/src/training/distillation/stage-b.js +20 -0
  111. package/src/training/index.d.ts +10 -0
  112. package/src/training/index.js +10 -0
  113. package/src/training/lora-pipeline.d.ts +40 -0
  114. package/src/training/lora-pipeline.js +796 -0
  115. package/src/training/operator-artifacts.d.ts +62 -0
  116. package/src/training/operator-artifacts.js +140 -0
  117. package/src/training/operator-command.d.ts +5 -0
  118. package/src/training/operator-command.js +453 -0
  119. package/src/training/operator-eval.d.ts +48 -0
  120. package/src/training/operator-eval.js +230 -0
  121. package/src/training/operator-scoreboard.d.ts +5 -0
  122. package/src/training/operator-scoreboard.js +44 -0
  123. package/src/training/runner.d.ts +52 -0
  124. package/src/training/runner.js +29 -4
  125. package/src/training/suite.d.ts +112 -0
  126. package/src/training/suite.js +9 -9
  127. package/src/training/workloads.d.ts +164 -0
  128. package/src/training/workloads.js +539 -0
  129. package/src/version.js +1 -1
  130. package/tools/doppler-cli.js +137 -40
package/src/gpu/device.js CHANGED
@@ -28,6 +28,62 @@ function advanceDeviceEpoch() {
28
28
  deviceEpoch += 1;
29
29
  }
30
30
 
31
+ function isValidGPUBuffer(value) {
32
+ if (!value) {
33
+ return false;
34
+ }
35
+ if (typeof GPUBuffer === 'undefined') {
36
+ return true;
37
+ }
38
+ return value instanceof GPUBuffer;
39
+ }
40
+
41
+ function describeBindGroupBufferValue(value) {
42
+ if (value === null) return 'null';
43
+ if (value === undefined) return 'undefined';
44
+ if (typeof GPUBuffer !== 'undefined' && value instanceof GPUBuffer) return 'GPUBuffer';
45
+ if (typeof value === 'object') {
46
+ return value.constructor?.name || 'object';
47
+ }
48
+ return typeof value;
49
+ }
50
+
51
+ function validateBindGroupDescriptor(descriptor) {
52
+ const label = descriptor?.label || 'unlabeled_bind_group';
53
+ const entries = Array.isArray(descriptor?.entries) ? descriptor.entries : [];
54
+ for (const entry of entries) {
55
+ const resource = entry?.resource;
56
+ if (!resource || typeof resource !== 'object' || !('buffer' in resource)) {
57
+ continue;
58
+ }
59
+ if (isValidGPUBuffer(resource.buffer)) {
60
+ continue;
61
+ }
62
+ throw new Error(
63
+ `[${label}] binding ${entry.binding} requires a GPUBuffer; ` +
64
+ `got ${describeBindGroupBufferValue(resource.buffer)}.`
65
+ );
66
+ }
67
+ }
68
+
69
+ function wrapDeviceCreateBindGroup(device) {
70
+ if (!device || device.__dopplerBindGroupValidationWrapped) {
71
+ return device;
72
+ }
73
+ const originalCreateBindGroup = device.createBindGroup.bind(device);
74
+ device.createBindGroup = (descriptor) => {
75
+ validateBindGroupDescriptor(descriptor);
76
+ return originalCreateBindGroup(descriptor);
77
+ };
78
+ Object.defineProperty(device, '__dopplerBindGroupValidationWrapped', {
79
+ value: true,
80
+ configurable: true,
81
+ enumerable: false,
82
+ writable: false,
83
+ });
84
+ return device;
85
+ }
86
+
31
87
 
32
88
  export const FEATURES = ({
33
89
  SHADER_F16: 'shader-f16',
@@ -201,6 +257,7 @@ export async function initDevice() {
201
257
  if (!gpuDevice) {
202
258
  throw createDopplerError(ERROR_CODES.GPU_DEVICE_FAILED, 'Failed to create WebGPU device');
203
259
  }
260
+ wrapDeviceCreateBindGroup(gpuDevice);
204
261
  advanceDeviceEpoch();
205
262
 
206
263
  // Set up device lost handler
@@ -253,6 +310,7 @@ export function setDevice(device, options = {}) {
253
310
  }
254
311
 
255
312
  gpuDevice = device;
313
+ wrapDeviceCreateBindGroup(gpuDevice);
256
314
  advanceDeviceEpoch();
257
315
  wrapQueueForTracking(gpuDevice.queue);
258
316
 
@@ -780,6 +780,23 @@ function resolveAttentionExecution(recorder) {
780
780
  };
781
781
  }
782
782
 
783
+ function assertAttentionBindGroupBuffer(kernelName, variant, bindingIndex, bindingLabel, buffer, details = []) {
784
+ const isGpuBuffer = buffer && (
785
+ typeof GPUBuffer === 'undefined'
786
+ ? true
787
+ : buffer instanceof GPUBuffer
788
+ );
789
+ if (isGpuBuffer) {
790
+ return;
791
+ }
792
+ const detailText = details.filter(Boolean).join(', ');
793
+ throw new Error(
794
+ `[${kernelName}] variant="${variant}" binding ${bindingIndex} "${bindingLabel}" requires a GPUBuffer` +
795
+ (detailText ? ` (${detailText})` : '') +
796
+ '.'
797
+ );
798
+ }
799
+
783
800
  function releaseAttentionUniform(execution, uniformBuffer) {
784
801
  if (!execution.recorder) {
785
802
  releaseUniformBuffer(uniformBuffer);
@@ -867,6 +884,26 @@ async function executeAttentionBDPA(
867
884
  slidingWindow,
868
885
  });
869
886
 
887
+ assertAttentionBindGroupBuffer('attention_bdpa', variant, 0, 'uniforms', uniformBuffer);
888
+ assertAttentionBindGroupBuffer('attention_bdpa', variant, 1, 'Q', Q?.buffer, [
889
+ `QLabel=${Q?.label ?? 'unknown'}`,
890
+ `QDtype=${Q?.dtype ?? 'unknown'}`,
891
+ ]);
892
+ assertAttentionBindGroupBuffer('attention_bdpa', variant, 2, 'basisK', basisK?.buffer, [
893
+ `basisKLabel=${basisK?.label ?? 'unknown'}`,
894
+ `basisKDtype=${basisK?.dtype ?? 'unknown'}`,
895
+ ]);
896
+ assertAttentionBindGroupBuffer('attention_bdpa', variant, 3, 'basisV', basisV?.buffer, [
897
+ `basisVLabel=${basisV?.label ?? 'unknown'}`,
898
+ `basisVDtype=${basisV?.dtype ?? 'unknown'}`,
899
+ ]);
900
+ assertAttentionBindGroupBuffer('attention_bdpa', variant, 4, 'pagedK', pagedK);
901
+ assertAttentionBindGroupBuffer('attention_bdpa', variant, 5, 'pagedV', pagedV);
902
+ assertAttentionBindGroupBuffer('attention_bdpa', variant, 6, 'index', index);
903
+ assertAttentionBindGroupBuffer('attention_bdpa', variant, 7, 'ropeCos', ropeCos);
904
+ assertAttentionBindGroupBuffer('attention_bdpa', variant, 8, 'ropeSin', ropeSin);
905
+ assertAttentionBindGroupBuffer('attention_bdpa', variant, 9, 'output', outputBuf);
906
+
870
907
  const bindGroup = execution.device.createBindGroup({
871
908
  label: 'attention_bdpa_bind_group',
872
909
  layout: pipeline.getBindGroupLayout(0),
@@ -982,6 +1019,24 @@ async function executeAttention(
982
1019
 
983
1020
  const kvLenBinding = kvLenBuffer || getKvLenFallbackBuffer(execution.device);
984
1021
  const pageTableBinding = kvPageTable || getPageTableFallbackBuffer(execution.device);
1022
+ assertAttentionBindGroupBuffer('attention', plan.variant, 0, 'uniforms', uniformBuffer);
1023
+ assertAttentionBindGroupBuffer('attention', plan.variant, 1, 'Q', Q?.buffer, [
1024
+ `QLabel=${Q?.label ?? 'unknown'}`,
1025
+ `QDtype=${Q?.dtype ?? 'unknown'}`,
1026
+ ]);
1027
+ assertAttentionBindGroupBuffer('attention', plan.variant, 2, 'K', K?.buffer, [
1028
+ `KLabel=${K?.label ?? 'unknown'}`,
1029
+ `KDtype=${K?.dtype ?? 'unknown'}`,
1030
+ ]);
1031
+ assertAttentionBindGroupBuffer('attention', plan.variant, 3, 'V', V?.buffer, [
1032
+ `VLabel=${V?.label ?? 'unknown'}`,
1033
+ `VDtype=${V?.dtype ?? 'unknown'}`,
1034
+ ]);
1035
+ assertAttentionBindGroupBuffer('attention', plan.variant, 4, 'output', outputBuf);
1036
+ assertAttentionBindGroupBuffer('attention', plan.variant, 5, 'kvLen', kvLenBinding);
1037
+ assertAttentionBindGroupBuffer('attention', plan.variant, 6, 'pageTable', pageTableBinding, [
1038
+ `kvLayout=${kvLayout}`,
1039
+ ]);
985
1040
  const bindGroup = execution.device.createBindGroup({
986
1041
  label: 'attention_bind_group',
987
1042
  layout: pipeline.getBindGroupLayout(0),
@@ -1099,6 +1154,31 @@ async function executeAttentionTiered(
1099
1154
  });
1100
1155
 
1101
1156
  const pageTableBinding = coldPageTable || getPageTableFallbackBuffer(execution.device);
1157
+ assertAttentionBindGroupBuffer('attention_tiered', variant, 0, 'uniforms', uniformBuffer);
1158
+ assertAttentionBindGroupBuffer('attention_tiered', variant, 1, 'Q', Q?.buffer, [
1159
+ `QLabel=${Q?.label ?? 'unknown'}`,
1160
+ `QDtype=${Q?.dtype ?? 'unknown'}`,
1161
+ ]);
1162
+ assertAttentionBindGroupBuffer('attention_tiered', variant, 2, 'hotK', hotK?.buffer, [
1163
+ `hotKLabel=${hotK?.label ?? 'unknown'}`,
1164
+ `hotKDtype=${hotK?.dtype ?? 'unknown'}`,
1165
+ ]);
1166
+ assertAttentionBindGroupBuffer('attention_tiered', variant, 3, 'hotV', hotV?.buffer, [
1167
+ `hotVLabel=${hotV?.label ?? 'unknown'}`,
1168
+ `hotVDtype=${hotV?.dtype ?? 'unknown'}`,
1169
+ ]);
1170
+ assertAttentionBindGroupBuffer('attention_tiered', variant, 4, 'coldK', coldK?.buffer, [
1171
+ `coldKLabel=${coldK?.label ?? 'unknown'}`,
1172
+ `coldKDtype=${coldK?.dtype ?? 'unknown'}`,
1173
+ ]);
1174
+ assertAttentionBindGroupBuffer('attention_tiered', variant, 5, 'coldV', coldV?.buffer, [
1175
+ `coldVLabel=${coldV?.label ?? 'unknown'}`,
1176
+ `coldVDtype=${coldV?.dtype ?? 'unknown'}`,
1177
+ ]);
1178
+ assertAttentionBindGroupBuffer('attention_tiered', variant, 6, 'output', outputBuf);
1179
+ assertAttentionBindGroupBuffer('attention_tiered', variant, 7, 'pageTable', pageTableBinding, [
1180
+ `coldLayout=${coldLayout}`,
1181
+ ]);
1102
1182
  const bindGroup = execution.device.createBindGroup({
1103
1183
  label: 'attention_tiered_bind_group',
1104
1184
  layout: pipeline.getBindGroupLayout(0),
@@ -1200,6 +1280,24 @@ async function executeAttentionTieredQuant(
1200
1280
  packedStride,
1201
1281
  });
1202
1282
 
1283
+ assertAttentionBindGroupBuffer('attention_tiered_quant', variant, 0, 'uniforms', uniformBuffer);
1284
+ assertAttentionBindGroupBuffer('attention_tiered_quant', variant, 1, 'Q', Q?.buffer, [
1285
+ `QLabel=${Q?.label ?? 'unknown'}`,
1286
+ `QDtype=${Q?.dtype ?? 'unknown'}`,
1287
+ ]);
1288
+ assertAttentionBindGroupBuffer('attention_tiered_quant', variant, 2, 'hotK', hotK?.buffer, [
1289
+ `hotKLabel=${hotK?.label ?? 'unknown'}`,
1290
+ `hotKDtype=${hotK?.dtype ?? 'unknown'}`,
1291
+ ]);
1292
+ assertAttentionBindGroupBuffer('attention_tiered_quant', variant, 3, 'hotV', hotV?.buffer, [
1293
+ `hotVLabel=${hotV?.label ?? 'unknown'}`,
1294
+ `hotVDtype=${hotV?.dtype ?? 'unknown'}`,
1295
+ ]);
1296
+ assertAttentionBindGroupBuffer('attention_tiered_quant', variant, 4, 'coldPackedK', coldPackedK);
1297
+ assertAttentionBindGroupBuffer('attention_tiered_quant', variant, 5, 'coldPackedV', coldPackedV);
1298
+ assertAttentionBindGroupBuffer('attention_tiered_quant', variant, 6, 'coldScalesK', coldScalesK);
1299
+ assertAttentionBindGroupBuffer('attention_tiered_quant', variant, 7, 'coldScalesV', coldScalesV);
1300
+ assertAttentionBindGroupBuffer('attention_tiered_quant', variant, 8, 'output', outputBuf);
1203
1301
  const bindGroup = execution.device.createBindGroup({
1204
1302
  label: 'attention_tiered_quant_bind_group',
1205
1303
  layout: pipeline.getBindGroupLayout(0),
@@ -14,6 +14,10 @@ struct Uniforms {
14
14
  dim: u32,
15
15
  data_offset: u32, // byte offset into data buffer (divide by 4 for F32)
16
16
  bias_offset: u32, // byte offset into bias buffer (divide by 4 for F32)
17
+ token_stride: u32,
18
+ _pad0: u32,
19
+ _pad1: u32,
20
+ _pad2: u32,
17
21
  }
18
22
 
19
23
  override WORKGROUP_SIZE: u32 = 256u;
@@ -24,17 +28,15 @@ override WORKGROUP_SIZE: u32 = 256u;
24
28
 
25
29
  @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
26
30
  fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
27
- let idx = gid.x;
28
- let total = u.num_tokens * u.dim;
29
- if (idx >= total) {
31
+ let d = gid.x;
32
+ let token = gid.z * max(u.token_stride, 1u) + gid.y;
33
+ if (token >= u.num_tokens || d >= u.dim) {
30
34
  return;
31
35
  }
32
36
 
33
37
  // Convert byte offsets to F32 indices
34
38
  let data_base = u.data_offset / 4u;
35
39
  let bias_base = u.bias_offset / 4u;
36
-
37
- let d = idx % u.dim;
40
+ let idx = token * u.dim + d;
38
41
  data[data_base + idx] = data[data_base + idx] + bias[bias_base + d];
39
42
  }
40
-
@@ -18,6 +18,10 @@ struct Uniforms {
18
18
  dim: u32,
19
19
  data_offset: u32, // byte offset into data buffer (divide by 2 for F16)
20
20
  bias_offset: u32, // byte offset into bias buffer (divide by 2 for F16)
21
+ token_stride: u32,
22
+ _pad0: u32,
23
+ _pad1: u32,
24
+ _pad2: u32,
21
25
  }
22
26
 
23
27
  override WORKGROUP_SIZE: u32 = 256u;
@@ -28,17 +32,16 @@ override WORKGROUP_SIZE: u32 = 256u;
28
32
 
29
33
  @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
30
34
  fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
31
- let idx = gid.x;
32
- let total = u.num_tokens * u.dim;
33
- if (idx >= total) {
35
+ let d = gid.x;
36
+ let token = gid.z * max(u.token_stride, 1u) + gid.y;
37
+ if (token >= u.num_tokens || d >= u.dim) {
34
38
  return;
35
39
  }
36
40
 
37
41
  // Convert byte offsets to F16 indices
38
42
  let data_base = u.data_offset / 2u;
39
43
  let bias_base = u.bias_offset / 2u;
40
-
41
- let d = idx % u.dim;
44
+ let idx = token * u.dim + d;
42
45
  let out = f32(data[data_base + idx]) + f32(bias[bias_base + d]);
43
46
  data[data_base + idx] = f16(out);
44
47
  }
@@ -58,7 +58,7 @@ async function _conv2d(target, input, weight, bias, options = {}) {
58
58
  kernel_h: kernelH, kernel_w: kernelW,
59
59
  stride, pad, _pad0: 0, _pad1: 0,
60
60
  },
61
- Math.ceil((outChannels * outHeight * outWidth) / WORKGROUP_SIZES.DEFAULT)
61
+ [Math.ceil((outHeight * outWidth) / WORKGROUP_SIZES.DEFAULT), outChannels, 1]
62
62
  );
63
63
 
64
64
  if (tempBias) {
@@ -27,19 +27,18 @@ struct Uniforms {
27
27
 
28
28
  @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
29
29
  fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
30
- let idx = gid.x;
31
30
  let out_height = u.out_height;
32
31
  let out_width = u.out_width;
33
- let out_size = u.out_channels * out_height * out_width;
34
- if (idx >= out_size) {
32
+ let out_spatial = out_height * out_width;
33
+ let out_spatial_idx = gid.x;
34
+ let out_c = gid.y;
35
+ if (out_c >= u.out_channels || out_spatial_idx >= out_spatial) {
35
36
  return;
36
37
  }
37
38
 
38
- let out_spatial = out_height * out_width;
39
- let out_c = idx / out_spatial;
40
- let rem = idx - out_c * out_spatial;
41
- let out_y = rem / out_width;
42
- let out_x = rem - out_y * out_width;
39
+ let out_y = out_spatial_idx / out_width;
40
+ let out_x = out_spatial_idx - out_y * out_width;
41
+ let idx = out_c * out_spatial + out_spatial_idx;
43
42
 
44
43
  var sum: f32 = bias[out_c];
45
44
 
@@ -29,19 +29,18 @@ struct Uniforms {
29
29
 
30
30
  @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
31
31
  fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
32
- let idx = gid.x;
33
32
  let out_height = u.out_height;
34
33
  let out_width = u.out_width;
35
- let out_size = u.out_channels * out_height * out_width;
36
- if (idx >= out_size) {
34
+ let out_spatial = out_height * out_width;
35
+ let out_spatial_idx = gid.x;
36
+ let out_c = gid.y;
37
+ if (out_c >= u.out_channels || out_spatial_idx >= out_spatial) {
37
38
  return;
38
39
  }
39
40
 
40
- let out_spatial = out_height * out_width;
41
- let out_c = idx / out_spatial;
42
- let rem = idx - out_c * out_spatial;
43
- let out_y = rem / out_width;
44
- let out_x = rem - out_y * out_width;
41
+ let out_y = out_spatial_idx / out_width;
42
+ let out_x = out_spatial_idx - out_y * out_width;
43
+ let idx = out_c * out_spatial + out_spatial_idx;
45
44
 
46
45
  var sum: f32 = f32(bias[out_c]);
47
46
 
@@ -45,6 +45,7 @@ async function _depthwiseConv2D(target, input, weight, bias, options = {}) {
45
45
  const bytesPerElement = dtypeBytes(input.dtype);
46
46
  const outputSize = channels * outHeight * outWidth * bytesPerElement;
47
47
  const output = outputBuffer || acquireBuffer(outputSize, undefined, 'depthwise_conv2d_output');
48
+ const outSpatial = outHeight * outWidth;
48
49
 
49
50
  const weightBuffer = getBuffer(weight);
50
51
  let biasBuffer = getBuffer(bias);
@@ -75,7 +76,7 @@ async function _depthwiseConv2D(target, input, weight, bias, options = {}) {
75
76
  _pad0: 0,
76
77
  _pad1: 0,
77
78
  },
78
- Math.ceil((channels * outHeight * outWidth) / WORKGROUP_SIZES.DEFAULT)
79
+ [Math.ceil(outSpatial / WORKGROUP_SIZES.DEFAULT), channels, 1]
79
80
  );
80
81
 
81
82
  if (tempBias) {
@@ -23,17 +23,14 @@ struct Uniforms {
23
23
 
24
24
  @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
25
25
  fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
26
- let idx = gid.x;
27
26
  let out_spatial = u.out_height * u.out_width;
28
- let out_size = u.channels * out_spatial;
29
- if (idx >= out_size) {
27
+ let spatial_idx = gid.x;
28
+ let channel = gid.y;
29
+ if (spatial_idx >= out_spatial || channel >= u.channels) {
30
30
  return;
31
31
  }
32
-
33
- let channel = idx / out_spatial;
34
- let rem = idx - channel * out_spatial;
35
- let out_y = rem / u.out_width;
36
- let out_x = rem - out_y * u.out_width;
32
+ let out_y = spatial_idx / u.out_width;
33
+ let out_x = spatial_idx - out_y * u.out_width;
37
34
 
38
35
  var sum: f32 = bias[channel];
39
36
  let pad = i32(u.pad);
@@ -54,5 +51,5 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
54
51
  }
55
52
  }
56
53
 
57
- output[idx] = sum;
54
+ output[channel * out_spatial + spatial_idx] = sum;
58
55
  }
@@ -27,17 +27,14 @@ struct Uniforms {
27
27
 
28
28
  @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
29
29
  fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
30
- let idx = gid.x;
31
30
  let out_spatial = u.out_height * u.out_width;
32
- let out_size = u.channels * out_spatial;
33
- if (idx >= out_size) {
31
+ let spatial_idx = gid.x;
32
+ let channel = gid.y;
33
+ if (spatial_idx >= out_spatial || channel >= u.channels) {
34
34
  return;
35
35
  }
36
-
37
- let channel = idx / out_spatial;
38
- let rem = idx - channel * out_spatial;
39
- let out_y = rem / u.out_width;
40
- let out_x = rem - out_y * u.out_width;
36
+ let out_y = spatial_idx / u.out_width;
37
+ let out_x = spatial_idx - out_y * u.out_width;
41
38
 
42
39
  var sum: f32 = f32(bias[channel]);
43
40
  let pad = i32(u.pad);
@@ -58,5 +55,5 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
58
55
  }
59
56
  }
60
57
 
61
- output[idx] = f16(sum);
58
+ output[channel * out_spatial + spatial_idx] = f16(sum);
62
59
  }
@@ -42,6 +42,7 @@ async function _groupedPointwiseConv2D(target, input, weight, bias, options = {}
42
42
  const bytesPerElement = dtypeBytes(input.dtype);
43
43
  const outputSize = outChannels * height * width * bytesPerElement;
44
44
  const output = outputBuffer || acquireBuffer(outputSize, undefined, 'grouped_pointwise_conv2d_output');
45
+ const spatial = height * width;
45
46
 
46
47
  const weightBuffer = getBuffer(weight);
47
48
  let biasBuffer = getBuffer(bias);
@@ -69,7 +70,7 @@ async function _groupedPointwiseConv2D(target, input, weight, bias, options = {}
69
70
  _pad1: 0,
70
71
  _pad2: 0,
71
72
  },
72
- Math.ceil((outChannels * height * width) / WORKGROUP_SIZES.DEFAULT)
73
+ [Math.ceil(spatial / WORKGROUP_SIZES.DEFAULT), outChannels, 1]
73
74
  );
74
75
 
75
76
  if (tempBias) {
@@ -19,17 +19,14 @@ struct Uniforms {
19
19
 
20
20
  @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
21
21
  fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
22
- let idx = gid.x;
23
22
  let spatial = u.height * u.width;
24
- let out_size = u.out_channels * spatial;
25
- if (idx >= out_size) {
23
+ let spatial_idx = gid.x;
24
+ let out_channel = gid.y;
25
+ if (spatial_idx >= spatial || out_channel >= u.out_channels) {
26
26
  return;
27
27
  }
28
-
29
- let out_channel = idx / spatial;
30
- let rem = idx - out_channel * spatial;
31
- let y = rem / u.width;
32
- let x = rem - y * u.width;
28
+ let y = spatial_idx / u.width;
29
+ let x = spatial_idx - y * u.width;
33
30
 
34
31
  let in_per_group = u.in_channels / u.groups;
35
32
  let out_per_group = u.out_channels / u.groups;
@@ -43,5 +40,5 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
43
40
  sum = sum + input[input_idx] * weight[weight_idx];
44
41
  }
45
42
 
46
- output[idx] = sum;
43
+ output[out_channel * spatial + spatial_idx] = sum;
47
44
  }
@@ -23,17 +23,14 @@ struct Uniforms {
23
23
 
24
24
  @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
25
25
  fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
26
- let idx = gid.x;
27
26
  let spatial = u.height * u.width;
28
- let out_size = u.out_channels * spatial;
29
- if (idx >= out_size) {
27
+ let spatial_idx = gid.x;
28
+ let out_channel = gid.y;
29
+ if (spatial_idx >= spatial || out_channel >= u.out_channels) {
30
30
  return;
31
31
  }
32
-
33
- let out_channel = idx / spatial;
34
- let rem = idx - out_channel * spatial;
35
- let y = rem / u.width;
36
- let x = rem - y * u.width;
32
+ let y = spatial_idx / u.width;
33
+ let x = spatial_idx - y * u.width;
37
34
 
38
35
  let in_per_group = u.in_channels / u.groups;
39
36
  let out_per_group = u.out_channels / u.groups;
@@ -47,5 +44,5 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
47
44
  sum = sum + f32(input[input_idx]) * f32(weight[weight_idx]);
48
45
  }
49
46
 
50
- output[idx] = f16(sum);
47
+ output[out_channel * spatial + spatial_idx] = f16(sum);
51
48
  }
@@ -52,6 +52,23 @@ function buildProfileLabel(options = {}) {
52
52
  return `matmul${roleLabel}${layerLabel}`;
53
53
  }
54
54
 
55
+ function assertBindGroupBuffer(kernelName, variant, bindingIndex, bindingLabel, buffer, details = []) {
56
+ const isGpuBuffer = buffer && (
57
+ typeof GPUBuffer === 'undefined'
58
+ ? true
59
+ : buffer instanceof GPUBuffer
60
+ );
61
+ if (isGpuBuffer) {
62
+ return;
63
+ }
64
+ const detailText = details.filter(Boolean).join(', ');
65
+ throw new Error(
66
+ `[${kernelName}] variant="${variant}" binding ${bindingIndex} "${bindingLabel}" requires a GPUBuffer` +
67
+ (detailText ? ` (${detailText})` : '') +
68
+ '.'
69
+ );
70
+ }
71
+
55
72
  function createMatmulBindGroupEntries(variant, uniformBuffer, matmulInput, bBuffer, outputBuffer, offsets, bindingSizes) {
56
73
  const isQ4KF16 = variant === 'q4_fused_multicol_f16'
57
74
  || variant === 'q4_fused_f16a'
@@ -59,6 +76,14 @@ function createMatmulBindGroupEntries(variant, uniformBuffer, matmulInput, bBuff
59
76
  || variant === 'q4_fused_multicol_f16a'
60
77
  || variant === 'q4_fused_batched_f16a';
61
78
 
79
+ assertBindGroupBuffer('matmul', variant, 0, 'uniforms', uniformBuffer);
80
+ assertBindGroupBuffer('matmul', variant, 1, 'input', matmulInput?.buffer, [
81
+ `inputLabel=${matmulInput?.label ?? 'unknown'}`,
82
+ `inputDtype=${matmulInput?.dtype ?? 'unknown'}`,
83
+ ]);
84
+ assertBindGroupBuffer('matmul', variant, 2, 'weights', bBuffer);
85
+ assertBindGroupBuffer('matmul', variant, isQ4KF16 ? 4 : 3, 'output', outputBuffer);
86
+
62
87
  const entries = [
63
88
  { binding: 0, resource: { buffer: uniformBuffer } },
64
89
  { binding: 1, resource: { buffer: matmulInput.buffer, offset: offsets.aOffset, size: bindingSizes.aBindingSize } },
@@ -34,7 +34,7 @@ async function _pixelShuffle(target, input, options = {}) {
34
34
  grid_width: gridWidth, grid_height: gridHeight, patch_size: patchSize,
35
35
  patch_channels: inferredPatchChannels, _pad0: 0,
36
36
  },
37
- Math.ceil((outChannels * outHeight * outWidth) / WORKGROUP_SIZES.DEFAULT)
37
+ [Math.ceil((outHeight * outWidth) / WORKGROUP_SIZES.DEFAULT), outChannels, 1]
38
38
  );
39
39
 
40
40
  return createTensor(output, input.dtype, [outChannels, outHeight, outWidth], 'pixel_shuffle_output');
@@ -19,17 +19,16 @@ struct Uniforms {
19
19
 
20
20
  @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
21
21
  fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
22
- let idx = gid.x;
23
22
  let spatial_size = u.out_height * u.out_width;
24
- let total = u.out_channels * spatial_size;
25
- if (idx >= total) {
23
+ let spatial = gid.x;
24
+ let c = gid.y;
25
+ if (c >= u.out_channels || spatial >= spatial_size) {
26
26
  return;
27
27
  }
28
28
 
29
- let c = idx / spatial_size;
30
- let spatial = idx - c * spatial_size;
31
29
  let y = spatial / u.out_width;
32
30
  let x = spatial - y * u.out_width;
31
+ let idx = c * spatial_size + spatial;
33
32
 
34
33
  let grid_y = y / u.patch_size;
35
34
  let grid_x = x / u.patch_size;
@@ -22,17 +22,16 @@ struct Uniforms {
22
22
 
23
23
  @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
24
24
  fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
25
- let idx = gid.x;
26
25
  let spatial_size = u.out_height * u.out_width;
27
- let total = u.out_channels * spatial_size;
28
- if (idx >= total) {
26
+ let spatial = gid.x;
27
+ let c = gid.y;
28
+ if (c >= u.out_channels || spatial >= spatial_size) {
29
29
  return;
30
30
  }
31
31
 
32
- let c = idx / spatial_size;
33
- let spatial = idx - c * spatial_size;
34
32
  let y = spatial / u.out_width;
35
33
  let x = spatial - y * u.out_width;
34
+ let idx = c * spatial_size + spatial;
36
35
 
37
36
  let grid_y = y / u.patch_size;
38
37
  let grid_x = x / u.patch_size;
@@ -18,19 +18,32 @@ function resolveCount(input, countOverride) {
18
18
  return Math.floor(input.buffer.size / dtypeBytes(input.dtype));
19
19
  }
20
20
 
21
+ function planReluDispatch(target, size) {
22
+ const device = target?.device;
23
+ const maxPerDim = Number.isFinite(device?.limits?.maxComputeWorkgroupsPerDimension)
24
+ ? device.limits.maxComputeWorkgroupsPerDimension
25
+ : 65535;
26
+ const dispatchStride = Math.min(size, maxPerDim * WORKGROUP_SIZES.DEFAULT);
27
+ return {
28
+ dispatchStride,
29
+ workgroups: [Math.ceil(dispatchStride / WORKGROUP_SIZES.DEFAULT), 1, 1],
30
+ };
31
+ }
32
+
21
33
  async function _relu(target, input, options = {}) {
22
34
  const { count = null, outputBuffer = null } = options;
23
35
  const size = resolveCount(input, count);
24
36
  const variant = selectReluVariant(input.dtype);
25
37
  const output = outputBuffer || acquireBuffer(size * dtypeBytes(input.dtype), undefined, 'relu_output');
38
+ const dispatchPlan = planReluDispatch(target, size);
26
39
 
27
40
  await unifiedKernelWrapper(
28
41
  'relu',
29
42
  target,
30
43
  variant,
31
44
  [input, output],
32
- { size, _pad0: 0, _pad1: 0, _pad2: 0 },
33
- Math.ceil(size / WORKGROUP_SIZES.DEFAULT)
45
+ { size, _pad0: dispatchPlan.dispatchStride, _pad1: 0, _pad2: 0 },
46
+ dispatchPlan.workgroups
34
47
  );
35
48
 
36
49
  return createTensor(output, input.dtype, [...input.shape], 'relu_output');
@@ -13,7 +13,8 @@ struct Uniforms {
13
13
 
14
14
  @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
15
15
  fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
16
- let idx = gid.x;
16
+ let dispatch_stride = max(u._pad0, 1u);
17
+ let idx = gid.y * dispatch_stride + gid.x;
17
18
  if (idx >= u.size) {
18
19
  return;
19
20
  }
@@ -15,7 +15,8 @@ struct Uniforms {
15
15
 
16
16
  @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
17
17
  fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
18
- let idx = gid.x;
18
+ let dispatch_stride = max(u._pad0, 1u);
19
+ let idx = gid.y * dispatch_stride + gid.x;
19
20
  if (idx >= u.size) {
20
21
  return;
21
22
  }
@@ -45,7 +45,7 @@ async function _repeatChannels(target, input, options = {}) {
45
45
  repeats,
46
46
  _pad0: 0,
47
47
  },
48
- Math.ceil((outChannels * height * width) / WORKGROUP_SIZES.DEFAULT)
48
+ [Math.ceil((height * width) / WORKGROUP_SIZES.DEFAULT), outChannels, 1]
49
49
  );
50
50
 
51
51
  return createTensor(output, input.dtype, [outChannels, height, width], 'repeat_channels_output');