@simulatte/doppler 0.1.4 → 0.1.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (199) hide show
  1. package/README.md +26 -10
  2. package/package.json +30 -6
  3. package/src/client/doppler-api.browser.d.ts +1 -0
  4. package/src/client/doppler-api.browser.js +288 -0
  5. package/src/client/doppler-api.js +1 -1
  6. package/src/client/doppler-provider/types.js +1 -1
  7. package/src/config/execution-contract-check.d.ts +33 -0
  8. package/src/config/execution-contract-check.js +72 -0
  9. package/src/config/execution-v0-contract-check.d.ts +94 -0
  10. package/src/config/execution-v0-contract-check.js +251 -0
  11. package/src/config/execution-v0-graph-contract-check.d.ts +20 -0
  12. package/src/config/execution-v0-graph-contract-check.js +64 -0
  13. package/src/config/kernel-path-contract-check.d.ts +76 -0
  14. package/src/config/kernel-path-contract-check.js +479 -0
  15. package/src/config/kernel-path-loader.d.ts +16 -0
  16. package/src/config/kernel-path-loader.js +54 -0
  17. package/src/config/kernels/kernel-ref-digests.js +39 -27
  18. package/src/config/kernels/registry.json +598 -2
  19. package/src/config/loader.js +81 -48
  20. package/src/config/merge-contract-check.d.ts +16 -0
  21. package/src/config/merge-contract-check.js +321 -0
  22. package/src/config/merge-helpers.d.ts +58 -0
  23. package/src/config/merge-helpers.js +54 -0
  24. package/src/config/merge.js +21 -6
  25. package/src/config/presets/models/janus-text.json +2 -0
  26. package/src/config/presets/models/qwen3.json +9 -2
  27. package/src/config/presets/models/transformer.json +5 -0
  28. package/src/config/quantization-contract-check.d.ts +12 -0
  29. package/src/config/quantization-contract-check.js +91 -0
  30. package/src/config/required-inference-fields-contract-check.d.ts +24 -0
  31. package/src/config/required-inference-fields-contract-check.js +237 -0
  32. package/src/config/schema/browser-suite-metrics.schema.d.ts +17 -0
  33. package/src/config/schema/browser-suite-metrics.schema.js +46 -0
  34. package/src/config/schema/conversion-report.schema.d.ts +40 -0
  35. package/src/config/schema/conversion-report.schema.js +108 -0
  36. package/src/config/schema/doppler.schema.js +12 -18
  37. package/src/config/schema/index.d.ts +22 -0
  38. package/src/config/schema/index.js +18 -0
  39. package/src/config/schema/inference-defaults.schema.js +3 -0
  40. package/src/config/schema/inference.schema.d.ts +9 -0
  41. package/src/config/schema/kernel-path.schema.d.ts +6 -0
  42. package/src/config/schema/manifest.schema.d.ts +6 -0
  43. package/src/config/schema/manifest.schema.js +3 -0
  44. package/src/converter/core.d.ts +10 -0
  45. package/src/converter/core.js +27 -2
  46. package/src/converter/parsers/diffusion.js +63 -3
  47. package/src/converter/rope-config.js +42 -0
  48. package/src/gpu/device.js +58 -0
  49. package/src/gpu/kernels/attention.js +98 -0
  50. package/src/gpu/kernels/bias_add.wgsl +8 -6
  51. package/src/gpu/kernels/bias_add_f16.wgsl +8 -5
  52. package/src/gpu/kernels/conv2d.js +1 -1
  53. package/src/gpu/kernels/conv2d.wgsl +7 -8
  54. package/src/gpu/kernels/conv2d_f16.wgsl +7 -8
  55. package/src/gpu/kernels/depthwise_conv2d.d.ts +29 -0
  56. package/src/gpu/kernels/depthwise_conv2d.js +99 -0
  57. package/src/gpu/kernels/depthwise_conv2d.wgsl +55 -0
  58. package/src/gpu/kernels/depthwise_conv2d_f16.wgsl +59 -0
  59. package/src/gpu/kernels/grouped_pointwise_conv2d.d.ts +27 -0
  60. package/src/gpu/kernels/grouped_pointwise_conv2d.js +93 -0
  61. package/src/gpu/kernels/grouped_pointwise_conv2d.wgsl +44 -0
  62. package/src/gpu/kernels/grouped_pointwise_conv2d_f16.wgsl +48 -0
  63. package/src/gpu/kernels/index.d.ts +30 -0
  64. package/src/gpu/kernels/index.js +25 -0
  65. package/src/gpu/kernels/matmul.js +25 -0
  66. package/src/gpu/kernels/pixel_shuffle.js +1 -1
  67. package/src/gpu/kernels/pixel_shuffle.wgsl +4 -5
  68. package/src/gpu/kernels/pixel_shuffle_f16.wgsl +4 -5
  69. package/src/gpu/kernels/relu.d.ts +18 -0
  70. package/src/gpu/kernels/relu.js +58 -0
  71. package/src/gpu/kernels/relu.wgsl +22 -0
  72. package/src/gpu/kernels/relu_f16.wgsl +24 -0
  73. package/src/gpu/kernels/repeat_channels.d.ts +21 -0
  74. package/src/gpu/kernels/repeat_channels.js +60 -0
  75. package/src/gpu/kernels/repeat_channels.wgsl +28 -0
  76. package/src/gpu/kernels/repeat_channels_f16.wgsl +30 -0
  77. package/src/gpu/kernels/residual.js +44 -8
  78. package/src/gpu/kernels/residual.wgsl +6 -3
  79. package/src/gpu/kernels/residual_f16.wgsl +2 -1
  80. package/src/gpu/kernels/residual_f16_vec4.wgsl +2 -1
  81. package/src/gpu/kernels/residual_vec4.wgsl +2 -1
  82. package/src/gpu/kernels/rmsnorm.js +58 -6
  83. package/src/gpu/kernels/rmsnorm.wgsl +14 -6
  84. package/src/gpu/kernels/rmsnorm_f16.wgsl +10 -2
  85. package/src/gpu/kernels/rope.d.ts +2 -0
  86. package/src/gpu/kernels/rope.js +11 -1
  87. package/src/gpu/kernels/rope.wgsl +56 -40
  88. package/src/gpu/kernels/sana_linear_attention.d.ts +27 -0
  89. package/src/gpu/kernels/sana_linear_attention.js +121 -0
  90. package/src/gpu/kernels/sana_linear_attention_apply.wgsl +43 -0
  91. package/src/gpu/kernels/sana_linear_attention_apply_f16.wgsl +46 -0
  92. package/src/gpu/kernels/sana_linear_attention_summary.wgsl +51 -0
  93. package/src/gpu/kernels/sana_linear_attention_summary_f16.wgsl +53 -0
  94. package/src/gpu/kernels/silu.d.ts +1 -0
  95. package/src/gpu/kernels/silu.js +32 -14
  96. package/src/gpu/kernels/silu.wgsl +19 -9
  97. package/src/gpu/kernels/silu_f16.wgsl +19 -9
  98. package/src/gpu/kernels/transpose.js +15 -2
  99. package/src/gpu/kernels/transpose.wgsl +5 -6
  100. package/src/gpu/kernels/upsample2d.js +2 -1
  101. package/src/gpu/kernels/upsample2d.wgsl +6 -9
  102. package/src/gpu/kernels/upsample2d_f16.wgsl +6 -9
  103. package/src/gpu/kernels/utils.js +16 -1
  104. package/src/index-browser.d.ts +1 -1
  105. package/src/index-browser.js +2 -2
  106. package/src/index.js +1 -1
  107. package/src/inference/browser-harness.js +109 -23
  108. package/src/inference/pipelines/diffusion/init.js +14 -0
  109. package/src/inference/pipelines/diffusion/pipeline.js +215 -77
  110. package/src/inference/pipelines/diffusion/sana-transformer.d.ts +53 -0
  111. package/src/inference/pipelines/diffusion/sana-transformer.js +738 -0
  112. package/src/inference/pipelines/diffusion/scheduler.d.ts +17 -1
  113. package/src/inference/pipelines/diffusion/scheduler.js +91 -3
  114. package/src/inference/pipelines/diffusion/text-encoder-gpu.d.ts +11 -4
  115. package/src/inference/pipelines/diffusion/text-encoder-gpu.js +282 -0
  116. package/src/inference/pipelines/diffusion/text-encoder.js +18 -1
  117. package/src/inference/pipelines/diffusion/types.d.ts +4 -0
  118. package/src/inference/pipelines/diffusion/vae.js +782 -78
  119. package/src/inference/pipelines/text/attention/record.js +11 -2
  120. package/src/inference/pipelines/text/attention/run.js +11 -2
  121. package/src/inference/pipelines/text/chat-format.js +25 -1
  122. package/src/inference/pipelines/text/config.d.ts +9 -0
  123. package/src/inference/pipelines/text/config.js +69 -2
  124. package/src/inference/pipelines/text/execution-plan.js +23 -31
  125. package/src/inference/pipelines/text/execution-v0.js +43 -95
  126. package/src/inference/pipelines/text/ffn/standard.js +3 -0
  127. package/src/inference/pipelines/text/init.d.ts +4 -0
  128. package/src/inference/pipelines/text/init.js +56 -9
  129. package/src/inference/pipelines/text/layer.js +11 -0
  130. package/src/inference/pipelines/text.js +4 -0
  131. package/src/inference/tokenizers/bundled.js +156 -33
  132. package/src/rules/execution-rules-contract-check.d.ts +17 -0
  133. package/src/rules/execution-rules-contract-check.js +245 -0
  134. package/src/rules/kernels/depthwise-conv2d.rules.json +6 -0
  135. package/src/rules/kernels/grouped-pointwise-conv2d.rules.json +6 -0
  136. package/src/rules/kernels/relu.rules.json +6 -0
  137. package/src/rules/kernels/repeat-channels.rules.json +6 -0
  138. package/src/rules/kernels/sana-linear-attention.rules.json +6 -0
  139. package/src/rules/layer-pattern-contract-check.d.ts +17 -0
  140. package/src/rules/layer-pattern-contract-check.js +231 -0
  141. package/src/rules/rule-registry.d.ts +28 -0
  142. package/src/rules/rule-registry.js +38 -0
  143. package/src/rules/tooling/command-runtime.rules.json +18 -0
  144. package/src/tooling/command-api.d.ts +27 -1
  145. package/src/tooling/command-api.js +142 -3
  146. package/src/tooling/conversion-config-materializer.d.ts +24 -0
  147. package/src/tooling/conversion-config-materializer.js +99 -0
  148. package/src/tooling/lean-execution-contract-runner.d.ts +43 -0
  149. package/src/tooling/lean-execution-contract-runner.js +158 -0
  150. package/src/tooling/node-browser-command-runner.d.ts +4 -0
  151. package/src/tooling/node-browser-command-runner.js +58 -3
  152. package/src/tooling/node-command-runner.js +15 -0
  153. package/src/tooling/node-convert.d.ts +10 -0
  154. package/src/tooling/node-converter.js +59 -0
  155. package/src/tooling/node-webgpu.js +11 -89
  156. package/src/training/checkpoint-watch.d.ts +7 -0
  157. package/src/training/checkpoint-watch.js +106 -0
  158. package/src/training/checkpoint.d.ts +6 -1
  159. package/src/training/checkpoint.js +12 -2
  160. package/src/training/distillation/artifacts.d.ts +71 -0
  161. package/src/training/distillation/artifacts.js +132 -0
  162. package/src/training/distillation/checkpoint-watch.d.ts +10 -0
  163. package/src/training/distillation/checkpoint-watch.js +57 -0
  164. package/src/training/distillation/dataset.d.ts +59 -0
  165. package/src/training/distillation/dataset.js +337 -0
  166. package/src/training/distillation/eval.d.ts +34 -0
  167. package/src/training/distillation/eval.js +310 -0
  168. package/src/training/distillation/index.d.ts +29 -0
  169. package/src/training/distillation/index.js +29 -0
  170. package/src/training/distillation/runtime.d.ts +20 -0
  171. package/src/training/distillation/runtime.js +121 -0
  172. package/src/training/distillation/scoreboard.d.ts +6 -0
  173. package/src/training/distillation/scoreboard.js +8 -0
  174. package/src/training/distillation/stage-a.d.ts +45 -0
  175. package/src/training/distillation/stage-a.js +338 -0
  176. package/src/training/distillation/stage-b.d.ts +24 -0
  177. package/src/training/distillation/stage-b.js +20 -0
  178. package/src/training/index.d.ts +10 -0
  179. package/src/training/index.js +10 -0
  180. package/src/training/lora-pipeline.d.ts +40 -0
  181. package/src/training/lora-pipeline.js +796 -0
  182. package/src/training/operator-artifacts.d.ts +62 -0
  183. package/src/training/operator-artifacts.js +140 -0
  184. package/src/training/operator-command.d.ts +5 -0
  185. package/src/training/operator-command.js +453 -0
  186. package/src/training/operator-eval.d.ts +48 -0
  187. package/src/training/operator-eval.js +230 -0
  188. package/src/training/operator-scoreboard.d.ts +5 -0
  189. package/src/training/operator-scoreboard.js +44 -0
  190. package/src/training/runner.d.ts +52 -0
  191. package/src/training/runner.js +29 -4
  192. package/src/training/suite.d.ts +112 -0
  193. package/src/training/suite.js +9 -9
  194. package/src/training/workloads.d.ts +164 -0
  195. package/src/training/workloads.js +539 -0
  196. package/src/version.d.ts +2 -0
  197. package/src/version.js +2 -0
  198. package/tools/convert-safetensors-node.js +47 -0
  199. package/tools/doppler-cli.js +252 -41
