@simulatte/doppler 0.1.5 → 0.1.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. package/README.md +23 -8
  2. package/package.json +7 -4
  3. package/src/config/kernels/kernel-ref-digests.js +39 -39
  4. package/src/config/kernels/registry.json +42 -2
  5. package/src/config/loader.js +31 -2
  6. package/src/config/merge.js +18 -0
  7. package/src/config/presets/models/qwen3.json +9 -2
  8. package/src/config/presets/models/transformer.json +5 -0
  9. package/src/config/required-inference-fields-contract-check.js +6 -0
  10. package/src/config/schema/inference-defaults.schema.js +3 -0
  11. package/src/config/schema/inference.schema.d.ts +9 -0
  12. package/src/config/schema/kernel-path.schema.d.ts +6 -0
  13. package/src/config/schema/manifest.schema.d.ts +6 -0
  14. package/src/config/schema/manifest.schema.js +3 -0
  15. package/src/converter/rope-config.js +42 -0
  16. package/src/gpu/device.js +58 -0
  17. package/src/gpu/kernels/attention.js +98 -0
  18. package/src/gpu/kernels/bias_add.wgsl +8 -6
  19. package/src/gpu/kernels/bias_add_f16.wgsl +8 -5
  20. package/src/gpu/kernels/conv2d.js +1 -1
  21. package/src/gpu/kernels/conv2d.wgsl +7 -8
  22. package/src/gpu/kernels/conv2d_f16.wgsl +7 -8
  23. package/src/gpu/kernels/depthwise_conv2d.js +2 -1
  24. package/src/gpu/kernels/depthwise_conv2d.wgsl +6 -9
  25. package/src/gpu/kernels/depthwise_conv2d_f16.wgsl +6 -9
  26. package/src/gpu/kernels/grouped_pointwise_conv2d.js +2 -1
  27. package/src/gpu/kernels/grouped_pointwise_conv2d.wgsl +6 -9
  28. package/src/gpu/kernels/grouped_pointwise_conv2d_f16.wgsl +6 -9
  29. package/src/gpu/kernels/matmul.js +25 -0
  30. package/src/gpu/kernels/pixel_shuffle.js +1 -1
  31. package/src/gpu/kernels/pixel_shuffle.wgsl +4 -5
  32. package/src/gpu/kernels/pixel_shuffle_f16.wgsl +4 -5
  33. package/src/gpu/kernels/relu.js +15 -2
  34. package/src/gpu/kernels/relu.wgsl +2 -1
  35. package/src/gpu/kernels/relu_f16.wgsl +2 -1
  36. package/src/gpu/kernels/repeat_channels.js +1 -1
  37. package/src/gpu/kernels/repeat_channels.wgsl +4 -5
  38. package/src/gpu/kernels/repeat_channels_f16.wgsl +4 -5
  39. package/src/gpu/kernels/residual.js +44 -8
  40. package/src/gpu/kernels/residual.wgsl +6 -3
  41. package/src/gpu/kernels/residual_f16.wgsl +2 -1
  42. package/src/gpu/kernels/residual_f16_vec4.wgsl +2 -1
  43. package/src/gpu/kernels/residual_vec4.wgsl +2 -1
  44. package/src/gpu/kernels/rmsnorm.js +58 -6
  45. package/src/gpu/kernels/rmsnorm.wgsl +14 -6
  46. package/src/gpu/kernels/rmsnorm_f16.wgsl +10 -2
  47. package/src/gpu/kernels/rope.d.ts +2 -0
  48. package/src/gpu/kernels/rope.js +11 -1
  49. package/src/gpu/kernels/rope.wgsl +56 -40
  50. package/src/gpu/kernels/sana_linear_attention.js +1 -2
  51. package/src/gpu/kernels/sana_linear_attention_apply.wgsl +4 -5
  52. package/src/gpu/kernels/sana_linear_attention_apply_f16.wgsl +4 -5
  53. package/src/gpu/kernels/sana_linear_attention_summary.wgsl +4 -0
  54. package/src/gpu/kernels/sana_linear_attention_summary_f16.wgsl +4 -0
  55. package/src/gpu/kernels/silu.d.ts +1 -0
  56. package/src/gpu/kernels/silu.js +32 -14
  57. package/src/gpu/kernels/silu.wgsl +19 -9
  58. package/src/gpu/kernels/silu_f16.wgsl +19 -9
  59. package/src/gpu/kernels/transpose.js +15 -2
  60. package/src/gpu/kernels/transpose.wgsl +5 -6
  61. package/src/gpu/kernels/upsample2d.js +2 -1
  62. package/src/gpu/kernels/upsample2d.wgsl +6 -9
  63. package/src/gpu/kernels/upsample2d_f16.wgsl +6 -9
  64. package/src/gpu/kernels/utils.js +16 -1
  65. package/src/inference/browser-harness.js +47 -1
  66. package/src/inference/pipelines/diffusion/pipeline.js +15 -6
  67. package/src/inference/pipelines/diffusion/text-encoder-gpu.d.ts +5 -0
  68. package/src/inference/pipelines/diffusion/text-encoder-gpu.js +27 -15
  69. package/src/inference/pipelines/text/attention/record.js +11 -2
  70. package/src/inference/pipelines/text/attention/run.js +11 -2
  71. package/src/inference/pipelines/text/chat-format.js +25 -1
  72. package/src/inference/pipelines/text/config.d.ts +4 -0
  73. package/src/inference/pipelines/text/config.js +68 -1
  74. package/src/inference/pipelines/text/execution-plan.js +23 -31
  75. package/src/inference/pipelines/text/execution-v0.js +29 -2
  76. package/src/inference/pipelines/text/ffn/standard.js +3 -0
  77. package/src/inference/pipelines/text/init.d.ts +4 -0
  78. package/src/inference/pipelines/text/init.js +56 -9
  79. package/src/inference/pipelines/text/layer.js +11 -0
  80. package/src/inference/pipelines/text.js +4 -0
  81. package/src/inference/tokenizers/bundled.js +156 -33
  82. package/src/rules/tooling/command-runtime.rules.json +18 -0
  83. package/src/tooling/command-api.d.ts +27 -1
  84. package/src/tooling/command-api.js +142 -3
  85. package/src/tooling/node-browser-command-runner.d.ts +4 -0
  86. package/src/tooling/node-browser-command-runner.js +58 -3
  87. package/src/tooling/node-command-runner.js +15 -0
  88. package/src/tooling/node-webgpu.js +9 -87
  89. package/src/training/checkpoint-watch.d.ts +7 -0
  90. package/src/training/checkpoint-watch.js +106 -0
  91. package/src/training/checkpoint.d.ts +6 -1
  92. package/src/training/checkpoint.js +12 -2
  93. package/src/training/distillation/artifacts.d.ts +71 -0
  94. package/src/training/distillation/artifacts.js +132 -0
  95. package/src/training/distillation/checkpoint-watch.d.ts +10 -0
  96. package/src/training/distillation/checkpoint-watch.js +57 -0
  97. package/src/training/distillation/dataset.d.ts +59 -0
  98. package/src/training/distillation/dataset.js +337 -0
  99. package/src/training/distillation/eval.d.ts +34 -0
  100. package/src/training/distillation/eval.js +310 -0
  101. package/src/training/distillation/index.d.ts +29 -0
  102. package/src/training/distillation/index.js +29 -0
  103. package/src/training/distillation/runtime.d.ts +20 -0
  104. package/src/training/distillation/runtime.js +121 -0
  105. package/src/training/distillation/scoreboard.d.ts +6 -0
  106. package/src/training/distillation/scoreboard.js +8 -0
  107. package/src/training/distillation/stage-a.d.ts +45 -0
  108. package/src/training/distillation/stage-a.js +338 -0
  109. package/src/training/distillation/stage-b.d.ts +24 -0
  110. package/src/training/distillation/stage-b.js +20 -0
  111. package/src/training/index.d.ts +10 -0
  112. package/src/training/index.js +10 -0
  113. package/src/training/lora-pipeline.d.ts +40 -0
  114. package/src/training/lora-pipeline.js +796 -0
  115. package/src/training/operator-artifacts.d.ts +62 -0
  116. package/src/training/operator-artifacts.js +140 -0
  117. package/src/training/operator-command.d.ts +5 -0
  118. package/src/training/operator-command.js +453 -0
  119. package/src/training/operator-eval.d.ts +48 -0
  120. package/src/training/operator-eval.js +230 -0
  121. package/src/training/operator-scoreboard.d.ts +5 -0
  122. package/src/training/operator-scoreboard.js +44 -0
  123. package/src/training/runner.d.ts +52 -0
  124. package/src/training/runner.js +29 -4
  125. package/src/training/suite.d.ts +112 -0
  126. package/src/training/suite.js +9 -9
  127. package/src/training/workloads.d.ts +164 -0
  128. package/src/training/workloads.js +539 -0
  129. package/src/version.js +1 -1
  130. package/tools/doppler-cli.js +137 -40
