@simulatte/doppler 0.1.4 → 0.1.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (199) hide show
  1. package/README.md +26 -10
  2. package/package.json +30 -6
  3. package/src/client/doppler-api.browser.d.ts +1 -0
  4. package/src/client/doppler-api.browser.js +288 -0
  5. package/src/client/doppler-api.js +1 -1
  6. package/src/client/doppler-provider/types.js +1 -1
  7. package/src/config/execution-contract-check.d.ts +33 -0
  8. package/src/config/execution-contract-check.js +72 -0
  9. package/src/config/execution-v0-contract-check.d.ts +94 -0
  10. package/src/config/execution-v0-contract-check.js +251 -0
  11. package/src/config/execution-v0-graph-contract-check.d.ts +20 -0
  12. package/src/config/execution-v0-graph-contract-check.js +64 -0
  13. package/src/config/kernel-path-contract-check.d.ts +76 -0
  14. package/src/config/kernel-path-contract-check.js +479 -0
  15. package/src/config/kernel-path-loader.d.ts +16 -0
  16. package/src/config/kernel-path-loader.js +54 -0
  17. package/src/config/kernels/kernel-ref-digests.js +39 -27
  18. package/src/config/kernels/registry.json +598 -2
  19. package/src/config/loader.js +81 -48
  20. package/src/config/merge-contract-check.d.ts +16 -0
  21. package/src/config/merge-contract-check.js +321 -0
  22. package/src/config/merge-helpers.d.ts +58 -0
  23. package/src/config/merge-helpers.js +54 -0
  24. package/src/config/merge.js +21 -6
  25. package/src/config/presets/models/janus-text.json +2 -0
  26. package/src/config/presets/models/qwen3.json +9 -2
  27. package/src/config/presets/models/transformer.json +5 -0
  28. package/src/config/quantization-contract-check.d.ts +12 -0
  29. package/src/config/quantization-contract-check.js +91 -0
  30. package/src/config/required-inference-fields-contract-check.d.ts +24 -0
  31. package/src/config/required-inference-fields-contract-check.js +237 -0
  32. package/src/config/schema/browser-suite-metrics.schema.d.ts +17 -0
  33. package/src/config/schema/browser-suite-metrics.schema.js +46 -0
  34. package/src/config/schema/conversion-report.schema.d.ts +40 -0
  35. package/src/config/schema/conversion-report.schema.js +108 -0
  36. package/src/config/schema/doppler.schema.js +12 -18
  37. package/src/config/schema/index.d.ts +22 -0
  38. package/src/config/schema/index.js +18 -0
  39. package/src/config/schema/inference-defaults.schema.js +3 -0
  40. package/src/config/schema/inference.schema.d.ts +9 -0
  41. package/src/config/schema/kernel-path.schema.d.ts +6 -0
  42. package/src/config/schema/manifest.schema.d.ts +6 -0
  43. package/src/config/schema/manifest.schema.js +3 -0
  44. package/src/converter/core.d.ts +10 -0
  45. package/src/converter/core.js +27 -2
  46. package/src/converter/parsers/diffusion.js +63 -3
  47. package/src/converter/rope-config.js +42 -0
  48. package/src/gpu/device.js +58 -0
  49. package/src/gpu/kernels/attention.js +98 -0
  50. package/src/gpu/kernels/bias_add.wgsl +8 -6
  51. package/src/gpu/kernels/bias_add_f16.wgsl +8 -5
  52. package/src/gpu/kernels/conv2d.js +1 -1
  53. package/src/gpu/kernels/conv2d.wgsl +7 -8
  54. package/src/gpu/kernels/conv2d_f16.wgsl +7 -8
  55. package/src/gpu/kernels/depthwise_conv2d.d.ts +29 -0
  56. package/src/gpu/kernels/depthwise_conv2d.js +99 -0
  57. package/src/gpu/kernels/depthwise_conv2d.wgsl +55 -0
  58. package/src/gpu/kernels/depthwise_conv2d_f16.wgsl +59 -0
  59. package/src/gpu/kernels/grouped_pointwise_conv2d.d.ts +27 -0
  60. package/src/gpu/kernels/grouped_pointwise_conv2d.js +93 -0
  61. package/src/gpu/kernels/grouped_pointwise_conv2d.wgsl +44 -0
  62. package/src/gpu/kernels/grouped_pointwise_conv2d_f16.wgsl +48 -0
  63. package/src/gpu/kernels/index.d.ts +30 -0
  64. package/src/gpu/kernels/index.js +25 -0
  65. package/src/gpu/kernels/matmul.js +25 -0
  66. package/src/gpu/kernels/pixel_shuffle.js +1 -1
  67. package/src/gpu/kernels/pixel_shuffle.wgsl +4 -5
  68. package/src/gpu/kernels/pixel_shuffle_f16.wgsl +4 -5
  69. package/src/gpu/kernels/relu.d.ts +18 -0
  70. package/src/gpu/kernels/relu.js +58 -0
  71. package/src/gpu/kernels/relu.wgsl +22 -0
  72. package/src/gpu/kernels/relu_f16.wgsl +24 -0
  73. package/src/gpu/kernels/repeat_channels.d.ts +21 -0
  74. package/src/gpu/kernels/repeat_channels.js +60 -0
  75. package/src/gpu/kernels/repeat_channels.wgsl +28 -0
  76. package/src/gpu/kernels/repeat_channels_f16.wgsl +30 -0
  77. package/src/gpu/kernels/residual.js +44 -8
  78. package/src/gpu/kernels/residual.wgsl +6 -3
  79. package/src/gpu/kernels/residual_f16.wgsl +2 -1
  80. package/src/gpu/kernels/residual_f16_vec4.wgsl +2 -1
  81. package/src/gpu/kernels/residual_vec4.wgsl +2 -1
  82. package/src/gpu/kernels/rmsnorm.js +58 -6
  83. package/src/gpu/kernels/rmsnorm.wgsl +14 -6
  84. package/src/gpu/kernels/rmsnorm_f16.wgsl +10 -2
  85. package/src/gpu/kernels/rope.d.ts +2 -0
  86. package/src/gpu/kernels/rope.js +11 -1
  87. package/src/gpu/kernels/rope.wgsl +56 -40
  88. package/src/gpu/kernels/sana_linear_attention.d.ts +27 -0
  89. package/src/gpu/kernels/sana_linear_attention.js +121 -0
  90. package/src/gpu/kernels/sana_linear_attention_apply.wgsl +43 -0
  91. package/src/gpu/kernels/sana_linear_attention_apply_f16.wgsl +46 -0
  92. package/src/gpu/kernels/sana_linear_attention_summary.wgsl +51 -0
  93. package/src/gpu/kernels/sana_linear_attention_summary_f16.wgsl +53 -0
  94. package/src/gpu/kernels/silu.d.ts +1 -0
  95. package/src/gpu/kernels/silu.js +32 -14
  96. package/src/gpu/kernels/silu.wgsl +19 -9
  97. package/src/gpu/kernels/silu_f16.wgsl +19 -9
  98. package/src/gpu/kernels/transpose.js +15 -2
  99. package/src/gpu/kernels/transpose.wgsl +5 -6
  100. package/src/gpu/kernels/upsample2d.js +2 -1
  101. package/src/gpu/kernels/upsample2d.wgsl +6 -9
  102. package/src/gpu/kernels/upsample2d_f16.wgsl +6 -9
  103. package/src/gpu/kernels/utils.js +16 -1
  104. package/src/index-browser.d.ts +1 -1
  105. package/src/index-browser.js +2 -2
  106. package/src/index.js +1 -1
  107. package/src/inference/browser-harness.js +109 -23
  108. package/src/inference/pipelines/diffusion/init.js +14 -0
  109. package/src/inference/pipelines/diffusion/pipeline.js +215 -77
  110. package/src/inference/pipelines/diffusion/sana-transformer.d.ts +53 -0
  111. package/src/inference/pipelines/diffusion/sana-transformer.js +738 -0
  112. package/src/inference/pipelines/diffusion/scheduler.d.ts +17 -1
  113. package/src/inference/pipelines/diffusion/scheduler.js +91 -3
  114. package/src/inference/pipelines/diffusion/text-encoder-gpu.d.ts +11 -4
  115. package/src/inference/pipelines/diffusion/text-encoder-gpu.js +282 -0
  116. package/src/inference/pipelines/diffusion/text-encoder.js +18 -1
  117. package/src/inference/pipelines/diffusion/types.d.ts +4 -0
  118. package/src/inference/pipelines/diffusion/vae.js +782 -78
  119. package/src/inference/pipelines/text/attention/record.js +11 -2
  120. package/src/inference/pipelines/text/attention/run.js +11 -2
  121. package/src/inference/pipelines/text/chat-format.js +25 -1
  122. package/src/inference/pipelines/text/config.d.ts +9 -0
  123. package/src/inference/pipelines/text/config.js +69 -2
  124. package/src/inference/pipelines/text/execution-plan.js +23 -31
  125. package/src/inference/pipelines/text/execution-v0.js +43 -95
  126. package/src/inference/pipelines/text/ffn/standard.js +3 -0
  127. package/src/inference/pipelines/text/init.d.ts +4 -0
  128. package/src/inference/pipelines/text/init.js +56 -9
  129. package/src/inference/pipelines/text/layer.js +11 -0
  130. package/src/inference/pipelines/text.js +4 -0
  131. package/src/inference/tokenizers/bundled.js +156 -33
  132. package/src/rules/execution-rules-contract-check.d.ts +17 -0
  133. package/src/rules/execution-rules-contract-check.js +245 -0
  134. package/src/rules/kernels/depthwise-conv2d.rules.json +6 -0
  135. package/src/rules/kernels/grouped-pointwise-conv2d.rules.json +6 -0
  136. package/src/rules/kernels/relu.rules.json +6 -0
  137. package/src/rules/kernels/repeat-channels.rules.json +6 -0
  138. package/src/rules/kernels/sana-linear-attention.rules.json +6 -0
  139. package/src/rules/layer-pattern-contract-check.d.ts +17 -0
  140. package/src/rules/layer-pattern-contract-check.js +231 -0
  141. package/src/rules/rule-registry.d.ts +28 -0
  142. package/src/rules/rule-registry.js +38 -0
  143. package/src/rules/tooling/command-runtime.rules.json +18 -0
  144. package/src/tooling/command-api.d.ts +27 -1
  145. package/src/tooling/command-api.js +142 -3
  146. package/src/tooling/conversion-config-materializer.d.ts +24 -0
  147. package/src/tooling/conversion-config-materializer.js +99 -0
  148. package/src/tooling/lean-execution-contract-runner.d.ts +43 -0
  149. package/src/tooling/lean-execution-contract-runner.js +158 -0
  150. package/src/tooling/node-browser-command-runner.d.ts +4 -0
  151. package/src/tooling/node-browser-command-runner.js +58 -3
  152. package/src/tooling/node-command-runner.js +15 -0
  153. package/src/tooling/node-convert.d.ts +10 -0
  154. package/src/tooling/node-converter.js +59 -0
  155. package/src/tooling/node-webgpu.js +11 -89
  156. package/src/training/checkpoint-watch.d.ts +7 -0
  157. package/src/training/checkpoint-watch.js +106 -0
  158. package/src/training/checkpoint.d.ts +6 -1
  159. package/src/training/checkpoint.js +12 -2
  160. package/src/training/distillation/artifacts.d.ts +71 -0
  161. package/src/training/distillation/artifacts.js +132 -0
  162. package/src/training/distillation/checkpoint-watch.d.ts +10 -0
  163. package/src/training/distillation/checkpoint-watch.js +57 -0
  164. package/src/training/distillation/dataset.d.ts +59 -0
  165. package/src/training/distillation/dataset.js +337 -0
  166. package/src/training/distillation/eval.d.ts +34 -0
  167. package/src/training/distillation/eval.js +310 -0
  168. package/src/training/distillation/index.d.ts +29 -0
  169. package/src/training/distillation/index.js +29 -0
  170. package/src/training/distillation/runtime.d.ts +20 -0
  171. package/src/training/distillation/runtime.js +121 -0
  172. package/src/training/distillation/scoreboard.d.ts +6 -0
  173. package/src/training/distillation/scoreboard.js +8 -0
  174. package/src/training/distillation/stage-a.d.ts +45 -0
  175. package/src/training/distillation/stage-a.js +338 -0
  176. package/src/training/distillation/stage-b.d.ts +24 -0
  177. package/src/training/distillation/stage-b.js +20 -0
  178. package/src/training/index.d.ts +10 -0
  179. package/src/training/index.js +10 -0
  180. package/src/training/lora-pipeline.d.ts +40 -0
  181. package/src/training/lora-pipeline.js +796 -0
  182. package/src/training/operator-artifacts.d.ts +62 -0
  183. package/src/training/operator-artifacts.js +140 -0
  184. package/src/training/operator-command.d.ts +5 -0
  185. package/src/training/operator-command.js +453 -0
  186. package/src/training/operator-eval.d.ts +48 -0
  187. package/src/training/operator-eval.js +230 -0
  188. package/src/training/operator-scoreboard.d.ts +5 -0
  189. package/src/training/operator-scoreboard.js +44 -0
  190. package/src/training/runner.d.ts +52 -0
  191. package/src/training/runner.js +29 -4
  192. package/src/training/suite.d.ts +112 -0
  193. package/src/training/suite.js +9 -9
  194. package/src/training/workloads.d.ts +164 -0
  195. package/src/training/workloads.js +539 -0
  196. package/src/version.d.ts +2 -0
  197. package/src/version.js +2 -0
  198. package/tools/convert-safetensors-node.js +47 -0
  199. package/tools/doppler-cli.js +252 -41