@@ -0,0 +1,43 @@
1
+ override WORKGROUP_SIZE: u32 = 256u;
2
+
3
+ struct Uniforms {
4
+ num_heads: u32,
5
+ head_dim: u32,
6
+ num_tokens: u32,
7
+ hidden_size: u32,
8
+ eps: f32,
9
+ _pad0: u32,
10
+ _pad1: u32,
11
+ _pad2: u32,
12
+ }
13
+
14
+ @group(0) @binding(0) var<uniform> u: Uniforms;
15
+ @group(0) @binding(1) var<storage, read> query: array<f32>;
16
+ @group(0) @binding(2) var<storage, read> summary: array<f32>;
17
+ @group(0) @binding(3) var<storage, read_write> output: array<f32>;
18
+
19
+ @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
20
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
21
+ let hidden = gid.x;
22
+ let token = gid.y;
23
+ if (token >= u.num_tokens || hidden >= u.hidden_size) {
24
+ return;
25
+ }
26
+
27
+ let idx = token * u.hidden_size + hidden;
28
+ let head = hidden / u.head_dim;
29
+ let dim = hidden - head * u.head_dim;
30
+ let rows_per_head = u.head_dim + 1u;
31
+ let head_offset = head * rows_per_head * u.head_dim;
32
+ let hidden_base = head * u.head_dim;
33
+
34
+ var numerator: f32 = 0.0;
35
+ var denominator: f32 = 0.0;
36
+ for (var i: u32 = 0u; i < u.head_dim; i = i + 1u) {
37
+ let q_value = max(query[token * u.hidden_size + hidden_base + i], 0.0);
38
+ numerator = numerator + summary[head_offset + dim * u.head_dim + i] * q_value;
39
+ denominator = denominator + summary[head_offset + u.head_dim * u.head_dim + i] * q_value;
40
+ }
41
+
42
+ output[idx] = numerator / (denominator + u.eps);
43
+ }
@@ -0,0 +1,46 @@
1
+ enable f16;
2
+
3
+ override WORKGROUP_SIZE: u32 = 256u;
4
+
5
+ struct Uniforms {
6
+ num_heads: u32,
7
+ head_dim: u32,
8
+ num_tokens: u32,
9
+ hidden_size: u32,
10
+ eps: f32,
11
+ _pad0: u32,
12
+ _pad1: u32,
13
+ _pad2: u32,
14
+ }
15
+
16
+ @group(0) @binding(0) var<uniform> u: Uniforms;
17
+ @group(0) @binding(1) var<storage, read> query: array<f16>;
18
+ @group(0) @binding(2) var<storage, read> summary: array<f32>;
19
+ @group(0) @binding(3) var<storage, read_write> output: array<f16>;
20
+
21
+ @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
22
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
23
+ let hidden = gid.x;
24
+ let token = gid.y;
25
+ if (token >= u.num_tokens || hidden >= u.hidden_size) {
26
+ return;
27
+ }
28
+
29
+ let idx = token * u.hidden_size + hidden;
30
+ let head = hidden / u.head_dim;
31
+ let dim = hidden - head * u.head_dim;
32
+ let rows_per_head = u.head_dim + 1u;
33
+ let head_offset = head * rows_per_head * u.head_dim;
34
+ let hidden_base = head * u.head_dim;
35
+
36
+ var numerator: f32 = 0.0;
37
+ var denominator: f32 = 0.0;
38
+ for (var i: u32 = 0u; i < u.head_dim; i = i + 1u) {
39
+ let q_value = max(f32(query[token * u.hidden_size + hidden_base + i]), 0.0);
40
+ numerator = numerator + summary[head_offset + dim * u.head_dim + i] * q_value;
41
+ denominator = denominator + summary[head_offset + u.head_dim * u.head_dim + i] * q_value;
42
+ }
43
+
44
+ let result = numerator / (denominator + u.eps);
45
+ output[idx] = f16(clamp(result, -65504.0, 65504.0));
46
+ }
@@ -0,0 +1,51 @@
1
+ override WORKGROUP_SIZE: u32 = 256u;
2
+
3
+ struct Uniforms {
4
+ num_heads: u32,
5
+ head_dim: u32,
6
+ num_tokens: u32,
7
+ hidden_size: u32,
8
+ _pad0: u32,
9
+ _pad1: u32,
10
+ }
11
+
12
+ @group(0) @binding(0) var<uniform> u: Uniforms;
13
+ @group(0) @binding(1) var<storage, read> query: array<f32>;
14
+ @group(0) @binding(2) var<storage, read> key: array<f32>;
15
+ @group(0) @binding(3) var<storage, read> value: array<f32>;
16
+ @group(0) @binding(4) var<storage, read_write> summary: array<f32>;
17
+
18
+ @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
19
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
20
+ let idx = gid.x;
21
+ let rows_per_head = u.head_dim + 1u;
22
+ let head_span = rows_per_head * u.head_dim;
23
+ let total = u.num_heads * head_span;
24
+ if (idx >= total) {
25
+ return;
26
+ }
27
+
28
+ let head = idx / head_span;
29
+ let rem = idx - head * head_span;
30
+ let row = rem / u.head_dim;
31
+ let col = rem - row * u.head_dim;
32
+ let hidden_base = head * u.head_dim;
33
+
34
+ var acc: f32 = 0.0;
35
+ for (var token: u32 = 0u; token < u.num_tokens; token = token + 1u) {
36
+ let query_value = query[token * u.hidden_size + hidden_base + col];
37
+ let key_idx = token * u.hidden_size + hidden_base + col;
38
+ let key_value = max(key[key_idx], 0.0);
39
+ let value_value = select(
40
+ value[token * u.hidden_size + hidden_base + row],
41
+ 1.0,
42
+ row == u.head_dim
43
+ );
44
+ if (u.hidden_size == 0u) {
45
+ acc = acc + query_value;
46
+ }
47
+ acc = acc + value_value * key_value;
48
+ }
49
+
50
+ summary[idx] = acc;
51
+ }
@@ -0,0 +1,53 @@
1
+ enable f16;
2
+
3
+ override WORKGROUP_SIZE: u32 = 256u;
4
+
5
+ struct Uniforms {
6
+ num_heads: u32,
7
+ head_dim: u32,
8
+ num_tokens: u32,
9
+ hidden_size: u32,
10
+ _pad0: u32,
11
+ _pad1: u32,
12
+ }
13
+
14
+ @group(0) @binding(0) var<uniform> u: Uniforms;
15
+ @group(0) @binding(1) var<storage, read> query: array<f16>;
16
+ @group(0) @binding(2) var<storage, read> key: array<f16>;
17
+ @group(0) @binding(3) var<storage, read> value: array<f16>;
18
+ @group(0) @binding(4) var<storage, read_write> summary: array<f32>;
19
+
20
+ @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
21
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
22
+ let idx = gid.x;
23
+ let rows_per_head = u.head_dim + 1u;
24
+ let head_span = rows_per_head * u.head_dim;
25
+ let total = u.num_heads * head_span;
26
+ if (idx >= total) {
27
+ return;
28
+ }
29
+
30
+ let head = idx / head_span;
31
+ let rem = idx - head * head_span;
32
+ let row = rem / u.head_dim;
33
+ let col = rem - row * u.head_dim;
34
+ let hidden_base = head * u.head_dim;
35
+
36
+ var acc: f32 = 0.0;
37
+ for (var token: u32 = 0u; token < u.num_tokens; token = token + 1u) {
38
+ let query_value = f32(query[token * u.hidden_size + hidden_base + col]);
39
+ let key_idx = token * u.hidden_size + hidden_base + col;
40
+ let key_value = max(f32(key[key_idx]), 0.0);
41
+ let value_value = select(
42
+ f32(value[token * u.hidden_size + hidden_base + row]),
43
+ 1.0,
44
+ row == u.head_dim
45
+ );
46
+ if (u.hidden_size == 0u) {
47
+ acc = acc + query_value;
48
+ }
49
+ acc = acc + value_value * key_value;
50
+ }
51
+
52
+ summary[idx] = acc;
53
+ }
@@ -16,6 +16,7 @@ export interface SiLUOptions extends OutputBufferOptions {
16
16
  size?: number | null;
17
17
  gate?: Tensor | null;
18
18
  gateActivation?: 'silu' | 'sigmoid';
19
+ inputActivation?: 'silu' | 'identity';
19
20
  useVec4?: boolean;
20
21
  biasOffset?: number;
21
22
  swigluLimit: number | null;
@@ -47,6 +47,18 @@ function createSiLUBindGroupEntries(uniformBuffer, input, output, gate) {
47
47
  ];
48
48
  }
