@simulatte/doppler 0.1.4 → 0.1.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (199) hide show
  1. package/README.md +26 -10
  2. package/package.json +30 -6
  3. package/src/client/doppler-api.browser.d.ts +1 -0
  4. package/src/client/doppler-api.browser.js +288 -0
  5. package/src/client/doppler-api.js +1 -1
  6. package/src/client/doppler-provider/types.js +1 -1
  7. package/src/config/execution-contract-check.d.ts +33 -0
  8. package/src/config/execution-contract-check.js +72 -0
  9. package/src/config/execution-v0-contract-check.d.ts +94 -0
  10. package/src/config/execution-v0-contract-check.js +251 -0
  11. package/src/config/execution-v0-graph-contract-check.d.ts +20 -0
  12. package/src/config/execution-v0-graph-contract-check.js +64 -0
  13. package/src/config/kernel-path-contract-check.d.ts +76 -0
  14. package/src/config/kernel-path-contract-check.js +479 -0
  15. package/src/config/kernel-path-loader.d.ts +16 -0
  16. package/src/config/kernel-path-loader.js +54 -0
  17. package/src/config/kernels/kernel-ref-digests.js +39 -27
  18. package/src/config/kernels/registry.json +598 -2
  19. package/src/config/loader.js +81 -48
  20. package/src/config/merge-contract-check.d.ts +16 -0
  21. package/src/config/merge-contract-check.js +321 -0
  22. package/src/config/merge-helpers.d.ts +58 -0
  23. package/src/config/merge-helpers.js +54 -0
  24. package/src/config/merge.js +21 -6
  25. package/src/config/presets/models/janus-text.json +2 -0
  26. package/src/config/presets/models/qwen3.json +9 -2
  27. package/src/config/presets/models/transformer.json +5 -0
  28. package/src/config/quantization-contract-check.d.ts +12 -0
  29. package/src/config/quantization-contract-check.js +91 -0
  30. package/src/config/required-inference-fields-contract-check.d.ts +24 -0
  31. package/src/config/required-inference-fields-contract-check.js +237 -0
  32. package/src/config/schema/browser-suite-metrics.schema.d.ts +17 -0
  33. package/src/config/schema/browser-suite-metrics.schema.js +46 -0
  34. package/src/config/schema/conversion-report.schema.d.ts +40 -0
  35. package/src/config/schema/conversion-report.schema.js +108 -0
  36. package/src/config/schema/doppler.schema.js +12 -18
  37. package/src/config/schema/index.d.ts +22 -0
  38. package/src/config/schema/index.js +18 -0
  39. package/src/config/schema/inference-defaults.schema.js +3 -0
  40. package/src/config/schema/inference.schema.d.ts +9 -0
  41. package/src/config/schema/kernel-path.schema.d.ts +6 -0
  42. package/src/config/schema/manifest.schema.d.ts +6 -0
  43. package/src/config/schema/manifest.schema.js +3 -0
  44. package/src/converter/core.d.ts +10 -0
  45. package/src/converter/core.js +27 -2
  46. package/src/converter/parsers/diffusion.js +63 -3
  47. package/src/converter/rope-config.js +42 -0
  48. package/src/gpu/device.js +58 -0
  49. package/src/gpu/kernels/attention.js +98 -0
  50. package/src/gpu/kernels/bias_add.wgsl +8 -6
  51. package/src/gpu/kernels/bias_add_f16.wgsl +8 -5
  52. package/src/gpu/kernels/conv2d.js +1 -1
  53. package/src/gpu/kernels/conv2d.wgsl +7 -8
  54. package/src/gpu/kernels/conv2d_f16.wgsl +7 -8
  55. package/src/gpu/kernels/depthwise_conv2d.d.ts +29 -0
  56. package/src/gpu/kernels/depthwise_conv2d.js +99 -0
  57. package/src/gpu/kernels/depthwise_conv2d.wgsl +55 -0
  58. package/src/gpu/kernels/depthwise_conv2d_f16.wgsl +59 -0
  59. package/src/gpu/kernels/grouped_pointwise_conv2d.d.ts +27 -0
  60. package/src/gpu/kernels/grouped_pointwise_conv2d.js +93 -0
  61. package/src/gpu/kernels/grouped_pointwise_conv2d.wgsl +44 -0
  62. package/src/gpu/kernels/grouped_pointwise_conv2d_f16.wgsl +48 -0
  63. package/src/gpu/kernels/index.d.ts +30 -0
  64. package/src/gpu/kernels/index.js +25 -0
  65. package/src/gpu/kernels/matmul.js +25 -0
  66. package/src/gpu/kernels/pixel_shuffle.js +1 -1
  67. package/src/gpu/kernels/pixel_shuffle.wgsl +4 -5
  68. package/src/gpu/kernels/pixel_shuffle_f16.wgsl +4 -5
  69. package/src/gpu/kernels/relu.d.ts +18 -0
  70. package/src/gpu/kernels/relu.js +58 -0
  71. package/src/gpu/kernels/relu.wgsl +22 -0
  72. package/src/gpu/kernels/relu_f16.wgsl +24 -0
  73. package/src/gpu/kernels/repeat_channels.d.ts +21 -0
  74. package/src/gpu/kernels/repeat_channels.js +60 -0
  75. package/src/gpu/kernels/repeat_channels.wgsl +28 -0
  76. package/src/gpu/kernels/repeat_channels_f16.wgsl +30 -0
  77. package/src/gpu/kernels/residual.js +44 -8
  78. package/src/gpu/kernels/residual.wgsl +6 -3
  79. package/src/gpu/kernels/residual_f16.wgsl +2 -1
  80. package/src/gpu/kernels/residual_f16_vec4.wgsl +2 -1
  81. package/src/gpu/kernels/residual_vec4.wgsl +2 -1
  82. package/src/gpu/kernels/rmsnorm.js +58 -6
  83. package/src/gpu/kernels/rmsnorm.wgsl +14 -6
  84. package/src/gpu/kernels/rmsnorm_f16.wgsl +10 -2
  85. package/src/gpu/kernels/rope.d.ts +2 -0
  86. package/src/gpu/kernels/rope.js +11 -1
  87. package/src/gpu/kernels/rope.wgsl +56 -40
  88. package/src/gpu/kernels/sana_linear_attention.d.ts +27 -0
  89. package/src/gpu/kernels/sana_linear_attention.js +121 -0
  90. package/src/gpu/kernels/sana_linear_attention_apply.wgsl +43 -0
  91. package/src/gpu/kernels/sana_linear_attention_apply_f16.wgsl +46 -0
  92. package/src/gpu/kernels/sana_linear_attention_summary.wgsl +51 -0
  93. package/src/gpu/kernels/sana_linear_attention_summary_f16.wgsl +53 -0
  94. package/src/gpu/kernels/silu.d.ts +1 -0
  95. package/src/gpu/kernels/silu.js +32 -14
  96. package/src/gpu/kernels/silu.wgsl +19 -9
  97. package/src/gpu/kernels/silu_f16.wgsl +19 -9
  98. package/src/gpu/kernels/transpose.js +15 -2
  99. package/src/gpu/kernels/transpose.wgsl +5 -6
  100. package/src/gpu/kernels/upsample2d.js +2 -1
  101. package/src/gpu/kernels/upsample2d.wgsl +6 -9
  102. package/src/gpu/kernels/upsample2d_f16.wgsl +6 -9
  103. package/src/gpu/kernels/utils.js +16 -1
  104. package/src/index-browser.d.ts +1 -1
  105. package/src/index-browser.js +2 -2
  106. package/src/index.js +1 -1
  107. package/src/inference/browser-harness.js +109 -23
  108. package/src/inference/pipelines/diffusion/init.js +14 -0
  109. package/src/inference/pipelines/diffusion/pipeline.js +215 -77
  110. package/src/inference/pipelines/diffusion/sana-transformer.d.ts +53 -0
  111. package/src/inference/pipelines/diffusion/sana-transformer.js +738 -0
  112. package/src/inference/pipelines/diffusion/scheduler.d.ts +17 -1
  113. package/src/inference/pipelines/diffusion/scheduler.js +91 -3
  114. package/src/inference/pipelines/diffusion/text-encoder-gpu.d.ts +11 -4
  115. package/src/inference/pipelines/diffusion/text-encoder-gpu.js +282 -0
  116. package/src/inference/pipelines/diffusion/text-encoder.js +18 -1
  117. package/src/inference/pipelines/diffusion/types.d.ts +4 -0
  118. package/src/inference/pipelines/diffusion/vae.js +782 -78
  119. package/src/inference/pipelines/text/attention/record.js +11 -2
  120. package/src/inference/pipelines/text/attention/run.js +11 -2
  121. package/src/inference/pipelines/text/chat-format.js +25 -1
  122. package/src/inference/pipelines/text/config.d.ts +9 -0
  123. package/src/inference/pipelines/text/config.js +69 -2
  124. package/src/inference/pipelines/text/execution-plan.js +23 -31
  125. package/src/inference/pipelines/text/execution-v0.js +43 -95
  126. package/src/inference/pipelines/text/ffn/standard.js +3 -0
  127. package/src/inference/pipelines/text/init.d.ts +4 -0
  128. package/src/inference/pipelines/text/init.js +56 -9
  129. package/src/inference/pipelines/text/layer.js +11 -0
  130. package/src/inference/pipelines/text.js +4 -0
  131. package/src/inference/tokenizers/bundled.js +156 -33
  132. package/src/rules/execution-rules-contract-check.d.ts +17 -0
  133. package/src/rules/execution-rules-contract-check.js +245 -0
  134. package/src/rules/kernels/depthwise-conv2d.rules.json +6 -0
  135. package/src/rules/kernels/grouped-pointwise-conv2d.rules.json +6 -0
  136. package/src/rules/kernels/relu.rules.json +6 -0
  137. package/src/rules/kernels/repeat-channels.rules.json +6 -0
  138. package/src/rules/kernels/sana-linear-attention.rules.json +6 -0
  139. package/src/rules/layer-pattern-contract-check.d.ts +17 -0
  140. package/src/rules/layer-pattern-contract-check.js +231 -0
  141. package/src/rules/rule-registry.d.ts +28 -0
  142. package/src/rules/rule-registry.js +38 -0
  143. package/src/rules/tooling/command-runtime.rules.json +18 -0
  144. package/src/tooling/command-api.d.ts +27 -1
  145. package/src/tooling/command-api.js +142 -3
  146. package/src/tooling/conversion-config-materializer.d.ts +24 -0
  147. package/src/tooling/conversion-config-materializer.js +99 -0
  148. package/src/tooling/lean-execution-contract-runner.d.ts +43 -0
  149. package/src/tooling/lean-execution-contract-runner.js +158 -0
  150. package/src/tooling/node-browser-command-runner.d.ts +4 -0
  151. package/src/tooling/node-browser-command-runner.js +58 -3
  152. package/src/tooling/node-command-runner.js +15 -0
  153. package/src/tooling/node-convert.d.ts +10 -0
  154. package/src/tooling/node-converter.js +59 -0
  155. package/src/tooling/node-webgpu.js +11 -89
  156. package/src/training/checkpoint-watch.d.ts +7 -0
  157. package/src/training/checkpoint-watch.js +106 -0
  158. package/src/training/checkpoint.d.ts +6 -1
  159. package/src/training/checkpoint.js +12 -2
  160. package/src/training/distillation/artifacts.d.ts +71 -0
  161. package/src/training/distillation/artifacts.js +132 -0
  162. package/src/training/distillation/checkpoint-watch.d.ts +10 -0
  163. package/src/training/distillation/checkpoint-watch.js +57 -0
  164. package/src/training/distillation/dataset.d.ts +59 -0
  165. package/src/training/distillation/dataset.js +337 -0
  166. package/src/training/distillation/eval.d.ts +34 -0
  167. package/src/training/distillation/eval.js +310 -0
  168. package/src/training/distillation/index.d.ts +29 -0
  169. package/src/training/distillation/index.js +29 -0
  170. package/src/training/distillation/runtime.d.ts +20 -0
  171. package/src/training/distillation/runtime.js +121 -0
  172. package/src/training/distillation/scoreboard.d.ts +6 -0
  173. package/src/training/distillation/scoreboard.js +8 -0
  174. package/src/training/distillation/stage-a.d.ts +45 -0
  175. package/src/training/distillation/stage-a.js +338 -0
  176. package/src/training/distillation/stage-b.d.ts +24 -0
  177. package/src/training/distillation/stage-b.js +20 -0
  178. package/src/training/index.d.ts +10 -0
  179. package/src/training/index.js +10 -0
  180. package/src/training/lora-pipeline.d.ts +40 -0
  181. package/src/training/lora-pipeline.js +796 -0
  182. package/src/training/operator-artifacts.d.ts +62 -0
  183. package/src/training/operator-artifacts.js +140 -0
  184. package/src/training/operator-command.d.ts +5 -0
  185. package/src/training/operator-command.js +453 -0
  186. package/src/training/operator-eval.d.ts +48 -0
  187. package/src/training/operator-eval.js +230 -0
  188. package/src/training/operator-scoreboard.d.ts +5 -0
  189. package/src/training/operator-scoreboard.js +44 -0
  190. package/src/training/runner.d.ts +52 -0
  191. package/src/training/runner.js +29 -4
  192. package/src/training/suite.d.ts +112 -0
  193. package/src/training/suite.js +9 -9
  194. package/src/training/workloads.d.ts +164 -0
  195. package/src/training/workloads.js +539 -0
  196. package/src/version.d.ts +2 -0
  197. package/src/version.js +2 -0
  198. package/tools/convert-safetensors-node.js +47 -0
  199. package/tools/doppler-cli.js +252 -41