@@ -0,0 +1,48 @@
1
+ // Grouped Pointwise Conv2D Kernel (NCHW, f16)
2
+
3
+ enable f16;
4
+
5
+ override WORKGROUP_SIZE: u32 = 256u;
6
+
7
+ struct Uniforms {
8
+ in_channels: u32,
9
+ out_channels: u32,
10
+ height: u32,
11
+ width: u32,
12
+ groups: u32,
13
+ _pad0: u32,
14
+ _pad1: u32,
15
+ _pad2: u32,
16
+ }
17
+
18
+ @group(0) @binding(0) var<uniform> u: Uniforms;
19
+ @group(0) @binding(1) var<storage, read> input: array<f16>;
20
+ @group(0) @binding(2) var<storage, read> weight: array<f16>;
21
+ @group(0) @binding(3) var<storage, read> bias: array<f16>;
22
+ @group(0) @binding(4) var<storage, read_write> output: array<f16>;
23
+
24
+ @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
25
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
26
+ let spatial = u.height * u.width;
27
+ let spatial_idx = gid.x;
28
+ let out_channel = gid.y;
29
+ if (spatial_idx >= spatial || out_channel >= u.out_channels) {
30
+ return;
31
+ }
32
+ let y = spatial_idx / u.width;
33
+ let x = spatial_idx - y * u.width;
34
+
35
+ let in_per_group = u.in_channels / u.groups;
36
+ let out_per_group = u.out_channels / u.groups;
37
+ let group_idx = out_channel / out_per_group;
38
+ let in_offset = group_idx * in_per_group;
39
+
40
+ var sum: f32 = f32(bias[out_channel]);
41
+ for (var i: u32 = 0u; i < in_per_group; i = i + 1u) {
42
+ let input_idx = ((in_offset + i) * u.height + y) * u.width + x;
43
+ let weight_idx = out_channel * in_per_group + i;
44
+ sum = sum + f32(input[input_idx]) * f32(weight[weight_idx]);
45
+ }
46
+
47
+ output[out_channel * spatial + spatial_idx] = f16(sum);
48
+ }
@@ -174,6 +174,18 @@ export {
174
174
  type Conv2DOptions,
175
175
  } from './conv2d.js';