49
49
 
50
+ function planSiLUDispatch(device, size, useVec4) {
51
+ const maxPerDim = Number.isFinite(device?.limits?.maxComputeWorkgroupsPerDimension)
52
+ ? device.limits.maxComputeWorkgroupsPerDimension
53
+ : 65535;
54
+ const laneWidth = useVec4 ? 4 : 1;
55
+ const chunkSize = maxPerDim * WORKGROUP_SIZES.DEFAULT * laneWidth;
56
+ const dispatchStride = Math.min(size, chunkSize);
57
+ const x = Math.min(maxPerDim, Math.ceil(dispatchStride / (WORKGROUP_SIZES.DEFAULT * laneWidth)));
58
+ const y = Math.max(1, Math.ceil(size / chunkSize));
59
+ return { dispatchStride, workgroups: [x, y, 1] };
60
+ }
61
+
50
62
 
51
63
  export async function runSiLU(
52
64
  input,
@@ -60,6 +72,7 @@ export async function runSiLU(
60
72
  useVec4 = false,
61
73
  swigluLimit,
62
74
  gateActivation = 'silu',
75
+ inputActivation = 'silu',
63
76
  } = options;
64
77
  const resolvedSwigluLimit = resolveSwigluLimit(swigluLimit, 'SiLU');
65
78
 
@@ -74,14 +87,17 @@ export async function runSiLU(
74
87
  useSplit: false,
75
88
  useRowsplit: false,
76
89
  });