@@ -14,16 +14,15 @@ struct Uniforms {
14
14
 
15
15
  @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
16
16
  fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
17
- let idx = gid.x;
18
17
  let spatial = u.height * u.width;
19
18
  let out_channels = u.in_channels * u.repeats;
20
- let total = out_channels * spatial;
21
- if (idx >= total) {
19
+ let spatial_idx = gid.x;
20
+ let out_channel = gid.y;
21
+ if (out_channel >= out_channels || spatial_idx >= spatial) {
22
22
  return;
23
23
  }
24
24
 
25
- let out_channel = idx / spatial;
26
25
  let channel = out_channel / u.repeats;
27
- let spatial_idx = idx - out_channel * spatial;
26
+ let idx = out_channel * spatial + spatial_idx;
28
27
  output[idx] = input[channel * spatial + spatial_idx];
29
28
  }
@@ -16,16 +16,15 @@ struct Uniforms {
16
16
 
17
17
  @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
18
18
  fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
19
- let idx = gid.x;
20
19
  let spatial = u.height * u.width;
21
20
  let out_channels = u.in_channels * u.repeats;
22
- let total = out_channels * spatial;
23
- if (idx >= total) {
21
+ let spatial_idx = gid.x;
22
+ let out_channel = gid.y;
23
+ if (out_channel >= out_channels || spatial_idx >= spatial) {
24
24
  return;
25
25
  }
26
26
 
27
- let out_channel = idx / spatial;
28
27
  let channel = out_channel / u.repeats;
29
- let spatial_idx = idx - out_channel * spatial;
28
+ let idx = out_channel * spatial + spatial_idx;
30
29
  output[idx] = input[channel * spatial + spatial_idx];
31
30
  }
@@ -63,6 +63,22 @@ function cleanupTemps(temps, recorder) {
63
63
  }
64
64
  }
65
65
 
66
+ function planResidualDispatch(target, size, elementsPerWorkgroup) {
67
+ const device = target?.device;
68
+ const maxPerDim = Number.isFinite(device?.limits?.maxComputeWorkgroupsPerDimension)
69
+ ? device.limits.maxComputeWorkgroupsPerDimension
70
+ : 65535;
71
+ const dispatchStride = Math.min(size, maxPerDim * elementsPerWorkgroup);
72
+ return {
73
+ dispatchStride,
74
+ workgroups: [
75
+ Math.ceil(dispatchStride / elementsPerWorkgroup),
76
+ Math.ceil(size / dispatchStride),
77
+ 1,
78
+ ],
79
+ };
80
+ }
81
+
66
82
  async function _residualAdd(target, a, b, size, options = {}) {
67
83
  const recorder = target && typeof target.beginComputePass === 'function' ? target : null;
68
84
  const { useVec4 = true, outputBuffer = null } = options;
@@ -75,15 +91,17 @@ async function _residualAdd(target, a, b, size, options = {}) {
75
91
  const outputSize = size * bytesPerElement;
76
92
  const output = outputBuffer || acquireBuffer(outputSize, undefined, 'residual_output');
77
93
 
78
- const workgroups = useVec4
79
- ? Math.ceil(size / VEC4_ELEMENTS_PER_WG)
80
- : Math.ceil(size / WORKGROUP_SIZES.DEFAULT);
94
+ const dispatchPlan = planResidualDispatch(
95
+ target,
96
+ size,
97
+ useVec4 ? VEC4_ELEMENTS_PER_WG : WORKGROUP_SIZES.DEFAULT
98
+ );
81
99
 
82
100
  await unifiedKernelWrapper(
83
101
  'residual', target, variant,
84
102
  [aAligned, bAligned, output],
85
- { size },
86
- workgroups
103
+ { size, scale: 1, _pad1: dispatchPlan.dispatchStride, _pad2: 0 },
104
+ dispatchPlan.workgroups
87
105
  );
88
106
 
89
107
  cleanupTemps(temps, recorder);
@@ -96,13 +114,31 @@ async function _biasAdd(target, data, bias, numTokens, dim, options = {}) {
96
114
 
97
115
  const { bias: biasAligned, temps } = await alignBiasTensor(data, bias, recorder);
98
116
  const variant = selectBiasAddVariant(data.dtype, biasAligned.dtype);
99
-
100
- const workgroups = Math.ceil((numTokens * dim) / WORKGROUP_SIZES.DEFAULT);
117
+ const device = target?.device;
118
+ const maxPerDim = Number.isFinite(device?.limits?.maxComputeWorkgroupsPerDimension)
119
+ ? device.limits.maxComputeWorkgroupsPerDimension
120
+ : 65535;
121
+ const tokenStride = Math.min(numTokens, maxPerDim);
122
+
123
+ const workgroups = [
124
+ Math.ceil(dim / WORKGROUP_SIZES.DEFAULT),
125
+ tokenStride,
126
+ Math.ceil(numTokens / tokenStride),
127
+ ];
101
128
 
102
129
  await unifiedKernelWrapper(
103
130
  'bias_add', target, variant,
104
131
  [data, biasAligned],
105
- { num_tokens: numTokens, dim, data_offset: dataOffset, bias_offset: biasOffset },
132
+ {
133
+ num_tokens: numTokens,
134
+ dim,
135
+ data_offset: dataOffset,
136
+ bias_offset: biasOffset,
137
+ token_stride: tokenStride,
138
+ _pad0: 0,
139
+ _pad1: 0,
140
+ _pad2: 0,
141
+ },
106
142
  workgroups
107
143
  );
108
144
 
@@ -23,7 +23,8 @@ override WORKGROUP_SIZE: u32 = 256u;
23
23
 
24
24
  @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
25
25
  fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
26
- let idx = gid.x;
26
+ let dispatch_stride = max(u._pad1, 1u);
27
+ let idx = gid.y * dispatch_stride + gid.x;
27
28
  if (idx >= u.size) {
28
29
  return;
29
30
  }
@@ -35,7 +36,8 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
35
36
  // This avoids requiring a different bind group layout with read_write on 'a'
36
37
  @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
37
38
  fn add_inplace(@builtin(global_invocation_id) gid: vec3<u32>) {
38
- let idx = gid.x;
39
+ let dispatch_stride = max(u._pad1, 1u);
40
+ let idx = gid.y * dispatch_stride + gid.x;
39
41
  if (idx >= u.size) {
40
42
  return;
41
43
  }
@@ -45,7 +47,8 @@ fn add_inplace(@builtin(global_invocation_id) gid: vec3<u32>) {
45
47
  // Fused residual + scale: output = a + scale * b
46
48
  @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
47
49
  fn add_scaled(@builtin(global_invocation_id) gid: vec3<u32>) {
48
- let idx = gid.x;
50
+ let dispatch_stride = max(u._pad1, 1u);
51
+ let idx = gid.y * dispatch_stride + gid.x;
49
52
  if (idx >= u.size) {
50
53
  return;
51
54
  }
@@ -27,7 +27,8 @@ override WORKGROUP_SIZE: u32 = 256u;
27
27
 
28
28
  @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
29
29
  fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
30
- let idx = gid.x;
30
+ let dispatch_stride = max(u._pad1, 1u);
31
+ let idx = gid.y * dispatch_stride + gid.x;
31
32
  if (idx >= u.size) {
32
33
  return;
33
34
  }
@@ -25,7 +25,8 @@ override WORKGROUP_SIZE_VEC4: u32 = 64u;
25
25
  // Vectorized version for better throughput
26
26
  @compute @workgroup_size(WORKGROUP_SIZE_VEC4, 1, 1)
27
27
  fn add_vec4(@builtin(global_invocation_id) gid: vec3<u32>) {
28
- let idx = gid.x * 4u;
28
+ let dispatch_stride = max(u._pad1, 4u);
29
+ let idx = gid.y * dispatch_stride + gid.x * 4u;
29
30
  let size = u.size;
30
31
 
31
32
  if (idx >= size) {
@@ -23,7 +23,8 @@ override WORKGROUP_SIZE_VEC4: u32 = 64u;
23
23
  // Vectorized version for better throughput
24
24
  @compute @workgroup_size(WORKGROUP_SIZE_VEC4, 1, 1)
25
25
  fn add_vec4(@builtin(global_invocation_id) gid: vec3<u32>) {
26
- let idx = gid.x * 4u;
26
+ let dispatch_stride = max(u._pad1, 4u);
27
+ let idx = gid.y * dispatch_stride + gid.x * 4u;
27
28
  let size = u.size;
28
29
 
29
30
  if (idx >= size) {
@@ -58,6 +58,36 @@ function resolveNormWeightDtype(weight, hiddenSize) {
58
58
  return 'f32';
59
59
  }
60
60
 
61
+ function assertRMSNormWeightBuffer(weight, weightBuffer, hiddenSize) {
62
+ const isGpuBuffer = weightBuffer && (
63
+ typeof GPUBuffer === 'undefined'
64
+ ? true
65
+ : weightBuffer instanceof GPUBuffer
66
+ );
67
+ if (isGpuBuffer) {
68
+ return;
69
+ }
70
+ const weightLabel = weight?.label ?? 'unknown';
71
+ const weightType = weight === null ? 'null' : weight === undefined ? 'undefined' : weight.constructor?.name || typeof weight;
72
+ const bufferType = weightBuffer === null ? 'null' : weightBuffer === undefined ? 'undefined' : weightBuffer.constructor?.name || typeof weightBuffer;
73
+ throw new Error(
74
+ `[rmsnorm] weight "${weightLabel}" requires a GPUBuffer ` +
75
+ `(weightType=${weightType}, bufferType=${bufferType}, hiddenSize=${hiddenSize ?? 'unknown'}).`
76
+ );
77
+ }
78
+
79
+ function planRMSNormDispatch(target, numTokens) {
80
+ const device = target?.device;
81
+ const maxPerDim = Number.isFinite(device?.limits?.maxComputeWorkgroupsPerDimension)
82
+ ? device.limits.maxComputeWorkgroupsPerDimension
83
+ : 65535;
84
+ const tokenStride = Math.min(numTokens, maxPerDim);
85
+ return {
86
+ tokenStride,
87
+ workgroups: [tokenStride, Math.ceil(numTokens / tokenStride), 1],
88
+ };
89
+ }
90
+
61
91
  export function selectRMSNormKernel(options = {}, isF16 = false) {
62
92
  const { residual = null, hiddenSize = null } = options;
63
93
  const { smallThreshold } = getKernelThresholds().rmsnorm;
@@ -82,23 +112,34 @@ export async function runRMSNorm(
82
112
  const variant = selectRMSNormKernel(options, isF16);
83
113
  const inferredHiddenSize = inferHiddenSize(input, hiddenSize);
84
114
  const normWeightBuffer = getBuffer(weight);
115
+ assertRMSNormWeightBuffer(weight, normWeightBuffer, inferredHiddenSize);
85
116
  const normWeightDtype = resolveNormWeightDtype(weight, inferredHiddenSize);
86
117
 
87
118
  const bytesPerElement = isF16 ? 2 : 4;
88
119
  const paddedHiddenSize = padToQ4KBlock(inferredHiddenSize);
89
120
  const outputSize = batchSize * paddedHiddenSize * bytesPerElement;
90
121
  const outputBuf = outputBuffer || acquireBuffer(outputSize, undefined, 'rmsnorm_output');
122
+ const dispatchPlan = planRMSNormDispatch(null, batchSize);
91
123
 
92
124
  // Shader layout always includes the residual binding; when unused, bind a harmless placeholder.
93
- const residualBuf = residual?.buffer || input.buffer;
125
+ const residualBuf = residual?.buffer || residual || input?.buffer || input || outputBuf;
94
126
 
95
127
  await unifiedKernelWrapper(
96
128
  'rmsnorm',
97
129
  null,
98
130
  variant,
99
131
  [input, normWeightBuffer, outputBuf, residualBuf],
100
- { hidden_size: inferredHiddenSize, num_tokens: batchSize, eps, has_residual: residual ? 1 : 0 },
101
- batchSize,
132
+ {
133
+ hidden_size: inferredHiddenSize,
134
+ num_tokens: batchSize,
135
+ eps,
136
+ has_residual: residual ? 1 : 0,
137
+ token_stride: dispatchPlan.tokenStride,
138
+ _pad0: 0,
139
+ _pad1: 0,
140
+ _pad2: 0,
141
+ },
142
+ dispatchPlan.workgroups,
102
143
  { RMS_NORM_OFFSET: rmsNormWeightOffset, WEIGHT_IS_F16: normWeightDtype === 'f16' }
103
144
  );
104
145
 
@@ -117,22 +158,33 @@ export async function recordRMSNorm(
117
158
  const variant = selectRMSNormKernel(options, isF16);
118
159
  const inferredHiddenSize = inferHiddenSize(input, hiddenSize);
119
160
  const normWeightBuffer = getBuffer(weight);
161
+ assertRMSNormWeightBuffer(weight, normWeightBuffer, inferredHiddenSize);
120
162
  const normWeightDtype = resolveNormWeightDtype(weight, inferredHiddenSize);
121
163
 
122
164
  const bytesPerElement = isF16 ? 2 : 4;
123
165
  const paddedHiddenSize = padToQ4KBlock(inferredHiddenSize);
124
166
  const outputSize = batchSize * paddedHiddenSize * bytesPerElement;
125
167
  const outputBuf = outputBuffer || acquireBuffer(outputSize, undefined, 'rmsnorm_output');
168
+ const dispatchPlan = planRMSNormDispatch(recorder, batchSize);
126
169
 
127
- const residualBuf = residual?.buffer || input.buffer;
170
+ const residualBuf = residual?.buffer || residual || input?.buffer || input || outputBuf;
128
171
 
129
172
  await unifiedKernelWrapper(
130
173
  'rmsnorm',
131
174
  recorder,
132
175
  variant,
133
176
  [input, normWeightBuffer, outputBuf, residualBuf],
134
- { hidden_size: inferredHiddenSize, num_tokens: batchSize, eps, has_residual: residual ? 1 : 0 },
135
- batchSize,
177
+ {
178
+ hidden_size: inferredHiddenSize,
179
+ num_tokens: batchSize,
180
+ eps,
181
+ has_residual: residual ? 1 : 0,
182
+ token_stride: dispatchPlan.tokenStride,
183
+ _pad0: 0,
184
+ _pad1: 0,
185
+ _pad2: 0,
186
+ },
187
+ dispatchPlan.workgroups,
136
188
  { RMS_NORM_OFFSET: rmsNormWeightOffset, WEIGHT_IS_F16: normWeightDtype === 'f16' }
137
189
  );
138
190
 
@@ -39,6 +39,10 @@ struct Uniforms {
39
39
  num_tokens: u32, // Number of tokens to process
40
40
  eps: f32, // Epsilon for numerical stability (typically 1e-5 or 1e-6)
41
41
  has_residual: u32, // Runtime flag: 1 = add residual after norm
42
+ token_stride: u32, // Workgroup rows per dispatch row
43
+ _pad0: u32,
44
+ _pad1: u32,
45
+ _pad2: u32,
42
46
  }
43
47
 
44
48
  @group(0) @binding(0) var<uniform> u: Uniforms;
@@ -82,6 +86,10 @@ fn should_add_residual() -> bool {
82
86
  return HAS_RESIDUAL || (u.has_residual != 0u);
83
87
  }
84
88
 
89
+ fn token_index(wg_id: vec3<u32>) -> u32 {
90
+ return wg_id.y * max(u.token_stride, 1u) + wg_id.x;
91
+ }
92
+
85
93
  // =============================================================================
86
94
  // Main Entry Point
87
95
  // =============================================================================
@@ -93,7 +101,7 @@ fn main(
93
101
  @builtin(local_invocation_id) local_id: vec3<u32>,
94
102
  @builtin(workgroup_id) wg_id: vec3<u32>
95
103
  ) {
96
- let token_idx = wg_id.x;
104
+ let token_idx = token_index(wg_id);
97
105
  let thread_idx = local_id.x;
98
106
  let size = u.size;
99
107
 
@@ -163,7 +171,7 @@ fn main_small(
163
171
  @builtin(local_invocation_id) local_id: vec3<u32>,
164
172
  @builtin(workgroup_id) wg_id: vec3<u32>
165
173
  ) {
166
- let token_idx = wg_id.x;
174
+ let token_idx = token_index(wg_id);
167
175
  let thread_idx = local_id.x;
168
176
  let size = u.size;
169
177
 
@@ -219,7 +227,7 @@ fn main_cached(
219
227
  @builtin(local_invocation_id) local_id: vec3<u32>,
220
228
  @builtin(workgroup_id) wg_id: vec3<u32>
221
229
  ) {
222
- let token_idx = wg_id.x;
230
+ let token_idx = token_index(wg_id);
223
231
  let thread_idx = local_id.x;
224
232
  let size = u.size;
225
233
 
@@ -288,7 +296,7 @@ fn main_subgroup(
288
296
  @builtin(subgroup_invocation_id) sg_lane: u32,
289
297
  @builtin(subgroup_size) sg_size: u32,
290
298
  ) {
291
- let token_idx = wg_id.x;
299
+ let token_idx = token_index(wg_id);
292
300
  let thread_idx = local_id.x;
293
301
  let size = u.size;
294
302
 
@@ -362,7 +370,7 @@ fn main_small_subgroup(
362
370
  @builtin(subgroup_invocation_id) sg_lane: u32,
363
371
  @builtin(subgroup_size) sg_size: u32,
364
372
  ) {
365
- let token_idx = wg_id.x;
373
+ let token_idx = token_index(wg_id);
366
374
  let thread_idx = local_id.x;
367
375
  let size = u.size;
368
376
 
@@ -414,4 +422,4 @@ fn main_small_subgroup(
414
422
  }
415
423
  output[base_offset + thread_idx] = result;
416
424
  }
417
- }
425
+ }
@@ -20,6 +20,10 @@ struct Uniforms {
20
20
  num_tokens: u32, // Number of tokens to process
21
21
  eps: f32, // Epsilon for numerical stability
22
22
  has_residual: u32, // 1 if residual input provided, 0 otherwise
23
+ token_stride: u32, // Workgroup rows per dispatch row
24
+ _pad0: u32,
25
+ _pad1: u32,
26
+ _pad2: u32,
23
27
  }
24
28
 
25
29
  @group(0) @binding(0) var<uniform> u: Uniforms;
@@ -47,6 +51,10 @@ fn load_weight(idx: u32) -> f32 {
47
51
  return bitcast<f32>(weight[idx]);
48
52
  }
49
53
 
54
+ fn token_index(wg_id: vec3<u32>) -> u32 {
55
+ return wg_id.y * max(u.token_stride, 1u) + wg_id.x;
56
+ }
57
+
50
58
  // Main RMSNorm kernel - one workgroup per token
51
59
  @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
52
60
  fn main(
@@ -54,7 +62,7 @@ fn main(
54
62
  @builtin(local_invocation_id) local_id: vec3<u32>,
55
63
  @builtin(workgroup_id) wg_id: vec3<u32>
56
64
  ) {
57
- let token_idx = wg_id.x;
65
+ let token_idx = token_index(wg_id);
58
66
  let thread_idx = local_id.x;
59
67
  let size = u.size;
60
68
 
@@ -121,7 +129,7 @@ fn rmsnorm_small_f16(
121
129
  @builtin(local_invocation_id) local_id: vec3<u32>,
122
130
  @builtin(workgroup_id) wg_id: vec3<u32>
123
131
  ) {
124
- let token_idx = wg_id.x;
132
+ let token_idx = token_index(wg_id);
125
133
  let thread_idx = local_id.x;
126
134
  let size = u.size;
127
135
 
@@ -15,6 +15,8 @@ import type { OutputBufferOptions } from './types.js';
15
15
  export interface RoPEOptions extends OutputBufferOptions {
16
16
  numHeads?: number;
17
17
  headDim?: number;
18
+ rotaryDim?: number;
19
+ interleaved?: boolean;
18
20
  ropeTheta?: number;
19
21
  startPos?: number;
20
22
  }
@@ -13,18 +13,26 @@ async function _rope(target, input, freqsCos, freqsSin, seqLen, options = {}) {
13
13
  const {
14
14
  numHeads = 1,
15
15
  headDim = 64,
16
+ rotaryDim = headDim,
17
+ interleaved = false,
16
18
  ropeTheta = ropeDefaults.defaultTheta,
17
19
  } = options;
18
20
 
19
21
  if (headDim % 2 !== 0) {
20
22
  throw new Error(`RoPE headDim must be even, got ${headDim}`);
21
23
  }
24
+ if (rotaryDim % 2 !== 0) {
25
+ throw new Error(`RoPE rotaryDim must be even, got ${rotaryDim}`);
26
+ }
27
+ if (rotaryDim <= 0 || rotaryDim > headDim) {
28
+ throw new Error(`RoPE rotaryDim must be in (0, headDim]; got ${rotaryDim} for headDim ${headDim}`);
29
+ }
22
30
 
23
31
  const caps = getKernelCapabilities();
24
32
  const useF16 = input.dtype === 'f16' && caps.hasF16;
25
33
  const variant = selectRuleValue('rope', 'variant', { useF16 });
26
34
 
27
- const halfDim = headDim / 2;
35
+ const halfDim = rotaryDim / 2;
28
36
  const workgroups = Math.ceil((seqLen * numHeads * halfDim) / WORKGROUP_SIZES.DEFAULT);
29
37
 
30
38
  await unifiedKernelWrapper(
@@ -34,9 +42,11 @@ async function _rope(target, input, freqsCos, freqsSin, seqLen, options = {}) {
34
42
  seq_len: seqLen,
35
43
  num_heads: numHeads,
36
44
  head_dim: headDim,
45
+ rotary_dim: rotaryDim,
37
46
  start_pos: options.startPos ?? ropeDefaults.defaultStartPos,
38
47
  rope_base: ropeTheta,
39
48
  rope_scale: 1.0,
49
+ interleaved: interleaved ? 1 : 0,
40
50
  },
41
51
  workgroups
42
52
  );