176
176
 
177
+ export {
178
+ runDepthwiseConv2D,
179
+ recordDepthwiseConv2D,
180
+ type DepthwiseConv2DOptions,
181
+ } from './depthwise_conv2d.js';
182
+
183
+ export {
184
+ runGroupedPointwiseConv2D,
185
+ recordGroupedPointwiseConv2D,
186
+ type GroupedPointwiseConv2DOptions,
187
+ } from './grouped_pointwise_conv2d.js';
188
+
177
189
  // Gather (Embedding Lookup)
178
190
  export {
179
191
  runGather,
@@ -250,6 +262,24 @@ export {
250
262
  type SampleResult,
251
263
  } from './sample.js';
252
264
 
265
+ export {
266
+ runSanaLinearAttention,
267
+ recordSanaLinearAttention,
268
+ type SanaLinearAttentionOptions,
269
+ } from './sana_linear_attention.js';
270
+
271
+ export {
272
+ runRepeatChannels,
273
+ recordRepeatChannels,
274
+ type RepeatChannelsOptions,
275
+ } from './repeat_channels.js';
276
+
277
+ export {
278
+ runReLU,
279
+ recordReLU,
280
+ type ReLUOptions,
281
+ } from './relu.js';
282
+
253
283
  // Fused FFN (Tier 2 P0)
254
284
  export {
255
285
  runFusedFFN,
@@ -139,6 +139,16 @@ export {
139
139
  recordConv2D,
140
140
  } from './conv2d.js';
141
141
 
142
+ export {
143
+ runDepthwiseConv2D,
144
+ recordDepthwiseConv2D,
145
+ } from './depthwise_conv2d.js';
146
+
147
+ export {
148
+ runGroupedPointwiseConv2D,
149
+ recordGroupedPointwiseConv2D,
150
+ } from './grouped_pointwise_conv2d.js';
151
+
142
152
  // Gather (Embedding Lookup)
143
153
  export {
144
154
  runGather,
@@ -205,6 +215,21 @@ export {
205
215
  isGPUSamplingAvailable,
206
216
  } from './sample.js';
207
217
 
218
+ export {
219
+ runSanaLinearAttention,
220
+ recordSanaLinearAttention,
221
+ } from './sana_linear_attention.js';
222
+
223
+ export {
224
+ runRepeatChannels,
225
+ recordRepeatChannels,
226
+ } from './repeat_channels.js';
227
+
228
+ export {
229
+ runReLU,
230
+ recordReLU,
231
+ } from './relu.js';
232
+
208
233
  // Fused FFN (Tier 2 P0)
209
234
  export {
210
235
  runFusedFFN,
@@ -52,6 +52,23 @@ function buildProfileLabel(options = {}) {
52
52
  return `matmul${roleLabel}${layerLabel}`;
53
53
  }
54
54
 
55
+ function assertBindGroupBuffer(kernelName, variant, bindingIndex, bindingLabel, buffer, details = []) {
56
+ const isGpuBuffer = buffer && (
57
+ typeof GPUBuffer === 'undefined'
58
+ ? true
59
+ : buffer instanceof GPUBuffer
60
+ );
61
+ if (isGpuBuffer) {
62
+ return;
63
+ }
64
+ const detailText = details.filter(Boolean).join(', ');
65
+ throw new Error(
66
+ `[${kernelName}] variant="${variant}" binding ${bindingIndex} "${bindingLabel}" requires a GPUBuffer` +
67
+ (detailText ? ` (${detailText})` : '') +
68
+ '.'
69
+ );
70
+ }
71
+
55
72
  function createMatmulBindGroupEntries(variant, uniformBuffer, matmulInput, bBuffer, outputBuffer, offsets, bindingSizes) {
56
73
  const isQ4KF16 = variant === 'q4_fused_multicol_f16'
57
74
  || variant === 'q4_fused_f16a'
@@ -59,6 +76,14 @@ function createMatmulBindGroupEntries(variant, uniformBuffer, matmulInput, bBuff
59
76
  || variant === 'q4_fused_multicol_f16a'
60
77
  || variant === 'q4_fused_batched_f16a';
61
78
 
79
+ assertBindGroupBuffer('matmul', variant, 0, 'uniforms', uniformBuffer);
80
+ assertBindGroupBuffer('matmul', variant, 1, 'input', matmulInput?.buffer, [
81
+ `inputLabel=${matmulInput?.label ?? 'unknown'}`,
82
+ `inputDtype=${matmulInput?.dtype ?? 'unknown'}`,
83
+ ]);
84
+ assertBindGroupBuffer('matmul', variant, 2, 'weights', bBuffer);
85
+ assertBindGroupBuffer('matmul', variant, isQ4KF16 ? 4 : 3, 'output', outputBuffer);
86
+
62
87
  const entries = [
63
88
  { binding: 0, resource: { buffer: uniformBuffer } },
64
89
  { binding: 1, resource: { buffer: matmulInput.buffer, offset: offsets.aOffset, size: bindingSizes.aBindingSize } },
@@ -34,7 +34,7 @@ async function _pixelShuffle(target, input, options = {}) {
34
34
  grid_width: gridWidth, grid_height: gridHeight, patch_size: patchSize,
35
35
  patch_channels: inferredPatchChannels, _pad0: 0,
36
36
  },
37
- Math.ceil((outChannels * outHeight * outWidth) / WORKGROUP_SIZES.DEFAULT)
37
+ [Math.ceil((outHeight * outWidth) / WORKGROUP_SIZES.DEFAULT), outChannels, 1]
38
38
  );
39
39
 
40
40
  return createTensor(output, input.dtype, [outChannels, outHeight, outWidth], 'pixel_shuffle_output');
@@ -19,17 +19,16 @@ struct Uniforms {
19
19
 
20
20
  @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
21
21
  fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
22
- let idx = gid.x;
23
22
  let spatial_size = u.out_height * u.out_width;
24
- let total = u.out_channels * spatial_size;
25
- if (idx >= total) {
23
+ let spatial = gid.x;
24
+ let c = gid.y;
25
+ if (c >= u.out_channels || spatial >= spatial_size) {
26
26
  return;
27
27
  }
28
28
 
29
- let c = idx / spatial_size;
30
- let spatial = idx - c * spatial_size;
31
29
  let y = spatial / u.out_width;
32
30
  let x = spatial - y * u.out_width;
31
+ let idx = c * spatial_size + spatial;
33
32
 
34
33
  let grid_y = y / u.patch_size;
35
34
  let grid_x = x / u.patch_size;
@@ -22,17 +22,16 @@ struct Uniforms {
22
22
 
23
23
  @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
24
24
  fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
25
- let idx = gid.x;
26
25
  let spatial_size = u.out_height * u.out_width;
27
- let total = u.out_channels * spatial_size;
28
- if (idx >= total) {
26
+ let spatial = gid.x;
27
+ let c = gid.y;
28
+ if (c >= u.out_channels || spatial >= spatial_size) {
29
29
  return;
30
30
  }
31
31
 
32
- let c = idx / spatial_size;
33
- let spatial = idx - c * spatial_size;
34
32
  let y = spatial / u.out_width;
35
33
  let x = spatial - y * u.out_width;
34
+ let idx = c * spatial_size + spatial;
36
35
 
37
36
  let grid_y = y / u.patch_size;
38
37
  let grid_x = x / u.patch_size;
@@ -0,0 +1,18 @@
1
+ import type { Tensor } from '../tensor.js';
2
+ import type { CommandRecorder } from '../command-recorder.js';
3
+ import type { OutputBufferOptions } from './types.js';
4
+
5
+ export interface ReLUOptions extends OutputBufferOptions {
6
+ count?: number | null;
7
+ }
8
+
9
+ export declare function runReLU(
10
+ input: Tensor,
11
+ options?: ReLUOptions
12
+ ): Promise<Tensor>;
13
+
14
+ export declare function recordReLU(
15
+ recorder: CommandRecorder,
16
+ input: Tensor,
17
+ options?: ReLUOptions
18
+ ): Promise<Tensor>;
@@ -0,0 +1,58 @@
1
+ import { acquireBuffer } from '../../memory/buffer-pool.js';
2
+ import { createTensor, dtypeBytes } from '../tensor.js';
3
+ import { unifiedKernelWrapper } from './utils.js';
4
+ import { selectRuleValue } from './rule-registry.js';
5
+ import { WORKGROUP_SIZES } from './constants.js';
6
+
7
+ function selectReluVariant(dtype) {
8
+ return selectRuleValue('relu', 'variant', { dtype });
9
+ }
10
+
11
+ function resolveCount(input, countOverride) {
12
+ if (Number.isFinite(countOverride) && countOverride > 0) {
13
+ return Math.floor(countOverride);
14
+ }
15
+ if (Array.isArray(input.shape) && input.shape.length > 0) {
16
+ return input.shape.reduce((acc, value) => acc * value, 1);
17
+ }
18
+ return Math.floor(input.buffer.size / dtypeBytes(input.dtype));
19
+ }
20
+
21
+ function planReluDispatch(target, size) {
22
+ const device = target?.device;
23
+ const maxPerDim = Number.isFinite(device?.limits?.maxComputeWorkgroupsPerDimension)
24
+ ? device.limits.maxComputeWorkgroupsPerDimension
25
+ : 65535;
26
+ const dispatchStride = Math.min(size, maxPerDim * WORKGROUP_SIZES.DEFAULT);
27
+ return {
28
+ dispatchStride,
29
+ workgroups: [Math.ceil(dispatchStride / WORKGROUP_SIZES.DEFAULT), 1, 1],
30
+ };
31
+ }
32
+
33
+ async function _relu(target, input, options = {}) {
34
+ const { count = null, outputBuffer = null } = options;
35
+ const size = resolveCount(input, count);
36
+ const variant = selectReluVariant(input.dtype);
37
+ const output = outputBuffer || acquireBuffer(size * dtypeBytes(input.dtype), undefined, 'relu_output');
38
+ const dispatchPlan = planReluDispatch(target, size);
39
+
40
+ await unifiedKernelWrapper(
41
+ 'relu',
42
+ target,
43
+ variant,
44
+ [input, output],
45
+ { size, _pad0: dispatchPlan.dispatchStride, _pad1: 0, _pad2: 0 },
46
+ dispatchPlan.workgroups
47
+ );
48
+
49
+ return createTensor(output, input.dtype, [...input.shape], 'relu_output');
50
+ }
51
+
52
+ export async function runReLU(input, options = {}) {
53
+ return _relu(null, input, options);
54
+ }
55
+
56
+ export async function recordReLU(recorder, input, options = {}) {
57
+ return _relu(recorder, input, options);
58
+ }
@@ -0,0 +1,22 @@
1
+ override WORKGROUP_SIZE: u32 = 256u;
2
+
3
+ struct Uniforms {
4
+ size: u32,
5
+ _pad0: u32,
6
+ _pad1: u32,
7
+ _pad2: u32,
8
+ }
9
+
10
+ @group(0) @binding(0) var<uniform> u: Uniforms;
11
+ @group(0) @binding(1) var<storage, read> input: array<f32>;
12
+ @group(0) @binding(2) var<storage, read_write> output: array<f32>;
13
+
14
+ @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
15
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
16
+ let dispatch_stride = max(u._pad0, 1u);
17
+ let idx = gid.y * dispatch_stride + gid.x;
18
+ if (idx >= u.size) {
19
+ return;
20
+ }
21
+ output[idx] = max(input[idx], 0.0);
22
+ }
@@ -0,0 +1,24 @@
1
+ enable f16;
2
+
3
+ override WORKGROUP_SIZE: u32 = 256u;
4
+
5
+ struct Uniforms {
6
+ size: u32,
7
+ _pad0: u32,
8
+ _pad1: u32,
9
+ _pad2: u32,
10
+ }
11
+
12
+ @group(0) @binding(0) var<uniform> u: Uniforms;
13
+ @group(0) @binding(1) var<storage, read> input: array<f16>;
14
+ @group(0) @binding(2) var<storage, read_write> output: array<f16>;
15
+
16
+ @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
17
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
18
+ let dispatch_stride = max(u._pad0, 1u);
19
+ let idx = gid.y * dispatch_stride + gid.x;
20
+ if (idx >= u.size) {
21
+ return;
22
+ }
23
+ output[idx] = max(input[idx], 0.0h);
24
+ }
@@ -0,0 +1,21 @@
1
+ import type { Tensor } from '../tensor.js';
2
+ import type { CommandRecorder } from '../command-recorder.js';
3
+ import type { OutputBufferOptions } from './types.js';
4
+
5
+ export interface RepeatChannelsOptions extends OutputBufferOptions {
6
+ inChannels: number;
7
+ height: number;
8
+ width: number;
9
+ repeats: number;
10
+ }
11
+
12
+ export declare function runRepeatChannels(
13
+ input: Tensor,
14
+ options: RepeatChannelsOptions
15
+ ): Promise<Tensor>;
16
+
17
+ export declare function recordRepeatChannels(
18
+ recorder: CommandRecorder,
19
+ input: Tensor,
20
+ options: RepeatChannelsOptions
21
+ ): Promise<Tensor>;
@@ -0,0 +1,60 @@
1
+ import { acquireBuffer } from '../../memory/buffer-pool.js';
2
+ import { createTensor, dtypeBytes } from '../tensor.js';
3
+ import { unifiedKernelWrapper } from './utils.js';
4
+ import { selectRuleValue } from './rule-registry.js';
5
+ import { WORKGROUP_SIZES } from './constants.js';
6
+
7
+ function selectRepeatChannelsVariant(dtype) {
8
+ return selectRuleValue('repeatChannels', 'variant', { dtype });
9
+ }
10
+
11
+ async function _repeatChannels(target, input, options = {}) {
12
+ const {
13
+ inChannels,
14
+ height,
15
+ width,
16
+ repeats,
17
+ outputBuffer = null,
18
+ } = options;
19
+
20
+ if (
21
+ !Number.isFinite(inChannels) ||
22
+ !Number.isFinite(height) ||
23
+ !Number.isFinite(width) ||
24
+ !Number.isFinite(repeats) ||
25
+ repeats < 1
26
+ ) {
27
+ throw new Error('RepeatChannels requires inChannels, height, width, and repeats.');
28
+ }
29
+
30
+ const outChannels = inChannels * repeats;
31
+ const variant = selectRepeatChannelsVariant(input.dtype);
32
+ const bytesPerElement = dtypeBytes(input.dtype);
33
+ const outputSize = outChannels * height * width * bytesPerElement;
34
+ const output = outputBuffer || acquireBuffer(outputSize, undefined, 'repeat_channels_output');
35
+
36
+ await unifiedKernelWrapper(
37
+ 'repeat_channels',
38
+ target,
39
+ variant,
40
+ [input, output],
41
+ {
42
+ in_channels: inChannels,
43
+ height,
44
+ width,
45
+ repeats,
46
+ _pad0: 0,
47
+ },
48
+ [Math.ceil((height * width) / WORKGROUP_SIZES.DEFAULT), outChannels, 1]
49
+ );
50
+
51
+ return createTensor(output, input.dtype, [outChannels, height, width], 'repeat_channels_output');
52
+ }
53
+
54
+ export async function runRepeatChannels(input, options = {}) {
55
+ return _repeatChannels(null, input, options);
56
+ }
57
+
58
+ export async function recordRepeatChannels(recorder, input, options = {}) {
59
+ return _repeatChannels(recorder, input, options);
60
+ }
@@ -0,0 +1,28 @@
1
+ override WORKGROUP_SIZE: u32 = 256u;
2
+
3
+ struct Uniforms {
4
+ in_channels: u32,
5
+ height: u32,
6
+ width: u32,
7
+ repeats: u32,
8
+ _pad0: u32,
9
+ }
10
+
11
+ @group(0) @binding(0) var<uniform> u: Uniforms;
12
+ @group(0) @binding(1) var<storage, read> input: array<f32>;
13
+ @group(0) @binding(2) var<storage, read_write> output: array<f32>;
14
+
15
+ @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
16
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
17
+ let spatial = u.height * u.width;
18
+ let out_channels = u.in_channels * u.repeats;
19
+ let spatial_idx = gid.x;
20
+ let out_channel = gid.y;
21
+ if (out_channel >= out_channels || spatial_idx >= spatial) {
22
+ return;
23
+ }
24
+
25
+ let channel = out_channel / u.repeats;
26
+ let idx = out_channel * spatial + spatial_idx;
27
+ output[idx] = input[channel * spatial + spatial_idx];
28
+ }
@@ -0,0 +1,30 @@
1
+ enable f16;
2
+
3
+ override WORKGROUP_SIZE: u32 = 256u;
4
+
5
+ struct Uniforms {
6
+ in_channels: u32,
7
+ height: u32,
8
+ width: u32,
9
+ repeats: u32,
10
+ _pad0: u32,
11
+ }
12
+
13
+ @group(0) @binding(0) var<uniform> u: Uniforms;
14
+ @group(0) @binding(1) var<storage, read> input: array<f16>;
15
+ @group(0) @binding(2) var<storage, read_write> output: array<f16>;
16
+
17
+ @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
18
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
19
+ let spatial = u.height * u.width;
20
+ let out_channels = u.in_channels * u.repeats;
21
+ let spatial_idx = gid.x;
22
+ let out_channel = gid.y;
23
+ if (out_channel >= out_channels || spatial_idx >= spatial) {
24
+ return;
25
+ }
26
+
27
+ let channel = out_channel / u.repeats;
28
+ let idx = out_channel * spatial + spatial_idx;
29
+ output[idx] = input[channel * spatial + spatial_idx];
30
+ }
@@ -63,6 +63,22 @@ function cleanupTemps(temps, recorder) {
63
63
  }
64
64
  }
65
65
 
66
+ function planResidualDispatch(target, size, elementsPerWorkgroup) {
67
+ const device = target?.device;
68
+ const maxPerDim = Number.isFinite(device?.limits?.maxComputeWorkgroupsPerDimension)
69
+ ? device.limits.maxComputeWorkgroupsPerDimension
70
+ : 65535;
71
+ const dispatchStride = Math.min(size, maxPerDim * elementsPerWorkgroup);
72
+ return {
73
+ dispatchStride,
74
+ workgroups: [
75
+ Math.ceil(dispatchStride / elementsPerWorkgroup),
76
+ Math.ceil(size / dispatchStride),
77
+ 1,
78
+ ],
79
+ };
80
+ }
81
+
66
82
  async function _residualAdd(target, a, b, size, options = {}) {
67
83
  const recorder = target && typeof target.beginComputePass === 'function' ? target : null;
68
84
  const { useVec4 = true, outputBuffer = null } = options;
@@ -75,15 +91,17 @@ async function _residualAdd(target, a, b, size, options = {}) {
75
91
  const outputSize = size * bytesPerElement;
76
92
  const output = outputBuffer || acquireBuffer(outputSize, undefined, 'residual_output');
77
93
 
78
- const workgroups = useVec4
79
- ? Math.ceil(size / VEC4_ELEMENTS_PER_WG)
80
- : Math.ceil(size / WORKGROUP_SIZES.DEFAULT);
94
+ const dispatchPlan = planResidualDispatch(
95
+ target,
96
+ size,
97
+ useVec4 ? VEC4_ELEMENTS_PER_WG : WORKGROUP_SIZES.DEFAULT
98
+ );
81
99
 
82
100
  await unifiedKernelWrapper(
83
101
  'residual', target, variant,
84
102
  [aAligned, bAligned, output],
85
- { size },
86
- workgroups
103
+ { size, scale: 1, _pad1: dispatchPlan.dispatchStride, _pad2: 0 },
104
+ dispatchPlan.workgroups
87
105
  );
88
106
 
89
107
  cleanupTemps(temps, recorder);
@@ -96,13 +114,31 @@ async function _biasAdd(target, data, bias, numTokens, dim, options = {}) {
96
114
 
97
115
  const { bias: biasAligned, temps } = await alignBiasTensor(data, bias, recorder);
98
116
  const variant = selectBiasAddVariant(data.dtype, biasAligned.dtype);
99
-
100
- const workgroups = Math.ceil((numTokens * dim) / WORKGROUP_SIZES.DEFAULT);
117
+ const device = target?.device;
118
+ const maxPerDim = Number.isFinite(device?.limits?.maxComputeWorkgroupsPerDimension)
119
+ ? device.limits.maxComputeWorkgroupsPerDimension
120
+ : 65535;
121
+ const tokenStride = Math.min(numTokens, maxPerDim);
122
+
123
+ const workgroups = [
124
+ Math.ceil(dim / WORKGROUP_SIZES.DEFAULT),
125
+ tokenStride,
126
+ Math.ceil(numTokens / tokenStride),
127
+ ];
101
128
 
102
129
  await unifiedKernelWrapper(
103
130
  'bias_add', target, variant,
104
131
  [data, biasAligned],
105
- { num_tokens: numTokens, dim, data_offset: dataOffset, bias_offset: biasOffset },
132
+ {
133
+ num_tokens: numTokens,
134
+ dim,
135
+ data_offset: dataOffset,
136
+ bias_offset: biasOffset,
137
+ token_stride: tokenStride,
138
+ _pad0: 0,
139
+ _pad1: 0,
140
+ _pad2: 0,
141
+ },
106
142
  workgroups
107
143
  );
108
144
 
@@ -23,7 +23,8 @@ override WORKGROUP_SIZE: u32 = 256u;
23
23
 
24
24
  @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
25
25
  fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
26
- let idx = gid.x;
26
+ let dispatch_stride = max(u._pad1, 1u);
27
+ let idx = gid.y * dispatch_stride + gid.x;
27
28
  if (idx >= u.size) {
28
29
  return;
29
30
  }
@@ -35,7 +36,8 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
35
36
  // This avoids requiring a different bind group layout with read_write on 'a'
36
37
  @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
37
38
  fn add_inplace(@builtin(global_invocation_id) gid: vec3<u32>) {
38
- let idx = gid.x;
39
+ let dispatch_stride = max(u._pad1, 1u);
40
+ let idx = gid.y * dispatch_stride + gid.x;
39
41
  if (idx >= u.size) {
40
42
  return;
41
43
  }
@@ -45,7 +47,8 @@ fn add_inplace(@builtin(global_invocation_id) gid: vec3<u32>) {
45
47
  // Fused residual + scale: output = a + scale * b
46
48
  @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
47
49
  fn add_scaled(@builtin(global_invocation_id) gid: vec3<u32>) {
48
- let idx = gid.x;
50
+ let dispatch_stride = max(u._pad1, 1u);
51
+ let idx = gid.y * dispatch_stride + gid.x;
49
52
  if (idx >= u.size) {
50
53
  return;
51
54
  }
@@ -27,7 +27,8 @@ override WORKGROUP_SIZE: u32 = 256u;
27
27
 
28
28
  @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
29
29
  fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
30
- let idx = gid.x;
30
+ let dispatch_stride = max(u._pad1, 1u);
31
+ let idx = gid.y * dispatch_stride + gid.x;
31
32
  if (idx >= u.size) {
32
33
  return;
33
34
  }
@@ -25,7 +25,8 @@ override WORKGROUP_SIZE_VEC4: u32 = 64u;
25
25
  // Vectorized version for better throughput
26
26
  @compute @workgroup_size(WORKGROUP_SIZE_VEC4, 1, 1)
27
27
  fn add_vec4(@builtin(global_invocation_id) gid: vec3<u32>) {
28
- let idx = gid.x * 4u;
28
+ let dispatch_stride = max(u._pad1, 4u);
29
+ let idx = gid.y * dispatch_stride + gid.x * 4u;
29
30
  let size = u.size;
30
31
 
31
32
  if (idx >= size) {
@@ -23,7 +23,8 @@ override WORKGROUP_SIZE_VEC4: u32 = 64u;
23
23
  // Vectorized version for better throughput
24
24
  @compute @workgroup_size(WORKGROUP_SIZE_VEC4, 1, 1)
25
25
  fn add_vec4(@builtin(global_invocation_id) gid: vec3<u32>) {
26
- let idx = gid.x * 4u;
26
+ let dispatch_stride = max(u._pad1, 4u);
27
+ let idx = gid.y * dispatch_stride + gid.x * 4u;
27
28
  let size = u.size;
28
29
 
29
30
  if (idx >= size) {