77
- const constants = gate && gateActivation === 'sigmoid'
78
- ? { ...(overrides || {}), GATE_USE_SIGMOID: true }
79
- : overrides;
90
+ const constants = {
91
+ ...(overrides || {}),
92
+ ...(gate && gateActivation === 'sigmoid' ? { GATE_USE_SIGMOID: true } : {}),
93
+ ...(inputActivation === 'identity' ? { INPUT_USE_IDENTITY: true } : {}),
94
+ };
80
95
  const pipeline = await getPipelineFast('silu', variant, null, constants);
81
96
 
82
97
  const inferredSize = size || (input.buffer.size / bytesPerElement);
83
98
  const outputSize = inferredSize * bytesPerElement;
84
99
  const output = outputBuffer || acquireBuffer(outputSize, undefined, 'silu_output');
100
+ const dispatchPlan = planSiLUDispatch(device, inferredSize, useVec4);
85
101
 
86
102
  // Create uniform buffer
87
103
  const uniformBuffer = createUniformBufferWithView(
@@ -89,7 +105,7 @@ export async function runSiLU(
89
105
  16,
90
106
  (view) => {
91
107
  view.setUint32(0, inferredSize, true);
92
- view.setUint32(4, 0, true);
108
+ view.setUint32(4, dispatchPlan.dispatchStride, true);
93
109
  view.setFloat32(8, gate ? resolvedSwigluLimit : 0, true);
94
110
  view.setFloat32(12, 0, true);
95
111
  },
@@ -106,8 +122,7 @@ export async function runSiLU(
106
122
  entries,
107
123
  });
108
124
 
109
- const workgroups = Math.ceil(inferredSize / WORKGROUP_SIZES.DEFAULT);
110
- dispatch(device, pipeline, bindGroup, workgroups, 'silu');
125
+ dispatch(device, pipeline, bindGroup, dispatchPlan.workgroups, 'silu');
111
126
 
112
127
  uniformBuffer.destroy();
113
128
 
@@ -215,7 +230,7 @@ export async function runSiLURowSplit(
215
230
  ],
216
231
  });
