@simulatte/doppler 0.1.4 → 0.1.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +26 -10
- package/package.json +30 -6
- package/src/client/doppler-api.browser.d.ts +1 -0
- package/src/client/doppler-api.browser.js +288 -0
- package/src/client/doppler-api.js +1 -1
- package/src/client/doppler-provider/types.js +1 -1
- package/src/config/execution-contract-check.d.ts +33 -0
- package/src/config/execution-contract-check.js +72 -0
- package/src/config/execution-v0-contract-check.d.ts +94 -0
- package/src/config/execution-v0-contract-check.js +251 -0
- package/src/config/execution-v0-graph-contract-check.d.ts +20 -0
- package/src/config/execution-v0-graph-contract-check.js +64 -0
- package/src/config/kernel-path-contract-check.d.ts +76 -0
- package/src/config/kernel-path-contract-check.js +479 -0
- package/src/config/kernel-path-loader.d.ts +16 -0
- package/src/config/kernel-path-loader.js +54 -0
- package/src/config/kernels/kernel-ref-digests.js +39 -27
- package/src/config/kernels/registry.json +598 -2
- package/src/config/loader.js +81 -48
- package/src/config/merge-contract-check.d.ts +16 -0
- package/src/config/merge-contract-check.js +321 -0
- package/src/config/merge-helpers.d.ts +58 -0
- package/src/config/merge-helpers.js +54 -0
- package/src/config/merge.js +21 -6
- package/src/config/presets/models/janus-text.json +2 -0
- package/src/config/presets/models/qwen3.json +9 -2
- package/src/config/presets/models/transformer.json +5 -0
- package/src/config/quantization-contract-check.d.ts +12 -0
- package/src/config/quantization-contract-check.js +91 -0
- package/src/config/required-inference-fields-contract-check.d.ts +24 -0
- package/src/config/required-inference-fields-contract-check.js +237 -0
- package/src/config/schema/browser-suite-metrics.schema.d.ts +17 -0
- package/src/config/schema/browser-suite-metrics.schema.js +46 -0
- package/src/config/schema/conversion-report.schema.d.ts +40 -0
- package/src/config/schema/conversion-report.schema.js +108 -0
- package/src/config/schema/doppler.schema.js +12 -18
- package/src/config/schema/index.d.ts +22 -0
- package/src/config/schema/index.js +18 -0
- package/src/config/schema/inference-defaults.schema.js +3 -0
- package/src/config/schema/inference.schema.d.ts +9 -0
- package/src/config/schema/kernel-path.schema.d.ts +6 -0
- package/src/config/schema/manifest.schema.d.ts +6 -0
- package/src/config/schema/manifest.schema.js +3 -0
- package/src/converter/core.d.ts +10 -0
- package/src/converter/core.js +27 -2
- package/src/converter/parsers/diffusion.js +63 -3
- package/src/converter/rope-config.js +42 -0
- package/src/gpu/device.js +58 -0
- package/src/gpu/kernels/attention.js +98 -0
- package/src/gpu/kernels/bias_add.wgsl +8 -6
- package/src/gpu/kernels/bias_add_f16.wgsl +8 -5
- package/src/gpu/kernels/conv2d.js +1 -1
- package/src/gpu/kernels/conv2d.wgsl +7 -8
- package/src/gpu/kernels/conv2d_f16.wgsl +7 -8
- package/src/gpu/kernels/depthwise_conv2d.d.ts +29 -0
- package/src/gpu/kernels/depthwise_conv2d.js +99 -0
- package/src/gpu/kernels/depthwise_conv2d.wgsl +55 -0
- package/src/gpu/kernels/depthwise_conv2d_f16.wgsl +59 -0
- package/src/gpu/kernels/grouped_pointwise_conv2d.d.ts +27 -0
- package/src/gpu/kernels/grouped_pointwise_conv2d.js +93 -0
- package/src/gpu/kernels/grouped_pointwise_conv2d.wgsl +44 -0
- package/src/gpu/kernels/grouped_pointwise_conv2d_f16.wgsl +48 -0
- package/src/gpu/kernels/index.d.ts +30 -0
- package/src/gpu/kernels/index.js +25 -0
- package/src/gpu/kernels/matmul.js +25 -0
- package/src/gpu/kernels/pixel_shuffle.js +1 -1
- package/src/gpu/kernels/pixel_shuffle.wgsl +4 -5
- package/src/gpu/kernels/pixel_shuffle_f16.wgsl +4 -5
- package/src/gpu/kernels/relu.d.ts +18 -0
- package/src/gpu/kernels/relu.js +58 -0
- package/src/gpu/kernels/relu.wgsl +22 -0
- package/src/gpu/kernels/relu_f16.wgsl +24 -0
- package/src/gpu/kernels/repeat_channels.d.ts +21 -0
- package/src/gpu/kernels/repeat_channels.js +60 -0
- package/src/gpu/kernels/repeat_channels.wgsl +28 -0
- package/src/gpu/kernels/repeat_channels_f16.wgsl +30 -0
- package/src/gpu/kernels/residual.js +44 -8
- package/src/gpu/kernels/residual.wgsl +6 -3
- package/src/gpu/kernels/residual_f16.wgsl +2 -1
- package/src/gpu/kernels/residual_f16_vec4.wgsl +2 -1
- package/src/gpu/kernels/residual_vec4.wgsl +2 -1
- package/src/gpu/kernels/rmsnorm.js +58 -6
- package/src/gpu/kernels/rmsnorm.wgsl +14 -6
- package/src/gpu/kernels/rmsnorm_f16.wgsl +10 -2
- package/src/gpu/kernels/rope.d.ts +2 -0
- package/src/gpu/kernels/rope.js +11 -1
- package/src/gpu/kernels/rope.wgsl +56 -40
- package/src/gpu/kernels/sana_linear_attention.d.ts +27 -0
- package/src/gpu/kernels/sana_linear_attention.js +121 -0
- package/src/gpu/kernels/sana_linear_attention_apply.wgsl +43 -0
- package/src/gpu/kernels/sana_linear_attention_apply_f16.wgsl +46 -0
- package/src/gpu/kernels/sana_linear_attention_summary.wgsl +51 -0
- package/src/gpu/kernels/sana_linear_attention_summary_f16.wgsl +53 -0
- package/src/gpu/kernels/silu.d.ts +1 -0
- package/src/gpu/kernels/silu.js +32 -14
- package/src/gpu/kernels/silu.wgsl +19 -9
- package/src/gpu/kernels/silu_f16.wgsl +19 -9
- package/src/gpu/kernels/transpose.js +15 -2
- package/src/gpu/kernels/transpose.wgsl +5 -6
- package/src/gpu/kernels/upsample2d.js +2 -1
- package/src/gpu/kernels/upsample2d.wgsl +6 -9
- package/src/gpu/kernels/upsample2d_f16.wgsl +6 -9
- package/src/gpu/kernels/utils.js +16 -1
- package/src/index-browser.d.ts +1 -1
- package/src/index-browser.js +2 -2
- package/src/index.js +1 -1
- package/src/inference/browser-harness.js +109 -23
- package/src/inference/pipelines/diffusion/init.js +14 -0
- package/src/inference/pipelines/diffusion/pipeline.js +215 -77
- package/src/inference/pipelines/diffusion/sana-transformer.d.ts +53 -0
- package/src/inference/pipelines/diffusion/sana-transformer.js +738 -0
- package/src/inference/pipelines/diffusion/scheduler.d.ts +17 -1
- package/src/inference/pipelines/diffusion/scheduler.js +91 -3
- package/src/inference/pipelines/diffusion/text-encoder-gpu.d.ts +11 -4
- package/src/inference/pipelines/diffusion/text-encoder-gpu.js +282 -0
- package/src/inference/pipelines/diffusion/text-encoder.js +18 -1
- package/src/inference/pipelines/diffusion/types.d.ts +4 -0
- package/src/inference/pipelines/diffusion/vae.js +782 -78
- package/src/inference/pipelines/text/attention/record.js +11 -2
- package/src/inference/pipelines/text/attention/run.js +11 -2
- package/src/inference/pipelines/text/chat-format.js +25 -1
- package/src/inference/pipelines/text/config.d.ts +9 -0
- package/src/inference/pipelines/text/config.js +69 -2
- package/src/inference/pipelines/text/execution-plan.js +23 -31
- package/src/inference/pipelines/text/execution-v0.js +43 -95
- package/src/inference/pipelines/text/ffn/standard.js +3 -0
- package/src/inference/pipelines/text/init.d.ts +4 -0
- package/src/inference/pipelines/text/init.js +56 -9
- package/src/inference/pipelines/text/layer.js +11 -0
- package/src/inference/pipelines/text.js +4 -0
- package/src/inference/tokenizers/bundled.js +156 -33
- package/src/rules/execution-rules-contract-check.d.ts +17 -0
- package/src/rules/execution-rules-contract-check.js +245 -0
- package/src/rules/kernels/depthwise-conv2d.rules.json +6 -0
- package/src/rules/kernels/grouped-pointwise-conv2d.rules.json +6 -0
- package/src/rules/kernels/relu.rules.json +6 -0
- package/src/rules/kernels/repeat-channels.rules.json +6 -0
- package/src/rules/kernels/sana-linear-attention.rules.json +6 -0
- package/src/rules/layer-pattern-contract-check.d.ts +17 -0
- package/src/rules/layer-pattern-contract-check.js +231 -0
- package/src/rules/rule-registry.d.ts +28 -0
- package/src/rules/rule-registry.js +38 -0
- package/src/rules/tooling/command-runtime.rules.json +18 -0
- package/src/tooling/command-api.d.ts +27 -1
- package/src/tooling/command-api.js +142 -3
- package/src/tooling/conversion-config-materializer.d.ts +24 -0
- package/src/tooling/conversion-config-materializer.js +99 -0
- package/src/tooling/lean-execution-contract-runner.d.ts +43 -0
- package/src/tooling/lean-execution-contract-runner.js +158 -0
- package/src/tooling/node-browser-command-runner.d.ts +4 -0
- package/src/tooling/node-browser-command-runner.js +58 -3
- package/src/tooling/node-command-runner.js +15 -0
- package/src/tooling/node-convert.d.ts +10 -0
- package/src/tooling/node-converter.js +59 -0
- package/src/tooling/node-webgpu.js +11 -89
- package/src/training/checkpoint-watch.d.ts +7 -0
- package/src/training/checkpoint-watch.js +106 -0
- package/src/training/checkpoint.d.ts +6 -1
- package/src/training/checkpoint.js +12 -2
- package/src/training/distillation/artifacts.d.ts +71 -0
- package/src/training/distillation/artifacts.js +132 -0
- package/src/training/distillation/checkpoint-watch.d.ts +10 -0
- package/src/training/distillation/checkpoint-watch.js +57 -0
- package/src/training/distillation/dataset.d.ts +59 -0
- package/src/training/distillation/dataset.js +337 -0
- package/src/training/distillation/eval.d.ts +34 -0
- package/src/training/distillation/eval.js +310 -0
- package/src/training/distillation/index.d.ts +29 -0
- package/src/training/distillation/index.js +29 -0
- package/src/training/distillation/runtime.d.ts +20 -0
- package/src/training/distillation/runtime.js +121 -0
- package/src/training/distillation/scoreboard.d.ts +6 -0
- package/src/training/distillation/scoreboard.js +8 -0
- package/src/training/distillation/stage-a.d.ts +45 -0
- package/src/training/distillation/stage-a.js +338 -0
- package/src/training/distillation/stage-b.d.ts +24 -0
- package/src/training/distillation/stage-b.js +20 -0
- package/src/training/index.d.ts +10 -0
- package/src/training/index.js +10 -0
- package/src/training/lora-pipeline.d.ts +40 -0
- package/src/training/lora-pipeline.js +796 -0
- package/src/training/operator-artifacts.d.ts +62 -0
- package/src/training/operator-artifacts.js +140 -0
- package/src/training/operator-command.d.ts +5 -0
- package/src/training/operator-command.js +453 -0
- package/src/training/operator-eval.d.ts +48 -0
- package/src/training/operator-eval.js +230 -0
- package/src/training/operator-scoreboard.d.ts +5 -0
- package/src/training/operator-scoreboard.js +44 -0
- package/src/training/runner.d.ts +52 -0
- package/src/training/runner.js +29 -4
- package/src/training/suite.d.ts +112 -0
- package/src/training/suite.js +9 -9
- package/src/training/workloads.d.ts +164 -0
- package/src/training/workloads.js +539 -0
- package/src/version.d.ts +2 -0
- package/src/version.js +2 -0
- package/tools/convert-safetensors-node.js +47 -0
- package/tools/doppler-cli.js +252 -41
|
@@ -58,6 +58,36 @@ function resolveNormWeightDtype(weight, hiddenSize) {
|
|
|
58
58
|
return 'f32';
|
|
59
59
|
}
|
|
60
60
|
|
|
61
|
+
function assertRMSNormWeightBuffer(weight, weightBuffer, hiddenSize) {
|
|
62
|
+
const isGpuBuffer = weightBuffer && (
|
|
63
|
+
typeof GPUBuffer === 'undefined'
|
|
64
|
+
? true
|
|
65
|
+
: weightBuffer instanceof GPUBuffer
|
|
66
|
+
);
|
|
67
|
+
if (isGpuBuffer) {
|
|
68
|
+
return;
|
|
69
|
+
}
|
|
70
|
+
const weightLabel = weight?.label ?? 'unknown';
|
|
71
|
+
const weightType = weight === null ? 'null' : weight === undefined ? 'undefined' : weight.constructor?.name || typeof weight;
|
|
72
|
+
const bufferType = weightBuffer === null ? 'null' : weightBuffer === undefined ? 'undefined' : weightBuffer.constructor?.name || typeof weightBuffer;
|
|
73
|
+
throw new Error(
|
|
74
|
+
`[rmsnorm] weight "${weightLabel}" requires a GPUBuffer ` +
|
|
75
|
+
`(weightType=${weightType}, bufferType=${bufferType}, hiddenSize=${hiddenSize ?? 'unknown'}).`
|
|
76
|
+
);
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
function planRMSNormDispatch(target, numTokens) {
|
|
80
|
+
const device = target?.device;
|
|
81
|
+
const maxPerDim = Number.isFinite(device?.limits?.maxComputeWorkgroupsPerDimension)
|
|
82
|
+
? device.limits.maxComputeWorkgroupsPerDimension
|
|
83
|
+
: 65535;
|
|
84
|
+
const tokenStride = Math.min(numTokens, maxPerDim);
|
|
85
|
+
return {
|
|
86
|
+
tokenStride,
|
|
87
|
+
workgroups: [tokenStride, Math.ceil(numTokens / tokenStride), 1],
|
|
88
|
+
};
|
|
89
|
+
}
|
|
90
|
+
|
|
61
91
|
export function selectRMSNormKernel(options = {}, isF16 = false) {
|
|
62
92
|
const { residual = null, hiddenSize = null } = options;
|
|
63
93
|
const { smallThreshold } = getKernelThresholds().rmsnorm;
|
|
@@ -82,23 +112,34 @@ export async function runRMSNorm(
|
|
|
82
112
|
const variant = selectRMSNormKernel(options, isF16);
|
|
83
113
|
const inferredHiddenSize = inferHiddenSize(input, hiddenSize);
|
|
84
114
|
const normWeightBuffer = getBuffer(weight);
|
|
115
|
+
assertRMSNormWeightBuffer(weight, normWeightBuffer, inferredHiddenSize);
|
|
85
116
|
const normWeightDtype = resolveNormWeightDtype(weight, inferredHiddenSize);
|
|
86
117
|
|
|
87
118
|
const bytesPerElement = isF16 ? 2 : 4;
|
|
88
119
|
const paddedHiddenSize = padToQ4KBlock(inferredHiddenSize);
|
|
89
120
|
const outputSize = batchSize * paddedHiddenSize * bytesPerElement;
|
|
90
121
|
const outputBuf = outputBuffer || acquireBuffer(outputSize, undefined, 'rmsnorm_output');
|
|
122
|
+
const dispatchPlan = planRMSNormDispatch(null, batchSize);
|
|
91
123
|
|
|
92
124
|
// Shader layout always includes the residual binding; when unused, bind a harmless placeholder.
|
|
93
|
-
const residualBuf = residual?.buffer || input
|
|
125
|
+
const residualBuf = residual?.buffer || residual || input?.buffer || input || outputBuf;
|
|
94
126
|
|
|
95
127
|
await unifiedKernelWrapper(
|
|
96
128
|
'rmsnorm',
|
|
97
129
|
null,
|
|
98
130
|
variant,
|
|
99
131
|
[input, normWeightBuffer, outputBuf, residualBuf],
|
|
100
|
-
{
|
|
101
|
-
|
|
132
|
+
{
|
|
133
|
+
hidden_size: inferredHiddenSize,
|
|
134
|
+
num_tokens: batchSize,
|
|
135
|
+
eps,
|
|
136
|
+
has_residual: residual ? 1 : 0,
|
|
137
|
+
token_stride: dispatchPlan.tokenStride,
|
|
138
|
+
_pad0: 0,
|
|
139
|
+
_pad1: 0,
|
|
140
|
+
_pad2: 0,
|
|
141
|
+
},
|
|
142
|
+
dispatchPlan.workgroups,
|
|
102
143
|
{ RMS_NORM_OFFSET: rmsNormWeightOffset, WEIGHT_IS_F16: normWeightDtype === 'f16' }
|
|
103
144
|
);
|
|
104
145
|
|
|
@@ -117,22 +158,33 @@ export async function recordRMSNorm(
|
|
|
117
158
|
const variant = selectRMSNormKernel(options, isF16);
|
|
118
159
|
const inferredHiddenSize = inferHiddenSize(input, hiddenSize);
|
|
119
160
|
const normWeightBuffer = getBuffer(weight);
|
|
161
|
+
assertRMSNormWeightBuffer(weight, normWeightBuffer, inferredHiddenSize);
|
|
120
162
|
const normWeightDtype = resolveNormWeightDtype(weight, inferredHiddenSize);
|
|
121
163
|
|
|
122
164
|
const bytesPerElement = isF16 ? 2 : 4;
|
|
123
165
|
const paddedHiddenSize = padToQ4KBlock(inferredHiddenSize);
|
|
124
166
|
const outputSize = batchSize * paddedHiddenSize * bytesPerElement;
|
|
125
167
|
const outputBuf = outputBuffer || acquireBuffer(outputSize, undefined, 'rmsnorm_output');
|
|
168
|
+
const dispatchPlan = planRMSNormDispatch(recorder, batchSize);
|
|
126
169
|
|
|
127
|
-
const residualBuf = residual?.buffer || input
|
|
170
|
+
const residualBuf = residual?.buffer || residual || input?.buffer || input || outputBuf;
|
|
128
171
|
|
|
129
172
|
await unifiedKernelWrapper(
|
|
130
173
|
'rmsnorm',
|
|
131
174
|
recorder,
|
|
132
175
|
variant,
|
|
133
176
|
[input, normWeightBuffer, outputBuf, residualBuf],
|
|
134
|
-
{
|
|
135
|
-
|
|
177
|
+
{
|
|
178
|
+
hidden_size: inferredHiddenSize,
|
|
179
|
+
num_tokens: batchSize,
|
|
180
|
+
eps,
|
|
181
|
+
has_residual: residual ? 1 : 0,
|
|
182
|
+
token_stride: dispatchPlan.tokenStride,
|
|
183
|
+
_pad0: 0,
|
|
184
|
+
_pad1: 0,
|
|
185
|
+
_pad2: 0,
|
|
186
|
+
},
|
|
187
|
+
dispatchPlan.workgroups,
|
|
136
188
|
{ RMS_NORM_OFFSET: rmsNormWeightOffset, WEIGHT_IS_F16: normWeightDtype === 'f16' }
|
|
137
189
|
);
|
|
138
190
|
|
|
@@ -39,6 +39,10 @@ struct Uniforms {
|
|
|
39
39
|
num_tokens: u32, // Number of tokens to process
|
|
40
40
|
eps: f32, // Epsilon for numerical stability (typically 1e-5 or 1e-6)
|
|
41
41
|
has_residual: u32, // Runtime flag: 1 = add residual after norm
|
|
42
|
+
token_stride: u32, // Workgroup rows per dispatch row
|
|
43
|
+
_pad0: u32,
|
|
44
|
+
_pad1: u32,
|
|
45
|
+
_pad2: u32,
|
|
42
46
|
}
|
|
43
47
|
|
|
44
48
|
@group(0) @binding(0) var<uniform> u: Uniforms;
|
|
@@ -82,6 +86,10 @@ fn should_add_residual() -> bool {
|
|
|
82
86
|
return HAS_RESIDUAL || (u.has_residual != 0u);
|
|
83
87
|
}
|
|
84
88
|
|
|
89
|
+
fn token_index(wg_id: vec3<u32>) -> u32 {
|
|
90
|
+
return wg_id.y * max(u.token_stride, 1u) + wg_id.x;
|
|
91
|
+
}
|
|
92
|
+
|
|
85
93
|
// =============================================================================
|
|
86
94
|
// Main Entry Point
|
|
87
95
|
// =============================================================================
|
|
@@ -93,7 +101,7 @@ fn main(
|
|
|
93
101
|
@builtin(local_invocation_id) local_id: vec3<u32>,
|
|
94
102
|
@builtin(workgroup_id) wg_id: vec3<u32>
|
|
95
103
|
) {
|
|
96
|
-
let token_idx = wg_id
|
|
104
|
+
let token_idx = token_index(wg_id);
|
|
97
105
|
let thread_idx = local_id.x;
|
|
98
106
|
let size = u.size;
|
|
99
107
|
|
|
@@ -163,7 +171,7 @@ fn main_small(
|
|
|
163
171
|
@builtin(local_invocation_id) local_id: vec3<u32>,
|
|
164
172
|
@builtin(workgroup_id) wg_id: vec3<u32>
|
|
165
173
|
) {
|
|
166
|
-
let token_idx = wg_id
|
|
174
|
+
let token_idx = token_index(wg_id);
|
|
167
175
|
let thread_idx = local_id.x;
|
|
168
176
|
let size = u.size;
|
|
169
177
|
|
|
@@ -219,7 +227,7 @@ fn main_cached(
|
|
|
219
227
|
@builtin(local_invocation_id) local_id: vec3<u32>,
|
|
220
228
|
@builtin(workgroup_id) wg_id: vec3<u32>
|
|
221
229
|
) {
|
|
222
|
-
let token_idx = wg_id
|
|
230
|
+
let token_idx = token_index(wg_id);
|
|
223
231
|
let thread_idx = local_id.x;
|
|
224
232
|
let size = u.size;
|
|
225
233
|
|
|
@@ -288,7 +296,7 @@ fn main_subgroup(
|
|
|
288
296
|
@builtin(subgroup_invocation_id) sg_lane: u32,
|
|
289
297
|
@builtin(subgroup_size) sg_size: u32,
|
|
290
298
|
) {
|
|
291
|
-
let token_idx = wg_id
|
|
299
|
+
let token_idx = token_index(wg_id);
|
|
292
300
|
let thread_idx = local_id.x;
|
|
293
301
|
let size = u.size;
|
|
294
302
|
|
|
@@ -362,7 +370,7 @@ fn main_small_subgroup(
|
|
|
362
370
|
@builtin(subgroup_invocation_id) sg_lane: u32,
|
|
363
371
|
@builtin(subgroup_size) sg_size: u32,
|
|
364
372
|
) {
|
|
365
|
-
let token_idx = wg_id
|
|
373
|
+
let token_idx = token_index(wg_id);
|
|
366
374
|
let thread_idx = local_id.x;
|
|
367
375
|
let size = u.size;
|
|
368
376
|
|
|
@@ -414,4 +422,4 @@ fn main_small_subgroup(
|
|
|
414
422
|
}
|
|
415
423
|
output[base_offset + thread_idx] = result;
|
|
416
424
|
}
|
|
417
|
-
}
|
|
425
|
+
}
|
|
@@ -20,6 +20,10 @@ struct Uniforms {
|
|
|
20
20
|
num_tokens: u32, // Number of tokens to process
|
|
21
21
|
eps: f32, // Epsilon for numerical stability
|
|
22
22
|
has_residual: u32, // 1 if residual input provided, 0 otherwise
|
|
23
|
+
token_stride: u32, // Workgroup rows per dispatch row
|
|
24
|
+
_pad0: u32,
|
|
25
|
+
_pad1: u32,
|
|
26
|
+
_pad2: u32,
|
|
23
27
|
}
|
|
24
28
|
|
|
25
29
|
@group(0) @binding(0) var<uniform> u: Uniforms;
|
|
@@ -47,6 +51,10 @@ fn load_weight(idx: u32) -> f32 {
|
|
|
47
51
|
return bitcast<f32>(weight[idx]);
|
|
48
52
|
}
|
|
49
53
|
|
|
54
|
+
fn token_index(wg_id: vec3<u32>) -> u32 {
|
|
55
|
+
return wg_id.y * max(u.token_stride, 1u) + wg_id.x;
|
|
56
|
+
}
|
|
57
|
+
|
|
50
58
|
// Main RMSNorm kernel - one workgroup per token
|
|
51
59
|
@compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
|
|
52
60
|
fn main(
|
|
@@ -54,7 +62,7 @@ fn main(
|
|
|
54
62
|
@builtin(local_invocation_id) local_id: vec3<u32>,
|
|
55
63
|
@builtin(workgroup_id) wg_id: vec3<u32>
|
|
56
64
|
) {
|
|
57
|
-
let token_idx = wg_id
|
|
65
|
+
let token_idx = token_index(wg_id);
|
|
58
66
|
let thread_idx = local_id.x;
|
|
59
67
|
let size = u.size;
|
|
60
68
|
|
|
@@ -121,7 +129,7 @@ fn rmsnorm_small_f16(
|
|
|
121
129
|
@builtin(local_invocation_id) local_id: vec3<u32>,
|
|
122
130
|
@builtin(workgroup_id) wg_id: vec3<u32>
|
|
123
131
|
) {
|
|
124
|
-
let token_idx = wg_id
|
|
132
|
+
let token_idx = token_index(wg_id);
|
|
125
133
|
let thread_idx = local_id.x;
|
|
126
134
|
let size = u.size;
|
|
127
135
|
|
package/src/gpu/kernels/rope.js
CHANGED
|
@@ -13,18 +13,26 @@ async function _rope(target, input, freqsCos, freqsSin, seqLen, options = {}) {
|
|
|
13
13
|
const {
|
|
14
14
|
numHeads = 1,
|
|
15
15
|
headDim = 64,
|
|
16
|
+
rotaryDim = headDim,
|
|
17
|
+
interleaved = false,
|
|
16
18
|
ropeTheta = ropeDefaults.defaultTheta,
|
|
17
19
|
} = options;
|
|
18
20
|
|
|
19
21
|
if (headDim % 2 !== 0) {
|
|
20
22
|
throw new Error(`RoPE headDim must be even, got ${headDim}`);
|
|
21
23
|
}
|
|
24
|
+
if (rotaryDim % 2 !== 0) {
|
|
25
|
+
throw new Error(`RoPE rotaryDim must be even, got ${rotaryDim}`);
|
|
26
|
+
}
|
|
27
|
+
if (rotaryDim <= 0 || rotaryDim > headDim) {
|
|
28
|
+
throw new Error(`RoPE rotaryDim must be in (0, headDim]; got ${rotaryDim} for headDim ${headDim}`);
|
|
29
|
+
}
|
|
22
30
|
|
|
23
31
|
const caps = getKernelCapabilities();
|
|
24
32
|
const useF16 = input.dtype === 'f16' && caps.hasF16;
|
|
25
33
|
const variant = selectRuleValue('rope', 'variant', { useF16 });
|
|
26
34
|
|
|
27
|
-
const halfDim =
|
|
35
|
+
const halfDim = rotaryDim / 2;
|
|
28
36
|
const workgroups = Math.ceil((seqLen * numHeads * halfDim) / WORKGROUP_SIZES.DEFAULT);
|
|
29
37
|
|
|
30
38
|
await unifiedKernelWrapper(
|
|
@@ -34,9 +42,11 @@ async function _rope(target, input, freqsCos, freqsSin, seqLen, options = {}) {
|
|
|
34
42
|
seq_len: seqLen,
|
|
35
43
|
num_heads: numHeads,
|
|
36
44
|
head_dim: headDim,
|
|
45
|
+
rotary_dim: rotaryDim,
|
|
37
46
|
start_pos: options.startPos ?? ropeDefaults.defaultStartPos,
|
|
38
47
|
rope_base: ropeTheta,
|
|
39
48
|
rope_scale: 1.0,
|
|
49
|
+
interleaved: interleaved ? 1 : 0,
|
|
40
50
|
},
|
|
41
51
|
workgroups
|
|
42
52
|
);
|
|
@@ -26,8 +26,8 @@ struct Uniforms {
|
|
|
26
26
|
start_pos: u32, // Starting position (for decode)
|
|
27
27
|
rope_base: f32, // Base frequency (default 10000)
|
|
28
28
|
rope_scale: f32, // Scaling factor for extended context
|
|
29
|
-
|
|
30
|
-
|
|
29
|
+
rotary_dim: u32, // Rotary slice within head_dim
|
|
30
|
+
interleaved: u32, // 1 = adjacent pairs, 0 = rotate-half
|
|
31
31
|
}
|
|
32
32
|
|
|
33
33
|
@group(0) @binding(0) var<uniform> u: Uniforms;
|
|
@@ -46,7 +46,8 @@ fn main(
|
|
|
46
46
|
let start_pos = u.start_pos;
|
|
47
47
|
|
|
48
48
|
// Global thread index (one thread per complex pair)
|
|
49
|
-
let
|
|
49
|
+
let rotary_dim = u.rotary_dim;
|
|
50
|
+
let half_dim = rotary_dim / 2u;
|
|
50
51
|
let total_pairs = seq_len * num_heads * half_dim;
|
|
51
52
|
let idx = global_id.x;
|
|
52
53
|
|
|
@@ -68,16 +69,18 @@ fn main(
|
|
|
68
69
|
|
|
69
70
|
// Apply "rotate-half" layout: pair (x[i], x[i + half_dim])
|
|
70
71
|
let base_idx = pos * num_heads * head_dim + head_idx * head_dim;
|
|
71
|
-
let
|
|
72
|
-
let
|
|
72
|
+
let first_idx = select(pair_idx, pair_idx * 2u, u.interleaved == 1u);
|
|
73
|
+
let second_idx = select(pair_idx + half_dim, pair_idx * 2u + 1u, u.interleaved == 1u);
|
|
74
|
+
let x0 = input[base_idx + first_idx];
|
|
75
|
+
let x1 = input[base_idx + second_idx];
|
|
73
76
|
|
|
74
77
|
// Apply rotation
|
|
75
78
|
let y0 = x0 * cos_val - x1 * sin_val;
|
|
76
79
|
let y1 = x0 * sin_val + x1 * cos_val;
|
|
77
80
|
|
|
78
81
|
// Write back
|
|
79
|
-
input[base_idx +
|
|
80
|
-
input[base_idx +
|
|
82
|
+
input[base_idx + first_idx] = y0;
|
|
83
|
+
input[base_idx + second_idx] = y1;
|
|
81
84
|
}
|
|
82
85
|
|
|
83
86
|
// Compute frequencies on-the-fly (no precomputation needed)
|
|
@@ -91,9 +94,10 @@ fn rope_compute_freqs(
|
|
|
91
94
|
let start_pos = u.start_pos;
|
|
92
95
|
let rope_base = u.rope_base;
|
|
93
96
|
let rope_scale = u.rope_scale;
|
|
97
|
+
let rotary_dim = u.rotary_dim;
|
|
94
98
|
|
|
95
99
|
let idx = global_id.x;
|
|
96
|
-
let half_dim =
|
|
100
|
+
let half_dim = rotary_dim / 2u;
|
|
97
101
|
let total_pairs = seq_len * num_heads * half_dim;
|
|
98
102
|
|
|
99
103
|
if (idx >= total_pairs) {
|
|
@@ -109,7 +113,7 @@ fn rope_compute_freqs(
|
|
|
109
113
|
let actual_pos = f32(start_pos + pos) / rope_scale;
|
|
110
114
|
|
|
111
115
|
// Compute frequency: 1 / (base^(2*pair_idx/head_dim))
|
|
112
|
-
let exponent = f32(pair_idx * 2u) / f32(
|
|
116
|
+
let exponent = f32(pair_idx * 2u) / f32(rotary_dim);
|
|
113
117
|
let freq = 1.0 / pow(rope_base, exponent);
|
|
114
118
|
let theta = actual_pos * freq;
|
|
115
119
|
|
|
@@ -118,12 +122,14 @@ fn rope_compute_freqs(
|
|
|
118
122
|
|
|
119
123
|
// Apply "rotate-half" layout: pair (x[i], x[i + half_dim])
|
|
120
124
|
let base_idx = pos * num_heads * head_dim + head_idx * head_dim;
|
|
121
|
-
let
|
|
122
|
-
let
|
|
125
|
+
let first_idx = select(pair_idx, pair_idx * 2u, u.interleaved == 1u);
|
|
126
|
+
let second_idx = select(pair_idx + half_dim, pair_idx * 2u + 1u, u.interleaved == 1u);
|
|
127
|
+
let x0 = input[base_idx + first_idx];
|
|
128
|
+
let x1 = input[base_idx + second_idx];
|
|
123
129
|
|
|
124
130
|
// Apply rotation
|
|
125
|
-
input[base_idx +
|
|
126
|
-
input[base_idx +
|
|
131
|
+
input[base_idx + first_idx] = x0 * cos_val - x1 * sin_val;
|
|
132
|
+
input[base_idx + second_idx] = x0 * sin_val + x1 * cos_val;
|
|
127
133
|
}
|
|
128
134
|
|
|
129
135
|
// Apply RoPE to both Q and K in one pass
|
|
@@ -138,10 +144,11 @@ fn rope_qk(
|
|
|
138
144
|
let start_pos = u.start_pos;
|
|
139
145
|
let rope_base = u.rope_base;
|
|
140
146
|
let rope_scale = u.rope_scale;
|
|
147
|
+
let rotary_dim = u.rotary_dim;
|
|
141
148
|
|
|
142
149
|
let idx = global_id.x;
|
|
143
150
|
// Each thread handles one Q-K pair at one dimension pair
|
|
144
|
-
let half_dim =
|
|
151
|
+
let half_dim = rotary_dim / 2u;
|
|
145
152
|
let total_pairs = seq_len * num_heads * half_dim;
|
|
146
153
|
|
|
147
154
|
if (idx >= total_pairs) {
|
|
@@ -156,7 +163,7 @@ fn rope_qk(
|
|
|
156
163
|
let actual_pos = f32(start_pos + pos) / rope_scale;
|
|
157
164
|
|
|
158
165
|
// Compute frequency
|
|
159
|
-
let exponent = f32(pair_idx * 2u) / f32(
|
|
166
|
+
let exponent = f32(pair_idx * 2u) / f32(rotary_dim);
|
|
160
167
|
let freq = 1.0 / pow(rope_base, exponent);
|
|
161
168
|
let theta = actual_pos * freq;
|
|
162
169
|
|
|
@@ -168,16 +175,18 @@ fn rope_qk(
|
|
|
168
175
|
let k_base_idx = q_base_idx + head_dim; // K starts after Q
|
|
169
176
|
|
|
170
177
|
// Process Q
|
|
171
|
-
let
|
|
172
|
-
let
|
|
173
|
-
input[q_base_idx +
|
|
174
|
-
input[q_base_idx +
|
|
178
|
+
let first_idx = select(pair_idx, pair_idx * 2u, u.interleaved == 1u);
|
|
179
|
+
let second_idx = select(pair_idx + half_dim, pair_idx * 2u + 1u, u.interleaved == 1u);
|
|
180
|
+
let q0 = input[q_base_idx + first_idx];
|
|
181
|
+
let q1 = input[q_base_idx + second_idx];
|
|
182
|
+
input[q_base_idx + first_idx] = q0 * cos_val - q1 * sin_val;
|
|
183
|
+
input[q_base_idx + second_idx] = q0 * sin_val + q1 * cos_val;
|
|
175
184
|
|
|
176
185
|
// Process K
|
|
177
|
-
let k0 = input[k_base_idx +
|
|
178
|
-
let k1 = input[k_base_idx +
|
|
179
|
-
input[k_base_idx +
|
|
180
|
-
input[k_base_idx +
|
|
186
|
+
let k0 = input[k_base_idx + first_idx];
|
|
187
|
+
let k1 = input[k_base_idx + second_idx];
|
|
188
|
+
input[k_base_idx + first_idx] = k0 * cos_val - k1 * sin_val;
|
|
189
|
+
input[k_base_idx + second_idx] = k0 * sin_val + k1 * cos_val;
|
|
181
190
|
}
|
|
182
191
|
|
|
183
192
|
// Precompute frequency table (run once at init)
|
|
@@ -190,9 +199,10 @@ fn precompute_freqs(
|
|
|
190
199
|
let seq_len = u.seq_len; // maxSeqLen for precomputation
|
|
191
200
|
let rope_base = u.rope_base;
|
|
192
201
|
let rope_scale = u.rope_scale;
|
|
202
|
+
let rotary_dim = u.rotary_dim;
|
|
193
203
|
|
|
194
204
|
let idx = global_id.x;
|
|
195
|
-
let half_dim =
|
|
205
|
+
let half_dim = rotary_dim / 2u;
|
|
196
206
|
let total_elements = seq_len * half_dim;
|
|
197
207
|
|
|
198
208
|
if (idx >= total_elements) {
|
|
@@ -203,7 +213,7 @@ fn precompute_freqs(
|
|
|
203
213
|
let dim_idx = idx % half_dim;
|
|
204
214
|
|
|
205
215
|
let actual_pos = f32(pos) / rope_scale;
|
|
206
|
-
let exponent = f32(dim_idx * 2u) / f32(
|
|
216
|
+
let exponent = f32(dim_idx * 2u) / f32(rotary_dim);
|
|
207
217
|
let freq = 1.0 / pow(rope_base, exponent);
|
|
208
218
|
let theta = actual_pos * freq;
|
|
209
219
|
|
|
@@ -218,6 +228,7 @@ fn rope_ntk_scaled(
|
|
|
218
228
|
@builtin(global_invocation_id) global_id: vec3<u32>
|
|
219
229
|
) {
|
|
220
230
|
let head_dim = u.head_dim;
|
|
231
|
+
let rotary_dim = u.rotary_dim;
|
|
221
232
|
let num_heads = u.num_heads;
|
|
222
233
|
let seq_len = u.seq_len;
|
|
223
234
|
let start_pos = u.start_pos;
|
|
@@ -225,7 +236,7 @@ fn rope_ntk_scaled(
|
|
|
225
236
|
let rope_scale = u.rope_scale;
|
|
226
237
|
|
|
227
238
|
let idx = global_id.x;
|
|
228
|
-
let half_dim =
|
|
239
|
+
let half_dim = rotary_dim / 2u;
|
|
229
240
|
let total_pairs = seq_len * num_heads * half_dim;
|
|
230
241
|
|
|
231
242
|
if (idx >= total_pairs) {
|
|
@@ -234,7 +245,7 @@ fn rope_ntk_scaled(
|
|
|
234
245
|
|
|
235
246
|
// NTK scaling: increase base proportionally to scale factor
|
|
236
247
|
// This preserves high-frequency components better than linear interpolation
|
|
237
|
-
rope_base = rope_base * pow(rope_scale, f32(
|
|
248
|
+
rope_base = rope_base * pow(rope_scale, f32(rotary_dim) / (f32(rotary_dim) - 2.0));
|
|
238
249
|
|
|
239
250
|
let pos = idx / (num_heads * half_dim);
|
|
240
251
|
let remainder = idx % (num_heads * half_dim);
|
|
@@ -243,7 +254,7 @@ fn rope_ntk_scaled(
|
|
|
243
254
|
|
|
244
255
|
let actual_pos = f32(start_pos + pos);
|
|
245
256
|
|
|
246
|
-
let exponent = f32(pair_idx * 2u) / f32(
|
|
257
|
+
let exponent = f32(pair_idx * 2u) / f32(rotary_dim);
|
|
247
258
|
let freq = 1.0 / pow(rope_base, exponent);
|
|
248
259
|
let theta = actual_pos * freq;
|
|
249
260
|
|
|
@@ -251,11 +262,13 @@ fn rope_ntk_scaled(
|
|
|
251
262
|
let sin_val = sin(theta);
|
|
252
263
|
|
|
253
264
|
let base_idx = pos * num_heads * head_dim + head_idx * head_dim;
|
|
254
|
-
let
|
|
255
|
-
let
|
|
265
|
+
let first_idx = select(pair_idx, pair_idx * 2u, u.interleaved == 1u);
|
|
266
|
+
let second_idx = select(pair_idx + half_dim, pair_idx * 2u + 1u, u.interleaved == 1u);
|
|
267
|
+
let x0 = input[base_idx + first_idx];
|
|
268
|
+
let x1 = input[base_idx + second_idx];
|
|
256
269
|
|
|
257
|
-
input[base_idx +
|
|
258
|
-
input[base_idx +
|
|
270
|
+
input[base_idx + first_idx] = x0 * cos_val - x1 * sin_val;
|
|
271
|
+
input[base_idx + second_idx] = x0 * sin_val + x1 * cos_val;
|
|
259
272
|
}
|
|
260
273
|
|
|
261
274
|
// YaRN-style RoPE with attention scaling
|
|
@@ -265,6 +278,7 @@ fn rope_yarn(
|
|
|
265
278
|
@builtin(global_invocation_id) global_id: vec3<u32>
|
|
266
279
|
) {
|
|
267
280
|
let head_dim = u.head_dim;
|
|
281
|
+
let rotary_dim = u.rotary_dim;
|
|
268
282
|
let num_heads = u.num_heads;
|
|
269
283
|
let seq_len = u.seq_len;
|
|
270
284
|
let start_pos = u.start_pos;
|
|
@@ -272,7 +286,7 @@ fn rope_yarn(
|
|
|
272
286
|
let rope_scale = u.rope_scale;
|
|
273
287
|
|
|
274
288
|
let idx = global_id.x;
|
|
275
|
-
let half_dim =
|
|
289
|
+
let half_dim = rotary_dim / 2u;
|
|
276
290
|
let total_pairs = seq_len * num_heads * half_dim;
|
|
277
291
|
|
|
278
292
|
if (idx >= total_pairs) {
|
|
@@ -292,7 +306,7 @@ fn rope_yarn(
|
|
|
292
306
|
let alpha: f32 = 1.0;
|
|
293
307
|
|
|
294
308
|
// Compute original frequency
|
|
295
|
-
let exponent = f32(pair_idx * 2u) / f32(
|
|
309
|
+
let exponent = f32(pair_idx * 2u) / f32(rotary_dim);
|
|
296
310
|
let orig_freq = 1.0 / pow(rope_base, exponent);
|
|
297
311
|
|
|
298
312
|
// Compute wavelength
|
|
@@ -300,8 +314,8 @@ fn rope_yarn(
|
|
|
300
314
|
|
|
301
315
|
// Interpolation factor based on wavelength
|
|
302
316
|
var ramp: f32;
|
|
303
|
-
let low_wavelength = f32(
|
|
304
|
-
let high_wavelength = f32(
|
|
317
|
+
let low_wavelength = f32(rotary_dim) / beta_fast;
|
|
318
|
+
let high_wavelength = f32(rotary_dim) / beta_slow;
|
|
305
319
|
|
|
306
320
|
if (wavelength < low_wavelength) {
|
|
307
321
|
ramp = 0.0; // No interpolation for high frequencies
|
|
@@ -320,9 +334,11 @@ fn rope_yarn(
|
|
|
320
334
|
let sin_val = sin(theta);
|
|
321
335
|
|
|
322
336
|
let base_idx = pos * num_heads * head_dim + head_idx * head_dim;
|
|
323
|
-
let
|
|
324
|
-
let
|
|
337
|
+
let first_idx = select(pair_idx, pair_idx * 2u, u.interleaved == 1u);
|
|
338
|
+
let second_idx = select(pair_idx + half_dim, pair_idx * 2u + 1u, u.interleaved == 1u);
|
|
339
|
+
let x0 = input[base_idx + first_idx];
|
|
340
|
+
let x1 = input[base_idx + second_idx];
|
|
325
341
|
|
|
326
|
-
input[base_idx +
|
|
327
|
-
input[base_idx +
|
|
342
|
+
input[base_idx + first_idx] = x0 * cos_val - x1 * sin_val;
|
|
343
|
+
input[base_idx + second_idx] = x0 * sin_val + x1 * cos_val;
|
|
328
344
|
}
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
import type { Tensor } from '../tensor.js';
|
|
2
|
+
import type { CommandRecorder } from '../command-recorder.js';
|
|
3
|
+
import type { OutputBufferOptions } from './types.js';
|
|
4
|
+
|
|
5
|
+
export interface SanaLinearAttentionOptions extends OutputBufferOptions {
|
|
6
|
+
numHeads: number;
|
|
7
|
+
headDim: number;
|
|
8
|
+
numTokens?: number;
|
|
9
|
+
hiddenSize?: number;
|
|
10
|
+
eps?: number;
|
|
11
|
+
summaryBuffer?: GPUBuffer | null;
|
|
12
|
+
}
|
|
13
|
+
|
|
14
|
+
export declare function runSanaLinearAttention(
|
|
15
|
+
query: Tensor,
|
|
16
|
+
key: Tensor,
|
|
17
|
+
value: Tensor,
|
|
18
|
+
options: SanaLinearAttentionOptions
|
|
19
|
+
): Promise<Tensor>;
|
|
20
|
+
|
|
21
|
+
export declare function recordSanaLinearAttention(
|
|
22
|
+
recorder: CommandRecorder,
|
|
23
|
+
query: Tensor,
|
|
24
|
+
key: Tensor,
|
|
25
|
+
value: Tensor,
|
|
26
|
+
options: SanaLinearAttentionOptions
|
|
27
|
+
): Promise<Tensor>;
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
import { getDevice } from '../device.js';
|
|
2
|
+
import { acquireBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
|
|
3
|
+
import { createTensor, dtypeBytes } from '../tensor.js';
|
|
4
|
+
import { unifiedKernelWrapper } from './utils.js';
|
|
5
|
+
import { selectRuleValue } from './rule-registry.js';
|
|
6
|
+
import { WORKGROUP_SIZES } from './constants.js';
|
|
7
|
+
|
|
8
|
+
function selectSanaLinearAttentionVariant(isF16) {
|
|
9
|
+
return selectRuleValue('sanaLinearAttention', 'variant', { isF16 });
|
|
10
|
+
}
|
|
11
|
+
|
|
12
|
+
async function runSummary(target, query, key, value, summaryBuffer, uniforms, variant) {
|
|
13
|
+
const summarySize = uniforms.num_heads * (uniforms.head_dim + 1) * uniforms.head_dim;
|
|
14
|
+
await unifiedKernelWrapper(
|
|
15
|
+
'sana_linear_attention_summary',
|
|
16
|
+
target,
|
|
17
|
+
variant,
|
|
18
|
+
[query, key, value, summaryBuffer],
|
|
19
|
+
{
|
|
20
|
+
num_heads: uniforms.num_heads,
|
|
21
|
+
head_dim: uniforms.head_dim,
|
|
22
|
+
num_tokens: uniforms.num_tokens,
|
|
23
|
+
hidden_size: uniforms.hidden_size,
|
|
24
|
+
_pad0: 0,
|
|
25
|
+
_pad1: 0,
|
|
26
|
+
},
|
|
27
|
+
Math.ceil(summarySize / WORKGROUP_SIZES.DEFAULT)
|
|
28
|
+
);
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
async function runApply(target, query, summaryBuffer, outputBuffer, uniforms, variant) {
|
|
32
|
+
await unifiedKernelWrapper(
|
|
33
|
+
'sana_linear_attention_apply',
|
|
34
|
+
target,
|
|
35
|
+
variant,
|
|
36
|
+
[query, summaryBuffer, outputBuffer],
|
|
37
|
+
{
|
|
38
|
+
num_heads: uniforms.num_heads,
|
|
39
|
+
head_dim: uniforms.head_dim,
|
|
40
|
+
num_tokens: uniforms.num_tokens,
|
|
41
|
+
hidden_size: uniforms.hidden_size,
|
|
42
|
+
eps: uniforms.eps,
|
|
43
|
+
_pad0: 0,
|
|
44
|
+
_pad1: 0,
|
|
45
|
+
_pad2: 0,
|
|
46
|
+
},
|
|
47
|
+
[Math.ceil(uniforms.hidden_size / WORKGROUP_SIZES.DEFAULT), uniforms.num_tokens, 1]
|
|
48
|
+
);
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
async function _sanaLinearAttention(target, query, key, value, options = {}) {
|
|
52
|
+
const recorder = target && typeof target.beginComputePass === 'function' ? target : null;
|
|
53
|
+
const device = target?.device || getDevice();
|
|
54
|
+
if (!device) {
|
|
55
|
+
throw new Error('SanaLinearAttention requires a WebGPU device.');
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
const {
|
|
59
|
+
numHeads,
|
|
60
|
+
headDim,
|
|
61
|
+
numTokens = query.shape?.[0],
|
|
62
|
+
hiddenSize = query.shape?.[1],
|
|
63
|
+
eps = 1e-15,
|
|
64
|
+
outputBuffer = null,
|
|
65
|
+
summaryBuffer = null,
|
|
66
|
+
} = options;
|
|
67
|
+
|
|
68
|
+
if (
|
|
69
|
+
!Number.isFinite(numHeads) ||
|
|
70
|
+
!Number.isFinite(headDim) ||
|
|
71
|
+
!Number.isFinite(numTokens) ||
|
|
72
|
+
!Number.isFinite(hiddenSize)
|
|
73
|
+
) {
|
|
74
|
+
throw new Error('SanaLinearAttention requires numHeads, headDim, numTokens, and hiddenSize.');
|
|
75
|
+
}
|
|
76
|
+
if (hiddenSize !== numHeads * headDim) {
|
|
77
|
+
throw new Error(`SanaLinearAttention hiddenSize mismatch: ${hiddenSize} != ${numHeads} * ${headDim}`);
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
const isF16 = query.dtype === 'f16';
|
|
81
|
+
const variant = selectSanaLinearAttentionVariant(isF16);
|
|
82
|
+
const temporarySummary = summaryBuffer || acquireBuffer(
|
|
83
|
+
numHeads * (headDim + 1) * headDim * Float32Array.BYTES_PER_ELEMENT,
|
|
84
|
+
undefined,
|
|
85
|
+
'sana_linear_attention_summary'
|
|
86
|
+
);
|
|
87
|
+
const output = outputBuffer || acquireBuffer(
|
|
88
|
+
numTokens * hiddenSize * dtypeBytes(query.dtype),
|
|
89
|
+
undefined,
|
|
90
|
+
'sana_linear_attention_output'
|
|
91
|
+
);
|
|
92
|
+
|
|
93
|
+
const uniforms = {
|
|
94
|
+
num_heads: numHeads,
|
|
95
|
+
head_dim: headDim,
|
|
96
|
+
num_tokens: numTokens,
|
|
97
|
+
hidden_size: hiddenSize,
|
|
98
|
+
eps,
|
|
99
|
+
};
|
|
100
|
+
|
|
101
|
+
await runSummary(target, query, key, value, temporarySummary, uniforms, variant);
|
|
102
|
+
await runApply(target, query, temporarySummary, output, uniforms, variant);
|
|
103
|
+
|
|
104
|
+
if (!summaryBuffer) {
|
|
105
|
+
if (recorder) {
|
|
106
|
+
recorder.trackTemporaryBuffer(temporarySummary);
|
|
107
|
+
} else {
|
|
108
|
+
releaseBuffer(temporarySummary);
|
|
109
|
+
}
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
return createTensor(output, query.dtype, [numTokens, hiddenSize], 'sana_linear_attention_output');
|
|
113
|
+
}
|
|
114
|
+
|
|
115
|
+
export async function runSanaLinearAttention(query, key, value, options = {}) {
|
|
116
|
+
return _sanaLinearAttention(null, query, key, value, options);
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
export async function recordSanaLinearAttention(recorder, query, key, value, options = {}) {
|
|
120
|
+
return _sanaLinearAttention(recorder, query, key, value, options);
|
|
121
|
+
}
|