@@ -58,6 +58,36 @@ function resolveNormWeightDtype(weight, hiddenSize) {
58
58
  return 'f32';
59
59
  }
60
60
 
61
+ function assertRMSNormWeightBuffer(weight, weightBuffer, hiddenSize) {
62
+ const isGpuBuffer = weightBuffer && (
63
+ typeof GPUBuffer === 'undefined'
64
+ ? true
65
+ : weightBuffer instanceof GPUBuffer
66
+ );
67
+ if (isGpuBuffer) {
68
+ return;
69
+ }
70
+ const weightLabel = weight?.label ?? 'unknown';
71
+ const weightType = weight === null ? 'null' : weight === undefined ? 'undefined' : weight.constructor?.name || typeof weight;
72
+ const bufferType = weightBuffer === null ? 'null' : weightBuffer === undefined ? 'undefined' : weightBuffer.constructor?.name || typeof weightBuffer;
73
+ throw new Error(
74
+ `[rmsnorm] weight "${weightLabel}" requires a GPUBuffer ` +
75
+ `(weightType=${weightType}, bufferType=${bufferType}, hiddenSize=${hiddenSize ?? 'unknown'}).`
76
+ );
77
+ }
78
+
79
+ function planRMSNormDispatch(target, numTokens) {
80
+ const device = target?.device;
81
+ const maxPerDim = Number.isFinite(device?.limits?.maxComputeWorkgroupsPerDimension)
82
+ ? device.limits.maxComputeWorkgroupsPerDimension
83
+ : 65535;
84
+ const tokenStride = Math.min(numTokens, maxPerDim);
85
+ return {
86
+ tokenStride,
87
+ workgroups: [tokenStride, Math.ceil(numTokens / tokenStride), 1],
88
+ };
89
+ }
90
+
61
91
  export function selectRMSNormKernel(options = {}, isF16 = false) {
62
92
  const { residual = null, hiddenSize = null } = options;
63
93
  const { smallThreshold } = getKernelThresholds().rmsnorm;
@@ -82,23 +112,34 @@ export async function runRMSNorm(
82
112
  const variant = selectRMSNormKernel(options, isF16);
83
113
  const inferredHiddenSize = inferHiddenSize(input, hiddenSize);
84
114
  const normWeightBuffer = getBuffer(weight);
115
+ assertRMSNormWeightBuffer(weight, normWeightBuffer, inferredHiddenSize);
85
116
  const normWeightDtype = resolveNormWeightDtype(weight, inferredHiddenSize);
86
117
 
87
118
  const bytesPerElement = isF16 ? 2 : 4;
88
119
  const paddedHiddenSize = padToQ4KBlock(inferredHiddenSize);
89
120
  const outputSize = batchSize * paddedHiddenSize * bytesPerElement;
90
121
  const outputBuf = outputBuffer || acquireBuffer(outputSize, undefined, 'rmsnorm_output');
122
+ const dispatchPlan = planRMSNormDispatch(null, batchSize);
91
123
 
92
124
  // Shader layout always includes the residual binding; when unused, bind a harmless placeholder.
93
- const residualBuf = residual?.buffer || input.buffer;
125
+ const residualBuf = residual?.buffer || residual || input?.buffer || input || outputBuf;
94
126
 
95
127
  await unifiedKernelWrapper(
96
128
  'rmsnorm',
97
129
  null,
98
130
  variant,
99
131
  [input, normWeightBuffer, outputBuf, residualBuf],
100
- { hidden_size: inferredHiddenSize, num_tokens: batchSize, eps, has_residual: residual ? 1 : 0 },
101
- batchSize,
132
+ {
133
+ hidden_size: inferredHiddenSize,
134
+ num_tokens: batchSize,
135
+ eps,
136
+ has_residual: residual ? 1 : 0,
137
+ token_stride: dispatchPlan.tokenStride,
138
+ _pad0: 0,
139
+ _pad1: 0,
140
+ _pad2: 0,
141
+ },
142
+ dispatchPlan.workgroups,
102
143
  { RMS_NORM_OFFSET: rmsNormWeightOffset, WEIGHT_IS_F16: normWeightDtype === 'f16' }
103
144
  );
104
145
 
@@ -117,22 +158,33 @@ export async function recordRMSNorm(
117
158
  const variant = selectRMSNormKernel(options, isF16);
118
159
  const inferredHiddenSize = inferHiddenSize(input, hiddenSize);
119
160
  const normWeightBuffer = getBuffer(weight);
161
+ assertRMSNormWeightBuffer(weight, normWeightBuffer, inferredHiddenSize);
120
162
  const normWeightDtype = resolveNormWeightDtype(weight, inferredHiddenSize);
121
163
 
122
164
  const bytesPerElement = isF16 ? 2 : 4;
123
165
  const paddedHiddenSize = padToQ4KBlock(inferredHiddenSize);
124
166
  const outputSize = batchSize * paddedHiddenSize * bytesPerElement;
125
167
  const outputBuf = outputBuffer || acquireBuffer(outputSize, undefined, 'rmsnorm_output');
168
+ const dispatchPlan = planRMSNormDispatch(recorder, batchSize);
126
169
 
127
- const residualBuf = residual?.buffer || input.buffer;
170
+ const residualBuf = residual?.buffer || residual || input?.buffer || input || outputBuf;
128
171
 
129
172
  await unifiedKernelWrapper(
130
173
  'rmsnorm',
131
174
  recorder,
132
175
  variant,
133
176
  [input, normWeightBuffer, outputBuf, residualBuf],
134
- { hidden_size: inferredHiddenSize, num_tokens: batchSize, eps, has_residual: residual ? 1 : 0 },
135
- batchSize,
177
+ {
178
+ hidden_size: inferredHiddenSize,
179
+ num_tokens: batchSize,
180
+ eps,
181
+ has_residual: residual ? 1 : 0,
182
+ token_stride: dispatchPlan.tokenStride,
183
+ _pad0: 0,
184
+ _pad1: 0,
185
+ _pad2: 0,
186
+ },
187
+ dispatchPlan.workgroups,
136
188
  { RMS_NORM_OFFSET: rmsNormWeightOffset, WEIGHT_IS_F16: normWeightDtype === 'f16' }
137
189
  );
138
190
 
@@ -39,6 +39,10 @@ struct Uniforms {
39
39
  num_tokens: u32, // Number of tokens to process
40
40
  eps: f32, // Epsilon for numerical stability (typically 1e-5 or 1e-6)
41
41
  has_residual: u32, // Runtime flag: 1 = add residual after norm
42
+ token_stride: u32, // Workgroup rows per dispatch row
43
+ _pad0: u32,
44
+ _pad1: u32,
45
+ _pad2: u32,
42
46
  }
43
47
 
44
48
  @group(0) @binding(0) var<uniform> u: Uniforms;
@@ -82,6 +86,10 @@ fn should_add_residual() -> bool {
82
86
  return HAS_RESIDUAL || (u.has_residual != 0u);
83
87
  }
84
88
 
89
+ fn token_index(wg_id: vec3<u32>) -> u32 {
90
+ return wg_id.y * max(u.token_stride, 1u) + wg_id.x;
91
+ }
92
+
85
93
  // =============================================================================
86
94
  // Main Entry Point
87
95
  // =============================================================================
@@ -93,7 +101,7 @@ fn main(
93
101
  @builtin(local_invocation_id) local_id: vec3<u32>,
94
102
  @builtin(workgroup_id) wg_id: vec3<u32>
95
103
  ) {
96
- let token_idx = wg_id.x;
104
+ let token_idx = token_index(wg_id);
97
105
  let thread_idx = local_id.x;
98
106
  let size = u.size;
99
107
 
@@ -163,7 +171,7 @@ fn main_small(
163
171
  @builtin(local_invocation_id) local_id: vec3<u32>,
164
172
  @builtin(workgroup_id) wg_id: vec3<u32>
165
173
  ) {
166
- let token_idx = wg_id.x;
174
+ let token_idx = token_index(wg_id);
167
175
  let thread_idx = local_id.x;
168
176
  let size = u.size;
169
177
 
@@ -219,7 +227,7 @@ fn main_cached(
219
227
  @builtin(local_invocation_id) local_id: vec3<u32>,
220
228
  @builtin(workgroup_id) wg_id: vec3<u32>
221
229
  ) {
222
- let token_idx = wg_id.x;
230
+ let token_idx = token_index(wg_id);
223
231
  let thread_idx = local_id.x;
224
232
  let size = u.size;
225
233
 
@@ -288,7 +296,7 @@ fn main_subgroup(
288
296
  @builtin(subgroup_invocation_id) sg_lane: u32,
289
297
  @builtin(subgroup_size) sg_size: u32,
290
298
  ) {
291
- let token_idx = wg_id.x;
299
+ let token_idx = token_index(wg_id);
292
300
  let thread_idx = local_id.x;
293
301
  let size = u.size;
294
302
 
@@ -362,7 +370,7 @@ fn main_small_subgroup(
362
370
  @builtin(subgroup_invocation_id) sg_lane: u32,
363
371
  @builtin(subgroup_size) sg_size: u32,
364
372
  ) {
365
- let token_idx = wg_id.x;
373
+ let token_idx = token_index(wg_id);
366
374
  let thread_idx = local_id.x;
367
375
  let size = u.size;
368
376
 
@@ -414,4 +422,4 @@ fn main_small_subgroup(
414
422
  }
415
423
  output[base_offset + thread_idx] = result;
416
424
  }
417
- }
425
+ }
@@ -20,6 +20,10 @@ struct Uniforms {
20
20
  num_tokens: u32, // Number of tokens to process
21
21
  eps: f32, // Epsilon for numerical stability
22
22
  has_residual: u32, // 1 if residual input provided, 0 otherwise
23
+ token_stride: u32, // Workgroup rows per dispatch row
24
+ _pad0: u32,
25
+ _pad1: u32,
26
+ _pad2: u32,
23
27
  }
24
28
 
25
29
  @group(0) @binding(0) var<uniform> u: Uniforms;
@@ -47,6 +51,10 @@ fn load_weight(idx: u32) -> f32 {
47
51
  return bitcast<f32>(weight[idx]);
48
52
  }
49
53
 
54
+ fn token_index(wg_id: vec3<u32>) -> u32 {
55
+ return wg_id.y * max(u.token_stride, 1u) + wg_id.x;
56
+ }
57
+
50
58
  // Main RMSNorm kernel - one workgroup per token
51
59
  @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
52
60
  fn main(
@@ -54,7 +62,7 @@ fn main(
54
62
  @builtin(local_invocation_id) local_id: vec3<u32>,
55
63
  @builtin(workgroup_id) wg_id: vec3<u32>
56
64
  ) {
57
- let token_idx = wg_id.x;
65
+ let token_idx = token_index(wg_id);
58
66
  let thread_idx = local_id.x;
59
67
  let size = u.size;
60
68
 
@@ -121,7 +129,7 @@ fn rmsnorm_small_f16(
121
129
  @builtin(local_invocation_id) local_id: vec3<u32>,
122
130
  @builtin(workgroup_id) wg_id: vec3<u32>
123
131
  ) {
124
- let token_idx = wg_id.x;
132
+ let token_idx = token_index(wg_id);
125
133
  let thread_idx = local_id.x;
126
134
  let size = u.size;
127
135
 
@@ -15,6 +15,8 @@ import type { OutputBufferOptions } from './types.js';
15
15
  export interface RoPEOptions extends OutputBufferOptions {
16
16
  numHeads?: number;
17
17
  headDim?: number;
18
+ rotaryDim?: number;
19
+ interleaved?: boolean;
18
20
  ropeTheta?: number;
19
21
  startPos?: number;
20
22
  }
@@ -13,18 +13,26 @@ async function _rope(target, input, freqsCos, freqsSin, seqLen, options = {}) {
13
13
  const {
14
14
  numHeads = 1,
15
15
  headDim = 64,
16
+ rotaryDim = headDim,
17
+ interleaved = false,
16
18
  ropeTheta = ropeDefaults.defaultTheta,
17
19
  } = options;
18
20
 
19
21
  if (headDim % 2 !== 0) {
20
22
  throw new Error(`RoPE headDim must be even, got ${headDim}`);
21
23
  }
24
+ if (rotaryDim % 2 !== 0) {
25
+ throw new Error(`RoPE rotaryDim must be even, got ${rotaryDim}`);
26
+ }
27
+ if (rotaryDim <= 0 || rotaryDim > headDim) {
28
+ throw new Error(`RoPE rotaryDim must be in (0, headDim]; got ${rotaryDim} for headDim ${headDim}`);
29
+ }
22
30
 
23
31
  const caps = getKernelCapabilities();
24
32
  const useF16 = input.dtype === 'f16' && caps.hasF16;
25
33
  const variant = selectRuleValue('rope', 'variant', { useF16 });
26
34
 
27
- const halfDim = headDim / 2;
35
+ const halfDim = rotaryDim / 2;
28
36
  const workgroups = Math.ceil((seqLen * numHeads * halfDim) / WORKGROUP_SIZES.DEFAULT);
29
37
 
30
38
  await unifiedKernelWrapper(
@@ -34,9 +42,11 @@ async function _rope(target, input, freqsCos, freqsSin, seqLen, options = {}) {
34
42
  seq_len: seqLen,
35
43
  num_heads: numHeads,
36
44
  head_dim: headDim,
45
+ rotary_dim: rotaryDim,
37
46
  start_pos: options.startPos ?? ropeDefaults.defaultStartPos,
38
47
  rope_base: ropeTheta,
39
48
  rope_scale: 1.0,
49
+ interleaved: interleaved ? 1 : 0,
40
50
  },
41
51
  workgroups
42
52
  );
@@ -26,8 +26,8 @@ struct Uniforms {
26
26
  start_pos: u32, // Starting position (for decode)
27
27
  rope_base: f32, // Base frequency (default 10000)
28
28
  rope_scale: f32, // Scaling factor for extended context
29
- _pad0: u32,
30
- _pad1: u32,
29
+ rotary_dim: u32, // Rotary slice within head_dim
30
+ interleaved: u32, // 1 = adjacent pairs, 0 = rotate-half
31
31
  }
32
32
 
33
33
  @group(0) @binding(0) var<uniform> u: Uniforms;
@@ -46,7 +46,8 @@ fn main(
46
46
  let start_pos = u.start_pos;
47
47
 
48
48
  // Global thread index (one thread per complex pair)
49
- let half_dim = head_dim / 2u;
49
+ let rotary_dim = u.rotary_dim;
50
+ let half_dim = rotary_dim / 2u;
50
51
  let total_pairs = seq_len * num_heads * half_dim;
51
52
  let idx = global_id.x;
52
53
 
@@ -68,16 +69,18 @@ fn main(
68
69
 
69
70
  // Apply "rotate-half" layout: pair (x[i], x[i + half_dim])
70
71
  let base_idx = pos * num_heads * head_dim + head_idx * head_dim;
71
- let x0 = input[base_idx + pair_idx];
72
- let x1 = input[base_idx + pair_idx + half_dim];
72
+ let first_idx = select(pair_idx, pair_idx * 2u, u.interleaved == 1u);
73
+ let second_idx = select(pair_idx + half_dim, pair_idx * 2u + 1u, u.interleaved == 1u);
74
+ let x0 = input[base_idx + first_idx];
75
+ let x1 = input[base_idx + second_idx];
73
76
 
74
77
  // Apply rotation
75
78
  let y0 = x0 * cos_val - x1 * sin_val;
76
79
  let y1 = x0 * sin_val + x1 * cos_val;
77
80
 
78
81
  // Write back
79
- input[base_idx + pair_idx] = y0;
80
- input[base_idx + pair_idx + half_dim] = y1;
82
+ input[base_idx + first_idx] = y0;
83
+ input[base_idx + second_idx] = y1;
81
84
  }
82
85
 
83
86
  // Compute frequencies on-the-fly (no precomputation needed)
@@ -91,9 +94,10 @@ fn rope_compute_freqs(
91
94
  let start_pos = u.start_pos;
92
95
  let rope_base = u.rope_base;
93
96
  let rope_scale = u.rope_scale;
97
+ let rotary_dim = u.rotary_dim;
94
98
 
95
99
  let idx = global_id.x;
96
- let half_dim = head_dim / 2u;
100
+ let half_dim = rotary_dim / 2u;
97
101
  let total_pairs = seq_len * num_heads * half_dim;
98
102
 
99
103
  if (idx >= total_pairs) {
@@ -109,7 +113,7 @@ fn rope_compute_freqs(
109
113
  let actual_pos = f32(start_pos + pos) / rope_scale;
110
114
 
111
115
  // Compute frequency: 1 / (base^(2*pair_idx/head_dim))
112
- let exponent = f32(pair_idx * 2u) / f32(head_dim);
116
+ let exponent = f32(pair_idx * 2u) / f32(rotary_dim);
113
117
  let freq = 1.0 / pow(rope_base, exponent);
114
118
  let theta = actual_pos * freq;
115
119
 
@@ -118,12 +122,14 @@ fn rope_compute_freqs(
118
122
 
119
123
  // Apply "rotate-half" layout: pair (x[i], x[i + half_dim])
120
124
  let base_idx = pos * num_heads * head_dim + head_idx * head_dim;
121
- let x0 = input[base_idx + pair_idx];
122
- let x1 = input[base_idx + pair_idx + half_dim];
125
+ let first_idx = select(pair_idx, pair_idx * 2u, u.interleaved == 1u);
126
+ let second_idx = select(pair_idx + half_dim, pair_idx * 2u + 1u, u.interleaved == 1u);
127
+ let x0 = input[base_idx + first_idx];
128
+ let x1 = input[base_idx + second_idx];
123
129
 
124
130
  // Apply rotation
125
- input[base_idx + pair_idx] = x0 * cos_val - x1 * sin_val;
126
- input[base_idx + pair_idx + half_dim] = x0 * sin_val + x1 * cos_val;
131
+ input[base_idx + first_idx] = x0 * cos_val - x1 * sin_val;
132
+ input[base_idx + second_idx] = x0 * sin_val + x1 * cos_val;
127
133
  }
128
134
 
129
135
  // Apply RoPE to both Q and K in one pass
@@ -138,10 +144,11 @@ fn rope_qk(
138
144
  let start_pos = u.start_pos;
139
145
  let rope_base = u.rope_base;
140
146
  let rope_scale = u.rope_scale;
147
+ let rotary_dim = u.rotary_dim;
141
148
 
142
149
  let idx = global_id.x;
143
150
  // Each thread handles one Q-K pair at one dimension pair
144
- let half_dim = head_dim / 2u;
151
+ let half_dim = rotary_dim / 2u;
145
152
  let total_pairs = seq_len * num_heads * half_dim;
146
153
 
147
154
  if (idx >= total_pairs) {
@@ -156,7 +163,7 @@ fn rope_qk(
156
163
  let actual_pos = f32(start_pos + pos) / rope_scale;
157
164
 
158
165
  // Compute frequency
159
- let exponent = f32(pair_idx * 2u) / f32(head_dim);
166
+ let exponent = f32(pair_idx * 2u) / f32(rotary_dim);
160
167
  let freq = 1.0 / pow(rope_base, exponent);
161
168
  let theta = actual_pos * freq;
162
169
 
@@ -168,16 +175,18 @@ fn rope_qk(
168
175
  let k_base_idx = q_base_idx + head_dim; // K starts after Q
169
176
 
170
177
  // Process Q
171
- let q0 = input[q_base_idx + pair_idx];
172
- let q1 = input[q_base_idx + pair_idx + half_dim];
173
- input[q_base_idx + pair_idx] = q0 * cos_val - q1 * sin_val;
174
- input[q_base_idx + pair_idx + half_dim] = q0 * sin_val + q1 * cos_val;
178
+ let first_idx = select(pair_idx, pair_idx * 2u, u.interleaved == 1u);
179
+ let second_idx = select(pair_idx + half_dim, pair_idx * 2u + 1u, u.interleaved == 1u);
180
+ let q0 = input[q_base_idx + first_idx];
181
+ let q1 = input[q_base_idx + second_idx];
182
+ input[q_base_idx + first_idx] = q0 * cos_val - q1 * sin_val;
183
+ input[q_base_idx + second_idx] = q0 * sin_val + q1 * cos_val;
175
184
 
176
185
  // Process K
177
- let k0 = input[k_base_idx + pair_idx];
178
- let k1 = input[k_base_idx + pair_idx + half_dim];
179
- input[k_base_idx + pair_idx] = k0 * cos_val - k1 * sin_val;
180
- input[k_base_idx + pair_idx + half_dim] = k0 * sin_val + k1 * cos_val;
186
+ let k0 = input[k_base_idx + first_idx];
187
+ let k1 = input[k_base_idx + second_idx];
188
+ input[k_base_idx + first_idx] = k0 * cos_val - k1 * sin_val;
189
+ input[k_base_idx + second_idx] = k0 * sin_val + k1 * cos_val;
181
190
  }
182
191
 
183
192
  // Precompute frequency table (run once at init)
@@ -190,9 +199,10 @@ fn precompute_freqs(
190
199
  let seq_len = u.seq_len; // maxSeqLen for precomputation
191
200
  let rope_base = u.rope_base;
192
201
  let rope_scale = u.rope_scale;
202
+ let rotary_dim = u.rotary_dim;
193
203
 
194
204
  let idx = global_id.x;
195
- let half_dim = head_dim / 2u;
205
+ let half_dim = rotary_dim / 2u;
196
206
  let total_elements = seq_len * half_dim;
197
207
 
198
208
  if (idx >= total_elements) {
@@ -203,7 +213,7 @@ fn precompute_freqs(
203
213
  let dim_idx = idx % half_dim;
204
214
 
205
215
  let actual_pos = f32(pos) / rope_scale;
206
- let exponent = f32(dim_idx * 2u) / f32(head_dim);
216
+ let exponent = f32(dim_idx * 2u) / f32(rotary_dim);
207
217
  let freq = 1.0 / pow(rope_base, exponent);
208
218
  let theta = actual_pos * freq;
209
219
 
@@ -218,6 +228,7 @@ fn rope_ntk_scaled(
218
228
  @builtin(global_invocation_id) global_id: vec3<u32>
219
229
  ) {
220
230
  let head_dim = u.head_dim;
231
+ let rotary_dim = u.rotary_dim;
221
232
  let num_heads = u.num_heads;
222
233
  let seq_len = u.seq_len;
223
234
  let start_pos = u.start_pos;
@@ -225,7 +236,7 @@ fn rope_ntk_scaled(
225
236
  let rope_scale = u.rope_scale;
226
237
 
227
238
  let idx = global_id.x;
228
- let half_dim = head_dim / 2u;
239
+ let half_dim = rotary_dim / 2u;
229
240
  let total_pairs = seq_len * num_heads * half_dim;
230
241
 
231
242
  if (idx >= total_pairs) {
@@ -234,7 +245,7 @@ fn rope_ntk_scaled(
234
245
 
235
246
  // NTK scaling: increase base proportionally to scale factor
236
247
  // This preserves high-frequency components better than linear interpolation
237
- rope_base = rope_base * pow(rope_scale, f32(head_dim) / (f32(head_dim) - 2.0));
248
+ rope_base = rope_base * pow(rope_scale, f32(rotary_dim) / (f32(rotary_dim) - 2.0));
238
249
 
239
250
  let pos = idx / (num_heads * half_dim);
240
251
  let remainder = idx % (num_heads * half_dim);
@@ -243,7 +254,7 @@ fn rope_ntk_scaled(
243
254
 
244
255
  let actual_pos = f32(start_pos + pos);
245
256
 
246
- let exponent = f32(pair_idx * 2u) / f32(head_dim);
257
+ let exponent = f32(pair_idx * 2u) / f32(rotary_dim);
247
258
  let freq = 1.0 / pow(rope_base, exponent);
248
259
  let theta = actual_pos * freq;
249
260
 
@@ -251,11 +262,13 @@ fn rope_ntk_scaled(
251
262
  let sin_val = sin(theta);
252
263
 
253
264
  let base_idx = pos * num_heads * head_dim + head_idx * head_dim;
254
- let x0 = input[base_idx + pair_idx];
255
- let x1 = input[base_idx + pair_idx + half_dim];
265
+ let first_idx = select(pair_idx, pair_idx * 2u, u.interleaved == 1u);
266
+ let second_idx = select(pair_idx + half_dim, pair_idx * 2u + 1u, u.interleaved == 1u);
267
+ let x0 = input[base_idx + first_idx];
268
+ let x1 = input[base_idx + second_idx];
256
269
 
257
- input[base_idx + pair_idx] = x0 * cos_val - x1 * sin_val;
258
- input[base_idx + pair_idx + half_dim] = x0 * sin_val + x1 * cos_val;
270
+ input[base_idx + first_idx] = x0 * cos_val - x1 * sin_val;
271
+ input[base_idx + second_idx] = x0 * sin_val + x1 * cos_val;
259
272
  }
260
273
 
261
274
  // YaRN-style RoPE with attention scaling
@@ -265,6 +278,7 @@ fn rope_yarn(
265
278
  @builtin(global_invocation_id) global_id: vec3<u32>
266
279
  ) {
267
280
  let head_dim = u.head_dim;
281
+ let rotary_dim = u.rotary_dim;
268
282
  let num_heads = u.num_heads;
269
283
  let seq_len = u.seq_len;
270
284
  let start_pos = u.start_pos;
@@ -272,7 +286,7 @@ fn rope_yarn(
272
286
  let rope_scale = u.rope_scale;
273
287
 
274
288
  let idx = global_id.x;
275
- let half_dim = head_dim / 2u;
289
+ let half_dim = rotary_dim / 2u;
276
290
  let total_pairs = seq_len * num_heads * half_dim;
277
291
 
278
292
  if (idx >= total_pairs) {
@@ -292,7 +306,7 @@ fn rope_yarn(
292
306
  let alpha: f32 = 1.0;
293
307
 
294
308
  // Compute original frequency
295
- let exponent = f32(pair_idx * 2u) / f32(head_dim);
309
+ let exponent = f32(pair_idx * 2u) / f32(rotary_dim);
296
310
  let orig_freq = 1.0 / pow(rope_base, exponent);
297
311
 
298
312
  // Compute wavelength
@@ -300,8 +314,8 @@ fn rope_yarn(
300
314
 
301
315
  // Interpolation factor based on wavelength
302
316
  var ramp: f32;
303
- let low_wavelength = f32(head_dim) / beta_fast;
304
- let high_wavelength = f32(head_dim) / beta_slow;
317
+ let low_wavelength = f32(rotary_dim) / beta_fast;
318
+ let high_wavelength = f32(rotary_dim) / beta_slow;
305
319
 
306
320
  if (wavelength < low_wavelength) {
307
321
  ramp = 0.0; // No interpolation for high frequencies
@@ -320,9 +334,11 @@ fn rope_yarn(
320
334
  let sin_val = sin(theta);
321
335
 
322
336
  let base_idx = pos * num_heads * head_dim + head_idx * head_dim;
323
- let x0 = input[base_idx + pair_idx];
324
- let x1 = input[base_idx + pair_idx + half_dim];
337
+ let first_idx = select(pair_idx, pair_idx * 2u, u.interleaved == 1u);
338
+ let second_idx = select(pair_idx + half_dim, pair_idx * 2u + 1u, u.interleaved == 1u);
339
+ let x0 = input[base_idx + first_idx];
340
+ let x1 = input[base_idx + second_idx];
325
341
 
326
- input[base_idx + pair_idx] = x0 * cos_val - x1 * sin_val;
327
- input[base_idx + pair_idx + half_dim] = x0 * sin_val + x1 * cos_val;
342
+ input[base_idx + first_idx] = x0 * cos_val - x1 * sin_val;
343
+ input[base_idx + second_idx] = x0 * sin_val + x1 * cos_val;
328
344
  }
@@ -0,0 +1,27 @@
1
+ import type { Tensor } from '../tensor.js';
2
+ import type { CommandRecorder } from '../command-recorder.js';
3
+ import type { OutputBufferOptions } from './types.js';
4
+
5
+ export interface SanaLinearAttentionOptions extends OutputBufferOptions {
6
+ numHeads: number;
7
+ headDim: number;
8
+ numTokens?: number;
9
+ hiddenSize?: number;
10
+ eps?: number;
11
+ summaryBuffer?: GPUBuffer | null;
12
+ }
13
+
14
+ export declare function runSanaLinearAttention(
15
+ query: Tensor,
16
+ key: Tensor,
17
+ value: Tensor,
18
+ options: SanaLinearAttentionOptions
19
+ ): Promise<Tensor>;
20
+
21
+ export declare function recordSanaLinearAttention(
22
+ recorder: CommandRecorder,
23
+ query: Tensor,
24
+ key: Tensor,
25
+ value: Tensor,
26
+ options: SanaLinearAttentionOptions
27
+ ): Promise<Tensor>;
@@ -0,0 +1,121 @@
1
+ import { getDevice } from '../device.js';
2
+ import { acquireBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
3
+ import { createTensor, dtypeBytes } from '../tensor.js';
4
+ import { unifiedKernelWrapper } from './utils.js';
5
+ import { selectRuleValue } from './rule-registry.js';
6
+ import { WORKGROUP_SIZES } from './constants.js';
7
+
8
+ function selectSanaLinearAttentionVariant(isF16) {
9
+ return selectRuleValue('sanaLinearAttention', 'variant', { isF16 });
10
+ }
11
+
12
+ async function runSummary(target, query, key, value, summaryBuffer, uniforms, variant) {
13
+ const summarySize = uniforms.num_heads * (uniforms.head_dim + 1) * uniforms.head_dim;
14
+ await unifiedKernelWrapper(
15
+ 'sana_linear_attention_summary',
16
+ target,
17
+ variant,
18
+ [query, key, value, summaryBuffer],
19
+ {
20
+ num_heads: uniforms.num_heads,
21
+ head_dim: uniforms.head_dim,
22
+ num_tokens: uniforms.num_tokens,
23
+ hidden_size: uniforms.hidden_size,
24
+ _pad0: 0,
25
+ _pad1: 0,
26
+ },
27
+ Math.ceil(summarySize / WORKGROUP_SIZES.DEFAULT)
28
+ );
29
+ }
30
+
31
+ async function runApply(target, query, summaryBuffer, outputBuffer, uniforms, variant) {
32
+ await unifiedKernelWrapper(
33
+ 'sana_linear_attention_apply',
34
+ target,
35
+ variant,
36
+ [query, summaryBuffer, outputBuffer],
37
+ {
38
+ num_heads: uniforms.num_heads,
39
+ head_dim: uniforms.head_dim,
40
+ num_tokens: uniforms.num_tokens,
41
+ hidden_size: uniforms.hidden_size,
42
+ eps: uniforms.eps,
43
+ _pad0: 0,
44
+ _pad1: 0,
45
+ _pad2: 0,
46
+ },
47
+ [Math.ceil(uniforms.hidden_size / WORKGROUP_SIZES.DEFAULT), uniforms.num_tokens, 1]
48
+ );
49
+ }
50
+
51
+ async function _sanaLinearAttention(target, query, key, value, options = {}) {
52
+ const recorder = target && typeof target.beginComputePass === 'function' ? target : null;
53
+ const device = target?.device || getDevice();
54
+ if (!device) {
55
+ throw new Error('SanaLinearAttention requires a WebGPU device.');
56
+ }
57
+
58
+ const {
59
+ numHeads,
60
+ headDim,
61
+ numTokens = query.shape?.[0],
62
+ hiddenSize = query.shape?.[1],
63
+ eps = 1e-15,
64
+ outputBuffer = null,
65
+ summaryBuffer = null,
66
+ } = options;
67
+
68
+ if (
69
+ !Number.isFinite(numHeads) ||
70
+ !Number.isFinite(headDim) ||
71
+ !Number.isFinite(numTokens) ||
72
+ !Number.isFinite(hiddenSize)
73
+ ) {
74
+ throw new Error('SanaLinearAttention requires numHeads, headDim, numTokens, and hiddenSize.');
75
+ }
76
+ if (hiddenSize !== numHeads * headDim) {
77
+ throw new Error(`SanaLinearAttention hiddenSize mismatch: ${hiddenSize} != ${numHeads} * ${headDim}`);
78
+ }
79
+
80
+ const isF16 = query.dtype === 'f16';
81
+ const variant = selectSanaLinearAttentionVariant(isF16);
82
+ const temporarySummary = summaryBuffer || acquireBuffer(
83
+ numHeads * (headDim + 1) * headDim * Float32Array.BYTES_PER_ELEMENT,
84
+ undefined,
85
+ 'sana_linear_attention_summary'
86
+ );
87
+ const output = outputBuffer || acquireBuffer(
88
+ numTokens * hiddenSize * dtypeBytes(query.dtype),
89
+ undefined,
90
+ 'sana_linear_attention_output'
91
+ );
92
+
93
+ const uniforms = {
94
+ num_heads: numHeads,
95
+ head_dim: headDim,
96
+ num_tokens: numTokens,
97
+ hidden_size: hiddenSize,
98
+ eps,
99
+ };
100
+
101
+ await runSummary(target, query, key, value, temporarySummary, uniforms, variant);
102
+ await runApply(target, query, temporarySummary, output, uniforms, variant);
103
+
104
+ if (!summaryBuffer) {
105
+ if (recorder) {
106
+ recorder.trackTemporaryBuffer(temporarySummary);
107
+ } else {
108
+ releaseBuffer(temporarySummary);
109
+ }
110
+ }
111
+
112
+ return createTensor(output, query.dtype, [numTokens, hiddenSize], 'sana_linear_attention_output');
113
+ }
114
+
115
+ export async function runSanaLinearAttention(query, key, value, options = {}) {
116
+ return _sanaLinearAttention(null, query, key, value, options);
117
+ }
118
+
119
+ export async function recordSanaLinearAttention(recorder, query, key, value, options = {}) {
120
+ return _sanaLinearAttention(recorder, query, key, value, options);
121
+ }