217
232
 
218
- const workgroups = Math.ceil((numTokens * dim) / WORKGROUP_SIZES.DEFAULT);
233
+ const workgroups = [Math.ceil(dim / WORKGROUP_SIZES.DEFAULT), numTokens, 1];
219
234
  dispatch(device, pipeline, bindGroup, workgroups, 'silu_rowsplit');
220
235
 
221
236
  uniformBuffer.destroy();
@@ -269,7 +284,7 @@ export async function recordSiLURowSplit(
269
284
  ],
270
285
  });
271
286
 
272
- const workgroups = Math.ceil((numTokens * dim) / WORKGROUP_SIZES.DEFAULT);
287
+ const workgroups = [Math.ceil(dim / WORKGROUP_SIZES.DEFAULT), numTokens, 1];
273
288
  recordDispatch(recorder, pipeline, bindGroup, workgroups, 'silu_rowsplit');
274
289
 
275
290
  return createTensor(output, input.dtype, [numTokens, dim], 'silu_rowsplit_output');
@@ -288,6 +303,7 @@ export async function recordSiLU(
288
303
  outputBuffer = null,
289
304
  swigluLimit,
290
305
  gateActivation = 'silu',
306
+ inputActivation = 'silu',
291
307
  } = options;
292
308
  const resolvedSwigluLimit = resolveSwigluLimit(swigluLimit, 'SiLU');
293
309
 
@@ -302,14 +318,17 @@ export async function recordSiLU(
302
318
  useSplit: false,
303
319
  useRowsplit: false,
304
320
  });
305
- const constants = gate && gateActivation === 'sigmoid'
306
- ? { ...(overrides || {}), GATE_USE_SIGMOID: true }
307
- : overrides;
321
+ const constants = {
322
+ ...(overrides || {}),
323
+ ...(gate && gateActivation === 'sigmoid' ? { GATE_USE_SIGMOID: true } : {}),
324
+ ...(inputActivation === 'identity' ? { INPUT_USE_IDENTITY: true } : {}),
325
+ };
308
326
  const pipeline = await getPipelineFast('silu', variant, null, constants);
309
327
 
310
328
  const inferredSize = size || (input.buffer.size / bytesPerElement);
311
329
  const outputSize = inferredSize * bytesPerElement;
312
330
  const output = outputBuffer || acquireBuffer(outputSize, undefined, 'silu_output');
331
+ const dispatchPlan = planSiLUDispatch(device, inferredSize, false);
313
332
 
314
333
  // Uniform buffer
315
334
  const uniformBuffer = createUniformBufferWithView(
@@ -317,7 +336,7 @@ export async function recordSiLU(
317
336
  16,
318
337
  (view) => {
319
338
  view.setUint32(0, inferredSize, true);
320
- view.setUint32(4, 0, true);
339
+ view.setUint32(4, dispatchPlan.dispatchStride, true);
321
340
  view.setFloat32(8, gate ? resolvedSwigluLimit : 0, true);
322
341
  view.setFloat32(12, 0, true);
323
342
  },
@@ -333,8 +352,7 @@ export async function recordSiLU(
333
352
  entries,
334
353
  });
335
354
 
336
- const workgroups = Math.ceil(inferredSize / WORKGROUP_SIZES.DEFAULT);
337
- recordDispatch(recorder, pipeline, bindGroup, workgroups, 'silu');
355
+ recordDispatch(recorder, pipeline, bindGroup, dispatchPlan.workgroups, 'silu');
338
356
 
339
357
  return createTensor(output, input.dtype, [inferredSize], 'silu_output');
340
358
  }
@@ -10,13 +10,14 @@
10
10
  override WORKGROUP_SIZE: u32 = 256u;
11
11
  override HAS_GATE: bool = false;
12
12
  override GATE_USE_SIGMOID: bool = false;
13
+ override INPUT_USE_IDENTITY: bool = false;
13
14
  override USE_SPLIT: bool = false;
14
15
  override USE_VEC4: bool = false;
15
16
  override USE_ROWSPLIT: bool = false;
16
17
 
17
18
  struct Uniforms {
18
19
  size: u32, // Total output elements
19
- rowsplit_dim: u32, // Dim for rowsplit variants (0 when unused)
20
+ rowsplit_dim: u32, // Row-split dim or dispatch stride for non-row-split variants
20
21
  clamp_max: f32, // SwiGLU clamp (0 = disabled)
21
22
  _pad1: f32,
22
23
  }
@@ -35,6 +36,10 @@ fn silu(x: f32) -> f32 {
35
36
  return x * sigmoid(x);
36
37
  }
37
38
 
39
+ fn apply_input_activation(x: f32) -> f32 {
40
+ return select(silu(x), x, INPUT_USE_IDENTITY);
41
+ }
42
+
38
43
  fn clamp_swiglu(x: f32) -> f32 {
39
44
  if (u.clamp_max <= 0.0) {
40
45
  return x;
@@ -46,8 +51,9 @@ fn clamp_swiglu(x: f32) -> f32 {
46
51
  fn main(
47
52
  @builtin(global_invocation_id) global_id: vec3<u32>
48
53
  ) {
54
+ let dispatch_stride = max(u.rowsplit_dim, 1u);
49
55
  if (USE_VEC4) {
50
- let base_idx = global_id.x * 4u;
56
+ let base_idx = global_id.y * dispatch_stride + global_id.x * 4u;
51
57
  if (base_idx >= u.size) {
52
58
  return;
53
59
  }
@@ -55,12 +61,12 @@ fn main(
55
61
  let remaining = min(4u, u.size - base_idx);
56
62
  for (var i: u32 = 0u; i < remaining; i = i + 1u) {
57
63
  let x = input[base_idx + i];
58
- output[base_idx + i] = silu(x);
64
+ output[base_idx + i] = apply_input_activation(x);
59
65
  }
60
66
  return;
61
67
  }
62
68
 
63
- let idx = global_id.x;
69
+ let idx = global_id.y * dispatch_stride + global_id.x;
64
70
  if (idx >= u.size) {
65
71
  return;
66
72
  }
@@ -70,12 +76,16 @@ fn main(
70
76
  return;
71
77
  }
72
78
  let dim = u.rowsplit_dim;
73
- let token_idx = idx / dim;
74
- let dim_idx = idx % dim;
79
+ let num_tokens = u.size / dim;
80
+ let token_idx = global_id.y;
81
+ let dim_idx = global_id.x;
82
+ if (token_idx >= num_tokens || dim_idx >= dim) {
83
+ return;
84
+ }
75
85
  let row_base = token_idx * dim * 2u;
76
86
  let g = input[row_base + dim_idx];
77
87
  let up = input[row_base + dim + dim_idx];
78
- output[idx] = clamp_swiglu(silu(g) * up);
88
+ output[token_idx * dim + dim_idx] = clamp_swiglu(silu(g) * up);
79
89
  return;
80
90
  }
81
91
 
@@ -83,7 +93,7 @@ fn main(
83
93
  let up = input[idx];
84
94
  let g = gate[idx];
85
95
  let gateAct = select(silu(g), sigmoid(g), GATE_USE_SIGMOID);
86
- output[idx] = clamp_swiglu(gateAct * up);
96
+ output[idx] = clamp_swiglu(gateAct * apply_input_activation(up));
87
97
  return;
88
98
  }
89
99
 
@@ -95,5 +105,5 @@ fn main(
95
105
  }
96
106
 
97
107
  let x = input[idx];
98
- output[idx] = silu(x);
108
+ output[idx] = apply_input_activation(x);
99
109
  }
@@ -9,13 +9,14 @@ enable f16;
9
9
  override WORKGROUP_SIZE: u32 = 256u;
10
10
  override HAS_GATE: bool = false;
11
11
  override GATE_USE_SIGMOID: bool = false;
12
+ override INPUT_USE_IDENTITY: bool = false;
12
13
  override USE_SPLIT: bool = false;
13
14
  override USE_VEC4: bool = false;
14
15
  override USE_ROWSPLIT: bool = false;
15
16
 
16
17
  struct Uniforms {
17
18
  size: u32, // Total output elements
18
- rowsplit_dim: u32, // Dim for rowsplit variants (0 when unused)
19
+ rowsplit_dim: u32, // Row-split dim or dispatch stride for non-row-split variants
19
20
  clamp_max: f32, // SwiGLU clamp (0 = disabled)
20
21
  _pad1: f32,
21
22
  }
@@ -34,6 +35,10 @@ fn silu(x: f32) -> f32 {
34
35
  return x * sigmoid(x);
35
36
  }
36
37
 
38
+ fn apply_input_activation(x: f32) -> f32 {
39
+ return select(silu(x), x, INPUT_USE_IDENTITY);
40
+ }
41
+
37
42
  fn clamp_swiglu(x: f32) -> f32 {
38
43
  if (u.clamp_max <= 0.0) {
39
44
  return x;
@@ -45,8 +50,9 @@ fn clamp_swiglu(x: f32) -> f32 {
45
50
  fn main(
46
51
  @builtin(global_invocation_id) global_id: vec3<u32>
47
52
  ) {
53
+ let dispatch_stride = max(u.rowsplit_dim, 1u);
48
54
  if (USE_VEC4) {
49
- let base_idx = global_id.x * 4u;
55
+ let base_idx = global_id.y * dispatch_stride + global_id.x * 4u;
50
56
  if (base_idx >= u.size) {
51
57
  return;
52
58
  }
@@ -54,12 +60,12 @@ fn main(
54
60
  let remaining = min(4u, u.size - base_idx);
55
61
  for (var i: u32 = 0u; i < remaining; i = i + 1u) {
56
62
  let x = f32(input[base_idx + i]);
57
- output[base_idx + i] = f16(silu(x));
63
+ output[base_idx + i] = f16(apply_input_activation(x));
58
64
  }
59
65
  return;
60
66
  }
61
67
 
62
- let idx = global_id.x;
68
+ let idx = global_id.y * dispatch_stride + global_id.x;
63
69
  if (idx >= u.size) {
64
70
  return;
65
71
  }
@@ -69,12 +75,16 @@ fn main(
69
75
  return;
70
76
  }
71
77
  let dim = u.rowsplit_dim;
72
- let token_idx = idx / dim;
73
- let dim_idx = idx % dim;
78
+ let num_tokens = u.size / dim;
79
+ let token_idx = global_id.y;
80
+ let dim_idx = global_id.x;
81
+ if (token_idx >= num_tokens || dim_idx >= dim) {
82
+ return;
83
+ }
74
84
  let row_base = token_idx * dim * 2u;
75
85
  let g = f32(input[row_base + dim_idx]);
76
86
  let up = f32(input[row_base + dim + dim_idx]);
77
- output[idx] = f16(clamp_swiglu(silu(g) * up));
87
+ output[token_idx * dim + dim_idx] = f16(clamp_swiglu(silu(g) * up));
78
88
  return;
79
89
  }
80
90
 
@@ -82,7 +92,7 @@ fn main(
82
92
  let up = f32(input[idx]);
83
93
  let g = f32(gate[idx]);
84
94
  let gateAct = select(silu(g), sigmoid(g), GATE_USE_SIGMOID);
85
- output[idx] = f16(clamp_swiglu(gateAct * up));
95
+ output[idx] = f16(clamp_swiglu(gateAct * apply_input_activation(up)));
86
96
  return;
87
97
  }
88
98
 
@@ -94,5 +104,5 @@ fn main(
94
104
  }
95
105
 
96
106
  let x = f32(input[idx]);
97
- output[idx] = f16(silu(x));
107
+ output[idx] = f16(apply_input_activation(x));
98
108
  }
@@ -3,19 +3,32 @@ import { createTensor, dtypeBytes } from '../tensor.js';
3
3
  import { WORKGROUP_SIZES } from './constants.js';
4
4
  import { unifiedKernelWrapper } from './utils.js';
5
5
 
6
+ function planTransposeDispatch(target, cols) {
7
+ const device = target?.device;
8
+ const maxPerDim = Number.isFinite(device?.limits?.maxComputeWorkgroupsPerDimension)
9
+ ? device.limits.maxComputeWorkgroupsPerDimension
10
+ : 65535;
11
+ const dispatchStride = Math.min(cols, maxPerDim * WORKGROUP_SIZES.DEFAULT);
12
+ return {
13
+ dispatchStride,
14
+ workgroups: [Math.ceil(dispatchStride / WORKGROUP_SIZES.DEFAULT), 1, 1],
15
+ };
16
+ }
17
+
6
18
  async function _transpose(target, input, rows, cols, options = {}) {
7
19
  const { outputBuffer = null } = options;
8
20
  const bytesPerElement = dtypeBytes(input.dtype);
9
21
  const outputSize = rows * cols * bytesPerElement;
10
22
  const outputBuf = outputBuffer || acquireBuffer(outputSize, undefined, 'transpose_output');
23
+ const dispatchPlan = planTransposeDispatch(target, cols);
11
24
 
12
25
  await unifiedKernelWrapper(
13
26
  'transpose',
14
27
  target,
15
28
  'default',
16
29
  [input, outputBuf],
17
- { rows, cols },
18
- Math.ceil((rows * cols) / WORKGROUP_SIZES.DEFAULT)
30
+ { rows, cols, _pad0: dispatchPlan.dispatchStride, _pad1: 0 },
31
+ [dispatchPlan.workgroups[0], rows, 1]
19
32
  );
20
33
 
21
34
  return createTensor(outputBuf, input.dtype, [cols, rows], 'transpose_output');
@@ -19,14 +19,13 @@ struct Uniforms {
19
19
 
20
20
  @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
21
21
  fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
22
- let idx = gid.x;
23
- let total = u.rows * u.cols;
24
- if (idx >= total) {
22
+ let dispatch_stride = max(u._pad0, 1u);
23
+ let row = gid.y;
24
+ let col = gid.x + row * dispatch_stride;
25
+ if (row >= u.rows || col >= u.cols) {
25
26
  return;
26
27
  }
27
-
28
- let row = idx / u.cols;
29
- let col = idx % u.cols;
28
+ let idx = row * u.cols + col;
30
29
  let out_idx = col * u.rows + row;
31
30
  output[out_idx] = input[idx];
32
31
  }
@@ -31,6 +31,7 @@ async function _upsample2d(target, input, options = {}) {
31
31
 
32
32
  const outHeight = resolvedHeight * scale;
33
33
  const outWidth = resolvedWidth * scale;
34
+ const outSpatial = outHeight * outWidth;
34
35
  const bytesPerElement = dtypeBytes(input.dtype);
35
36
  const outputSize = channels * outHeight * outWidth * bytesPerElement;
36
37
  const output = outputBuffer || acquireBuffer(outputSize, undefined, 'upsample2d_output');
@@ -43,7 +44,7 @@ async function _upsample2d(target, input, options = {}) {
43
44
  out_height: outHeight, out_width: outWidth, scale,
44
45
  _pad0: 0, _pad1: 0,
45
46
  },
46
- Math.ceil((channels * outHeight * outWidth) / WORKGROUP_SIZES.DEFAULT)
47
+ [Math.ceil(outSpatial / WORKGROUP_SIZES.DEFAULT), channels, 1]
47
48
  );
48
49
 
49
50
  return createTensor(output, input.dtype, [channels, outHeight, outWidth], 'upsample2d_output');
@@ -19,19 +19,16 @@ struct Uniforms {
19
19
 
20
20
  @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
21
21
  fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
22
- let idx = gid.x;
23
22
  let out_spatial = u.out_height * u.out_width;
24
- let total = u.channels * out_spatial;
25
- if (idx >= total) {
23
+ let spatial_idx = gid.x;
24
+ let channel = gid.y;
25
+ if (spatial_idx >= out_spatial || channel >= u.channels) {
26
26
  return;
27
27
  }
28
-
29
- let channel = idx / out_spatial;
30
- let rem = idx - channel * out_spatial;
31
- let out_y = rem / u.out_width;
32
- let out_x = rem - out_y * u.out_width;
28
+ let out_y = spatial_idx / u.out_width;
29
+ let out_x = spatial_idx - out_y * u.out_width;
33
30
  let in_y = out_y / u.scale;
34
31
  let in_x = out_x / u.scale;
35
32
  let in_idx = (channel * u.in_height + in_y) * u.in_width + in_x;
36
- output[idx] = input[in_idx];
33
+ output[channel * out_spatial + spatial_idx] = input[in_idx];
37
34
  }
@@ -23,19 +23,16 @@ struct Uniforms {
23
23
 
24
24
  @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
25
25
  fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
26
- let idx = gid.x;
27
26
  let out_spatial = u.out_height * u.out_width;
28
- let total = u.channels * out_spatial;
29
- if (idx >= total) {
27
+ let spatial_idx = gid.x;
28
+ let channel = gid.y;
29
+ if (spatial_idx >= out_spatial || channel >= u.channels) {
30
30
  return;
31
31
  }
32
-
33
- let channel = idx / out_spatial;
34
- let rem = idx - channel * out_spatial;
35
- let out_y = rem / u.out_width;
36
- let out_x = rem - out_y * u.out_width;
32
+ let out_y = spatial_idx / u.out_width;
33
+ let out_x = spatial_idx - out_y * u.out_width;
37
34
  let in_y = out_y / u.scale;
38
35
  let in_x = out_x / u.scale;
39
36
  let in_idx = (channel * u.in_height + in_y) * u.in_width + in_x;
40
- output[idx] = input[in_idx];
37
+ output[channel * out_spatial + spatial_idx] = input[in_idx];
41
38
  }
@@ -116,9 +116,24 @@ export async function unifiedKernelWrapper(opName, target, variant, bindings, un
116
116
  index = config.variantMetadata.outputBinding;
117
117
  }
118
118
 
119
+ const buffer = binding?.buffer || binding;
120
+ const isGpuBuffer = buffer && (
121
+ typeof GPUBuffer === 'undefined'
122
+ ? true
123
+ : buffer instanceof GPUBuffer
124
+ );
125
+ if (!isGpuBuffer) {
126
+ const bindingLabel = binding?.label ?? buffer?.label ?? 'unknown';
127
+ const bufferType = buffer === null ? 'null' : buffer === undefined ? 'undefined' : buffer.constructor?.name || typeof buffer;
128
+ throw new Error(
129
+ `Kernel "${opName}/${variant}" binding "${bindingConfig.name}" (index ${index}) requires a GPUBuffer ` +
130
+ `(label=${bindingLabel}, type=${bufferType}).`
131
+ );
132
+ }
133
+
119
134
  bindGroupEntries.push({
120
135
  binding: index,
121
- resource: { buffer: binding?.buffer || binding }
136
+ resource: { buffer }
122
137
  });
123
138
  }
124
139
 
@@ -1,5 +1,5 @@
1
1
  export declare const DOPPLER_VERSION: string;
2
- export { doppler } from './client/doppler-api.js';
2
+ export { doppler } from './client/doppler-api.browser.js';
3
3
 
4
4
  export {
5
5
  DopplerLoader,